from keras import backend as K
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout
from keras.layers import Reshape, UpSampling1D, Conv1D
from keras.layers import Flatten, Activation
from keras.utils import np_utils, multi_gpu_model
from keras.regularizers import l2
from keras.wrappers.scikit_learn import KerasRegressor
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import PReLU
from keras.models import Model
from keras.layers import Input, Add
from keras.layers.normalization import BatchNormalization
from keras.layers import PReLU
from keras.utils import to_channels_first



#staging area for new models
def plot_training_history(history, red_factor):
    loss, val_loss = history.history['loss'], history.history['val_loss']
    loss = np.asarray(loss)/red_factor
    val_loss = np.asarray(val_loss)/red_factor
    epochs = len(loss)

    fig, axs = plt.subplots(1,1, figsize=(5,5))
    axs.semilogy(np.arange(1, epochs + 1), loss, label='train error')
    axs.semilogy(np.arange(1, epochs + 1), val_loss, label='validation error')
    axs.set_xlabel('Epoch number')
    axs.set_ylabel('Mean Relative Error (MRE) (%)')
    axs.legend(loc="best")
    
#function to test performance on testset
def calc_mre(y_true, y_pred):
    y_err = 100*np.abs(y_true - y_pred)/y_true
    return np.mean(y_err)

#function to test performance on testset
def calc_mre_K(y_true, y_pred):
    y_err = 100*np.abs(y_true - y_pred)/y_true
    return K.mean(y_err)

#naive percentage loss
def relerr_loss(y_true, y_pred):
    y_err = np.abs(y_true - y_pred)/y_true
    y_err_f = K.flatten(y_err)
    return K.sum(y_err_f)





def fullycon( in_size=8, 
             out_size=256, 
             batch_size=32,
             N_hidden=3, 
             N_neurons=250, 
             N_gpus=1):
    """
    Returns a fully-connected model which will take a normalized size vector and return a
    spectrum
    in_size: length of the size vector
    out_size: length of the spectrum vector
    N_hidden: number of hidden layers
    N_neurons: number of neurons in each of the hidden layers
    """
    model = Sequential()
    model.add(Dense(N_neurons, input_dim=in_size, 
                    kernel_initializer='normal', activation='relu',
                    name='first' ))
    for h in np.arange(N_hidden):
        lname = "H"+str(h)
        model.add(Dense(N_neurons, 
                        kernel_initializer='normal', activation='relu', name=lname ))

    model.add(Dense(out_size, kernel_initializer='normal', name='last'))

    # Compile model
    if N_gpus == 1:
        model.compile(loss=relerr_loss, optimizer='adam', metrics=[calc_mre_K])
    else:
        gpu_list = ["gpu(%d)" % i for i in range(N_gpus)]
        model.compile(loss=relerr_loss, optimizer='adam', metrics=[calc_mre_K], context = gpu_list)
    return model



def conv1dmodel(in_size=8, 
        out_size=256,
        batch_size=32,
        c1_nf=64,
        clayers=2,
        ker_size=3):
    # create model
    model = Sequential()

    model.add(Dense(out_size, input_dim=in_size, 
        kernel_initializer='normal',
        name='first', activation='relu' ))
   
    model.add(Reshape((4, 64), name='Reshape1'))
    model.add(UpSampling1D(size=2, name='Up1'))

    model.add(Conv1D(filters=c1_nf, 
        kernel_size=ker_size, strides=1, padding='same', 
        dilation_rate=1, name='Conv1', 
        kernel_initializer='normal', activation='relu'))


    for cl in np.arange(clayers):
        model.add(Conv1D(filters=32, 
            kernel_size=ker_size, 
            strides=1, 
            padding='same', 
            dilation_rate=1, 
            name='Conv'+ str(cl+2),
            kernel_initializer='normal',
            activation='relu'))

    model.add(Flatten()) 
    
    model.compile(loss=relerr_loss, optimizer='adam', metrics=[calc_mre_K])
    return model


def convprel(in_size=8, 
        out_size=256,
        batch_size=32,
        c1_nf=64,
        clayers=2,
        ker_size=3):
    
    # create model
    model = Sequential()

    model.add(Dense(out_size, input_dim=in_size, 
        kernel_initializer='normal',
        name='first'))
    model.add(PReLU(alpha_initializer='zeros', alpha_regularizer=None))
    model.add(Reshape((4, 64), name='Reshape1'))
    model.add(UpSampling1D(size=2, name='Up1'))

    model.add(Conv1D(filters=c1_nf, 
        kernel_size=ker_size, strides=1, padding='same', 
        dilation_rate=1, name='Conv1', 
        kernel_initializer='normal'))
    model.add(PReLU(alpha_initializer='zeros', alpha_regularizer=None))


    for cl in np.arange(clayers):
        model.add(Conv1D(filters=32, 
            kernel_size=ker_size, 
            strides=1, 
            padding='same', 
            dilation_rate=1, 
            name='Conv'+ str(cl+2),
            kernel_initializer='normal'))
        model.add(PReLU(alpha_initializer='zeros', alpha_regularizer=None))
        
        
    model.add(Flatten()) 
    
    model.compile(loss=relerr_loss, optimizer='adam', metrics=[calc_mre_K])
    return model


def resblock2(Input):
    #Input = to_channels_first(Input)
    Output = Conv1D(filters=32, kernel_size=3, strides=1, padding='same', 
            dilation_rate=1, 
            kernel_initializer='normal')(Input)
    #Output = BatchNormalization()(Output)
    Output = Activation('relu')(Output)  
    Output = Conv1D(filters=32, kernel_size=3, strides=1, padding='same', 
            dilation_rate=1, 
            kernel_initializer='normal')(Output)
    Output = Add()([Output, Input])
    return Output



def resblock(Input, ker_size, red_dim):
    #Input = to_channels_first(Input)
    Output = Conv1D(filters=red_dim, kernel_size=1, strides=1, padding='same', 
            dilation_rate=1, 
            kernel_initializer='normal')(Input)
    Output = BatchNormalization()(Output)
    Output = Activation('relu')(Output)  
    Output = Conv1D(filters=red_dim, kernel_size=ker_size, strides=1, padding='same', 
            dilation_rate=1, 
            kernel_initializer='normal')(Output)
    Output = BatchNormalization()(Output)
    Output = Activation('relu')(Output)  
    Output = Conv1D(filters=32, kernel_size=1, strides=1, padding='same', 
            dilation_rate=1, 
            kernel_initializer='normal')(Output)
    Output = BatchNormalization()(Output)
    Output = Add()([Output, Input])
    Output = PReLU(alpha_initializer='zeros', alpha_regularizer=None)(Output)
    return Output

def resnet(in_size=8, 
        out_size=256,
        num_units=2,
        red_dim=8,
        batch_size=32,
        ker_size=3):
    
    a = Input(shape=(in_size,))
    first = Dense(256, kernel_initializer='normal')(a)
    first = PReLU(alpha_initializer='zeros', alpha_regularizer=None)(first)
    first = Reshape((8,32))(first)
   
    for units in np.arange(num_units):
        first = resblock(first, ker_size, red_dim)

    last = Flatten()(first)

    model = Model(inputs=a, outputs=last)

    #compile model
    model.compile(loss=relerr_loss, optimizer='adam', metrics=[calc_mre_K])

    return model