pytorch生成对抗网络

# 生成对抗网络
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数
latent_size = 64 # 潜在空间(latent space)的维度数量
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5],   # 1 for greyscale channels
                                     std=[0.5])])
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./datasets',
                                   train=True,
                                   transform=transform,
                                   download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)
for i,_ in data_loader:
    print(i.shape,i.max(),i.min(),torch.unique(_))
    break
# 鉴别器
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())
# 生成器
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1) # 裁剪到0-1
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
torch.zeros(4, 1).shape
torch.randn(4,50).shape
# 生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,
# 判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由
# 真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数
# 鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断
# 它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差
# 特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是
#生成器模型中的参数
# 训练
total_step = len(data_loader) # 总批次数
for epoch in range(num_epochs):
    # 遍历每个批次数据
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        # 创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        # 训练鉴别器
        outputs = D(images) # 获取鉴别器对真实图片的鉴别分数
        d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失
        real_score = outputs # 真实鉴别分数
        # z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z) # 获取生成的图片
        outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数
        d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失
        fake_score = outputs
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失
        reset_grad()
        # 鉴别器损失反向传播
        d_loss.backward()
        # 根据梯度更新参数
        d_optimizer.step()
        # 训练生成器
        # 随机初始化一个噪音图片(用一定大小的向量表示)
        z = torch.randn(batch_size, latent_size).to(device)
        # 通过生成器生成图片
        fake_images = G(z)
        outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分
        # 生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差
        g_loss = criterion(outputs, real_labels)
        # 清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        # 每隔200批次打印日志
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))
    # 真实图片只需要保存一次
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    # 每个轮次都会保存一次生成器生成的图片
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

 

上面是真实图片,下面是经过200个轮次后的生成图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/63387.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

taro小程序进入腾讯验证码

接入原因 昨天突然晚上有人刷我们公司的登录发送短信接口,紧急将小程序的验证码校验更新上去了 接下来就是我们的接入方法,其实很简单,不过有时候可能大家着急就没有仔细看文档,腾讯验证码文档微信小程序地址,注意这里…

Node.js 新手教程

1、nodejs简介 Node.js 是一个开源和跨平台的 JavaScript 运行时环境。它是几乎所有类型项目的流行工具! Node.js 在浏览器之外运行 V8 JavaScript 引擎(Google Chrome 的核心)。这使得 Node.js 的性能非常出色。 Node.js 应用程序在单个进…

从GCC源码分析C语言编译原理——源码表层分析(脚本篇)

目录 一、目录结构 二、有意思的小功能 三、install脚本 脚本变量和设置 程序名称变量 模式和命令 参数解析 主要逻辑 四、主要功能脚本 ------------------------------------------------------------------------------------------------------------------------…

Mybatis 学习 之 XML 手册

目录 单次执行单次新增单次更新单次删除 批量执行批量新增批量更新for 循环执行更新for 循环生成多条 sql&#xff0c;一次执行 批量删除 参数传递预处理方式 (OGNL表达式 #{})数据类型转换 直接替换 (EL表达式 ${}) 安全 单次执行 单次新增 <insert id"insert"…

自定义指令,全局,局部,注册

让输入框自动获取焦点(每次刷新自动获取焦点&#xff09; <template><div><h3>自定义指令</h3><input ref"inp" type"text"></div> </template><script> export default {mounted(){this.$refs.inp.focus…

CSS3 动画详解,介绍、实现与应用场景详解

CSS3 动画概述 CSS3 动画是通过 CSS3 的新特性来实现元素的动态变化。与传统的 JavaScript 动画不同,CSS3 动画主要通过 CSS 属性的变化来实现动画效果,具有高效、轻量和易于实现的优点。CSS3 动画通常用于网页的动态交互效果、过渡效果、元素移动、缩放、旋转等场景。 一、…

二、点亮希望之光:寄存器与库函数驱动 LED 灯

文章目录 一、寄存器1、存储器映射2、存储器映射表3、寄存器4、寄存器映射5、寄存器重映射6、总线基地址、外设基地址、外设寄存器地址7、操作寄存器&#xff08;以操作一个GPIO口为例&#xff09;1. 寄存器地址定义部分2. GPIOD_Configuration 函数部分3. main 函数部分 二、库…

【特殊子序列 DP】力扣1137. 第 N 个泰波那契数

泰波那契序列 Tn 定义如下&#xff1a; T0 0, T1 1, T2 1, 且在 n > 0 的条件下 Tn3 Tn Tn1 Tn2 给你整数 n&#xff0c;请返回第 n 个泰波那契数 Tn 的值。 示例 1&#xff1a; 输入&#xff1a;n 4 输出&#xff1a;4 解释&#xff1a; T_3 0 1 1 2 T_4 1 …

梯度下降法求解局部最小值深入讨论以及 Python 实现

文章目录 0. 前期准备1. 局部最小值2. 例子3. 讨论3.1 增加迭代次数3.2 修改学习率 η \eta η3.3 修改初始值 x 0 x^0 x0 4. 总结参考 0. 前期准备 在开始讲梯度下降法求解函数的局部最小值之前&#xff0c; 你需要有梯度下降法求解函数的最小值的相关知识。 如果你还不是…

css部分

前面我们学习了HTML&#xff0c;但是HTML仅仅只是做数据的显示&#xff0c;页面的样式比较简陋&#xff0c;用户体验度不高&#xff0c;所以需要通过CSS来完成对页面的修饰&#xff0c;CSS就是页面的装饰者&#xff0c;给页面化妆&#xff0c;让它更好看。 1 层叠样式表&#…

7-zip如何分卷压缩文件并加密?

想要压缩的文件过大&#xff0c;想要在压缩过程中将文件拆分为几个压缩包并且同时为所有压缩包设置加密应该如何设置&#xff1f; 想要分卷压缩文件并加密一起操作就可以完成了&#xff0c;设置方法如下&#xff1a; 打开7-zip&#xff0c;选中需要压缩的文件&#xff0c;选择…

14.在 Vue 3 中使用 OpenLayers 自定义地图版权信息

在 WebGIS 开发中&#xff0c;默认的地图服务通常会带有版权信息&#xff0c;但有时候我们需要根据项目需求自定义版权信息或添加额外的版权声明。在本文中&#xff0c;我们将基于 Vue 3 的 Composition API 和 OpenLayers&#xff0c;完成自定义地图版权信息的实现。 最终效果…

【游戏】超休闲游戏:简单易上手的游戏风潮

1. 什么是超休闲游戏&#xff1f; 超休闲游戏&#xff08;Hyper-Casual Games&#xff09; 是一种专注于简单玩法、快速上手、短时间内能完成的游戏类型。这类游戏通常具有以下特点&#xff1a; 玩法简单&#xff1a;规则直观&#xff0c;用户无需教程即可上手。碎片化体验&a…

【RTKLIB源码阅读】rtklibexplorer博客翻译-Updated guide to the RTKLIB configuration file

我已经有一段时间没有更新我的RTKLIB配置文件指南了。 自从上次更新以来,我增加了一些新的功能,并对一些现有的功能有了更多的了解。 对于以前的更新,我只是更新了原帖,但这次我想我应该重新发布它,使它更容易被找到。、 RTKLIB 的一大优点是它的可配置性极强,有一系列的…

excel文件合并,每个excel名称插入excel列

import pandas as pd import os # 设置文件夹路径 folder_path rC:\test # 替换为您的下载文件夹路径 output_file os.path.join(folder_path, BOM材料.xlsx) # 创建一个空的 DataFrame 用于存储合并的数据 combined_data pd.DataFrame() # 遍历文件夹中的所有文件 for …

无法找到 msvcr120.dll,无法执行程序要怎么处理?修改msvcr120.dll文件

遇到“无法找到 msvcr120.dll&#xff0c;无法执行程序”这类错误通常指示着你的计算机缺少了一个核心组件&#xff0c;即 Microsoft Visual C 2013 Redistributable 的一部分&#xff1a;msvcr120.dll 文件。这个文件是许多依赖 Visual C 的应用程序运行的基础。目前修复此类问…

MongDB快速入门

一&#xff0c;MongoDB简介 MongoDB是一个开源&#xff0c;高性能&#xff0c;无模式的文档数据库&#xff0c;当初的设计就是用于简化开发和方便扩展&#xff0c;是NoSQL数据库产品的一种&#xff0c;也是最想关系型数据库的非关系型数据库 它支持的数据结构非常松散&#x…

Kruskal 算法在特定边权重条件下的性能分析及其实现

引言 Kruskal 算法是一种用于求解最小生成树(Minimum Spanning Tree, MST)的经典算法。它通过逐步添加权重最小的边来构建最小生成树,同时确保不会形成环路。在本文中,我们将探讨在特定边权重条件下 Kruskal 算法的性能,并分别给出伪代码和 C 语言实现。特别是,我们将分…

实战 | C# 中使用GPU加速YOLOv11 推理

导 读 本文主要介绍如何在C#中使用GPU加速YOLOv11推理。 YOLOv11介绍 C# 中使用YOLOv11 (GPU版本) 【1】环境和依赖项。 下载安装CUDA12.6和CUDNN9.6,截止文章日期最新版本,注意选择自己的版本,我的系统是win11 64位。

大数据-244 离线数仓 - 电商核心交易 ODS层 数据库结构 数据加载 DataX

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; Java篇开始了&#xff01; 目前开始更新 MyBatis&#xff0c;一起深入浅出&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff0…