GAN的损失函数和二元交叉熵损失的对应及代码

以下解释为GPT生成

在这里插入图片描述
在这里插入图片描述在这里插入图片描述
这里有个问题,使用二元交叉熵,的时候生成器的损失如何体现
在这里插入图片描述
看代码

import torch
import torch.nn as nn
import torch.optim as optim# 设置设备为GPU或CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义生成器 (Generator)
class Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(input_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, output_size),nn.Tanh()  # 输出范围在 [-1, 1] 之间)def forward(self, x):return self.model(x)# 定义判别器 (Discriminator)
class Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(input_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, output_size),nn.Sigmoid()  # 输出概率 [0, 1])def forward(self, x):return self.model(x)# 超参数设置
input_size = 100  # 生成器输入噪声向量的维度
hidden_size = 128
output_size = 1   # 判别器的输出是一个概率
data_size = 784   # 假设输入数据的维度(例如,MNIST 图片 28x28 展开成 784 维向量)# 初始化生成器和判别器
G = Generator(input_size, hidden_size, data_size).to(device)
D = Discriminator(data_size, hidden_size, output_size).to(device)# 定义损失函数(二元交叉熵损失)
criterion = nn.BCELoss()# 优化器
lr = 0.0002
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)# 生成随机噪声的函数(用于生成器)
def generate_noise(batch_size, input_size):return torch.randn(batch_size, input_size).to(device)# 假设我们有一个简单的生成数据和真实数据的函数
def get_real_data(batch_size):# 这里我们用随机生成的假数据来模拟真实数据return torch.randn(batch_size, data_size).to(device)# 训练步骤
epochs = 1000
batch_size = 64for epoch in range(epochs):# 训练判别器real_data = get_real_data(batch_size)  # 获取真实数据noise = generate_noise(batch_size, input_size)  # 生成噪声fake_data = G(noise)  # 生成数据# 判别器的目标:正确区分真实数据和生成数据real_labels = torch.ones(batch_size, 1).to(device)  # 真实数据的标签为1fake_labels = torch.zeros(batch_size, 1).to(device)  # 生成数据的标签为0# 判别器对真实数据的损失outputs_real = D(real_data)D_loss_real = criterion(outputs_real, real_labels)# 判别器对生成数据的损失outputs_fake = D(fake_data.detach())  # 对生成数据的判别(生成数据不传递梯度给生成器)D_loss_fake = criterion(outputs_fake, fake_labels)# 判别器总损失D_loss = D_loss_real + D_loss_fake# 更新判别器optimizer_D.zero_grad()D_loss.backward()optimizer_D.step()# 训练生成器noise = generate_noise(batch_size, input_size)  # 生成新的噪声fake_data = G(noise)  # 生成新的假数据# 生成器的目标:欺骗判别器,让判别器认为生成的数据是真实的outputs_fake = D(fake_data)G_loss = criterion(outputs_fake, real_labels)  # 生成器希望生成的数据被判为真实数据,因此标签设为1# 更新生成器optimizer_G.zero_grad()G_loss.backward()optimizer_G.step()# 每隔一段时间打印损失if epoch % 100 == 0:print(f"Epoch [{epoch}/{epochs}] | D Loss: {D_loss.item():.4f} | G Loss: {G_loss.item():.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/54115.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

EndnoteX9安装及使用教程

EndnoteX9安装及使用教程 一、EndNote安装 1.1 下载 这里提供一个下载链接: 链接:https://pan.baidu.com/s/1RlGJksQ67YDIhz4tBmph6Q 提取码:5210 解压完成后,如下所示: 1.2 安装 双击右键进行安装 安装比较简单…

【C++11 —— 线程库】

C11 —— 线程库 thread类介绍线程函数参数原子性操作库(atomic)lock_guard与unique_lockmutex的种类lock_guardunique_lock 两个线程交替打印奇偶数 thread类介绍 在C11之前,涉及到多线程的问题,都是和平台相关的,比如windows和Linux下各有…

AI 时代程序员的应变之道

一、AI 浪潮来袭,编程界风云变幻 随着 AIGC 大语言模型如 ChatGPT、Midjourney、Claude 等的涌现,AI 辅助编程工具日益普及,程序员的工作方式正经历着深刻的变革。 分析公司 OReilly 日前发布的《2023 Generative AI in the Enterprise》报告…

【Linux基础】冯诺依曼体系结构操作系统的理解

目录 前言一,冯诺依曼体系1. 为什么有内存结构?2. 对硬件中数据流动的再理解 二,操作系统(Operator System)1. 概念2. 操作系统结构的层状划分3. 操作系统对硬件管理的理解4. 用户与操作系统的关系的理解5. 系统调用和库函数的关系6. 为什么要有操作系统…

策略路由与路由策略的区别

🐣个人主页 可惜已不在 🐤这篇在这个专栏 华为_可惜已不在的博客-CSDN博客 🐥有用的话就留下一个三连吧😼 目录 一、主体不同 二、方式不同 三、规则不同 四、定义和基本概念 一、主体不同 1、路由策略:是为了改…

android 删除系统原有的debug.keystore,系统运行的时候,重新生成新的debug.keystore,来完成App的运行。

1、先上一个图:这个是keystore无效的原因 之前在安装这个旧版本android studio的时候呢,安装过一版最新的android studio,然后通过模拟器跑过测试的demo。 2、运行旧的项目到模拟器的时候,就报错了: Execution failed…

后台数据管理系统 - 项目架构设计-Vue3+axios+Element-plus(0916)

接口文档: https://apifox.com/apidoc/shared-26c67aee-0233-4d23-aab7-08448fdf95ff/api-93850835 接口根路径: http://big-event-vue-api-t.itheima.net 本项目的技术栈 本项目技术栈基于 ES6、vue3、pinia、vue-router 、vite 、axios 和 element-plus http:/…

RabbitMQ(高阶使用)死信队列

文章内容是学习过程中的知识总结,如有纰漏,欢迎指正 文章目录 一、什么是死信队列? 二、死信队列使用场景 三、死信队列如何使用 四、打车超时处理 1.打车超时实现 以下是本篇文章正文内容 一、什么是死信队列? 先从概念解释上搞…

idea插件推荐之Cool Request

Cool Request是一款基于IDEA的HTTP调试工具,可以看成是轻量版的postman,它会自动扫描项目代码中所有API路径,按项目分组管理。一个类被定义为Controller且其中的方法被RequestMapping或者XXXMapping注解标注以后就会被扫描到。 对应方法左侧会…

智能硬件从零开始的设计生产流程

文章目录 市场分析团队组建ID设计结构设计pcba设计软件开发手板EVT开模DVTPVTMP 智能硬件研发是一个复杂的过程, 当然一件事要发出萌芽必须得有人, 有一天,几个合伙人凑在一起,说一起开发个智能硬件产品吧,于是故事开始了. 市场分析 合伙人: 万物互联的时代, 智能音箱已经成为…

LDR6020,单C口OTG,充放一体新潮流!

PD(Power Delivery)芯片实现单Type-C接口输入和输出OTG(On-The-Go)功能,主要是通过支持USB Power Delivery规范和OTG功能的特定硬件和软件设计来实现的。以下是对这一过程的具体解释: 一、PD芯片基础功能 …

Chainlit集成Langchain并使用通义千问实现和数据库交互的网页对话应用增强扩展(text2sql)

前言 我在上一篇文章中《Chainlit集成Langchain并使用通义千问实现和数据库交互的网页对话应用(text2sql)》 利用langchain 中create_sql_agent 创建一个数据库代理智能体,但是实测中发现,使用 create_sql_agent 在对话中&#x…

Qt控制开发板的LED

Qt控制开发板的LED 使用开发板的IO接口进行控制是嵌入式中非常重要的一点,就像冯诺依曼原理说的一样,一个计算机最起码要有输入输出吧,我们有了信息的接收和处理,那我们就要有输出。 我们在开发板上一般都是使用开发板的GPIO接口…

七、垃圾收集器ParNewCMS与底层三色标记算法详解

文章目录 垃圾收集算法分代收集理论标记-复制算法标记-清除算法标记-整理算法 垃圾收集器1.1 Serial收集器(-XX:UseSerialGC -XX:UseSerialOldGC)1.2 Parallel Scavenge收集器(-XX:UseParallelGC(年轻代),-XX:UseParallelOldGC(老年代))1.3 ParNew收集器(-XX:UseParNewGC)1.4 C…

MATLAB 可视化基础:绘图命令与应用

目录 1. 绘制子图1.1基本绘图命令1.2. 使用 subplot 函数1.3. 绘图类型 2.MATLAB 可视化进阶(以下代码均居于以上代码的数据定义上实现)2.1. 极坐标图2.3. 隐函数的绘制 3.总结 在数据分析和科学计算中,数据可视化是理解和解释结果的关键工具。今天,我将…

Text2vec -文本转向量

文章目录 一、关于 Text2vec1、Text2vec 是什么2、Features3、Demo4、News5、Evaluation英文匹配数据集的评测结果:中文匹配数据集的评测结果: 6、Release Models 二、Install三、使用1、文本向量表征1.2 Usage (HuggingFace Transformers)1.3 Usage (se…

标准库标头 <barrier>(C++20)学习

此头文件是线程支持库的一部分。 类模板 std::barrier 提供一种线程协调机制,阻塞已知大小的线程组直至该组中的所有线程到达该屏障。不同于 std::latch,屏障是可重用的:一旦到达的线程组被解除阻塞,即可重用同一屏障。与 std::l…

深度学习之微积分预备知识点

极限(Limit) 定义:表示某一点处函数趋近于某一特定值的过程,一般记为 极限是一种变化状态的描述,核心思想是无限靠近而永远不能到达 公式: 表示 x 趋向 a 时 f(x) 的极限。 知识点口诀解释极限的存在左…

C语言 | Leetcode C语言题解之第412题Fizz Buzz

题目&#xff1a; 题解&#xff1a; /*** Note: The returned array must be malloced, assume caller calls free().*/ char ** fizzBuzz(int n, int* returnSize) {/*定义字符串数组*/char **answer (char**)malloc(sizeof(char*)*n);for(int i 1;i<n;i){/*分配单个字符串…

React学习day06-异步操作、ReactRouter的概念及简单使用

13、续 &#xff08;8&#xff09;异步状态操作 1&#xff09;在子仓库中 ①创建仓库 ②解构需要的方法 ③安装axios ④封装并导出请求 ⑤在reducer中为newsList赋值 ⑥获取并导出reducer函数 2&#xff09;在入口文件index.js中&#xff0c;注入 3&#xff09;在App.js中&a…