import sys
import math
from numpy import *
from scipy import *
from bricks2 import *

class MSQfeedbackSwitch:

  COPIES    = []
  TIME      = []
  PIR       = []
  ETR       = []
  NOISE     = True
  DEBUG     = True
  TFEQ      = False
  SSP       = 1e-5
  SSD       = 0.5
  CAPACITY  = 100;
  MASS      = 200;
  
  
  def __init__(self, debug, time, noise, copies, signals):
    self.DEBUG = debug
    self.TIME = range(time);
    self.NOISE = noise
    self.COPIES = copies
    self.ETR = array([0 for i in self.TIME])
    self.PIR = array([0 for i in self.TIME])
    
    i = 0
    for j in signals[0]:
      self.PIR[j:j+signals[1][i]] = 1 
      i = i + 1
    
    i = 0
    for j in signals[2]:
      self.ETR[j:j+signals[3][i]] = 1
      i = i + 1
    
    #TODO: protokol za porazdeljevanje DNA glede na ng
    #TODO: racunanje porazdelitve plazmidov v celici (det.)
    
    
    sc = float(sum(self.COPIES))
    for i in range(len(self.COPIES)):
      self.COPIES[i] = self.COPIES[i]/sc * self.CAPACITY;
    
    
    #TODO: kontrola pravilnih gramatur
    #TODO: pozanimaj se glede hitrosti vezave TAL,PIP,ETR/ratio
    #TODO: documentation

  def plotResults(self,BFP,MCT,TALA_KRAB,TALB_KRAB,TALA_VP16,TALB_VP16,PIR,ETR, SS_BFP, SS_MCT, SS_AMB, SS_TIME):
    
    if sys.version_info[0] > 2:
      return

    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    from matplotlib.figure import Figure
    from matplotlib.dates import DateFormatter
    from random import random
    
    NONCESIZE = 10000;
    NONCE = int(random()*NONCESIZE)

    if self.DEBUG:
      f0 = figure()
      f1 = figure()
      f2 = figure()
      f3 = figure()
      f4 = figure()
      
    else:
      f0 = Figure()
      f1 = Figure()
      f2 = Figure()
      f3 = Figure()
      f4 = Figure()
      
    f0.set_size_inches(15,4.5)
    f1.set_size_inches(15,4.5)
    f2.set_size_inches(15,4.5)
    f3.set_size_inches(15,4.5)
    f4.set_size_inches(15,4.5)

    
    
    #Prepare for drawing
    m1 = float(max(BFP))
    m2 = float(max(MCT))
    m12 = max(m1,m2)
    m3 = float(max(TALA_KRAB))
    m4 = float(max(TALB_KRAB))
    m34 = max(m3,m4)
    m5 = float(max(TALA_VP16))
    m6 = float(max(TALB_VP16))
    m56 = max(m5,m6)
    
    BFP = [100*i/m12 for i in BFP]
    MCT = [100*i/m12 for i in MCT]
    TALA_KRAB = [100*i/m34 for i in TALA_KRAB]
    TALB_KRAB = [100*i/m34 for i in TALB_KRAB]
    TALA_VP16 = [100*i/m56 for i in TALA_VP16]
    TALB_VP16 = [100*i/m56 for i in TALB_VP16]
    
    
      
    f0.set_facecolor('white')
    s0=f0.add_subplot(121)
    s0.plot(MCT,  label='mCitrine', color='green')
    s0.plot(BFP,   label='BFP', color='blue')
    s0.set_xlabel('Time [min]')
    s0.set_ylabel('Relative level [%]')
    s0.legend(loc="upper right")
    s0.set_ylim([0,120])
    s0.grid(True)
    lg = s0.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

    
    f1.set_facecolor('white')
    s1=f1.add_subplot(121)
    s1.plot(TALA_KRAB,  label='TAL-A:KRAB', color='green')
    s1.plot(TALB_KRAB,   label='TAL-B:KRAB', color='blue')
    s1.set_xlabel('Time [min]')
    s1.set_ylabel('Relative level [%]')
    s1.legend(loc="upper right")
    s1.set_ylim([0,120])
    s1.grid(True)
    lg = s1.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
     
    f2.set_facecolor('white')
    s2=f2.add_subplot(121)
    s2.plot(TALB_VP16,  label='TAL-B:VP16', color='green')
    s2.plot(TALA_VP16,   label='TAL-A:VP16', color='blue')
    s2.set_xlabel('Time [min]')
    s2.set_ylabel('Relative level [%]')
    s2.set_ylim([0,120])
    s2.legend(loc="upper right")
    s2.grid(True)
    lg = s2.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    
    
    f3.set_facecolor('white')
    s3=f3.add_subplot(121)
    s3.plot(ETR,   label='Erithromycin', color='green')
    s3.plot(PIR,  label='Pristinamycin', color='blue')
    s3.set_xlabel('Time [min]')
    s3.set_ylabel('Signal [A.U.]')
    s3.set_ylim([-0.1,1.1])
    s3.legend(loc="upper right")
    s3.grid(True)
    lg = s3.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    

    
    s4=f4.add_subplot(121)
    
    s4.plot(BFP, MCT, 'c.', label="Phase")
    
    s4leg = [False,False,False]
    for st in SS_TIME:
      if SS_AMB.count([BFP[st],MCT[st]]) >0:
        pcolor = 'yellow'
        plabel = 'Stable state (AMB)'
        if not s4leg[0]:
          s4.plot(BFP[st],MCT[st], 'd', color=pcolor,label=plabel)
          s4leg[0] = True
        else:
          s4.plot(BFP[st],MCT[st], 'd', color=pcolor)
      elif BFP[st] > MCT[st]:
        pcolor = 'blue'
        plabel = 'Stable state (BFP)'
        if not s4leg[1]:
          s4.plot(BFP[st],MCT[st], 'd', color=pcolor,label=plabel)
          s4leg[1] = True
        else:
          s4.plot(BFP[st],MCT[st], 'd', color=pcolor)
      else:
        pcolor = 'green'
        plabel = 'Stable state (MCT)'
        if not s4leg[2]:
          s4.plot(BFP[st],MCT[st], 'd', color=pcolor,label=plabel)
          s4leg[2] = True
        else:
          s4.plot(BFP[st],MCT[st], 'd', color=pcolor)
      
      
      
    s4.plot(linspace(-max(BFP)*0.1,max(BFP)*1.1,len(BFP)), [0 for i in range(len(BFP))], linestyle='dashed', color='gray')
    s4.plot([0 for i in range(len(BFP))], linspace(-max(MCT)*0.1,max(MCT)*1.1,len(MCT)), linestyle='dashed', color='gray')

    s4.set_xlim(-max(BFP)*0.1,max(BFP)*1.1)
    s4.set_ylim(-max(MCT)*0.1,max(MCT)*1.1)

    s4.set_xlabel('BFP relative level [%]')
    s4.set_ylabel('mCitrine relative level [%]')
    s4.legend(loc="upper right")
    s4.grid(True)
    f4.set_facecolor('white')   
    lg = s4.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    #lg.get_frame().set_alpha(0.5)
    
    
    if self.DEBUG:
      path = './out/'+self.__class__.__name__
      paths = [path+'_s0_'+str(NONCE)+'.png', path+'_s1_'+str(NONCE)+'.png', path+'_s2_'+str(NONCE)+'.png', path+'_s3_'+str(NONCE)+'.png', path+'_s4_'+str(NONCE)+'.png']
      f0.savefig(paths[0], format ='png')
      f1.savefig(paths[1], format ='png')
      f2.savefig(paths[2], format ='png')
      f3.savefig(paths[3], format ='png')
      f4.savefig(paths[4], format ='png')
    
    else:
      path = '/var/www/imodel/img/'+self.__class__.__name__
      paths = [path+'_s0_'+str(NONCE)+'.png', path+'_s1_'+str(NONCE)+'.png', path+'_s2_'+str(NONCE)+'.png', path+'_s3_'+str(NONCE)+'.png', path+'_s4_'+str(NONCE)+'.png']
      canvas=FigureCanvas(f0)
      f0.savefig(paths[0], format ='png')
      canvas=FigureCanvas(f1)
      f1.savefig(paths[1], format ='png')
      canvas=FigureCanvas(f2)
      f2.savefig(paths[2], format ='png')
      canvas=FigureCanvas(f3)
      f3.savefig(paths[3], format ='png')
      canvas=FigureCanvas(f4)
      f4.savefig(paths[4], format ='png')
    return paths
    
  def run(self):
    
    if self.DEBUG:
      talBkrab = TranscriptionFactor('TAL-B:KRAB', 'dat/talBkrabcmv.dat')
      talAkrab = TranscriptionFactor('TAL-A:KRAB', 'dat/talAkrabcmv.dat')
      talBvp16 = TranscriptionFactor('TAL-B:VP16', 'dat/talBvp16min.dat')
      talAvp16 = TranscriptionFactor('TAL-A:VP16', 'dat/talAvp16min.dat')
      priAct = TranscriptionFactor('PIP:Act-System', 'dat/prismin.dat')
      ectAct = TranscriptionFactor('E:Act-System', 'dat/ekrabmin.dat')
      priRep = TranscriptionFactor('PIP:Rep-System', 'dat/pipcmv.dat')
      ectRep = TranscriptionFactor('E:Rep-System', 'dat/etrcmv.dat')
      
      '''
      #f1 = figure()
      ectAct.plotTransferFunction()
      grid(True)
      legend(loc="lower right")
      #f2 = figure()
      ectRep.plotTransferFunction()
      ylim([0.0,1.0])
      grid(True)
      #f3 = figure()
      talAkrab.plotTransferFunction()
      grid(True)
      #f4 = figure()
      talAvp16.plotTransferFunction()
      legend(loc="lower right")
      grid(True)
      show()
      quit()
      '''
      
    else:
      talAkrab = TranscriptionFactor('TAL-B:KRAB', '/home/root/cellulator-mirror/dat/talBkrabcmv.dat')
      talBkrab = TranscriptionFactor('TAL-A:KRAB', '/home/root/cellulator-mirror/dat/talAkrabcmv.dat')
      talAvp16 = TranscriptionFactor('TAL-B:VP16', '/home/root/cellulator-mirror/dat/talBvp16min.dat')
      talBvp16 = TranscriptionFactor('TAL-A:VP16', '/home/root/cellulator-mirror/dat/talAvp16min.dat')
      priAct = TranscriptionFactor('PIP:Act-System', '/home/root/cellulator-mirror/dat/prismin.dat')
      ectAct = TranscriptionFactor('E:Act-System', '/home/root/cellulator-mirror/dat/ekrabmin.dat')
      priRep = TranscriptionFactor('PIP:Rep-System', '/home/root/cellulator-mirror/dat/pipcmv.dat')
      ectRep = TranscriptionFactor('E:Rep-System', '/home/root/cellulator-mirror/dat/etrcmv.dat')
      
 

    #Inducible system data
    rqe = self.COPIES[9]/(self.COPIES[6]*(self.COPIES[2]+self.COPIES[3]))
    rqp = self.COPIES[8]/(self.COPIES[4]*(self.COPIES[0]+self.COPIES[1]))
    rae = self.COPIES[9]/(self.COPIES[7]*(self.COPIES[2]+self.COPIES[3]))
    rap = self.COPIES[8]/(self.COPIES[5]*(self.COPIES[0]+self.COPIES[1]))

    #Induction system transfer rates
    qETR = ectRep.getTransferRate(rqe);
    qPIR = priRep.getTransferRate(rqp);
    aETR = ectAct.getTransferRate(rae);
    aPIR = priAct.getTransferRate(rap);
    noiqETR = ectRep.getTransferNoise();
    noiqPIR = priRep.getTransferNoise();
    noiaETR = ectAct.getTransferNoise();
    noiaPIR = priAct.getTransferNoise();

    '''
    print rqe,qETR,noiqETR
    print rqp,qPIR,noiqPIR
    print rae,aETR,noiaETR
    print rap,aPIR,noiaPIR
    print '---'
    '''
    
    
    #The switch datas
    #Include inducible system as well?
    rqA = (self.COPIES[2]+self.COPIES[6])/(self.COPIES[0]+self.COPIES[1])
    rqB = (self.COPIES[0]+self.COPIES[4])/(self.COPIES[2]+self.COPIES[3])
    raA = (self.COPIES[1]+self.COPIES[5])/(self.COPIES[0]+self.COPIES[1])
    raA = (self.COPIES[1]+self.COPIES[5])/(self.COPIES[0])
    raB = (self.COPIES[3]+self.COPIES[7])/(self.COPIES[2]+self.COPIES[3])
    raB = (self.COPIES[3]+self.COPIES[7])/(self.COPIES[2])


    #TAL transfer rates
    qA = talAkrab.getTransferRate(rqA)
    qB = talBkrab.getTransferRate(rqB)
    aA = talAvp16.getTransferRate(raA)
    aB = talBvp16.getTransferRate(raB)
    noiqA = talAkrab.getTransferNoise()
    noiqB = talBkrab.getTransferNoise()
    noiaA = talAvp16.getTransferNoise()
    noiaB = talBvp16.getTransferNoise()

    
    #Array initialiazation
    BFP = array([0.0 for i in self.TIME])
    MCT = array([0.0 for i in self.TIME])
    TALA_KRAB = array([0.0 for i in self.TIME])
    TALA_VP16= array([0.0 for i in self.TIME])
    TALB_KRAB = array([0.0 for i in self.TIME])
    TALB_VP16 = array([0.0 for i in self.TIME])
    
    #Stable state stacks
    SS_BFP  = []
    SS_MCT  = []
    SS_AMB  = []
    SS_TIME = []
    
    #Stable state discovery
    sslock = False
    sstate = False
    sshigh = 0


    #Arrange to get proper results
    #Activator in repressor povtrocita enak nivo kot kontrola
    RATES = [0.01,10.0]
    kdp = 0.02;


    #Time to reach desired repression level, subject to adjustment
    mA = 1.0*300
    mB = 1.0*300
    nA = 1.0*300
    nB = 1.0*300

    
    

    for t in self.TIME:
      if t == 0:
        continue
      
      
      #SWITCH
      #Initial rates
      dBFP = self.COPIES[0] * RATES[0]
      dMCT = self.COPIES[2] * RATES[0]
      dTALB_KRAB = self.COPIES[0] * RATES[0]
      dTALA_KRAB = self.COPIES[2] * RATES[0]
      dTALB_VP16 = self.COPIES[3] * RATES[0]
      dTALA_VP16 = self.COPIES[1] * RATES[0]
      
      
      
      #Noisia
      qAy = qA + int(self.NOISE)*qA*noiqA*random()*sign(0.5 - random())
      qBy = qB + int(self.NOISE)*qB*noiqB*random()*sign(0.5 - random())
      aAy = aA + int(self.NOISE)*aA*noiaA*random()*sign(0.5 - random())
      aBy = aB + int(self.NOISE)*aB*noiaB*random()*sign(0.5 - random())
      qETRy = qETR + int(self.NOISE)*qETR*noiqETR*random()*sign(0.5 - random())
      qPIRy = qPIR + int(self.NOISE)*qPIR*noiqPIR*random()*sign(0.5 - random())
      aETRy = aETR + int(self.NOISE)*aETR*noiaETR*random()*sign(0.5 - random())
      aPIRy = aPIR + int(self.NOISE)*aPIR*noiaPIR*random()*sign(0.5 - random())
      
      
      
      #Activation
      #Safety belt and no induction
      if TALA_VP16[t-1]>0:
        dBFP = aAy*dBFP * (TALA_VP16[t-1]/(nA+TALA_VP16[t-1]))
        dTALB_KRAB = aAy*dTALB_KRAB * (TALA_VP16[t-1]/(nA+TALA_VP16[t-1]))
        dTALA_VP16 = aAy*dTALA_VP16 * (TALA_VP16[t-1]/(nA+TALA_VP16[t-1]))
      if TALB_VP16[t-1]>0:  
        dMCT = aBy*dMCT* (TALB_VP16[t-1]/(nB+TALB_VP16[t-1]))
        dTALA_KRAB = aBy*dTALA_KRAB* (TALB_VP16[t-1]/(nB+TALB_VP16[t-1]))
        dTALB_VP16 = aBy*dTALB_VP16* (TALB_VP16[t-1]/(nB+TALB_VP16[t-1]))
      
      
      #Induction
      #TAL-only!
      if self.PIR[t] == 1:
        dTALB_KRAB = dTALB_KRAB * aPIRy
        dTALA_VP16 = dTALA_VP16 * aPIRy
        dTALA_KRAB = dTALA_KRAB * qPIRy
        dTALB_VP16 = dTALB_VP16 * qPIRy
     
      if self.ETR[t] == 1:
        dTALB_KRAB = dTALB_KRAB * qETRy
        dTALA_VP16 = dTALA_VP16 * qETRy
        dTALA_KRAB = dTALA_KRAB * aETRy
        dTALB_VP16 = dTALB_VP16 * aETRy
        
      #Repression rates
      kAy = -1 + qA**-1
      kBy = -1 + qB**-1
        
      #Repression
      dBFP = dBFP * (1+kAy*TALA_KRAB[t-1]/(TALA_KRAB[t-1]+mA))**-1
      dTALB_KRAB = dTALB_KRAB * (1+kAy*TALA_KRAB[t-1]/(TALA_KRAB[t-1]+mA))**-1
      dTALA_VP16 = dTALA_VP16 * (1+kAy*TALA_KRAB[t-1]/(TALA_KRAB[t-1]+mA))**-1
        
      dMCT = dMCT * (1+kBy*TALB_KRAB[t-1]/(TALB_KRAB[t-1]+mB))**-1
      dTALA_KRAB = dTALA_KRAB * (1+kBy*TALB_KRAB[t-1]/(TALB_KRAB[t-1]+mB))**-1
      dTALB_VP16 = dTALB_VP16 * (1+kBy*TALB_KRAB[t-1]/(TALB_KRAB[t-1]+mB))**-1
      
       
      #Degradation
      dBFP -= kdp * BFP[t-1]
      dMCT -= kdp * MCT[t-1]
      dTALA_KRAB -= kdp * TALA_KRAB[t-1]
      dTALB_KRAB -= kdp * TALB_KRAB[t-1]
      dTALA_VP16 -= kdp * TALA_VP16[t-1]
      dTALB_VP16 -= kdp * TALB_VP16[t-1]
      
      
      #Sum up
      BFP[t] = BFP[t-1] + dBFP
      MCT[t] = MCT[t-1] + dMCT
      TALA_KRAB[t] = TALA_KRAB[t-1] + dTALA_KRAB
      TALB_KRAB[t] = TALB_KRAB[t-1] + dTALB_KRAB
      TALA_VP16[t] = TALA_VP16[t-1] + dTALA_VP16
      TALB_VP16[t] = TALB_VP16[t-1] + dTALB_VP16
      
      
      
      #Stable states discovery
      sstate = True
      sshigh = max(BFP[t-1],MCT[t-1])
      if BFP[t-1] == 0:
        sstate = sstate and (dBFP == 0)
      else:
        sstate = sstate and (math.fabs(float(dBFP)/sshigh) < self.SSP) 

      if MCT[t-1] == 0:
        sstate = sstate and (dMCT ==0)
      else:
        sstate =  sstate and (math.fabs(float(dMCT)/sshigh) < self.SSP)
      
      #Lock acquisition
      if not sstate:
        sslock = False
      
      #Stable state classification
      if sstate and not sslock:
        if BFP[t-1] > MCT[t-1]:
          if MCT[t-1]/BFP[t-1] < self.SSD:
            #BFP state
            SS_BFP.append([BFP[t-1],MCT[t-1]])
          else:
            #Ambigous state
            SS_AMB.append([BFP[t-1],MCT[t-1]])
        elif MCT[t-1] > BFP[t-1]:
          if BFP[t-1]/MCT[t-1] < self.SSD:
            #MCT state
            SS_MCT.append([BFP[t-1],MCT[t-1]])
          else:
            #Ambigous state
            SS_AMB.append([BFP[t-1],MCT[t-1]])
        else:
          #Ambigous state
          SS_AMB.append([BFP[t-1],MCT[t-1]])

        #Stable state discovered;
        '''
        print '-------- t=', t
        print BFP[t-1], dBFP, math.fabs(dBFP)/BFP[t-1]
        print MCT[t-1], dMCT, math.fabs(dMCT)/MCT[t-1]
        '''
        SS_TIME.append(t)
        sslock = True
          
      #Keep above zero
      BFP[t] = int(BFP[t] > 0)*BFP[t];
      MCT[t] = int(MCT[t] > 0)*MCT[t];
      TALA_KRAB[t] = int(TALA_KRAB[t] > 0)*TALA_KRAB[t];
      TALB_KRAB[t] = int(TALB_KRAB[t] > 0)*TALB_KRAB[t];
      TALA_VP16[t] = int(TALA_VP16[t] > 0)*TALA_VP16[t];
      TALB_VP16[t] = int(TALB_VP16[t] > 0)*TALB_VP16[t];
      


    if self.DEBUG:
    
      figure(1)
      subplot(2,2,1)
      title('Reporters')
      plot(BFP,'blue')
      plot(MCT,'green')
      plot(SS_TIME, [max(BFP[i], MCT[i]) for i in SS_TIME], 'd', color='yellow')
      grid(True)
      
      subplot(2,2,2)
      title('Induction')
      plot(self.ETR,'green')
      plot(self.PIR,'blue')
      ylim([-0.1,1.1])
      grid(True)

      subplot(2,2,3)
      plot(TALA_KRAB,'green')
      plot(TALB_KRAB,'blue')
      title('TAL:KRAB')
      grid(True)
      
      '''
      subplot(2,2,4)
      plot(TALA_VP16,'blue')
      plot(TALB_VP16,'green')
      title('TAL:VP16')
      '''
      
      subplot(2,2,4)
      plot(BFP,MCT,color='black')
      
      for st in SS_TIME:
        if SS_AMB.count([BFP[st],MCT[st]]) >0:
          pcolor = 'yellow'
        elif BFP[st] > MCT[st]:
          pcolor = 'blue'
        else:
          pcolor = 'green'
        plot(BFP[st],MCT[st], 'd', color=pcolor)
      
      
      
      plot(linspace(-max(BFP)*0.1,max(BFP)*1.1,len(BFP)), [0 for i in range(len(BFP))], linestyle='dashed', color='gray')
      plot([0 for i in range(len(BFP))], linspace(-max(MCT)*0.1,max(MCT)*1.1,len(MCT)), linestyle='dashed', color='gray')
        
      title('Phase')
      xlabel('BFP')
      ylabel('MCT')
      xlim(-max(BFP)*0.1,max(BFP)*1.1)
      ylim(-max(MCT)*0.1,max(MCT)*1.1)
      grid(True)
      show()
      
    else:
      return self.plotResults(BFP,MCT,TALA_KRAB,TALB_KRAB,TALA_VP16,TALB_VP16,self.PIR,self.ETR, SS_BFP, SS_MCT, SS_AMB, SS_TIME)