VGG_loss.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import torch.nn as nn
  2. import torch
  3. def calc_mean_std(feat, eps=1e-5):
  4. size = feat.size()
  5. assert (len(size) == 4)
  6. N, C = size[:2]
  7. feat_var = feat.view(N, C, -1).var(dim=2) + eps
  8. feat_std = feat_var.sqrt().view(N, C, 1, 1)
  9. feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
  10. return feat_mean, feat_std
  11. vgg = nn.Sequential(
  12. nn.Conv2d(3, 3, (1, 1)),
  13. nn.ReflectionPad2d((1, 1, 1, 1)),
  14. nn.Conv2d(3, 64, (3, 3)),
  15. nn.ReLU(), # relu1-1
  16. nn.ReflectionPad2d((1, 1, 1, 1)),
  17. nn.Conv2d(64, 64, (3, 3)),
  18. nn.ReLU(), # relu1-2
  19. nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
  20. nn.ReflectionPad2d((1, 1, 1, 1)),
  21. nn.Conv2d(64, 128, (3, 3)),
  22. nn.ReLU(), # relu2-1
  23. nn.ReflectionPad2d((1, 1, 1, 1)),
  24. nn.Conv2d(128, 128, (3, 3)),
  25. nn.ReLU(), # relu2-2
  26. nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
  27. nn.ReflectionPad2d((1, 1, 1, 1)),
  28. nn.Conv2d(128, 256, (3, 3)),
  29. nn.ReLU(), # relu3-1
  30. nn.ReflectionPad2d((1, 1, 1, 1)),
  31. nn.Conv2d(256, 256, (3, 3)),
  32. nn.ReLU(), # relu3-2
  33. nn.ReflectionPad2d((1, 1, 1, 1)),
  34. nn.Conv2d(256, 256, (3, 3)),
  35. nn.ReLU(), # relu3-3
  36. nn.ReflectionPad2d((1, 1, 1, 1)),
  37. nn.Conv2d(256, 256, (3, 3)),
  38. nn.ReLU(), # relu3-4
  39. nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
  40. nn.ReflectionPad2d((1, 1, 1, 1)),
  41. nn.Conv2d(256, 512, (3, 3)),
  42. nn.ReLU(), # relu4-1, this is the last layer used
  43. nn.ReflectionPad2d((1, 1, 1, 1)),
  44. nn.Conv2d(512, 512, (3, 3)),
  45. nn.ReLU(), # relu4-2
  46. nn.ReflectionPad2d((1, 1, 1, 1)),
  47. nn.Conv2d(512, 512, (3, 3)),
  48. nn.ReLU(), # relu4-3
  49. nn.ReflectionPad2d((1, 1, 1, 1)),
  50. nn.Conv2d(512, 512, (3, 3)),
  51. nn.ReLU(), # relu4-4
  52. nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
  53. nn.ReflectionPad2d((1, 1, 1, 1)),
  54. nn.Conv2d(512, 512, (3, 3)),
  55. nn.ReLU(), # relu5-1
  56. nn.ReflectionPad2d((1, 1, 1, 1)),
  57. nn.Conv2d(512, 512, (3, 3)),
  58. nn.ReLU(), # relu5-2
  59. nn.ReflectionPad2d((1, 1, 1, 1)),
  60. nn.Conv2d(512, 512, (3, 3)),
  61. nn.ReLU(), # relu5-3
  62. nn.ReflectionPad2d((1, 1, 1, 1)),
  63. nn.Conv2d(512, 512, (3, 3)),
  64. nn.ReLU() # relu5-4
  65. )
  66. def weighted_mse_loss_merge(input_mean, target_mean, input_std, target_std, k=0.8):
  67. loss_mean = ((input_mean - target_mean) ** 2)
  68. sort_loss_mean,idx = torch.sort(loss_mean,dim=1)
  69. sort_loss_mean[:,int(sort_loss_mean.shape[1]*k):] = 0
  70. loss_std = ((input_std - target_std) ** 2)
  71. loss_std[:,idx[:,int(idx.shape[1]*k):]] = 0
  72. return sort_loss_mean.mean(),loss_std.mean()
  73. class VGGLoss(nn.Module):
  74. def __init__(self, vgg_model):
  75. super(VGGLoss, self).__init__()
  76. encoder = vgg
  77. encoder.load_state_dict(torch.load(vgg_model))
  78. enc_layers = list(encoder.children())
  79. self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
  80. self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
  81. self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
  82. self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
  83. self.mse_loss = nn.MSELoss()
  84. # fix the encoder
  85. for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
  86. for param in getattr(self, name).parameters():
  87. param.requires_grad = False
  88. # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
  89. def encode_with_intermediate(self, input):
  90. results = [input]
  91. for i in range(4):
  92. func = getattr(self, 'enc_{:d}'.format(i + 1))
  93. results.append(func(results[-1]))
  94. return results[1:]
  95. # extract relu4_1 from input image
  96. def encode(self, input):
  97. for i in range(4):
  98. input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
  99. return input
  100. def calc_content_loss(self, input, target):
  101. assert (input.size() == target.size())
  102. assert (target.requires_grad is False)
  103. size1 = input.size()
  104. size2 = target.size()
  105. input_mean, input_std = calc_mean_std(input)
  106. target_mean, target_std = calc_mean_std(target)
  107. normalized_feat1 = (input - input_mean.expand(size1)) / input_std.expand(size1)
  108. normalized_feat2 = (target - target_mean.expand(size2)) / target_std.expand(size2)
  109. return self.mse_loss(normalized_feat1, normalized_feat2)
  110. def calc_style_loss(self, input, target, k=0.8):
  111. assert (input.size() == target.size())
  112. assert (target.requires_grad is False)
  113. input_mean, input_std = calc_mean_std(input)
  114. target_mean, target_std = calc_mean_std(target)
  115. if k < 0.0:
  116. return self.mse_loss(input_mean, target_mean) + self.mse_loss(input_std, target_std)
  117. else:
  118. loss_mean, loss_std = weighted_mse_loss_merge(input_mean, target_mean, input_std, target_std, k)
  119. return loss_mean+loss_std
  120. def forward(self, content_images, style_images, stylized_images, k=0.8):
  121. style_feats = self.encode_with_intermediate(style_images)
  122. content_feat = self.encode(content_images)
  123. stylized_feats = self.encode_with_intermediate(stylized_images)
  124. loss_c = self.calc_content_loss(stylized_feats[-1], content_feat)
  125. loss_s = self.calc_style_loss(stylized_feats[0], style_feats[0], k)
  126. for i in range(1, 4):
  127. loss_s += self.calc_style_loss(stylized_feats[i], style_feats[i], k)
  128. return loss_c, loss_s