动手学深度学习(Pytorch版)代码实践 -计算机视觉-49风格迁移

49风格迁移

在这里插入图片描述
在这里插入图片描述

读入内容图像:

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 读取内容图像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
plt.imshow(content_img)
plt.show()

在这里插入图片描述

读取风格图像:

# 读取风格图像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
plt.imshow(style_img)
plt.show()

在这里插入图片描述

import torch
import torchvision
from torch import nn
import matplotlib.pylab as plt
import liliPytorch as lp
from d2l import torch as d2l# 读取内容图像
content_img = d2l.Image.open('../limuPytorch/images/rainier.jpg')
# plt.imshow(content_img)
# plt.show()# 读取风格图像
style_img = d2l.Image.open('../limuPytorch/images/autumn-oak.jpg')
# plt.imshow(style_img)
# plt.show()# 预处理和后处理
rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])# 函数preprocess对输入图像在RGB三个通道分别做标准化,
# 并将结果变换成卷积神经网络接受的输入格式
def preprocess(img, image_shape):transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(image_shape),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])return transforms(img).unsqueeze(0) # 增加一个通道# 后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。 
# 由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。
def postprocess(img):# img[0] 表示移除批次维度,从批次中提取出第一个图像img = img[0].to(rgb_std.device) # 移除批次维度,并将图像张量移动到与 rgb_std 相同的设备img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1) # 反转标准化过程return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))# ToPILImage() 期望的输入是 [C, H, W] 形式,因此需要再次将张量的通道维度移动到第一个位置。# 抽取图像特征
# 使用基于ImageNet数据集预训练的VGG-19模型
# VGG19包含了19个隐藏层(16个卷积层和3个全连接层)
pretrained_net = torchvision.models.vgg19(pretrained=True)"""一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。 为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。 我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。
"""
style_layers, content_layers = [0, 5, 10, 19, 28], [25]
# net 模型包含了 VGG-19 从第 0 层到第 28 层的所有层
net = nn.Sequential(*[pretrained_net.features[i] for i inrange(max(content_layers + style_layers) + 1)])# 由于我们还需要中间层的输出,
# 因此这里我们逐层计算,并保留内容层和风格层的输出
def extract_features(X, content_layers, style_layers):contents = []styles = []for i in range(len(net)):X = net[i](X)if i in style_layers:styles.append(X)if i in content_layers:contents.append(X)return contents, styles# 对内容图像抽取内容特征
def get_contents(image_shape, device):content_X = preprocess(content_img, image_shape).to(device)contents_Y, _ = extract_features(content_X, content_layers, style_layers)return content_X, contents_Y# 对风格图像抽取风格特征
def get_styles(image_shape, device):style_X = preprocess(style_img, image_shape).to(device)_, styles_Y = extract_features(style_X, content_layers, style_layers)return style_X, styles_Y# 定义损失函数
# 由内容损失、风格损失和全变分损失3部分组成# 内容损失
# 内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异
# 平方误差函数的两个输入均为extract_features函数计算所得到的内容层的输出。
def content_loss(Y_hat, Y):# 我们从动态计算梯度的树中分离目标:# 这是一个规定的值,而不是一个变量。return torch.square(Y_hat - Y.detach()).mean()# 风格损失
def gram(X): # 基于风格图像的格拉姆矩阵num_channels, n = X.shape[1], X.numel() // X.shape[1]X = X.reshape((num_channels, n))return torch.matmul(X, X.T) / (num_channels * n)def style_loss(Y_hat, gram_Y):return torch.square(gram(Y_hat) - gram_Y.detach()).mean()# 全变分损失
# 合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。 
# 一种常见的去噪方法是全变分去噪total variation denoising
def tv_loss(Y_hat):return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())"""
风格转移的损失函数是内容损失、风格损失和总变化损失的加权和。
通过调节这些权重超参数,我们可以权衡合成图像在保留内容、迁移风格以及去噪三方面的相对重要性。
"""
content_weight, style_weight, tv_weight = 1, 1e3, 10def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):# 分别计算内容损失、风格损失和全变分损失contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(contents_Y_hat, contents_Y)]styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(styles_Y_hat, styles_Y_gram)]tv_l = tv_loss(X) * tv_weight# 对所有损失求和l = sum(10 * styles_l + contents_l + [tv_l])return contents_l, styles_l, tv_l, l# 初始化合成图像
class SynthesizedImage(nn.Module):def __init__(self, img_shape, **kwargs):super(SynthesizedImage, self).__init__(**kwargs)self.weight = nn.Parameter(torch.rand(*img_shape))def forward(self):return self.weight# 函数创建了合成图像的模型实例,并将其初始化为图像X
def get_inits(X, device, lr, styles_Y):gen_img = SynthesizedImage(X.shape).to(device)gen_img.weight.data.copy_(X.data)trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)styles_Y_gram = [gram(Y) for Y in styles_Y]return gen_img(), styles_Y_gram, trainer# 训练模型
def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)  # 初始化合成图像和优化器scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)animator = lp.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs],legend=['content', 'style', 'TV'],ncols=2, figsize=(7, 2.5))for epoch in range(num_epochs):trainer.zero_grad()  # 梯度清零contents_Y_hat, styles_Y_hat = extract_features(X, content_layers, style_layers)  # 提取特征contents_l, styles_l, tv_l, l = compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)  # 计算损失l.backward()  # 反向传播计算梯度trainer.step()  # 更新模型参数scheduler.step()  # 更新学习率if (epoch + 1) % 10 == 0:animator.axes[1].imshow(postprocess(X))animator.add(epoch + 1, [float(sum(contents_l)),float(sum(styles_l)), float(tv_l)])return Xdevice, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)
plt.show()

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/38531.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 Swift 递归搜索目录中文件的内容,同时支持 Glob 模式和正则表达式

文章目录 前言项目设置查找文件读取CODEOWNERS文件解析规则搜索匹配的文件确定文件所有者输出结果总结前言 如果你新加入一个团队,想要快速的了解团队的领域和团队中拥有的代码库的详细信息。 如果新团队中的代码库在 GitHub / GitLab 中并且你不熟悉代码所有权模型的概念或…

Unity开箱即用的UGUI面板的拖拽移动功能

文章目录 👉一、背景👉二、效果图👉三、原理👉四、核心代码👉五,总结 👉一、背景 之前做PC项目时常常有面板拖拽移动的需求,今天总结封装一下,做成一个随时随地可复用的…

More Effective C++ 35个改善编程与设计的有效方法笔记与心得 3

三. 异常 条款9: 利用destructors避免泄露资源 ‌‌‌‌  在编程中,"资源"可以指任何系统级的有限资源,如内存、文件句柄、网络套接字等。"泄露"则是指在应用程序中分配了资源,但在不再需要这些资源时没有…

Linux 安装 Redis 教程

优质博文:IT-BLOG-CN 一、准备工作 配置gcc:安装Redis前需要配置gcc: yum install gcc如果配置gcc出现依赖包问题,在安装时提示需要的依赖包版本和本地版本不一致,本地版本过高,出现如下问题&#xff1a…

Jupyter无法导入库,但能在终端导入的问题

Jupyter无法导入库,但能在终端导入 ❌错误问题描述:conda activate LLMs激活某个Conda的环境后,尽管已经通过conda或者pip在这个环境中安装了一些🐍Python的库,但无法在Jupyter中导入,却能在终端成功导入。…

京东商品详情数据接口(JD.item_get)丨京东API实时接口指南

京东商品详情API接口(JD.item_get)是京东开放平台提供的一个数据接口,用于获取京东平台上单个商品的详细信息。 通过这个接口,开发者可以获取到包括商品名称、品牌、产地、规格参数、价格信息、销量、评价、图片、描述等在内的详…

Node.js开发实战 视频教程 下载

ode.js开发实战 视频教程 下载 下载地址 https://download.csdn.net/download/m0_67912929/89487510 01-课程介绍.mp4 02-内容综述.mp4 03-Node.js是什么? .mp4 04-Node.js可以用来做什么?.mp4 05-课程实战项目介绍.mp4 06-什么是技术预研? .mp4 07-Node.js开发环境…

Windows 11 安装 安卓子系统 (WSA)

How to Install Windows Subsystem for Android (WSA) on Windows 11 新手教程:如何安装Windows 11 安卓子系统 说明 Windows Subsystem for Android 或 WSA 是由 Hyper-V 提供支持的虚拟机,可在 Windows 11 操作系统上运行 Android 应用程序。虽然它需…

【JS】注意考点

1.声明变量时所遵循的规则: (1)可以使用一个保留关键字var同时声明多个变量 (2)可以在声明变量的同时对其赋值, (3)如果只是声明了变量,并未对其赋值,其值就默认为 Undefined。 (4)保留关键字var可以用作for语句和for…in语句…

python基础_类

在Python中,类(Class)是面向对象编程(OOP)的核心概念之一。类提供了一种创建新对象的模板,这些对象通常被称为类的实例或对象。以下是关于Python类的一些关键点和特性: 定义类 类通过class关键…

PostgreSQL的系统视图pg_stat_wal

PostgreSQL的系统视图pg_stat_wal 在 PostgreSQL 数据库中,pg_stat_wal 视图提供了与 WAL(Write-Ahead Logging)日志有关的统计信息。WAL 是 PostgreSQL 用于确保数据一致性和持久性的重要机制。因此,监控和分析 WAL 活动对于数据…

ctfshow-web入门-命令执行(web71-web74)

目录 1、web71 2、web72 3、web73 4、web74 1、web71 像上一题那样扫描但是输出全是问号 查看提示:我们可以结合 exit() 函数执行php代码让后面的匹配缓冲区不执行直接退出。 payload: cvar_export(scandir(/));exit(); 同理读取 flag.txt cinclud…

文华财经博易大师盘立方多空波段止损画线指标公式

TT:PERIOD7; EMA120:EMA(C,120); RSV:(CLOSE-LLV(LOW,9))/(HHV(HIGH,9)-LLV(LOW,9))*100; K:SMA(RSV,3,1); D:SMA(K,3,1); J:3*K-2*D; DRAWTEXT(TT&&J<0,L,多),VALIGN0; DRAWTEXT(TT&&J>100,H,空),VALIGN2; IF(TT,EMA(C,60),NULL),RGB(255,255,2…

JavaScript数组对象 , 正则对象 , String对象以及自定义对象介绍

1. Array数组对象 数组对象是使用单独的变量名来存储一系列的值。 1.1创建一个数组 创建一个数组&#xff0c;有三种方法。 【1】常规方式: let 数组名 new Array();【2】简洁方式: 推荐使用 let 数组名 new Array(数值1,数值2,...);【3】字面:在js中创建数组使用中括号…

【ubuntu 】使用samba配置共享用户home目录和其他具体路径

目录 1 安装samba 2 修改Samba配置文件 3 增加Rose用户的samba帐号 4 重启samba 5 测试 1 安装samba 使用如下命令安装samba&#xff1a; sudo apt-get updatesudo apt-get install samba openssh-server 2 修改Samba配置文件 sudo cp /etc/samba/smb.conf /etc/samba…

试用笔记之-收钱吧安卓版演示源代码,收钱吧手机版感受

首先下载&#xff1a; https://download.csdn.net/download/tjsoft/89499105 安卓手机安装 如果有收钱吧帐号输入收钱吧帐号和密码。 如果没有收钱吧帐号点我的注册 登录收钱吧帐号后就可以把手机当成收钱吧POS机用了&#xff0c;还可以扫客服的付款码哦 源代码技术交流QQ:42…

Docker安装MySQL5

Docker安装MySQL5 前言 MySQL 是一个开源的关系型数据库管理系统&#xff0c;广泛用于各种 Web 应用程序的开发和生产环境中。MySQL 5 是 MySQL 数据库的一个较早版本&#xff0c;虽然不再是最新版本&#xff0c;但仍然被一些项目所使用和支持。 在 Docker 中安装 MySQL 5 可…

Docker 手册

帮助命令 docker 命令 --help镜像命令 docker images (-a所有 &#xff5c; -q只显示容器的ID) docker search 镜像名 docker pull 镜像名&#xff1a;版本号 docker rmi -f ID&#xff5c;镜像名&#xff1a;版本号 // 删除本地一个或多个镜像 docker rmi -f $(docker …

U盘数据恢复实战指南:原因、方案与预防措施

一、引言&#xff1a;U盘数据恢复概述 在数字化时代&#xff0c;U盘作为一种便携式存储设备&#xff0c;广泛应用于个人和企业中。然而&#xff0c;由于各种原因&#xff0c;U盘数据丢失的问题时有发生。U盘数据恢复技术便是在这种情况下应运而生&#xff0c;它帮助用户在数据…

TPS61085非同步650kHz,1.2MHz, 18.5V升压DCDC芯片

1 特点 TPS61085外观和丝印PMKI 2.3 V 至 6 V 输入电压范围 具有 2.0A 开关电流的 18.5V 升压转换器 650kHz/1.2MHz 可选开关频率 可调软启动 热关断 欠压闭锁 8引脚VSSOP封装 8引脚TSSOP封装 2 应用 手持设备 GPS接收器 数码相机 便携式应用 DSL调制解调器 PCMCIA卡 TFT LCD…