基于CycleGAN的图像风格转换

基于CycleGAN的图像风格转换

  • 1.导入所需要的包和库:
  • 2.将一个Tensor转换为图像:
  • 3.数据加载:
  • 4.图像变换:
  • 5.加载和预处理训练数据:
  • 6.定义了一个残差块:
  • 7.生成器:
  • 8.判断器:
  • 9.数据缓存器:
  • 10.执行生成器的训练步骤:
  • 11.训练判别器:
  • 12.损失打印,存储伪造图片:

1.导入所需要的包和库:

from random import randint
import numpy as np 
import torch
torch.set_default_tensor_type(torch.FloatTensor)
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import shutil
import cv2
import random
from PIL import Image
import itertools

2.将一个Tensor转换为图像:

def to_img(x):out = 0.5 * (x + 1)out = out.clamp(0, 1)  out = out.view(-1, 3, 256, 256)  return out

3.数据加载:

data_path = os.path.abspath('D:/XUNLJ/data')
image_size = 256
batch_size = 1

4.图像变换:

  • 首先,图像会被调整到略大于原始大小,然后随机裁剪回原始大小,接着进行水平翻转,转换为张量格式,最后进行标准化处理
transform = transforms.Compose([transforms.Resize(int(image_size * 1.12), Image.BICUBIC), transforms.RandomCrop(image_size), transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

5.加载和预处理训练数据:

  • 文件夹中随机选择一批A类和B类图像,应用预定义的图像变换,并将它们转换为适合神经网络输入的张量格式
def _get_train_data(batch_size=1):train_a_filepath = data_path + '\\trainA\\'train_b_filepath = data_path + '\\trainB\\'train_a_list = os.listdir(train_a_filepath)train_b_list = os.listdir(train_b_filepath)train_a_result = []train_b_result = [] numlist = random.sample(range(0, len(train_a_list)), batch_size)for i in numlist:a_filename = train_a_list[i]a_img = Image.open(train_a_filepath + a_filename).convert('RGB')res_a_img = transform(a_img)train_a_result.append(torch.unsqueeze(res_a_img, 0))b_filename = train_b_list[i]b_img = Image.open(train_b_filepath + b_filename).convert('RGB')res_b_img = transform(b_img)train_b_result.append(torch.unsqueeze(res_b_img, 0))return torch.cat(train_a_result, dim=0), torch.cat(train_b_result, dim=0)

6.定义了一个残差块:

  • 定义了一个简单的残差块,它包含两个卷积层和实例归一化,以及ReLU激活函数
class ResidualBlock(nn.Module):def __init__(self, in_features):super(ResidualBlock, self).__init__()self.block_layer = nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features),nn.ReLU(inplace=True),nn.ReflectionPad2d(1),nn.Conv2d(in_features, in_features, 3),nn.InstanceNorm2d(in_features))def forward(self, x):return x + self.block_layer(x)

7.生成器:

  • 网络包含卷积层、下采样层、残差块和上采样层,用于将噪声输入转换为高质量的图像输出
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()model = [nn.ReflectionPad2d(3), nn.Conv2d(3, 64, 7), nn.InstanceNorm2d(64), nn.ReLU(inplace=True)]in_features = 64out_features = in_features * 2for _ in range(2):model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)]in_features = out_featuresout_features = in_features*2for _ in range(9):model += [ResidualBlock(in_features)]out_features = in_features // 2for _ in range(2):model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), nn.InstanceNorm2d(out_features), nn.ReLU(inplace=True)]in_features = out_featuresout_features = in_features // 2model += [nn.ReflectionPad2d(3), nn.Conv2d(64, 3, 7), nn.Tanh()]self.gen = nn.Sequential( * model)def forward(self, x):x = self.gen(x)return x 

8.判断器:

  • 用于判断输入图像的真实性,含卷积层和LeakyReLU激活函数,用于从输入图像中提取特征,通过平均池化和重塑来生成一个与图像真实性相关的分数
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.dis = nn.Sequential(nn.Conv2d(3, 64, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, 128, 4, 2, 1, bias=False),nn.InstanceNorm2d(128),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, 4, 2, 1, bias=False),nn.InstanceNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, 4, padding=1),nn.InstanceNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, 4, padding=1))        def forward(self, x):x = self.dis(x)return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

9.数据缓存器:

  • 用于存储和复用生成器生成的图像
class ReplayBuffer():def __init__(self, max_size=50):self.max_size = max_sizeself.data = []
  • 将新的数据推入缓存,并弹出旧的数据;如果缓存未满,则将数据推入缓存。如果缓存已满,则随机替换缓存中的一个数据。
   def push_and_pop(self, data):to_return = []for element in data.data:element = torch.unsqueeze(element, 0)if len(self.data) < self.max_size:self.data.append(element)to_return.append(element)else:if random.uniform(0,1) > 0.5:i = random.randint(0, self.max_size-1)to_return.append(self.data[i].clone())self.data[i] = elementelse:to_return.append(element)return Variable(torch.cat(to_return))
  • 实例化ReplayBuffer类,分别用于存储生成的A类和B类图像
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()
  • 定义生成器网络,用于从A类图像生成B类图像
netG_A2B = Generator()
netG_B2A = Generator()
  • 定义判别器网络,用于判断A类和B类图像的真实性
netD_A = Discriminator()
netD_B = Discriminator()
  • 定义GAN损失函数和循环一致性损失函数
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
  • 定义身份损失函数
criterion_identity = torch.nn.L1Loss()
  • 定义优化器的参数
d_learning_rate = 3e-4  # 3e-4
  • 定义生成器和判别器的学习器
g_learning_rate = 3e-4
optim_betas = (0.5, 0.999)g_optimizer = optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), 
lr=d_learning_rate)
da_optimizer = optim.Adam(netD_A.parameters(), lr=d_learning_rate)
db_optimizer = optim.Adam(netD_B.parameters(), lr=d_learning_rate)
  • 定义训练的轮数
num_epochs = 1000 

10.执行生成器的训练步骤:

  • 计算多个损失函数的值,综合考虑了图像的身份、对抗和循环一致性,来生成更真实的图像
same_B = netG_A2B(real_b).float()loss_identity_B = criterion_identity(same_B, real_b) * 5.0   same_A = netG_B2A(real_a).float()loss_identity_A = criterion_identity(same_A, real_a) * 5.0fake_B = netG_A2B(real_a).float()pred_fake = netD_B(fake_B).float()loss_GAN_A2B = criterion_GAN(pred_fake, target_real)fake_A = netG_B2A(real_b).float()pred_fake = netD_A(fake_A).float()loss_GAN_B2A = criterion_GAN(pred_fake, target_real)recovered_A = netG_B2A(fake_B).float()loss_cycle_ABA = criterion_cycle(recovered_A, real_a) * 10.0recovered_B = netG_A2B(fake_A).float()loss_cycle_BAB = criterion_cycle(recovered_B, real_b) * 10.0  loss_G = (loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB)loss_G.backward()    g_optimizer.step()

11.训练判别器:

  • 训练判别器A:通过计算真实图像和生成图像的对抗损失,来训练判别器以更准确地进行区分
da_optimizer.zero_grad()pred_real = netD_A(real_a).float()loss_D_real = criterion_GAN(pred_real, target_real)fake_A = fake_A_buffer.push_and_pop(fake_A)pred_fake = netD_A(fake_A.detach()).float()loss_D_fake = criterion_GAN(pred_fake, target_fake)loss_D_A = (loss_D_real + loss_D_fake) * 0.5loss_D_A.backward()da_optimizer.step()

训练判别器B:

db_optimizer.zero_grad()pred_real = netD_B(real_b)loss_D_real = criterion_GAN(pred_real, target_real)fake_B = fake_B_buffer.push_and_pop(fake_B)pred_fake = netD_B(fake_B.detach())loss_D_fake = criterion_GAN(pred_fake, target_fake)loss_D_B = (loss_D_real + loss_D_fake) * 0.5loss_D_B.backward()db_optimizer.step()

12.损失打印,存储伪造图片:

print('Epoch[{}],loss_G:{:.6f} ,loss_D_A:{:.6f},loss_D_B:{:.6f}'.format(epoch, loss_G.data.item(), loss_D_A.data.item(), loss_D_B.data.item()))if (epoch + 1) % 20 == 0 or epoch == 0:  b_fake = to_img(fake_B.data)a_fake = to_img(fake_A.data)a_real = to_img(real_a.data)b_real = to_img(real_b.data)save_image(a_fake, '../tmp/a_fake.png') save_image(b_fake, '../tmp/b_fake.png') save_image(a_real, '../tmp/a_real.png') save_image(b_real, '../tmp/b_real.png') 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/849592.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

编写程序,提示用户输入以米/秒(m/s)为单位的速度v和以米/秒的平方(m/s)为单位的加速度 a,然后显示最短跑道长度。

(物理:求出跑道长度)假设一个飞机的加速度是a而起飞速度是v&#xff0c;那么可以使用下 面的公式计算出飞机起飞所需的最短跑道长度: 编写程序&#xff0c;提示用户输入以米/秒(m/s)为单位的速度v和以米/秒的平方(m/s)为单 位的加速度 a&#xff0c;然后显示最短跑道长度。下面…

LCM — Least Common Multiple 最小公倍数

因为任何一个数都可以表示为若干个质数幂的乘积。 比如75 3*5*5&#xff0c;即 2^0 * 3^1 * 5^2 * 7^0 ... 那么对于两个数来说&#xff0c;gcd就是他们取每个质数的较小幂的乘积&#xff0c;lcm则相反。显然&#xff0c;这些幂加起来就是他们乘积。 gcd(a,b) * lcm(a,b) a…

CorelDRAW2024发布更新啦!设计师们的得力助手

在数字化的今天&#xff0c;视觉设计已经成为我们生活中不可或缺的一部分。从手机界面到广告海报&#xff0c;从网页布局到包装设计&#xff0c;每一个细节都离不开设计师们的专业与创意。然而&#xff0c;面对日益增长的设计需求和不断提升的审美标准&#xff0c;许多设计师开…

【算法专题--栈】最小栈--高频面试题(图文详解,小白一看就会!!)

目录 一、前言 二、题目描述 三、解题方法 ⭐解题方法--1 ⭐解题方法--2 四、总结 五、共勉 一、前言 最小栈这道题&#xff0c;可以说是--栈专题--&#xff0c;比较经典的一道题&#xff0c;也是在面试中频率较高的一道题目&#xff0c;通常在面试中&#xff0c;面试官可…

OpenAI发布GPT-4思维破解新策略,Ilya亦有贡献!

OpenAI正在研究如何破解GPT-4的思维&#xff0c;并公开了超级对齐团队的工作&#xff0c;Ilya Sutskever也在作者名单中。 论文地址&#xff1a;https://cdn.openai.com/papers/sparse-autoencoders.pdf 代码&#xff1a;https://github.com/openai/sparse_autoencoder 特征可…

【Unity游戏制作】地精寻宝Gnome‘s Well That Ends Well卷轴动作游戏【一】场景搭建

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 专栏交流&#x1f9e7;&…

Redisson分布式锁原理解析

前言 首先Redis执行命令是单线程的&#xff0c;所以可以利用Redis实现分布式锁&#xff0c;而对于Redis单线程的问题&#xff0c;是其线程模型的问题&#xff0c;本篇重点是对目前流行的工具Redisson怎么去实现的分布式锁进行深入理解&#xff1b;开始之前&#xff0c;我们可以…

MyBatis Plus<=3.5.6 存在 SQL 注入漏洞

MyBatis Plus<3.5.6 存在 SQL 注入漏洞 漏洞描述 MyBatis Plus 属于 MyBatis 的增强工具&#xff0c;目的时用于简化数据库开发&#xff0c;并提高开发效率。 收到 SQL 注入漏洞影响的版本&#xff0c;由于 UpdateWrapper 类未对用户可控的参数进行过滤导致存在 SQL 注入漏…

什么情况下要配置DNS服务

什么是DNS 一、DNS就是域名解析 我们上网的方式通常都由ip地址组成&#xff0c;但是为了有个规范&#xff0c;而且我们也不可能去记住那么多一串Ip数字&#xff0c;首先域名就会比ip好记很多&#xff0c;其次固定性&#xff0c;一旦服务器换了&#xff0c;只要重新绑定域名对…

汇编指令——ARM Cortex-M指令分析

cpsid i 这条指令 cpsid i 是 ARM Cortex-M 处理器的汇编语言指令&#xff0c;用于关闭全局中断。在 ARM Cortex-M 处理器中&#xff0c;cpsid i 指令的作用是将处理器的中断&#xff08;IRQ&#xff09;禁用&#xff0c;以防止中断干扰当前的执行流程。这意味着在执行这条指令…

Mac - Node/Java 配置安装全流程

Mac - Node/Java 配置安装全流程 一. Git 安装二. Java 相关安装2.1 jenv 版本控制工具2.2 JDK1.8 和 JDK21的安装2.3 maven 安装 三. Node 相关安装3.1 nvm 版本控制工具3.2 Node 版本安装 一. Git 安装 1.我们首先安装一下Homebrew&#xff0c;这个工具很有用&#xff0c;能…

LLM的基础模型7:Positional Encoding

大模型技术论文不断&#xff0c;每个月总会新增上千篇。本专栏精选论文重点解读&#xff0c;主题还是围绕着行业实践和工程量产。若在某个环节出现卡点&#xff0c;可以回到大模型必备腔调或者LLM背后的基础模型新阅读。而最新科技&#xff08;Mamba,xLSTM,KAN&#xff09;则提…

单列集合.java

单列集合 为了存储不同类型的多个对象&#xff0c;Java提供了一些特殊系列的类&#xff0c;这些类可以存储任意类型的对象&#xff0c;并且存储的长度可变&#xff0c;这些类统称为集合。可以简单的理解为一个长度可变&#xff0c;可以存储不同数据类型的动态数组。集合都位于j…

【机器学习】原理与应用场景 Python代码展现

机器学习&#xff1a;原理、应用与实例深度解析 引言一、机器学习的基本原理二、机器学习的应用范围三、机器学习实例解析四、机器学习部分讲解五、机器学习的挑战与未来 引言 随着大数据和计算能力的飞速发展&#xff0c;机器学习&#xff08;Machine Learning, ML&#xff0…

【UML用户指南】-10-对高级结构建模-高级类

目录 1、类目 2、高级类 3、可见性 4、实例范围和静态范围 5、抽象元素、叶子元素和多态性元素 6、多重性 7、属性 8、操作 9、模板类 10、标准元素 1、类目 类目 &#xff08;classifier&#xff09;是描述结构特征和行为特征的机制。类目包括类、关联、接口、数据类…

补充SimGNN

补充SimGNN 理解Test函数&#xff1a; 理解Test函数&#xff1a; 理解test&#xff08;&#xff09;函数中部分代码&#xff1a; 假设数据&#xff1a; test_dataset [ {“norm_ged”: 0.1, “edge_index_1”: …, “edge_index_2”: …, “features_1”: …, “features_2”:…

常见硬件工程师面试题(一)

大家好&#xff0c;我是山羊君Goat。 对于硬件工程师&#xff0c;学习的东西主要和电路硬件相关&#xff0c;所以在硬件工程师的面试中&#xff0c;对于经验是十分看重的&#xff0c;像PCB设计&#xff0c;电路设计原理&#xff0c;模拟电路&#xff0c;数字电路等等相关的知识…

人工智能治理国内外政策与标准分析

文│阿里巴巴标准化部 朱红儒、彭骏涛、孙勇&#xff1b;中国信息通信研究院安全研究所 静静 人工智能&#xff08;AI&#xff09;作为新一轮科技革命的重要驱动力量&#xff0c;正在有效推动着数字化转型&#xff0c;其带来巨大机遇的同时&#xff0c;也伴随着新的风险和挑战…

数据库设计步骤、E-R图转关系模式、E-R图的画法

一、数据库设计步骤 ①需求分析阶段 准确了解与分析用户需求。 ②概念结构设计阶段 通过对用户需求进行综合、归纳与抽象&#xff0c;形成一个独立于具体数据库管理系统的概念模型。 ③逻辑结构设计阶段 将概念结构转换为某个数据库管理系统所支持的数据模型&am…

“安全生产月”专题报道:AI智能监控技术如何助力安全生产

今年6月是第23个全国“安全生产月”&#xff0c;6月16日为全国“安全宣传咨询日”。今年全国“安全生产月”活动主题为“人人讲安全、个个会应急——畅通生命通道”。近日&#xff0c;国务院安委会办公室、应急管理部对开展好2024年全国“安全生产月”活动作出安排部署。 随着科…