论文阅读:Bayesian GAN

Bayesian GAN

点击访问paper
官方github
半监督学习对比算法

1.简介

贝叶斯 GAN(Saatchi 和 Wilson,2017)是生成对抗网络(Goodfellow,2014)的贝叶斯公式,我们在其中学习生成器参数 θ g \theta_g θg 和鉴别器参数 θ d \theta_d θd 的分布,而不是优化 用于点估计。 贝叶斯方法的优点包括在参数空间中灵活地建模多模态,以及在最大似然(非贝叶斯)情况下防止模式崩溃的能力。

我们通过称为“随机梯度哈密顿蒙特卡罗(SGHMC)”的近似推理算法来学习贝叶斯 GAN,这是一种基于梯度的 MCMC 方法,其样本近似于 θ g \theta_g θg θ d \theta_d θd 的真实后验分布。

贝叶斯 GAN 训练过程从固定分布(通常是标准 d-dim 正态分布)中采样噪声 z z z 开始。 噪声被馈送到生成器,其中参数 θ g \theta_g θg 从后验分布 p ( θ g ∣ D ) p(\theta_g | D) p(θgD) 中采样。 给定参数 θ g \theta_g θg ( G ( z ∣ θ g ) G(z|\theta_g) G(zθg)) 生成的图像以及真实数据呈现给鉴别器,其参数是从其后验分布 p ( θ d ∣ D ) p(\theta_d|D) p(θdD) 中采样的 。 我们使用梯度 ∂ log ⁡ p ( θ g ∣ D ) ∂ θ g \frac{\partial \log p(\theta_g|D) }{\partial \theta_g } θglogp(θgD) ∂ log ⁡ p ( θ d ∣ D ) ∂ θ d \frac{\partial \log p(\theta_d|D) }{\partial \theta_d } θdlogp(θdD) 更新后验与随机梯度哈密顿蒙特卡罗 (SGHMC)。

SGHMC 通过优化噪声损失

首先,观察到除了噪声 n \boldsymbol{n} n 之外,更新规则与动量 SGD 类似。 事实上,如果没有 n \boldsymbol{n} n,这相当于执行动量 SGD,损失为 − ∑ i = 1 J g ∑ k = 1 J d log ⁡ posterior - \sum_{i=1}{J_g} \sum_{k=1}^{J_d} \log \text{posterior} i=1Jgk=1Jdlogposterior。 为了简单起见,我们将描述 J g = J d = 1 J_g = J_d=1 Jg=Jd=1 的情况。

我们使用主要损失 L = − log ⁡ p ( θ ∣ . . ) \mathcal{L} = - \log p(\theta | ..) L=logp(θ∣..) 并添加噪声损失 L noise = 1 η θ ⋅ n \mathcal{L}_\text{noise} = \frac{1}{\eta } \theta \cdot \boldsymbol{n} Lnoise=η1θn 其中 n ∼ N ( 0 , 2 α η I ) \boldsymbol{n} \sim \mathcal{N}(0, 2 \alpha \eta I) nN(0,2αηI) 从而优化损失函数 L + L noise \mathcal{L} + \mathcal{L}_\text{noise} L+Lnoise 与动量 SGD 相当于执行 SGHMC 更新步骤。
在这里插入图片描述

2. 算法

下面(公式 3 和 4)是后验概率,其中每个误差项对应其负对数概率。
在这里插入图片描述
其中 K K K表示对象的总类别的数量。 D ( x ( i ) = y ( i ) ; θ d ) D(x^{(i)} = y^{(i)}; \theta_d) D(x(i)=y(i);θd) 表示辨别器认为样本 x ( i ) x^{(i)} x(i) 属于 y ( i ) y^{(i)} y(i) 类的概率。

2.1 对于鉴别器

errD = errD_real + errD_fake + err_sup + errD_prior + errD_noise

不知道 errD_noise是个啥

2.1.1 errD_real

需要保证无监督学习的差异性(优化分类)
errD_real = ∏ i = 1 n d ∑ y = 1 K D ( x ( i ) = y ; θ d ) \text{errD\_real}=\prod_{i=1}^{n_d}\sum_{y=1}^KD(x^{(i)}=y;\theta_d) errD_real=i=1ndy=1KD(x(i)=y;θd)
只需要输出 size = * × number of classs

errD_real = criterion_comp(output)
2.1.2 errD_fake

需要保证能够鉴别出假数据(优化鉴别)
errD_fake = ∏ i = 1 n g D ( G ( z ( i ) ; θ g ) = 0 ; θ d ) \text{errD\_fake}=\prod_{i=1}^{n_g}D(G(z^{(i)};\theta_g)=0;\theta_d) errD_fake=i=1ngD(G(z(i);θg)=0;θd)
需要辨别器以及标签全为0

output = netD(fake.detach())
labelv = Variable(torch.LongTensor(fake.data.shape[0]).cuda().fill_(fake_label))
errD_fake = criterion(output, labelv)
2.1.3 errD_sup

(优化监督分类)
errD_sup = ∏ i = 1 n s ∑ y = 1 K D ( x s ( i ) = y s ( i ) ; θ d ) \text{errD\_sup}=\prod_{i=1}^{n_s}\sum_{y=1}^KD(x^{(i)}_s=y_s^{(i)};\theta_d) errD_sup=i=1nsy=1KD(xs(i)=ys(i);θd)

output_sup = netD(input_sup_v)
err_sup = criterion(output_sup, target_sup_v)
2.2.4 errD_prior

p ( θ d ∣ α d ) p(\theta_d|\alpha_d) p(θdαd)

errD_prior = dprior_criterion(netD.parameters())
errD_prior.backward()
errD_noise = dnoise_criterion(netD.parameters())
errD_noise.backward()

2.2 生成器

2.2.1 errG

errG = ∏ i = 1 n g ∑ y = 1 K D ( z s ( i ) = y ; θ d ) \text{errG}=\prod_{i=1}^{n_g}\sum_{y=1}^KD(z^{(i)}_s=y;\theta_d) errG=i=1ngy=1KD(zs(i)=y;θd)

output = netD(fake)
errG = criterion_comp(output)
2.2.2 errG_prior
if opt.bayes:for netG in netGs:errG += gprior_criterion(netG.parameters())errG += gnoise_criterion(netG.parameters())

第三个链接中得到了如下图像。证明了用生成数据能够提升模型的泛化能力。接下来将详细分析泛化能力的来源
在这里插入图片描述

iteration = 0
for epoch in range(opt.niter):top1 = AverageMeter()top1_weakD = AverageMeter()for i, data in enumerate(dataloader):iteration += 1######## 1. real inputnetD.zero_grad()_input, _ = databatch_size = _input.size(0)if opt.cuda:_input = _input.cuda()input.resize_as_(_input).copy_(_input)       label.resize_(batch_size).fill_(real_label)  inputv = Variable(input)labelv = Variable(label)output = netD(inputv)errD_real = criterion_comp(output)errD_real.backward()# calculate D_x, the probability that real data are classified D_x = 1 - torch.nn.functional.softmax(output,dim=1).data[:, 0].mean()######## 2. Generated inputfakes = []for _idxz in range(opt.numz):noise.resize_(batch_size, opt.nz, 1, 1).normal_(0, 1)noisev = Variable(noise)for _idxm in range(opt.num_mcmc):idx = _idxz*opt.num_mcmc + _idxmnetG = netGs[idx]_fake = netG(noisev)fakes.append(_fake)fake = torch.cat(fakes)output = netD(fake.detach())labelv = Variable(torch.LongTensor(fake.data.shape[0]).cuda().fill_(fake_label))errD_fake = criterion(output, labelv)errD_fake.backward()D_G_z1 = 1 - torch.nn.functional.softmax(output,dim=1).data[:, 0].mean()######## 3. Labeled Data Part (for semi-supervised learning)for ii, (input_sup, target_sup) in enumerate(dataloader_semi):input_sup, target_sup = input_sup.cuda(), target_sup.cuda()breakinput_sup_v = Variable(input_sup.cuda())# convert target indicies from 0 to 9 to 1 to 10target_sup_v = Variable( (target_sup + 1).cuda())output_sup = netD(input_sup_v)err_sup = criterion(output_sup, target_sup_v)err_sup.backward()prec1 = accuracy(output_sup.data, target_sup + 1, topk=(1,))[0]top1.update(prec1.item(), input_sup.size(0))if opt.bayes:errD_prior = dprior_criterion(netD.parameters())errD_prior.backward()errD_noise = dnoise_criterion(netD.parameters())errD_noise.backward()errD = errD_real + errD_fake + err_sup + errD_prior + errD_noiseelse:errD = errD_real + errD_fake + err_supoptimizerD.step()# 4. Generatorfor netG in netGs:netG.zero_grad()labelv = Variable(torch.FloatTensor(fake.data.shape[0]).cuda().fill_(real_label))output = netD(fake)errG = criterion_comp(output)# print(errG)if opt.bayes:for netG in netGs:errG += gprior_criterion(netG.parameters())errG += gnoise_criterion(netG.parameters())errG.backward()D_G_z2 = 1 - torch.nn.functional.softmax(output,dim=1).data[:, 0].mean()for optimizerG in optimizerGs:optimizerG.step()# 5. Fully supervised training (running in parallel for comparison)netD_fullsup.zero_grad()input_fullsup = Variable(input_sup)target_fullsup = Variable((target_sup + 1))output_fullsup = netD_fullsup(input_fullsup)err_fullsup = criterion_fullsup(output_fullsup, target_fullsup)optimizerD_fullsup.zero_grad()err_fullsup.backward()optimizerD_fullsup.step()# 6. get test accuracy after every intervalif iteration % opt.stats_interval == 0:# get test accuracy on train and testnetD.eval()get_test_accuracy(netD, iteration, label='semi')get_test_accuracy(netD_fullsup, iteration, label='sup')netD.train()# 7. Report for this iterationcur_val, ave_val = top1.val, top1.avglog_value('train_acc', top1.avg, iteration)print('[%d/%d][%d/%d] Loss_D: %.2f Loss_G: %.2f D(x): %.2f D(G(z)): %.2f / %.2f | Acc %.1f / %.1f'% (epoch, opt.niter, i, len(dataloader),errD.data.item(), errG.item(), D_x, D_G_z1, D_G_z2, cur_val, ave_val))# after each epoch, save imagesvutils.save_image(_input,'%s/real_samples.png' % opt.outf,normalize=True)for _zid in range(opt.numz):for _mid in range(opt.num_mcmc):idx = _zid*opt.num_mcmc + _midnetG = netGs[idx]fake = netG(fixed_noise)vutils.save_image(fake.data,'%s/fake_samples_epoch_%03d_G_z%02d_m%02d.png' % (opt.outf, epoch, _zid, _mid),normalize=True)for ii, netG in enumerate(netGs):torch.save(netG.state_dict(), '%s/netG%d_epoch_%d.pth' % (opt.outf, ii, epoch))torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))torch.save(netD_fullsup.state_dict(), '%s/netD_fullsup_epoch_%d.pth' % (opt.outf, epoch))

接下来我们将借鉴此框架,融合这篇论文训练生成视频的算法。并用于视频分类。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/621695.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mybatisplus(service CRUD 接口)

一、我们在控制器层都是调用Service层&#xff0c;不会直接调用仓储层。现在我给大家介绍一下怎么快速实现Service 的CRUD 定义接口&#xff1a;IProductService 继承IService<实体> package com.saas.plusdemo;import com.baomidou.mybatisplus.extension.service.ISe…

Bootsrap-导航、栅格、及使用案例

文章目录 一、下载并导入Bootstrap中文文档二、Bootstrap初体验三、Boostrap导航栏四、Boostrap栅格五、博客案例六、用户登录界面七、后台管理界面八、引入图标九、Bootstrap动态效果 一、下载并导入Bootstrap中文文档 二、Bootstrap初体验 实现提交按钮&#xff0c;去中文文…

SpringBoot 入门 SpringBoot 与其他项目整合 集成 Druid 数据库连接池 集成 Log 日志 配置修改

目录 1.SpringBoot简介 1.1.什么是SpringBoot 1.2.特点 2.SpringBoot快速入门 2.1.创建SpringBoot项目 2.2.项目目录介绍 2.3.配置修改 2.4.启动SpringBoot 3.SpringBoot与其他项目整合 3.1.整合JDBC 3.2.整合Druid数据库连接池 3.3.整合MyBatis 3.4.整合Log日志 …

科研绘图(四)火山图

火山图是生物信息学中常用的一种图表&#xff0c;用来显示基因表达数据的变化。它通常将每个点表示为一个基因&#xff0c;x轴显示对数比率&#xff08;log ratio&#xff09;&#xff0c;表示基因表达的变化大小&#xff1b;y轴显示-log10(p-value)&#xff0c;表示变化的统计…

跨镜动线分析丨用AI解读顾客行为,助力零售企业运营与增长

步入数字时代&#xff0c;先进技术让传统零售焕发新生。智慧零售以用户为中心&#xff0c;“人”的数据化价值将反哺生产、渠道、销售、运营全场景。 悠络客正式推出“跨镜动线分析”&#xff0c;运用AI技术&#xff0c;深度分析顾客的进店、逛店等一系列行为&#xff0c;助力零…

host没有管理员权限

1 以管理员身份运行 Windows PowerShell 2 输入 notepad C:\Windows\System32\drivers\etc\hosts 3在自动弹出的host文件里添加信息&#xff0c;然后保存即可

Fluids —— Viscosity: honey

目录 Fixed viscosity: honey Point variable viscosity: honey Fixed viscosity: honey SOP FLIP提供的粘性解释器&#xff0c;可对恒定或变化的粘性&#xff1b;以下是恒定粘性的蜂蜜模拟&#xff0c;蜂蜜的特性与粘度和表面张力等参数相关&#xff0c;可观察到典型的缠绕和…

机器学习周报第28周

目录 摘要Abstract一、文献阅读1.题目&#xff1a;2.摘要3.问题描述4.过去方案5.论文方案6.论文模型7.相关代码 摘要 本周阅读了一篇混沌时间序列预测的论文&#xff0c;论文模型主要使用的是时间卷积网络&#xff08;Temporal Convolutional Network&#xff0c;TCN&#xff…

2624. 蜗牛排序

说在前面 &#x1f388;不知道大家对于算法的学习是一个怎样的心态呢&#xff1f;为了面试还是因为兴趣&#xff1f;不管是出于什么原因&#xff0c;算法学习需要持续保持。 题目描述 请你编写一段代码为所有数组实现 snail(rowsCount&#xff0c;colsCount) 方法&#xff0c;…

5.Pytorch模型单机多GPU训练原理与实现

文章目录 Pytorch的单机多GPU训练1)多GPU训练介绍2)pytorch中使用单机多GPU训练DistributedDataParallel(DDP)相关变量及含义a)初始化b)数据准备c)模型准备d)清理e)运行 3)使用DistributedDataParallel训练模型的一个简单实例 欢迎访问个人网络日志&#x1f339;&#x1f339;知…

数学建模day15-时间序列分析

时间序列也称动态序列&#xff0c;是指将某种现象的指标数值按照时间顺序排列而成的数值序列。时间序列分析大致可分成三大部分&#xff0c;分别是描述过去、分析规律和预测未来&#xff0c;本讲将主要介绍时间序列分析中常用的三种模型&#xff1a;季节分解、指数平滑方法和AR…

WEB服务器-Tomcat

3. WEB服务器-Tomcat 3.1 简介 3.1.1 服务器概述 服务器硬件 指的也是计算机&#xff0c;只不过服务器要比我们日常使用的计算机大很多。 服务器&#xff0c;也称伺服器。是提供计算服务的设备。由于服务器需要响应服务请求&#xff0c;并进行处理&#xff0c;因此一般来说…

【AI】人工智能和水下机器视觉

目录 一、初识水下机器视觉 ——不同点 ——难点 二、AI如何助力水下机器视觉 三、应用场景 四、关键技术 水下机器视觉&#xff0c;非常复杂&#xff0c;今天来简单讨论一下。因为目标识别更难。 水下机器视觉是机器视觉技术在水下环境中的应用&#xff0c;它与普通机器…

基于Springboot的网上点餐系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的网上点餐系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&am…

【2024】OAK智能深度相机校准教程

编辑&#xff1a;OAK中国 首发&#xff1a;oakchina.cn 喜欢的话&#xff0c;请多多&#x1f44d;⭐️✍ 内容可能会不定期更新&#xff0c;官网内容都是最新的&#xff0c;请查看首发地址链接。 ▌前言 Hello&#xff0c;大家好&#xff0c;这里是OAK中国&#xff0c;我是Ash…

机器人跟踪性能量化指标

衡量机械臂关节轨迹跟踪控制的性能可以通过以下几个方面来进行&#xff1a; 跟踪精度&#xff1a;这是衡量机械臂关节轨迹跟踪控制性能的最重要的指标。它反映了机械臂实际运动轨迹与期望运动轨迹之间的偏差。跟踪精度越高&#xff0c;说明机械臂的控制性能越好。运动范围&…

抖音小店怎么选品?分享如何培养选爆品的思维,每个人都要学会

选品定店铺生死。 一个店铺能不能出单&#xff0c;能不能赚钱&#xff0c;店铺的商品占主要部分&#xff0c;商品才是电商店铺最核心的内容&#xff0c;一个货真价实&#xff0c;物美价廉的产品才是店铺的核心竞争力&#xff0c;运营和找达人都是让产品卖的更多&#xff0c;更…

三、MySQL实例初始化、设置、服务启动关闭、环境变量配置、客户端登入(一篇足以从白走到黑)

目录 1、选择安装的电脑类型、设置端口号 2、选择mysql账号密码加密规则 3、设置root账户密码 4、设置mysql服务名和服务启动策略 5、执行设置&#xff08;初始化mysql实例&#xff09; 6、完成设置 7、MySQL数据库服务的启动和停止 方式一&#xff1a;图形化方式 方式…

AI智能剪辑,快速剪辑出需要的视频

AI智能剪辑技术&#xff0c;是一种基于人工智能的技术&#xff0c;它能够通过机器学习和深度学习算法&#xff0c;自动识别视频中的内容&#xff0c;并根据用户的需求和喜好&#xff0c;快速地剪辑出需要的视频。 所需工具 &#xff1a; 一个【媒体梦工厂】软件 视频素材 …

软件报错msvcp120.dll丢失怎么办?总共有6个msvcp120.dll丢失的解决方法分享

一、msvcp120.dll是什么文件&#xff1f; msvcp120.dll是Microsoft Visual C Redistributable Package的一部分&#xff0c;它是运行许多Windows应用程序所必需的动态链接库文件之一。它包含了许多C函数和类&#xff0c;用于支持各种应用程序的正常运行。因此&#xff0c;当ms…