hyptune.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from sklearn.model_selection import GridSearchCV
  2. from keras.models import Sequential
  3. from keras.layers import Dense
  4. from keras.wrappers.scikit_learn import KerasRegressor
  5. import scnets as scn
  6. from sklearn.metrics import make_scorer
  7. import numpy as np
  8. import pandas as pd
  9. def print_tuning_results(cvresults, modelfunc):
  10. pd.set_option('precision',2)
  11. bestidx = np.argsort(cvresults['mean_test_score'])
  12. scorelist = cvresults['mean_test_score'][bestidx]
  13. parlist = cvresults['params']
  14. runtlist = (1/60.0)*(cvresults['mean_fit_time'])
  15. runtlist = runtlist.astype('int64')
  16. bestlist = [parlist[indx] for indx in bestidx]
  17. par_count =[]
  18. for elem in bestlist:
  19. model = modelfunc(**elem)
  20. par_count.append(model.count_params())
  21. parkeylist = [key for key in bestlist[0]]
  22. columns = parkeylist + ['MRE(%)', 'Total Params', 'mean_fit_time']
  23. df = pd.DataFrame(columns=columns)
  24. for colno in np.arange(len(bestlist[0])):
  25. df[columns[colno]] = [elem[parkeylist[colno]] for elem in bestlist]
  26. df[columns[len(bestlist[0]) + 0]] = scorelist
  27. df[columns[len(bestlist[0])+ 1]] = par_count
  28. df[columns[len(bestlist[0])+ 2]] = runtlist
  29. return df
  30. # This will return the MRE error as score
  31. def mre_score_func(ground_truth, predictions):
  32. diff = np.abs(ground_truth - predictions)/np.abs(ground_truth)
  33. return -100*np.mean(diff)
  34. def get_cv_grid(modelfunc, param_grid, num_epochs, x_train, y_train):
  35. mre_score = make_scorer(mre_score_func, greater_is_better=False)
  36. #build estimator
  37. model = KerasRegressor(build_fn=modelfunc,
  38. epochs=num_epochs,
  39. verbose=0)
  40. grid = GridSearchCV(estimator=model,
  41. param_grid=param_grid,
  42. n_jobs=1,
  43. scoring=mre_score,
  44. verbose=1)
  45. grid_result = grid.fit(x_train, y_train)
  46. return grid_result.cv_results_