【深度学习】实验07 使用TensorFlow完成逻辑回归

文章目录

  • 使用TensorFlow完成逻辑回归
    • 1. 环境设定
    • 2. 数据读取
    • 3. 准备好placeholder
    • 4. 准备好参数/权重
    • 5. 计算多分类softmax的loss function
    • 6. 准备好optimizer
    • 7. 在session里执行graph里定义的运算
  • 附:系列文章

使用TensorFlow完成逻辑回归

TensorFlow是一种开源的机器学习框架,由Google Brain团队于2015年开发。它被广泛应用于图像和语音识别、自然语言处理、推荐系统等领域。

TensorFlow的核心是用于计算的数据流图。在数据流图中,节点表示数学操作,边表示张量(多维数组)。将操作和数据组合在一起的数据流图可以使 TensorFlow 对复杂的数学模型进行优化,同时支持分布式计算。

TensorFlow提供了Python,C++,Java,Go等多种编程语言的接口,让开发者可以更便捷地使用TensorFlow构建和训练深度学习模型。此外,TensorFlow还具有丰富的工具和库,包括TensorBoard可视化工具、TensorFlow Serving用于生产环境的模型服务、Keras高层封装API等。

TensorFlow已经发展出了许多优秀的模型,如卷积神经网络、循环神经网络、生成对抗网络等。这些模型已经在许多领域取得了优秀的成果,如图像识别、语音识别、自然语言处理等。

除了开源的TensorFlow,Google还推出了基于TensorFlow的云端机器学习平台Google Cloud ML,为用户提供了更便捷的训练和部署机器学习模型的服务。

解决分类问题里最普遍的baseline model就是逻辑回归,简单同时可解释性好,使得它大受欢迎,我们来用tensorflow完成这个模型的搭建。

1. 环境设定

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'import warnings
warnings.filterwarnings("ignore")import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time

2. 数据读取

#使用tensorflow自带的工具加载MNIST手写数字集合
mnist = input_data.read_data_sets('./data/mnist', one_hot=True) 
Extracting ./data/mnist/train-images-idx3-ubyte.gz
Extracting ./data/mnist/train-labels-idx1-ubyte.gz
Extracting ./data/mnist/t10k-images-idx3-ubyte.gz
Extracting ./data/mnist/t10k-labels-idx1-ubyte.gz
#查看一下数据维度
mnist.train.images.shape
(55000, 784)
#查看target维度
mnist.train.labels.shape
(55000, 10)

3. 准备好placeholder

batch_size = 128
X = tf.placeholder(tf.float32, [batch_size, 784], name='X_placeholder') 
Y = tf.placeholder(tf.int32, [batch_size, 10], name='Y_placeholder')

4. 准备好参数/权重

w = tf.Variable(tf.random_normal(shape=[784, 10], stddev=0.01), name='weights')
b = tf.Variable(tf.zeros([1, 10]), name="bias")
logits = tf.matmul(X, w) + b 

5. 计算多分类softmax的loss function

# 求交叉熵损失
entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y, name='loss')
# 求平均
loss = tf.reduce_mean(entropy)

6. 准备好optimizer

这里的最优化用的是随机梯度下降,我们可以选择AdamOptimizer这样的优化器

learning_rate = 0.01
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)

7. 在session里执行graph里定义的运算

#迭代总轮次
n_epochs = 30with tf.Session() as sess:# 在Tensorboard里可以看到图的结构writer = tf.summary.FileWriter('../graphs/logistic_reg', sess.graph)start_time = time.time()sess.run(tf.global_variables_initializer())	n_batches = int(mnist.train.num_examples/batch_size)for i in range(n_epochs): # 迭代这么多轮total_loss = 0for _ in range(n_batches):X_batch, Y_batch = mnist.train.next_batch(batch_size)_, loss_batch = sess.run([optimizer, loss], feed_dict={X: X_batch, Y:Y_batch}) total_loss += loss_batchprint('Average loss epoch {0}: {1}'.format(i, total_loss/n_batches))print('Total time: {0} seconds'.format(time.time() - start_time))print('Optimization Finished!')# 测试模型preds = tf.nn.softmax(logits)correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y, 1))accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32))n_batches = int(mnist.test.num_examples/batch_size)total_correct_preds = 0for i in range(n_batches):X_batch, Y_batch = mnist.test.next_batch(batch_size)accuracy_batch = sess.run([accuracy], feed_dict={X: X_batch, Y:Y_batch}) total_correct_preds += accuracy_batch[0]print('Accuracy {0}'.format(total_correct_preds/mnist.test.num_examples))writer.close()
   Average loss epoch 0: 0.36748782022571785    Average loss epoch 1: 0.2978815356126198    Average loss epoch 2: 0.27840628396797845    Average loss epoch 3: 0.2783186247437706    Average loss epoch 4: 0.2783641471138923    Average loss epoch 5: 0.2750668214473413           Average loss epoch 6: 0.2687560408126502    Average loss epoch 7: 0.2713795114126239    Average loss epoch 8: 0.2657588795522154    Average loss epoch 9: 0.26322007090686916    Average loss epoch 10: 0.26289192279735646    Average loss epoch 11: 0.26248606019989873       Average loss epoch 12: 0.2604622903056356    Average loss epoch 13: 0.26015280702939403    Average loss epoch 14: 0.2581879366319496    Average loss epoch 15: 0.2590309207117085    Average loss epoch 16: 0.2630510463581219    Average loss epoch 17: 0.25501730025578767    Average loss epoch 18: 0.2547102673000945    Average loss epoch 19: 0.258298404375851    Average loss epoch 20: 0.2549241428330784    Average loss epoch 21: 0.2546788509283866    Average loss epoch 22: 0.259556887067837    Average loss epoch 23: 0.25428259843365575    Average loss epoch 24: 0.25442713139565676    Average loss epoch 25: 0.2553852511383159    Average loss epoch 26: 0.2503043229415978    Average loss epoch 27: 0.25468004046828596    Average loss epoch 28: 0.2552785321479633    Average loss epoch 29: 0.2506257003663859    Total time: 28.603315353393555 seconds    Optimization Finished!    Accuracy 0.9187

附:系列文章

序号文章目录直达链接
1波士顿房价预测https://want595.blog.csdn.net/article/details/132181950
2鸢尾花数据集分析https://want595.blog.csdn.net/article/details/132182057
3特征处理https://want595.blog.csdn.net/article/details/132182165
4交叉验证https://want595.blog.csdn.net/article/details/132182238
5构造神经网络示例https://want595.blog.csdn.net/article/details/132182341
6使用TensorFlow完成线性回归https://want595.blog.csdn.net/article/details/132182417
7使用TensorFlow完成逻辑回归https://want595.blog.csdn.net/article/details/132182496
8TensorBoard案例https://want595.blog.csdn.net/article/details/132182584
9使用Keras完成线性回归https://want595.blog.csdn.net/article/details/132182723
10使用Keras完成逻辑回归https://want595.blog.csdn.net/article/details/132182795
11使用Keras预训练模型完成猫狗识别https://want595.blog.csdn.net/article/details/132243928
12使用PyTorch训练模型https://want595.blog.csdn.net/article/details/132243989
13使用Dropout抑制过拟合https://want595.blog.csdn.net/article/details/132244111
14使用CNN完成MNIST手写体识别(TensorFlow)https://want595.blog.csdn.net/article/details/132244499
15使用CNN完成MNIST手写体识别(Keras)https://want595.blog.csdn.net/article/details/132244552
16使用CNN完成MNIST手写体识别(PyTorch)https://want595.blog.csdn.net/article/details/132244641
17使用GAN生成手写数字样本https://want595.blog.csdn.net/article/details/132244764
18自然语言处理https://want595.blog.csdn.net/article/details/132276591

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/67659.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vlan笔记

在一个LAN中可能有很多设备节点,有很多协议使用广播,每次整个网络中所有设备都要处理广播,效率低。 一个LAN中用一个或多个二层交换机switch连接,交换机对广播透明。 在交换机上配置VLAN,把物理网络划分为多个逻辑网…

肖sir__设计测试用例方法之场景法04_(黑盒测试)

设计测试用例方法之场景法 1、场景法主要是针对测试场景类型的,顾也称场景流程分析法。 2、流程分析是将软件系统的某个流程看成路径,用路径分析的方法来设计测试用例。根据流程的顺序依次进行组合,使得流程的各个分支能走到。 举例说明&…

网易低代码引擎Tango正式开源

一、Tango简介 Tango 是一个用于快速构建低代码平台的低代码设计器框架,借助 Tango 只需要数行代码就可以完成一个基本的低代码平台前端系统的搭建。Tango 低代码设计器直接读取前端项目的源代码,并以源代码为中心,执行和渲染前端视图,并为用户提供低代码可视化搭建能力,…

uniapp从零到一的学习商城实战

涵盖的功能: 安装开发工具HBuilder:HBuilderX-高效极客技巧 创建项目步骤: 1.右键-项目: 2.选择vue2和默认模板: 3.完整的项目目录: 微信开发者工具调试: 1.安装微信开发者工具 2.打开…

GeoServe Web 管理界面 实现远程访问

文章目录 前言1.安装GeoServer2. windows 安装 cpolar3. 创建公网访问地址4. 公网访问Geo Servcer服务5. 固定公网HTTP地址 前言 GeoServer是OGC Web服务器规范的J2EE实现,利用GeoServer可以方便地发布地图数据,允许用户对要素数据进行更新、删除、插入…

Android Studio新版本New UI及相关设置丨遥遥领先版

1、前言 俗话说工欲善其事必先利其器嘛,工具用不好怎么行呢,借着Android Studio的更新,介绍一下新版本中的更新内容,以及日常开发中那些好用的设置。 2、关于新版本 2.1、最新正式版本 Android Studio Giraffe | 2022.3.1 Pat…

elementui el-table在有summary-method时,table数据行将合计行遮挡住了

前端使用框架:elementUI 使用组件:el-table 在表格内添加合计了合计行,根据业务多次调用数据渲染画面后,偶然导致画面变成如下图所示,table的数据行将合计行遮挡住了,且这个现象有时候好用,有…

Android图形-架构1

目录 引言 Android图形的关键组件: Android图形的pipeline数据流 BufferQueue是啥? 引言 Android提供用于2D和3D图形渲染的API,可与制造商的驱动程序实现代码交互,下面梳理一下Android图形的运作原理。 应用开发者通过三种方…

C++多态案例2----制作饮品

#include<iostream> using namespace std;//制作饮品的大致流程都为&#xff1a; //煮水-----冲泡-----倒入杯中----加入辅料//本案例利用多态技术&#xff0c;提供抽象类制作饮品基类&#xff0c;提供子类制作茶叶和咖啡class AbstractDrinking {public://煮水//冲水//倒…

Scala的集合操作之可变数组和不可变数组,可变List集合与不可变List集合,可变Set与不可变Set操作,可变和不可变Map集合和元组操作

Scala的集合操作之&#xff0c;可变数组和不可变数组&#xff0c;可变List集合与不可变List集合 不可变数组 /* traversable/ˈtrvəsəbl/adj.能越过的&#xff1b;可否认的*/ object Test01_ImmutableArray {def main(args: Array[String]): Unit {// 1. 创建数组val arr:…

视频监控/视频汇聚/视频云存储EasyCVR平台HLS流集成在小程序无法播放问题排查

安防视频/视频云存储/视频集中存储EasyCVR视频监控综合管理平台可以根据不同的场景需求&#xff0c;让平台在内网、专网、VPN、广域网、互联网等各种环境下进行音视频的采集、接入与多端分发。在视频能力上&#xff0c;视频云存储平台EasyCVR可实现视频实时直播、云端录像、视频…

字节前端实习的两道算法题,看看强度如何

最长严格递增子序列 题目描述 给你一个整数数组nums&#xff0c;找到其中最长严格递增子序列的长度。 子序列是由数组派生而来的序列&#xff0c;删除&#xff08;或不删除&#xff09;数组中的元素而不改变其余元素的顺序。例如&#xff0c;[3,6,2,7] 是数组 [0,3,1,6,2,2,7…

flink实现kafka、doris精准一次说明

前言说明:本文档只讨论数据源为kafka的情况实现kafka和doris的精准一次写入 flink的kafka连接器已经实现了自动提交偏移量到kafka,当flink中的数据写入成功后,flink会将这批次数据的offset提交到kafka,程序重启时,kafka中记录了当前groupId消费的offset位置,开始消费时将…

文件系统与inode编号

文件描述符fd 0&1&2 Linux 进程默认情况会有3个缺省打开的文件描述符&#xff0c;分别是标准输入0&#xff0c; 标准输出1&#xff0c; 标准错误2. 0,1,2对应的物理设备一般是&#xff1a;键盘&#xff0c;显示器&#xff0c;显示器 所以输入输出还可以采用如下方式 …

中国非晶纳米晶行业市场预测与投资战略报告(2023版)

内容简介&#xff1a; 由于性能优异&#xff0c;非晶材料从20世纪80年代开始成为中国外科学界研究重点&#xff0c;目前美、日、德已经具备完善的生产规模&#xff0c;大量的非晶合金产品逐渐取代硅钢、铁氧体等。 2021年之前&#xff0c;中国铁基非晶带材企业有12家&#xf…

AVR128单片机 自动售水机

一、系统方案 1、设计使用两个按键分别为S1和S2及一个发光二极管LED。S1为出水控制按键&#xff0c;当S1按下&#xff0c;表示售水机持续出水&#xff0c;继电器&#xff08;库元件relay&#xff09;接通&#xff0c;指示灯LED亮。S2为停水控制键&#xff0c;当S2按下&#xff…

CSS面试题

CSS面试题 1.说说flexbox&#xff08;弹性盒布局模型&#xff09;,以及适用场景&#xff1f;flex-directionflex-wrapflex-flowjustify-contentalign-itemsalign-content 2.让Chrome支持小于12px 的文字方式有哪些&#xff1f;区别&#xff1f;3.css选择器有哪些? 优先级? 哪…

OSCS 安全周报第 58 期:VMware Aria Operations SSH 身份验证绕过漏洞 (CVE-2023-34039)

​ 本周安全态势综述 OSCS 社区共收录安全漏洞 3 个&#xff0c;公开漏洞值得关注的是 VMware Aria Operations SSH 身份验证绕过漏洞( CVE-2023-34039 )、Apache Airflow Spark Provider 反序列化漏洞( CVE-2023-40195 )。 针对 NPM 仓库&#xff0c;共监测到 324 个不同版本…

9月3日,每日信息差

第一、中国中铁与广州市城中村改造做地主体签署战略合作框架协议。根据协议&#xff0c;双方将积极响应广州市统筹做地推进高质量发展工作精神&#xff0c;充分发挥双方优势资源&#xff0c;共同加大在物业复建安置、基础设施建设、综合开发投资、城中村改造&#xff08;微改造…

volatile 关键字 与 CPU cache line 的效率问题

分析&回答 Cache Line可以简单的理解为CPU Cache中的最小缓存单位。目前主流的CPU Cache的Cache Line大小都是64Bytes。假设我们有一个512字节的一级缓存&#xff0c;那么按照64B的缓存单位大小来算&#xff0c;这个一级缓存所能存放的缓存个数就是512/64 8个。具体参见下…