【Pytorch和Keras】使用transformer库进行图像分类

目录

    • 一、环境准备
    • 二、基于Pytorch的预训练模型
      • 1、准备数据集
      • 2、加载预训练模型
      • 3、 使用pytorch进行模型构建
    • 三、基于keras的预训练模型
    • 四、模型测试
    • 五、参考

 现在大多数的模型都会上传到huggface平台进行统一的管理,transformer库能关联到huggface中对应的模型,并且提供简洁的transformer模型调用,这大大提高了开发人员的开发效率。本博客主要利用transformer库实现一个简单的模型微调,以进行图像分类的任务。


一、环境准备

 使用终端命令行安装对应的第三方包,具体安装命令输入如下:

pip install transformers datasets evaluate

二、基于Pytorch的预训练模型

  由于下面这些内容需要在huggface上申请账号权限,才能进行模型和数据集加载,如果之前有从huggface上拉取模型和数据集的经验,可以略过,如果没有配置过,可以参考笔者之前的文章https://blog.csdn.net/qq_40734883/article/details/143922095,然后直接申请Write权限就可以。

在这里插入图片描述
  后续所有涉及到的数据集food101和transformer模型都需要参考上述文章进行直接下载,才能运行整个程序,或者在google的colab直接运行。

  如果在google的colab上运行,请提前设置好电脑的GPU资源,同时加入huggface登录代码,具体如下:

from huggingface_hub import notebook_login
notebook_login()

  运行之后会提示进行token输入,按之前获取到的token输入即可。

1、准备数据集

  这里以food101数据集作为微调数据集,在imagenet-21k上训练完成的transformer模型(vit-base-patch16-224)进行优化

from datasets import load_datasetfood = load_dataset("food101", split="train[:5000]")# 划分数据集,训练集:测试集=8:2,food有两个键:一个train,一个test
food = food.train_test_split(test_size=0.2)  # 标签转换
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):label2id[label] = str(i)id2label[str(i)] = label

id2label为通过id访问标签的字典,后续会使用到。

2、加载预训练模型

from transformers import AutoImageProcessorcheckpoint = "google/vit-base-patch16-224-in21k"   # ImageNet-21k上的预训练模型
image_processor = AutoImageProcessor.from_pretrained(checkpoint)  # 从huggface拉取并加载模型

3、 使用pytorch进行模型构建

from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor# 数据预处理操作定义
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (image_processor.size["shortest_edge"]if "shortest_edge" in image_processor.sizeelse (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])# 对原始数据进行RGB及字典化
def transforms(examples):examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]del examples["image"]return examplesfood = food.with_transform(transforms)
# 验证
import evaluate
# 指定验证过程中的评价指标-准确率
accuracy = evaluate.load("accuracy")import numpy as np
def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return accuracy.compute(predictions=predictions, references=labels)

  训练设置和运行,具体输入代码如下:

# 整合训练中的数据,以便在模型训练或评估过程中使用
from transformers import DefaultDataCollator
data_collator = DefaultDataCollator()from transformers import AutoModelForImageClassification, TrainingArguments, Trainer# 初始化模型
model = AutoModelForImageClassification.from_pretrained(checkpoint,num_labels=len(labels),id2label=id2label,label2id=label2id,
)# 设置模型优化参数
training_args = TrainingArguments(output_dir="my_awesome_food_model",remove_unused_columns=False,evaluation_strategy="epoch",save_strategy="epoch",learning_rate=5e-5,per_device_train_batch_size=16,gradient_accumulation_steps=4,per_device_eval_batch_size=16,num_train_epochs=3,warmup_ratio=0.1,logging_steps=10,load_best_model_at_end=True,metric_for_best_model="accuracy",push_to_hub=True,
)# 初始化训练实例
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=food["train"],eval_dataset=food["test"],tokenizer=image_processor,compute_metrics=compute_metrics,
)trainer.train()  # 开始训练trainer.push_to_hub()  # 推送到huggfacehub

  经过上述设置训练完成之后,会将模型微调结果推送到huggface平台,如果不想推送,可以不运行相关的命令行,并且training_args中的push_to_hub=False

  训练结果如下图所示:

在这里插入图片描述
  默认需要选择是否关联wandb,如果不想选择,直接根据设置提示跳过即可。

  如果选择了推送到huggfacehub(trainer.push_to_hub() )的话,在个人的huggface上会有一个名为my_awesome_food_model的模型,里面包含了模型训练的各个参数设置和测试结果。

在这里插入图片描述


三、基于keras的预训练模型

  使用transflow的keras API 进行模型的搭建,具体代码如下:

from transformers import create_optimizer# 超参数设置
batch_size = 16
num_epochs = 5
num_train_steps = len(food["train"]) * num_epochs
learning_rate = 3e-5
weight_decay_rate = 0.01# 定义优化方式和策略
optimizer, lr_schedule = create_optimizer(init_lr=learning_rate, num_train_steps=num_train_steps, weight_decay_rate=weight_decay_rate, num_warmup_steps=0)# 定义分类器
from transformers import TFAutoModelForImageClassification
model = TFAutoModelForImageClassification.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)# converting our train dataset to tf.data.Dataset
tf_train_dataset = food["train"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=True, batch_size=batch_size, collate_fn=data_collator)# converting our test dataset to tf.data.Dataset
tf_eval_dataset = food["test"].to_tf_dataset(columns="pixel_values", label_cols="label", shuffle=False, batch_size=batch_size, collate_fn=data_collator)# 定义损失函数
from tensorflow.keras.losses import SparseCategoricalCrossentropy
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)model.compile(optimizer=optimizer, loss=loss)from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback
# 定义验证指标
metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=tf_eval_dataset)
# 推送到huggface回调函数
push_to_hub_callback = PushToHubCallback(output_dir="food_classifier", tokenizer=image_processor, save_strategy="no")
callbacks = [metric_callback, push_to_hub_callback]# 开始训练
model.fit(tf_train_dataset, validation_data=tf_eval_dataset, epochs=num_epochs, callbacks=callbacks)

四、模型测试

 这里使用微调好的模型在food101上找一张验证图像进行简单的验证测试,具体代码如下:

# 验证food中验证集的某一张图像
ds = load_dataset("food101", split="validation[-5:-1]")
image = ds["image"][-1]# visualize image
import matplotlib.pyplot as plt
plt.imshow(image)
plt.axis('off')  
plt.show()

  测试图像如下所示:
在这里插入图片描述


from transformers import pipeline
# initialize classifier instance
classifier = pipeline("image-classification", model="my_awesome_food_model")
classifier(image)from transformers import AutoImageProcessor
import torch
# load pre-trained image processor
image_processor = AutoImageProcessor.from_pretrained("my_awesome_food_model")
inputs = image_processor(image, return_tensors="pt")from transformers import AutoModelForImageClassification
# laod pre-trained model
model = AutoModelForImageClassification.from_pretrained("my_awesome_food_model")
with torch.no_grad():logits = model(**inputs).logits# 输出测试结果
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

  输出结果如下所示:

Device set to use cuda:0
[{'label': 'ramen', 'score': 0.9517934918403625},{'label': 'bruschetta', 'score': 0.7566707730293274},{'label': 'hamburger', 'score': 0.7004948854446411},{'label': 'chicken_wings', 'score': 0.6275856494903564},{'label': 'prime_rib', 'score': 0.5991673469543457}]

  预测结果为:ramen

  释义:“ramen”一词源于日语“ラーメン”,是“拉面”的意思。它进一步追溯至汉语“拉面”,是一种起源于中国、流行于日本及其他东亚地区的面条食品。在日本,拉面通常由小麦面粉制成的面条,搭配肉汤和各种配料,如叉烧、鸡蛋、蔬菜等。

五、参考

[1] https://huggingface.co/docs/transformers/main/tasks/image_classification

[2] https://github.com/huggingface/transformers/blob/main/docs/source/en/installation.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/68940.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

relational DB与NoSQL DB有什么区别?该如何选型?

Relational Database(关系型数据库,简称RDB)与NoSQL Database(非关系型数据库)是两类常见的数据库类型。它们在设计理念、数据存储方式、性能优化、扩展性等方面有许多差异。下面我们将会详细分析它们的区别,以及如何根据应用场景进行选型。 一、数据模型的区别 关系型…

Flutter常用Widget小部件

小部件Widget是一个类,按照继承方式,分为无状态的StatelessWidget和有状态的StatefulWidget。 这里先创建一个简单的无状态的Text小部件。 Text文本Widget 文件:lib/app/app.dart。 import package:flutter/material.dart;class App exte…

智能小区物业管理系统推动数字化转型与提升用户居住体验

内容概要 在当今快速发展的社会中,智能小区物业管理系统的出现正在改变传统的物业管理方式。这种系统不仅仅是一种工具,更是一种推动数字化转型的重要力量。它通过高效的技术手段,将物业管理与用户居住体验紧密结合,无疑为社区带…

给AI加知识库

1、加载 Document Loader文档加载器 在 langchain_community. document_loaders 里有很多种文档加载器 from langchain_community. document_loaders import *** 1、纯文本加载器:TextLoader,纯文本(不包含任何粗体、下划线、字号格式&am…

游戏引擎 Unity - Unity 设置为简体中文、Unity 创建项目

Unity Unity 首次发布于 2005 年,属于 Unity Technologies Unity 使用的开发技术有:C# Unity 的适用平台:PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域:开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

小红的小球染色期望

B-小红的小球染色_牛客周赛 Round 79 题目描述 本题与《F.R小红的小球染色期望》共享题目背景,但是所求内容与范围均不同,我们建议您重新阅读题面。 有 n 个白色小球排成一排。小红每次将随机选择两个相邻的白色小球,将它们染成红色。小红…

ASP.NET Core与配置系统的集成

目录 配置系统 默认添加的配置提供者 加载命令行中的配置。 运行环境 读取方法 User Secrets 注意事项 Zack.AnyDBConfigProvider 案例 配置系统 默认添加的配置提供者 加载现有的IConfiguration。加载项目根目录下的appsettings.json。加载项目根目录下的appsettin…

Redis集群理解以及Tendis的优化

主从模式 主从同步 同步过程: 全量同步(第一次连接):RDB文件加缓冲区,主节点fork子进程,保存RDB,发送RDB到从节点磁盘,从节点清空数据,从节点加载RDB到内存增量同步&am…

沙皮狗为什么禁养?

各位铲屎官们,今天咱们来聊聊一个比较敏感的话题:沙皮狗为什么会被禁养?很多人对沙皮狗情有独钟,但有些地方却明确禁止饲养这种犬种,这背后到底是什么原因呢?别急,今天就来给大家好好揭秘&#…

物联网 STM32【源代码形式-ESP8266透传】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】

一、MQTT介绍 MQTT(Message Queuing Telemetry Transport,消息队列遥测传输协议)是一种基于发布/订阅模式的轻量级通讯协议,构建于TCP/IP协议之上。它最初由IBM在1999年发布,主要用于在硬件性能受限和网络状况不佳的情…

w186格障碍诊断系统spring boot设计与实现

🙊作者简介:多年一线开发工作经验,原创团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹赠送计算机毕业设计600个选题excel文…

题解 洛谷 Luogu P1955 [NOI2015] 程序自动分析 并查集 离散化 哈希表 C++

题目 传送门 P1955 [NOI2015] 程序自动分析 - 洛谷 | 计算机科学教育新生态https://www.luogu.com.cn/problem/P1955 思路 主要用到的知识是并查集 (如何实现并查集,这里不赘述了) 若 xi xj,则合并它们所在的集合。若 xi ! xj,则 i 和 …

无用知识之:std::initializer_list的秘密

先说结论,用std::initializer_list初始化vector,内部逻辑是先生成了一个临时数组,进行了拷贝构造,然后用这个数组的起终指针初始化initializer_list。然后再用initializer_list对vector进行初始化,这个动作又触发了拷贝…

97,【5】buuctf web [极客大挑战 2020]Greatphp

进入靶场 审代码 <?php // 关闭所有 PHP 错误报告&#xff0c;防止错误信息泄露可能的安全隐患 error_reporting(0);// 定义一个名为 SYCLOVER 的类 class SYCLOVER {// 定义类的公共属性 $sycpublic $syc;// 定义类的公共属性 $loverpublic $lover;// 定义魔术方法 __wa…

蓝桥杯单片机第七届省赛

前言 这套题不难&#xff0c;相对于第六套题这一套比较简单了&#xff0c;但是还是有些小细节要抓 题目 OK&#xff0c;以上就是全部的题目了&#xff0c;这套题目相对来说逻辑比较简单&#xff0c;四个按键&#xff0c;S4控制pwm占空比&#xff0c;S5控制计时时间&#xff0…

【C语言】自定义类型讲解

文章目录 一、前言二、结构体2.1 概念2.2 定义2.2.1 通常情况下的定义2.2.2 匿名结构体 2.3 结构体的自引用和嵌套2.4 结构体变量的定义与初始化2.5 结构体的内存对齐2.6 结构体传参2.7 结构体实现位段 三、枚举3.1 概念3.2 定义3.3 枚举的优点3.3.1 提高代码的可读性3.3.2 防止…

我问了DeepSeek和ChatGPT关于vue中包含几种watch的问题,它们是这么回答的……

前言&#xff1a;听说最近DeepSeek很火&#xff0c;带着好奇来问了关于Vue的一个问题&#xff0c;看能从什么角度思考&#xff0c;如果回答的不对&#xff0c;能不能尝试纠正&#xff0c;并帮我整理出一篇不错的文章。 第一次回答的原文如下&#xff1a; 在 Vue 中&#xff0c;…

纯后训练做出benchmark超过DeepseekV3的模型?

论文地址 https://arxiv.org/pdf/2411.15124 模型是AI2的&#xff0c;他们家也是玩开源的 先看benchmark&#xff0c;几乎是纯用llama3 405B后训练去硬刚出一个gpt4o等级的LLamA405 我们先看之前的机遇Lllama3.1 405B进行全量微调的模型 Hermes 3&#xff0c;看着还没缘模型…

UbuntuWindows双系统安装

做系统盘&#xff1a; Ubuntu20.04双系统安装详解&#xff08;内容详细&#xff0c;一文通关&#xff01;&#xff09;_ubuntu 20.04-CSDN博客 ubuntu系统调整大小&#xff1a; 调整指南&#xff1a; 虚拟机中的Ubuntu扩容及重新分区方法_ubuntu重新分配磁盘空间-CSDN博客 …

在 Zemax 中使用布尔对象创建光学光圈

在 Zemax 中&#xff0c;布尔对象用于通过组合或减去较简单的几何形状来创建复杂形状。布尔运算涉及使用集合运算&#xff08;如并集、交集和减集&#xff09;来组合或修改对象的几何形状。这允许用户在其设计中为光学元件或机械部件创建更复杂和定制的形状。 本视频中&#xf…