from pylab import *
from numpy import exp
from numpy import log
from numpy import array
from copy import deepcopy
from scipy.optimize import curve_fit
import math
import numpy as np

#DEFINE TranscriptionFactorClass
ACTIVATOR = 0
REPRESSOR = 1



def actTransfer(x, a, b):
  '''Activator transfer function'''
  return a*x**(1/b)

def repTransfer(x,a,b,c):
  '''Repressor transfer function'''
  return a**(-b*x)+c
  
class AvgTranscriptionFactor:
  '''Average transcription factor'''
  eff1 = None
  eff2 = None
  xmax = 0
  lbl = ''
    
        
  def __init__(self,e1,e2,xmx, l):
    self.eff1 = e1
    self.eff2 = e2
    self.xmax = xmx
    self.lbl = l
      
  def plotTransferFunctionAct(self):
    datax = np.linspace(0,self.xmax, 1000)
    prs1 = self.eff1.parsOpt
    prs2 = self.eff2.parsOpt
    datay = (actTransfer(datax, *prs1)+actTransfer(datax, *prs2))/2.0;
    suptitle('')
    c = 'k'
    plot(datax, datay, c+'-', label=self.lbl)
    xlabel('Activator/Reporter input ratio')
    ylabel('Activation rate')
    legend(loc="lower right")
    

  def plotTransferFunctionRep(self):
    datax = np.linspace(0,self.xmax, 1000)
    prs1 = self.eff1.parsOpt
    prs2 = self.eff2.parsOpt
    datay = (repTransfer(datax, *prs1)+repTransfer(datax, *prs2))/2.0;
    suptitle('')
    c = 'k'
    plot(datax, datay, c+'-', label=self.lbl)
    xlabel('Repressor/Reporter input ratio')
    ylabel('Repression rate')
    legend()
      
  
  def getTransferRate(self,r):
    r = (self.eff1.getTransferRate(r)+self.eff1.getTransferRate(r))/2.0;
    r = int(r > 0)*r
    return r


class TranscriptionFactor:
  '''TranscriptionFactor plasmid group instance'''
  '''
    tag             : name tag
    mass            : mass status
    dataFile        : measurement data
    inducer         : bound inducer
    dataChart       : measurement data
    dataRLA         : relative luciferase activity average
    dataRLAstd      :  relative luciferase activity average
    dataEffUnq      : unique TranscriptionFactor values
    dataRepUnq      : unique product values
    dataRatio       : lookup-table for TranscriptionFactor-Reporter ratio (ratio)
    dataTransfer    : lookup-table for TranscriptionFactor-Reporter ratio (value)
    optPars         : optimal parameters for tranfer function
    optCov          : optimal covariance
  '''
  tag = ''
  mass = 0
  targetProduct = None
  dataFile      = None
  inducer       = None
  dataChart     = []
  dataRLA       = []
  dataRLAstd    = []
  dataEffUnq    = []
  dataRepUnq    = []
  dataRatio     = []
  dataTransfer  = []
  dictTransfer  = {}
  optPars       = []
  optCov        = []
  dataNoise     = 0
 
  def __init__(self, tg, f):
    self.tag = '['+tg+']'
    self.dataFile      = f
    self.mass          = 0
    self.inducer       = None
    self.targetProduct = None
    self.dataChart     = []
    self.dataRLA       = []
    self.dataRLAstd    = []
    self.dataEffUnq    = []
    self.dataRepUnq    = []
    self.dataRatio     = []
    self.dataTransfer  = []
    self.dictTransfer  = {}
    
    #self.unitMass      = exEffMass
    #self.count         = self.mass/exEffMass
    
    
    #SERVER optimization
    '''
    if self.getTag() == '[TAL-B:KRAB]':
      self.parsOpt = [ 4.78563811, 1.79830432,  0.04734081]
      self.effClass = REPRESSOR
    elif self.getTag() == '[TAL-A:KRAB]':
      self.parsOpt = [ 8.54972992,  4.04837695,  0.03521075]
      self.effClass = REPRESSOR
    elif self.getTag() == '[TAL-B:VP16]':
      self.parsOpt = [ 1286.65504068,     2.01045855]
      self.effClass = ACTIVATOR
    elif self.getTag() == '[TAL-A:VP16]':
      self.parsOpt =       [ 1729.33769067,     3.84351035]
      self.effClass = ACTIVATOR
    return
    '''
    
    fp = open(f,'r')
    for line in fp:
      v = [float(e) for e in line.split('\t')]
      self.dataChart.append(v)
      

    #Noise calc
    aNoise = 0.0
    for v in self.dataChart:
      aNoise = aNoise + v[4]
      
    self.dataNoise = aNoise/len(self.dataChart);  


    effUnq = []
    repUnq = []
    for v in self.dataChart:
      if effUnq.count(v[0]) == 0:
        effUnq.append(v[0])
      if repUnq.count(v[1]) == 0:
        repUnq.append(v[1])
    
    effUnq.sort()
    repUnq.sort()
    

    avgRLA = 0.0
    stdRLA = 0.0
    avgRLAspace  = []
    stdRLAspace  = []
    for eu in effUnq:
      v1 = []
      v2 = []      
      for ru in repUnq:
        d = [v[2]/v[3] for v in self.dataChart if v[0] == eu and v[1] == ru]
        if len(d) > 0:
          v1.append(mean(d))
          v2.append(std(d))
        else:
          v1.append(avgRLA)
          v2.append(stdRLA)

      avgRLAspace.append(v1)
      stdRLAspace.append(v2)

    self.dataEffUnq = effUnq
    self.dataRepUnq = repUnq
    self.dataRLA = avgRLAspace
    self.dataRLAstd = stdRLAspace

    f = 0
    for ru in repUnq:
      for eu in effUnq:
        f = self.dataRLA[self.dataEffUnq.index(eu)][self.dataRepUnq.index(ru)]
        self.dictTransfer[(eu,ru)] = f

    for ru in repUnq:
      for eu in effUnq:
          if self.dictTransfer[(eu,ru)] > 0 and ru > 0:
            x = eu/ru
            y = self.dictTransfer[(eu,ru)]/self.dictTransfer[(0,ru)]
            self.dataRatio.append(x)
            self.dataTransfer.append(y)

    tmpratio = deepcopy(self.dataRatio)
    tmpx = []
    tmpy = []
    for x in sorted(list(set(self.dataRatio))):
      inx = []
      j = 0
      for tx in tmpratio:
        if x == tx:
          inx.append(j)
        j = j+1
      y = mean([self.dataTransfer[i] for i in inx])
      tmpx.append(x)
      tmpy.append(y)
    
    self.dataRatio    = array(tmpx)
    self.dataTransfer = array(tmpy)
    
    #Determine TranscriptionFactor class (activator/repressor)
    actThreshold = 1.0
    if mean(self.dataTransfer) > actThreshold:
      self.effClass = ACTIVATOR
      self.parsOpt, self.parsCov = curve_fit(actTransfer, self.dataRatio, self.dataTransfer)
	  
      #print self.getTag()
      #print self.parsOpt
      #print '---'
    else:
      self.effClass = REPRESSOR
      self.parsOpt, self.parsCov = curve_fit(repTransfer, self.dataRatio, self.dataTransfer)
      
      #print self.getTag()
      #print self.parsOpt
      #print '---'
      
    return
    
  def plotTransferFunction(self):
    datax = np.linspace(0,self.dataRatio[len(self.dataRatio)-1], 1000)
    datax = np.linspace(0,2.0, 1000)
    
    if self.effClass == ACTIVATOR:
      datay = actTransfer(datax, *self.parsOpt)
    else:
      datay = repTransfer(datax, *self.parsOpt)
      
    fig = figure()
    fig.suptitle(self.getTag())
    plot(self.dataRatio, self.dataTransfer, 'ko', label="Original Data")
    plot(datax, datay, 'r-', label="Fitted Curve")
    xlabel('Transcription Factor/Reporter Ratio')
    ylabel('Effect Rate')
    legend()
    
  def plotTransferFunctionAct(self):
    datax = np.linspace(0,self.dataRatio[len(self.dataRatio)-1], 1000)
    if self.effClass == ACTIVATOR:
      datay = actTransfer(datax, *self.parsOpt)
    else:
      datay = repTransfer(datax, *self.parsOpt)

     
    dmean = mean(datay)
    print dmean
    if self.getTag() == '[TAL-A:VP16]':
      c = 'r'
    else:
      c = 'b'
    plot(self.dataRatio, self.dataTransfer, c+'o', label="Original Data ({0})".format(self.getTag()))
    plot(datax, datay, c+'-', linestyle='dashed', label="Activation Rate ({0})".format(self.getTag()))
    xlabel('Activator/Reporter input ratio')
    ylabel('Activation rate')
    legend(loc="lower right")
    
  def plotTransferFunctionRep(self):
    #datax = np.linspace(0,self.dataRatio[len(self.dataRatio)-1], 1000)
    XMAX = 3
    datax = np.linspace(0,XMAX, 1000)
    if self.effClass == ACTIVATOR:
      datay = actTransfer(datax, *self.parsOpt)
    else:
      datay = repTransfer(datax, *self.parsOpt)
      
    #fig = figure()
    suptitle('')
    if self.getTag() == '[TAL-A:KRAB]':
      c = 'r'
    else:
      c = 'b'
    xm = [i for i in self.dataRatio if i < XMAX]
    ym = self.dataTransfer[0:len(xm)]
    plot(xm, ym , c+'o', label="Original Data ({0})".format(self.getTag()))
    plot(datax, datay, c+'-', label=self.getTag(), linestyle='dashed')
    xlabel('Repressor/Reporter input ratio')
    ylabel('Repression rate')
    legend()
    
  def getTransferRate(self,ratio):
    if self.effClass == ACTIVATOR:
      r = actTransfer(ratio, *self.parsOpt)
      r = int(r > 0)*r
      return r
    else:
      r = repTransfer(ratio, *self.parsOpt)
      r = int(r > 0)*r
      return r
  
  def getTransferNoise(self):
    return self.dataNoise;
  
  def getDataChart(self):
    return np.array(self.dataChart)

  def getdataEffUnq(self):
    return np.array(self.dataEffUnq)

  def getDataRepUnq(self):
    return np.array(self.dataRepUnq)

  def getMass(self):
    return self.mass

  def getDataRLA(self):
    return np.array(self.dataRLA)
    
  def getTag(self):
    return self.tag
    
  def __str__(self):
    return self.getTag() + " TranscriptionFactor instance at "+str(id(self))+", amount: "+str(self.mass);