config.yaml 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # logdir = os.path.join(output, task_name)
  2. output: ./output_dir/
  3. task_name: base_exp
  4. dataset:
  5. train:
  6. rootA: 'flow_model/data/train/weighted_real'
  7. rootB: 'flow_model/data/train/maps_sample/subject04_crisp_v_180.npy'
  8. width: 512
  9. height: 256
  10. scale_l: 0.8
  11. scale_h: 1.0
  12. transform: [] #['h_flip', 'v_flip', 'crop', 'normalize', 'random_resized_crop']
  13. random_pair: True
  14. return_name: False
  15. batch_size: 1
  16. test:
  17. rootA: 'flow_model/data/test/weighted_real'
  18. rootB: 'flow_model/data/test/maps_sample/subject18_crisp_v_180.npy'
  19. width: 512
  20. height: 256
  21. scale_l: 0.8
  22. scale_h: 1.0
  23. transform: [] #['h_flip', 'v_flip', 'crop', 'normalize']
  24. random_pair: False
  25. return_name: True
  26. batch_size: 16
  27. lr: 0.0001
  28. epochs: 120
  29. max_iter: 300000
  30. print_freq: 450
  31. save_freq: 450
  32. resume: True
  33. load_path: 'flow_model/checkpoint_for_resume/0.ckpt.pth.tar'
  34. network:
  35. configurable: False #[True, False]
  36. pad_size: 10
  37. in_channel: 3
  38. out_channels: [30, 120] #[30, 120], [12, 60, 120], [30, 120, 480], [30, 120, 480, 1920]
  39. weight_type: 'learned' #['fixed', 'sigmoid', 'softmax', 'attention', 'learned']
  40. loss:
  41. vgg_encoder: 'flow_model/model/losses/vgg_model/vgg_normalised.pth'
  42. k: 0.7
  43. weight: 0.7
  44. lr_scheduler:
  45. type: cosine
  46. eta_min: 0.0000000