深度学习(三)

5.Functional API 搭建神经网络模型
5.1利用Functional API编写宽深神经网络模型进行手写数字识别
import numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom sklearn.datasets import load_irisfrom sklearn.model_selection import train_test_splitfrom tensorflow.keras.layers import Input, Dense, concatenatefrom tensorflow.keras.models import Modeliris = load_iris()x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23)X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12)print(X_valid.shape)print(X_train.shape)inputs = Input(shape=X_train.shape[1:])hidden1 = Dense(300, activation="relu")(inputs)hidden2 = Dense(100, activation="relu")(hidden1)concat = concatenate([inputs, hidden2])output = Dense(10, activation="softmax")(concat)model_wide_deep = Model(inputs=inputs, outputs=output)

iris = load_iris():加载iris数据集,这是一个常用的多类分类数据集,包含了150个样本,每个样本有4个特征,属于3个不同的类别。

x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=23):将iris数据集分割为训练集和测试集。test_size=0.2表示测试集的大小为原始数据的20%,random_state=23是一个随机种子,确保分割的可重复性。

X_train, X_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=12):进一步将训练集分割为训练集和验证集。同样,test_size=0.2表示验证集的大小为分割后训练数据的20%,random_state=12确保分割的可重复性。

print(X_valid.shape):打印验证集的特征数据的形状。

print(X_train.shape):打印新的训练集的特征数据的形状。

inputs = Input(shape=X_train.shape[1:]):定义模型的输入层,shape=X_train.shape[1:]指定输入的形状,由于X_train是一个二维数组,shape[1:]表示除了第一维(样本数量)之外的所有维度。

hidden1 = Dense(300, activation="relu")(inputs):定义第一个隐藏层,它有300个神经元,并使用ReLU激活函数。

hidden2 = Dense(100, activation="relu")(hidden1):定义第二个隐藏层,它有100个神经元,并使用ReLU激活函数。

concat = concatenate([inputs, hidden2]):将输入层和第二个隐藏层的输出拼接起来,形成更宽的网络。

output = Dense(10, activation="softmax")(concat):定义输出层,它有10个神经元(对应于3个类别和一个额外的神经元,这是常见的做法),并使用softmax激活函数输出概率分布。

model_wide_deep = Model(inputs=inputs, outputs=output):创建一个Keras模型,将输入层和输出层连接起来。

使用scikit-learn库中的load_iris函数来加载iris数据集,然后使用train_test_split函数将数据集分割为训练集和测试集,以及进一步的训练集和验证集。接着,它定义了一个宽深网络(wide and deep network)模型,其中包含了输入层、两个隐藏层和一个输出层。

model_wide_deep.summary()

model_wide_deep.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])h = model_wide_deep.fit(X_train, y_train, batch_size=32, epochs=30,validation_data=(X_valid, y_valid))

# 绘图pd.DataFrame(h.history).plot(figsize=(8,5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()

# 使用 model_wide_deep 评估测试集test_loss, test_accuracy = model_wide_deep.evaluate(x_test, y_test, batch_size=32)print(f"Test Loss: {test_loss}")print(f"Test Accuracy: {test_accuracy}")

6.SubClassing API 搭建神经网络模型

以前馈全连接神经网络手写数字识别为例

import numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom sklearn.datasets import load_irisfrom sklearn.model_selection import train_test_splitfrom tensorflow.keras.layers import Input, Dense, concatenatefrom tensorflow.keras.models import Modelfrom tensorflow.keras import backend as K# 加载数据集iris = load_iris()X = iris.datay = iris.target# 分割数据集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=23)X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=12)# 打印验证集和训练集的形状print(X_valid.shape)print(X_train.shape)# 定义 Model_sub_fn 类class Model_sub_fn(Model):def __init__(self, units_1, units_2, units_out, activation="relu"):super(Model_sub_fn, self).__init__()self.hidden1 = Dense(units_1, activation=activation)self.hidden2 = Dense(units_2, activation=activation)self.main_output = Dense(units_out, activation="softmax")def call(self, inputs):x = self.hidden1(inputs)x = self.hidden2(x)return self.main_output(x)

定义了一个名为Model_sub_fn的类,该类继承自tensorflow.keras.Model。这个类用于创建一个简单的神经网络模型,它包含两个隐藏层和一个输出层。

class Model_sub_fn(Model)定义一个名为Model_sub_fn的类,它继承自tensorflow.keras.Model。这意味着Model_sub_fn类可以访问和继承Model类的所有属性和方法。

def __init__(self, units_1, units_2, units_out, activation="relu"):定义类的构造函数__init__,它接受四个参数:units_1(第一个隐藏层的神经元数量)、units_2(第二个隐藏层的神经元数量)、units_out(输出层的神经元数量)和activation(激活函数类型,默认为ReLU)。

super(Model_sub_fn, self).__init__():调用父类的构造函数,这是继承自Model类的标准做法。

self.hidden1 = Dense(units_1, activation=activation):定义第一个隐藏层,它有units_1个神经元,并使用activation作为激活函数。

self.hidden2 = Dense(units_2, activation=activation):定义第二个隐藏层,它有units_2个神经元,并使用activation作为激活函数。

self.main_output = Dense(units_out, activation="softmax"):定义输出层,它有units_out个神经元,并使用softmax作为激活函数。

def call(self, inputs):定义call方法,这是所有Keras模型必须定义的方法,它用于前向传播。在这个方法中,输入数据通过两个隐藏层,最后通过输出层。

x = self.hidden1(inputs):将输入数据通过第一个隐藏层。

x = self.hidden2(x):将第一个隐藏层的输出通过第二个隐藏层。

return self.main_output(x):将第二个隐藏层的输出通过输出层,并返回结果。

model_sub_fn = Model_sub_fn(units_1=64, units_2=32, units_out=3)# 创建 Model_sub_fn 实例model_sub_fn = Model_sub_fn(300, 100, 3, activation="relu")  # 假设输出层有3个单元,因为Iris数据集有3个类别# 编译模型model_sub_fn.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])# 训练模型history = model_sub_fn.fit(X_train, y_train, batch_size=32, epochs=30, validation_data=(X_valid, y_valid))

model_sub_fn.summary()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/22004.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Stable diffusion文生图大模型——隐扩散模型原理解析

1、前言 本篇文章,我们将讲这些年非常流行的文生图大模型——Stable Diffusion。该模型也不难,甚至说很简单。创新点也相对较少,如果你学会了我以前的文章讲过的模型,学习这个也自然水到渠成! 参考论文:H…

AdminController

目录 1、 AdminController 1.1、 Students 1.2、 StudentDetail 1.3、 EnrolledStudent AdminController using ITM_College.Data; using ITM_College.Models; using Microsoft.AspNetCore.Mvc; using Microsoft.EntityFrameworkCore; namespace ITM_College.Cont…

大白话说明:k8s-Service资源的理解以及与Ingress Controller 做区分

一、什么是Service? 1、Service只是后段一堆Pod的集合抽象概念而已。 2、当你创建Service资源时会同时创建一个名称为Endpoints的资源对象来List 这些实际的Pod的IP。 二、Service代理流量吗? 1、不代理。 2、实际处理流量是每个节点上的kube-prox…

tomcat-valve通过servlet处理请求

上一节说到请求url定位servlet的过程,tomcat会把请求url和容器的映射关系保存到MappingData中,org.apache.catalina.connector.Request类实现了HttpServletRequest,其中定义了属性mappingDataprotected final MappingData mappingData new M…

腾讯云 TDMQ for Apache Pulsar 多地区高可用容灾实践

作者介绍 林宇强 腾讯云高级工程师 专注于消息队列、API网关、微服务、数据同步等 PaaS 领域。有多年的开发和维护经验,目前在腾讯云从事 TDMQ Pulsar 商业化产品方向的研发工作。 导语 本文将从四个维度,深入剖析 Pulsar 在多可用区高可用领域的容…

大数据—元数据管理

在大数据环境中,元数据管理是确保数据资产有效利用和治理的关键组成部分。元数据是描述数据的数据,它提供了关于数据集的上下文信息,包括数据的来源、格式、结构、关系、质量、处理历史和使用方式等。有效的元数据管理有助于提高数据的可发现…

Amazon云计算AWS(四)

目录 八、其他Amazon云计算服务(一)快速应用部署Elastic Beanstalk和服务模板CloudFormation(二)DNS服务Router 53(三)虚拟私有云VPC(四)简单通知服务和简单邮件服务(五&…

【劲舞团game】

编写《劲舞团》这样的游戏代码是一个复杂的过程,涉及到游戏引擎的使用、图形渲染、物理模拟、音频处理、网络通信等多个方面。以下是一个非常简化的步骤,用于说明如何开始编写一个基于Unity引擎的简单舞蹈游戏: 1. 准备开发环境 下载并安装…

LeetCode刷题之HOT100之全排列

九点半了&#xff0c;做题吧。聊天聊到十一点多哈哈。 1、题目描述 2、逻辑分析 给定一个不重复数组&#xff0c;要求返回所有可能的全排列。这道题跟我上一道题思想一致&#xff0c;都是使用到回溯的算法思想来解决。直接用代码来解释吧 3、代码演示 public List<List&…

MongoDB环境搭建

一.下载安装包 Download MongoDB Community Server | MongoDB 二、双击下载完成后的安装包开始安装&#xff0c;除了以下两个部分需要注意操作&#xff0c;其他直接next就行 三.可视化界面安装 下载MongoDB-compass&#xff0c;地址如下 MongoDB Compass Download (GUI) | M…

Golang | Leetcode Golang题解之第129题求根节点到叶节点数字之和

题目&#xff1a; 题解&#xff1a; type pair struct {node *TreeNodenum int }func sumNumbers(root *TreeNode) (sum int) {if root nil {return}queue : []pair{{root, root.Val}}for len(queue) > 0 {p : queue[0]queue queue[1:]left, right, num : p.node.Left, …

overpass-api 部署(docker)

简介 Overpass是一个用于访问和查询OpenStreetMap&#xff08;OSM&#xff09;数据的开放式数据API和查询语言。OpenStreetMap是一个由社区驱动的免费开放地图项目&#xff0c;用户可以贡献地理数据并使用它来创建自由和开放的地图。 Overpass API提供了一种强大的方式来获取和…

Spire.PDF for .NET【文档操作】演示:在 C# 中向 PDF 文件添加图层

Spire.PDF 完美支持将多页 PDF 拆分为单页。但是&#xff0c;更常见的情况是&#xff0c;您可能希望提取选定的页面范围并保存为新的 PDF 文档。在本文中&#xff0c;您将学习如何通过 Spire.PDF 在 C#、VB.NET 中根据页面范围拆分 PDF 文件。 Spire.PDF for .NET 是一款独立 …

群体优化算法---蝙蝠优化算法分类Iris数据集

介绍 蝙蝠算法&#xff08;Bat Algorithm, BA&#xff09;是一种基于蝙蝠回声定位行为的优化算法。要将蝙蝠算法应用于分类问题&#xff0c;可以通过将蝙蝠算法用于优化分类器的参数&#xff0c;图像分割等 本文示例 我们使用一个经典的分类数据集&#xff0c;如Iris数据集&…

Python开发运维:VSCode与Pycharm 部署 Anaconda虚拟环境

目录 一、实验 1.环境 2.Windows 部署 Anaconda 3.Anaconda 使用 4.VSCode 部署 Anaconda虚拟环境 5.Pycharm 部署 Anaconda虚拟环境 6.Windows使用命令窗口版 Jupyter Notebook 7.Anaconda 图形化界面 二、问题 1.VSCode 运行.ipynb代码时报错 2.pip 如何使用国内…

JS本地存储

cookie Cookie用于储存不超过 4KB 的小型文本数据&#xff1b;可设置数据过期时间。 // 设置cookie function setCookie(name, value, daysToLive) {let cookie name "" encodeURIComponent(value);if (typeof daysToLive "number") {cookie "…

构造,CF862C. Mahmoud and Ehab and the xor

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 Problem - 862C - Codeforces 二、解题报告 1、思路分析 非常松的一道构造题目 我们只需让最终的异或和为x即可 下面给出个人一种构造方式&#xff1a; 先选1~N-3&#xff0c;然后令o (1 << 17) …

Oracle表分区的基本使用

什么是表空间 是一个或多个数据文件的集合&#xff0c;所有的数据对象都存放在指定的表空间中&#xff0c;但主要存放的是表&#xff0c;所以称为表空间 什么是表分区 表分区就是把一张大数据的表&#xff0c;根据分区策略进行分区&#xff0c;分区设置完成之后&#xff0c;…

Redis5学习笔记之三:事务、锁和集成

3. 事务&#xff0c;锁和集成 3.1 事务 3.1.1 基本应用 redis事务的本质&#xff1a;一组命令的集合&#xff0c;一个事务中的所有命令都会被序列化&#xff0c;在执行事务的过程中&#xff0c;会按照顺序执行 redis事务的特点&#xff1a; redis单条命令能够保证原子性&…

第五讲:独立键盘、矩阵键盘的检测原理及实现

IO口电平检测 检测IO口的电平时&#xff0c;需要先给高电平 之后便进入输出状态 #include <reg52.h>void main() {// 配置P1.0为输出模式&#xff0c;并输出高电平P1 0x01; // 将P1.0置为高电平// 读取P1.0的电平状态if (P1 & 0x01) {// 如果P1.0为高电平&#x…