#! /usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import os
import math
import warnings
from Bio import PDB
from Bio import BiopythonWarning

warnings.simplefilter('ignore', BiopythonWarning)

def usage():
    print("\nUsage: python overlap.py -p object.pdb -t target.pdb [-r 3.5]")
    print("       This script calculate the OVERLAP between atom-entities between object and target pdb")
    print("       The atom/atom contact-cutoff is set 3.5A, which can be reset using -rcut  \n")
    print("       -p :: object pdf file, usually referred to CALCULATED data, also with -obj")
    print("       -t :: target pdf file, usually referred to EXPERIMENT data, also with -target")
    print("       -t :: when input \"xxxx*\", it compares all files starting with \"xxxx\"")
    print("       -r :: atom-atom contact cutoff, default 3.5A, also with -rcut")
    print("       -x :: find maixmum overlap, print ONLY the maixmum overlap result")
    print("       -c :: check details")
    print("Output: ")
    print("       OVERLAP> chain_Obj chain_Target >>> Precision Recall  >>> obj_conn nobj  ||  target_conn ntarget")
    print("                obj_conn --> atoms in Obj that is in contact with any Target atom (within rcutoff)")
    print("                target_conn --> atoms in Target that is in contact with any object atom (within rcutoff)")
    print("                nobj --> total number of atoms in Objct")
    print("                ntarget --> total number of atoms in Target")

    sys.exit()

def ppdist(i,j):

    d = (i[0]-j[0])**2 + (i[1]-j[1])**2 +(i[2]-j[2])**2;
    
    return math.sqrt(d);
def measure_overlap(obj,target,rcut):

    oconn = 0;
    for i in obj: 
        for j in target:
            dd = ppdist(i,j)
            if dd < rcut: 
               oconn += 1
               break
    tconn = 0; 
    for j in target:
        for i in obj:
            dd = ppdist(i,j)
            if dd < rcut: 
               tconn += 1
               break
#    print ("OVER> ", oconn,tconn, len(obj), len(target))

    return oconn,tconn, len(obj), len(target);

def matthews_correlation_coefficient(TP, TN, FP, FN):
    numerator = TP * TN - FP * FN
    denominator = np.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN))
    mcc = numerator / denominator
    return mcc   

def parsearg():

   rcutoff = 3.5
   objpdbf = ''; targetpdbf = ''; flag_maximum_show = 'false'; flag_check='false';

   opts=sys.argv; numargv=len(opts)-1; #opts includes one element of this program name
   i=1
   while i <= numargv:
      opt = opts[i]
      if(opt[0] != "-"):
         print("\ndonothing for option %s\n"% opt)
         i=i+1; continue
      elif(opt == "-"):
         print ("\nignored the %d unspecified option \"-\"\n"% i)
         i=i+1; continue

      if(opt == "-h" or opt == "--help"):
         usage()
      elif(opt == "-c" or opt == "-check"):
         flag_check='true'
         i=i+1
         continue

      elif(opt == "-x" or opt == "--maximum"):
         flag_maximum_show='true'
         i=i+1
         continue

      if(i == numargv):  #no following info input for this option
         print ("\n-1 need parameter-data for option: %s\n"% opt); exit();
      else:
         i=i+1; para = opts[i]

      if(opt == "-p" or opt == "-obj" ):
         objpdbf = para
      elif(opt == "-t" or opt == "-target"):
         targetpdbf = para
      elif(opt == "-r" or opt == "--rcut"):
         rcutoff = float(para)
#           
#       for more options .....
#
      i=i+1

   return (rcutoff,objpdbf,targetpdbf,flag_maximum_show,flag_check)

(rcutoff,objpdbf,targetpdbf,flag_maximum_show,flag_check)=parsearg();

if not os.path.exists(objpdbf):
   print (" -1  CANNOT find OBJ PDB file: $objpdbf\n"); exit()

parser = PDB.PDBParser()
objstruct = parser.get_structure('input', objpdbf)
maxx = -1
maxfile = 'MAXFILE_NOTDEFINED';maxinfo = 'MAXINFO_NOTDEFINED';

if targetpdbf.endswith("*"):  #MULTIPLE-TARGET-FILE-WITH-SAME-PREFIX-NAME
    prefix = targetpdbf.rsplit("*", 1)[0]
    current_dir = os.getcwd()
    matching_files = [f for f in os.listdir(current_dir) if f.startswith(prefix)]
    matching_files.sort()
    for file in matching_files:
        if not os.path.exists(file):
           print (" -1  CANNOT find TARGET PDB file: $targetpdbf\n"); exit()
        parser2 = PDB.PDBParser()
        targetstruct = parser2.get_structure('input', file)
        for objmodel in objstruct:      #object pdb
          for targetmodel in targetstruct:  #target pdb
            for targetchain in targetmodel: #one chain from target pdb
              target = [];
              for residue in targetchain:
                    for atom in residue:
                        XYZ = atom.get_coord();
                        target.append(XYZ);
        
              for objchain in objmodel: # one chain from object pdb
                obj = [];
                for residue in objchain:
                    for atom in residue:
                        XYZ = atom.get_coord(); 
                        obj.append(XYZ)
        
                (objconn,targetconn, nobj, ntarget) = measure_overlap(obj,target,rcutoff)
                precision = objconn/nobj
                recall = targetconn/ntarget
                if (flag_maximum_show == 'true'): 
                    if (maxx < precision):
                        maxx = precision
                        maxfile = file
                        maxinfo = "MAXOVERLAP> %s  %s  >>> %5.2f  %5.2f  ::  %4d %4d  ||  %4d  %4d" % (objchain.id,targetchain.id,precision,recall, objconn,nobj,targetconn,ntarget)           
                    if (flag_check == 'true'):
                        print (file," ---> ","OVERLAP> %s  %s  >>> %5.2f  %5.2f  ::  %4d %4d  ||  %4d  %4d" % (objchain.id,targetchain.id,precision,recall, objconn,nobj,targetconn,ntarget))
                else:
                    print (file," ---> ","OVERLAP> %s  %s  >>> %5.2f  %5.2f  ::  %4d %4d  ||  %4d  %4d" % (objchain.id,targetchain.id,precision,recall, objconn,nobj,targetconn,ntarget))
        
    if(flag_maximum_show == 'true'):
        print ()
        print (maxfile," ---> ", maxinfo,"\n")

else:   #FOR A SINGLE-TARGET-FILE
    if not os.path.exists(targetpdbf):
       print (" -1  CANNOT find TARGET PDB file: $targetpdbf\n"); exit()
    parser2 = PDB.PDBParser()
    targetstruct = parser2.get_structure('input', targetpdbf)
    for objmodel in objstruct:      #object pdb
      for targetmodel in targetstruct:  #target pdb
        for targetchain in targetmodel: #one chain from target pdb
          target = [];
          for residue in targetchain:
                for atom in residue:
                    XYZ = atom.get_coord();
                    target.append(XYZ);
    
          for objchain in objmodel: # one chain from object pdb
            obj = [];
            for residue in objchain:
                for atom in residue:
                    XYZ = atom.get_coord(); 
                    obj.append(XYZ)
    
            (objconn,targetconn, nobj, ntarget) = measure_overlap(obj,target,rcutoff)
            precision = objconn/nobj
            recall = targetconn/ntarget
            if (flag_maximum_show == 'true'): 
                if (maxx < precision):
                    maxx = precision
                    maxinfo = "MAXOVERLAP> %s  %s  >>> %5.2f  %5.2f  ::  %4d %4d  ||  %4d  %4d" % (objchain.id,targetchain.id,precision,recall, objconn,nobj,targetconn,ntarget)           
                if (flag_check == 'true'):
                    print ("OVERLAP> %s  %s  >>> %5.2f  %5.2f  ::  %4d %4d  ||  %4d  %4d" % (objchain.id,targetchain.id,precision,recall, objconn,nobj,targetconn,ntarget))
            else:
                print ("OVERLAP> %s  %s  >>> %5.2f  %5.2f  ::  %4d %4d  ||  %4d  %4d" % (objchain.id,targetchain.id,precision,recall, objconn,nobj,targetconn,ntarget))
    
    if(flag_maximum_show == 'true'):
        print ()
        print (targetpdbf," ---> ", maxinfo,"\n")
