利用GAN原始框架生成手写数字

这一篇GAN文章只是让产生的结果尽量真实,还不能分类。

本次手写数字GAN的思想:

对于辨别器,利用真实的手写数字(真样本,对应的标签为真标签)和随机噪声经过生成器产生的样本(假样本,对应的标签为假标签)送入辨别器,分别得到两个损失值,最小化这两个损失值,这样的话就能保证辨别器能分清楚真假。

而对于生成器,用产生的随机噪声送入生成器,产生样本(假样本,对应的标签为真标签),得到损失值,最小化损失值,注意标签要改为真标签,因为这样才能以假乱真,正是辨别器和生成器博弈的过程,使得生成的数据能够以假乱真,为什么博弈到平衡的状态随机噪声能够以假乱真呢?看公式,论证了全局最优值,就是噪声和真实数据相等的时候。

详细推导过程

下面看代码:test做测试用的,不用管。

路径:

util.py代码:

import tensorflow as tf
import numpy as np
"""
从正太分布输出随机值
"""
def xavier_init(size):in_dim=size[0]xavier_stddev=tf.sqrt(2./in_dim)return tf.random_normal(shape=size,stddev=xavier_stddev)#生成模型的输入和参数初始化G_W1=tf.Variable(xavier_init(size=[100,128]))
G_b1=tf.Variable(tf.zeros(shape=[128]))G_W2=tf.Variable(xavier_init(size=[128,784]))
G_b2=tf.Variable(tf.zeros(shape=[784]))theta_G=[G_W1,G_W2,G_b1,G_b2]#判别模型的输入和参数初始化D_W1=tf.Variable(xavier_init(size=[784,128]))
D_b1=tf.Variable(tf.zeros(shape=[128]))D_W2=tf.Variable(xavier_init(size=[128,1]))
D_b2=tf.Variable(tf.zeros(shape=[1]))theta_D=[D_W1,D_W2,D_b1,D_b2]"""
随机噪声产生
"""
def sample_z(m,n):return np.random.uniform(-1.0,1.0,size=[m,n])
"""
生成模型:产生数据
"""
def generator(z):G_h1=tf.nn.relu(tf.matmul(z,G_W1)+G_b1)G_log_prob=tf.matmul(G_h1, G_W2) + G_b2G_prob=tf.nn.sigmoid(G_log_prob)return G_prob"""
判别模型:真实值和概率值
"""
def discriminator(x):D_h1=tf.nn.relu(tf.matmul(x,D_W1)+D_b1)D_logit=tf.matmul(D_h1, D_W2) + D_b2D_prob=tf.nn.sigmoid(D_logit)return D_prob,D_logit

main.py代码:

import tensorflow as tf
import numpy as np
from GAN.TWO import util
import os
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
#读入数据
mnist=input_data.read_data_sets('./data',one_hot=True)
# print(mnist)Z=tf.placeholder(tf.float32,shape=[None,100])X=tf.placeholder(tf.float32,shape=[None,784])
#喂入数据
G_sample=util.generator(Z)
D_real,D_logit_real=util.discriminator(X)
D_fake,D_logit_fake=util.discriminator(G_sample)
#计算loss
D_real_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real,labels=tf.ones_like(D_logit_real)))
D_fake_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake,labels=tf.zeros_like(D_logit_fake)))
D_loss=D_fake_loss+D_real_lossG_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake,labels=tf.ones_like(D_logit_fake)))D_optimizer=tf.train.AdamOptimizer().minimize(D_loss,var_list=util.theta_D)
G_optimizer=tf.train.AdamOptimizer().minimize(G_loss,var_list=util.theta_G)if not os.path.exists('out/'):os.makedirs('out/')
"""
画图
"""
def plot(samples):gs=gridspec.GridSpec(4,4)gs.update(wspace=0.05,hspace=.05)for i,sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(28,28),cmap='Greys_r')
print("=====================开始训练============================")
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for it in range(100000):X_mb,_=mnist.train.next_batch(batch_size=128)# print(X_mb)_,D_loss_curr=sess.run([D_optimizer,D_loss],feed_dict={X:X_mb,Z:util.sample_z(128,100)})_, G_loss_curr = sess.run([G_optimizer, G_loss],feed_dict={Z: util.sample_z(128, 100)})if it%1000==0:print('====================打印出生成的数据============================')samples=sess.run(G_sample,feed_dict={Z: util.sample_z(16, 100)})plot(samples)plt.show()if it%1000==0:print('iter={}'.format(it))print('D_loss={}'.format(D_loss_curr))print('G_loss={}'.format(G_loss_curr))

打印结果:

迭代0次。

迭代50000次。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493548.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

DL也懂纹理吗——图像的纹理特征

工作中遇到一个问题:对于同一场景,训练好的DL模型能把大部分样本分类准确,而对于少量负样本,DL会错分到另外一个对立的类中。错分的样本可以认为是难分的样本,但是我们还想知道这两种样本到底是哪里的差异导致DL做出了…

排序算法--(冒泡排序,插入排序,选择排序,归并排序,快速排序,桶排序,计数排序,基数排序)

一.时间复杂度分析 - **时间复杂度**:对排序数据的总的操作次数。反应当n变化时,操作次数呈现什么规律 - **空间复杂度**:算法在计算机内执行时所需要的存储空间的容量,它也是数据规模n的函数。 1.例题: 有一个字符串数组&…

肠里细菌“肚里蛔虫”:肠脑研究缘何越来越热

来源:科学网最懂你大脑的,可能不是“肚子里的蛔虫”,而是肠子里的细菌——肠道菌群对神经系统、心理和行为方面的影响正成为一个新兴热点领域。在日前举办的美国神经科学学会年会上,一张海报上的大脑切片显微镜图像显示&#xff0…

SVM原理与实战

先看线性可分问题。对于线性可分,其实感知机就可以解决。但是感知机只是找到一个超平面将数据分开,而这样的超平面可能是平行的无限多个,我们需要在这其中找到最优的一个。怎么衡量一个超平面是不是最优的呢,直观上讲,…

2014-01-01

一:HyperlinkButton点击后打开新窗口的方法 1,直接在界面中写这段代码就可以了: <HyperlinkButton NavigateUri"http://www.cnblogs.com/wsdj-ITtech/" Content"Click Me" TargetName"_blank" FontSize"28" Height"50"…

李飞飞高徒:斯坦福如何打造基于视觉的智能医院?

作者&#xff1a;Albert Haque、Michelle Guo来源&#xff1a;机器之心自 2009 年担任斯坦福人工智能实验室和视觉实验室的负责人&#xff0c;李飞飞在推动计算机视觉方面研究的同时&#xff0c;还密切关注 AI 医疗的发展。昨日&#xff0c;李飞飞离任斯坦福 AI 实验室负责人一…

tensorflow知识点

一.bazel编译tensorflow注意版本号: 在/tensorflow/tensorflow/configure.py 查看bazel版本号 https://github.com/tensorflow/tensorflow https://github.com/bazelbuild/bazel/releases?after0.26.1 https://tensorflow.google.cn/ 二&#xff0c;基础知识点 1.打印出…

eclipse中如何导入jar包

如图&#xff0c;首先右键点击项目&#xff0c;选择最下面的properties&#xff0c; 然后进去之后点击java build path&#xff0c;右边会出来4个选项卡&#xff0c;选择libraries&#xff0c; 这时候最右边会有多个选项&#xff0c;第一个add jars是添加项目文件中的jar包&…

线性-LR-softmax傻傻分不清楚

softmax 对于分类网络&#xff0c;最后一层往往是全连接层&#xff0c;如果是N分类&#xff0c;那么最终的全连接层有N个结点。很显然&#xff0c;每个节点对应一个类&#xff0c;该节点的权重越大&#xff0c;说明网络越倾向于认为输入样本属于该类。这其实就是Softmax的思想…

一图看懂国外智能网联汽车传感器产业发展!

来源&#xff1a;赛迪智库编辑&#xff1a;煜 佳未来智能实验室是人工智能学家与科学院相关机构联合成立的人工智能&#xff0c;互联网和脑科学交叉研究机构。未来智能实验室的主要工作包括&#xff1a;建立AI智能系统智商评测体系&#xff0c;开展世界人工智能智商评测&#…

深度学习中的信息论——交叉熵

信息量 可以说就信息量是在将信息量化。首先信息的相对多少是有切实体会的&#xff0c;有的人一句话能包含很多信息&#xff0c;有的人说了等于没说。我们还可以直观地感觉到信息的多少和概率是有关的&#xff0c;概率大的信息也相对低一些。为了量化信息&#xff0c;一个做法…

传统手工特征--opencv

一&#xff0c;颜色特征&#xff1a; 简单点来说就是将一幅图上的各个像素点颜色统计出来&#xff0c;适用颜色空间&#xff1a;RGB&#xff0c;HSV等颜色空间&#xff0c; 具体操作&#xff1a;量化颜色空间&#xff0c;每个单元&#xff08;bin&#xff09;由单元中心代表&…

特写李飞飞:她激励了人工智能的发展,更要给人工智能赋予人的价值

文 | MrBear 编辑 | 杨晓凡来源&#xff1a;雷锋网摘要&#xff1a;李飞飞无疑是人工智能界最响亮的名字之一。她既对机器学习领域的发展做出了杰出的贡献&#xff0c;也是普通大众眼中温和的人工智能技术宣扬者&#xff0c;还是谷歌这一科技巨头的人工智能技术领导人之一。WI…

Chap-4 Section 4.2.4 指令修正方式

对于X86平台下的ELF文件的重定位入口所修正的指令寻址方式只有两种&#xff1a;绝对近址32寻址和相对近址32寻址。 这两种指令修正方式每个被修正的位置的长度都为32位&#xff0c;即4个字节&#xff0c;而且都是近址寻址&#xff0c;不用考虑Intel的段间远址寻址。r_info成员的…

没见过女人的小和尚——SVDD

是的&#xff0c;即便是出生在山上的小和尚&#xff0c;从来没有下过山&#xff0c;没有见过女人&#xff0c;但是一旦有女施主上山&#xff0c;小和尚依然可以轻松地区分出眼前的人是如此不同。 传统的SVM是寻找一个超平面&#xff0c;而SVDD寻找的超平面更进一步&#xff0c…

解读GAN及其 2016 年度进展

作者&#xff1a;程程 链接&#xff1a;https://zhuanlan.zhihu.com/p/25000523 来源&#xff1a;知乎 著作权归作者所有。商业转载请联系作者获得授权&#xff0c;非商业转载请注明出处。 GAN&#xff0c;全称为Generative Adversarial Nets&#xff0c;直译为生成式对抗网络…

全国首套中小学生人工智能教材在沪亮相

来源&#xff1a;网络大数据中小学 AI 教材正式亮相11 月 18 日&#xff0c;优必选与华东师范大学出版社共同发布了《AI 上未来智造者——中小学人工智能精品课程系列丛书》&#xff08;以下简称“AI 上未来智造者”丛书&#xff09;。据了解&#xff0c;该丛书根据教育部“义务…

numpy基础知识点

1. np.squeeze 一,np.squeeze """ np.squeeze 删除单维度的条 对多维度无效 """ import numpy as np anp.array([[1,2],[3,4],[4,5]]) print(a) print(a.shape) bnp.squeeze(a) print(b) ca.reshape(1,6,1) print(c) print(np.squeeze(c)) pri…

从智能交通到智能能源:智慧城市在7个方面的应用实践

来源&#xff1a;资本实验室目前&#xff0c;智慧城市已经成为全球众多城市未来规划和设计的方向&#xff0c;并致力于通过各种新技术的应用来改善城市居民的工作与生活。但什么样的技术应用能够推动智慧城市的建设&#xff1f;如何让新技术在智慧城市中的应用效率最大化&#…

别以为if slse很简单——决策树

怎么分——熵与Gini指数 熵&#xff0c;表示信息量的期望&#xff0c;含义是混乱程度&#xff0c;也是对随机变量编码所需的最小比特数。请参考之前的文章。 信息增益建立在熵之上&#xff0c;是选择某特征之后熵减少的多少&#xff08;熵减少即信息增加&#xff09;&#xf…