动手学深度学习(Pytorch版)代码实践 -计算机视觉-49风格迁移

49风格迁移

在这里插入图片描述
在这里插入图片描述

读入内容图像:

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 读取内容图像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
plt.imshow(content_img)
plt.show()

在这里插入图片描述

读取风格图像:

# 读取风格图像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
plt.imshow(style_img)
plt.show()

在这里插入图片描述

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 读取内容图像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
# plt.imshow(content_img)
# plt.show()# 读取风格图像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
# plt.imshow(style_img)
# plt.show()# 预处理和后处理
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])# 函数preprocess对输入图像在RGB三个通道分别做标准化,
# 并将结果变换成卷积神经网络接受的输入格式
def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0) # 增加一个通道# 后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。 
# 由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。
def postprocess(img):# img[0] 表示移除批次维度,从批次中提取出第一个图像img = img[0].to(rgb_std.device) # 移除批次维度,并将图像张量移动到与 rgb_std 相同的设备img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 反转标准化过程return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))# ToPILImage() 期望的输入是 [C, H, W] 形式,因此需要再次将张量的通道维度移动到第一个位置。# 抽取图像特征
# 使用基于ImageNet数据集预训练的VGG-19模型
# VGG19包含了19个隐藏层(16个卷积层和3个全连接层)
pretrained_net = torchvision.models.vgg19(pretrained=True)"""一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。 为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。 我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。
"""
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
# net 模型包含了 VGG-19 从第 0 层到第 28 层的所有层
net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])# 由于我们还需要中间层的输出,
# 因此这里我们逐层计算,并保留内容层和风格层的输出
def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X)if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles# 对内容图像抽取内容特征
def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Y# 对风格图像抽取风格特征
def get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y# 定义损失函数
# 由内容损失、风格损失和全变分损失3部分组成# 内容损失
# 内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异
# 平方误差函数的两个输入均为extract_features函数计算所得到的内容层的输出。
def content_loss(Y_hat, Y):# 我们从动态计算梯度的树中分离目标:# 这是一个规定的值,而不是一个变量。return torch.square(Y_hat - Y.detach()).mean()# 风格损失
def gram(X): # 基于风格图像的格拉姆矩阵num_channels, n = X.shape[1], X.numel() // X.shape[1]X = X.reshape((num_channels, n))return torch.matmul(X, X.T) / (num_channels * n)def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean()# 全变分损失
# 合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。 
# 一种常见的去噪方法是全变分去噪total variation denoising
def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())"""
风格转移的损失函数是内容损失、风格损失和总变化损失的加权和。
通过调节这些权重超参数,我们可以权衡合成图像在保留内容、迁移风格以及去噪三方面的相对重要性。
"""
content_weight, style_weight, tv_weight = 1, 1e3, 10def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):# 分别计算内容损失、风格损失和全变分损失contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weight# 对所有损失求和l = sum(10 * styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l# 初始化合成图像
class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight# 函数创建了合成图像的模型实例,并将其初始化为图像X
def get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer# 训练模型
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)  # 初始化合成图像和优化器scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = lp.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()  # 梯度清零contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)  # 提取特征contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)  # 计算损失l.backward()  # 反向传播计算梯度trainer.step()  # 更新模型参数scheduler.step()  # 更新学习率if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return Xdevice, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
plt.show()

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/38531.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 Swift 递归搜索目录中文件的内容,同时支持 Glob 模式和正则表达式

文章目录 前言项目设置查找文件读取CODEOWNERS文件解析规则搜索匹配的文件确定文件所有者输出结果总结前言 如果你新加入一个团队,想要快速的了解团队的领域和团队中拥有的代码库的详细信息。 如果新团队中的代码库在 GitHub / GitLab 中并且你不熟悉代码所有权模型的概念或…

Unity开箱即用的UGUI面板的拖拽移动功能

文章目录 👉一、背景👉二、效果图👉三、原理👉四、核心代码👉五,总结 👉一、背景 之前做PC项目时常常有面板拖拽移动的需求,今天总结封装一下,做成一个随时随地可复用的…

Linux 安装 Redis 教程

优质博文:IT-BLOG-CN 一、准备工作 配置gcc:安装Redis前需要配置gcc: yum install gcc如果配置gcc出现依赖包问题,在安装时提示需要的依赖包版本和本地版本不一致,本地版本过高,出现如下问题&#xff1a…

Windows 11 安装 安卓子系统 (WSA)

How to Install Windows Subsystem for Android (WSA) on Windows 11 新手教程:如何安装Windows 11 安卓子系统 说明 Windows Subsystem for Android 或 WSA 是由 Hyper-V 提供支持的虚拟机,可在 Windows 11 操作系统上运行 Android 应用程序。虽然它需…

python基础_类

在Python中,类(Class)是面向对象编程(OOP)的核心概念之一。类提供了一种创建新对象的模板,这些对象通常被称为类的实例或对象。以下是关于Python类的一些关键点和特性: 定义类 类通过class关键…

ctfshow-web入门-命令执行(web71-web74)

目录 1、web71 2、web72 3、web73 4、web74 1、web71 像上一题那样扫描但是输出全是问号 查看提示:我们可以结合 exit() 函数执行php代码让后面的匹配缓冲区不执行直接退出。 payload: cvar_export(scandir(/));exit(); 同理读取 flag.txt cinclud…

文华财经博易大师盘立方多空波段止损画线指标公式

TT:PERIOD7; EMA120:EMA(C,120); RSV:(CLOSE-LLV(LOW,9))/(HHV(HIGH,9)-LLV(LOW,9))*100; K:SMA(RSV,3,1); D:SMA(K,3,1); J:3*K-2*D; DRAWTEXT(TT&&J<0,L,多),VALIGN0; DRAWTEXT(TT&&J>100,H,空),VALIGN2; IF(TT,EMA(C,60),NULL),RGB(255,255,2…

JavaScript数组对象 , 正则对象 , String对象以及自定义对象介绍

1. Array数组对象 数组对象是使用单独的变量名来存储一系列的值。 1.1创建一个数组 创建一个数组&#xff0c;有三种方法。 【1】常规方式: let 数组名 new Array();【2】简洁方式: 推荐使用 let 数组名 new Array(数值1,数值2,...);【3】字面:在js中创建数组使用中括号…

试用笔记之-收钱吧安卓版演示源代码,收钱吧手机版感受

首先下载&#xff1a; https://download.csdn.net/download/tjsoft/89499105 安卓手机安装 如果有收钱吧帐号输入收钱吧帐号和密码。 如果没有收钱吧帐号点我的注册 登录收钱吧帐号后就可以把手机当成收钱吧POS机用了&#xff0c;还可以扫客服的付款码哦 源代码技术交流QQ:42…

U盘数据恢复实战指南:原因、方案与预防措施

一、引言&#xff1a;U盘数据恢复概述 在数字化时代&#xff0c;U盘作为一种便携式存储设备&#xff0c;广泛应用于个人和企业中。然而&#xff0c;由于各种原因&#xff0c;U盘数据丢失的问题时有发生。U盘数据恢复技术便是在这种情况下应运而生&#xff0c;它帮助用户在数据…

TPS61085非同步650kHz,1.2MHz, 18.5V升压DCDC芯片

1 特点 TPS61085外观和丝印PMKI 2.3 V 至 6 V 输入电压范围 具有 2.0A 开关电流的 18.5V 升压转换器 650kHz/1.2MHz 可选开关频率 可调软启动 热关断 欠压闭锁 8引脚VSSOP封装 8引脚TSSOP封装 2 应用 手持设备 GPS接收器 数码相机 便携式应用 DSL调制解调器 PCMCIA卡 TFT LCD…

Java并发编程基础知识点

目录 Java并发编程基础知识点1、线程&#xff0c;进程概念及二者的关系进程相关概念线程相关概念进程与线程的关系补充小知识点&#xff1a; 2、线程的状态Java线程的状态&#xff1a;Java线程不同状态之间的切换图示 3、Java程序中如何创建线程&#xff1f;①、继承Thread类②…

【漏洞复现】用友NC——文件上传漏洞

声明&#xff1a;本文档或演示材料仅供教育和教学目的使用&#xff0c;任何个人或组织使用本文档中的信息进行非法活动&#xff0c;均与本文档的作者或发布者无关。 文章目录 漏洞描述漏洞复现测试工具 漏洞描述 用友NC是由用友公司开发的一套面向大型企业和集团型企业的管理软…

C语言 | Leetcode C语言题解之第207题课程表

题目&#xff1a; 题解&#xff1a; bool canFinish(int numCourses, int** prerequisites, int prerequisitesSize, int* prerequisitesColSize) {int** edges (int**)malloc(sizeof(int*) * numCourses);for (int i 0; i < numCourses; i) {edges[i] (int*)malloc(0);…

【Linux】部署NFS服务实现数据共享

&#x1f468;‍&#x1f393;博主简介 &#x1f3c5;CSDN博客专家   &#x1f3c5;云计算领域优质创作者   &#x1f3c5;华为云开发者社区专家博主   &#x1f3c5;阿里云开发者社区专家博主 &#x1f48a;交流社区&#xff1a;运维交流社区 欢迎大家的加入&#xff01…

4-数据提取方法2(xpath和lxml)(6节课学会爬虫)

4-数据提取方法2&#xff08;xpath和lxml&#xff09;&#xff08;6节课学会爬虫&#xff09; 1&#xff0c;Xpath语法&#xff1a;&#xff08;1&#xff09;选择节点&#xff08;标签&#xff09;&#xff08;2&#xff09;“//”:能从任意节点开始选择&#xff08;3&#xf…

计算机网络面试TCP篇之TCP三次握手与四次挥手

TCP 三次握手与四次挥手面试题 任 TCP 虐我千百遍&#xff0c;我仍待 TCP 如初恋。 巨巨巨巨长的提纲&#xff0c;发车&#xff01;发车&#xff01; PS&#xff1a;本次文章不涉及 TCP 流量控制、拥塞控制、可靠性传输等方面知识&#xff0c;这些知识在这篇&#xff1a; TCP …

第3章:数据结构

树 对稀疏矩阵的压缩方法有三种&#xff1a; 1、三元组顺序表 2、行逻辑连接的顺序表 3、十字链表 同义词才会占用同个位置&#xff0c;从而需要进行多次比较。这些关键字的第一个可以不是e的同义词&#xff0c;可以是排在e之前的关键字正好占了那个位置。 Dijkstra算法主要特点…

jenkins 发布服务到 windows服务器

1.环境准备 1.1 这些就不过多描述了&#xff0c;可以参考我的另一盘文章部署到linux。 jenkins 发布服务到linux服务器-CSDN博客 1.2 需要再windows上安装openssh 地址&#xff1a;Releases PowerShell/Win32-OpenSSH GitHub 到windows上执行安装&#xff0c;可以里面cmd命令…

Java的限制序列化和常用IO流

一.限制序列化 a.问题 出于安全考虑&#xff0c;对于一些比较敏感的信息(如用户密码)&#xff0c;应限制被序列化&#xff0c;如何实现? ◆使用transient关键字修改不需要序列化的对象属性 b.示例 ◆希望Person类对象中的年龄信息不被序列化 二.Java常用IO流有哪些&#x…