#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 31 11:19:11 2017

@author: udacity cs222 adapted by curzua 
"""

# reusing using code from udacity course: Differential equations in action cs222
 
import numpy as np
import matplotlib.pyplot as plt

def si4r_model(h, 
              transmission_coeff1,
              transmission_coeff2, 
              infectious_time1,
              infectious_time2,
              virulence1,
              virulence2,
              fitness,
              spike_in_timepoint,
              end_time):
    
    num_steps = int(end_time / h)

    s = np.zeros(num_steps + 1)
    i1 = np.zeros(num_steps + 1)
    i2 = np.zeros(num_steps + 1)
    sr1 = np.zeros(num_steps + 1)
    sr2 = np.zeros(num_steps + 1)
    i1p = np.zeros(num_steps + 1)
    i2p = np.zeros(num_steps + 1)
    r = np.zeros(num_steps + 1)
    d = np.zeros(num_steps + 1)

    s[0] = 1e8 - 1e5 - 1e5
    i1[0] = 1e5
    i2[0] = 1e5
    sr1[0] = 0
    sr2[0] = 0
    i1p[0] = 0
    i2p[0] = 0
    r[0] = 0
    d[0] = 0

    for step in range(num_steps):
        
        s2i1 = h * transmission_coeff1 * (s[step]) * (i1[step] + i1p[step])
        s2i2 = h * transmission_coeff2 * (s[step]) * (i2[step] + i2p[step])
        
        i1_2sr1 = h / infectious_time1 * i1[step]
        i2_2sr2 = h / infectious_time2 * i2[step]
        
        sr2_2_i1p = h * transmission_coeff1 * sr2[step] * (i1[step] + i1p[step])
        sr1_2_i2p = h * transmission_coeff2 * sr1[step] * (i2[step] + i2p[step])
        
        i1p_2_r =  h / infectious_time1 * i1p[step]
        i2p_2_r =  h / infectious_time2 * i2p[step]
        
        ###
        
        i1t_2_d = h * virulence1 * (i1[step] + i1p[step])
        i2t_2_d = h * virulence2 * (i2[step] + i2p[step])
        
        ####
        
        d_2_s = h * fitness * (d[step])
        
        all_div = s[step] + sr1[step] + sr2[step] + r[step]
        
        all_i1 = i1[step] + i1p[step]
        
        all_i2 = i2[step] + i2p[step]
        
        s[step + 1] = s[step] - s2i1 - s2i2 + s[step]*d_2_s/all_div
        
        i1[step + 1] = i1[step] + s2i1 - i1_2sr1 - i1t_2_d*i1[step]/all_i1
        i2[step + 1] = i2[step] + s2i2 - i2_2sr2 - i2t_2_d*i2[step]/all_i2
        
        sr1[step + 1] = sr1[step] + i1_2sr1 - sr1_2_i2p + sr1[step]*d_2_s/all_div
        sr2[step + 1] = sr2[step] + i2_2sr2 - sr2_2_i1p + sr2[step]*d_2_s/all_div
        
        i1p[step + 1] = i1p[step] + sr2_2_i1p- i1p_2_r - i1t_2_d*i1p[step]/all_i1
        i2p[step + 1] = i2p[step] + sr1_2_i2p - i2p_2_r - i2t_2_d*i2p[step]/all_i2
        
        r[step + 1] = r[step] + i2p_2_r + i1p_2_r + r[step]*d_2_s/all_div
        
        d[step + 1] = d[step] + i1t_2_d + i2t_2_d - d_2_s
        
    return s, i1, i2, sr1, sr2, i1p, i2p, r, d

h = 0.5 # time unit ( low for better resolution )

transmission_coeff1 = 5e-9 # 1 / time unit cell
transmission_coeff2 = 5e-9 # 1 / time unit cell

infectious_time1 = 5. # time units
infectious_time2 = 6. # time units

virulence1 = 0.1 # time units
virulence2 = 0.1 # time units

fitness = 0.01

spike_in_timepoint = 150 # step units
end_time = 200.0 # days

times = h * np.array(range(int(end_time / h) + 1))

s, i1, i2, sr1, sr2, i1p, i2p, r, d = si4r_model(h, 
              transmission_coeff1,
              transmission_coeff2, 
              infectious_time1,
              infectious_time2,
              virulence1,
              virulence2,
              fitness,
              spike_in_timepoint,
              end_time)


# Plot of the populations

fig = plt.figure(1)

plt.plot(times, s, label = 's')
plt.plot(times, i1, label = 'inf1')
plt.plot(times, i2, label = 'inf2')
plt.plot(times, sr1, label = 'sr1')
plt.plot(times, sr2, label = 'sr2')
plt.plot(times, i1p, label = 'i1p')
plt.plot(times, i2p, label = 'i2p')
plt.plot(times, r, label = 'r')
plt.plot(times, d, label = 'd')
plt.axvline(x=spike_in_timepoint, linestyle = ':', color = "black")
plt.legend(('s','i1','i2','sr1','sr2','i1p','i2p','r','d'), loc = 'upper right')
axes = plt.gca()
axes.set_xlabel('Time units')
axes.set_ylabel('Number of cells')

fig.savefig("raw_results.pdf")

fig = plt.figure(1)

plt.plot(times, s, label = 'suceptible')
plt.plot(times, i1 + i1p, label = 'infected_with_preprogrammed')
plt.plot(times, i2 + i2p, label = 'infected_with_non_preprogrammed')
plt.plot(times, sr1, label = 'Immune_to_preprogrammed')
plt.plot(times, sr2, label = 'Immune_to_non_preprogrammed')
plt.plot(times, r, label = 'immune_to_both')
plt.legend(('suceptible',
            'infected_with_preprogrammed',
            'infected_with_non_preprogrammed',
            'Immune_to_preprogrammed',
            'Immune_to_non_preprogrammed',
            'immune_to_both'), loc = 'upper right')
axes = plt.gca()
axes.set_xlabel('Time units')
axes.set_ylabel('Number of cells')

fig.savefig("compartments_both.pdf")

#### Plot of our color signature

fig = plt.figure(1)

rfp = s/2 + 0*i1/2 + i2/2 + 0*sr1/2 + sr2/2 + 0*i1p + 0*i2p/2 + 0*r/2
gfp = s/2 + i1/2 + i2/2 + sr1/2 + sr2/2 + i1p/2 + i2p/2 + r/2
plt.plot(times, gfp, 'g', label = 'GFP')
plt.plot(times, rfp, 'r', label = 'RFP')

plt.legend(('GFP', 'RFP'), loc = 'upper right')
axes = plt.gca()
axes.set_xlabel('Time units')
axes.set_ylabel('Color intensity')

fig.savefig("colors_both.pdf")