第T9周:Tensorflow实现猫狗识别(2)

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

具体实现

(一)环境

语言环境:Python 3.10
编 译 器: PyCharm
框 架: Tensorflow 2.10.0

(二)具体步骤
from absl.logging import warning  
import tensorflow as tf  
from tensorflow.python.data import AUTOTUNE  from utils import GPU_ON  
import matplotlib.pyplot as plt  
# 目标:主要学习数据增强的方式方法# 第一步:准备环境  
# 查询tensorflow版本print("Tensorflow Version:", tf.__version__)# print(tf.config.experimental.list_physical_devices('GPU'))# 设置使用GPUgpus = tf.config.list_physical_devices("GPU")print(gpus)if gpus:gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存按需使用tf.config.set_visible_devices([gpu0], "GPU")>)# ##########output#############################################  
# Tensorflow Version: 2.10.0# [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]  
# [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]  
# ##########end output##########################################  
# 支持中文  
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来显示中文标签  
plt.rcParams['axes.unicode_minus'] = False     # 用来正常显示负号  import os, PIL, pathlib  # 隐藏警告  
import warnings  
warnings.filterwarnings('ignore')  # 第二步:导入数据  
data_dir = "./datasets/365-7-data"  
data_dir = pathlib.Path(data_dir)  
image_count = len(list(data_dir.glob('*/*')))  
print("图片总数为:", image_count)  
# ########output##############################################  
# 图片总数为: 3400# ########end output##########################################  # 第三步:数据预处理  
batch_size = 8  
img_height, img_width = 224, 224  
train_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="training",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size,  
)  
# ############output##########################################  
# Found 3400 files belonging to 2 classes.  
# Using 2720 files for training.  
##############end output######################################  val_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="validation",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size,  
)  
# ############output##########################################  
# Found 3400 files belonging to 2 classes.  
# Using 680 files for validation.  
# ###############end output##################################  
# 获取名称标签  
class_names = train_ds.class_names  
print(class_names)  
# #################output######################################  
# ['cat', 'dog']  
###################end output###################################  
# 检查一下数据  
for image_batch, labels_batch in train_ds:  print(image_batch.shape)  print(labels_batch.shape)  break  
# #############output########################################  
# (8, 224, 224, 3)  ---每一批8张图片,长224,宽224,RGB彩色通道(3)  
# (8,) --- 标签就是一批8张图片的标签  
# #############end output###################################  
# 预处理  
AUTOTUNE = tf.data.AUTOTUNE  def preprocess_image(image, label):  return (image / 255.0, label)  # 归一化处理  
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)  
val_ds = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)  # cache() ----将数据集缓存到内存当中 加速运行  
# shuffle() ----打乱数据  
# prefetch() ----预取数据,加速运行  
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)  
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)  # 可视化数据  
plt.figure(figsize=(15, 10))    # 创建一个顶层容器,大小是15*20英寸  
for images, labels in train_ds.take(1):  for i in range(8):  # 向当前图添加坐标轴, 我们想在1行显示8张图片,所以是1行8列  ax = plt.subplot(1, 8, i + 1)   print(images[i])  # imshow()--将数据显示为图像,支持的数据类型(M,N)标量数据/(M,N,3)RGB数据/(M,N,4)RGBA数据。本例中是RGB  plt.imshow(images[i])       plt.title(class_names[labels[i]])   # 显示坐标轴标签  plt.axis('off')     # 隐藏所有的轴信息  plt.show()  

image.png

# 第四步:构建VGG16网络模型  
from tensorflow.keras import layers, models, Input  
from tensorflow.keras.models import Model  
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout  def VGG16(nb_classes, input_shape):  input_tensor = Input(shape=input_shape)  # 1st block  x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv1')(input_tensor)  x = Conv2D(64, (3,3), activation='relu', padding='same',name='block1_conv2')(x)  x = MaxPooling2D((2,2), strides=(2,2), name = 'block1_pool')(x)  # 2nd block  x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv1')(x)  x = Conv2D(128, (3,3), activation='relu', padding='same',name='block2_conv2')(x)  x = MaxPooling2D((2,2), strides=(2,2), name = 'block2_pool')(x)  # 3rd block  x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv1')(x)  x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv2')(x)  x = Conv2D(256, (3,3), activation='relu', padding='same',name='block3_conv3')(x)  x = MaxPooling2D((2,2), strides=(2,2), name = 'block3_pool')(x)  # 4th block  x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv1')(x)  x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv2')(x)  x = Conv2D(512, (3,3), activation='relu', padding='same',name='block4_conv3')(x)  x = MaxPooling2D((2,2), strides=(2,2), name = 'block4_pool')(x)  # 5th block  x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv1')(x)  x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv2')(x)  x = Conv2D(512, (3,3), activation='relu', padding='same',name='block5_conv3')(x)  x = MaxPooling2D((2,2), strides=(2,2), name = 'block5_pool')(x)  # full connection  x = Flatten()(x)  x = Dense(4096, activation='relu',  name='fc1')(x)  x = Dense(4096, activation='relu', name='fc2')(x)  output_tensor = Dense(nb_classes, activation='softmax', name='predictions')(x)  model = Model(input_tensor, output_tensor)  return model  model=VGG16(1000, (img_width, img_height, 3))  
model.summary()  

image.png

  
# 第五步:编译  
model.compile(loss='sparse_categorical_crossentropy',  # 损失函数  optimizer='adam',                 # 优化函数  metrics=['accuracy'])             # 模型评估的指标,一般是accuracy  
# 第六步:训练模型  
from tqdm import tqdm  
import tensorflow.keras.backend as K  epochs = 10  
lr = 1e-4  # 记录训练数据,方便后面分析  
history_train_loss = []  
history_val_loss = []  
history_train_accuracy = []  
history_val_accuracy = []  for epoch in range(epochs):  train_total = len(train_ds)  val_total = len(val_ds)  """  total: 预期的迭代数目  ncols: 控制进度条宽度  mininterval: 进度条更新最小间隔,以秒为单位(默认为0.1)  """    with tqdm(total=train_total,  desc=f'Epoch {epoch + 1}/{epochs}',  mininterval=1,  ncols=100) as pbar:  lr = lr * 0.92  K.set_value(model.optimizer.lr, lr)  for image, label in train_ds:  history = model.train_on_batch(image, label)  train_loss = history[0]  train_accuracy = history[1]  pbar.set_postfix({  "loss": "%.4f" % train_loss,  "accuracy": "%.4f" % train_accuracy,  "lr": K.get_value(model.optimizer.lr),  })  pbar.update(1)  history_train_loss.append(train_loss)  history_train_accuracy.append(train_accuracy)  print('开始验证!')  with tqdm(total=val_total,  desc=f'Epoch {epoch + 1}/{epochs}',  mininterval=0.3,  ncols=100) as pbar:  for image, label in val_ds:  history = model.test_on_batch(image, label)  val_loss = history[0]  val_accuracy = history[1]  pbar.set_postfix({  "loss": "%.4f" % val_loss,  "accuracy": "%.4f" % val_accuracy  })  pbar.update(1)  history_val_loss.append(val_loss)  history_val_accuracy.append(val_accuracy)  print('结束验证!')  print('验证loss为:%.4f'%val_loss)  print('验证准确率为:%.4f'%val_accuracy)  

image.png

# 第七步:评估模型  [# 采用加载的模型(new_model)来看预测结果  
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5  
plt.suptitle("预测结果展示")  for images, labels in val_ds.take(1):  for i in range(8):  ax = plt.subplot(1, 8, i + 1)  # 显示图片  plt.imshow(images[i].numpy())  # 需要给图片增加一个维度  img_array = tf.expand_dims(images[i], 0)  # 使用模型预测图片中的人物  predictions = model.predict(img_array)  plt.title(class_names[np.argmax(predictions)])  plt.axis("off")  
plt.show()](<epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)plt.plot(epochs_range, history_train_accuracy, label='Training Accuracy')
plt.plot(epochs_range, history_val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, history_train_loss, label='Training Loss')
plt.plot(epochs_range, history_val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()>) 

  
# 第八步:预测  
import numpy as np  
# 采用加载的模型(new_model)来看预测结果  
plt.figure(figsize=(18, 3))  # 图形的宽为18高为5  
plt.suptitle("预测结果展示")  for images, labels in val_ds.take(1):  for i in range(8):  ax = plt.subplot(1, 8, i + 1)  # 显示图片  plt.imshow(images[i].numpy())  # 需要给图片增加一个维度  img_array = tf.expand_dims(images[i], 0)  # 使用模型预测图片中的人物  predictions = model.predict(img_array)  plt.title(class_names[np.argmax(predictions)])  plt.axis("off")
plt.show()

image.png
image.png

总结
1. 我的GPU是nvidia RTX 4060 laptop,当batch设定为64时,报错:W tensorflow/core/common_runtime/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available. 但字面理解为显存不足. batch改为32则无问题。那么,对于 batch的大小,如何来界定?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62255.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式锁的实现原理

作者&#xff1a;来自 vivo 互联网服务器团队- Xu Yaoming 介绍分布式锁的实现原理。 一、分布式锁概述 分布式锁&#xff0c;顾名思义&#xff0c;就是在分布式环境下使用的锁。众所周知&#xff0c;在并发编程中&#xff0c;我们经常需要借助并发控制工具&#xff0c;如 mu…

搭建帮助中心到底有什么作用?

在当今快节奏的商业环境中&#xff0c;企业面临着日益增长的客户需求和竞争压力。搭建一个有效的帮助中心对于企业来说&#xff0c;不仅是提升客户服务体验的重要途径&#xff0c;也是优化内部知识管理和提升团队效率的关键。以下是帮助中心在企业运营中的几个关键作用&#xf…

深入浅出剖析典型文生图产品Midjourney

2022年7月,一个小团队推出了公测的 Midjourney,打破了 AIGC 领域的大厂垄断。作为一个精调生成模型,以聊天机器人方式部署在 Discord,它创作的《太空歌剧院》作品,甚至获得了美国「数字艺术/数码摄影」竞赛单元一等奖。 这一事件展示了 AI 在绘画领域惊人的创造力,让人们…

python+docx:(二)页眉页脚、表格操作

目录 页眉页脚 表格 表格样式 插入表格 插入行/列 合并单元格 单元格 页眉页脚 页眉页脚操作需要访问文件的section&#xff0c;可通过添加页脚来添加页码。 from docx import Document from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_ALIGN_PARAGRAPH, WD_CO…

Matlab Simulink 电力电子仿真-单相电压型半桥逆变电路分析

目录 一、单相电压型半桥逆变电路仿真模型 1.电路模型 2.电路模型参数 二、仿真分析 三、总结 1.优缺点 2.应用场景 一、单相电压型半桥逆变电路仿真模型 1.电路模型 单相电压型半桥逆变电路是一种常见的逆变电路&#xff0c;主要用于将直流电源转换为交流电源。 &…

Qt 编程专栏目录

Qt 编程专栏简介 Qt为开发者提供了一个强大的跨平台开发工具。无论你是刚刚接触Qt&#xff0c;还是已经在使用它构建复杂应用的开发者&#xff0c;这里都能为你提供有用的知识和实战技巧。 在这个专栏中&#xff0c;我们不仅讲解Qt的使用方法&#xff0c;还会结合实际开发场景…

C++入门——“C++11-lambda”

引入 C11支持lambda表达式&#xff0c;lambda是一个匿名函数对象&#xff0c;它允许在函数体中直接定义。 一、初识lambda lambda的结构是&#xff1a;[ ] () -> 返回值类型 { }。从左到右依次是&#xff1a;捕捉列表 函数参数 -> 返回值类型 函数体。 以下是一段用lam…

Day 2:Java 集合框架(List 和 Map)

目标&#xff1a;掌握日常工作中常用集合的基本操作。 理论知识&#xff1a; List&#xff1a; ArrayList 和 LinkedList 的区别。 特性ArrayListLinkedList底层实现基于动态数组实现&#xff0c;元素安索引存储基于双向链表实现&#xff0c;元素节点彼此连接访问速度随机访…

如何保护LabVIEW程序免遭反编译

在正常情况下&#xff0c;LabVIEW程序&#xff08;即编译后的可执行文件或运行时文件&#xff0c;如 .exe 或 .llb&#xff09;无法直接被反编译出源码。然而&#xff0c;有一些需要特别注意的点&#xff1a; 1. LabVIEW的编译机制 LabVIEW编译器会将源码&#xff08;.vi文件&a…

提升76%的关键-在ModelMapper中实现性能提升的几种方法

目录 前言 一、ModelMapper基础知识 1、深入ModelMapper 2、深入Configuration配置 3、深入MappingEngineImpl 二、默认加载模式 1、基础测试代码 三、持续优化&#xff0c;慢慢提升 1、增加忽略字段 2、设置忽略空值模式 3、设置命名模式 4、采用精准匹配模式 四、…

【C语言】结构体、联合体、枚举类型的字节大小详解

在C语言中&#xff0c;结构体&#xff08;struct&#xff09;和联合体&#xff08;union&#xff09; 是常用的复合数据类型&#xff0c;它们的内存布局和字节大小直接影响程序的性能和内存使用。下面为大家详细解释它们的字节大小计算方法&#xff0c;包括对齐规则、内存分配方…

【优选算法】位运算

目录 常见位运算总结1、基础位运算2、给一个数n&#xff0c;确定它的二进制位的第x位上是0还是13、将一个数n的二进制位的第x位改成14、将一个数n的二进制位的第x位改成05、位图的思想6、提取一个数n的二进制位中最右侧的17、将一个数n的二进制位中最右侧的1变为08、位运算的优…

jQuery九宫格抽奖,php处理抽奖信息

功能介绍 jQuery九宫格抽奖是一种基于jQuery库的前端抽奖效果。通过九宫格的形式展示抽奖项&#xff0c;用户点击抽奖按钮后&#xff0c;九宫格开始旋转&#xff0c;最终停在一个随机位置上&#xff0c;此位置对应的抽奖项为用户的中奖结果。 本文实现九宫格的步骤为&#xf…

AI界的信仰危机:单靠“规模化”智能增长的假设,正在面临挑战

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

Unity类银河战士恶魔城学习总结(P149 Screen Fade淡入淡出菜单)

【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili 教程源地址&#xff1a;https://www.udemy.com/course/2d-rpg-alexdev/ 本章节实现了进入游戏和死亡之后的淡入淡出动画效果 UI_FadeScreen.cs 1. Animator 组件的引用 (anim) 该脚本通过 Animator 控制 UI 元…

【C语言篇】探索 C 语言结构体:从基础语法到数据组织的初体验

我的个人主页 我的专栏&#xff1a;C语言&#xff0c;希望能帮助到大家&#xff01;&#xff01;&#xff01;点赞❤ 收藏❤ 目录 什么是结构体结构体的定义与使用结构体内存布局嵌套结构体与指针结构体数组的操作结构体与函数结构体内存对齐机制位域与结构体的结合动态内存分…

COMSOL工作站:配置指南与性能优化

COMSOL Multiphysics 求解的问题类型相当广泛&#xff0c;提供了仿真单一物理场以及灵活耦合多个物理场的功能&#xff0c;供工程师和科研人员来精确分析各个工程领域的设备、工艺和流程。 软件内置的#模型开发器#包含完整的建模工作流程&#xff0c;可实现从几何建模、材料参数…

全面解析LLM业务落地:RAG技术的创新应用、ReAct的智能化实践及基于业务场景的评估框架设计

1. 如何让LLM更好的业务落地常见方法 等待新的大型模型版本:但是,每个新版本也会有时间限制。 自己训练模型:这种方法成本高昂且耗时,需要大量基础设施。它也只是一个临时解决方案。 LoRA(低秩自适应)微调:这种方法更简单、更便宜,可以更频繁地进行,但不能在线进行。模…

大语言模型LLM的微调代码详解

代码的摘要说明 一、整体功能概述 这段 Python 代码主要实现了基于 Hugging Face Transformers 库对预训练语言模型&#xff08;具体为 TAIDE-LX-7B-Chat 模型&#xff09;进行微调&#xff08;Fine-tuning&#xff09;的功能&#xff0c;使其能更好地应用于生成唐诗相关内容的…

js中判断数组和判断对象的方法

判断数组 Array.isArray() 方法 这是最推荐的方法&#xff0c;简单明了。它可以检测数组的情况&#xff0c;并且不会误报其他类型。 const arr [1, 2, 3]; console.log(Array.isArray(arr)); // trueconst notArray { key: value }; console.log(Array.isArray(notArray))…