Python深度学习基于Tensorflow(3)Tensorflow 构建模型

文章目录

        • 数据导入和数据可视化
        • 数据集制作以及预处理
        • 模型结构
        • 低阶 API 构建模型
        • 中阶 API 构建模型
        • 高阶 API 构建模型
        • 保存和导入模型

这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。

这里以CIFAR-10为数据集,CIFAR-10为小型数据集,一共包含10个类别的 RGB 彩色图像:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图像的尺寸为 32×32(像素),3个通道 ,数据集中一共有 50000 张训练圄片和 10000 张测试图像。CIFAR-10数据集有3个版本,这里使用Python版本。

数据导入和数据可视化

这里不用书中给的CIFAR-10数据,直接使用TensorFlow自带的玩意导入数据,可能需要魔法,其实TensorFlow中的数据特别的经典。

![[Pasted image 20240506194103.png]]

接下来导入cifar10数据集并进行可视化展示

import matplotlib.pyplot as plt
import tensorflow as tf(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape
# ((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))index_name = {0:'airplane',1:'automobile',2:'bird',3:'cat',4:'deer',5:'dog',6:'frog',7:'horse',8:'ship',9:'truck'
}def plot_100_img(imgs, labels):fig = plt.figure(figsize=(20,20))for i in range(10):for j in range(10):plt.subplot(10,10,i*10+j+1)plt.imshow(imgs[i*10+j])plt.title(index_name[labels[i*10+j][0]])plt.axis('off')plt.show()plot_100_img(x_test[:100])

![[Pasted image 20240506200312.png]]

数据集制作以及预处理

数据集预处理很简单就能实现,直接一行代码。

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))# 提取出一行数据
# train_data.take(1).get_single_element()

这里接着对数据预处理操作,也很容易就能实现。

def process_data(img, label):img = tf.cast(img, tf.float32) / 255.0return img, labeltrain_data = train_data.map(process_data)# 提取出一行数据
# train_data.take(1).get_single_element()

这里对数据还有一些存储和提取操作

dataset 中 shuffle()、repeat()、batch()、prefetch()等函数的主要功能如下。
1)repeat(count=None) 表示重复此数据集 count 次,实际上,我们看到 repeat 往往是接在 shuffle 后面的。为何要这么做,而不是反过来,先 repeat 再 shuffle 呢? 如果shuffle 在 repeat 之后,epoch 与 epoch 之间的边界就会模糊,出现未遍历完数据,已经计算过的数据又出现的情况。
2)shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 表示将数据打乱,数值越大,混乱程度越大。为了完全打乱,buffer_size 应等于数据集的数量。
3)batch(batch_size, drop_remainder=False) 表示按照顺序取出 batch_size 大小数据,最后一次输出可能小于batch ,如果程序指定了每次必须输入进批次的大小,那么应将drop_remainder 设置为 True 以防止产生较小的批次,默认为 False。
4)prefetch(buffer_size) 表示使用一个后台线程以及一个buffer来缓存batch,提前为模型的执行程序准备好数据。一般来说,buffer的大小应该至少和每一步训练消耗的batch数量一致,也就是 GPU/TPU 的数量。我们也可以使用AUTOTUNE来设置。创建一个Dataset便可从该数据集中预提取元素,注意:examples.prefetch(2) 表示将预取2个元素(2个示例),而examples.batch(20).prefetch(2) 表示将预取2个元素(2个批次,每个批次有20个示例),buffer_size 表示预提取时将缓冲的最大元素数返回 Dataset。

![[Pasted image 20240506201344.png]]

最后我们对数据进行一些缓存操作

learning_rate = 0.0002
batch_size = 64
training_steps = 40000
display_step = 1000AUTOTUNE = tf.data.experimental.AUTOTUNE
train_data = train_data.map(process_data).shuffle(5000).repeat(training_steps).batch(batch_size).prefetch(buffer_size=AUTOTUNE)

目前数据准备完毕!

模型结构

模型的结构如下,现在使用低阶,中阶,高阶 API 来构建这一个模型

![[Pasted image 20240506202450.png]]

低阶 API 构建模型
import matplotlib.pyplot as plt
import tensorflow as tf## 定义模型
class CustomModel(tf.Module):def __init__(self, name=None):super(CustomModel, self).__init__(name=name)self.w1 = tf.Variable(tf.initializers.RandomNormal()([32*32*3, 256]))self.b1 = tf.Variable(tf.initializers.RandomNormal()([256]))self.w2 = tf.Variable(tf.initializers.RandomNormal()([256, 128]))self.b2 = tf.Variable(tf.initializers.RandomNormal()([128]))self.w3 = tf.Variable(tf.initializers.RandomNormal()([128, 64]))self.b3 = tf.Variable(tf.initializers.RandomNormal()([64]))self.w4 = tf.Variable(tf.initializers.RandomNormal()([64, 10]))self.b4 = tf.Variable(tf.initializers.RandomNormal()([10]))def __call__(self, x):x = tf.cast(x, tf.float32)x = tf.reshape(x, [x.shape[0], -1])x = tf.nn.relu(x @ self.w1 + self.b1)x = tf.nn.relu(x @ self.w2 + self.b2)x = tf.nn.relu(x @ self.w3 + self.b3)x = tf.nn.softmax(x @ self.w4 + self.b4)return x
model = CustomModel()## 定义损失
def compute_loss(y, y_pred):y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)return tf.reduce_mean(loss)## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义准确率
def compute_accuracy(y, y_pred):correct_pred = tf.equal(tf.argmax(y_pred, axis=1), tf.cast(tf.reshape(y, -1), tf.int64))correct_pred = tf.cast(correct_pred, tf.float32)return tf.reduce_mean(correct_pred)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)accuracy = compute_accuracy(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))return loss.numpy(), accuracy.numpy()## 开始训练loss_list, acc_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):loss, acc = train_one_epoch(batch_x, batch_y)loss_list.append(loss)acc_list.append(acc)if i % 10 == 0:print(f'第{i}次训练->', 'loss:' ,loss, 'acc:', acc)
中阶 API 构建模型
## 定义模型
class CustomModel(tf.Module):def __init__(self):super(CustomModel, self).__init__()self.flatten = tf.keras.layers.Flatten()self.dense_1 = tf.keras.layers.Dense(256, activation='relu')self.dense_2 = tf.keras.layers.Dense(128, activation='relu')self.dense_3 = tf.keras.layers.Dense(64, activation='relu')self.dense_4 = tf.keras.layers.Dense(10, activation='softmax')def __call__(self, x):x = self.flatten(x)x = self.dense_1(x)x = self.dense_2(x)x = self.dense_3(x)x = self.dense_4(x)return xmodel = CustomModel()## 定义损失以及准确率
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss)train_accuracy(y, y_pred)## 开始训练
loss_list, accuracy_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):train_one_epoch(batch_x, batch_y)loss_list.append(train_loss.result())accuracy_list.append(train_accuracy.result())if i % 10 == 0:print(f"第{i}次训练: loss: {train_loss.result()} accuarcy: {train_accuracy.result()}")
高阶 API 构建模型
## 定义模型
model = tf.keras.Sequential([tf.keras.layers.Input(shape=[32,32,3]),tf.keras.layers.Flatten(),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax'),
])## 定义optimizer,loss, accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),loss = tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']
)## 开始训练
model.fit(train_data.take(10000))
保存和导入模型

保存模型

tf.keras.models.save_model(model, 'model_folder')

导入模型

model = tf.keras.models.load_model('model_folder')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/7167.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

9.4.k8s的控制器资源(job控制器,cronjob控制器)

目录 一、job控制器 二、cronjob控制器 一、job控制器 job控制器就是一次性任务的pod控制器,pod完成作业后不会重启,其重启策略是:Never; 简单案例 启动一个pod,执行完成一个事件,然后pod关闭;…

初识指针(2)<C语言>

前言 前文介绍完了一些指针基本概念,下面介绍一下,const关键字、指针的运算、野指针的成因以及避免,assert函数等。 目录 const(常属性) 变量的常属性 指针的常属性 指针的运算 ①指针 -整数 ②指针-指针 ③指针与…

Enhanced-Rtmp支持H265

Enhanced-Rtmp支持H265 千呼万唤使出来,rtmp/flv算是有统一支持H265的国际版本。本文介绍一下: 现存rtmp/flv支持H265的方式;Enhanced-Rtmp协议如何支持H265;ffmpeg/obs/srs/media-server各个开源的实现;国内方案与国外方案的兼容性问题; 1. rtmp/flv…

安卓内存机制

目录 前言一、内存 LowMemoryKiller二、常用的内存调优分析命令: 前言 安卓内存知识,不定期更新… 一、内存 LowMemoryKiller Android的设计理念之一,便是应用程序退出,但进程还会继续存在系统以便再次启动时提高响应时间. 这样的设计会带…

初识JDBC

1、JDBC是什么? Java DataBase Connectivity(Java语言连接数据库) 2、JDBC的本质是什么? JDBC是SUN公司制定的一套接口(interface) java.sql.*;(这个包下有很多接口) 接口都有调用者和实现者。 面向接口调用、面向接口写实现类,这都属于…

mysql的导入与导出

mysql表的导入与导出 导出 直接在命令行中输入&#xff08;注意不需要进入mysql&#xff09; mysqldump -u root -p my_database > C:/Users/xxx/Desktop/all.sql然后他会要求你输入数据库的密码 导入 同样也是直接在命令行中输入 mysql -u root -p my_database < …

JAVA安装linux环境安装maven

linux安装java 下载java8,地址为&#xff1a; https://www.oracle.com/cn/java/technologies/downloads/#java8&#xff0c;下载后缀为tar.gz的解压 tar -zxvf jdk-8u381-linux-x64.tar.gz移动 mv jdk1.8.0_381/ /usr/local/环境变量 export JAVA_HOME/usr/local/jdk1.8.0_38…

【Osek网络管理测试】[TG3_TC6]等待总线睡眠状态_2

&#x1f64b;‍♂️ 【Osek网络管理测试】系列&#x1f481;‍♂️点击跳转 文章目录 1.环境搭建2.测试目的3.测试步骤4.预期结果5.测试结果 1.环境搭建 硬件&#xff1a;VN1630 软件&#xff1a;CANoe 2.测试目的 验证DUT在满足进入等待睡眠状态的条件时是否进入该状态 …

17 内核开发-内核内部内联汇编学习

​ 17 内核开发-内核内部内联汇编学习 课程简介&#xff1a; Linux内核开发入门是一门旨在帮助学习者从最基本的知识开始学习Linux内核开发的入门课程。该课程旨在为对Linux内核开发感兴趣的初学者提供一个扎实的基础&#xff0c;让他们能够理解和参与到Linux内核的开发过程中…

英伟达推出视觉语言模型:VILA

NVIDIA和MIT的研究人员推出了一种新的视觉语言模型(VLM)预训练框架&#xff0c;名为VILA。这个框架旨在通过有效的嵌入对齐和动态神经网络架构&#xff0c;改进语言模型的视觉和文本的学习能力。VILA通过在大规模数据集如Coy0-700m上进行预训练&#xff0c;采用基于LLaVA模型的…

VBA编程之条件语句

上一篇我们讲述了条件语句以及分支。文章的最后用到了逻辑运算符“And“那么今天我们来聊一聊逻辑运算符和Select……Case结构。 在学习前我们先来了解一下&#xff0c;在生活中我们经常说”这个包括那个“&#xff0c;”你或者他“&#xff0c;”不是“等等。而这里”包括“和…

面对对象之封装

Python面向对象之封装 【一】什么是封装&#xff0c;为什么要封装 封装就是指&#xff0c;把数据与功能都整合到一起 就是将某些地方隐藏起来&#xff0c;在程序外部看不到&#xff0c;其他程序无法调用 封装最主要的原因就是为了保护隐私&#xff0c;将不想让用户看到的功能…

esp32+mqtt协议+paltformio+vscode+微信小程序+温湿度检测

花费两天时间完成了这个项目&#xff08;不完全是&#xff0c;属于是在resnet模型训练和温湿度检测两头跑......模型跑不出来&#xff0c;又是第一次从头到尾独立玩硬件&#xff0c;属于是焦头烂额了......&#xff0c;完成这个项目后&#xff0c;我的第一反应是写个csdn&#…

[每日AI·0506]巴菲特谈 AI,李飞飞创业,苹果或将推出 AI 功能,ChatGPT 版搜索引擎

AI 资讯 苹果或将推出 AI 功能&#xff0c;随 iPhone 发布2024 年巴菲特股东大会&#xff0c;巴菲特将 AI 类比为核技术 巴菲特股东大会 5 万字实录消息称 OpenAI 将于 5 月 9 日发布 ChatGPT 版搜索引擎路透社消息&#xff0c;斯坦福大学 AI 领军人物李飞飞打造“空间智能”创…

切比雪夫滤波器

切比雪夫滤波器&#xff0c;也被称为车比雪夫滤波器&#xff0c;是一种在通带或阻带上频率响应幅度等波纹波动的滤波器。它基于切比雪夫多项式的理论&#xff0c;并且是以俄罗斯数学家巴夫尼提列波维其切比雪夫&#xff08;Пафнутий Лвович Чебышёв&#…

论文辅助笔记:Tempo 之 model.py

0 导入库 import math from dataclasses import dataclass, asdictimport torch import torch.nn as nnfrom src.modules.transformer import Block from src.modules.prompt import Prompt from src.modules.utils import (FlattenHead,PoolingHead,RevIN, )1TEMPOConfig 1.…

【C++】 认识多态 + 多态的构成条件详细讲解

前言 C 目录 1. 多态的概念2 多态的定义及实现2 .1 虚函数&#xff1a;2 .2 虚函数的重写&#xff1a;2 .2.1 虚函数重写的两个例外&#xff1a; 2 .3 多态的两个条件&#xff08;重点&#xff09;2 .4 析构函数为啥写成虚函数 3 新增的两个关键字3.1 final的使用&#xff1a;3…

09_电子设计教程基础篇(电阻)

文章目录 前言一、电阻原理二、电阻种类1.固定电阻1、材料工艺1、线绕电阻2、非线绕电阻1、实心电阻1、有机实心电阻2、无机实心电阻 2、薄膜电阻&#xff08;常用&#xff09;1、碳膜电阻2、合成碳膜电阻3、金属膜电阻4、金属氧化膜电阻5、玻璃釉膜电阻 3、厚膜电阻&#xff0…

vue2实现生成二维码和复制保存图片功能(复制的同时会给图片加文字)

<template><divstyle"display: flex;justify-content: center;align-items: center;width: 100vw;height: 100vh;"><div><!-- 生成二维码按钮和输入二维码的输入框 --><input v-model"url" placeholder"输入链接" ty…

智能家居1 -- 实现语音模块

项目整体框架: 监听线程4&#xff1a; 1. 语音监听线程:用于监听语音指令&#xff0c; 当有语音指令过来后&#xff0c; 通过消息队列的方式给消息处理线程发送指令 2. 网络监听线程&#xff1a;用于监听网络指令&#xff0c;当有网络指令过来后&#xff0c; 通过消息队列的方…