DAY12 Tensorflow 六步法搭建神经网络

六步法:

一.import   

导入各种库,比如:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model
import numpy as np
import pandas as pd
# 可能还会根据需求导入其他库,如用于数据可视化的 matplotlib 等
import matplotlib.pyplot as plt

二.train,test

准备训练数据和测试数据,比如:

# 以 MNIST 数据集为例
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()# 数据预处理
x_train, x_test = x_train / 255.0, x_test / 255.0

      首先,从相应的数据集中加载数据,如这里使用 mnist.load_data() 加载 MNIST 手写数字数据集,得到训练集的特征 x_train 和标签 y_train,以及测试集的特征 x_test 和标签 y_test。然后,对数据进行预处理,常见的预处理操作包括归一化、标准化等。在上述代码中,将图像像素值除以 255,将其缩放到 0 到 1 的范围,这有助于模型的训练和收敛。

    三.model=tf.keras.models.Sequential

    构建模型架构,比如:

    model = tf.keras.models.Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')
    ])

    四.model.compile

    配置模型训练过程,比如:

    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

    这一步用于配置模型的训练过程,主要设置三个重要参数:

    optimizer:优化器,用于调整模型的参数以最小化损失函数。adam 是一种常用的优化器,它结合了 AdaGrad 和 RMSProp 的优点,具有自适应学习率的能力。


    loss:损失函数,用于衡量模型预测结果与真实标签之间的差异。其中sparse_categorical_crossentropy 适用于标签为整数编码的多分类问题。


    metrics:评估指标,用于在训练和测试过程中监控模型的性能。accuracy 表示准确率,即模型预测正确的样本数占总样本数的比例。

    五,model.fit

    进行模型训练,使用训练模型进行迭代训练。

    model.fit(x_train, y_train, epochs=5)

    六,model.summary

    这一步用于打印模型的结构信息,包括每一层的名称、输出形状和参数数量等。通过查看 model.summary() 的输出,你可以了解模型的整体架构和参数规模,帮助你检查模型是否符合预期,以及评估模型的复杂度。

    model.summary()

    各自的使用方法:

    Flatten只是把数值特征拉成一维数组

    Dense全连接

    后面是卷积神经网络层和循环神经网络层

    compile配置训练方法。

    validation_data和validation_split二选一,进行训练。

    validation_freq 多少轮训练后用测试集测试一次。

    model.summary()打印出统计结果,其中可以看到,总共的参数15个,可训练参数15个,不可训练参数0个。

    以下是用六步法搭建鸢尾花分类。

    import tensorflow as tf
    from sklearn import datasets
    import numpy as npx_train = datasets.load_iris().data
    y_train = datasets.load_iris().targetnp.random.seed(116)
    np.random.shuffle(x_train)
    np.random.seed(116)
    np.random.shuffle(y_train)
    tf.random.set_seed(116)model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
    ])model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)model.summary()
    

    首先import,然后两个train交代训练集和测试集。

    np到tf这几个random用于打乱数据集设置。

    model中设置网络结构。3指的是神经元个数是3,activation选用激活函数,最后是选用正则化方法。

    complie中配置训练方法,SGD优化器,学习率0.1,选用SparseCategoricalCrossentropy当做损失函数,由于神经末端使用softmax函数,输出不是原始分布,所以logits=False。

    鸢尾花数据集给的是0,1,2是数值,神经网络前向输出是概率分布,选择sparse_categorical_accuracy作为测评指标。

    fit中执行训练过程,分别是 输入特征,训练集标签,训练时一次喂给神经网络多少组数据batch_size,循环迭代次数,validation_split=0.2告知从训练集中选择百分之20数据当做测试集,validation_freq=20,表示迭代20次,在测试集中验证一次准确率。

    运行结果:

    可见,打印出了网络结构和参数统计。

    本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/71482.shtml

    如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

    相关文章

    Zookeeper分布式锁实现

    zookeeper最初设计的初衷就是为了保证分布式系统的一致性。本文将讲解如何利用zookeeper的临时顺序结点,实现分布式锁。 目录 1. 理论分析 1.1 结点类型 1.2 监听器 1.3 实现原理 2. 手写实现简易zookeeper分布式锁 1.1 依赖 1.2 常量定义 1.3 实现zookeeper分布式…

    Git是什么

    简单介绍: Git是一个分布式版本控制系统,用于跟踪文件的更改,特别是在多人协作开发的环境中。 Key: 分布式 版本控制 系统 最常用于软件开发,但也可以用于管理任何类型的文件和文件夹。 Git帮助团队跟踪和管理文件的历史版本&a…

    Pycharm 2024在解释器提供的python控制台中运行py文件

    2024版的界面发生了变化, run with python console搬到了这里:

    【分布式理论12】事务协调者高可用:分布式选举算法

    文章目录 一、分布式系统中事务协调的问题二、分布式选举算法1. Bully算法2. Raft算法3. ZAB算法 三、小结与比较 一、分布式系统中事务协调的问题 在分布式系统中,常常有多个节点(应用)共同处理不同的事务和资源。前文 【分布式理论9】分布式…

    免费deepseek的API获取教程及将API接入word或WPS中

    免费deepseek的API获取教程: 1 https://cloud.siliconflow.cn/中注册时填写邀请码:GAejkK6X即可获取2000 万 Tokens; 2 按照图中步骤进行操作 将API接入word或WPS中 1 打开一个word,文件-选项-自定义功能区-勾选开发工具-左侧的信任中心-信任中心设置…

    【SFRA】笔记

    GK_SFRA_INJECT(x) SFRA小信号注入函数,向控制环路注入一个小信号。如下图所示,当前程序,小信号注入是在固定占空比的基础叠加小信号,得到新的占空比,使用该占空比控制环路。 1.2 GK_SFRA_COLLECT(x, y) SFRA数据收集函数,将小信号注入环路后,该函数收集环路的数据,以…

    论文笔记-WSDM2024-LLMRec

    论文笔记-WSDM2024-LLMRec: Large Language Models with Graph Augmentation for Recommendation LLMRec: 基于图增强的大模型推荐摘要1.引言2.前言2.1使用图嵌入推荐2.2使用辅助信息推荐2.3使用数据增强推荐 3.方法3.1LLM作为隐式反馈增强器3.2基于LLM的辅助信息增强3.2.1用户…

    Ubuntu 系统 cuda12.2 安装 MMDetection3D

    DataBall 助力快速掌握数据集的信息和使用方式,会员享有 百种数据集,持续增加中。 需要更多数据资源和技术解决方案,知识星球: “DataBall - X 数据球(free)” 贵在坚持! ---------------------------------------…

    Tomcat的升级

    Tomcat 是一个开源的 Java Servlet 容器,用于部署 Java Servlet 和 JavaServer Pages(JSP)。随着新版本的发布,Tomcat 通常会带来性能改进、安全增强、新特性和对最新 Java 版本的更好支持。升级 Tomcat 服务器通常涉及到以下几个…

    Python常见面试题的详解10

    1. 哪些操作会导致 Python 内存溢出,怎么处理? 要点 1. 创建超大列表或字典:当我们一次性创建规模极为庞大的列表或字典时,会瞬间占用大量的内存资源。例如,以下代码试图创建一个包含 10 亿个元素的列表,在…

    多个用户如何共用一根网线传输数据

    前置知识 一、电信号 网线(如以太网线)中传输的信号主要是 电信号,它携带着数字信息。这些信号用于在计算机和其他网络设备之间传输数据。下面是一些关于网线传输信号的详细信息: 1. 电信号传输 在以太网中,数据是…

    华为昇腾 910B 部署 DeepSeek-R1 蒸馏系列模型详细指南

    本文记录 在 华为昇腾 910B(65GB) * 8 上 部署 DeepSeekR1 蒸馏系列模型(14B、32B)全过程与测试结果。 NPU:910B3 (65GB) * 8 (910B 有三个版本 910B1、2、3) 模型:DeepSeek-R1-Distill-Qwen-14B、DeepSeek…

    【前端】Vue组件库之Element: 一个现代化的 UI 组件库

    文章目录 前言一、官网1、官网主页2、设计原则3、导航4、组件 二、核心功能:开箱即用的组件生态1、丰富的组件体系2、特色功能亮点 三、快速上手:三步开启组件化开发1、安装(使用Vue 3)2、全局引入3、按需导入(推荐&am…

    关于uniApp的面试题及其答案解析

    我的血液里流淌着战意!力量与智慧指引着我! 文章目录 1. 什么是uniApp?2. uniApp与原生小程序开发有什么区别?3. 如何使用uniApp实现条件编译?4. uniApp支持哪些平台,各有什么特点?5. 在uniApp中…

    Ubuntu 下 nginx-1.24.0 源码分析 - ngx_pool_t 类型

    ngx_pool_t 定义在 src/core/ngx_core.h typedef struct ngx_pool_s ngx_pool_t; ngx_pool_s 定义在 src/core/ngx_palloc.h struct ngx_pool_s {ngx_pool_data_t d;size_t max;ngx_pool_t *current;ngx_chain_t *chain;ng…

    力扣 最长递增子序列

    动态规划,二分查找。 题目 由题,从数组中找一个最长子序列,不难想到,当这个子序列递增子序列的数越接近时是越容易拉长的。从dp上看,当遍历到这个数,会从前面的dp选一个最大的数加上当前数,注意…

    Linux | 进程控制(进程终止与进程等待)

    文章目录 Linux | 进程控制 — 进程终止 & 进程等待1、进程终止进程常见退出方法1.1退出码基本概念获取退出码的方式常见退出码约定使用场景 1.2 strerror函数 & errno宏1.3 _exit函数1.4_exit和exit的区别1.4.1 所属头文件与函数原型1.4.2 执行过程差异**结合现象分析…

    Android - Handler使用post之后,Runnable没有执行

    问题:子线程创建的Handler。如果 post 之后,在Handler.removeCallbacks(run)移除了,下次再使用Handler.postDelayed(Runnable)接口或者使用post时,Runnable是没有执行。导致没有收到消息。 解决办法:只有主线程创建的…

    鱼皮面试鸭30天后端面试营

    day1 1. MySQL的索引类型有哪些? MySQL里的索引就像是书的目录,能帮数据库快速找到你要的数据。以下是各种索引类型的通俗解释: 按数据结构分 B树索引:最常用的一种,数据像在一棵树上分层存放,能快速定位范围数据…

    【核心算法篇十二】《深入解剖DeepSeek多任务学习:共享表示层的24个设计细节与实战密码 》

    引言:为什么你的模型总在"精神分裂"? 想象你训练了一个AI实习生: 早上做文本分类时准确率90%下午做实体识别却把"苹果"都识别成水果公司晚上做情感分析突然开始输出乱码这就是典型的任务冲突灾难——模型像被不同任务"五马分尸"。DeepSeek通…