[pai-diffusion]pai的easynlp的clip模型训练

EasyNLP带你玩转CLIP图文检索 - 知乎作者:熊兮、章捷、岑鸣、临在导读随着自媒体的不断发展,多种模态数据例如图像、文本、语音、视频等不断增长,创造了互联网上丰富多彩的世界。为了准确建模用户的多模态内容,跨模态检索是跨模态理解的重要任务,…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/528476134

initialize_easynlp()->train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path("alibaba-pai/clip_chinese_roberta_base_vit_base"),data_file="MUGE_MR_train_base64_part.tsv",max_seq_length=32,input_schema="text:str:1,image:str:1",first_sequence="text",second_sequence="image",is_training=True)
valid_dataset = CLIPDataset()model = get_application_model(app_name='clip',...)
- easynlp.appzoo.api.ModelMapping->CLIPApp
- easynlp.appzoo.clip.model.py->CLIPApp
- CHINESE_CLIP->
- self.visual = VisualTransformer()
- self.bert = BertModel()trainer = Trainer(model,train_dataset,user_defined_parameters,  evaluator=get_application_evaluator(app_name="clip",valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))trainer.train()
- for _epoch in range(self._first_epoch,int(args.epoch_num)):for _step,batch in enumerate(self._train_loader):    label_ids = batch.pop()forward_outputs = self._model(batch)loss_dict = self.model_module.compute_loss(forward_outputs,label_ids)_loss = loss_dict('loss')_loss.backward()model = get_application_model_evaluation()
evaluator = get_application_evaluator()
evaluator.evaluate(model)

数据处理:

import os
import base64
import multiprocessing
from tqdm import tqdmdef process_image(image_path):# 从图片路径中提取中文描述image_name = os.path.basename(image_path)description = os.path.splitext(image_name)[0]# 将图片转换为 Base64 编码with open(image_path, 'rb') as f:image_data = f.read()base64_data = base64.b64encode(image_data).decode('utf-8')return description, base64_datadef generate_tsv(directory):image_paths = [os.path.join(directory, filename) for filename in os.listdir(directory) iffilename.endswith(('.jpg', '.png'))]with multiprocessing.Pool() as pool, tqdm(total=len(image_paths), desc='Processing Images') as pbar:results = []for result in pool.imap_unordered(process_image, image_paths):results.append(result)pbar.update(1)with open('/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train.tsv','w', encoding='utf-8') as f:for description, base64_data in results:line = f"{description}\t{base64_data}\n"f.write(line)if __name__ == '__main__':target_directory = "/home/image_team/image_team_docker_home/lgd/e_commerce_sd/data/vcg_furnitures_text_image/vcg_furnitures_train/img_download/"# import pdb;pdb.set_trace()generate_tsv(target_directory)

训练代码:

import torch.cuda
from easynlp.appzoo import CLIPDataset
from easynlp.appzoo import get_application_predictor, get_application_model, get_application_evaluator, \get_application_model_for_evaluation
from easynlp.core import Trainer, PredictorManager
from easynlp.utils import initialize_easynlp, get_args, get_pretrain_model_path
from easynlp.utils.global_vars import parse_user_defined_parametersdef main():# /root/.easynlp/modelzoo中train_dataset = CLIPDataset(pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),data_file=args.tables.split(",")[0],max_seq_length=args.sequence_length,input_schema=args.input_schema,first_sequence=args.first_sequence,second_sequence=args.second_sequence,is_training=True)valid_dataset = CLIPDataset(# 预训练模型名称路径,这里我们使用封装好的get_pretrain_model_path函数,来处理模型名称"alibaba-pai/clip_chinese_roberta_base_vit_base"以得到其路径,并自动下载模型pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),data_file=args.tables.split(",")[-1],# "data/pai/MUGE_MR_valid_base64_part.tsv"max_seq_length=args.sequence_length,  # 文本最大长度,超过将截断,不足将paddinginput_schema=args.input_schema,  # 输入tsv数据的格式,逗号分隔的每一项对应数据文件中每行以\t分隔的一项,每项开头为其字段标识,如label、sent1等first_sequence=args.first_sequence,  # 用于说明input_schema中哪些字段作为第一/第二列输入数据second_sequence=args.second_sequence,is_training=False)  # 是否为训练过程,train_dataset为True,valid_dataset为Falsemodel = get_application_model(app_name=args.app_name,  # 任务名称,这里选择文本分类"clip"pretrained_model_name_or_path=get_pretrain_model_path(args.pretrained_model_name_or_path),user_defined_parameters=user_defined_parameters# user_defined_parameters:用户自定义参数,直接填入刚刚处理好的自定义参数user_defined_parameters)trainer = Trainer(model=model,train_dataset=train_dataset,user_defined_parameters=user_defined_parameters,evaluator=get_application_evaluator(app_name=args.app_name,valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32))trainer.train()# 模型评估model = get_application_model_for_evaluation(app_name=args.app_name,pretrained_model_name_or_path=args.checkpoint_dir,user_defined_parameters=user_defined_parameters)evaluator = get_application_evaluator(app_name=args.app_name,valid_dataset=valid_dataset,user_defined_parameters=user_defined_parameters,eval_batch_size=32)model.to(torch.cuda.current_device())evaluator.evaluate(model=model)# 模型预测if test:predictor = get_application_predictor(app_name="clip",model_dir="./outputs/clip_model/",first_sequence="text",second_sequence="image",sequence_length=32,user_defined_parameters=user_defined_parameters)predictor_manager = PredictorManager(predictor=predictor,input_file="data/vcg_furnitures_text_image/vcg_furnitures_test.tsv",input_schema="text:str:1",output_file="text_feat.tsv",output_schema="text_feat",append_cols="text",batch_size=2)predictor_manager.run()if __name__ == "__main__":initialize_easynlp()args = get_args()user_defined_parameters = parse_user_defined_parameters('pretrain_model_name_or_path=alibaba-pai/clip_chinese_roberta_base_vit_base')args.checkpoint_dir = "./outputs/clip_model/"args.pretrained_model_name_or_path = "alibaba-pai/clip_chinese_roberta_base_vit_base"# args.n_gpu = 3# args.worker_gpu = "1,2,3"args.app_name = "clip"args.tables = "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"# "data/vcg_furnitures_text_image/vcg_furnitures_train.tsv," \#               "data/vcg_furnitures_text_image/vcg_furnitures_test.tsv"# "data/pai/MUGE_MR_train_base64_part.tsv,data/pai/MUGE_MR_valid_base64_part.tsv"args.input_schema = "text:str:1,image:str:1"args.first_sequence = "text"args.second_sequence = "image"args.learning_rate = 1e-4args.epoch_num = 1000args.random_seed = 42args.save_checkpoint_steps = 200args.sequence_length = 32# args.train_batch_size = 2args.micro_batch_size = 32test = Falsemain()# python -m torch.distributed.launch --nproc_per_node 4 tools/train_pai_chinese_clip.py

说一点自己的想法,在我自己工作之初,我很喜欢去拆解一些框架,例如openmm系列,但其实大部分在训练过程上都是相似的,大可不必,在改动上,也没有必要对其进行流程上的大改动,兼具百家之长,了解整体pipeline,更加专注在pipeline实现和效果导向型的结果提交更加有效。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/84226.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ctfshow web入门 代码审计篇 web301-web310 详细题解 全

CTFshow 代码审计 web301 下载的附件的目录结构如下: 开题后界面,看见输入框,感觉是sql。 大概浏览一遍源码,我们可以发现在checklogin.php文件中有无过滤的SQL语句,SQL注入没得跑了。 这题SQL注入有三种做法。 方法一…

Linux:GlusterFS 集群

GlusterFS介绍 1)Glusterfs是一个开源的分布式文件系统,是Scale存储的核心,能够处理千数量级的客户端.在传统的解决 方案中Glusterfs能够灵活的结合物理的,虚拟的和云资源去体现高可用和企业级的性能存储. 2)Glusterfs通过TCP/IP或InfiniBand RDMA网络链…

Everything + Cpolar,打造在线搜索的终极神器

文章目录 前言1. 下载安装注册cpolar2. Everything安装和设置2.1 进入Everything官网进行下载2.2 对Everything文件进行设定 3. 创建cpolar内网穿透隧道4. 公网访问测试Everything5. 固定连接公网地址 前言 你还在用Windows资源管理器自带的搜索工具来搜索文件吗?这…

轮换对称性

二重积分 普通对称性–D关于 y x yx yx对称: ∬ D f ( x , y ) d σ { 2 ∬ D 1 f ( x , y ) d σ f ( x , y ) f ( y , x ) 0 f ( x , y ) − f ( y , x ) \iint_{D}f(x,y)d\sigma\begin{cases} 2\iint_{D_1}f(x,y)d\sigma\ \ \ \ \ \ f(x,y)f(y,x) \\ 0 \ \…

OpenStack创建云主机并连接CRT

文章目录 OpenStackT版创建云主机并连接CRT命令行操作(1)创建镜像(2)创建实例(3)创建网络创建内网创建外网 (4)创建安全组(5)创建路由(6&#xff…

1952-2018年中国各省份人均GDP数据(消涨处理)

1952-2018年中国各省份人均GDP数据(消涨处理) 1、时间:1952-2018年 2、范围:30省市 3、指标:人均GDP 4、来源:《新中国60周年统计汇编》和各省年鉴 5、指标解释: 过程为环比人均GDP指数转…

YSA Toon (Anime/Toon Shader)

这是一个Toon着色器/Cel阴影着色器,用于Unity URP 此着色器的目的是使角色或物体阴影实时看起来尽可能接近真实的动画或卡通效果 可以用于游戏,渲染,插图等 着色器特性,如:面的法线平滑、轮廓修复、先进的边缘照明、镜面照明、完全平滑控制 这个文档包括所有的功能https:/…

Eclipse ABAP ADT 集成详细安装教程

最近看到网上有个源码使用CDS做的,然后看了一下原来还可以用eclipse,趁热打铁,试了一把,最后成功了,中间可能会有一些报错,可以自己慢慢解决,大概就是这样的。 SAP的开发,有三种开发…

Java————List

一 、顺序表和链表 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构, 常见的线性表:顺序表、链表、栈、队列… 线性表在逻辑上是线性结构,也就说是连续的一条直…

微信小程序与idea后端如何进行数据交互

交互使用的其实就是调用的req.get(url)方法 进行路径访问,你要先保证自己的springboot项目已经成功运行了: 如下: 如何交互的? 微信小程序:如下为index.js页面 在onLoad()事件中调用方法Project.findAllCities() 要…

贝叶斯滤波计算4d毫米波聚类目标动静属性

机器人学中有些问题是二值问题,对于这种二值问题的概率评估问题可以用二值贝叶斯滤波器binary Bayes filter来解决的。比如机器人前方有一个门,机器人想判断这个门是开是关。这个二值状态是固定的,并不会随着测量数据变量的改变而改变。就像门…

rv1126-rv1109-test

测试指令 播放音频:aplay aigei.wav 测试时间: 查看系统时间:date 设置时间:date -s "2023-09-21 16:00:00" 设置芯片时间:hwclock -w 查看芯片时间:hwclock 测试背光: echo 0 > sys/class/backlight/backlight/brightness echo 50 > sys/class/backlig…

期权如何交易?期权如何做模拟交易?

买卖期权的第一步就是要有期权账户,国内的期权品种有商品期权和ETF期权以及股指期权,每种的开户方式和要求都不同,下文为大家介绍期权如何交易?期权如何做模拟交易? 一、期权交易需要开立一个期权账户,可以…

【Spring Boot 源码学习】OnBeanCondition 详解

Spring Boot 源码学习系列 OnBeanCondition 详解 引言往期内容主要内容1. getOutcomes 方法2. getMatchOutcome 方法2.1 ConditionalOnBean 注解处理2.2 ConditionalOnSingleCandidate 注解处理2.3 ConditionalOnMissingBean 注解处理 3. getMatchingBeans 方法 总结 引言 上篇…

实战演练 | Navicat 常用功能之转储与运行 SQL 文件

数据库管理工作中,"转储 SQL 文件"和"运行 SQL 文件"是两个极为常见操作。一般来说,用户使用数据库管理工具或命令行工具来完成。Navicat 管理开发工具中的“转储 SQL 文件”和“运行 SQL 文件”功能具有直观易用的界面、多种文件格…

Linux内核源码分析 (B.x)Linux物理内存的初始化

Linux内核源码分析 (B.x)Linux物理内存的初始化 文章目录 Linux内核源码分析 (B.x)Linux物理内存的初始化一、DDR简介二、内存节点三、内存管理区域ZONE四、 struct zone五、 struct page六、mem_map数组七、伙伴系统简介八、迁移类型九、内存初始化十、总结 一、DDR简介 详细可…

threejs给3d模型中的物体换肤(修改材质)

变成这样 this.otherModel.traverse(function (child) {if (child instanceof THREE.Mesh && child.name Cylinder240) {// 导入纹理const textureLoader new THREE.TextureLoader();const floorColortextureLoader.load(require(../../../public/img/color.jpg));co…

5个GitHub热门算法岗面试攻略,附资源下载

最近听行内的大佬们讨论,说今年的秋招情况依旧卷的激烈,不知道大家有没有都拿到满意的offer? 已经拿到的同学我先给赞个,还没有获得心仪offer的同学也不要着急,当下最该做的就是抓紧时间提升自己的硬实力,…

openssl创建CA证书教程

配置生成CA证书 总示意图: (1),通过openssl创建CA证书 第一步:创建一个秘钥,这个便是CA证书的根本,之后所有的东西都来自这个秘钥 # 通过rsa算法生成2048位长度的秘钥 openssl genrsa -out myCA.key 2048 第二步&#…

小谈设计模式(5)—开放封闭原则

小谈设计模式(5)—开放封闭原则 专栏介绍专栏地址专栏介绍 开放封闭原则核心思想关键词概括扩展封闭 解释抽象和接口多态 代码示例代码解释 优缺点优点可扩展性可维护性可复用性高内聚低耦合 缺点抽象设计的复杂性需要预留扩展点可能引入过度设计 总结 专…