#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import numpy as np
from matplotlib import patches
from matplotlib.path import Path
data = []
r=115
#r=100
#r=75
#name = "Fig3%s"%r
#to_plot = {1,6,8}
#to_plot = {0,1,2,3,4,5,6,7}

#-80,-60,-30, -20, -10
# to_plot = {"a","b","c","d",
#            "e","f","g","h",
#            "i","j","k","l"}

to_plot = ["a","e","i",
           "b","f","j",
           "c","g","k",
           "d","h","l"]
    
for i in range(len(to_plot)):
    #print(i, to_plot[i])
    data.append(
        np.transpose(
            np.loadtxt("Fig3%s.txt"%to_plot[i])
        )[-1]
    )

space = 25
px = 190-space
for i in range(len(data)):
    dim = int(np.sqrt(len(data[i])))
    data[i] = np.reshape(data[i], (dim, dim) )
    zero = np.zeros((dim+2*space, dim+2*space))
    zero[space:-space, space:-space] = data[i]    
    tmp = zero
    # Re(\epsilon) = 3.681^2 - 2.7* 10^{-21}*N_e(in cm^{-3}).
    #data[i] = 3.681**2 - 2.7*np.transpose(tmp)*5e22/1e21
    #data[i] = 3.681 - np.sqrt(3.681**2 - 2.7*np.transpose(tmp)*5e22/1e21)
    data[i] =  -2.7*np.transpose(tmp)*5e22/1e21
    # if i == 0:
    #     data[i] = np.transpose(tmp)*1000*5

        

nm_scale = 2 # nm per pixel
scale_x = np.linspace( -(250-px)*nm_scale, (250-px)*nm_scale, 500-2*px)
scale_y = np.linspace( -(250-px)*nm_scale, (250-px)*nm_scale, 500-2*px)

    
import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib import cm

my_fontsize = 8
rcParams.update({'font.size': my_fontsize})
#rcParams.update({'mathtext.it': "TeX Gyre Termes Math"})
import matplotlib
# Match Overleaf template font for Nanoscale journal
matplotlib.rcParams['mathtext.fontset'] = 'stix'

# matplotlib.rcParams['mathtext.it'] = 'TeX Gyre Termes:italic'

# matplotlib.rcParams['mathtext.fontset'] = 'custom'
# matplotlib.rcParams['mathtext.rm'] = 'TeX Gyre Termes'
# matplotlib.rcParams['mathtext.it'] = 'TeX Gyre Termes:italic'
# matplotlib.rcParams['mathtext.bf'] = 'TeX Gyre Termes:bold'

fig, axs = plt.subplots(3,4)#, sharey=True, sharex=True)

axis_color = "white"
#axis_color = "black"

for i in range(len(data)):
    r=75
    if i%3==1: r=100
    if i%3==2: r=115
    if i//3 == 0:
        data[i] = data[i]*100
    print(r)
    # print(i%3, i//3)
    ax = axs[i%3][i//3]
    #max_tick = np.amax(data[i])*0.75
    max_tick = 0
    min_tick = np.amin(data[i])*0.65
    if i//3 == 0:
        min_tick = np.amin(data[i])*0.85
    if i//3 == 2:
        min_tick = np.amin(data[i])*0.55
    if i%3 == 2 and i//3 == 2:
        min_tick = np.amin(data[i])*0.29
                
    if max_tick > 10: max_tick = int(max_tick)
    if min_tick < -10: min_tick = int(min_tick)
    #if i!=0: max_tick = 1.3
    #min_tick = np.amin(data[i])
    #min_tick = 0.1
    #max_tick = 60
    scale_ticks = np.linspace(min_tick, max_tick, 2)

    cax = ax.imshow(data[i]
                    #, interpolation='bicubic'
                    , interpolation='none'
                    #, cmap=cm.afmhot_r
                    , cmap=cm.hot_r
                    ,vmin=min_tick, vmax=max_tick
                    , extent=(min(scale_x), max(scale_x), min(scale_y), max(scale_y))
                        #, extent=(-50, 50, -50, 50)
                    , aspect = 'equal'
                    # ,norm = LogNorm()
                )
            # Define scale ticks
    # vertically oriented colorbar
    cbar = fig.colorbar(cax, ticks=[a for a in scale_ticks], ax=ax, fraction=0.042, pad=0.04)

    cbar.ax.set_yticklabels(['%2.1f' % (a) for a in scale_ticks], va="center", ha="left")
    if max_tick > 10 or min_tick <= -10:
        cbar.ax.set_yticklabels(['%i' % (int(a)) for a in scale_ticks], va="center", ha="left")
    

    #bar.ax.set_title(r'$n_e, 10^{22}{\rm cm}^{-3}$')
    lp_cbar = -5
    if max_tick > 10: lp_cbar = -1.5
    if max_tick > 100: lp_cbar = -4.7
    
    cbar.ax.set_ylabel(r'$\Delta \mathrm{Re}(\epsilon)$', labelpad = lp_cbar, fontsize=my_fontsize+1, rotation=90+180)
    if i//3 == 0:
        cbar.ax.set_ylabel(r'$\Delta \mathrm{Re}(\epsilon) \times 100$', labelpad = lp_cbar, fontsize=my_fontsize+1, rotation=90+180)
    #cbar.ax.set_ylabel(r'$n_e$', labelpad = lp_cbar, fontsize=my_fontsize+2, rotation=90+180)
    # cbar.ax.set_ylabel(r'$n_e,\; 10^{\,20} \ {\rm cm}^{-3}$', labelpad = -3, fontsize=my_fontsize+1, rotation=90+180)

    # cbar.ax.set_ylabel(r'$n_e,\; 10^{\;22}\ {\rm cm}^{-3}$', labelpad = -4, fontsize=my_fontsize+2, rotation=90+180)
    # if i == 0:
    #     cbar.ax.set_ylabel(r'$n_e,\; 10^{\;19}\ {\rm cm}^{-3}$', labelpad = -4, fontsize=my_fontsize+2, rotation=90+180)

    #ax.set(adjustable='box-forced')
    # ax.xaxis.set_tick_params(width=outline_width/2.0)
    # ax.yaxis.set_tick_params(width=outline_width/2.0)
    # ax.tick_params(axis='x', colors=axis_color)
    # ax.tick_params(axis='y', colors=axis_color)
    ax.spines['bottom'].set_color(axis_color)
    ax.spines['top'].set_color(axis_color) 
    ax.spines['right'].set_color(axis_color)
    ax.spines['left'].set_color(axis_color)
    ax.tick_params(axis='both', color=axis_color, width = 1.2)
    # for spine in ax.spines:
    #     spine.set_color(axis_color)
    # ax.yaxis.label.set_color('red')
    lp1 = 1.0
    if r==115:
        ax.set_xlabel(r'$Z,\rm nm$', labelpad=lp1, fontsize=my_fontsize+2)
    else:
        ax.set_xticklabels([])

    ax.axis("image")
    s1 = patches.Arc((1., -1.), 2.0 * r, 2.0 * r,  angle=0.0,  zorder=1.8,
                         theta1=0.0, theta2=360.0, linewidth=1.3,
#                         color='white',
                         color=axis_color,
                         linestyle='--')
    # s1 = patches.Arc((2.5, -2.5), 2.0 * r, 2.0 * r,  angle=0.0,  zorder=1.8,
    #                      theta1=0.0, theta2=360.0, linewidth=0.7, color='white',                         
    #                      fill=False)
    ax.add_patch(s1)

# axs[0][0].annotate('',
#             xy=(-0.45, 1.2), xycoords='axes fraction',
#             xytext=(-0.45, 0.9), textcoords='offset points',
#             size=10,
#             color="green",
#             # bbox=dict(boxstyle="round", fc="0.8"),
#             arrowprops=dict(arrowstyle="<->",
#                             #fc="0.6", ec="none",
#                             #patchB=el,
#                             ))
axs[0][0].annotate(r'${E}$', xy=(-0.48, 1.05), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='center', verticalalignment='bottom', color="black")
axs[0][0].annotate(s='', xy=(-0.38,1.24),
                       xytext=(-0.38,0.84),
                       size=12, xycoords='axes fraction',
                       arrowprops=dict(arrowstyle='<->', linewidth=1.5))

axs[0][0].annotate(r'${k}$', xy=(-0.2, 1.05), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='center', verticalalignment='bottom', color="black")
axs[0][0].annotate(s='', xy=(-0.36,1.04),
                       xytext=(-0.04,1.04),
                       size=12, xycoords='axes fraction',
                       arrowprops=dict(arrowstyle='<-', linewidth=1.5))

axs[0][0].annotate(r'$R=\rm 75\;nm$', xy=(-0.53, 0.5), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='left', verticalalignment='center', color="black", rotation = 90)
axs[1][0].annotate(r'$R=\rm 100\; nm$', xy=(-0.53, 0.5), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='left', verticalalignment='center', color="black", rotation = 90)
axs[2][0].annotate(r'$R=\rm 115\; nm$', xy=(-0.53, 0.5), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='left', verticalalignment='center', color="black", rotation = 90)

axs[0][0].annotate(r'$\rm Stage\;1$', xy=(0.5, 1.2), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='center', verticalalignment='top', color="black")
axs[0][1].annotate(r'$\rm Stage\;2$', xy=(0.5, 1.2), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='center', verticalalignment='top', color="black")
axs[0][2].annotate(r'$\rm Stage\;3$', xy=(0.5, 1.2), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='center', verticalalignment='top', color="black")
axs[0][3].annotate(r'$\rm Stage\;4$', xy=(0.5, 1.2), xycoords='axes fraction', fontsize=my_fontsize+4,
                       horizontalalignment='center', verticalalignment='top', color="black")
# axs[0][0].annotate(r'$\rm Stage 1$', xy=(0.5, -0.0), xycoords='axes fraction', fontsize=my_fontsize+4,
#                        horizontalalignment='center', verticalalignment='top', color="white")

if r: #r==75:
    axs[0][0].annotate('(a)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[0][1].annotate('(b)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[0][2].annotate('(c)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[0][3].annotate('(d)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
if r: #r==100:
    axs[1][0].annotate('(e)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[1][1].annotate('(f)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[1][2].annotate('(g)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[1][3].annotate('(h)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
if r: #r==115:
    axs[2][0].annotate('(i)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[2][1].annotate('(j)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[2][2].annotate('(k)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    axs[2][3].annotate('(l)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
                    horizontalalignment='left', verticalalignment='top', color=axis_color)
    # axs[2].annotate('(m)', xy=(0.045, 0.934), xycoords='axes fraction', fontsize=my_fontsize+2,
    #                 horizontalalignment='left', verticalalignment='top', color=axis_color)
    # axs[3].annotate('(n)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
    #                 horizontalalignment='left', verticalalignment='top', color=axis_color)
    # axs[4].annotate('(o)', xy=(0.015, 0.96), xycoords='axes fraction', fontsize=my_fontsize+2,
    #                 horizontalalignment='left', verticalalignment='top', color=axis_color)


lp2 = -10.0
axs[0][0].set_ylabel(r'$X,\rm nm$', labelpad=lp2,  fontsize=my_fontsize+2)
axs[1][0].set_ylabel(r'$X,\rm nm$', labelpad=lp2,  fontsize=my_fontsize+2)
axs[2][0].set_ylabel(r'$X,\rm nm$', labelpad=lp2,  fontsize=my_fontsize+2)
for i in range(len(data)):
    print(i)
    ax = axs[i%3][i//3]
    if i//3 != 0:
        ax.set_yticklabels([])
    if i%3 != 2:
        ax.set_xticklabels([])
    ax.locator_params(axis='x',nbins=4)
    ax.locator_params(axis='y',nbins=4)
    

fig.subplots_adjust(hspace=-0.43, wspace=0.25)
#fig.tight_layout()

plt.savefig("plasma-grid.pdf",pad_inches=0.02, bbox_inches='tight')

plt.draw()

#    plt.show()

plt.clf()
plt.close()