| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- import numpy as np
- import os
- #from np.lib.histograms import histogram
- import torch
- import torch.distributed as dist
- from torchvision.utils import save_image
- from torch.utils.data import DataLoader
- from model.losses import VGGLoss
- from model.network.hf import HierarchyFlow
- from model.utils.dataset import get_dataset
- from model.utils.sampler import DistributedGivenIterationSampler, DistributedTestSampler
- from tensorboardX import SummaryWriter
- import logging
- from model.utils.log_helper import init_log
- from model.trainers.MR_signal_model import *
- init_log('pytorch hierarchy flow')
- global_logger = logging.getLogger('pytorch hierarchy flow')
- def save_checkpoint(state, filename):
- torch.save(state, filename+'.pth.tar')
- def load_checkpoint(checkpoint_fpath, model, optimizer):
- checkpoint = torch.load(checkpoint_fpath)
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- return model, optimizer, checkpoint['step']
- def reduce_mean(tensor, nprocs):
- rt = tensor.clone()
- dist.all_reduce(rt, op=dist.ReduceOp.SUM)
- rt /= nprocs
- return rt
- class Trainer():
- def __init__(self, cfg, local_rank, world_size):
- self.cfg = cfg
- self.rank = local_rank
- self.world_size = world_size
-
- model = HierarchyFlow(self.cfg.network.pad_size, self.cfg.network.in_channel, self.cfg.network.out_channels, self.cfg.network.weight_type)
- #model.cuda(self.rank)
- model = model.to(device='cuda')
- if self.rank == 0:
- global_logger.info(self.cfg)
- global_logger.info(model)
- optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.lr)
- #if self.cfg.eval_mode or (self.cfg.resume and os.path.isfile(self.cfg.load_path)):
- if self.cfg.resume:
- self.model, self.optimizer, self.resumed_step = load_checkpoint('checkpoint_for_resume/0.ckpt.pth.tar', model, optimizer)
- global_logger.info("=> loaded checkpoint '{}' with current step {}".format(self.cfg.load_path, self.resumed_step))
- else:
- self.model = model
- self.optimizer = optimizer
- self.resumed_step = -1
- if self.cfg.lr_scheduler.type == 'cosine':
- self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.cfg.max_iter, self.cfg.lr_scheduler.eta_min)
- else:
- raise RuntimeError('lr_scheduler {} is not implemented'.format(self.cfg.lr_scheduler))
- self.criterion = VGGLoss(self.cfg.loss.vgg_encoder).cuda(self.rank)
-
- if self.rank == 0:
- self.logger = SummaryWriter(os.path.join(self.cfg.output, self.cfg.task_name, 'runs'))
- def train(self):
- train_dataset = get_dataset(self.cfg.dataset.train)
- #train_sampler = DistributedGivenIterationSampler(train_dataset,
- #self.cfg.max_iter, self.cfg.dataset.train.batch_size, world_size=self.world_size, rank=self.rank, last_iter=-1)
- train_loader = DataLoader(
- train_dataset,
- batch_size=self.cfg.dataset.train.batch_size,
- shuffle=True) #sampler=train_sampler num_workers=4 pin_memory=False
- for ep in range(self.cfg.epochs):
- for batch_id, batch in enumerate(train_loader):
- b_id = batch_id + (len(train_loader) * ep)
- self.train_iter(b_id, batch)
-
- self.eval()
- def eval(self):
- test_dataset = get_dataset(self.cfg.dataset.test)
- test_sampler = DistributedTestSampler(test_dataset, world_size=self.world_size, rank=self.rank)
- test_loader = DataLoader(
- test_dataset,
- batch_size=self.cfg.dataset.test.batch_size,
- shuffle=False,
- num_workers=4,
- pin_memory=False,
- sampler=test_sampler)
- self.model.eval()
- with torch.no_grad():
- for batch_id, batch in enumerate(test_loader):
- content_images = batch[0].cuda(self.rank)
- style_images = batch[1].cuda(self.rank)
- names = batch[2]
- outputs = self.model(content_images, style_images)
- outputs = torch.clamp(outputs, 0, 1)
- outputs = outputs.cpu()
- for idx in range(len(outputs)):
- output_name = os.path.join(self.cfg.output, self.cfg.task_name, 'eval_results', 'pred', names[idx])
- save_image(outputs[idx].unsqueeze(0), output_name)
- if idx == 0:
- output_name = os.path.join(self.cfg.output, self.cfg.task_name, 'eval_results', 'cat_img', names[idx])
- output_images = torch.stack((content_images[idx].cpu(), style_images[idx].cpu(), outputs[idx]), 0)
- save_image(output_images, output_name, nrow=1)
- if self.rank == 0 and batch_id % 10 == 1:
- global_logger.info('predicting {}th batch...'.format(batch_id))
- if self.rank == 0:
- global_logger.info('Save predictions to {}\nDone.'.format(os.path.join(self.cfg.output, self.cfg.task_name, 'eval_results')))
- def train_iter(self, batch_id, batch):
- content_images = batch[0].cuda(self.rank)
- style_images = batch[1].cuda(self.rank)
- denorm_images = batch[2].cuda(self.rank)
- max_vals = batch[3].cuda(self.rank)
- outputs = self.model(content_images, style_images)
- #outputs = torch.clamp(outputs, 0)
- outputs = torch.clamp(outputs, 0, 1)
- #loss_c, loss_s = self.criterion(content_images, style_images, outputs, self.cfg.loss.k)
- #loss_c = loss_c.mean()
- #loss_s = loss_s.mean()
- #loss = loss_c + self.cfg.loss.weight * loss_s
- denorm_outputs = denormalize_per_image(outputs, [0.0, 0.0, 0.0], [3.5, 0.45, 1.0 ])
- mr_model = MRSignalModel()
- reconstruct_images = mr_model(denorm_outputs, max_vals)
- #reconstruct_images = mr_model(outputs, max_vals)
- loss_MSE = nn.MSELoss()
- loss = loss_MSE(reconstruct_images, denorm_images.float())
- #torch.distributed.barrier()
- #loss = reduce_mean(loss, self.world_size)
- #loss_c = reduce_mean(loss_c, self.world_size)
- #loss_s = reduce_mean(loss_s, self.world_size)
- self.optimizer.zero_grad()
- loss.backward()
- self.optimizer.step()
- self.lr_scheduler.step()
- if self.rank == 0:
- current_lr = self.lr_scheduler.get_lr()[0]
- self.logger.add_scalar("current_lr", current_lr, batch_id + 1)
- self.logger.add_scalar("loss_c", loss.item(), batch_id + 1)
- self.logger.add_scalar("loss_s", loss.item(), batch_id + 1)
- self.logger.add_scalar("loss", loss.item(), batch_id + 1)
- if batch_id % self.cfg.print_freq == 0:
- global_logger.info('batch: {}, style_loss: {}, content_loss: {}, loss: {}'.format(batch_id, loss.item(), loss.item(), loss.item()))
- output_name = os.path.join(self.cfg.output, self.cfg.task_name, 'img_save', str(batch_id)+'.jpg')
- output_images = torch.cat((content_images.cpu(), reconstruct_images.cpu(), outputs.cpu()), 0)
- save_image(output_images, output_name, nrow=1)
- if batch_id % self.cfg.save_freq == 0:
- save_checkpoint({
- 'step':batch_id,
- 'state_dict':self.model.state_dict(),
- 'optimizer':self.optimizer.state_dict()
- },os.path.join(self.cfg.output, self.cfg.task_name, 'model_save', str(batch_id)+ '.ckpt'))
|