from paddle.vision.models.vgg import make_layers, cfgs
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
import paddle as P
import numpy as np
import PIL
from paddle.vision import transforms as T
from matplotlib import pyplot as plt
import paddle
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized
class ConvBlock(nn.Layer):
def __init__(self, input_channels, output_channels, groups, name=None):
super(ConvBlock, self).__init__()
self.groups = groups
self.conv1 = Conv2D(
in_channels=input_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)
if groups == 2 or groups == 3 or groups == 4:
self.conv2 = Conv2D(
in_channels=output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)
if groups == 3 or groups == 4:
self.conv3 = Conv2D(
in_channels=output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)
if groups == 4:
self.conv4 = Conv2D(
in_channels=output_channels,
out_channels=output_channels,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)
# self.max_pool = MaxPool2D(kernel_size=2, stride=2, padding=0)
self.max_pool = AvgPool2D(kernel_size=2, stride=2, padding=0)
self.relu = nn.ReLU()
def forward(self, inputs):
conv1 = self.conv1(inputs)
x = self.relu(conv1)
if self.groups == 2 or self.groups == 3 or self.groups == 4:
conv2 = self.conv2(x)
x = self.relu(conv2)
if self.groups == 3 or self.groups == 4:
x = self.conv3(x)
x = self.relu(x)
if self.groups == 4:
x = self.conv4(x)
x = self.relu(x)
x = self.max_pool(x)
return x, conv1, conv2
class VGGNet(nn.Layer):
def __init__(self):
self.stop_grad_layers = 5
super(VGGNet, self).__init__()
self.groups = [2, 2, 4, 4, 4]
self.conv_block_1 = ConvBlock(3, 64, self.groups[0])
self.conv_block_2 = ConvBlock(64, 128, self.groups[1])
self.conv_block_3 = ConvBlock(128, 256, self.groups[2])
self.conv_block_4 = ConvBlock(256, 512, self.groups[3])
self.conv_block_5 = ConvBlock(512, 512, self.groups[4])
for idx, block in enumerate([
self.conv_block_1, self.conv_block_2, self.conv_block_3,
self.conv_block_4, self.conv_block_5
]):
if self.stop_grad_layers >= idx + 1:
for param in block.parameters():
param.trainable = False
def forward(self, inputs):
x, conv1_1, _ = self.conv_block_1(inputs)
x, conv2_1, _ = self.conv_block_2(x)
x, conv3_1, _ = self.conv_block_3(x)
x, conv4_1, conv4_2 = self.conv_block_4(x)
_, conv5_1, _ = self.conv_block_5(x)
return [conv4_2, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1]
npy = np.load('work/vgg19.npy', encoding='latin1', allow_pickle=True)
my19 = VGGNet()
my19.conv_block_1.conv1.set_dict({'weight':npy[0].transpose([3,2,0,1])})
my19.conv_block_1.conv1.set_dict({'bias':npy[1]})
my19.conv_block_1.conv2.set_dict({'weight':npy[2].transpose([3,2,0,1])})
my19.conv_block_1.conv2.set_dict({'bias':npy[3]})
my19.conv_block_2.conv1.set_dict({'weight':npy[4].transpose([3,2,0,1])})
my19.conv_block_2.conv1.set_dict({'bias':npy[5]})
my19.conv_block_2.conv2.set_dict({'weight':npy[6].transpose([3,2,0,1])})
my19.conv_block_2.conv2.set_dict({'bias':npy[7]})
my19.conv_block_3.conv1.set_dict({'weight':npy[8].transpose([3,2,0,1])})
my19.conv_block_3.conv1.set_dict({'bias':npy[9]})
my19.conv_block_3.conv2.set_dict({'weight':npy[10].transpose([3,2,0,1])})
my19.conv_block_3.conv2.set_dict({'bias':npy[11]})
my19.conv_block_3.conv3.set_dict({'weight':npy[12].transpose([3,2,0,1])})
my19.conv_block_3.conv3.set_dict({'bias':npy[13]})
my19.conv_block_3.conv4.set_dict({'weight':npy[14].transpose([3,2,0,1])})
my19.conv_block_3.conv4.set_dict({'bias':npy[15]})
my19.conv_block_4.conv1.set_dict({'weight':npy[16].transpose([3,2,0,1])})
my19.conv_block_4.conv1.set_dict({'bias':npy[17]})
my19.conv_block_4.conv2.set_dict({'weight':npy[18].transpose([3,2,0,1])})
my19.conv_block_4.conv2.set_dict({'bias':npy[19]})
my19.conv_block_4.conv3.set_dict({'weight':npy[20].transpose([3,2,0,1])})
my19.conv_block_4.conv3.set_dict({'bias':npy[21]})
my19.conv_block_4.conv4.set_dict({'weight':npy[22].transpose([3,2,0,1])})
my19.conv_block_4.conv4.set_dict({'bias':npy[23]})
my19.conv_block_5.conv1.set_dict({'weight':npy[24].transpose([3,2,0,1])})
my19.conv_block_5.conv1.set_dict({'bias':npy[25]})
my19.conv_block_5.conv2.set_dict({'weight':npy[26].transpose([3,2,0,1])})
my19.conv_block_5.conv2.set_dict({'bias':npy[27]})
my19.conv_block_5.conv3.set_dict({'weight':npy[28].transpose([3,2,0,1])})
my19.conv_block_5.conv3.set_dict({'bias':npy[29]})
my19.conv_block_5.conv4.set_dict({'weight':npy[30].transpose([3,2,0,1])})
my19.conv_block_5.conv4.set_dict({'bias':npy[31]})
def load_img(path_to_img):
max_dim = 512
img = plt.imread(path_to_img)
img = P.to_tensor(img, P.float32)
img = img / 255.
img = img.transpose([2,0,1])
# print(img.shape)
shape = np.array(img.shape[1:], dtype=float)
long_dim = max(shape)
# print(long_dim)
scale = max_dim / long_dim
# print(scale)
# new_shape = tf.cast(shape * scale, tf.int32)
new_shape = np.array(shape * scale, dtype=int)
# print(new_shape)
# print(new_shape.tolist())
resize = T.Compose([T.Resize(new_shape.tolist())])
img = resize(img)
img = img.unsqueeze(0)
return img
def imshow(image, title=None):
if len(image.shape)>3:
image = image.squeeze()
plt.imshow(image.transpose([1,2,0]))
if title:
plt.title(title)
content_image = load_img('work/YellowLabradorLooking_new.jpg')
# style_image = load_img('work/kandinsky5.jpg')
style_image = load_img('work/sky.jpeg')
plt.subplot(1,2,1)
imshow(content_image, 'Content Image')
plt.subplot(1,2,2)
imshow(style_image, 'Style Image')
print(f'[info]content shape:{content_image.shape} style shape:{style_image.shape}')
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data
[info]content shape:[1, 3, 422, 512] style shape:[1, 3, 320, 512]
h = 224
w = 224
resize = T.Compose([T.Resize((h,w))])
from paddle import optimizer
from paddle.vision import transforms as T
from matplotlib import pyplot as plt
from paddle.optimizer import Adam
P.set_device('gpu')
CUDAPlace(0)
# content = resize(content_image.squeeze())
# content = content.unsqueeze(0)
content = content_image
# style = resize(style_image.squeeze())
# style = style.unsqueeze(0)
style = style_image
transfer = load_img('work/YellowLabradorLooking_new.jpg')
# transfer = P.randn(shape=[1,3,422,512],dtype=P.float32)
# transfer = resize(transfer.squeeze())
# transfer = transfer.unsqueeze(0)
transfer.stop_gradient = False
optim = Adam(parameters=[transfer],learning_rate=0.02, beta1=0.99, epsilon=1e-1)
# optim = P.optimizer.SGD(parameters=[transfer])
my19.eval()
style_features = my19(style)
conte_features = my19(content)
trans_features = my19(transfer)
print(f'[info]style features:{[x.shape for x in style_features]}')
print(f'[info]conte features:{[x.shape for x in conte_features]}')
print(f'[info]trans features:{[x.shape for x in trans_features]}')
[info]style features:[[1, 512, 40, 64], [1, 64, 320, 512], [1, 128, 160, 256], [1, 256, 80, 128], [1, 512, 40, 64], [1, 512, 20, 32]] [info]conte features:[[1, 512, 52, 64], [1, 64, 422, 512], [1, 128, 211, 256], [1, 256, 105, 128], [1, 512, 52, 64], [1, 512, 26, 32]] [info]trans features:[[1, 512, 52, 64], [1, 64, 422, 512], [1, 128, 211, 256], [1, 256, 105, 128], [1, 512, 52, 64], [1, 512, 26, 32]]
class GramMatrix(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, imgs):
b, c, h, w = imgs.shape
imgs = paddle.reshape(imgs, (-1, c, h*w))
grammatrix = paddle.bmm(imgs, paddle.transpose(imgs, [0, 2, 1]))
return grammatrix/(h*w)
class GramMSEloss(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, input, target):
gram_func_input = GramMatrix()
gram_func_target = GramMatrix()
loss_func = nn.MSELoss()
gram_input = gram_func_input(input)
gram_target = gram_func_target(target)
return loss_func(gram_input, gram_target)
def high_pass_x_y(image):
x_var = image[:,:,1:,:] - image[:,:,:-1,:]
y_var = image[:,1:,:,:] - image[:,:-1,:,:]
return x_var, y_var
def total_variation_loss(image):
x_deltas, y_deltas = high_pass_x_y(image)
return paddle.sum(paddle.abs(x_deltas)) + paddle.sum(paddle.abs(y_deltas))
style_weights = [1,1,1,1,1]
# style_weights = [1e-2,1e-2,1e-2,1e-2,1e-2]
# content_weights = [1e4]
content_weights = [1e-10]
# style_weights = [0.5,1,1e2,1e3,1e4]
weights = content_weights + style_weights
total_variation_weight = 0.5
debug = False
for e in range(2010):
optim.clear_grad()
trans_features = my19(transfer)
style_features = my19(style)
conte_features = my19(content)
loss_funcs = [nn.MSELoss()] + [GramMSEloss()]*(len(style_features)-1)
layer_loss = [loss_f(t_out, out) for loss_f, t_out, out in zip(loss_funcs, \
trans_features, conte_features[:1]+style_features[1:])]
if debug:
print('[info]layer_loss:',[x.numpy().tolist() for x in layer_loss])
loss_mid = [loss * weight for loss, weight in zip(layer_loss,weights)]
if debug:
print(f'[info]layer_loss * weight:{[x.numpy().tolist() for x in loss_mid]}')
loss = sum(loss_mid[:1]) + sum(loss_mid[1:])/5
if debug:
print('[info]sum_content:',sum(loss_mid[:1]))
print('[info]sum_style:',sum(loss_mid[1:])/5 )
# loss += total_variation_weight*total_variation_loss(transfer)
loss.backward()
optim.step()
if e < 100:
if e % 10 ==0:
# pic = endp(transfer[0].detach()).numpy()
pic = transfer[0].detach().transpose([1,2,0])
# pic = endp(target_img[0]).numpy()
pic = paddle.clip(pic, min=0., max=1.)
im = PIL.Image.fromarray(np.uint8(pic.numpy()*255))
# im.save(f'trans_test/trans_6_{epoch}.png')
plt.imshow(im)
plt.title(f'epochs:{e}')
plt.show()
if e > 200:
if e % 200 ==0:
pic = transfer[0].detach().transpose([1,2,0])
# pic = endp(target_img[0]).numpy()
pic = paddle.clip(pic, min=0., max=1.)
im = PIL.Image.fromarray(np.uint8(pic.numpy()*255))
# im.save(f'trans_test/trans_6_{epoch}.png')
plt.imshow(im)
plt.title(f'epochs:{e}')
plt.show()