| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import torch.nn as nn
- import torch
- def calc_mean_std(feat, eps=1e-5):
- size = feat.size()
- assert (len(size) == 4)
- N, C = size[:2]
- feat_var = feat.view(N, C, -1).var(dim=2) + eps
- feat_std = feat_var.sqrt().view(N, C, 1, 1)
- feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
- return feat_mean, feat_std
- vgg = nn.Sequential(
- nn.Conv2d(3, 3, (1, 1)),
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(3, 64, (3, 3)),
- nn.ReLU(), # relu1-1
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(64, 64, (3, 3)),
- nn.ReLU(), # relu1-2
- nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(64, 128, (3, 3)),
- nn.ReLU(), # relu2-1
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(128, 128, (3, 3)),
- nn.ReLU(), # relu2-2
- nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(128, 256, (3, 3)),
- nn.ReLU(), # relu3-1
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(256, 256, (3, 3)),
- nn.ReLU(), # relu3-2
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(256, 256, (3, 3)),
- nn.ReLU(), # relu3-3
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(256, 256, (3, 3)),
- nn.ReLU(), # relu3-4
- nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(256, 512, (3, 3)),
- nn.ReLU(), # relu4-1, this is the last layer used
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(512, 512, (3, 3)),
- nn.ReLU(), # relu4-2
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(512, 512, (3, 3)),
- nn.ReLU(), # relu4-3
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(512, 512, (3, 3)),
- nn.ReLU(), # relu4-4
- nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(512, 512, (3, 3)),
- nn.ReLU(), # relu5-1
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(512, 512, (3, 3)),
- nn.ReLU(), # relu5-2
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(512, 512, (3, 3)),
- nn.ReLU(), # relu5-3
- nn.ReflectionPad2d((1, 1, 1, 1)),
- nn.Conv2d(512, 512, (3, 3)),
- nn.ReLU() # relu5-4
- )
- def weighted_mse_loss_merge(input_mean, target_mean, input_std, target_std, k=0.8):
- loss_mean = ((input_mean - target_mean) ** 2)
- sort_loss_mean,idx = torch.sort(loss_mean,dim=1)
- sort_loss_mean[:,int(sort_loss_mean.shape[1]*k):] = 0
- loss_std = ((input_std - target_std) ** 2)
- loss_std[:,idx[:,int(idx.shape[1]*k):]] = 0
- return sort_loss_mean.mean(),loss_std.mean()
-
- class VGGLoss(nn.Module):
- def __init__(self, vgg_model):
- super(VGGLoss, self).__init__()
- encoder = vgg
- encoder.load_state_dict(torch.load(vgg_model))
- enc_layers = list(encoder.children())
- self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
- self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
- self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
- self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
- self.mse_loss = nn.MSELoss()
- # fix the encoder
- for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
- for param in getattr(self, name).parameters():
- param.requires_grad = False
- # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
- def encode_with_intermediate(self, input):
- results = [input]
- for i in range(4):
- func = getattr(self, 'enc_{:d}'.format(i + 1))
- results.append(func(results[-1]))
- return results[1:]
- # extract relu4_1 from input image
- def encode(self, input):
- for i in range(4):
- input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
- return input
- def calc_content_loss(self, input, target):
- assert (input.size() == target.size())
- assert (target.requires_grad is False)
- size1 = input.size()
- size2 = target.size()
- input_mean, input_std = calc_mean_std(input)
- target_mean, target_std = calc_mean_std(target)
- normalized_feat1 = (input - input_mean.expand(size1)) / input_std.expand(size1)
- normalized_feat2 = (target - target_mean.expand(size2)) / target_std.expand(size2)
- return self.mse_loss(normalized_feat1, normalized_feat2)
- def calc_style_loss(self, input, target, k=0.8):
- assert (input.size() == target.size())
- assert (target.requires_grad is False)
- input_mean, input_std = calc_mean_std(input)
- target_mean, target_std = calc_mean_std(target)
- if k < 0.0:
- return self.mse_loss(input_mean, target_mean) + self.mse_loss(input_std, target_std)
- else:
- loss_mean, loss_std = weighted_mse_loss_merge(input_mean, target_mean, input_std, target_std, k)
- return loss_mean+loss_std
- def forward(self, content_images, style_images, stylized_images, k=0.8):
- style_feats = self.encode_with_intermediate(style_images)
- content_feat = self.encode(content_images)
- stylized_feats = self.encode_with_intermediate(stylized_images)
- loss_c = self.calc_content_loss(stylized_feats[-1], content_feat)
- loss_s = self.calc_style_loss(stylized_feats[0], style_feats[0], k)
- for i in range(1, 4):
- loss_s += self.calc_style_loss(stylized_feats[i], style_feats[i], k)
- return loss_c, loss_s
|