matplotlib 动态显示训练过程中的数据和模型的决策边界

文章目录

  • Github
  • 官网
  • 文档
  • 简介
  • 动态显示训练过程中的数据和模型的决策边界
    • 安装
    • 源码

Github

  • https://github.com/matplotlib/matplotlib

官网

  • https://matplotlib.org/stable/

文档

  • https://matplotlib.org/stable/api/index.html

简介

matplotlib 是 Python 中最常用的绘图库之一,用于创建各种类型的静态、动态和交互式可视化。

动态显示训练过程中的数据和模型的决策边界

在这里插入图片描述

安装

pip install tensorflow==2.13.1
pip install matplotlib==3.7.5
pip install numpy==1.24.3

源码

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap# 生成数据
np.random.seed(0)
num_samples_per_class = 500
negative_samples = np.random.multivariate_normal(mean=[0, 3],cov=[[1, 0.5], [0.5, 1]],size=num_samples_per_class
)
positive_samples = np.random.multivariate_normal(mean=[3, 0],cov=[[1, 0.5], [0.5, 1]],size=num_samples_per_class
)inputs = np.vstack((negative_samples, positive_samples)).astype(np.float32)
targets = np.vstack((np.zeros((num_samples_per_class, 1)), np.ones((num_samples_per_class, 1)))).astype(np.float32)# 将数据分为训练集和测试集
train_size = int(0.8 * len(inputs))
X_train, X_test = inputs[:train_size], inputs[train_size:]
y_train, y_test = targets[:train_size], targets[train_size:]# 构建二分类模型
model = Sequential([# 输入层:输入形状为 (2,)# 第一个隐藏层:包含 4 个节点,激活函数使用 ReLUDense(4, activation='relu', input_shape=(2,)),# 输出层:包含 1 个节点,激活函数使用 Sigmoid(因为是二分类问题)Dense(1, activation='sigmoid')
])# 编译模型
# 指定优化器为 Adam,损失函数为二分类交叉熵,评估指标为准确率
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])# 准备绘图
fig, ax = plt.subplots()
cmap_light = ListedColormap(['#FFAAAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#0000FF'])# 动态绘制函数
def plot_decision_boundary(epoch, logs):ax.clear()x_min, x_max = X_train[:, 0].min() - 1, X_train[:, 0].max() + 1y_min, y_max = X_train[:, 1].min() - 1, X_train[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),np.arange(y_min, y_max, 0.1))grid = np.c_[xx.ravel(), yy.ravel()]probs = model.predict(grid).reshape(xx.shape)ax.contourf(xx, yy, probs, alpha=0.8, cmap=cmap_light)ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train[:, 0], edgecolor='k', cmap=cmap_bold)ax.set_title(f'Epoch {epoch+1}')plt.draw()plt.pause(0.01)# 自定义回调函数
class PlotCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs=None):plot_decision_boundary(epoch, logs)# 训练模型并动态显示
plot_callback = PlotCallback()
model.fit(X_train, y_train, epochs=50, batch_size=16, callbacks=[plot_callback])# 评估模型
loss, accuracy = model.evaluate(X_test, y_test)
print(f"Test Loss: {loss}")
print(f"Test Accuracy: {accuracy}")plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/25203.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ghidra

https://github.com/NationalSecurityAgency/ghidra ghidra是一个so的逆向工具,功能和jadx-gui类似,但是和jadx-gui专注于java层的不同,ghidra专注于native层的代码反编译(从二进制到c语言)。 一、 安装 准备好java1…

解释一下I/O多路复用模型?

想象一下,你是一家小餐馆的老板,你的工作是接收顾客的订单,然后通知厨师开始准备。如果每次只能等一个顾客点完菜再接待下一个,那效率就太低了,顾客可能要等很久。 现在,有一种聪明的做法叫做“I/O多路复用…

js理解异步编程和回调

什么是异步 计算机在设计上是异步的。 异步意味着事情可以独立于主程序流发生。 当你打开一个网页,网页载入的过程,你又打开了编译器,那么你在网页载入时启动了编译器的行为就是计算机的异步, 可以看出计算机时一个超大的异步…

华为防火墙 1

华为防火墙1 实验拓扑: 实验步骤: 1.完成终端基本IP信息配置 2.配置防火墙: 2.1配置IP地址 sys Enter system view, return user view with CtrlZ. [USG6000V1]undo in e Info: Saving log files… Info: Information center is disabled. […

《科学,无尽的前沿》—— 程序员必读

一、总体概述 《科学,无尽的前沿》(Science The Endless Frontier)开创了大政治推动下的大科学工程新范式,被视为“美国科学政策的开山之作,促成了支持科学的“美国战后共识”,是“美国历史上最具影响力的政策文件之一”。 报告…

场内基金和场外基金的区别

场内基金就是只能在场内买卖的基金,只有股票账户才能买。场外就是在场内以外的交易场所买卖的基金。 场内基金和场外基金区别主要是费用和买卖价格。 场内基金和场外基金都有管理费和托管费。二者不同的主要是交易费用。场内基金买卖都要交交易费用,这…

基于小波脊线的一维时间序列信号分解方法(MATLAB R2018A)

信号分解技术是把一个复杂信号分解为若干含有时频信息的简单信号,研可通过分解后的简单信号来读取和分析复杂信号的有效特征。因此,信号分解技术对分析结果的影响是不言而喻的。 傅里叶分解是早期常用的信号分解方法,最初被用于分析热过程&a…

心链7 ----Redis的引入和实现以及缓存和定时任务应用

心链 — 伙伴匹配系统 Redis 数据查询慢怎么办? 用缓存:提前把数据取出来保存好(通常保存到读写更快的介质,比如内存),就可以更快地读写。 缓存 Redis(分布式缓存)memcached&…

JavaScript基础(十二)

截取字符串 //对象名.toLowerCase();将字符串转为小写 var strLAOWANG; strstr.toLowerCase(); console.log(str); //对象名.toUpperCase();将字符串转为大写 var str1laowang str1str1.toUpperCase(); console.log(str1); 截取字符串 //方法1:对象名.substr(a,b); …

Unity世界坐标下UI始终朝向摄像机

Unity世界坐标下UI始终朝向摄像机 1、第一种方法UI会反过来 void Update(){this.transform.LookAt(Camera.main.transform.position);}2、第二种方法 Transform m_Camera;void Start(){m_Camera Camera.main.transform;}void LateUpdate(){transform.rotation Quaternion.Lo…

kafka-生产者事务-数据传递语义事务介绍事务消息发送(SpringBoot整合Kafka)

文章目录 1、kafka数据传递语义2、kafka生产者事务3、事务消息发送3.1、application.yml配置3.2、创建生产者监听器3.3、创建生产者拦截器3.4、发送消息测试3.5、使用Java代码创建主题分区副本3.6、屏蔽 kafka debug 日志 logback.xml3.7、引入spring-kafka依赖3.8、控制台日志…

.shape 和 .size的区别

在 Python 中,尤其是使用 numpy 和 torch 库进行数组和张量操作时,.shape 和 .size() 是两个非常常见的方法。虽然它们有时可以互换使用,但它们确实有一些细微的区别。 .shape 属性 类型:.shape 是一个属性。返回值:…

关于烫烫烫和屯屯屯

微较的msvc编译器,调试模式下为了方便检测内存的非法访问,对于不同的内存做了初始化, 未初始化栈: 0xCCCCCCCC 未初始化堆: 0xCDCDCDCD 已释放的堆: 0xDDDDDDDD 0xCCCC解释为GB2312字符即是烫&#xff…

Django 鸡与蛋问题

"Django 的鸡与蛋问题"通常指的是在开始 Django 项目时,你可能会遇到的一个困境:是先设计数据库模型还是先编写视图和控制器(即视图函数)? 这个问题的实质是在于,Django 的核心部分是由数据库模…

Qt5/6使用SqlServer用户连接操作SqlServer数据库

网上下载SQLServer2022express版数据库,这里没啥可说的,随你喜欢,也可以下载Develop版本。安装完后,我们可以直接连接尝试, 不过一般来说,还是下载SQLServer管理工具来连接数据更加方便。 所以直接下载ssms, 我在用的时候,一开始只能用Windows身份登录。 所以首先,我…

入门matlab

常识 如何建一个新文件 创建新文件,点击新建,我们就可以开始写代码了 为什么要在代码开头加入clear 假如我们有2个文件,第一个文件里面给x赋值100,第二个文件为输出x 依次运行: 结果输出100,这是因为它们…

013-Linux交换分区管理

一、交换分区的作用 ”提升“内存容量,防止OOM(out of memory,内存溢出)。 从功能上讲,交换分区主要是在内存不够用的时候,将部分内存上的数据交换到swap空间上,以便让系统不会因内存不够用而…

ChatGPT Prompt技术全攻略-精通篇:Prompt工程技术的高级应用

系列篇章💥 No.文章1ChatGPT Prompt技术全攻略-入门篇:AI提示工程基础2ChatGPT Prompt技术全攻略-进阶篇:深入Prompt工程技术3ChatGPT Prompt技术全攻略-高级篇:掌握高级Prompt工程技术4ChatGPT Prompt技术全攻略-应用篇&#xf…

Web前端Hack:深入探索、挑战与防范

Web前端Hack:深入探索、挑战与防范 在数字化时代的浪潮中,Web前端作为用户与互联网世界交互的桥梁,其安全性日益受到关注。然而,Web前端也面临着各种潜在的攻击和风险。今天,我们将一起探索Web前端Hack的四个方面、五…

电脑缺失msvcp110.dll文件的解决方法,总结5种靠谱的方法

在计算机使用过程中,我们可能会遇到一些错误提示,其中之一就是“找不到msvcp110.dll”。这个错误提示通常出现在运行某些软件时,那么,它究竟会造成哪些问题呢? 一,msvcp110.dll文件概述 msvcp110.dll是Mic…