Pytorch Advanced(三) Neural Style Transfer

神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本。

这里用pytorch重新实现一次,原理图如下:


from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as npdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载图像

def load_image(image_path, transform=None, max_size=None, shape=None):"""Load an image and convert it to a torch tensor."""image = Image.open(image_path)if max_size:scale = max_size / max(image.size)size = np.array(image.size) * scaleimage = image.resize(size.astype(int), Image.ANTIALIAS)if shape:image = image.resize(shape, Image.LANCZOS)if transform:image = transform(image).unsqueeze(0)return image.to(device)

这里用的模型是 VGG-19,所要用的是网络中的5个卷积层

class VGGNet(nn.Module):def __init__(self):"""Select conv1_1 ~ conv5_1 activation maps."""super(VGGNet, self).__init__()self.select = ['0', '5', '10', '19', '28'] self.vgg = models.vgg19(pretrained=True).featuresdef forward(self, x):"""Extract multiple convolutional feature maps."""features = []for name, layer in self.vgg._modules.items():x = layer(x)if name in self.select:features.append(x)return features

 模型结构如下,可以看到使用序列模型来写的VGG-NET,所以标号即层号,我们要保存的是['0', '5', '10', '19', '28'] 的输出结果。

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace)(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace)(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace)(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(24): ReLU(inplace)(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(26): ReLU(inplace)(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace)(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(31): ReLU(inplace)(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(33): ReLU(inplace)(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(35): ReLU(inplace)(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace)(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace)(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

 训练:

接下来对训练过程进行解释:

1、加载风格图像和内容图像,我们在之前的博客中使用的一幅加噪图进行训练,这里是用的内容图像的拷贝。

2、我们需要优化的就是作为目标的内容图像拷贝,可以看到target需要求导。

3、VGGnet参数是不需要优化的,所以设置为验证状态。

4、将3幅图像输入网络,得到总共15个输出(每个图像有5层的输出)

5、内容损失:这里是遍历5个层的输出来计算损失,而在keras版本中只用了第4层的输出计算损失

6、风格损失:同样计算格拉姆风格矩阵,将每一层的风格损失叠加,得到总的风格损失,计算公式同样和keras版本有所不一样

7、反向传播

def main(config):# Image preprocessing# VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].# We use the same normalization statistics here.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])# Load content and style images# Make the style image same size as the content imagecontent = load_image(config.content, transform, max_size=config.max_size)style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])# Initialize a target image with the content imagetarget = content.clone().requires_grad_(True)optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])vgg = VGGNet().to(device).eval()for step in range(config.total_step):# Extract multiple(5) conv feature vectorstarget_features = vgg(target)content_features = vgg(content)style_features = vgg(style)style_loss = 0content_loss = 0for f1, f2, f3 in zip(target_features, content_features, style_features):# Compute content loss with target and content imagescontent_loss += torch.mean((f1 - f2)**2)# Reshape convolutional feature maps_, c, h, w = f1.size()f1 = f1.view(c, h * w)f3 = f3.view(c, h * w)# Compute gram matrixf1 = torch.mm(f1, f1.t())f3 = torch.mm(f3, f3.t())# Compute style loss with target and style imagesstyle_loss += torch.mean((f1 - f3)**2) / (c * h * w) # Compute total loss, backprop and optimizeloss = content_loss + config.style_weight * style_loss optimizer.zero_grad()loss.backward()optimizer.step()if (step+1) % config.log_step == 0:print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' .format(step+1, config.total_step, content_loss.item(), style_loss.item()))if (step+1) % config.sample_step == 0:# Save the generated imagedenorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))img = target.clone().squeeze()img = denorm(img).clamp_(0, 1)torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

写在if __name__=="__main__"后面的语句只会在本脚本中才能被执行,被调用时是不会被执行的。 

python的命令行工具:argparse,很优雅的添加参数

但是由于jupyter不支持添加外部参数,所以使用了外部博客的方法来支持(记住更改读取图片的位置)

import sys
if __name__ == "__main__":#解决方案来自于博客if '-f' in sys.argv:sys.argv.remove('-f')parser = argparse.ArgumentParser()parser.add_argument('--content', type=str, default='png/content.png')parser.add_argument('--style', type=str, default='png/style.png')parser.add_argument('--max_size', type=int, default=400)parser.add_argument('--total_step', type=int, default=2000)parser.add_argument('--log_step', type=int, default=10)parser.add_argument('--sample_step', type=int, default=500)parser.add_argument('--style_weight', type=float, default=100)parser.add_argument('--lr', type=float, default=0.003)#config = parser.parse_args()config = parser.parse_known_args()[0]   #参考博客 https://blog.csdn.net/ken_for_learning/article/details/89675904print(config)main(config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/76293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Json“牵手”亚马逊商品详情数据方法,亚马逊商品详情API接口,亚马逊API申请指南

亚马逊平台是美国最大的一家网络电子商务公司,亚马逊公司是1995年成立,刚开始只做网上书籍售卖业务,后来扩展到了其他产品。现在已经是全世界商品品种最多的网上零售商和第二互联网公司,亚马逊是北美洲、欧洲等地区的主流购物平台…

数据结构:线性表之-循环双向链表(万字详解)

目录 基本概念 1,什么是双向链表 2,与单向链表的区别 双向链表详解 功能展示: 1. 定义链表 2,创建双向链表 3,初始化链表 4,尾插 5,头插 6,尾删 判断链表是否被删空 尾删代码 7&a…

我们这一代人的机会是什么?

大家好,我是苍何,今天作为专业嘉宾参观了 2023 年中国国际智能产业博览会(智博会),是一场以「智汇八方,博采众长」为主题的汇聚全球智能技术和产业创新的盛会,感触颇深,随着中国商业…

9月11日作业

思维导图 代码 #include <iostream> #include<string.h>using namespace std;class myString { private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度 public://无参构造myString():size(10){str new char[size]; …

淘宝京东扣库存怎么实现的

1. 使用kv存储实时的库存&#xff0c;直接在kv里扣减&#xff0c;避免用分布式锁 2. 不要先查再扣&#xff0c;直接扣扣扣&#xff0c;扣到负数&#xff0c;&#xff08;增改就直接在kv里做&#xff09;&#xff0c;就说明超卖了&#xff0c;回滚刚才的扣减 3. 同时写MQ&…

JVM类加载机制

目录 一、Java为什么是一种跨平台的语言&#xff1f; 二、Java代码的执行流程 解释执行为主&#xff0c;编译执行为辅&#xff1a; 三、类加载的过程 3.1、加载 类加载器&#xff08;就是加载类的&#xff09;分为&#xff1a; 3.1.1、启动类加载器&#xff08;Bootstrap…

UMA 2 - Unity Multipurpose Avatar☀️三.给UMA设置默认服饰Recipes

文章目录 🟥 项目基础配置🟧 给UMA配置默认服饰Recipes🟨 设置服饰Recipes属性🟥 项目基础配置 将 UMA_DCS 预制体放到场景中创建空物体,添加DynamicCharacterAvatar 脚本,选择 HumanMaleDCS作为我们的基本模型配置默认Animator 🟧 给UMA配置默认服饰Recipes 服饰Re…

回归预测 | MATLAB实现PCA-BP主成分降维结合BP神经网络多输入单输出回归预测

回归预测 | MATLAB实现PCA-BP主成分降维结合BP神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现PCA-BP主成分降维结合BP神经网络多输入单输出回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 MATLAB实现PCA-BP主成分降维算法结合BP神经网络多输入单输出回…

【数据结构】串

串 串的顺序实现简单的模式匹配算法KMP算法KMP算法的进一步优化 串的顺序实现 初始化 #define MaxSize 50 typedef char ElemType;//顺序存储表示 typedef struct{ElemType data[MaxSize];int length; }SString;/*** 初始化串*/ void InitString(SString *string) {for (int …

Cmake入门(一文读懂)

目录 1、Cmake简介2、安装CMake3、CMakeLists.txt4、单目录简单实例4.1、CMakeLists.txt4.2、构建bulid内部构建外部构建 4.3、运行C语言程序 5、多目录文件简单实例5.1、根目录CMakeLists.txt5.2、源文件目录5.3、utils.h5.4、创建build 6、生成库文件和链接外部库文件7、注意…

Mysql5.7(Docker环境)实现主从复制

文章目录 前言一、MySQL主从数据库同步如何实现&#xff1f;(理论)1.1 为什么要使用数据库主从1.2 数据库主从实现原理是什么&#xff1f; 二、Docker环境配置MySQL5.7主从(实践)2.1 配置安装Master2.2 配置安装Slave 前言 本文章将以MySQL5.7版本来讲诉MySQL主从复制的原理以…

【计算机网络】UDP数据包是如何在网络中传输的?

List item 创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; 更多计算机网络知识专栏&#xff1a;计算机网络&#x1f525;…

【uni-app】

准备工作&#xff08;Hbuilder&#xff09; 1.下载hbuilder&#xff0c;插件使用Vue3的uni-app项目 2.需要安装编译器 3.下载微信开发者工具 4.点击运行->微信开发者工具 5.打开微信开发者工具的服务端口 效果图 准备工作&#xff08;VScode&#xff09; 插件 uni-cr…

Android窗口层级(Window Type)分析

前言 Android的窗口Window分为三种类型&#xff1a; 应用Window&#xff0c;比如Activity、Dialog&#xff1b;子Window&#xff0c;比如PopupWindow&#xff1b;系统Window&#xff0c;比如Toast、系统状态栏、导航栏等等。 应用Window的Z-Ordered最低&#xff0c;就是在系…

Codeforces Round 827 (Div. 4) D 1e5+双重for循环技巧

Codeforces Round 827 (Div. 4) D 做题链接&#xff1a;Codeforces Round 827 (Div. 4) 给定一个由 n个正整数 a1,a2,…,an&#xff08;1≤ai≤1000&#xff09;组成的数组。求ij的最大值&#xff0c;使得ai和aj共质&#xff0c;否则−1&#xff0c;如果不存在这样的i&#…

Jetsonnano B01 笔记3:GPIO上拉下拉-输入输出读取

今日继续我的jetsonnano学习之路&#xff0c;今日学习的是GPIO的上拉下拉&#xff0c;输入输出的读取&#xff0c;文章贴出完整操作步骤过程&#xff0c;贴出源码。 目录 Linux常用文件命令&#xff1a; ls&#xff08;list&#xff09;列表&#xff1a; man&#xff1a; …

页面页脚部分CSS分享

先看效果&#xff1a; CSS部分&#xff1a;&#xff08;查看更多&#xff09; <style>body {display: grid;grid-template-rows: 1fr 10rem auto;grid-template-areas: "main" "." "footer";overflow-x: hidden;background: #F5F7FA;min…

【微服务】五. Nacos服务注册

Nacos服务注册 5.1 Nacos服务分级存储模型Nacos服务分级存储模型&#xff1a;服务集群属性&#xff1a;总结&#xff1a; 5.2 根据集群负载均衡总结 5.3 Nacos服务实例的权重设置总结&#xff1a; 5.6 环境隔离namespace总结 5.7 Nacos和Eureka的对比总结 5.1 Nacos服务分级存储…

科技云报道:AI时代,对构建云安全提出了哪些新要求?

科技云报道原创。 随着企业上云的提速&#xff0c;一系列云安全问题也逐渐暴露出来&#xff0c;云安全问题得到重视&#xff0c;市场不断扩大。 Gartner 发布“2022 年中国 ICT 技术成熟度曲线”显示&#xff0c;云安全已处于技术萌芽期高点&#xff0c;预期在2-5年内有望达到…

Material Design系列探究之LinearLayoutCompat

谷歌Material Design推出了许多非常好用的控件&#xff0c;所以我决定写一个专题来讲述MaterialDesign&#xff0c;今天带来Material Design系列的第一弹 LinearLayoutCompat。 以前要在LinearLayout布局之间的子View之间添加分割线&#xff0c;还需要自己去自定义控件进行添加…