Python深度学习基于Tensorflow(3)Tensorflow 构建模型

文章目录

        • 数据导入和数据可视化
        • 数据集制作以及预处理
        • 模型结构
        • 低阶 API 构建模型
        • 中阶 API 构建模型
        • 高阶 API 构建模型
        • 保存和导入模型

这里以实际项目CIFAR-10为例,分别使用低阶,中阶,高阶 API 搭建模型。

这里以CIFAR-10为数据集,CIFAR-10为小型数据集,一共包含10个类别的 RGB 彩色图像:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图像的尺寸为 32×32(像素),3个通道 ,数据集中一共有 50000 张训练圄片和 10000 张测试图像。CIFAR-10数据集有3个版本,这里使用Python版本。

数据导入和数据可视化

这里不用书中给的CIFAR-10数据,直接使用TensorFlow自带的玩意导入数据,可能需要魔法,其实TensorFlow中的数据特别的经典。

![[Pasted image 20240506194103.png]]

接下来导入cifar10数据集并进行可视化展示

import matplotlib.pyplot as plt
import tensorflow as tf(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# x_train.shape, y_train.shape, x_test.shape, y_test.shape
# ((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))index_name = {0:'airplane',1:'automobile',2:'bird',3:'cat',4:'deer',5:'dog',6:'frog',7:'horse',8:'ship',9:'truck'
}def plot_100_img(imgs, labels):fig = plt.figure(figsize=(20,20))for i in range(10):for j in range(10):plt.subplot(10,10,i*10+j+1)plt.imshow(imgs[i*10+j])plt.title(index_name[labels[i*10+j][0]])plt.axis('off')plt.show()plot_100_img(x_test[:100])

![[Pasted image 20240506200312.png]]

数据集制作以及预处理

数据集预处理很简单就能实现,直接一行代码。

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))# 提取出一行数据
# train_data.take(1).get_single_element()

这里接着对数据预处理操作,也很容易就能实现。

def process_data(img, label):img = tf.cast(img, tf.float32) / 255.0return img, labeltrain_data = train_data.map(process_data)# 提取出一行数据
# train_data.take(1).get_single_element()

这里对数据还有一些存储和提取操作

dataset 中 shuffle()、repeat()、batch()、prefetch()等函数的主要功能如下。
1)repeat(count=None) 表示重复此数据集 count 次,实际上,我们看到 repeat 往往是接在 shuffle 后面的。为何要这么做,而不是反过来,先 repeat 再 shuffle 呢? 如果shuffle 在 repeat 之后,epoch 与 epoch 之间的边界就会模糊,出现未遍历完数据,已经计算过的数据又出现的情况。
2)shuffle(buffer_size, seed=None, reshuffle_each_iteration=None) 表示将数据打乱,数值越大,混乱程度越大。为了完全打乱,buffer_size 应等于数据集的数量。
3)batch(batch_size, drop_remainder=False) 表示按照顺序取出 batch_size 大小数据,最后一次输出可能小于batch ,如果程序指定了每次必须输入进批次的大小,那么应将drop_remainder 设置为 True 以防止产生较小的批次,默认为 False。
4)prefetch(buffer_size) 表示使用一个后台线程以及一个buffer来缓存batch,提前为模型的执行程序准备好数据。一般来说,buffer的大小应该至少和每一步训练消耗的batch数量一致,也就是 GPU/TPU 的数量。我们也可以使用AUTOTUNE来设置。创建一个Dataset便可从该数据集中预提取元素,注意:examples.prefetch(2) 表示将预取2个元素(2个示例),而examples.batch(20).prefetch(2) 表示将预取2个元素(2个批次,每个批次有20个示例),buffer_size 表示预提取时将缓冲的最大元素数返回 Dataset。

![[Pasted image 20240506201344.png]]

最后我们对数据进行一些缓存操作

learning_rate = 0.0002
batch_size = 64
training_steps = 40000
display_step = 1000AUTOTUNE = tf.data.experimental.AUTOTUNE
train_data = train_data.map(process_data).shuffle(5000).repeat(training_steps).batch(batch_size).prefetch(buffer_size=AUTOTUNE)

目前数据准备完毕!

模型结构

模型的结构如下,现在使用低阶,中阶,高阶 API 来构建这一个模型

![[Pasted image 20240506202450.png]]

低阶 API 构建模型
import matplotlib.pyplot as plt
import tensorflow as tf## 定义模型
class CustomModel(tf.Module):def __init__(self, name=None):super(CustomModel, self).__init__(name=name)self.w1 = tf.Variable(tf.initializers.RandomNormal()([32*32*3, 256]))self.b1 = tf.Variable(tf.initializers.RandomNormal()([256]))self.w2 = tf.Variable(tf.initializers.RandomNormal()([256, 128]))self.b2 = tf.Variable(tf.initializers.RandomNormal()([128]))self.w3 = tf.Variable(tf.initializers.RandomNormal()([128, 64]))self.b3 = tf.Variable(tf.initializers.RandomNormal()([64]))self.w4 = tf.Variable(tf.initializers.RandomNormal()([64, 10]))self.b4 = tf.Variable(tf.initializers.RandomNormal()([10]))def __call__(self, x):x = tf.cast(x, tf.float32)x = tf.reshape(x, [x.shape[0], -1])x = tf.nn.relu(x @ self.w1 + self.b1)x = tf.nn.relu(x @ self.w2 + self.b2)x = tf.nn.relu(x @ self.w3 + self.b3)x = tf.nn.softmax(x @ self.w4 + self.b4)return x
model = CustomModel()## 定义损失
def compute_loss(y, y_pred):y_pred = tf.clip_by_value(y_pred, 1e-9, 1.)loss = tf.keras.losses.sparse_categorical_crossentropy(y, y_pred)return tf.reduce_mean(loss)## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义准确率
def compute_accuracy(y, y_pred):correct_pred = tf.equal(tf.argmax(y_pred, axis=1), tf.cast(tf.reshape(y, -1), tf.int64))correct_pred = tf.cast(correct_pred, tf.float32)return tf.reduce_mean(correct_pred)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)accuracy = compute_accuracy(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))return loss.numpy(), accuracy.numpy()## 开始训练loss_list, acc_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):loss, acc = train_one_epoch(batch_x, batch_y)loss_list.append(loss)acc_list.append(acc)if i % 10 == 0:print(f'第{i}次训练->', 'loss:' ,loss, 'acc:', acc)
中阶 API 构建模型
## 定义模型
class CustomModel(tf.Module):def __init__(self):super(CustomModel, self).__init__()self.flatten = tf.keras.layers.Flatten()self.dense_1 = tf.keras.layers.Dense(256, activation='relu')self.dense_2 = tf.keras.layers.Dense(128, activation='relu')self.dense_3 = tf.keras.layers.Dense(64, activation='relu')self.dense_4 = tf.keras.layers.Dense(10, activation='softmax')def __call__(self, x):x = self.flatten(x)x = self.dense_1(x)x = self.dense_2(x)x = self.dense_3(x)x = self.dense_4(x)return xmodel = CustomModel()## 定义损失以及准确率
compute_loss = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean()
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()## 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)## 定义一次epoch
def train_one_epoch(x, y):with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)grads = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(grads, model.trainable_variables))train_loss(loss)train_accuracy(y, y_pred)## 开始训练
loss_list, accuracy_list = [], []
for i, (batch_x, batch_y) in enumerate(train_data.take(1000), 1):train_one_epoch(batch_x, batch_y)loss_list.append(train_loss.result())accuracy_list.append(train_accuracy.result())if i % 10 == 0:print(f"第{i}次训练: loss: {train_loss.result()} accuarcy: {train_accuracy.result()}")
高阶 API 构建模型
## 定义模型
model = tf.keras.Sequential([tf.keras.layers.Input(shape=[32,32,3]),tf.keras.layers.Flatten(),tf.keras.layers.Dense(256, activation='relu'),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(10, activation='softmax'),
])## 定义optimizer,loss, accuracy
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),loss = tf.keras.losses.SparseCategoricalCrossentropy(),metrics=['accuracy']
)## 开始训练
model.fit(train_data.take(10000))
保存和导入模型

保存模型

tf.keras.models.save_model(model, 'model_folder')

导入模型

model = tf.keras.models.load_model('model_folder')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/7167.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

9.4.k8s的控制器资源(job控制器,cronjob控制器)

目录 一、job控制器 二、cronjob控制器 一、job控制器 job控制器就是一次性任务的pod控制器,pod完成作业后不会重启,其重启策略是:Never; 简单案例 启动一个pod,执行完成一个事件,然后pod关闭;…

初识指针(2)<C语言>

前言 前文介绍完了一些指针基本概念,下面介绍一下,const关键字、指针的运算、野指针的成因以及避免,assert函数等。 目录 const(常属性) 变量的常属性 指针的常属性 指针的运算 ①指针 -整数 ②指针-指针 ③指针与…

【Osek网络管理测试】[TG3_TC6]等待总线睡眠状态_2

🙋‍♂️ 【Osek网络管理测试】系列💁‍♂️点击跳转 文章目录 1.环境搭建2.测试目的3.测试步骤4.预期结果5.测试结果 1.环境搭建 硬件:VN1630 软件:CANoe 2.测试目的 验证DUT在满足进入等待睡眠状态的条件时是否进入该状态 …

17 内核开发-内核内部内联汇编学习

​ 17 内核开发-内核内部内联汇编学习 课程简介: Linux内核开发入门是一门旨在帮助学习者从最基本的知识开始学习Linux内核开发的入门课程。该课程旨在为对Linux内核开发感兴趣的初学者提供一个扎实的基础,让他们能够理解和参与到Linux内核的开发过程中…

英伟达推出视觉语言模型:VILA

NVIDIA和MIT的研究人员推出了一种新的视觉语言模型(VLM)预训练框架,名为VILA。这个框架旨在通过有效的嵌入对齐和动态神经网络架构,改进语言模型的视觉和文本的学习能力。VILA通过在大规模数据集如Coy0-700m上进行预训练,采用基于LLaVA模型的…

VBA编程之条件语句

上一篇我们讲述了条件语句以及分支。文章的最后用到了逻辑运算符“And“那么今天我们来聊一聊逻辑运算符和Select……Case结构。 在学习前我们先来了解一下,在生活中我们经常说”这个包括那个“,”你或者他“,”不是“等等。而这里”包括“和…

esp32+mqtt协议+paltformio+vscode+微信小程序+温湿度检测

花费两天时间完成了这个项目(不完全是,属于是在resnet模型训练和温湿度检测两头跑......模型跑不出来,又是第一次从头到尾独立玩硬件,属于是焦头烂额了......,完成这个项目后,我的第一反应是写个csdn&#…

[每日AI·0506]巴菲特谈 AI,李飞飞创业,苹果或将推出 AI 功能,ChatGPT 版搜索引擎

AI 资讯 苹果或将推出 AI 功能,随 iPhone 发布2024 年巴菲特股东大会,巴菲特将 AI 类比为核技术 巴菲特股东大会 5 万字实录消息称 OpenAI 将于 5 月 9 日发布 ChatGPT 版搜索引擎路透社消息,斯坦福大学 AI 领军人物李飞飞打造“空间智能”创…

论文辅助笔记:Tempo 之 model.py

0 导入库 import math from dataclasses import dataclass, asdictimport torch import torch.nn as nnfrom src.modules.transformer import Block from src.modules.prompt import Prompt from src.modules.utils import (FlattenHead,PoolingHead,RevIN, )1TEMPOConfig 1.…

【C++】 认识多态 + 多态的构成条件详细讲解

前言 C 目录 1. 多态的概念2 多态的定义及实现2 .1 虚函数:2 .2 虚函数的重写:2 .2.1 虚函数重写的两个例外: 2 .3 多态的两个条件(重点)2 .4 析构函数为啥写成虚函数 3 新增的两个关键字3.1 final的使用:3…

09_电子设计教程基础篇(电阻)

文章目录 前言一、电阻原理二、电阻种类1.固定电阻1、材料工艺1、线绕电阻2、非线绕电阻1、实心电阻1、有机实心电阻2、无机实心电阻 2、薄膜电阻(常用)1、碳膜电阻2、合成碳膜电阻3、金属膜电阻4、金属氧化膜电阻5、玻璃釉膜电阻 3、厚膜电阻&#xff0…

vue2实现生成二维码和复制保存图片功能(复制的同时会给图片加文字)

<template><divstyle"display: flex;justify-content: center;align-items: center;width: 100vw;height: 100vh;"><div><!-- 生成二维码按钮和输入二维码的输入框 --><input v-model"url" placeholder"输入链接" ty…

智能家居1 -- 实现语音模块

项目整体框架: 监听线程4&#xff1a; 1. 语音监听线程:用于监听语音指令&#xff0c; 当有语音指令过来后&#xff0c; 通过消息队列的方式给消息处理线程发送指令 2. 网络监听线程&#xff1a;用于监听网络指令&#xff0c;当有网络指令过来后&#xff0c; 通过消息队列的方…

SpringSecurity6 学习

学习介绍 网上关于SpringSecurity的教程大部分都停留在6以前的版本 但是&#xff0c;SpringSecurity6.x版本后的内容进行大量的整改&#xff0c;网上的教程已经不能够满足 最新的版本使用。这里我查看了很多教程 发现一个宝藏课程&#xff0c;并且博主也出了一个关于SpringSec…

【python】条件语句与循环语句

目录 一.条件语句 1.定义 2.条件语句格式 &#xff08;1&#xff09;if &#xff08;2&#xff09;if-else &#xff08;3&#xff09;elif功能 &#xff08;4&#xff09;if嵌套使用 3.猜拳游戏 二.循环语句 1. while循环 2.while嵌套 3.for循环 4.break和conti…

被问了n遍的小程序地理位置权限开通方法

小程序地理位置接口有什么功能&#xff1f; 在平时我们在开发小程序时&#xff0c;难免会需要用到用户的地理位置信息的功能&#xff0c;小程序开发者开放平台新规要求如果没有申请开通微信小程序地理位置接口( getLocation )&#xff0c;但是在代码中却使用到了相关接口&#…

人工智能概述与入门基础简述

人工智能&#xff08;AI&#xff09;是计算机科学的一个分支&#xff0c;它致力于创建能够执行通常需要人类智能的任务的机器。这篇科普文章将全面介绍人工智能的基本概念、发展历程、主要技术、实际应用以及如何入门这一领域。 一、人工智能的定义与发展历程 人工智能的概念…

springboot版本升级,及解决springsecurity漏洞问题

背景&#xff1a; 项目中要解决 Spring Security RegexRequestMatcher 认证绕过漏洞&#xff08;CVE-2022-22978&#xff09; 漏洞问题&#xff0c;并且需要将项目的版本整体升级到boot版本2.1.7&#xff0c;升级改造过程非常的痛苦&#xff0c;一方面对整个框架的代码不是很熟…

六淳科技IPO终止背后:十分着急上市,大额分红,实控人买豪宅

华西证券被暂停保荐业务资格6个月的影响力逐渐显现。 近日&#xff0c;深圳证券交易所披露的信息显示&#xff0c;东莞六淳智能科技股份有限公司&#xff08;下称“六淳科技”&#xff09;及其保荐人撤回上市申请材料。因此&#xff0c;深圳证券交易所决定终止对其首次公开发行…

LangChain 概念篇(喂饭级)

LangChain 介绍 LangChain 是一个用于开发由语言模型驱动的应用程序的框架。 LangChain 框架的设计目标 支持应用程序让其不仅会通过 API 调用语言模型&#xff0c;而且还会数据感知&#xff08;将语言模型连接到其他数据源&#xff09;&#xff0c;Be agentic&#xff08;允…