Pytorch Advanced(三) Neural Style Transfer

神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本。

这里用pytorch重新实现一次,原理图如下:


from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as npdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载图像

def load_image(image_path, transform=None, max_size=None, shape=None):"""Load an image and convert it to a torch tensor."""image = Image.open(image_path)if max_size:scale = max_size / max(image.size)size = np.array(image.size) * scaleimage = image.resize(size.astype(int), Image.ANTIALIAS)if shape:image = image.resize(shape, Image.LANCZOS)if transform:image = transform(image).unsqueeze(0)return image.to(device)

这里用的模型是 VGG-19,所要用的是网络中的5个卷积层

class VGGNet(nn.Module):def __init__(self):"""Select conv1_1 ~ conv5_1 activation maps."""super(VGGNet, self).__init__()self.select = ['0', '5', '10', '19', '28'] self.vgg = models.vgg19(pretrained=True).featuresdef forward(self, x):"""Extract multiple convolutional feature maps."""features = []for name, layer in self.vgg._modules.items():x = layer(x)if name in self.select:features.append(x)return features

 模型结构如下,可以看到使用序列模型来写的VGG-NET,所以标号即层号,我们要保存的是['0', '5', '10', '19', '28'] 的输出结果。

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace)(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace)(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace)(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(24): ReLU(inplace)(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(26): ReLU(inplace)(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace)(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(31): ReLU(inplace)(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(33): ReLU(inplace)(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(35): ReLU(inplace)(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace)(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace)(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

 训练:

接下来对训练过程进行解释:

1、加载风格图像和内容图像,我们在之前的博客中使用的一幅加噪图进行训练,这里是用的内容图像的拷贝。

2、我们需要优化的就是作为目标的内容图像拷贝,可以看到target需要求导。

3、VGGnet参数是不需要优化的,所以设置为验证状态。

4、将3幅图像输入网络,得到总共15个输出(每个图像有5层的输出)

5、内容损失:这里是遍历5个层的输出来计算损失,而在keras版本中只用了第4层的输出计算损失

6、风格损失:同样计算格拉姆风格矩阵,将每一层的风格损失叠加,得到总的风格损失,计算公式同样和keras版本有所不一样

7、反向传播

def main(config):# Image preprocessing# VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].# We use the same normalization statistics here.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])# Load content and style images# Make the style image same size as the content imagecontent = load_image(config.content, transform, max_size=config.max_size)style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])# Initialize a target image with the content imagetarget = content.clone().requires_grad_(True)optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])vgg = VGGNet().to(device).eval()for step in range(config.total_step):# Extract multiple(5) conv feature vectorstarget_features = vgg(target)content_features = vgg(content)style_features = vgg(style)style_loss = 0content_loss = 0for f1, f2, f3 in zip(target_features, content_features, style_features):# Compute content loss with target and content imagescontent_loss += torch.mean((f1 - f2)**2)# Reshape convolutional feature maps_, c, h, w = f1.size()f1 = f1.view(c, h * w)f3 = f3.view(c, h * w)# Compute gram matrixf1 = torch.mm(f1, f1.t())f3 = torch.mm(f3, f3.t())# Compute style loss with target and style imagesstyle_loss += torch.mean((f1 - f3)**2) / (c * h * w) # Compute total loss, backprop and optimizeloss = content_loss + config.style_weight * style_loss optimizer.zero_grad()loss.backward()optimizer.step()if (step+1) % config.log_step == 0:print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' .format(step+1, config.total_step, content_loss.item(), style_loss.item()))if (step+1) % config.sample_step == 0:# Save the generated imagedenorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))img = target.clone().squeeze()img = denorm(img).clamp_(0, 1)torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

写在if __name__=="__main__"后面的语句只会在本脚本中才能被执行,被调用时是不会被执行的。 

python的命令行工具:argparse,很优雅的添加参数

但是由于jupyter不支持添加外部参数,所以使用了外部博客的方法来支持(记住更改读取图片的位置)

import sys
if __name__ == "__main__":#解决方案来自于博客if '-f' in sys.argv:sys.argv.remove('-f')parser = argparse.ArgumentParser()parser.add_argument('--content', type=str, default='png/content.png')parser.add_argument('--style', type=str, default='png/style.png')parser.add_argument('--max_size', type=int, default=400)parser.add_argument('--total_step', type=int, default=2000)parser.add_argument('--log_step', type=int, default=10)parser.add_argument('--sample_step', type=int, default=500)parser.add_argument('--style_weight', type=float, default=100)parser.add_argument('--lr', type=float, default=0.003)#config = parser.parse_args()config = parser.parse_known_args()[0]   #参考博客 https://blog.csdn.net/ken_for_learning/article/details/89675904print(config)main(config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/76293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

零碎的c++二

虚函数 虚函数是C中实现多态的一种机制,它允许通过基类指针或引用来调用派生类的成员函数。虚函数的作用是实现动态绑定,即在运行时根据对象的实际类型来确定调用哪个函数。虚函数的声明方式是在函数前加上关键字virtual,如: cl…

Json“牵手”亚马逊商品详情数据方法,亚马逊商品详情API接口,亚马逊API申请指南

亚马逊平台是美国最大的一家网络电子商务公司,亚马逊公司是1995年成立,刚开始只做网上书籍售卖业务,后来扩展到了其他产品。现在已经是全世界商品品种最多的网上零售商和第二互联网公司,亚马逊是北美洲、欧洲等地区的主流购物平台…

数据结构:线性表之-循环双向链表(万字详解)

目录 基本概念 1,什么是双向链表 2,与单向链表的区别 双向链表详解 功能展示: 1. 定义链表 2,创建双向链表 3,初始化链表 4,尾插 5,头插 6,尾删 判断链表是否被删空 尾删代码 7&a…

我们这一代人的机会是什么?

大家好,我是苍何,今天作为专业嘉宾参观了 2023 年中国国际智能产业博览会(智博会),是一场以「智汇八方,博采众长」为主题的汇聚全球智能技术和产业创新的盛会,感触颇深,随着中国商业…

JVM相关知识点

Java可以跨平台的原因 Java可以跨平台的原因是因为它使用了Java虚拟机(JVM)作为中间层。Java源代码首先被编译成字节码,然后由JVM解释执行或即时编译成本地机器代码。这样,在不同的操作系统上,只需要安装适合该操作系…

9月11日作业

思维导图 代码 #include <iostream> #include<string.h>using namespace std;class myString { private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度 public://无参构造myString():size(10){str new char[size]; …

从零学算法2849

2849.给你四个整数 sx、sy、fx、fy 以及一个 非负整数 t 。 在一个无限的二维网格中&#xff0c;你从单元格 (sx, sy) 开始出发。每一秒&#xff0c;你 必须 移动到任一与之前所处单元格相邻的单元格中。 如果你能在 恰好 t 秒 后到达单元格 (fx, fy) &#xff0c;返回 true &a…

目标检测YOLO实战应用案例100讲-毫米波辐射图像去模糊重建与目标检测(续)

目录 3.3基于RSRN模型的毫米波辐射图像去模糊重建方法 3.3.2非线性映射 3.3.3多尺度模糊提取

淘宝京东扣库存怎么实现的

1. 使用kv存储实时的库存&#xff0c;直接在kv里扣减&#xff0c;避免用分布式锁 2. 不要先查再扣&#xff0c;直接扣扣扣&#xff0c;扣到负数&#xff0c;&#xff08;增改就直接在kv里做&#xff09;&#xff0c;就说明超卖了&#xff0c;回滚刚才的扣减 3. 同时写MQ&…

vue中打印指定dom元素

和window.print()效果一样&#xff0c;调出打印窗口&#xff0c;只是当前使用的插件是vue-print-nb 官网地址&#xff1a;vue-print-nb vue2中使用 安装插件 npm install vue-print-nb --save导入插件 import Print from vue-print-nb // 全局使用 Vue.use(Print);//or// 单…

如何确保ChatGPT的文本生成对特定行业术语的正确使用?

确保ChatGPT在特定行业术语的正确使用是一个重要而复杂的任务。这涉及到许多方面&#xff0c;包括数据预处理、模型训练、微调、评估和监控。下面我将详细介绍如何确保ChatGPT的文本生成对特定行业术语的正确使用&#xff0c;并探讨这一过程中的关键考虑因素。 ### 1. 数据预处…

JVM类加载机制

目录 一、Java为什么是一种跨平台的语言&#xff1f; 二、Java代码的执行流程 解释执行为主&#xff0c;编译执行为辅&#xff1a; 三、类加载的过程 3.1、加载 类加载器&#xff08;就是加载类的&#xff09;分为&#xff1a; 3.1.1、启动类加载器&#xff08;Bootstrap…

UMA 2 - Unity Multipurpose Avatar☀️三.给UMA设置默认服饰Recipes

文章目录 🟥 项目基础配置🟧 给UMA配置默认服饰Recipes🟨 设置服饰Recipes属性🟥 项目基础配置 将 UMA_DCS 预制体放到场景中创建空物体,添加DynamicCharacterAvatar 脚本,选择 HumanMaleDCS作为我们的基本模型配置默认Animator 🟧 给UMA配置默认服饰Recipes 服饰Re…

回归预测 | MATLAB实现PCA-BP主成分降维结合BP神经网络多输入单输出回归预测

回归预测 | MATLAB实现PCA-BP主成分降维结合BP神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现PCA-BP主成分降维结合BP神经网络多输入单输出回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 MATLAB实现PCA-BP主成分降维算法结合BP神经网络多输入单输出回…

Linux命令(78)之read

linux命令之read 1.read介绍 linux命令read用来接收键盘或其它文件的输入&#xff0c;得到输入后&#xff0c;read命令将接收的数据放入到标准变量中。 2.read用法 read [参数] [变量名称] read常用参数 参数说明-p后面跟提示信息-e可以使用命令补全功能-n输入文本的长度-s…

【数据结构】串

串 串的顺序实现简单的模式匹配算法KMP算法KMP算法的进一步优化 串的顺序实现 初始化 #define MaxSize 50 typedef char ElemType;//顺序存储表示 typedef struct{ElemType data[MaxSize];int length; }SString;/*** 初始化串*/ void InitString(SString *string) {for (int …

点云从入门到精通技术详解100篇-基于车载激光点云的道路标线提取及分类方法

目录 前言 车载 LiDAR 技术基础理论 2.1 车载 LiDAR 系统组成 2.2 车载 LiDAR 系统工作原理

Python基础continue和break关键字

continue 和 break 关键字 continue 含义&#xff1a;表示跳过本次循环&#xff0c;继续下次循环 注意&#xff1a;continue在while循环中不可以使用 例子&#xff1a; for i in range(5): if i 3: continue else: print(i) #结果为0,1,2,4 当i 3的时候&#xff0c;跳…

变压器耐压试验电压及电源容量的计算

被试变压器的额定电压为&#xff08;11081. 25%&#xff09; /10. 5kV&#xff0c; 联接组标号为 YNd11。 试验时高压分接开关置于第 1 分接位置&#xff0c; 即高压侧电压为 126kV&#xff0c; 高、 低压电压比 K1126/&#xff08;√310. 5&#xff09; 6. 93。 现以 A 相试验…

Cmake入门(一文读懂)

目录 1、Cmake简介2、安装CMake3、CMakeLists.txt4、单目录简单实例4.1、CMakeLists.txt4.2、构建bulid内部构建外部构建 4.3、运行C语言程序 5、多目录文件简单实例5.1、根目录CMakeLists.txt5.2、源文件目录5.3、utils.h5.4、创建build 6、生成库文件和链接外部库文件7、注意…