【TensorFlow】随机训练和批训练的比较与实现

一、随机训练和批训练

  1. 随机训练:一次随机抽样训练数据和目标数据对完成训练。
  2. 批训练:一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。
  3. 批训练和随机训练的差异优化器方法收敛的不同
  4. 批训练的难点在于:确定合适的batch_size
  5. 二者比较
训练类型优点缺点
随机训练脱离局部最小一般需要更多的迭代次数才收敛
批训练快速得到最小损失耗费更多的计算资源

二、实现随机训练

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.python.framework import ops
ops.reset_default_graph()
# 一、随机训练:# 1.创建计算图
sess = tf.Session()# 2.创建数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)# 3.创建变量
A = tf.Variable(tf.random_normal(shape=[1]))# 4.增加图操作
my_output = tf.multiply(x_data, A)# 5.声明L2正则损失
loss = tf.square(my_output - y_target)# 6.声明优化器 学习率为0.02
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)# 7.初始化变量
init = tf.global_variables_initializer()
sess.run(init)# 8.保存loss数据用于绘图
loss_stochastic = []# 9.开始训练
for i in range(100):rand_index = np.random.choice(100)rand_x = [x_vals[rand_index]]rand_y = [y_vals[rand_index]]sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})if (i+1)%5==0:print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})print('Loss = ' + str(temp_loss))loss_stochastic.append(temp_loss)
# 输出结果
Step #5 A = [2.0631378]
Loss = [60.90259]
Step #10 A = [3.560384]
Loss = [35.39518]
Step #15 A = [4.7225595]
Loss = [37.812637]
Step #20 A = [5.681144]
Loss = [13.796157]
Step #25 A = [6.4919457]
Loss = [13.752169]
Step #30 A = [7.1609416]
Loss = [9.70855]
Step #35 A = [7.710085]
Loss = [5.826261]
Step #40 A = [8.253489]
Loss = [7.3934216]
Step #45 A = [8.671478]
Loss = [2.5475926]
Step #50 A = [8.993064]
Loss = [1.32571]
Step #55 A = [9.101872]
Loss = [0.67589337]
Step #60 A = [9.256593]
Loss = [5.34419]
Step #65 A = [9.329251]
Loss = [0.58555096]
Step #70 A = [9.421848]
Loss = [3.088755]
Step #75 A = [9.563117]
Loss = [6.0601945]
Step #80 A = [9.661991]
Loss = [0.05205128]
Step #85 A = [9.8208685]
Loss = [2.3963788]
Step #90 A = [9.8652935]
Loss = [0.19284673]
Step #95 A = [9.842097]
Loss = [4.9211507]
Step #100 A = [10.044914]
Loss = [4.2354054]

三、实现批训练

import numpy as np
import tensorflow as tf
import matplotlib as pltfrom tensorflow.python.framework import ops
ops.reset_default_graph()sess = tf.Session()# 1.声明批量大小(一次传入多少训练数据)
batch_size = 20# 2.声明模型的数据、占位符和变量。
# 这里能做的是改变占位符的形状,占位符有两个维度:
# 第一个维度为None,第二个维度是批量训练中的数据量。
# 我们能显式地设置维度为20,也能设为None。
# 我们必须知道训练模型中的维度,从而阻止不合法的矩阵操作
x_vals = np.random.normal(1,0.1,100)
y_vals = np.repeat(10.,100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
A = tf.Variable(tf.random_normal(shape=[1,1]))# 3.现在在计算图中增加矩阵乘法操作,
# 切记矩阵乘法不满足交换律,所以在matmul()函数中的矩阵参数顺序要正确:
my_output = tf.multiply(x_data, A)# 4.改变损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))# 5.声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)# 6.在训练中通过循环迭代优化模型算法。
# 为了绘制损失值图与随机训练对比
# 这里初始化一个列表每间隔5次迭代保存损失函数# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)loss_batch = []
for i in range(100):# 每次用0~100中取20个数作为索引值rand_index = np.random.choice(100, size=batch_size)# 转置rand_x = np.transpose([x_vals[rand_index]])rand_y = np.transpose([y_vals[rand_index]])sess.run(train_step, feed_dict={x_data:rand_x,y_target:rand_y})if (i+1)%5 == 0:print("Step # " + str(i+1) + ' A = ' + str(sess.run(A)))temp_loss = sess.run(loss, feed_dict={x_data:rand_x,y_target:rand_y})print('Loss = ' + str(temp_loss))loss_batch.append(temp_loss)
在这里插入代码片
# 输出结果
Step # 5 A = [[2.626382]]
Loss = 55.444374
Step # 10 A = [[3.980196]]
Loss = 36.855064
Step # 15 A = [[5.0858808]]
Loss = 22.765038
Step # 20 A = [[5.9751787]]
Loss = 15.496961
Step # 25 A = [[6.713659]]
Loss = 12.349718
Step # 30 A = [[7.2950797]]
Loss = 7.5467796
Step # 35 A = [[7.782353]]
Loss = 5.17468
Step # 40 A = [[8.20625]]
Loss = 4.1199327
Step # 45 A = [[8.509094]]
Loss = 2.6329637
Step # 50 A = [[8.760488]]
Loss = 1.9998455
Step # 55 A = [[8.967735]]
Loss = 1.6577679
Step # 60 A = [[9.1537]]
Loss = 1.4356906
Step # 65 A = [[9.317189]]
Loss = 1.9666836
Step # 70 A = [[9.387019]]
Loss = 1.9287064
Step # 75 A = [[9.499526]]
Loss = 1.7477573
Step # 80 A = [[9.594302]]
Loss = 1.719229
Step # 85 A = [[9.666611]]
Loss = 1.4769726
Step # 90 A = [[9.711805]]
Loss = 1.1235845
Step # 95 A = [[9.784608]]
Loss = 1.9176414
Step # 100 A = [[9.849552]]
Loss = 1.1561565

四、绘制图像

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

在这里插入图片描述
从图中可以看出批训练损失更平滑,随机训练损失更不规则

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/481105.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「小公式」平均数与级数

喵喵喵,小夕最近准备复习一下数学和基础算法,所以可能会推送或者附带推送点数学和基础算法的小文章。说不定哪天就用(考)到了呢( ̄∇ ̄)注意哦,与头条位的文章推送不同,「小公式」和「…

最新出炉-阿里 2020届算法工程师-自然语言处理(实习生)以及补充:快递最短路径

问题2感觉跟下面的分苹果类似; 问题 G: 分梨 题目描述 zzq非常喜欢吃梨,有一天他得到了ACMCLUB送给他的一筐梨子。由于他比较仗义,就打算把梨子分给好朋友们吃。现在他要把M个梨子放到N个盘子里面(我们允许有的盘子为空&#xff0…

如何匹配两段文本的语义?

喵喵喵,好久不见啦。首先很抱歉大家期待的调参手册(下)迟迟没有出稿,最近两个月连着赶了4个DDL,整个人都不好了。最近几天终于有时间赶一下未完成的稿子了。在赶DDL的时候夹着写了这篇文章,就先发布这一篇吧…

【TensorFlow】实现简单的鸢尾花分类器

代码实现及说明 # python 3.6 # TensorFlow实现简单的鸢尾花分类器 import matplotlib.pyplot as plt import tensorflow as tf import numpy as np from sklearn import datasetssess tf.Session()#导入数据 iris datasets.load_iris() # 是否是山鸢尾 0/1 binary_target …

偏差-方差全面解释

偏差(Bias)与方差(Variance) 目录: 为什么会有偏差和方差?偏差、方差、噪声是什么?泛化误差、偏差和方差的关系?用图形解释偏差和方差。偏差、方差窘境。偏差、方差与过拟合、欠拟合…

「小算法」回文数与数值合法性检验

喵喵喵,小夕最近准备复习一下数学和基础算法,尽量每篇推送下面会附带点数学和基础算法的小文章。说不定哪天就用(考)到了呢( ̄∇ ̄)注意哦,与头条位的文章推送不同,「小公式」和「小算…

【TensorFlow】实现、训练并评估简单的回归模型和分类模型

1 回归模型 回归算法模型用来预测连续数值型,其目标不是分类值而是数字。为了评估这些回归预测值是否与实际目标相符,我们需要度量两者间的距离,打印训练过程中的损失,最终评估模型损失。 这里使用的例子是从均值为1、标准差为0…

史上最通熟易懂的检索式聊天机器人讲解

喵喵喵,一不小心又匿了三个月,突然诈尸害不害怕( ̄∇ ̄) 小夕从7月份开始收到第一场面试邀请,到9月初基本结束了校招(面够了面够了T_T),深深的意识到今年的对话系统/chatbot方向是真的…

Python pandas数据分析中常用方法

官方教程 读取写入文件 官方IO 读取 写入 read_csv       to_csv read_excel      to_excel read_hdf       to_hdf read_sql       to_sql read_json      to_json read_msgpack (experimental)   to_msgpack (experimental) read_html    …

小哥哥,检索式chatbot了解一下?

喵喵喵,一不小心又匿了三个月,突然诈尸害不害怕( ̄∇ ̄) 小夕从7月份开始收到第一场面试邀请,到9月初基本结束了校招(面够了面够了T_T),深深的意识到今年的对话系统/chatbot方向是真的…

领域应用 | 中医临床术语系统

本文转载自公众号中医药知识组织与标准。什么是中医药术语系统?它是干什么用的呢?中医药术语系统是运用计算机与信息技术等工具,对中医药学各领域中的事物、现象、特性、关系和过程进行标记和概括,并为每个概念赋予指称形成概念体…

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Image图片组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Image图片组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Image组件 Image 用来加载并显示图片的基础组件,它支持从内存、本…

梁家卿 | 百科知识图谱同步更新

本文转载自公众号知识工场。 本文整理自复旦大学知识工场梁家卿博士在IJCAI 2017 会议上的论文报告,题目为《How to Keep a Knowledge Base Synchronized with Its Encyclopedia Source》,作者包括:梁家卿博士(复旦大学&#xff0…

别再搜集面经啦!小夕教你斩下NLP算法岗offer!

推完上一篇文章,订阅号和知乎后台有好多小伙伴跟小夕要面经(还有个要买简历的是什么鬼),然鹅小夕真的没有整理面经呀,真的木有时间(。 ́︿ ̀。)。不过话说回来,面经有多大用呢&#…

肖仰华 | 基于知识图谱的问答系统

本文转载自公众号知识工场。 本文整理自复旦大学知识工场肖仰华教授在VLDB 2017 会议上的论文报告,题目为《KBQA: Learning Question Answering over QA Corpora and Knowledge Bases》,作者包括:崔万云博士(现上海财经大学讲师&a…

【小夕精选】YJango 7分钟带你领略你未曾想过的线性代数+微积分

小夕很早之前就想转一些精彩的技术文章,这样哪怕没有时间写作的时候,也能把优质的干货分享给大家~然鹅,由于我也不知道是什么的原因,就不小心拖到了现在╮( ̄▽ ̄"")╭之前有不少粉丝希…

白硕 | 基于区块链的众包社区激励机制

本文整理自白硕老师在 YOCSEF 武汉专题论坛:“人工智能遇到区块链,是惊鸿一瞥还是天长地久?”的报告。 很高兴有这个机会跟大家交流。我先讲几个案例作为引子。第一个案例与知识图谱有关。这个公司做的是非常垂直的一个领域,安全教…

【小夕精选】多轮对话之对话管理(Dialog Management)

这一篇是一段时间之前小夕初入对话领域时刷到的徐阿衡小姐姐写的一篇文章,写的深入浅出,十分适合有一定基础的情况下想快速了解对话管理技术的童鞋阅读~另外顺手推一下阿衡小姐姐的订阅号「徐阿衡」,干货满满不要错过哦~这一篇想写一写对话管…

KD Tree的原理及Python实现

1. 原理篇我们用大白话讲讲KD-Tree是怎么一回事。1.1 线性查找假设数组A为[0, 6, 3, 8, 7, 4, 11],有一个元素x,我们要找到数组A中距离x最近的元素,应该如何实现呢?比较直接的想法是用数组A中的每一个元素与x作差,差的…

漆桂林 | 知识图谱的应用

本文作者为东南大学漆桂林老师,首发于知乎专栏知识图谱和智能问答 前面一篇文章“知识图谱之语义网络篇”已经提到了知识图谱的发展历史,回顾一下有以下几点: 1. 知识图谱是一种语义网络,即一个具有图结构的知识库,这里…