#!/usr/bin/env python

"""
print the theta and phi angles as a time series, as well as some other junk. This version puts the arm vectors in the plane normal to the origami when computing phi 30/06/15
"""

import readers
import base
import numpy as np
import origami_utils as oru
import sys
import math

armveclen=3 # for each arm, the number of bp to include in the vector defining that arm's direction
err_file = open("isomers.inf", "w")

def check_isomer(hj, origami, conf_count):
    # check X-stacked isomer by ensuring that distance between bps that should be coaxially stacked is much smaller than the distance between bps that should be in separate double helical domains
    system = origami._sys
    vh1 = hj[0]
    vb1 = hj[1]
    vh2 = hj[2]
    vb2 = hj[3]
    
    nuc_a1_id = origami.get_nucleotides(vh1, vb1)[0]
    nuc_a2_id = origami.complementary_list[nuc_a1_id]
    nuc_b1_id = origami.get_nucleotides(vh1, vb2)[0]
    nuc_b2_id = origami.complementary_list[nuc_b1_id]
    nuc_c1_id = origami.get_nucleotides(vh2, vb1)[0]
    nuc_c2_id = origami.complementary_list[nuc_c1_id]
    nuc_d1_id = origami.get_nucleotides(vh2, vb2)[0]
    nuc_d2_id = origami.complementary_list[nuc_d1_id]

    nuc_a1 = system._nucleotides[nuc_a1_id]
    nuc_a2 = system._nucleotides[nuc_a2_id]
    nuc_b1 = system._nucleotides[nuc_b1_id]
    nuc_b2 = system._nucleotides[nuc_b2_id]
    nuc_c1 = system._nucleotides[nuc_c1_id]
    nuc_c2 = system._nucleotides[nuc_c2_id]
    nuc_d1 = system._nucleotides[nuc_d1_id]
    nuc_d2 = system._nucleotides[nuc_d2_id]

    pos = []
    for (nuc_x1, nuc_x2) in [(nuc_a1, nuc_a2), (nuc_b1, nuc_b2), (nuc_c1, nuc_c2), (nuc_d1, nuc_d2)]:
        pos.append(oru.get_pos_midpoint(nuc_x1.get_pos_base(), nuc_x2.get_pos_base(), origami._sys._box))

    ds1 = oru.min_distance(pos[1], pos[0], origami._sys._box)
    ds1 = np.sqrt(np.dot(ds1, ds1))
    ds2 = oru.min_distance(pos[3], pos[2], origami._sys._box)
    ds2 = np.sqrt(np.dot(ds2, ds2))
    dl1 = oru.min_distance(pos[2], pos[0], origami._sys._box)
    dl1 = np.sqrt(np.dot(dl1, dl1))
    dl2 = oru.min_distance(pos[3], pos[1], origami._sys._box)
    dl2 = np.sqrt(np.dot(dl2, dl2))

    if max(ds1, ds2) > min(dl1, dl2):
        base.Logger.log("inter base-pair distance indicates isomer change. Have short distances %f, %f and long distances %f, %f" % (ds1, ds2, dl1, dl2), base.Logger.WARNING)
        aa = [ds1, ds2, dl1, dl2]
        aa.sort()
        ret = True
        if aa[1] < 0.8:
            base.Logger.log("2nd shortest is shorter than critical distance", base.Logger.WARNING)
    else:
        err_file.write("%d OK\n" % conf_count)
        ret = False

    return ret

def get_arm_vec(vhc,vbc,vbn,origami):
    system = origami._sys
    nuc0scid = origami.get_nucleotides(vhc,vbc)[0]
    nuc0stid = origami.complementary_list[nuc0scid]
    nuc1scid = origami.get_nucleotides(vhc,vbn)[0]
    nuc1stid = origami.complementary_list[nuc1scid]

    nuc0sc = system._nucleotides[nuc0scid]
    nuc0st = system._nucleotides[nuc0stid]
    nuc1sc = system._nucleotides[nuc1scid]
    nuc1st = system._nucleotides[nuc1stid]

    v0 = nuc0sc.get_pos_base() + nuc0st.get_pos_base()
    v1 = nuc1sc.get_pos_base() + nuc1st.get_pos_base()

    v = oru.norm(oru.min_distance(v0, v1, origami._sys._box))
    return v

def main():

    if len(sys.argv) < 3:
        base.Logger.log("Usage is %s configuration topology [origami]" % sys.argv[0], base.Logger.CRITICAL)
        sys.exit()

    conffile = sys.argv[1]
    topfile = sys.argv[2]

    isorigami = False
    if len(sys.argv) > 3:
        if sys.argv[3] == "origami":
            isorigami = True

    r = readers.LorenzoReader(conffile, topfile)
    s = r.get_system()

    origami = oru.Origami(s, "virt2nuc")
    hjs = origami.get_holliday_junctions()

    phi1_av = [0 for x in hjs]
    phi1_av2 = [0 for x in hjs]
    phi2_av = [0 for x in hjs]
    phi2_av2 = [0 for x in hjs]
    theta_av = [[0 for x in range(4)] for x in hjs]
    theta_av2 = [[0 for x in range(4)] for x in hjs]
    phi1_ser = [[] for x in hjs] # time series
    phi2_ser = [[] for x in hjs]
    thetaA_ser = [[] for x in hjs]
    thetaB_ser = [[] for x in hjs]
    conf_count = 0
    conf_used_count = 0
    wrong_isomer_count = 0
    while s:
        origami.update_system(s)
        conf_count += 1
        base.Logger.log("reading conf %d" % conf_count, base.Logger.INFO)
        for ii in range(len(hjs)):
            # assume no insertions/deletions
            if len(origami._sys._strands) == 4:
                if check_isomer(hjs[ii], origami, conf_count):
                    wrong_isomer_count += 1

            vh1 = hjs[ii][0]
            vb1 = hjs[ii][1]
            vh2 = hjs[ii][2]
            vb2 = hjs[ii][3]

            # arm vectors - in cadnano representation:
            # A====>C
            #    ||
            # D<====B
            vA = get_arm_vec(vh1,vb1,vb1-armveclen,origami)
            vB = get_arm_vec(vh2,vb2,vb2+armveclen,origami)
            vC = get_arm_vec(vh1,vb2,vb2+armveclen,origami)
            vD = get_arm_vec(vh2,vb1,vb1-armveclen,origami)

            nucu1_id = origami.get_nucleotides(vh1,vb1)[0]
            nucu2_id = origami.get_nucleotides(vh1,vb2)[0]
            nucd1_id = origami.get_nucleotides(vh2,vb1)[0]
            nucd2_id = origami.get_nucleotides(vh2,vb2)[0]

            nucu1 = s._nucleotides[nucu1_id]
            nucu2 = s._nucleotides[nucu2_id]
            nucd1 = s._nucleotides[nucd1_id]
            nucd2 = s._nucleotides[nucd2_id]

            ru1 = nucu1.get_pos_base()
            ru2 = nucu2.get_pos_base()
            rd1 = nucd1.get_pos_base()
            rd2 = nucd2.get_pos_base()

            # from top dhd to bottom
            # distance(midpoint_up, midpoint_down)
            n_updown = oru.min_distance(oru.get_pos_midpoint(ru1, ru2, origami._sys._box), oru.get_pos_midpoint(rd1, rd2, origami._sys._box), origami._sys._box)
            n_updown = oru.norm(n_updown)

            vA = oru.norm(vA)
            vB = oru.norm(vB)
            vC = oru.norm(vC)
            vD = oru.norm(vD)

            # plane normal
            n = np.cross(vA,vB) + np.cross(vB,vC) + np.cross(vC,vD) + np.cross(vD,vA)
            n = oru.norm(n)

            # project arms into plane to get phi
            vAp = vA - np.dot(n,vA) * n
            vBp = vB - np.dot(n,vB) * n
            vCp = vC - np.dot(n,vC) * n
            vDp = vD - np.dot(n,vD) * n

            phi1 = np.arccos(np.dot(vAp,vBp)) * 180/np.pi
            phi2 = np.arccos(np.dot(vCp,vDp)) * 180/np.pi

            # incorporate sense of phi angle 20/09/12
            if np.dot(np.cross(vAp,vBp), n_updown) < 0:
                phi1 = 360. -phi1
            if np.dot(np.cross(vCp,vDp), n_updown) < 0:
                phi2 = 360. -phi2
                    
            thetaA = abs(np.arcsin(np.dot(n,vA))) * 180/np.pi
            thetaB = abs(np.arcsin(np.dot(n,vB))) * 180/np.pi
            thetaC = abs(np.arcsin(np.dot(n,vC))) * 180/np.pi
            thetaD = abs(np.arcsin(np.dot(n,vD))) * 180/np.pi

            # incorporate sense of theta angle 27/09/12
            if np.dot(vA, n_updown) > 0:
                thetaA *= -1
            if np.dot(vB, n_updown) < 0:
                thetaB *= -1
            if np.dot(vC, n_updown) > 0:
                thetaC *= -1
            if np.dot(vD, n_updown) < 0:
                thetaD *= -1

            phi1_av[ii] += phi1
            phi1_av2[ii] += phi1 * phi1
            phi2_av[ii] += phi2
            phi2_av2[ii] += phi2 * phi2
            theta_av[ii][0] += thetaA
            theta_av[ii][1] += thetaB
            theta_av[ii][2] += thetaC
            theta_av[ii][3] += thetaD
            theta_av2[ii][0] += thetaA * thetaA
            theta_av2[ii][1] += thetaB * thetaB
            theta_av2[ii][2] += thetaC * thetaC
            theta_av2[ii][3] += thetaD * thetaD

            phi1_ser[ii].append(phi1)
            phi2_ser[ii].append(phi2)
            thetaA_ser[ii].append(thetaA)
            thetaB_ser[ii].append(thetaB)

        s = r.get_system()

    phi1_av = [x/conf_count for x in phi1_av]
    phi2_av = [x/conf_count for x in phi2_av]
    theta_av = [[y/conf_count for y in x] for x in theta_av]
    phi1_av2 = [x/conf_count for x in phi1_av2]
    phi2_av2 = [x/conf_count for x in phi2_av2]
    theta_av2 = [[y/conf_count for y in x] for x in theta_av2]

    phi1_sig = [0 for x in phi1_av]
    phi2_sig = [0 for x in phi2_av]
    theta_sig = [[0 for y in x] for x in theta_av]
    for jj in range(len(phi1_av)):
        phi1_sig[jj] = np.sqrt(phi1_av2[jj] - phi1_av[jj] * phi1_av[jj])
        phi2_sig[jj] = np.sqrt(phi2_av2[jj] - phi2_av[jj] * phi2_av[jj])
        for ii in range(len(theta_av[jj])):
            theta_sig[jj][ii] = np.sqrt(theta_av2[jj][ii] - theta_av[jj][ii] * theta_av[jj][ii])

    if isorigami:
        foutn = "hj_angles_detailed.dat"
        fout = open(foutn, "w")
        for ii in range(len(hjs)):
            fout.write("%d %d %f %f %f %f %f %f %f %f\n" % (hjs[ii][0], hjs[ii][1], phi1_av[ii], phi1_sig[ii], phi2_av[ii], phi2_sig[ii], theta_av[ii][0], theta_sig[ii][0], theta_av[ii][1], theta_sig[ii][1]))
        fout.close()
        base.Logger.log("wrote file %s" % foutn, base.Logger.INFO)

        fout2n = "hj_angles.dat"
        fout2 = open(fout2n, "w")
        for ii in range(len(hjs)):
            fout2.write("%d %d %f %f\n" % (hjs[ii][0], hjs[ii][1], (phi1_av[ii] + phi2_av[ii])/2, (theta_av[ii][0] + theta_av[ii][1])/2))
        fout2.close()
        base.Logger.log("wrote file %s" % fout2n, base.Logger.INFO)

        phi1_avav = 0
        phi2_avav = 0
        thetaA_avav = 0
        thetaB_avav = 0
        for ii in range(len(hjs)):
            phi1_avav += phi1_av[ii]
            phi2_avav += phi2_av[ii]
            thetaA_avav += theta_av[ii][0]
            thetaB_avav += theta_av[ii][1]
        phi1_avav /= len(hjs)
        phi2_avav /= len(hjs)
        thetaA_avav /= len(hjs)
        thetaB_avav /= len(hjs)
        phi_avav = (phi1_avav + phi2_avav)/2
        theta_avav = (abs(thetaA_avav) + abs(thetaB_avav))/2

        print "phi_av_all, abs(theta_av_all), phi1_av, phi2_av, thetaA_av, thetaB_av"
        print phi_avav, theta_avav, phi1_avav, phi2_avav, thetaA_avav, thetaB_avav

    else:
        print "phi1 phi1_sig phi2 phi2_sig" 
        print phi1_av[0], phi1_sig[0], phi2_av[0], phi2_sig[0]
        print "theta_av[ii] theta_sig[ii]"
        for ii in range(len(theta_av[0])):
            print theta_av[0][ii], theta_sig[0][ii]


    # bin to get joint probability density function (2d)
    # let's just assume there is 1 holliday junction

    # first get average phi and average theta...
    phi_ser = range(len(phi1_ser[0]))
    theta_ser = range(len(phi1_ser[0]))
    for ii in range(len(phi1_ser[0])):
        phi_ser[ii] = (phi1_ser[0][ii] + phi2_ser[0][ii])/2
        theta_ser[ii]  = (thetaA_ser[0][ii] + thetaB_ser[0][ii])/2

    f = open("phi_theta_series.dat", "w")
    for ii in range(len(phi_ser)):
        f.write("%d %f %f\n" %(ii, phi_ser[ii], theta_ser[ii]))
    f.close()

    base.Logger.log("wrote time series data to phi_theta_series.dat", base.Logger.INFO)

    nbins1 = 72
    nbins2 = 36
    max1 = 360. # for phi
    min1 = 0.
    max2 = 90. # for theta
    min2 = -90.

    f = open("phi_theta_bin.dat", "w")

    dx1 = (max1-min1)/float(nbins1)
    dx2 = (max2-min2)/float(nbins2)
    dy = 1./len(phi_ser)

    hist = [[0 for x in range(nbins1)] for y in range(nbins2)]

    for ii in range(len(phi_ser)):
        try:
            theta_bin = int(math.floor((theta_ser[ii]-min2)/dx2))
            phi_bin = int(math.floor((phi_ser[ii]-min1)/dx1))
            if phi_bin < 0 or theta_bin < 0:
                raise IndexError
            hist[theta_bin][phi_bin] += dy
        except IndexError:
            base.Logger.log("out of range %f %f, dying" %(theta_ser[ii], phi_ser[ii]), base.Logger.CRITICAL)
            sys.exit(1)


    for ii in range(len(hist)):
        for jj in range(len(hist[ii])):
            f.write("%f %f %f\n" % ((ii+0.5)*dx2 + min2, (jj+0.5)*dx1 + min1, hist[ii][jj]))
            #f.write("%f " % hist[ii][jj]) # print changing phis across the x axis
        f.write("\n")
    f.close()
    base.Logger.log("wrote phi-theta plot to phi_theta_bin.dat", base.Logger.INFO)


    # print phi and theta series

    base.Logger.log("%d wrong isomer configurations" % wrong_isomer_count, base.Logger.INFO)
    base.Logger.log("Isomer information written to isomers.inf", base.Logger.INFO)

main()
