import snlay as sn 
import numpy as np 
from scipy import interpolate
from noise import pnoise1
import random
from tmm import coh_tmm
import multiprocessing as mp

#import pyximport; pyximport.install(pyimport=True)
#import mtmm as mtmm



#make a materials dictionary
matsdict = {
  1: './materials/gold.dat',
  2: './materials/silicon.dat',
  3: './materials/silica.dat',
  4: './materials/tio2.dat',
  5: './materials/silver.dat'
}

def get_nk(datafile, wavelengths):
    """Reads the given file and returns the n+ik complex at
    the given wavelength after suitable interpolation
    :datafile: TODO
    :wavelength: TODO
    :returns: TODO
    """
    rawdisp = np.loadtxt(datafile)
    f_r = interpolate.interp1d(rawdisp[:,0], rawdisp[:,1])
    f_i = interpolate.interp1d(rawdisp[:,0], rawdisp[:,2])
    return f_r(wavelengths) + 1j*f_i(wavelengths)

def nk_2_eps(n):
    """TODO: Docstring for nk_2_epsnk_2_eps.
    :returns: complex epsilon given n and kappa 
    """
    eps_r = n.real**2 - n.imag**2
    eps_i = 2*n.real*n.imag
    return eps_r + 1j*eps_i

def eps_2_nk(eps):
    """TODO: Docstring for nk_2_epsnk_2_eps.
    :returns: complex epsilon given n and kappa 
    """
    modeps = np.abs(eps)
    n_r = np.sqrt(0.5*(modeps + eps.real)) 
    n_i = np.sqrt(0.5*(modeps - eps.real)) 
    return n_r + 1j*n_i

def LL_mixing(fH, n_H, n_L):
    """TODO: Docstring for brugg_mixingbrugg_mixing.
    Given the volumne fraction of the higher index material, give the effective
    index of the layer
    :fH: volumne fraction from 0 to 1 
    :n_H: ri  of the higher index material 
    :n_L: ri of the lower index material 
    :returns: TODO
    """
    eH = nk_2_eps(n_H)
    eL = nk_2_eps(n_L)
    bigK = fH*(eH - 1)/(eH + 2) + (1 - fH)*(eL - 1)/(eL + 2)
    e_eff = (1 + 2*bigK)/(1 - bigK)
    return eps_2_nk(e_eff)

def bet01(x):
    x = x - np.amin(x)
    x = x/np.amax(x)
    return x

def make_qx(num_layers):
    """generate a random q_x array
    :num_layers: the number of layers
    :returns: a random function with bounds 0 and 1
    """
    uni = np.linspace(0, num_layers, num_layers, endpoint=False)
    rwalk = np.ones_like(uni)
    rwalk[0] = 0
    p_motion = random.random()
    pos_counter = 0

    #Start the random walk.
    for i in range(1,num_layers):
            test = random.random()
            if test >= p_motion:
                pos_counter += np.random.uniform(1,10)
            else:
                pos_counter -= np.random.uniform(1,10)
            rwalk[i] = pos_counter
    rwalk = bet01(rwalk)

    # some random number seeds
    samps = int(np.random.uniform(int(num_layers/10)+2, int(9*num_layers/10), 1))
    scales = np.random.uniform(0.05, 0.75, 1)

    s = np.linspace(-1, num_layers + 1, samps)
    fr = np.clip(np.random.normal(loc=0.5, scale=scales, size=samps), 0, 1)
    fr_p = s - s
    for ctr, x in enumerate(s):
        octa = np.random.uniform(1, 8, 1)
        pers = np.random.uniform(0.5, 2.2, 1)
        fr_p[ctr] = np.clip((0.5 + pnoise1(x, octaves=octa, persistence=pers)), 0, 1)

    
    

    if np.random.randint(2):
        f = interpolate.interp1d(s, fr, kind="quadratic", assume_sorted=True)
    else:
        f = interpolate.interp1d(s, fr_p, kind="quadratic", assume_sorted=True)

    if random.random() >= 0.20:
        rwalk = np.clip(f(uni), 0, 1)
    
    return rwalk 



def tmm_eval2(
        qx, 
        cthick,
        lam_low=400, #in nm 
        lam_high=1200,
        lam_pts=100,
        ang_low=0,   #degrees
        ang_high=90, 
        ang_pts=25,
        n_subs=1.52,
        ):
    """TODO: Docstring for tmm_eval.
    :qx: TODO
    :lam_low: TODO
    :#in nm 
        lam_high: TODO
    :lam_pts: TODO
    :ang_low: TODO
    :#degrees
        ang_high: TODO
    :ang_pts: TODO
    :returns: TODO

    """
    degree = np.pi/180
    lams = np.linspace(lam_low, lam_high, endpoint=True, num=lam_pts)
    thetas = np.linspace(ang_high, ang_low, endpoint=True, num=ang_pts)
    Rs = np.zeros((thetas.size, lams.size))
    Rp = np.zeros((thetas.size, lams.size))

    for tidx, theta in enumerate(thetas):
        for lidx, lam in enumerate(lams):
            d_x, n_x = make_nxdx(qx=qx, cthick=cthick, wavelen=lam, n_substrate=n_subs)
            Rs[tidx, lidx] = 100*coh_tmm('s',n_x,d_x, th_0=theta*degree,lam_vac=lam)
            Rp[tidx, lidx] = 100*coh_tmm('p',n_x,d_x, th_0=theta*degree,lam_vac=lam)
    return Rs, Rp 


def tmm_eval(
        qx, 
        cthick,
        lam_low=400, #in nm 
        lam_high=1200,
        lam_pts=100,
        ang_low=0,   #degrees
        ang_high=90, 
        ang_pts=25,
        n_subs=1.52,
        ):
    """TODO: Docstring for tmm_eval.
    :qx: TODO
    :lam_low: TODO
    :#in nm 
        lam_high: TODO
    :lam_pts: TODO
    :ang_low: TODO
    :#degrees
        ang_high: TODO
    :ang_pts: TODO
    :returns: TODO

    """
    degree = np.pi/180
    lams = np.linspace(lam_low, lam_high, endpoint=True, num=lam_pts)
    thetas = np.linspace(ang_high, ang_low, endpoint=True, num=ang_pts)
    Rs = np.zeros((thetas.size, lams.size))
    Rp = np.zeros((thetas.size, lams.size))

    for tidx, theta in enumerate(thetas):
        for lidx, lam in enumerate(lams):
            d_x, n_x = make_nxdx(qx=qx, cthick=cthick, wavelen=lam, n_substrate=n_subs)
            Rs[tidx, lidx] = 100*coh_tmm('s',n_x,d_x, th_0=theta*degree,lam_vac=lam)
            Rp[tidx, lidx] = 100*coh_tmm('p',n_x,d_x, th_0=theta*degree,lam_vac=lam)
    return Rs, Rp 

def tmm_eval_wbk(
        qx, 
        cthick,
        inc_ang,
        lam,
        n_subs=1.52
        ):
    """TODO: Docstring for tmm_eval.
    :qx: TODO
    :returns: TODO
    """
    degree = np.pi/180
    d_x, n_x = make_nxdx(qx=qx, cthick=cthick, wavelen=lam, n_substrate=n_subs)
    Rs = 100*mtmm.coh_tmm('s',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
#     Rp = 100*coh_tmm('p',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
    return Rs
    
def digitize_qx(q_x, dlevels=2):
    """TODO: Docstring for digitize_qx.
    :q_x: TODO
    :returns: TODO
    """
    #bins = np.array([-0.02, 0.2, 0.4, 0.6, 0.8, 1.01])
    bins = np.linspace(0,1, num=dlevels+1, endpoint=True)
    bins[0] = -0.02
    bins[-1] = 1.01
    dig = (np.digitize(q_x, bins, right=True)) - 1
    act_r = np.linspace(0, 1, num=dlevels, endpoint=True)
    return act_r[dig]


def make_nxdx(qx, n_substrate=1.52):
    """TODO: Docstring for make_nxdx.
    :qx: TODO
    :n_substrate: TODO
    :layer_thickness: TODO
    :returns: TODO
    """
    #num_layers = int(qx.size/2)
    d_x = [np.inf] + (qx/np.sum(qx)).tolist() + [np.inf]
    #qtmp = qx[:,0]
    #qtmp = digitize_qx(qx[:,1], dlevels=2)
    #d_x = num_layers*[cthick/num_layers]
    #d_x = [np.inf] + d_x + [np.inf]
    qtmp = np.zeros_like(qx)
    qtmp[::2] = qtmp[::2] + 1.0
    sde = 1.45 + (2.58 - 1.45)*qtmp
#     sde = LL_mixing(fH = qtmp, 
#                     n_H = 2.58, #get_nk(datafile=matsdict[3], wavelengths=wavelen), 
#                     n_L = 1.45) #get_nk(datafile=matsdict[4], wavelengths=wavelen
#                #   ))
    n_x = [1.0] + sde.tolist() + [n_substrate]
    return d_x, n_x

def vgdr_eval_wsweep(
        qx, 
        inc_ang,
        lam_low=0.25, #in nm 
        lam_high=1,
        lam_pts=256,
        n_subs=1.52,
        ):
    """TODO: Docstring for tmm_eval.
    """  
    degree = np.pi/180
    #lams = np.linspace(lam_low, lam_high, endpoint=True, num=lam_pts)
    lam_inv = np.linspace(1/lam_low, 1/lam_high, num=lam_pts, endpoint=True)
    lams = 1.0/lam_inv
    Rs = np.zeros(lams.size)
    #Rp = np.zeros(lams.size)

    for lidx, lam in enumerate(lams):
            d_x, n_x = make_nxdx(qx=qx, n_substrate=n_subs)
            Rs[lidx] = 100*coh_tmm('s',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
            #Rp[lidx] = 100*coh_tmm('p',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
    #return Rs
    vgdr = 100*np.ones(lam_pts)
    for idx, lam in enumerate(lams):
        if lams[idx] <= lam_high/2.0:
            cuts = Rs[np.logical_and(lams >= lam, lams <= 2*lam)]
            vgdr[idx] = np.mean(cuts)
    
    return np.amin(vgdr)
    
def vgdr2_eval_wsweep(
        qx, 
        inc_ang,
        lam_low=0.25, #in nm 
        lam_high=2,
        lam_pts=256,
        n_subs=1.52,
        ):
    """TODO: Docstring for tmm_eval.
    """  
    degree = np.pi/180
    #lams = np.linspace(lam_low, lam_high, endpoint=True, num=lam_pts)
    lam_inv = np.linspace(1/lam_low, 1/lam_high, num=lam_pts, endpoint=True)
    lams = 1.0/lam_inv
    Rs = np.zeros(lams.size)
    #Rp = np.zeros(lams.size)

    for lidx, lam in enumerate(lams):
            d_x, n_x = make_nxdx(qx=qx, n_substrate=n_subs)
            Rs[lidx] = 100*coh_tmm('s',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
            #Rp[lidx] = 100*coh_tmm('p',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
    #return Rs
    vgdr = 100*np.ones(lam_pts)
    for idx, lam in enumerate(lams):
        if lams[idx] <= 1:
            cuts = Rs[np.logical_and(lams >= lam, lams <= 2*lam)]
            vgdr[idx] = np.mean(cuts)
    
    return np.amin(vgdr), 400/lams[np.argmin(vgdr)]


def tmm_eval_wsweep(
        qx, 
        inc_ang,
        lam_low=0.1, #in nm 
        lam_high=10,
        lam_pts=100,
        n_subs=1.52,
        ):
    """TODO: Docstring for tmm_eval.
    """  
    degree = np.pi/180    
    #lams = np.linspace(lam_low, lam_high, endpoint=True, num=lam_pts)
    lam_inv = np.linspace(1/lam_low, 1/lam_high, num=lam_pts, endpoint=True)
    lams = 1.0/lam_inv
    Rs = np.zeros(lams.size)
    #Rp = np.zeros(lams.size)

    for lidx, lam in enumerate(lams):
            d_x, n_x = make_nxdx(qx=qx, n_substrate=n_subs)
            Rs[lidx] = 100*coh_tmm('s',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
            #Rp[lidx] = 100*coh_tmm('p',n_x,d_x, th_0=inc_ang*degree,lam_vac=lam)
    return Rs
    
    
# def tmm_wrapper2(arg):
#     args, kwargs = arg
#     return tmm_eval_wbk(*args, **kwargs)

# def tmm_lam_parallel(
#     q_x, 
#     cthick,
#     inc_ang, 
#     n_par=12,
#     lam_low=400, 
#     lam_high=1200, 
#     lam_pts=100, 
#     **kwargs):
    
#     jobs = []
#     pool=mp.Pool(n_par)
#     lams = np.linspace(lam_low, lam_high, endpoint=True, num=lam_pts)
#     for lam in lams:
#         jobs.append((q_x, cthick, inc_ang, lam))
#     arg = [(j, kwargs) for j in jobs]
#     answ = np.array(pool.map(tmm_wrapper2, arg))
#     pool.close()
#     return answ[:,0], answ[:,1] 
    
    

# def tmm_wrapper(arg):
#     args, kwargs = arg
#     return tmm_eval_wsweep(*args, **kwargs)

# def tmm_eval_parallel(q_x, cthick, n_ang= 25, n_par=10, **kwargs):
#     jobs = []
#     pool=mp.Pool(n_par)
#     angs = np.linspace(90, 0, endpoint=True, num=n_ang)
#     for ang in angs:
#         jobs.append((q_x, cthick, ang))
#     arg = [(j, kwargs) for j in jobs]
#     answ = np.array(pool.map(tmm_wrapper, arg))
#     pool.close()
#     return answ[:,0,:], answ[:,1,:]