hf_trainer.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import numpy as np
  2. import os
  3. #from np.lib.histograms import histogram
  4. import torch
  5. import torch.distributed as dist
  6. from torchvision.utils import save_image
  7. from torch.utils.data import DataLoader
  8. from model.losses import VGGLoss
  9. from model.network.hf import HierarchyFlow
  10. from model.utils.dataset import get_dataset
  11. from model.utils.sampler import DistributedGivenIterationSampler, DistributedTestSampler
  12. from tensorboardX import SummaryWriter
  13. import logging
  14. from model.utils.log_helper import init_log
  15. from model.trainers.MR_signal_model import *
  16. init_log('pytorch hierarchy flow')
  17. global_logger = logging.getLogger('pytorch hierarchy flow')
  18. def save_checkpoint(state, filename):
  19. torch.save(state, filename+'.pth.tar')
  20. def load_checkpoint(checkpoint_fpath, model, optimizer):
  21. checkpoint = torch.load(checkpoint_fpath)
  22. model.load_state_dict(checkpoint['state_dict'])
  23. optimizer.load_state_dict(checkpoint['optimizer'])
  24. return model, optimizer, checkpoint['step']
  25. def reduce_mean(tensor, nprocs):
  26. rt = tensor.clone()
  27. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  28. rt /= nprocs
  29. return rt
  30. class Trainer():
  31. def __init__(self, cfg, local_rank, world_size):
  32. self.cfg = cfg
  33. self.rank = local_rank
  34. self.world_size = world_size
  35. model = HierarchyFlow(self.cfg.network.pad_size, self.cfg.network.in_channel, self.cfg.network.out_channels, self.cfg.network.weight_type)
  36. #model.cuda(self.rank)
  37. model = model.to(device='cuda')
  38. if self.rank == 0:
  39. global_logger.info(self.cfg)
  40. global_logger.info(model)
  41. optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.lr)
  42. #if self.cfg.eval_mode or (self.cfg.resume and os.path.isfile(self.cfg.load_path)):
  43. if self.cfg.resume:
  44. self.model, self.optimizer, self.resumed_step = load_checkpoint('checkpoint_for_resume/0.ckpt.pth.tar', model, optimizer)
  45. global_logger.info("=> loaded checkpoint '{}' with current step {}".format(self.cfg.load_path, self.resumed_step))
  46. else:
  47. self.model = model
  48. self.optimizer = optimizer
  49. self.resumed_step = -1
  50. if self.cfg.lr_scheduler.type == 'cosine':
  51. self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.cfg.max_iter, self.cfg.lr_scheduler.eta_min)
  52. else:
  53. raise RuntimeError('lr_scheduler {} is not implemented'.format(self.cfg.lr_scheduler))
  54. self.criterion = VGGLoss(self.cfg.loss.vgg_encoder).cuda(self.rank)
  55. if self.rank == 0:
  56. self.logger = SummaryWriter(os.path.join(self.cfg.output, self.cfg.task_name, 'runs'))
  57. def train(self):
  58. train_dataset = get_dataset(self.cfg.dataset.train)
  59. #train_sampler = DistributedGivenIterationSampler(train_dataset,
  60. #self.cfg.max_iter, self.cfg.dataset.train.batch_size, world_size=self.world_size, rank=self.rank, last_iter=-1)
  61. train_loader = DataLoader(
  62. train_dataset,
  63. batch_size=self.cfg.dataset.train.batch_size,
  64. shuffle=True) #sampler=train_sampler num_workers=4 pin_memory=False
  65. for ep in range(self.cfg.epochs):
  66. for batch_id, batch in enumerate(train_loader):
  67. b_id = batch_id + (len(train_loader) * ep)
  68. self.train_iter(b_id, batch)
  69. self.eval()
  70. def eval(self):
  71. test_dataset = get_dataset(self.cfg.dataset.test)
  72. test_sampler = DistributedTestSampler(test_dataset, world_size=self.world_size, rank=self.rank)
  73. test_loader = DataLoader(
  74. test_dataset,
  75. batch_size=self.cfg.dataset.test.batch_size,
  76. shuffle=False,
  77. num_workers=4,
  78. pin_memory=False,
  79. sampler=test_sampler)
  80. self.model.eval()
  81. with torch.no_grad():
  82. for batch_id, batch in enumerate(test_loader):
  83. content_images = batch[0].cuda(self.rank)
  84. style_images = batch[1].cuda(self.rank)
  85. names = batch[2]
  86. outputs = self.model(content_images, style_images)
  87. outputs = torch.clamp(outputs, 0, 1)
  88. outputs = outputs.cpu()
  89. for idx in range(len(outputs)):
  90. output_name = os.path.join(self.cfg.output, self.cfg.task_name, 'eval_results', 'pred', names[idx])
  91. save_image(outputs[idx].unsqueeze(0), output_name)
  92. if idx == 0:
  93. output_name = os.path.join(self.cfg.output, self.cfg.task_name, 'eval_results', 'cat_img', names[idx])
  94. output_images = torch.stack((content_images[idx].cpu(), style_images[idx].cpu(), outputs[idx]), 0)
  95. save_image(output_images, output_name, nrow=1)
  96. if self.rank == 0 and batch_id % 10 == 1:
  97. global_logger.info('predicting {}th batch...'.format(batch_id))
  98. if self.rank == 0:
  99. global_logger.info('Save predictions to {}\nDone.'.format(os.path.join(self.cfg.output, self.cfg.task_name, 'eval_results')))
  100. def train_iter(self, batch_id, batch):
  101. content_images = batch[0].cuda(self.rank)
  102. style_images = batch[1].cuda(self.rank)
  103. denorm_images = batch[2].cuda(self.rank)
  104. max_vals = batch[3].cuda(self.rank)
  105. outputs = self.model(content_images, style_images)
  106. #outputs = torch.clamp(outputs, 0)
  107. outputs = torch.clamp(outputs, 0, 1)
  108. #loss_c, loss_s = self.criterion(content_images, style_images, outputs, self.cfg.loss.k)
  109. #loss_c = loss_c.mean()
  110. #loss_s = loss_s.mean()
  111. #loss = loss_c + self.cfg.loss.weight * loss_s
  112. denorm_outputs = denormalize_per_image(outputs, [0.0, 0.0, 0.0], [3.5, 0.45, 1.0 ])
  113. mr_model = MRSignalModel()
  114. reconstruct_images = mr_model(denorm_outputs, max_vals)
  115. #reconstruct_images = mr_model(outputs, max_vals)
  116. loss_MSE = nn.MSELoss()
  117. loss = loss_MSE(reconstruct_images, denorm_images.float())
  118. #torch.distributed.barrier()
  119. #loss = reduce_mean(loss, self.world_size)
  120. #loss_c = reduce_mean(loss_c, self.world_size)
  121. #loss_s = reduce_mean(loss_s, self.world_size)
  122. self.optimizer.zero_grad()
  123. loss.backward()
  124. self.optimizer.step()
  125. self.lr_scheduler.step()
  126. if self.rank == 0:
  127. current_lr = self.lr_scheduler.get_lr()[0]
  128. self.logger.add_scalar("current_lr", current_lr, batch_id + 1)
  129. self.logger.add_scalar("loss_c", loss.item(), batch_id + 1)
  130. self.logger.add_scalar("loss_s", loss.item(), batch_id + 1)
  131. self.logger.add_scalar("loss", loss.item(), batch_id + 1)
  132. if batch_id % self.cfg.print_freq == 0:
  133. global_logger.info('batch: {}, style_loss: {}, content_loss: {}, loss: {}'.format(batch_id, loss.item(), loss.item(), loss.item()))
  134. output_name = os.path.join(self.cfg.output, self.cfg.task_name, 'img_save', str(batch_id)+'.jpg')
  135. output_images = torch.cat((content_images.cpu(), reconstruct_images.cpu(), outputs.cpu()), 0)
  136. save_image(output_images, output_name, nrow=1)
  137. if batch_id % self.cfg.save_freq == 0:
  138. save_checkpoint({
  139. 'step':batch_id,
  140. 'state_dict':self.model.state_dict(),
  141. 'optimizer':self.optimizer.state_dict()
  142. },os.path.join(self.cfg.output, self.cfg.task_name, 'model_save', str(batch_id)+ '.ckpt'))