进行生成简单数字图片

1.之前只能做一些图像预测,我有个大胆的想法,如果神经网络正向就是预测图片的类别,如果我只有一个类别那就可以进行生成图片,专业术语叫做gan对抗网络
在这里插入图片描述
2.训练代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import os# 设置环境变量
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 定义生成器模型
class Generator(nn.Module):def __init__(self, input_dim=100, output_dim=784):super(Generator, self).__init__()self.fc1 = nn.Linear(input_dim, 256)self.fc2 = nn.Linear(256, 512)self.fc3 = nn.Linear(512, 1024)self.fc4 = nn.Linear(1024, output_dim)self.relu = nn.ReLU()self.tanh = nn.Tanh()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.relu(self.fc3(x))x = self.tanh(self.fc4(x))return x# 定义判别器模型
class Discriminator(nn.Module):def __init__(self, input_dim=784, output_dim=1):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_dim, 1024)self.fc2 = nn.Linear(1024, 512)self.fc3 = nn.Linear(512, 256)self.fc4 = nn.Linear(256, output_dim)self.relu = nn.ReLU()self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.relu(self.fc3(x))x = self.sigmoid(self.fc4(x))return x# 加载 MNIST 手写数字图片数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
dataroot = "path_to_your_mnist_dataset"  # 替换为 MNIST 数据集的路径
dataset = dset.MNIST(root=dataroot, train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)# 创建生成器和判别器实例
input_dim = 100
output_dim = 784
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)# 定义优化器和损失函数
lr = 0.0002
beta1 = 0.5
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()# 训练 GAN 模型
num_epochs = 50
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)
generator.to(device)
discriminator.to(device)
for epoch in range(num_epochs):for i, data in enumerate(dataloader, 0):real_images, _ = datareal_images = real_images.to(device)batch_size = real_images.size(0)  # 获取批次样本数量# 训练判别器optimizer_d.zero_grad()real_labels = torch.full((batch_size, 1), 1.0, device=device)fake_labels = torch.full((batch_size, 1), 0.0, device=device)noise = torch.randn(batch_size, input_dim, device=device)fake_images = generator(noise)real_outputs = discriminator(real_images.view(batch_size, -1))fake_outputs = discriminator(fake_images.detach())d_loss_real = criterion(real_outputs, real_labels)d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_d.step()# 训练生成器optimizer_g.zero_grad()noise = torch.randn(batch_size, input_dim, device=device)fake_images = generator(noise)fake_outputs = discriminator(fake_images)g_loss = criterion(fake_outputs, real_labels)g_loss.backward()optimizer_g.step()# 输出训练信息if i % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"% (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# 保存生成器的权重和图片示例if epoch % 10 == 0:with torch.no_grad():noise = torch.randn(64, input_dim, device=device)fake_images = generator(noise).view(64, 1, 28, 28).cpu().numpy()fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(12, 12), sharex=True, sharey=True)for i, ax in enumerate(axes.flatten()):ax.imshow(fake_images[i][0], cmap='gray')ax.axis('off')plt.subplots_adjust(wspace=0.05, hspace=0.05)plt.savefig("epoch_%d.png" % epoch)plt.close()torch.save(generator.state_dict(), "generator_epoch_%d.pth" % epoch)

3.测试模型的代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image# 定义生成器模型
class Generator(nn.Module):def __init__(self, input_dim, output_dim):super(Generator, self).__init__()self.fc1 = nn.Linear(input_dim, 256)self.fc2 = nn.Linear(256, 512)self.fc3 = nn.Linear(512, 1024)self.fc4 = nn.Linear(1024, output_dim)def forward(self, x):x = F.leaky_relu(self.fc1(x), 0.2)x = F.leaky_relu(self.fc2(x), 0.2)x = F.leaky_relu(self.fc3(x), 0.2)x = torch.tanh(self.fc4(x))return x# 创建生成器模型
generator = Generator(input_dim=100, output_dim=784)# 加载预训练权重
generator_weights = torch.load("generator_epoch_40.pth", map_location=torch.device('cpu'))# 将权重加载到生成器模型
generator.load_state_dict(generator_weights)# 生成随机噪声
noise = torch.randn(1, 100)# 生成图像
fake_image = generator(noise).view(1, 1, 28, 28)# 保存生成的图片
save_image(fake_image, "generated_image.png", normalize=False)

#测试结果,由于我的训练集是数字的,所以会生成各种各样的数字,下面明显的是1
在这里插入图片描述
#应该也是1
在这里插入图片描述

#再次运行,我也看不出来,不过只要我训练只有一个种类的问题就可以生成这个种类的图像
在这里插入图片描述
#搞定黑白图,那彩色图应该距离不远了,我需要改进的是把对抗网络的代码改为训练一个种类的图形,不过我感觉这种图形具有随机性,虽然通过训练我们得到了所有图像他们的规律,但是如果需要正常点的图片还是挺难的,就像是上面这张人都不一定知道他是什么东西(在没有颜色的情况下)总结就是精度不够,而且随机性太强了,现在普遍图片AI生成工具具有这个缺点(生成的物体可能会扭曲,挺阴间的),而且生成的图片速度慢,如果谁比较受益那一定是老黄(英伟达)哈哈哈
//比如下面这个图片生成视频的网站
https://app.runwayml.com/login

#每一帧看起来都没有问题,就是连起来变成视频不自然,如果有改进方法的话那可能需要引入重力/加速度/光处理 等等物理公式,来让图片更自然…
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/213230.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

飞天使-rsync大文件断点续传与vim批量删除

文章目录 rsync 断点续传,亲测有效vim 批量删除消息 rsync 断点续传,亲测有效 rsync -vzrtp -P --append -e "/usr/bin/ssh -p 22 -o StrictHostKeyCheckingno" m.tar.gz root10.0.0.1:/tmp后台运行 screem 既可 或者 nohup rsync -vzrt…

【华为od】存在一个m*n的二维数组,其成员取值范围为0,1。其中值为1的元素具备扩散性,每经过1S,将上下左右值为0的元素同化为1。

存在一个m*n的二维数组,其成员取值范围为0,1。其中值为1的元素具备扩散性,每经过1S,将上下左右值为0的元素同化为1。将数组所有成员初始化为0,将矩阵的[i, j]和[m,n]位置上元素修改成1后,在经过多长时间所有元素变为1。 输入描述 输入的前两个数字是矩阵大小。后面是数字…

盛域宏数合伙人张天:AI时代,数字化要以AI重构

大数据产业创新服务媒体 ——聚焦数据 改变商业 在这个飞速发展的科技时代,数字化已经深刻地改变了我们的生活和商业方式。信息技术的迅猛发展使得数据成为现代社会最宝贵的资源之一。数字化已经不再是可选项,而是企业持续发展的必由之路。背靠着数据的…

【React】路由的基础使用

react-router-dom6的基础使用 1、安装依赖 npm i react-router-dom默认安装最新版本的 2、在src/router/index.js import { createBrowserRouter } from "react-router-dom"/* createBrowserRouter:[/home]--h5路由createHashRouter:[/#/ho…

Linux访问NFS存储及自动挂载

本章主要介绍NFS客户端的使用 创建NFS服务器并通过NFS共享一个目录在客户端上访问NFS共享的目录自动挂载的配置和使用 1.1 访问NFS存储 前面那篇介绍了本地存储,本章就来介绍如何使用网络上上的存储设备。NFS即网络文件系统,所实现的是Linux和Linux之…

通信:mqtt学习网址

看这个网址:讲的很详细,后面补实战例子 第一章 - MQTT介绍 MQTT协议中文版 (gitbooks.io)https://mcxiaoke.gitbooks.io/mqtt-cn/content/mqtt/01-Introduction.html

【网络编程】-- 04 UDP

网络编程 6 UDP 6.1 初识Tomcat 服务端 自定义 STomcat S 客户端 自定义 C浏览器 B 6.2 UDP 6.2.1 udp实现发送消息 接收端: package com.duo.lesson03;import java.net.DatagramPacket; import java.net.DatagramSocket; import java.net.SocketExceptio…

【论文极速读】LVM,视觉大模型的GPT时刻?

【论文极速读】LVM,视觉大模型的GPT时刻? FesianXu 20231210 at Baidu Search Team 前言 这一周,LVM在arxiv上刚挂出不久,就被众多自媒体宣传为『视觉大模型的GPT时刻』,笔者抱着强烈的好奇心,在繁忙工作之…

m.2固态硬盘怎么选择?

一、什么是固态硬盘 固态硬盘又称SSD,是Solid State Drive的简称,由于采用了闪存技术,其处理速度远远超过传统的机械硬盘,这主要是因为固态硬盘的数据以电子的方式存储在闪存芯片中,不需要像机械硬盘那样通过磁头读写磁…

linux查看笔记本电池健康情况

本人的老电脑,笔记本x1 carbon 5th 用久了,电池不行了,实际容量只有27.657%,充电到40%的时候,一瞬间彪满100%。到某宝淘了一个 model: 01AV430的电池,等更换了再看看使用情况 $ upower --help 用法:upower…

Linux 安装 中间件 Tuxedo

安装步聚 一、首先,下载中间件安装包: tuxedo121300_64_Linux_01_x86 Tuxedo下载地址: Oracle Tuxedo Downloads 二、新建Oracle用户组(创建Oracle用户时,需要root权限操作,登陆) [rootloca…

【CiteSpace】引文可视化分析软件CiteSpace下载与安装

CiteSpace 译“引文空间”,是一款着眼于分析科学分析中蕴含的潜在知识,是在科学计量学、数据可视化背景下逐渐发展起来的引文可视化分析软件。由于是通过可视化的手段来呈现科学知识的结构、规律和分布情况,因此也将通过此类方法分析得到的可…

【Spring教程23】Spring框架实战:从零开始学习SpringMVC 之 SpringMVC简介与SpringMVC概述

目录 1,SpringMVC简介2、SpringMVC概述 欢迎大家回到《Java教程之Spring30天快速入门》,本教程所有示例均基于Maven实现,如果您对Maven还很陌生,请移步本人的博文《如何在windows11下安装Maven并配置以及 IDEA配置Maven环境》&…

python使用vtk与mayavi三维可视化绘图

VTK(Visualization Toolkit)是3D计算机图形学、图像处理和可视化的强大工具。它可以通过Python绑定使用,适合于科学数据的复杂可视化。Mayavi 依赖于 VTK (Visualization Toolkit),一个用于 3D 计算机图形、图像处理和可视化的强大…

AS安装目录

编辑器: sdk: gradle: gradle使用的jdk目录:Gradle使用的jdk是android studio安装目录下的jbr 成功项目的android studio配置:

H264码流结构

视频编码的码流结构是指视频经过编码之后得到的二进制数据是怎么组织的,或者说,就是编码后的码流我们怎么将一帧帧编码后的图像数据分离出来,以及在二进制码流数据中,哪一块数据是一帧图像,哪一块数据是另外一帧图像。…

C++面试宝典第4题:合并链表

题目 有一个链表,其节点声明如下: struct TNode {int nData;struct TNode *pNext;TNode(int x) : nData(x), pNext(NULL) {} }; 现给定两个按升序排列的单链表pA和pB,请编写一个函数,实现这两个单链表的合并。合并后,…

scheduleatfixedrate详解

scheduleatfixedrate详解 大家好,我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!在Java开发中,我们常常需要执行定时任务,并且需要保证任务按照一定…

使用Java实现基数排序算法

文章目录 基数排序算法 基数排序算法 (1)基本思想:将整数按位数切割成不同的数字,然后按每个位数分别比较。 (2)排序过程:将所有待比较数值(正整数)统一为同样的数位长…

Vuex快速上手

一、Vuex 概述 目标:明确Vuex是什么,应用场景以及优势 1.是什么 Vuex 是一个 Vue 的 状态管理工具,状态就是数据。 大白话:Vuex 是一个插件,可以帮我们管理 Vue 通用的数据 (多组件共享的数据)。例如:购…