机器学习---使用 TensorFlow 构建神经网络模型预测波士顿房价和鸢尾花数据集分类

1. 预测波士顿房价

1.1 导包
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport itertoolsimport pandas as pd
import tensorflow as tftf.logging.set_verbosity(tf.logging.INFO)

最后一行设置了TensorFlow日志的详细程度:

tf.logging.DEBUG:最详细的日志级别,用于记录调试信息。

tf.logging.INFO:用于记录一般的信息性消息,比如训练过程中的指标和进度。

tf.logging.WARN:用于记录警告消息,表示可能存在潜在问题,但不会导致程序终止。

tf.logging.ERROR:仅记录错误消息,表示程序遇到了错误并可能终止执行。

tf.logging.FATAL:记录严重错误消息,并终止程序的执行。

1.2 处理数据集
COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age","dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm","age", "dis", "tax", "ptratio"]
LABEL = "medv"training_set = pd.read_csv("boston_train.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)
test_set = pd.read_csv("boston_test.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)
prediction_set = pd.read_csv("boston_predict.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)

定义了一些列名和特征,并使用pd.read_csv函数读取了训练集、测试集和预测集的数据。

pd.read_csv函数来读取CSV文件,并将其转换为Pandas数据帧。

1.3  创建DNNRegressor对象
feature_cols = [tf.feature_column.numeric_column(k) for k in FEATURES]
regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,hidden_units=[50,50,50],model_dir="./boston_model")

tf.feature_column.numeric_column函数用于创建一个表示数值特征的特征列。在这种情况下,它

会遍历FEATURES列表中的每个特征名称,并为每个特征创建一个数值特征列。

创建DNNRegressor对象的参数:

  feature_columns:这是包含特征列的列表,用于定义输入的特征。在这里,您传递了之前创建

feature_cols,它包含了用于模型训练的数值特征列。

  hidden_units:这是一个整数列表,用于定义隐藏层的结构。在这个例子中,您定义了一个具

有3个隐藏层的DNN模型,每个隐藏层都有50个神经元。

model_dir:这是模型保存的目录路径。在这里,您指定了"./boston_model"作为模型保存的目录。

1.4 创建输入函数
def get_input_fn(data_set, num_epochs=None, shuffle=True):return tf.estimator.inputs.pandas_input_fn(x=pd.DataFrame({k: data_set[k].values for k in FEATURES}),y = pd.Series(data_set[LABEL].values),num_epochs=num_epochs,shuffle=shuffle)

该输入函数将Pandas数据帧作为输入,并将其转换为TensorFlow的输入格式。具体而言,它将特

征数据集(由FEATURES列表指定的列)转换为x,将标签数据(由LABEL指定的列)转换为y

1.5 训练评估预测
regressor.train(input_fn=get_input_fn(training_set), steps=5000)
ev = regressor.evaluate(input_fn=get_input_fn(test_set, num_epochs=1, shuffle=False))
loss_score = ev["loss"]
print("Loss: {0:f}".format(loss_score))
y = regressor.predict(input_fn=get_input_fn(prediction_set, num_epochs=1, shuffle=False))
# .predict() returns an iterator of dicts; convert to a list and print
# predictions
predictions = list(p["predictions"] for p in itertools.islice(y, 6))
print("Predictions: {}".format(str(predictions)))

steps参数指定了训练的迭代步数,即模型将对训练数据执行多少次梯度下降更新。

使用get_input_fn获取输入函数,该函数将测试集(test_set)作为输入数据。num_epochs参数设

置为1,表示测试集只会被迭代一次,shuffle参数被设置为False,表示测试集不需要进行洗牌。

然后提取评估结果中的损失值(loss),并将其赋值给loss_score变量。

通过迭代预测结果的字典形式,将预测值提取出来,并将其存储在predictions列表中。

2. 鸢尾花数据集分类

import tensorflow as tf
import pandas as pdCOLUMN_NAMES = ['SepalLength', 'SepalWidth','PetalLength', 'PetalWidth', 'Species']# Import training dataset
training_dataset = pd.read_csv('iris_training.csv', names=COLUMN_NAMES, header=0)
train_x = training_dataset.iloc[:, 0:4]
train_y = training_dataset.iloc[:, 4]# Import testing dataset
test_dataset = pd.read_csv('iris_test.csv', names=COLUMN_NAMES, header=0)
test_x = test_dataset.iloc[:, 0:4]
test_y = test_dataset.iloc[:, 4]# Setup feature columns
columns_feat = [tf.feature_column.numeric_column(key='SepalLength'),tf.feature_column.numeric_column(key='SepalWidth'),tf.feature_column.numeric_column(key='PetalLength'),tf.feature_column.numeric_column(key='PetalWidth')
]# Build Neural Network - Classifier
classifier = tf.estimator.DNNClassifier(feature_columns=columns_feat,# Two hidden layers of 10 nodes each.hidden_units=[10, 10],# The model is classifying 3 classesn_classes=3)# Define train function
def train_function(inputs, outputs, batch_size):dataset = tf.data.Dataset.from_tensor_slices((dict(inputs), outputs))dataset = dataset.shuffle(1000).repeat().batch(batch_size)return dataset.make_one_shot_iterator().get_next()# Train the Model.
classifier.train(input_fn=lambda:train_function(train_x, train_y, 100),steps=1000)# Define evaluation function
def evaluation_function(attributes, classes, batch_size):attributes=dict(attributes)if classes is None:inputs = attributeselse:inputs = (attributes, classes)dataset = tf.data.Dataset.from_tensor_slices(inputs)assert batch_size is not None, "batch_size must not be None"dataset = dataset.batch(batch_size)return dataset.make_one_shot_iterator().get_next()# Evaluate the model.
eval_result = classifier.evaluate(input_fn=lambda:evaluation_function(test_x, test_y, 100))print('\nAccuracy: {accuracy:0.3f}\n'.format(**eval_result))

首先导入所需的库,包括 TensorFlow 和 Pandas。然后,定义了一个包含特征列的列

表 columns_feat,用于描述输入数据的特征。接下来,通过 Pandas 读取训练集和测试集的数

据,并将其分为输入特征和输出类别。

然后,使用 tf.estimator.DNNClassifier 类构建了一个多层感知机神经网络分类器。该分类器具

有两个隐藏层,每个隐藏层包含10个节点,输出层用于分类3个类别的鸢尾花。

然后,定义了一个训练函数 train_function 和一个评估函数 evaluation_function,用于转换输

入数据并创建 TensorFlow 数据集。训练函数将训练数据转换为 Dataset 对象,并进行随机化、重

复和分批处理。评估函数将测试数据转换为 Dataset 对象,并进行分批处理。

最后,通过调用 classifier.train 方法来训练模型,使用训练函数作为输入函数,并指定训练步

数。然后,通过调用 classifier.evaluate 方法来评估模型的性能,使用评估函数作为输入函数,

并指定评估时的批大小。评估结果包括准确率,并通过 print 函数进行输出。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/123099.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试测试工程师一般问什么问题?

面试和项目一起,是自学路上的两大拦路虎。面试测试工程师一般会被问什么问题,总结下来一般是下面这4类: 1.做好自我介绍 2.项目相关问题 3.技术相关问题 4.人事相关问题 接下来,主要从以上四个方向分别展开介绍。为了让大家更有获…

[双指针](一) Leetcode 283.移动零和1089.复写零

[双指针] Leetcode 283.移动零和1089.复写零 移动零 283. 移动零 1.题意分析 (1) 给你一个数组,将数组中的所有0移动到数组的末尾 (2) 保证非0元素在数组中相对位置不变 (3) 在原数组中操作 2.解题思路 由于题目要求我们移动数组内容(也就是交换两…

离线语音通断器开发-稳定之后顺应新需求

使用云知声的US516p6方案开发了一系列的离线语音通断器,目前已经取得了不小的收获,有1路的,3路的,4路的,唛头和扬声器包括唛头线材也在不断的更新打磨中找到了效果特别好的供应商。 离线语音通断器,家用控…

Beyond Compare比较规则设置 Beyond Compare怎么对比表格

在对文件进行比较时,文件夹内的文件可能存在不同类型、不同后缀名、不同内容等差异,这些差异会影响具体的比较结果,因此需要我们对软件的比较规则进行一些设置。接下来就让我们一起来学习一下Beyond Compare比较规则设置,Beyond C…

重构之美:Java Swing中 如何对指定行文本进行CSS样式渲染,三种实现思路分享

文章目录 需求分析Document 应用彩蛋 需求分析 在Swing中,如果期望实现对JTextArea 或者 TextPane等文本区域实现单行渲染改怎么做?如上图所示 总的来说有两种实现方案 文本行数可控,那么构造一组JLabel集合按表单顺序添加,这样可…

松下A6B伺服 马达不动问题解决

本人在用信捷XDH plc ethercat总线,连松下A6B伺服,轴配置完成轴调试时,出现能使能,但 马达不动的情况。 开始总怀疑时信捷PLC的原因,后面查明是输入口定义引起的。 用USB线连接伺服,打开PANAPARM软件,自…

在Mac上安装MongoDB 5.0

MongoDB 5.0安装 1、环境描述 操作系统:macOS 14.0 (23A344) 2、安装MongoDB 2.1、tar解压包安装 下载地址:Download MongoDB Community Server | MongoDB 创建一个目录,以便数据库将文件放入其中。(默认情况下,数据…

linux--

一、crond 任务调度 1、原理示意图 2、crontab 进行定时任务的设置 2.1. 概述 任务调度,是指系统在某个时间执行的特定的命令或程序。任务调度分类: 系统工作: 有些重要的工作必须周而复始地执行。如病毒扫描等 个别用户工作:个别用户可能希望执行某些…

深度学习:张量 介绍

张量[1]是向量和矩阵到 n 维的推广。了解它们如何相互作用是机器学习的基础。 简介 虽然张量看起来是复杂的对象,但它们可以理解为向量和矩阵的集合。理解向量和矩阵对于理解张量至关重要。 向量是元素的一维列表: 矩阵是向量的二维列表: 下标…

Ajax学习笔记第三天

做决定之前仔细考虑,一旦作了决定就要勇往直前、坚持到底! 【1 ikunGG邮箱注册】 整个流程展示: 1.文件目录 2.页面效果展示及代码 mysql数据库中的初始表 2.1 主页 09.html:里面代码部分解释 display: inline-block; 让块元素h1变成行内…

OpenCV C++ 图像处理实战 ——《缺陷检测》

OpenCV C++ 图像处理实战 ——《缺陷检测》 一、结果演示二、缺陷检测算法2.1、多元模板图像2.2、训练差异模型三、图像配准3.1 功能源码3.1 功能效果四、多元模板图像4.1 功能源码五、缺陷检测5.1 功能源码六、源码测试图像下载总结一、结果演示

pytorch笔记:TRIPLETMARGINLOSS

1 介绍 创建一个衡量三元组损失的标准,给定输入张量 x1​、x2​ 和 x3​ 以及一个大于0的间距值。这用于测量样本之间的相对相似性。一个三元组由a、p和n组成(锚点、正例和负例)。所有输入张量的形状都应为 (N,D) 2 基本使用方法 torch.nn.…

AD9371 官方例程HDL详解之JESD204B RX侧格式配置

AD9371 系列快速入口 AD9371ZCU102 移植到 ZCU106 : AD9371 官方例程构建及单音信号收发 采样率和各个时钟之间的关系 : AD9371 官方例程HDL详解之JESD204B TX侧时钟生成 (三) 参考资料: UltraScale Architecture G…

【AD9361 数字接口CMOS LVDSSPI】D 串行数据之SPI

【AD9361 数字接口CMOS &LVDS&SPI】D部分 接续 【AD9361 数字接口CMOS &LVDS&SPI】A 并行数据之CMOS 串行外设接口(SPI) SPI总线为AD9361的所有数字控制提供机制。每个SPI寄存器的宽度为8位,每个寄存器包含控制位、状态监视…

【OpenCV实现平滑图像金字塔,轮廓:入门】

文章目录 概要图像金字塔轮廓:入门 概要 文章内容的概要: 平滑图像金字塔: 图像金字塔是什么? 图像金字塔是指将原始图像按照不同的分辨率进行多次缩小(下采样)得到的一系列图像。这种处理方式常用于图像…

数据链路层和DNS之间的那些事~

数据链路层,考虑的是两个节点之间的传输。这里面的典型协议也很多,最知名的就是“以太网”。我们本篇主要介绍的就是以太网协议。这个协议规定了数据链路层,也规定了物理层的内容。 目录 以太网帧格式 帧头 载荷 帧尾 DNS 从输入URL到…

[读论文] On Joint Learning for Solving Placement and Routing in Chip Design

0. Abstract 由于 GPU 在加速计算方面的优势和对人类专家的依赖较少,机器学习已成为解决布局和布线问题的新兴工具,这是现代芯片设计流程中的两个关键步骤。它仍处于早期阶段,存在一些基本问题:可扩展性、奖励设计和端到端学习范…

获取IEEE会议论文的标题和摘要

获取IEEE会议论文的标题和摘要 – 潘登同学的爬虫笔记 文章目录 获取IEEE会议论文的标题和摘要 -- 潘登同学的爬虫笔记 打开IEEE的高级搜索环境准备完整爬虫过程获取文章地址翻译函数获取文章标题和摘要 前几天接到导师的一个任务,要我去找找IEEE Transactions on K…

vue源码分析(七)—— createComponent

文章目录 前言一、createComponent 参数说明二、createComponent 源码详解1.baseCtor的实际指向2.extend 方法3.判断Ctor是否是函数的判断4.installComponentHooks方法5.返回一个带标识的组件 vnode 前言 createComponent文件的路径: src\core\vdom\create-componen…

【Qt之控件QKeySequenceEdit】分析及使用

描述 QKeySequenceEdit小部件允许输入一个QKeySequence。 该小部件允许用户选择一个QKeySequence,通常用作快捷键。当小部件获取焦点时,录制将开始,并在用户释放最后一个键后的一秒钟结束。 用户可以使用输入键盘来输入键序列。通过调用get…