100字范文,内容丰富有趣,生活中的好帮手!
100字范文 > 【学习笔记】pytorch-图像风格迁移实战

【学习笔记】pytorch-图像风格迁移实战

时间:2021-02-23 01:40:07

相关推荐

【学习笔记】pytorch-图像风格迁移实战

1. 风格迁移入门

我们首先选取一幅图像作为基准图像,也可以将其称为内容图像,然后选取另一幅或多幅图像作为我们希望获取相应风格的图像,也可以将其称为风格图像。图像风格迁移算法就是在保证内容图像的完整性的前提下,将风格图像的风格融入内容图像中,使得内容图像的原始风格最后发生了转变,最终输出的图像呈现的将是输入的内容图像的内容和风格图像风格之间的理想融合。

图像风格迁移实现的难点就是如何有效地提取一张图像的风格。和传统的图像风格提取方法不同,我们在基于神经网络的图像风格迁移算法中使用卷积神经网络来完成对图像风格的提取。

2. pytorch实战

首先我们需要获取一张内容图片和一张风格图片;然后定义两个度量值,一个度量叫做内容度量值,另一个叫做风格度量值,其中的内容度量值用于衡量图片之间的内容差异程度,风格度量值用于衡量图片之间的风格差异程度;最后,简历神经网络模型,对内容图片中的内容和风格图片的风格进行提取,以内容图片为基准将其输入建立的模型中,并不断调整内容度量值和风格度量值,让他们趋近最小,最后输出的图片就是内容与风格融合的图片。

2.1 图像的内容损失

内容度量值可以使用均方误差作为损失函数,在代码中定义的图像内容损失如下:

class Content_loss(torch.nn.Module):def __init__(self,weight,targrt):super(Content_loss, self).__init__()self.weight = weightself.target = target.detach()*weightself.loss_fn = torch.nn.MSELoss()def forward(self, input):self.loss = self.loss_fn(input*self.weight, self.target)return inputdef backward(self):self.loss.backward(retain_graph = True)return self.loss

以上代码中的target是通过卷积获取到的输入图像中的内容,weight是我们设置的一个权重参数,用来控制内容和风格对最后合成图像的影响程度;input代表输入图像,target.detach()用于对提取到的内容进行锁定,不需要进行梯度;forward函数用于计算输入图像和内容图像之间的损失值;backward函数根据计算得到的损失值进行后向传播,并返回损失值。

2.2 图像的风格损失

风格度量同样使用均方误差作为损失函数,代码如下:

class Style_loss(torch.nn.Module):def __init__(self,weight,target):super(Style_loss, self).__init__():self.weight = weightself.target = targetself.loss_fn = torch.nn.MSELoss()self.gram = Gram_matrix()def forward(self,input):self.Gram = self.gram(input.clone())self.Gram.mul_(self.weight)self.loss = self.loss_fn(self.Gram, self.target)return inputdef backward(self):self.loss.backward(retain_graph = True)return self.loss

风格损失计算的代码基本和内容损失计算的代码相似,不同之处是在代码中引入了一个Gram_matri类定义的实例参与风格损失的计算,这个类的代码如下:

class Gram_matrix(torch.nn.Module):def forward(self,input):a,b,c,d = input.size()feature = input.view(a*b,c*d)gram = torch.mm(feature, feature.t())return gram.div(a*b*c*d

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。