TensorFlow实现逻辑回归模型

逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将介绍如何使用TensorFlow框架实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来直观地观察模型的训练过程。

数据准备

首先,我们准备两类数据点,分别表示两个不同的类别。这些数据点将作为模型的输入特征。

# 1.散点输入
class1_points=np.array([[1.9,1.2],[1.5,2.1],[1.9,0.5],[1.5,0.9],[0.9,1.2],[1.1,1.7],[1.4,1.1]])
class2_points=np.array([[3.2,3.2],[3.7,2.9],[3.2,2.6],[1.7,3.3],[3.4,2.6],[4.1,2.3],[3.0,2.9]])

将两类数据点合并为一个矩阵,并为每个数据点分配相应的标签(0或1)。

#不用单独提取出x1_data 和x2_data
#框架会根据输入特征数自动提取
x_train=np.concatenate((class1_points,class2_points),axis=0)
y_train=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class2_points))))

将数据转换为TensorFlow张量,以便在模型中使用。

import tensorflow as tfx_train_tensor = tf.convert_to_tensor(x_train, dtype=tf.float32)
y_train_tensor = tf.convert_to_tensor(y_train, dtype=tf.float32)

模型定义

使用TensorFlow的tf.keras模块定义逻辑回归模型。模型包含一个输入层和一个输出层,输出层使用sigmoid激活函数。

def LogisticRegreModel():input = tf.keras.Input(shape=(2,))fc = tf.keras.layers.Dense(1, activation='sigmoid')(input)lr_model = tf.keras.models.Model(inputs=input, outputs=fc)return lr_modelmodel = LogisticRegreModel()

定义优化器和损失函数。这里使用随机梯度下降优化器和二元交叉熵损失函数。

opt = tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(optimizer=opt, loss="binary_crossentropy")

训练过程

训练模型时,我们记录每个epoch的损失值,并动态绘制决策边界和损失曲线。

 

import matplotlib.pyplot as pltfig, (ax1, ax2) = plt.subplots(1, 2)epochs = 500
epoch_list = []
epoch_loss = []for epoch in range(1, epochs + 1):y_pre = model.fit(x_train_tensor, y_train_tensor, epochs=50, verbose=0)epoch_loss.append(y_pre.history["loss"][0])epoch_list.append(epoch)w1, w2 = model.get_weights()[0].flatten()b = model.get_weights()[1][0]slope = -w1 / w2intercept = -b / w2x_min, x_max = 0, 5x = np.array([x_min, x_max])y = slope * x + interceptax1.clear()ax1.plot(x, y, 'r')ax1.scatter(x_train[:len(class1_points), 0], x_train[:len(class1_points), 1])ax1.scatter(x_train[len(class1_points):, 0], x_train[len(class1_points):, 1])ax2.clear()ax2.plot(epoch_list, epoch_loss, 'b')plt.pause(1)

结果展示

训练完成后,决策边界图将显示模型如何将两类数据分开,损失曲线图将显示模型在训练过程中的损失值变化。生成结果基本如图所示:

通过动态绘制决策边界和损失曲线,我们可以直观地观察模型的训练过程,了解模型如何逐渐学习数据的分布并优化决策边界。

总结

本文介绍了如何使用TensorFlow实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来观察模型的训练过程。逻辑回归是一种简单而有效的分类算法,适用于二分类问题。通过TensorFlow框架,我们可以轻松地实现和训练逻辑回归模型,并利用其强大的功能来优化模型的性能。


完整代码

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# 1.散点输入
class1_points=np.array([[1.9,1.2],[1.5,2.1],[1.9,0.5],[1.5,0.9],[0.9,1.2],[1.1,1.7],[1.4,1.1]])
class2_points=np.array([[3.2,3.2],[3.7,2.9],[3.2,2.6],[1.7,3.3],[3.4,2.6],[4.1,2.3],[3.0,2.9]])#不用单独提取出x1_data 和x2_data
#框架会根据输入特征数自动提取
x_train=np.concatenate((class1_points,class2_points),axis=0)
y_train=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class2_points))))
#转化为张量
x_train_tensor=tf.convert_to_tensor(x_train,dtype=tf.float32)
y_train_tensor=tf.convert_to_tensor(y_train,dtype=tf.float32)#2.定义前向模型
# 使用类的方式
# 先设置一下随机数种子
seed=0
tf.random.set_seed(0)def LogisticRegreModel():input=tf.keras.Input(shape=(2,))fc=tf.keras.layers.Dense(1,activation='sigmoid')(input)lr_model=tf.keras.models.Model(inputs=input,outputs=fc)return lr_model
#实例化网络
model=LogisticRegreModel()
#3.定义损失函数和优化器
#定义优化器
#需要输入模型参数和学习率
lr=0.1
opt=tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(optimizer=opt,loss="binary_crossentropy")# 最后画图
fig,(ax1,ax2)=plt.subplots(1,2)
#训练
epoches=500
epoch_list=[]
epoch_loss=[]
for epoch in range(1,epoches+1):# verbose=0 进度条不显示  epochs迭代次数y_pre=model.fit(x_train_tensor,y_train_tensor,epochs=50,verbose=0)# print(y_pre.history["loss"])epoch_loss.append(y_pre.history["loss"][0])epoch_list.append(epoch)w1,w2=model.get_weights()[0].flatten()b=model.get_weights()[1][0]#画左图# 使用斜率和截距画直线#目前将x2当作y轴 x1当作x轴# w1*x1+w2*x2+b=0#求出斜率和截距slope=-w1/w2intercept=-b/w2#绘制直线 开始结束位置x_min,x_max=0,5x=np.array([x_min,x_max])y=slope*x+interceptax1.clear()ax1.plot(x,y,'r')#画散点图ax1.scatter(x_train[:len(class1_points),0],x_train[:len(class1_points),1])ax1.scatter(x_train[len(class1_points):, 0],x_train[len(class1_points):, 1])#画右图ax2.clear()ax2.plot(epoch_list,epoch_loss,'b')plt.pause(1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/69467.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity git版本管理

创建仓库的时候添加了Unity的.gitignore模版,在这个时候就能自动过滤不需要的文件 打开git bash之后,步骤git版本管理-CSDN博客 如果报错,尝试重新进git 第一次传会耗时较长,之后的更新就很快了

【AI论文】扩散对抗后训练用于一步视频生成总结

摘要:扩散模型被广泛应用于图像和视频生成,但其迭代生成过程缓慢且资源消耗大。尽管现有的蒸馏方法已显示出在图像领域实现一步生成的潜力,但它们仍存在显著的质量退化问题。在本研究中,我们提出了一种在扩散预训练后针对真实数据…

低代码系统-产品架构案例介绍、明道云(十一)

明道云HAP-超级应用平台(Hyper Application Platform),其实就是企业级应用平台,跟微搭类似。 通过自设计底层架构,兼容各种平台,使用低代码做到应用搭建、应用运维。 企业级应用平台最大的特点就是隐藏在冰山下的功能很深&#xf…

2025年AI手机集中上市,三星Galaxy S25系列上市

2025年被认为是AI手机集中爆发的一年,各大厂商都会推出搭载人工智能的智能手机。三星Galaxy S25系列全球上市了。 三星Galaxy S25系列包含S25、S25和S25 Ultra三款机型,起售价为800美元(约合人民币5800元)。全系搭载骁龙8 Elite芯…

【ESP32】ESP-IDF开发 | WiFi开发 | TCP传输控制协议 + TCP服务器和客户端例程

1. 简介 TCP(Transmission Control Protocol),全称传输控制协议。它的特点有以下几点:面向连接,每一个TCP连接只能是点对点的(一对一);提供可靠交付服务;提供全双工通信&…

2025数学建模美赛|赛题翻译|E题

2025数学建模美赛,E题赛题翻译 更多美赛内容持续更新中...

【Elasticsearch】Elasticsearch的查询

Elasticsearch的查询 DSL查询基础语句叶子查询全文检索查询matchmulti_match 精确查询termrange 复合查询算分函数查询bool查询 排序分页基础分页深度分页 高亮高亮原理实现高亮 RestClient查询基础查询叶子查询复合查询排序和分页高亮 数据聚合DSL实现聚合Bucket聚合带条件聚合…

什么是循环神经网络?

一、概念 循环神经网络(Recurrent Neural Network, RNN)是一类用于处理序列数据的神经网络。与传统的前馈神经网络不同,RNN具有循环连接,可以利用序列数据的时间依赖性。正因如此,RNN在自然语言处理、时间序列预测、语…

深入探索C++17的std::any:类型擦除与泛型编程的利器

文章目录 基本概念构建方式构造函数直接赋值std::make_anystd::in_place_type 访问值值转换引用转换指针转换 修改器emplaceresetswap 观察器has_valuetype 使用场景动态类型的API设计类型安全的容器简化类型擦除实现 性能考虑动态内存分配类型转换和异常处理 总结 在C17的标准…

物管系统赋能智慧物业管理提升服务质量与工作效率的新风潮

内容概要 在当今的物业管理领域,物管系统的崛起为智慧物业管理带来了新的机遇和挑战。这些先进的系统能够有效整合各类信息,促进数字化管理,从而提升服务质量和工作效率。通过物管系统,物业管理者可以实时查看和分析各种数据&…

分组表格antd+ react +ts

import React from "react"; import { Table, Tag } from "antd"; import styles from "./index.less"; import GroupTag from "../Tag"; const GroupTable () > {const columns [{title: "姓名",dataIndex: "nam…

【JAVA实战】如何使用 Apache POI 在 Java 中写入 Excel 文件

大家好!🌟 在这篇文章中,我们将带你深入学习如何使用 Apache POI 在 Java 中编写 Excel 文件的技巧!📊📚 如果你是 Java 开发者,或者正在探索如何处理 Excel 文件的数据,那么这篇文章…

使用Avalonia UI实现DataGrid

1.Avalonia中的DataGrid的使用 DataGrid 是客户端 UI 中一个非常重要的控件。在 Avalonia 中,DataGrid 是一个独立的包 Avalonia.Controls.DataGrid,因此需要单独通过 NuGet 安装。接下来,将介绍如何安装和使用 DataGrid 控件。 2.安装 Dat…

C#分页思路:双列表数据组合返回设计思路

一、应用场景 需要分页查询(并非全表查载入物理内存再筛选),返回列表1和列表2叠加的数据时 二、实现方式 列表1必查,列表2根据列表1的查询结果决定列表2的分页查询参数 三、示意图及其实现代码 1.示意图 黄色代表list1的数据&a…

【Linux】磁盘

没有被打开的文件 文件在磁盘中的存储 认识磁盘 磁盘的存储构成 磁盘的效率 与磁头运动频率有关。 磁盘的逻辑结构 把一面展开成线性。 通过扇区的下标编号可以推算出在磁盘的位置。 磁盘的寄存器 控制寄存器:负责告诉磁盘是读还是写。 数据寄存器:给…

第13章 深入volatile关键字(Java高并发编程详解:多线程与系统设计)

1.并发编程的三个重要特性 并发编程有三个至关重要的特性,分别是原子性、有序性和可见性 1.1 原子性 所谓原子性是指在一次的操作或者多次操作中,要么所有的操作全部都得到了执行并 且不会受到任何因素的干扰而中断,要么所有的操作都不执行…

记录 | Docker的windows版安装

目录 前言一、1.1 打开“启用或关闭Windows功能”1.2 安装“WSL”方式1:命令行下载方式2:离线包下载 二、Docker Desktop更新时间 前言 参考文章:Windows Subsystem for Linux——解决WSL更新速度慢的方案 参考视频:一个视频解决D…

stack 和 queue容器的介绍和使用

1.stack的介绍 1.1stack容器的介绍 stack容器的基本特征和功能我们在数据结构篇就已经详细介绍了,还不了解的uu, 可以移步去看这篇博客哟: 数据结构-栈数据结构-队列 简单回顾一下,重要的概念其实就是后进先出,栈在…

JUC--ConcurrentHashMap底层原理

ConcurrentHashMap底层原理 ConcurrentHashMapJDK1.7底层结构线程安全底层具体实现 JDK1.8底层结构线程安全底层具体实现 总结JDK 1.7 和 JDK 1.8实现有什么不同?ConcurrentHashMap 中的 CAS 应用 ConcurrentHashMap ConcurrentHashMap 是一种线程安全的高效Map集合…

C++17 std::variant 详解:概念、用法和实现细节

文章目录 简介基本概念定义和使用std::variant与传统联合体union的区别 多类型值存储示例初始化修改判断variant中对应类型是否有值获取std::variant中的值获取当前使用的type在variant声明中的索引 访问std::variant中的值使用std::get使用std::get_if 错误处理和访问未初始化…