hf.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import torch
  2. from torch import nn
  3. from torch.nn import functional as F
  4. from math import log, pi, exp
  5. class SELayer(nn.Module):
  6. def __init__(self, channel, reduction=16):
  7. super(SELayer, self).__init__()
  8. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  9. self.fc = nn.Sequential(
  10. nn.Linear(channel, channel // reduction, bias=False),
  11. nn.ReLU(inplace=True),
  12. nn.Linear(channel // reduction, channel, bias=False),
  13. nn.Sigmoid()
  14. )
  15. def forward(self, x):
  16. x = torch.cat(x, dim=1)
  17. b, c, _, _ = x.size()
  18. y = self.avg_pool(x).view(b, c)
  19. y = self.fc(y).view(b, c, 1, 1)
  20. return x * y.expand_as(x)
  21. def reverse(self, x):
  22. return x[-1]
  23. class WeightedConcatLayer(nn.Module):
  24. def __init__(self, channel, activate='softmax'):
  25. super(WeightedConcatLayer, self).__init__()
  26. self.weight = nn.Parameter(torch.randn(channel), requires_grad=True)
  27. if activate == 'softmax':
  28. self.activate = nn.Softmax(-1)
  29. elif activate == 'sigmoid':
  30. self.activate = nn.Sigmoid()
  31. def forward(self, x):
  32. weights = self.activate(self.weight)
  33. out = []
  34. for idx, feat in enumerate(x):
  35. out.append(weights[idx] * feat)
  36. return torch.cat(out, dim=1)
  37. def reverse(self, x):
  38. weights = self.activate(self.weight)
  39. out = 0
  40. for idx, feat in enumerate(x):
  41. out += weights[idx] * feat
  42. return out
  43. class FlatConcatLayer(nn.Module):
  44. def __init__(self):
  45. super(FlatConcatLayer, self).__init__()
  46. def forward(self, x):
  47. return torch.cat(x, dim=1)
  48. def reverse(self, x):
  49. return x[-1]
  50. class AdaIN(nn.Module):
  51. def __init__(self):
  52. super().__init__()
  53. def calc_mean_std(self, feat, eps=1e-5):
  54. size = feat.size()
  55. assert (len(size) == 4)
  56. N, C = size[:2]
  57. feat_var = feat.view(N, C, -1).var(dim=2) + eps
  58. feat_std = feat_var.sqrt().view(N, C, 1, 1)
  59. feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
  60. return feat_mean, feat_std
  61. def forward(self, content, style_mean, style_std):
  62. assert style_mean is not None
  63. assert style_std is not None
  64. size = content.size()
  65. content_mean, content_std = self.calc_mean_std(content)
  66. style_mean = style_mean.reshape(size[0], content_mean.shape[1], 1, 1)
  67. style_std = style_std.reshape(size[0], content_mean.shape[1], 1, 1)
  68. normalized_feat = (content - content_mean.expand(size)) / content_std.expand(size)
  69. sum_mean = style_mean.expand(size)
  70. sum_std = style_std.expand(size)
  71. return normalized_feat*sum_std + sum_mean
  72. class Conv2dBlock(nn.Module):
  73. def __init__(self, input_dim ,output_dim, kernel_size, stride,
  74. padding=0, norm='none', activation='relu', pad_type='zero'):
  75. super(Conv2dBlock, self).__init__()
  76. self.use_bias = True
  77. # initialize padding
  78. if pad_type == 'reflect':
  79. self.pad = nn.ReflectionPad2d(padding)
  80. elif pad_type == 'replicate':
  81. self.pad = nn.ReplicationPad2d(padding)
  82. elif pad_type == 'zero':
  83. self.pad = nn.ZeroPad2d(padding)
  84. else:
  85. assert 0, "Unsupported padding type: {}".format(pad_type)
  86. # initialize normalization
  87. norm_dim = output_dim
  88. if norm == 'bn':
  89. self.norm = nn.BatchNorm2d(norm_dim)
  90. elif norm == 'in':
  91. #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
  92. self.norm = nn.InstanceNorm2d(norm_dim)
  93. elif norm == 'ln':
  94. self.norm = nn.LayerNorm(norm_dim)
  95. elif norm == 'adain':
  96. self.norm = AdaptiveInstanceNorm2d(norm_dim)
  97. elif norm == 'none' or norm == 'sn':
  98. self.norm = None
  99. else:
  100. assert 0, "Unsupported normalization: {}".format(norm)
  101. # initialize activation
  102. if activation == 'relu':
  103. self.activation = nn.ReLU(inplace=True)
  104. elif activation == 'lrelu':
  105. self.activation = nn.LeakyReLU(0.2, inplace=True)
  106. elif activation == 'prelu':
  107. self.activation = nn.PReLU()
  108. elif activation == 'selu':
  109. self.activation = nn.SELU(inplace=True)
  110. elif activation == 'tanh':
  111. self.activation = nn.Tanh()
  112. elif activation == 'none':
  113. self.activation = None
  114. else:
  115. assert 0, "Unsupported activation: {}".format(activation)
  116. # initialize convolution
  117. if norm == 'sn':
  118. self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
  119. else:
  120. self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
  121. def forward(self, x):
  122. x = self.conv(self.pad(x))
  123. if self.norm:
  124. x = self.norm(x)
  125. if self.activation:
  126. x = self.activation(x)
  127. return x
  128. class LB(nn.Module):
  129. def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
  130. super(LB, self).__init__()
  131. use_bias = True
  132. # initialize fully connected layer
  133. if norm == 'sn':
  134. self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
  135. else:
  136. self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
  137. # initialize normalization
  138. norm_dim = output_dim
  139. if norm == 'bn':
  140. self.norm = nn.BatchNorm1d(norm_dim)
  141. elif norm == 'in':
  142. self.norm = nn.InstanceNorm1d(norm_dim)
  143. elif norm == 'ln':
  144. self.norm = nn.LayerNorm(norm_dim)
  145. elif norm == 'none' or norm == 'sn':
  146. self.norm = None
  147. else:
  148. assert 0, "Unsupported normalization: {}".format(norm)
  149. # initialize activation
  150. if activation == 'relu':
  151. self.activation = nn.ReLU(inplace=True)
  152. elif activation == 'lrelu':
  153. self.activation = nn.LeakyReLU(0.2, inplace=True)
  154. elif activation == 'prelu':
  155. self.activation = nn.PReLU()
  156. elif activation == 'selu':
  157. self.activation = nn.SELU(inplace=True)
  158. elif activation == 'tanh':
  159. self.activation = nn.Tanh()
  160. elif activation == 'none':
  161. self.activation = None
  162. else:
  163. assert 0, "Unsupported activation: {}".format(activation)
  164. def forward(self, x):
  165. out = self.fc(x)
  166. if self.norm:
  167. out = self.norm(out)
  168. if self.activation:
  169. out = self.activation(out)
  170. return out
  171. class MLP(nn.Module):
  172. def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
  173. super(MLP, self).__init__()
  174. self.model = []
  175. self.model += [LB(input_dim, dim, norm=norm, activation=activ)]
  176. for i in range(n_blk - 2):
  177. self.model += [LB(dim, dim, norm=norm, activation=activ)]
  178. self.model += [LB(dim, output_dim, norm='none', activation='none')] # no output activations
  179. self.model = nn.Sequential(*self.model)
  180. def forward(self, x):
  181. x = x.view(x.size(0), -1)
  182. x = self.model(x)
  183. return x
  184. class StyleEncoder(nn.Module):
  185. def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
  186. super(StyleEncoder, self).__init__()
  187. self.model = []
  188. self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
  189. for i in range(2):
  190. self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
  191. dim *= 2
  192. for i in range(n_downsample - 2):
  193. self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
  194. self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
  195. self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
  196. self.model = nn.Sequential(*self.model)
  197. self.output_dim = dim
  198. def forward(self, x):
  199. return self.model(x)
  200. class HierarchyCoupling(nn.Module):
  201. def __init__(self, in_channel, out_channel, weight_type='fixed'):
  202. super(HierarchyCoupling, self).__init__()
  203. self.feat = None
  204. self.out_channel = out_channel
  205. self.in_channel = in_channel
  206. self.affine_net = nn.Sequential(
  207. nn.Conv2d(in_channels=in_channel, out_channels=in_channel*2, kernel_size=3, stride=1, padding=1, dilation=1),
  208. nn.InstanceNorm2d(in_channel*2),
  209. nn.ReLU(inplace=True),
  210. nn.Conv2d(in_channels=in_channel*2, out_channels=in_channel*2, kernel_size=3, stride=1, padding=1, dilation=1),
  211. nn.InstanceNorm2d(in_channel*2),
  212. nn.ReLU(inplace=True),
  213. nn.Conv2d(in_channels=in_channel*2, out_channels=out_channel, kernel_size=3, stride=1, padding=1, dilation=1),
  214. nn.ReLU(inplace=True),
  215. )
  216. self.adain = AdaIN()
  217. self.style_mlp = MLP(8, out_channel*2, out_channel*3, 3, norm='none', activ='relu')
  218. self.splits = self.out_channel // self.in_channel
  219. self.weight_type = weight_type
  220. self.fixed_weight = 0.5
  221. if self.weight_type == 'softmax' or self.weight_type == 'sigmoid':
  222. self.weight = WeightedConcatLayer(channel=self.splits, activate=self.weight_type)
  223. elif self.weight_type == 'attention':
  224. self.weight = SELayer(channel=self.out_channel, reduction=self.splits)
  225. elif self.weight_type == 'fixed':
  226. self.weight = FlatConcatLayer()
  227. elif self.weight_type == 'learned':
  228. self.weight = FlatConcatLayer()
  229. self.fixed_weight = nn.Parameter(torch.tensor(self.fixed_weight))
  230. self.fixed_weight.requires_grad = True
  231. else:
  232. raise NotImplementedError('Error: type {} for weight is not implemented.'.format(self.weight_type))
  233. def forward(self, input):
  234. b_size, n_channel, height, width = input.shape
  235. feature = self.affine_net(input)
  236. self.feat = feature
  237. output_list = []
  238. out = input - feature[:, 0:n_channel]
  239. output_list.append(out)
  240. tmp_out = out
  241. for i in range(1, self.splits):
  242. tmp_out = tmp_out - feature[:, i*n_channel:(i+1)*n_channel]
  243. output_list.append(tmp_out)
  244. return self.weight(output_list)
  245. def reverse(self, input, style):
  246. feature = self.feat
  247. pred_style = self.style_mlp(style)
  248. mean, std = pred_style.chunk(2, 1)
  249. input = self.adain(input, mean, std)
  250. output_list = []
  251. tmp_out = input[:, -self.in_channel:] + feature[:, -self.in_channel:]
  252. output_list.append(tmp_out)
  253. for i in range(self.splits-1, 0, -1):
  254. tmp_out = self.fixed_weight*(tmp_out + input[:, (i-1)*self.in_channel:i*self.in_channel]) + feature[:, (i-1)*self.in_channel:i*self.in_channel]*(1-self.fixed_weight)
  255. output_list.append(tmp_out)
  256. return self.weight.reverse(output_list)
  257. class HierarchyFlow(nn.Module):
  258. def __init__(self, pad_size=10, in_channel=3, out_channels=[30, 120], weight_type='fixed'):
  259. super(HierarchyFlow, self).__init__()
  260. self.pad_size = pad_size
  261. self.num_block = len(out_channels)
  262. self.in_channels = [in_channel]
  263. self.out_channels = out_channels
  264. self.padding = torch.nn.ReflectionPad2d(self.pad_size)
  265. self.blocks = nn.ModuleList()
  266. for i in range(self.num_block):
  267. self.blocks.append(HierarchyCoupling(in_channel=self.in_channels[i], out_channel=self.out_channels[i], weight_type=weight_type))
  268. self.in_channels.append(self.out_channels[i])
  269. self.style_net = StyleEncoder(n_downsample=2,input_dim=3, dim=64, style_dim=8, norm='none', activ='relu', pad_type='reflect')
  270. def forward(self, content, style):
  271. style_feat = self.style_net(style)
  272. content = self.padding(content)
  273. b_size, n_channel, height, width = content.shape
  274. for i in range(self.num_block):
  275. content = self.blocks[i](content)
  276. for i in range(self.num_block-1, -1, -1):
  277. content = self.blocks[i].reverse(content, style_feat)
  278. content = content[:, :, self.pad_size:height-self.pad_size, self.pad_size:width-self.pad_size]
  279. return content