TensorFlow 简单的二分类神经网络的训练和应用流程

展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括:

1. 数据准备与预处理

2. 构建模型

3. 编译模型

4. 训练模型

5. 评估模型

6. 模型应用与部署

加载和应用已训练的模型


1. 数据准备与预处理

在本例中,数据准备是通过两个 Numpy 数组来完成的:

  • x:输入特征,形状为 (8, 2),包含 8 个数据点,每个数据点有 2 个特征。
  • y:标签,形状为 (8,),包含对应的 0 或 1 标签,表示每个输入点的类别。
x = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1], [0.7, 0.7], [0.7, -0.7], [-0.7, -0.7], [-0.7, 0.7]])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0])

2. 构建模型

使用 Keras 的 Sequential 模型来构建神经网络。此模型包含两个全连接层(Dense 层):

  • 第一个 Dense 层有 3 个单位,激活函数是 Sigmoid。
  • 第二个 Dense 层有 1 个单位,激活函数是 Sigmoid,输出层的激活函数将模型输出的值映射到 0 到 1 之间,适合二分类任务。
l1 = tf.keras.layers.Dense(units=3, activation='sigmoid')
l2 = tf.keras.layers.Dense(units=1, activation='sigmoid')
model = tf.keras.Sequential([l1, l2])

3. 编译模型

在编译阶段,我们选择了优化器、损失函数和评估指标:

  • 优化器:SGD(随机梯度下降),学习率设置为 0.9。
  • 损失函数:binary_crossentropy,适用于二分类任务。
  • 评估指标:accuracy,表示训练过程中对分类准确率的衡量。
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])

4. 训练模型

通过 model.fit() 函数来训练模型。我们传入训练数据 x 和标签 y,并设置训练的 epoch 数量为 2000。

model.fit(x, y, epochs=2000)

5. 评估模型

在此示例中,评估部分通过训练后的 model 来进行,并没有显式写出 evaluate() 函数。评估通常是在训练之后,通过测试集或验证集对模型性能进行评估,具体可以使用 model.evaluate() 来查看损失和准确度。

6. 模型应用与部署

训练完成后,我们保存了训练好的模型。保存后的模型可以被加载和应用于新的数据集。

model.save('my_model.keras')  # 保存模型

7.加载和应用已训练的模型

加载保存的模型,并用其对新数据进行预测。model.predict() 方法返回的是预测的概率值,我们通过设置阈值(如 0.9)将其转换为类别(0 或 1)。

model = tf.keras.models.load_model('my_model.keras')  # 加载模型
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])  # 新的输入数据
predictions = model.predict(nx)  # 获取预测结果
print(predictions)  # 输出概率# 将概率转化为类别
predicted_classes = (predictions > 0.9).astype(int)
print(predicted_classes)  # 输出最终的类别预测

8.完整代码
test.py 训练模型

import tensorflow as tf
import numpy as np
# 创建int32类型的0维张量,即标量
l1=tf.keras.layers.Dense(units=3,activation='sigmoid')
l2=tf.keras.layers.Dense(units=1,activation='sigmoid')
model=tf.keras.Sequential([l1,l2])
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
x=np.array([[1,1],[1,-1],[-1,1],[-1,-1],[0.7,0.7],[0.7,-0.7],[-0.7,-0.7],[-0.7,0.7]])
y=np.array([1,1,1,1,0,0,0,0])
model.fit(x,y,epochs=2000)
# 保存训练好的模型(Keras 格式)
model.save('my_model.keras')

 test2.py加载模型并进行预测:

import tensorflow as tf
import numpy as np# 加载训练好的模型
model = tf.keras.models.load_model('my_model.keras')# 预测数据
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])# 获取预测结果
predictions = model.predict(nx)# 输出预测结果
print(predictions)# 如果需要将概率转化为类别(0或1)
predicted_classes = (predictions > 0.9).astype(int)# 输出最终的类别预测
print(predicted_classes)

9.视频分享


初识TensorFlow 
https://v.douyin.com/ifG2mmLH/
复制此链接,打开Dou音搜索,直接观看视频!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/68901.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无人机PX4飞控 | PX4源码添加自定义uORB消息并保存到日志

PX4源码添加自定义uORB消息并保存到日志 0 前言 PX4的内部通信机制主要依赖于uORB(Micro Object Request Broker),这是一种跨进程的通信机制,一种轻量级的中间件,用于在PX4飞控系统的各个模块之间进行高效的数据交换…

XCCL、NCCL、HCCL通信库

XCCL提供的基本能力 XCCL提供的基本能力 不同的XCCL 针对不同的网络拓扑,实现的是不同的优化算法的(不同CCL库最大的区别就是这) 不同CCL库还会根据自己的硬件、系统,在底层上面对一些相对应的改动; 但是对上的API接口…

Docker快速部署高效照片管理系统LibrePhotos搭建私有云相册

文章目录 前言1.关于LibrePhotos2.本地部署LibrePhotos3.LibrePhotos简单使用4. 安装内网穿透5.配置LibrePhotos公网地址6. 配置固定公网地址 前言 想象一下这样的场景:你有一大堆珍贵的回忆照片,但又不想使用各种网盘来管理。怎么办?别担心…

【Java计算机毕业设计】基于Springboot的物业信息管理系统【源代码+数据库+LW文档+开题报告+答辩稿+部署教程+代码讲解】

源代码数据库LW文档(1万字以上)开题报告答辩稿 部署教程代码讲解代码时间修改教程 一、开发工具、运行环境、开发技术 开发工具 1、操作系统:Window操作系统 2、开发工具:IntelliJ IDEA或者Eclipse 3、数据库存储&#xff1a…

深入解析Python机器学习库Scikit-Learn的应用实例

深入解析Python机器学习库Scikit-Learn的应用实例 随着人工智能和数据科学领域的迅速发展,机器学习成为了当下最炙手可热的技术之一。而在机器学习领域,Python作为一种功能强大且易于上手的编程语言,拥有庞大的生态系统和丰富的机器学习库。其…

高斯光束介绍及光斑处理

常规激光器的光斑为高斯光斑,即中心能量集中,边缘能量较低。一般定义光强的处为高斯光束的半径。高斯光斑的传输由光斑半径、远场发散角、波长等决定。 其中为位置z处的光斑半径,w(z), k2pi/λ为波矢,λ为光波长,R为高…

C++哈希(链地址法)(二)详解

文章目录 1.开放地址法1.1key不能取模的问题1.1.1将字符串转为整型1.1.2将日期类转为整型 2.哈希函数2.1乘法散列法(了解)2.2全域散列法(了解) 3.处理哈希冲突3.1线性探测(挨着找)3.2二次探测(跳…

【Redis】List 类型的介绍和常用命令

1. 介绍 Redis 中的 list 相当于顺序表,并且内部更接近于“双端队列”,所以也支持头插和尾插的操作,可以当做队列或者栈来使用,同时也存在下标的概念,不过和 Java 中的下标不同,Redis 支持负数下标&#x…

携程Java开发面试题及参考答案 (200道-上)

说说四层模型、七层模型。 七层模型(OSI 参考模型) 七层模型,即 OSI(Open System Interconnection)参考模型,是一种概念模型,用于描述网络通信的架构。它将计算机网络从下到上分为七层,各层的功能和作用如下: 物理层:物理层是计算机网络的最底层,主要负责传输比特流…

IM 即时通讯系统-51-MPush开源实时消息推送系统

IM 开源系列 IM 即时通讯系统-41-开源 野火IM 专注于即时通讯实时音视频技术,提供优质可控的IMRTC能力 IM 即时通讯系统-42-基于netty实现的IM服务端,提供客户端jar包,可集成自己的登录系统 IM 即时通讯系统-43-简单的仿QQ聊天安卓APP IM 即时通讯系统-44-仿QQ即…

AlexNet论文代码阅读

论文标题: ImageNet Classification with Deep Convolutional Neural Networks 论文链接: https://volctracer.com/w/BX18q92F 代码链接: https://github.com/dansuh17/alexnet-pytorch 内容概述 训练了一个大型的深度卷积神经网络&#xf…

扩散模型(三)

相关阅读: 扩散模型(一) 扩散模型(二) Latent Variable Space 潜在扩散模型(LDM;龙巴赫、布拉特曼等人,2022 年)在潜在空间而非像素空间中运行扩散过程,这…

git基础使用--4---git分支和使用

文章目录 git基础使用--4---git分支和使用1. 按顺序看2. 什么是分支3. 分支的基本操作4. 分支的基本操作4.1 查看分支4.2 创建分支4.3 切换分支4.4 合并冲突 git基础使用–4—git分支和使用 1. 按顺序看 -git基础使用–1–版本控制的基本概念 -git基础使用–2–gti的基本概念…

8.攻防世界Web_php_wrong_nginx_config

进入题目页面如下 尝试弱口令密码登录 一直显示网站建设中,尝试无果,查看源码也没有什么特别漏洞存在 用Kali中的dirsearch扫描根目录试试 命令: dirsearch -u http://61.147.171.105:53736/ -e* 登录文件便是刚才登录的界面打开robots.txt…

【漫话机器学习系列】076.合页损失函数(Hinge Loss)

Hinge Loss损失函数 Hinge Loss(合页损失),也叫做合页损失函数,广泛用于支持向量机(SVM)等分类模型的训练过程中。它主要用于二分类问题,尤其是支持向量机中的优化目标函数。 定义与公式 对于…

python算法和数据结构刷题[5]:动态规划

动态规划(Dynamic Programming, DP)是一种算法思想,用于解决具有最优子结构的问题。它通过将大问题分解为小问题,并找到这些小问题的最优解,从而得到整个问题的最优解。动态规划与分治法相似,但区别在于动态…

本地Deepseek添加个人知识库(Page Assist/AnythingLLM)

本地Deepseek两种方法建立知识库 前言 (及个人测试结论)法一、在Page Assist建立知识库step1 下载nomic-embed-textstep2 加载进Page Assiststep3 添加知识step4 对话框添加知识库 法二、在AnythingLLM建立知识库准备工作1.下载nomic-embed-text2.下载An…

记8(高级API实现手写数字识别

目录 1、Keras:2、Sequential模型:2.1、建立Sequential模型:modeltf.keras.Sequential()2.2、添加层:model.add(tf.keras.layers.层)2.3、查看摘要:model.summary()2.4、配置训练方法:model.compile(loss,o…

grpc 和 http 的区别---二进制vsJSON编码

gRPC 和 HTTP 是两种广泛使用的通信协议,各自适用于不同的场景。以下是它们的详细对比与优势分析: 一、核心特性对比 特性gRPCHTTP协议基础基于 HTTP/2基于 HTTP/1.1 或 HTTP/2数据格式默认使用 Protobuf(二进制)通常使用 JSON/…

文字投影效果

大家好,我是喝西瓜汁的兔叽,今天给大家分享一个常见的文字投影效果。 效果展示 我们来实现一个这样的文字效果。 思路分析 这样的效果如何实现的呢? 实际上是两组相同的文字,叠合在一块,只不过对应的css不同罢了。 首先&…