用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriterbatch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)# Generator with Conv2D structure
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(512),nn.ReLU(),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(256),nn.ReLU(),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(32),nn.ReLU(),nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, z):img = self.model(z)return img# Discriminator with Conv2D structure
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),)def forward(self, img):validity = self.model(img)return validity# Initialize GAN components
generator = Generator()
discriminator = Discriminator()# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0# Training loop
num_epochs = 200
for epoch in range(num_epochs):for batch_idx, real_data in enumerate(dataloader):real_data = real_data.to(device)# Train Discriminatoroptimizer_D.zero_grad()real_labels = torch.ones(real_data.size(0), 1).to(device)fake_labels = torch.zeros(real_data.size(0), 1).to(device)z = torch.randn(real_data.size(0), 100, 1, 1).to(device)fake_data = generator(z)real_pred = discriminator(real_data)fake_pred = discriminator(fake_data.detach())d_loss_real = criterion(real_pred, real_labels)d_loss_fake = criterion(fake_pred, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_D.step()# Train Generatoroptimizer_G.zero_grad()z = torch.randn(real_data.size(0), 100, 1, 1).to(device)fake_data = generator(z)fake_pred = discriminator(fake_data)g_loss = criterion(fake_pred, real_labels)g_loss.backward()optimizer_G.step()# Print progressif batch_idx % 100 == 0:print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")with torch.no_grad():img_grid_real = torchvision.utils.make_grid(fake_data#, normalize=True,)img_grid_fake = torchvision.utils.make_grid(real_data#, normalize=True)writer_fake.add_image("fake_img", img_grid_fake, global_step=step)writer_real.add_image("real_img", img_grid_real, global_step=step)step += 1# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/115791.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QTday05(TCP的服务端客户端通信)

实现聊天室功能 服务端代码&#xff1a; pro文件需要导入 network 头文件&#xff1a; #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTcpServer>//服务端 #include <QTcpSocket>//客户端 #include <QList> #include <QMes…

超声波清洗机频率如何选择?高频和低频有什么区别

超声波清洗原理就是在清洗液中产生“空化效应”&#xff0c;即清洗液产生拉伸和压缩现象&#xff0c;清洗液拉伸时会产生大量微小气泡&#xff0c;清洗液压缩时气泡会被压碎破裂。这些气泡产生和破裂的局部压强可达到上千个大气压的冲击力&#xff0c;这种极强大的压力足以使得…

Leetcode 1089. 复写零

复写零 题目链接1089. 复写零 给你一个长度固定的整数数组 arr &#xff0c;请你将该数组中出现的每个零都复写一遍&#xff0c;并将其余的元素向右平移。 注意&#xff1a;请不要在超过该数组长度的位置写入元素。请对输入的数组 就地 进行上述修改&#xff0c;不要从函数返回…

网络协议--UDP:用户数据报协议

11.1 引言 UDP是一个简单的面向数据报的运输层协议&#xff1a;进程的每个输出操作都正好产生一个UDP数据报&#xff0c;并组装成一份待发送的IP数据报。这与面向流字符的协议不同&#xff0c;如TCP&#xff0c;应用程序产生的全体数据与真正发送的单个IP数据报可能没有什么联…

并查集学习笔记

在一些有 n n n 个元素的集合应用问题中&#xff0c;我们通常是在开始时让每个元素构成一个单元素的集合&#xff0c;我们通常是在开始时让每个元素构成一个单元素的集合&#xff0c;然后按一定顺序将属于同一族的元素所在的集合合并&#xff0c;期间要反复查找一个元素在哪个…

老卫带你学---leetcode刷题(56. 合并区间)

56. 合并区间 问题&#xff1a; 以数组 intervals 表示若干个区间的集合&#xff0c;其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间&#xff0c;并返回 一个不重叠的区间数组&#xff0c;该数组需恰好覆盖输入中的所有区间 。 示例 1&#xff1a;输…

一分钟带你了解什么是0day攻击什么是Nday攻击

1. 什么是零日漏洞 零日攻击是指利用零日漏洞对系统或软件应用发动的网络攻击。 零日漏洞也称零时差漏洞&#xff0c;通常是指还没有补丁的安全漏洞。由于零日漏洞的严重级别通常较高&#xff0c;所以零日攻击往往也具有很大的破坏性。目前&#xff0c;任何安全产品或解决方案…

吃瓜教程2|线性模型

线性回归 “广义的线性模型”&#xff08;generalized linear model&#xff09;&#xff0c;其中&#xff0c;g&#xff08;*&#xff09;称为联系函数&#xff08;link function&#xff09;。 线性几率回归&#xff08;逻辑回归&#xff09; 线性判别分析 想让同类样本点的…

ATPCS:ARM-Thumb程序调用的基本规则

为了使单独编译的c文件和汇编文件之间能够互相调用&#xff0c;需要制定一系列的规则&#xff0c;AAPCS就是ARM程序和Thumb程序中子程序调用的基本规则。 1、ATPCS概述 ATPCS规定了子程序调用过程中寄存器的使用规程、数据站的使用规则、参数的传递规则。为了适应一些特殊的需…

10月22日,每日信息差

今天是2023年10月22日&#xff0c;以下是为您准备的13条信息差 第一、库迪咖啡计划到2025年底全球门店数量达2万家&#xff0c;库迪咖啡开业一周年全球门店数量达到6061家&#xff0c;位居全球第四 第二、超高速纯硅调制器取得创纪录突破&#xff0c;国际上首次把纯硅调制器带…

【C++】继承 ⑧ ( 继承 + 组合 模式的类对象 构造函数 和 析构函数 调用规则 )

文章目录 一、继承 组合 模式的类对象 构造函数和析构函数调用规则1、场景说明2、调用规则 二、完整代码示例分析1、代码分析2、代码示例 一、继承 组合 模式的类对象 构造函数和析构函数调用规则 1、场景说明 如果一个类 既 继承了 基类 ,又 在类中 维护了一个 其它类型 的…

腾讯地图基本使用(撒点位,点位点击,弹框等...功能) 搭配Vue3

腾讯地图的基础注册账号 展示地图等基础功能在专栏的上一篇内容 大家有兴趣可以去看一看 今天说的是腾讯地图的在稍微一点的基础操作 话不多说 直接上代码 var marker ref(null) var map var center ref(null) // 地图初始化 const initMap () > {//定义地图中心点坐标…

Mysql数据库 3.SQL语言 DML数据操纵语言 增删改

DML语句&#xff1a;用于完成对数据表中数据的插入、删除、修改操作 一.表数据插入 插入数据语法&#xff1a; 步骤例&#xff1a; 1.声明数据库&#xff1a;use 数据库名; 2.删除操作&#xff1a;drop table if exists 表名; 3.创建数据库中的表&#xff1a;create table 表…

Java,回形数

回形数基本思路&#xff1a; 用不同的四个分支分别表示向右向下向左向上&#xff0c;假如i表示数组的行数&#xff0c;j表示数组的列数&#xff0c;向右向左就是控制j的加减&#xff0c;向上向下就是控制i的加减。 class exercise {public static void main(String[] args){//回…

【AOA-VMD-LSTM分类故障诊断】基于阿基米德算法AOA优化变分模态分解VMD的长短期记忆网络LSTM分类算法(Matlab代码)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

Python配置镜像源

Python3安装pika的准备 Windows下配置镜像源可以按照如下操作。 1.winR执行%APPDATA% %APPDATA%后&#xff0c;创建pip文件夹&#xff0c;并创建pip.ini配置文件 查看此目录下是否有pip目录&#xff0c;如果没有则需要创建&#xff0c;并在pip目录下以文本方式添加pip.ini文件…

Node学习笔记之fs模块

fs 全称为 file system &#xff0c;称之为 文件系统 &#xff0c;是 Node.js 中的 内置模块 &#xff0c;可以对计算机中的磁盘进行操 作。 本章节会介绍如下几个操作&#xff1a; 文件写入文件读取文件移动与重命名文件删除文件夹操作查看资源状态 一、文件写入 文件写入就…

maven以及配置

oss oss配置 <!--oss--> <dependency><groupId>com.aliyun.oss</groupId><artifactId>aliyun-sdk-oss</artifactId><version>3.6.0</version></dependency> lombok <!--lombok--><dependency><gro…

什么是工程化思维

什么是工程化思维&#xff1f; 工程化思维&#xff08;Engineering Thinking&#xff09;是一种解决问题和做决策的方法&#xff0c;它强调系统性、逻辑性和效率。这种思维方式在工程学和其他技术领域中尤为常见&#xff0c;但其实它可以应用于生活的各个方面。以下是工程化思…

Telent

Telnet协议是一种远程登录协议&#xff0c;它允许用户通过网络连接到远程主机并在远程主机上执行命令。 Telnet协议是TCP/IP协议族中的一员&#xff0c;是Internet远程登录服务的标准协议和主要方式。它为用户提供了在本地计算机上完成远程主机工作的能力。在终端使用者的电脑…