【PyTorch][chapter 20][李宏毅深度学习]【无监督学习][ GAN]【实战】

前言

 本篇主要是结合手写数字例子,结合PyTorch 介绍一下Gan 实战

第一轮训练效果

第20轮训练效果,已经可以生成数字了

68 轮


目录: 

  1.   谷歌云服务器(Google Colab)
  2.   整体训练流程
  3.   Python 代码

一  谷歌云服务器(Google Colab)

     个人用的一直是联想小新笔记本,虽然非常稳定方便。但是现在跑深度学习,性能确实有点跟不上. 

   1.1    打开谷歌云服务器(Google Colab)

      https://colab.research.google.com/

    1. 2  新建笔记

                 

1

 1.4  选择T4GPU 

1.5  点击运行按钮

可以看到当前硬件的情况

     


二  整体训练流程


三    PyTorch 例子

# -*- coding: utf-8 -*-
"""
Created on Fri Mar  1 13:27:49 2024@author: chengxf2
"""
import torch.optim as optim #优化器
import numpy as np 
import matplotlib.pyplot  as plt
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn#第一步加载手写数字集
def loadData():#同时归一化数据集(-1,1)style = transforms.Compose([transforms.ToTensor(),   #0-1 归一化0-1, channel,height,widthtransforms.Normalize(mean=0.5, std=0.5) #变成了-1,1 ])trainData = torchvision.datasets.MNIST('data',train=True,transform=style,download=True)dataloader = torch.utils.data.DataLoader(trainData,batch_size= 16,shuffle=True)imgs,_ = next(iter(dataloader))#torch.Size([64, 1, 28, 28])print("\n imgs shape ",imgs.shape)return dataloaderclass Generator(nn.Module):'''定义生成器输入:z 随机噪声[batch, input_size]输出:x: 图片 [batch, height, width, channel]'''def __init__(self,input_size):super(Generator,self).__init__()self.net = nn.Sequential(nn.Linear(in_features = input_size , out_features =256),nn.ReLU(),nn.Linear(in_features = 256 , out_features =512),nn.ReLU(),nn.Linear(in_features = 512 , out_features =28*28),nn.Tanh())def forward(self, z):# z 随机输入[batch, dim]x = self.net(z)#[batch, height, width, channel]#print(x.shape)x = x.view(-1,28,28,1)return xclass Discriminator(nn.Module):'''定义鉴别器输入:x: 图片 [batch, height, width, channel]输出:y:  二分类图片的概率: BCELoss 计算交叉熵损失'''def __init__(self):super(Discriminator,self).__init__()#开始的维度和终止的维度,默认值分别是1和-1self.flatten = nn.Flatten()self.net = nn.Sequential(nn.Linear(in_features = 28*28 , out_features =512),nn.LeakyReLU(), #负值的时候保留梯度信息nn.Linear(in_features = 512 , out_features =256),nn.LeakyReLU(),nn.Linear(in_features = 256 , out_features =1),nn.Sigmoid())def forward(self, x):x = self.flatten(x)#print(x.shape)out =self.net(x)return outdef gen_img_plot(model, epoch, test_input):out = model(test_input).detach().cpu()out = out.numpy()imgs = np.squeeze(out)fig = plt.figure(figsize=(4,4))for i in range(out.shape[0]):plt.subplot(4,4,i+1)img = (imgs[i]+1)/2.0#[-1,1]plt.imshow(img)plt.axis('off')plt.show()def train():#1 初始化参数device ='cuda' if torch.cuda.is_available() else 'cpu'#2 加载训练数据dataloader = loadData()test_input  = torch.randn(16,100,device=device)#3 超参数maxIter = 20 #最大训练次数input_size = 100batchNum = 16input_size =100#4 初始化模型gen = Generator(100).to(device)dis = Discriminator().to(device)#5 优化器,损失函数d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)g_optim = torch.optim.Adam(gen.parameters(),lr=1e-4)loss_fn = torch.nn.BCELoss()#6 loss 变化列表D_loss =[]G_loss= []for epoch in range(0,maxIter):d_epoch_loss = 0.0g_epoch_loss  =0.0#count = len(dataloader)for step ,(realImgs, _) in enumerate(dataloader):realImgs = realImgs.to(device)random_noise = torch.randn(batchNum, input_size).to(device)#先训练判别器d_optim.zero_grad()real_output = dis(realImgs)d_real_loss = loss_fn(real_output, torch.ones_like(real_output))d_real_loss.backward()#不要训练生成器,所以要生成器detachfake_img = gen(random_noise)fake_output = dis(fake_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossd_optim.step()#优化生成器g_optim.zero_grad()fake_output = dis(fake_img.detach())g_loss = loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss+= d_lossg_epoch_loss+= g_losscount = 16       with torch.no_grad():d_epoch_loss/=countg_epoch_loss/=countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)gen_img_plot(gen, epoch, test_input)print("Epoch: ",epoch)print("-----finised-----")if __name__ == "__main__":train()

参考:

10.完整课程简介_哔哩哔哩_bilibili

理论【PyTorch][chapter 19][李宏毅深度学习]【无监督学习][ GAN]【理论】-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/711894.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Open CASCADE学习|曲线曲面连续性

1、曲线的连续性 曲线的连续性是三维建模、动画设计等领域中非常重要的一个概念,它涉及到曲线在不同点之间的连接方式和光滑程度。下面将详细介绍曲线的连续性,包括C连续性和G连续性。 1.1C连续性(参数连续性) C连续性是指曲线…

使用MyBatisPlus实现向数据库中存储List类型的数据

使用MyBatisPlus实现向数据库中存储List类型的数据 问题描述 建表时,表中的这五个字段为json类型 但是在入库的时候既不能写入数据,也不能查询出数据。 解决方案: 1.首先明确,数据存入的时候是经过了数据类型转化&#xff0c…

数据之光:探索数据库技术的演进之路

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua,在这里我会分享我的知识和经验。&#x…

喜讯!持安科技CEO何艺获评安全419《2023年度十大优秀创业者》

近日,由网络安全产业资讯媒体安全419主办的《年度策划》2023年度十大优秀创业者正式出炉,零信任办公安全技术创新企业持安科技创始人兼CEO何艺,获评十大优秀创业者。 这是安全419第二届推出该项目的评选活动,安全419编辑老师在多年…

抽象类、模板方法模式

抽象类概述 在Java中abstract是抽象的意思,如果一个类中的某个方法的具体实现不能确定,就可以申明成abstract修饰的抽象方法(不能写方法体了),这个类必须用abstract修饰,被称为抽象类。 抽象方法定义&…

这些单片机汇编语言的错误,你还在犯错吗?

在单片机开发中,很多工程师会选择汇编语言来作为底层编程,来直接控制硬件和高校执行命令,然而因为汇编语言是直接与硬件交互,所以很容易出现错误,本文将基于Keil C51汇编器的环境总结单片机汇编语言常见的错误&#xf…

人工智能_大模型010_Centos7.9中CPU安装ChatGLM3-6B大模型_安装使用_010---人工智能工作笔记0145

从一个空的虚拟机开始安装: https://www.modelscope.cn/models/ZhipuAI/chatglm3-6b/files 可以看到这里有很多的数据文件,那么这里 这里点击模型文件就可以下载,这个就是chatglm3-6B的文件,需要点击每个文件,然后点击右边的下载,把文件都下载下来 右侧有下载按钮.点击下载可…

使用Fabric创建的canvas画布背景图片,自适应画布宽高

之前的文章写过vue2使用fabric实现简单画图demo,完成批阅功能;但是功能不完善,对于很大的图片就只能显示一部分出来,不符合我们的需求。这就需要改进,对我们设置的背景图进行自适应。 有问题的canvas画布背景 修改后的…

【rust】11、所有权

文章目录 一、背景二、Stack 和 Heap2.1 Stack2.2 Heap2.3 性能区别2.4 所有权和堆栈 三、所有权原则3.1 变量作用域3.2 String 类型示例 四、变量绑定背后的数据交互4.1 所有权转移4.1.1 基本类型: 拷贝, 不转移所有权4.1.2 分配在 Heap 的类型: 转移所有权 4.2 Clone(深拷贝)…

Quartz 任务调度框架源码阅读解析

概念: quartz 是一个基于JAVA的定时任务调度框架 案例: <dependency><groupId>org.quartz-scheduler</groupId><artifactId>quartz</artifactId><version>2.3.0</version></dependency>JobDetail job JobBuilder.newJob(Sc…

LeetCode 刷题 [C++] 第236题.二叉树的最近公共祖先

题目描述 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&#xff0c;满足 x 是 p、q 的祖先且 x 的深度尽可能大&#xff08;一个节点也可以…

大数据分析案例-基于SVM支持向量机算法构建手机价格分类预测模型

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

矩阵爆破逆向之条件断点的妙用

不知道你是否使用过IDA的条件断点呢&#xff1f;在IDA进阶使用中&#xff0c;它的很多功能都有大作用&#xff0c;比如&#xff1a;ida-trace来跟踪调用流程。同时IDA的断点功能也十分强大&#xff0c;配合IDA-python的输出语句能够大杀特杀&#xff01; 那么本文就介绍一下这…

【JAVA】JDK内置工具之appletviewer

下载java 下载java的时候会先下载Java jdk&#xff0c;Java Development Kit Java开发工具包。 然后会下载jre&#xff0c;也就是Java Runtime Environment Java运行环境。什么是JDK、JRE&#xff1f;_java中的jdk,jre代表什么-CSDN博客 下载之后先找到java下的bin文件&#x…

yolov9 tensorRT 的 C++ 部署

yolov9 tensorRT C 部署 本示例中&#xff0c;包含完整的代码、模型、测试图片、测试结果。 完整的代码、模型、测试图片、测试结果【github参考链接】 TensorRT版本&#xff1a;TensorRT-7.1.3.4 导出onnx模型 导出适配本实例的onnx模型参考【yolov9 瑞芯微芯片rknn部署、地平…

网络爬虫的危害,如何有效的防止非法利用

近年来&#xff0c;不法分子利用“爬虫”软件收集公民隐私数据案件屡见不鲜。2023年8月23日&#xff0c;北京市高级人民法院召开北京法院侵犯公民个人信息犯罪案件审判情况新闻通报会&#xff0c;通报侵犯公民个人隐私信息案件审判情况&#xff0c;并发布典型案例。在这些典型案…

获取PDF中的布局信息——如何获取段落

PDF解析是极其复杂的问题。不可能靠一个工具解决全部问题&#xff0c;尤其是五花八门&#xff0c;格式不统一的PDF文件。除非有钞能力。如果没有那就看看可以分为哪些问题。 提取文本内容&#xff0c;提取表格内容&#xff0c;提取图片。我认为这些应该是分开做的事情。python有…

DataSpell 2023:专注于数据,加速您的数据科学之旅 mac/win版

JetBrains DataSpell 2023是一款专为数据科学家和数据分析师设计的集成开发环境&#xff08;IDE&#xff09;。这款IDE提供了强大的数据分析和可视化工具&#xff0c;旨在帮助用户更快速、更高效地进行数据科学工作。 DataSpell 2023软件获取 DataSpell 2023在保持其一贯的数…

【多线程】常见锁策略详解(面试常考题型)

目录 &#x1f334; 乐观锁 vs 悲观锁&#x1f38d;重量级锁 vs 轻量级锁&#x1f340;自旋锁&#xff08;Spin Lock&#xff09;&#x1f38b;公平锁 vs ⾮公平锁&#x1f333;可重⼊锁 vs 不可重⼊锁&#x1f384;读写锁⭕相关面试题 常⻅的锁策略 注意: 接下来讲解的锁策略不…

cpp基础学习笔记03:类型转换

static_cast 静态转换 用于类层次结构中基类和派生类之间指针或者引用的转换。up-casting (把派生类的指针或引用转换成基类的指针或者引用表示)是安全的&#xff1b;down-casting(把基类指针或引用转换成子类的指针或者引用)是不安全的。用于基本数据类型之间的转换&#xff…