pytorch生成对抗网络

# 生成对抗网络
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超参数
latent_size = 64 # 潜在空间(latent space)的维度数量
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5],   # 1 for greyscale channels
                                     std=[0.5])])
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./datasets',
                                   train=True,
                                   transform=transform,
                                   download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)
for i,_ in data_loader:
    print(i.shape,i.max(),i.min(),torch.unique(_))
    break
# 鉴别器
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())
# 生成器
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1) # 裁剪到0-1
def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
torch.zeros(4, 1).shape
torch.randn(4,50).shape
# 生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,
# 判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由
# 真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数
# 鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断
# 它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差
# 特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是
#生成器模型中的参数
# 训练
total_step = len(data_loader) # 总批次数
for epoch in range(num_epochs):
    # 遍历每个批次数据
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        # 创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        # 训练鉴别器
        outputs = D(images) # 获取鉴别器对真实图片的鉴别分数
        d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失
        real_score = outputs # 真实鉴别分数
        # z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z) # 获取生成的图片
        outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数
        d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失
        fake_score = outputs
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失
        reset_grad()
        # 鉴别器损失反向传播
        d_loss.backward()
        # 根据梯度更新参数
        d_optimizer.step()
        # 训练生成器
        # 随机初始化一个噪音图片(用一定大小的向量表示)
        z = torch.randn(batch_size, latent_size).to(device)
        # 通过生成器生成图片
        fake_images = G(z)
        outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分
        # 生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差
        g_loss = criterion(outputs, real_labels)
        # 清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        # 每隔200批次打印日志
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))
    # 真实图片只需要保存一次
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    # 每个轮次都会保存一次生成器生成的图片
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

 

上面是真实图片,下面是经过200个轮次后的生成图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/63387.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从GCC源码分析C语言编译原理——源码表层分析(脚本篇)

目录 一、目录结构 二、有意思的小功能 三、install脚本 脚本变量和设置 程序名称变量 模式和命令 参数解析 主要逻辑 四、主要功能脚本 ------------------------------------------------------------------------------------------------------------------------…

自定义指令,全局,局部,注册

让输入框自动获取焦点(每次刷新自动获取焦点&#xff09; <template><div><h3>自定义指令</h3><input ref"inp" type"text"></div> </template><script> export default {mounted(){this.$refs.inp.focus…

二、点亮希望之光:寄存器与库函数驱动 LED 灯

文章目录 一、寄存器1、存储器映射2、存储器映射表3、寄存器4、寄存器映射5、寄存器重映射6、总线基地址、外设基地址、外设寄存器地址7、操作寄存器&#xff08;以操作一个GPIO口为例&#xff09;1. 寄存器地址定义部分2. GPIOD_Configuration 函数部分3. main 函数部分 二、库…

梯度下降法求解局部最小值深入讨论以及 Python 实现

文章目录 0. 前期准备1. 局部最小值2. 例子3. 讨论3.1 增加迭代次数3.2 修改学习率 η \eta η3.3 修改初始值 x 0 x^0 x0 4. 总结参考 0. 前期准备 在开始讲梯度下降法求解函数的局部最小值之前&#xff0c; 你需要有梯度下降法求解函数的最小值的相关知识。 如果你还不是…

css部分

前面我们学习了HTML&#xff0c;但是HTML仅仅只是做数据的显示&#xff0c;页面的样式比较简陋&#xff0c;用户体验度不高&#xff0c;所以需要通过CSS来完成对页面的修饰&#xff0c;CSS就是页面的装饰者&#xff0c;给页面化妆&#xff0c;让它更好看。 1 层叠样式表&#…

7-zip如何分卷压缩文件并加密?

想要压缩的文件过大&#xff0c;想要在压缩过程中将文件拆分为几个压缩包并且同时为所有压缩包设置加密应该如何设置&#xff1f; 想要分卷压缩文件并加密一起操作就可以完成了&#xff0c;设置方法如下&#xff1a; 打开7-zip&#xff0c;选中需要压缩的文件&#xff0c;选择…

14.在 Vue 3 中使用 OpenLayers 自定义地图版权信息

在 WebGIS 开发中&#xff0c;默认的地图服务通常会带有版权信息&#xff0c;但有时候我们需要根据项目需求自定义版权信息或添加额外的版权声明。在本文中&#xff0c;我们将基于 Vue 3 的 Composition API 和 OpenLayers&#xff0c;完成自定义地图版权信息的实现。 最终效果…

【游戏】超休闲游戏:简单易上手的游戏风潮

1. 什么是超休闲游戏&#xff1f; 超休闲游戏&#xff08;Hyper-Casual Games&#xff09; 是一种专注于简单玩法、快速上手、短时间内能完成的游戏类型。这类游戏通常具有以下特点&#xff1a; 玩法简单&#xff1a;规则直观&#xff0c;用户无需教程即可上手。碎片化体验&a…

无法找到 msvcr120.dll,无法执行程序要怎么处理?修改msvcr120.dll文件

遇到“无法找到 msvcr120.dll&#xff0c;无法执行程序”这类错误通常指示着你的计算机缺少了一个核心组件&#xff0c;即 Microsoft Visual C 2013 Redistributable 的一部分&#xff1a;msvcr120.dll 文件。这个文件是许多依赖 Visual C 的应用程序运行的基础。目前修复此类问…

Kruskal 算法在特定边权重条件下的性能分析及其实现

引言 Kruskal 算法是一种用于求解最小生成树(Minimum Spanning Tree, MST)的经典算法。它通过逐步添加权重最小的边来构建最小生成树,同时确保不会形成环路。在本文中,我们将探讨在特定边权重条件下 Kruskal 算法的性能,并分别给出伪代码和 C 语言实现。特别是,我们将分…

实战 | C# 中使用GPU加速YOLOv11 推理

导 读 本文主要介绍如何在C#中使用GPU加速YOLOv11推理。 YOLOv11介绍 C# 中使用YOLOv11 (GPU版本) 【1】环境和依赖项。 下载安装CUDA12.6和CUDNN9.6,截止文章日期最新版本,注意选择自己的版本,我的系统是win11 64位。

大数据-244 离线数仓 - 电商核心交易 ODS层 数据库结构 数据加载 DataX

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; Java篇开始了&#xff01; 目前开始更新 MyBatis&#xff0c;一起深入浅出&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff0…

C# 高效编程指南:从命名空间到异常处理的技巧与最佳实践

在这条成长的路上&#xff0c;有一些核心概念将成为你开发过程中的得力助手—命名空间、预处理指令、正则表达式、异常处理和文件输入输出&#xff0c;这些看似独立的技术&#xff0c;实际上在大多数应用中都紧密相连&#xff0c;共同构成了C#开发的基础。 目录 C#命名空间 C…

gitee常见命令

目录 1.本地分支重命名 2.更新远程仓库分支 3.为当前分支设置远程跟踪分支 4.撤销已经push远程的代码 5.idea->gitee的‘还原提交’ 需要和本地当前的代码解决冲突 解决冲突 本地工作区的差异代码显示 本地commit和push远程 6.idea->gitee的‘将当前分支重置到此…

简易图书管理系统

javawebjspservlet 实体类 package com.ghx.entity;/*** author &#xff1a;guo* date &#xff1a;Created in 2024/12/6 10:13* description&#xff1a;* modified By&#xff1a;* version:*/ public class Book {private int id;private String name;private double pri…

什么是甘特图?使用甘特图制定项目计划表的步骤

在项目管理中&#xff0c;项目计划是项目的核心要素&#xff0c;它详细记录了项目任务详情、责任人、时间规划以及所需资源。 这份计划不仅为项目推进提供指引&#xff0c;更是控制范围蔓延、争取更多支持的有力工具。 在项目管理中&#xff0c;项目计划是项目的核心要素&…

mock.js介绍

mock.js http://mockjs.com/ 1、mock的介绍 *** 生成随机数据&#xff0c;拦截 Ajax 请求。** 通过随机数据&#xff0c;模拟各种场景&#xff1b;不需要修改既有代码&#xff0c;就可以拦截 Ajax 请求&#xff0c;返回模拟的响应数据&#xff1b;支持生成随机的文本、数字…

16.[极客大挑战 2019]Upload1

进入靶场 是文件上传类题目 随便传个图片 制作个含木马的 算了&#xff0c;直接抓包只传文件改格式好了 第一次传的是<?php eval($_POST[attack]);?> 第二次传的是GIF89a? <script language"php">eval($_REQUEST[1])</script> "GIF89a&…

标书里的“废标雷区”:你踩过几个?

在投标领域&#xff0c;标书的质量不仅决定了中标的可能性&#xff0c;更是体现企业专业度的关键。但即便是经验丰富的投标人&#xff0c;也难免会在标书编制过程中踩中“废标雷区”。这些雷区可能隐藏在技术方案的细节中&#xff0c;也可能是投标文件格式的规范问题。以下&…

电脑投屏到电脑:Windows,macOS及Linux系统可以相互投屏!

本篇其实是电脑远程投屏到另一台电脑的操作介绍。本篇文章的方法可用于Windows&#xff0c;macOS及Linux系统的相互投屏。 为了避免介绍过程中出现“这台电脑”投屏到“那台电脑”的混乱表述&#xff0c;假定当前屏幕投出端是Windows系统电脑&#xff0c;屏幕接收端是Linux系统…