#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import numpy as np
import matplotlib.pyplot as plt
c = 299792458
pi = np.pi
verbose = 6
def read_data(dirname, distance, zshift):
    media = [1,2] # 1 - positive zshift, 2 - negative (need to add a minus sign for real shift).
    #min_mesh_step = 2.5 #nm
    data = []
    data.append([])
    for x in distance:
        data.append([])
        data[x].append([])
        for m in media:
            data[x].append([])
            for z in zshift:            
                monitor_name = "mon_x"+str(x)+"mkm_media"+str(m)+"_zshift"+z+"nm"
                data[x][m].append(
                    np.transpose(
                    np.genfromtxt(dirname+"/"+monitor_name+".txt", delimiter=", ",skip_header=1
                                  ,dtype=None, encoding = None
                                      , converters={0: lambda s: complex(s),
                                                    1: lambda s: complex(s),
                                                    2: lambda s: complex(s.replace('i', 'j')),
                                                    3: lambda s: complex(s.replace('i', 'j')),
                                                    4: lambda s: complex(s.replace('i', 'j')),
                                                    5: lambda s: complex(s.replace('i', 'j')),
                                                    6: lambda s: complex(s.replace('i', 'j')),
                                                    7: lambda s: complex(s.replace('i', 'j')),
                                                    8: lambda s: complex(s.replace('i', 'j'))
                                                    }
                                      )
                        )
                    )
    return data


def find_nearest(array,value):
    idx = (np.abs(array-value)).argmin()
    return array[idx],idx


def get_WLs_idx(WLs, data):
    dist = 1 #mkm
    mmedia = 1 # vacuum
    shift = 1 # one mesh step
    WLs_idx = []
    for wl in WLs:
        val, idx = find_nearest(data[dist][mmedia][shift][0,:],wl*1e-9)
        WLs_idx.append(idx)
    return WLs_idx


def check_field_match(data_in_air, data_in_gold,wl_idx,z_vec,kappa1,kappa2,eps2): 
    H1 = data_in_air[:,6,wl_idx]
    H2 = data_in_gold[:,6,wl_idx]
    E1 = data_in_air[:,4,wl_idx]
    E2 = data_in_gold[:,4,wl_idx]
    for i in range(len(z_vec)):
        z = z_vec[i]*1e-9
        if verbose > 8: print("z =",z)
        H1_0 = H1[i]/np.exp(-kappa1[wl_idx]*z)
        H2_0 = H2[i]/np.exp(-kappa2[wl_idx]*z)
        E1_0 = E1[i]/np.exp(-kappa1[wl_idx]*z)
        E2_0 = E2[i]/np.exp(-kappa2[wl_idx]*z)
        E2_0e = E2[i]/np.exp(-kappa2[wl_idx]*z)*eps2[wl_idx]
        if verbose > 8:
            print("H0 air  (%5.4g %+5.4gj)"%(np.real(H1_0), np.imag(H1_0)),
                  " from H1 (%5.4g %+5.4gj)"%(np.real(H1[i]), np.imag(H1[i])))
            print("H0 gold (%5.4g %+5.4gj)"%(np.real(H2_0), np.imag(H2_0)),
                  " from H2 (%5.4g %+5.4gj)"%(np.real(H2[i]), np.imag(H2[i])))
            print("E0 air  (%5.4g %+5.4gj)"%(np.real(E1_0), np.imag(E1_0)),
                  " from E1 (%5.4g %+5.4gj)"%(np.real(E1[i]), np.imag(E1[i])))
            print("E0*eps2 (%5.4g %+5.4gj)"%(np.real(E2_0e), np.imag(E2_0e)),
                  " from E2 (%5.4g %+5.4gj)"%(np.real(E2[i]), np.imag(E2[i])))
            print("E0 gold (%5.4g %+5.4gj)"%(np.real(E2_0), np.imag(E2_0)))


def analyze(data, dist, z_vec, wl_idx):
    ''' dist in mkm!!!
    '''
    #data = [dist][mmedia][shift] "lambda, dip.power, Ex, Ey, Ez, Hx, Hy, Hz, n_Au"
    #                              0     , 1        , 2 , 3 , 4 , 5 , 6 , 7 , 8   "
    data_in_air = np.array(data[dist][1])
    data_in_gold = np.array(data[dist][2])
    lambd = data_in_air[0][0,:]
    omega = 2*pi*c/lambd
    dip_power = data_in_air[0][1,:]

    Ex = data_in_air[0,2,0]
    Ey = data_in_air[0,3,0]
    Ez = data_in_air[0,4,0]
    Hx = data_in_air[0,5,0]
    Hy = data_in_air[0,6,0]
    Hz = data_in_air[0,7,0]
    E = np.array([Ex,Ey,Ez])
    H = np.array([Hx,Hy,Hz])
    print("S from full field",np.real(np.cross(E,np.conj(H))))
    
    eps1 = complex(1)
    n_Au = data_in_air[0][8,:]
    eps2 = n_Au**2

    k_0 = omega/c #air
    k_spp = k_0*np.sqrt(eps1*eps2/(eps1+eps2))
    kappa1= np.sqrt(k_spp**2 - eps1*k_0**2)
    kappa2= np.sqrt(k_spp**2 - eps2*k_0**2)

    print(1e9*lambd[9])
    print(1e9/kappa1[9])
    print(1e9/kappa2[9])
    check_field_match(data_in_air, data_in_gold,wl_idx,z_vec,kappa1,kappa2,eps2)

    H1 = data_in_air[:,6]
    E1 = data_in_air[:,4]
    
    z = z_vec[0]*1e-9

    if verbose > 5: print("Using data from air monitor at z =",z)
    H1_0 = H1[0]/np.exp(-kappa1*z)
    E1_0 = E1[0]/np.exp(-kappa1*z)
    E2_0 = E1[0]/eps2
    if verbose > 5: 
        print("H0 air  (%5.4g %+5.4gj)"%(np.real(H1_0[wl_idx]), np.imag(H1_0[wl_idx])),
              " from H1 (%5.4g %+5.4gj)"%(np.real(H1[0][wl_idx]), np.imag(H1[0][wl_idx])))
        print("E0 air  (%5.4g %+5.4gj)"%(np.real(E1_0[wl_idx]), np.imag(E1_0[wl_idx])),
              " from E1 (%5.4g %+5.4gj)"%(np.real(E1[0][wl_idx]), np.imag(E1[0][wl_idx])))
        print("E0 gold (%5.4g %+5.4gj)"%(np.real(E2_0[wl_idx]), np.imag(E2_0[wl_idx])), " from E1")

    R = dist*1e-6
    print("R =",R)
    #plasmon_power = 1.0/2.0 * np.real( E1[0] * np.conj(H1[0]))  # TODO check minus sign!!
    plasmon_power = -1.0/2.0 * 2.0*np.pi*R * ( # TODO check minus sign!!
        np.real( E1_0 * np.conj(H1_0) )
            / (2.0 * np.real(kappa1))
        +
        np.real( E2_0 * np.conj(H1_0) )
            / (2.0 * np.real(kappa2))        
        )* np.exp( 2.0*np.imag(k_spp)*R )  # TODO check minus sign!!
    #print(np.abs(plasmon_power/ dip_power))
    eta0 = plasmon_power[0]/ dip_power[0] *100
    ppw = plasmon_power[0]
    print("\n")
    print(dirname)
    print("Power: plasmon %4.3g W of dipoles %4.3g W, efficiency %5.3g%%  from:"%(ppw, float(np.abs(dip_power[0])),float(np.abs( eta0))), ppw, eta0)
    plt.plot(lambd*1e9, plasmon_power/ dip_power)
    plt.ylim(0,0.04)
    plt.xlim(550,800)

    #plt.plot(lambd*1e9, np.real(eps2))
    # plt.plot(lambd*1e9, np.real(k_spp))
    # plt.plot(lambd*1e9, k_0)
    #plt.semilogy(lambd*1e9, np.absolute(plasmon_power/ dip_power))
    # # legend = []
    # # legend.append(zshift[shift]+"@"+str(WLs[i])+" nm")
    # # plt.legend(legend)
    # # #plt.xlabel(r'THz')
    plt.xlabel(r'$\lambda$, nm')
    plt.ylabel(r'$P_{spp}/P_{dipole}$',labelpad=-5)
    #plt.title(' R = '+str(core_r)+' nm')
    plt.savefig(dirname+"_power_ratio."+file_ext)
    plt.clf()
    plt.close()
    
file_ext="pdf"
#dirname="template-dipole-on-sphere-on-surf-z.fsp.results"
#dirname="Au-JC-R100-Au-JC.fsp.results"
#dirname="Au-McPeak-R100-Si-Green.fsp.results"
#dirname="Au-McPeak-R100-Au-McPeak.fsp.results"
#dirname="sub-Au-R100-Si-wl450-800-sep10nm.fsp.results"
#dirname="bg-Au-sub-Au-dipole-W.fsp.results"
#dirname="bg-Au-sub-Si-dipole-W.fsp.results"
dirname="bg-Au-sub-dipole-W.fsp.results"
#dirname="Au-McPeak-R0.fsp.results"
#dirname="Au-McPeak-R100-Si-Green-1500.fsp.results"
#dirname="Au-McPeak-R100-Si-Green-1500-l.fsp.results"
#dirname="Au-McPeak-R50-Si-Green-1500-l.fsp.results"
#dirname="Au-sub-dipole.fsp.results"
#dirname="Au-sub-dipole-W.fsp.results"
#dirname="Au-sub-Au-dipole-W.fsp.results"
#dirname="Au-sub-Si-dipole-W.fsp.results"
def main ():
    distance = [1,2,3,4,5,6,7,8,9,10] #mkm
    zshift = ["5","20"]
    # zshift = ["5","20","200","400","600"]
    z_vec = [int(val) for val in zshift]

    data = read_data(dirname, distance, zshift)

    #WLs=[300,350,400,450,600,700,800]
    #WLs=[600,700, 800, 450]
    WLs=[800]#,1500]#, 450]
    WLs_idx = get_WLs_idx(WLs, data)


    dist = 10 #mkm
    wl_idx = WLs_idx[0]
    
    analyze(data, dist, z_vec, wl_idx)


    # legend = []
    # mmedia = 1
    # for shift in range(len(zshift)):
    #     for i in range(len(WLs)):
    #         pl_data = []
    #         idx = WLs_idx[i]
    #         legend.append(zshift[shift]+"@"+str(WLs[i])+" nm")
    #         for dist in distance:
    #             pl_data.append(np.absolute(data[dist][mmedia][shift][2,idx]*np.sqrt(dist)))
    #         print(len(pl_data))
    #         plt.semilogy(distance, pl_data,marker="o")
    # plt.legend(legend)
    # # #plt.xlabel(r'THz')
    # plt.xlabel(r'Monitor R, $\mu$m')
    # plt.ylabel(r'$Abs(E_x) \sqrt{R}$',labelpad=-5)
    # # plt.title(' r = '+str(core_r))
    # plt.savefig(dirname+"_WLs."+file_ext)
    # plt.clf()
    # plt.close()
main()