CycleGAN深度学习项目

远程仓库

leftthomas/CycleGAN: A PyTorch implementation of CycleGAN based on ICCV 2017 paper "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" (github.com)

运行准备

Anaconda

安装需要的库

指令

pip install pandas -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install torch==1.11.0 -i Simple Index

pip install torchvision==0.12.0 -i Simple Index

pip install dominate==2.4.0 -i Simple Index

pip install visdom==0.1.8.8 -i Simple Index

pip install tqdm -i https://pypi.tuna.tsinghua.edu.cn/simple

运行结果

数据集

我当前使用的数据集

leftthomas/CycleGAN: A PyTorch implementation of CycleGAN based on ICCV 2017 paper "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" (github.com)

因为数据集太大,训练时间慢所以删掉了很多图片

A-副本和B-副本里面是原始的数据集

A B是我自己删了图片的数据集

如果使用其他数据集也可以训练,例如:从网上随便下载图片

运行结果

程序解读

从main.py的if __name__ == '__main__':开始看

因为程序从这里开始执行

parser = argparse.ArgumentParser(description='Train Model')
# common args
parser.add_argument('--data_root', default='horse2zebra', type=str, help='Dataset root path')
# 文件放的位置
parser.add_argument('--batch_size', default=1, type=int, help='Number of images in each mini-batch')
#每个小批量中的图像数量
parser.add_argument('--epochs', default=2, type=int, help='Number of epochs over the data to train')
# 多少轮训练
parser.add_argument('--lr', default=0.0002, type=float, help='Initial learning rate')
# 开始时学习率
parser.add_argument('--decay', default=2, type=int, help='Epoch to start linearly decaying lr to 0')
# 从第几轮开始学习率逐渐减为0
parser.add_argument('--save_root', default='result', type=str, help='Result saved root path')
# 训练出来的保存在哪里
# args parse
args = parser.parse_args()
data_root, batch_size, epochs, lr = args.data_root, args.batch_size, args.epochs, args.lr
decay, save_root = args.decay, args.save_root# data prepare
train_data = ImageDataset(data_root, 'train')
# 训练集
print("数据")
print(train_data.__len__())
# 打印出数据集的长度
test_data = ImageDataset(data_root, 'test')
# 验证集
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=8)

使用通义灵码解释

# optimizer setup
optimizer_G = Adam(itertools.chain(G_A.parameters(), G_B.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_DA = Adam(D_A.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_DB = Adam(D_B.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
lr_scheduler_DA = LambdaLR(optimizer_DA, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))
lr_scheduler_DB = LambdaLR(optimizer_DB, lr_lambda=lambda eiter: 1.0 - max(0, eiter - decay) / float(decay))

这段代码是用于设置优化器和学习率调度器的。

首先,使用Adam优化器来初始化G_A和G_B的参数以及D_A和D_B的参数。Adam优化器是一种基于梯度的优化算法,它利用了动量和自适应学习率的特性。itertools.chain函数用于将G_A和G_B的参数组合在一起。

然后,使用LambdaLR学习率调度器来设置学习率的衰减。LambdaLR调度器使用给定的函数来计算每个迭代步骤的学习率。这里使用了一个lambda函数,它在迭代次数eiter超过decay后开始衰减学习率,衰减的速度由decay参数控制。

这些优化器和学习率调度器将用于训练生成器和判别器模型。

# training loop
results = {'train_g_loss': [], 'train_da_loss': [], 'train_db_loss': []}
if not os.path.exists(save_root):os.makedirs(save_root)
for epoch in range(1, epochs + 1):g_loss, da_loss, db_loss = train(G_A, G_B, D_A, D_B, train_loader, optimizer_G, optimizer_DA, optimizer_DB)results['train_g_loss'].append(g_loss)results['train_da_loss'].append(da_loss)results['train_db_loss'].append(db_loss)val(G_A, G_B, test_loader)lr_scheduler_G.step()lr_scheduler_DA.step()lr_scheduler_DB.step()# save statisticsdata_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))data_frame.to_csv('{}/results.csv'.format(save_root), index_label='epoch')torch.save(G_A.state_dict(), '{}/GA.pth'.format(save_root))torch.save(G_B.state_dict(), '{}/GB.pth'.format(save_root))torch.save(D_A.state_dict(), '{}/DA.pth'.format(save_root))torch.save(D_B.state_dict(), '{}/DB.pth'.format(save_root))

这段代码是一个训练循环,用于训练深度学习模型。以下是代码的详细解释:

首先,定义一个字典results,用于存储训练过程中的损失值。

检查保存模型和结果的目录save_root是否存在,如果不存在则创建该目录。

使用for循环遍历epochs次,每次迭代都会进行一次训练和验证。

在每次迭代中,调用train函数训练生成器G_A、G_B和判别器D_A、D_B,并更新损失值。

将训练过程中的损失值分别添加到results字典中对应的列表中。

调用val函数对模型进行验证。

更新生成器和判别器的学习率。

将results字典转换为DataFrame,并将其保存为CSV文件。

保存生成器和判别器的模型参数。

这个训练循环的主要目的是在给定的训练数据集上训练生成对抗网络(GAN),并保存训练过程中的损失值和模型参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/46592.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LPRNet 车牌识别部署 rk3588(pt-onnx-rknn)包含各个步骤完整代码

虽然车牌识别技术很成熟了,但完全没有接触过。一直想搞一下、整一下、试一下、折腾一下,工作之余找了一个简单的例子入个门。本博客简单记录一下 LPRNet 车牌识别部署 rk3588流程,训练参考 LPRNet 官方代码。 1、导出onnx   导出onnx很容易…

【postgresql】权限(Privileges)

权限(privileges)是决定用户或角色可以对数据库对象(如表、视图、序列和函数)执行哪些操作的许可。权限对于维护安全性和控制对数据的访问至关重要。 权限分类 在 PostgreSQL 中,权限分为以下几种: SELEC…

数据库基本查询(表的增删查改)

一、增加 1、添加信息 insert 语法 insert into table_name (列名) values (列数据1,列数据2,列数据3...) 若插入时主键或唯一键冲突就无法插入。 但如果我们就是要修改一列信息也可以用insert insert into table_name (列名) values (列数据1&am…

客户端通过服务器进行TCP通信(三)

一. 对TCP的基础讲解 服务端 1. 首先创建一个套接字,TCP是面向字节流的套接字,故需要使用SOCK_STREAM 2. 然后使用bind()函数将套接字与服务器地址关联(如果是在本地测试,直接将地址设置为217.0.0.1或者localhost,端口号为1000…

内存函数(C语言)

内存函数 以下函数的头文件:string.h 针对内存块进行处理的函数 memcpy 函数原型: void* memcpy(void* destination, const void* source, size_t num);目标空间地址 源空间地址num,被拷贝的字节个数 返回目标空间的起始地…

Python与自动化脚本编写

Python与自动化脚本编写 Python因其简洁的语法和强大的库支持,成为了自动化脚本编写的首选语言之一。在这篇文章中,我们将探索如何使用Python来编写自动化脚本,以简化日常任务。 一、Python自动化脚本的基础 1. Python在自动化中的优势 Pyth…

1.31、基于长短记忆网络(LSTM)的发动机剩余寿命预测(matlab)

1、基于长短记忆网络(LSTM)的发动机剩余寿命预测的原理及流程 基于长短期记忆网络(LSTM)的发动机剩余寿命预测是一种常见的机器学习应用,用于分析和预测发动机或其他设备的剩余可用寿命。下面是LSTM用于发动机剩余寿命预测的原理和流程: 数据收集&#…

【Linux】 GCC/G++与Makefile使用

Linux GCC/G使用 GCC如何完成 格式:gcc [选项] 要编译的文件 [选项] [目标文件] 常用选项: -E:让gcc在预处理结束后停止编译过程,输出.i的C语言原始文件。-S:该选项只是进行编译而不是进行汇编,最终生成汇…

力扣144题:二叉树的先序遍历

给你二叉树的根节点 root ,返回它节点值的 前序 遍历。 示例 1: 输入:root [1,null,2,3] 输出:[1,2,3]示例 2: 输入:root [] 输出:[]示例 3: 输入:root [1] 输出&am…

C++入门学习——初始化列表

概念 初始化列表:以一个冒号开始,接着是一个以逗号分隔的数据成员列表,每个"成员变量"后面跟一个放在括 号中的初始值或表达式 class Date { public://初始化列表Date(int year,int month,int day):_year(year),_month(month),_d…

[Windows] 油.管视频下载神器 Gihosoft TubeGet Pro v9.3.88

描述 对于经常在互联网上进行操作的学生,白领等! 一款好用的软件总是能得心应手,事半功倍。 今天给大家带了一款高科技软件 管视频下载神器 无需额外付费,永久免费! 亲测可运行!! 内容 目前主…

高德地图显示圆形区域并在区域边上标注半径

bug:循环创建三个圆形区域 ,数组设置为[{raduis:500,color:“#FF0000”}],然后循环取颜色会莫名其妙报错修改为 strokeColor: [“#FF0000”, “#1EE3C2”, “#3772E9”][i]即可 initAMap() {AMapLoader.load({key: "130cca3be68a2ff0fd5…

记VMware网络适配器里的自定义特定虚拟网络一直加载问题解决办法

1、问题描述 VMware网络适配器里的自定义特定虚拟网络一直加载问题: 在自定义:特定虚拟网络选择的时候 没有上图所示的三个选择,而是正在加载虚拟网络.... 如下图所示: 2、解决办法 2.1、原因分析: 是安装时候出现…

安防视频监控/视频汇聚EasyCVR平台浏览器http可以播放,https不能播放,如何解决?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台基于云边端一体化架构,兼容性强、支持多协议接入,包括国标GB/T 28181协议、部标JT808、GA/T 1400协议、RTMP、RTSP/Onvif协议、海康Ehome、海康SDK、大华SDK、华为SDK、宇视SDK、乐橙SDK、萤石云SD…

7.15洛谷蓝题

二分答案的两个模板&#xff1a; 1.最小值的最大化&#xff1a; #define _CRT_SECURE_NO_WARNINGS 1 #include<bits/stdc.h> #include<iostream> #include<algorithm> #include<cstring> #include<vector> #include<queue> #include<…

Studying-代码随想录训练营day40| 198.打家劫舍、213.打家劫舍II、337.打家劫舍III

第40天&#xff0c;动态规划part07&#xff0c;动态规划经典题型“打家劫舍”(ง •_•)ง&#xff0c;编程语言&#xff1a;C 目录 198.打家劫舍 213.打家劫舍II 337.打家劫舍III 总结 198.打家劫舍 文档讲解&#xff1a;代码随想录打家劫舍 视频讲解&#xff1a;手…

【C++进阶学习】第七弹——AVL树——树形结构存储数据的经典模块

二叉搜索树&#xff1a;【C进阶学习】第五弹——二叉搜索树——二叉树进阶及set和map的铺垫-CSDN博客 目录 一、AVL树的概念 二、AVL树的原理与实现 AVL树的节点 AVL树的插入 AVL树的旋转 AVL树的打印 AVL树的检查 三、实现AVL树的完整代码 四、总结 前言&#xff1a…

JavaScript青少年简明教程:输入输出

JavaScript青少年简明教程&#xff1a;输入输出 JavaScript的输入输出情况相对复杂&#xff0c;因为它依赖于其运行的宿主环境&#xff08;如Web浏览器或Node.js&#xff09;来提供具体的输入输出机制。JavaScript的核心规范&#xff08;ECMAScript&#xff09;本身并不直接提…

C基础day9

一、思维导图 二、课后练习 1> 使用递归实现 求 n 的 k 次方 #include<myhead.h>int Pow(int n,int k) {if(k 0 ) //递归出口{return 1;}else{return n*Pow(n,k-1); //递归主体} }int main(int argc, const char *argv[]) {int n0,k0;printf("请输入n和k:&…

韩国coupang上线的卖家官网是什么?韩国电商有哪些平台?

根据Statista的调查报告&#xff0c;预计2024年电子商务市场收入将达到4.117亿美元。而韩国的电子商务市场是全球最具活力和创新性的市场之一&#xff0c;有数据显示2023年韩国电商市场规模已突破1700亿美元&#xff0c;全球排名第四。 韩国coupang上线的卖家官网是什么&#x…