第八节 LLaVA模型CLI推理构建custom推理代码Demo

文章目录

  • 前言
  • 一、parser 参数设定
    • 1、lora权重推理
    • 2、非lora权重推理
    • 3、量化权重推理
    • 4、实验总结
  • 二、初始化模型
  • 三、模型推理
  • 四、完整代码Demo


前言

我在第七节介绍了cli.py推理源码解读,而我也因项目需要构建了推理demo,我们是用来自动生成标签和推理需要。想了想,我还是用一节将我的代码记录于此,供有需求读者使用。本节,介绍更改cli.py代码,实现一张图像推理、也为需要grounding的读者提供如何在图上给出目标box。


一、parser 参数设定

为什么我要单独介绍参数设定?因为它很重要,正确的设定会减少模型错误概率。我将介绍三个部分设定,一个是使用lora权重,一个是合并权重,最后一个是使用量化方式。

1、lora权重推理

我们训练模型多数使用lora训练,而未将lora训练结果合并的权重加载方式的方法。如果我们是使用自己训练方法,可以使用如下方式给出参数:

    parser.add_argument("--model-path", type=str, default="/extend_disk/disk3/tj/LLaVA/checkpoints/llava-v1.5-13b-lora_vaild_1epoch_clean2/checkpoint-10200")parser.add_argument("--model-base", type=str, default="/extend_disk/disk3/tj/LLaVA/llava_v1.5_lora/vicuna-13b-v1.5")

如果我们是使用LLaVA自带lora方式,model-base基本不变,只需将model-path="/LLaVA/checkpoint/llava-v1.5-13b-lora",而权重下载我之前文章也介绍。

2、非lora权重推理

我们训练模型使用lora方法保存,想调用非lora方式,就需要将其转换。我们这里不说转换方法,给出非lora的权重加载方式。那这里只介绍官方给出权重加载参数设定,如下:

   parser.add_argument("--model-path", type=str, default="/LLaVA/llava_v1.5_lora/llava-v1.5-13b")parser.add_argument("--model-base", type=str, default=None)

3、量化权重推理

量化只需打开load-8bit或load-4bit参数,但量化必须是非lora权重加载方式,其代码如下:

   parser.add_argument("--load-8bit", action="store_true")# parser.add_argument("--load-4bit", default=True)parser.add_argument("--load-4bit", action="store_true")

当然量化显存占用测试,我们以LLAVA-13b量化显存测试:
不量化推理显存占用:28.4G
8bit量化推理显存占用:16.6G
4bit量化推理显存占用:10.6G

4、实验总结

我测试官方提供lora与非lora权重,我发现非lora效果会比lora好。当然这是我测试工程数据得到结论,只做参考。

二、初始化模型

我不在介绍,如下代码:

def llava_init(args):# Modeldisable_torch_init()model_name = get_model_name_from_path(args.model_path)tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)return tokenizer, model, image_processor, context_len,model_name

我想说,每个权重名称需包含v1字符,以便后续对话加载方式。

三、模型推理

模型推理,我将提示改成列表方式,我也对有框目标的文本预测做了图上画框操作。其它基本都是流程,我不在解读了。

四、完整代码Demo

最后,我给出完整的Demo,可以直接复制粘贴即可使用。若还想按照自己custom方式,读者也可根据我提供的方法来修改。其完整带阿米如下:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"import argparse
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamerdef img_drawingbox(image,conversation_info,res_img_path=None):from PIL import Image, ImageDraw, ImageFontimport rewidth, height = image.sizedraw = ImageDraw.Draw(image)box_lst = []for info in conversation_info['conversations']:value = info['value']gpt = info['from']if gpt == 'gpt':result = re.search(r'\[(.*?)\]', value)if result:content_in_brackets = result.group(1)# 将提取的内容转换为浮点数列表float_list = [float(num) for num in content_in_brackets.split(',')]if float_list not in box_lst:box_lst.append(float_list)if len(box_lst)>0:for b in box_lst:if  len(b)==4:x1,y1,x2,y2 = b[0]*width,b[1]*height,b[2]*width,b[3]*heightx1,y1,x2,y2=max(0,int(x1)),max(0,int(y1)),min(width,int(x2)),min(y2,height) box=(x1,y1,x2,y2)# 绘制矩形框draw.rectangle(box, outline="red", width=2)  # 红色边框,宽度为2像素if res_img_path is not None:image.save(res_img_path,encoding="utf-8")return imagedef load_image(image_file):if image_file.startswith('http://') or image_file.startswith('https://'):response = requests.get(image_file)image = Image.open(BytesIO(response.content)).convert('RGB')else:image = Image.open(image_file).convert('RGB')return imagedef llava_init(args):# Modeldisable_torch_init()model_name = get_model_name_from_path(args.model_path)tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)return tokenizer, model, image_processor, context_len,model_namedef llava_infer(image,test_prompt,args,tokenizer, model, image_processor, model_name='llava_v1.5'):assert isinstance(test_prompt,list), "test_prompt提示文本必须是问题构成的列表!"if 'llama-2' in model_name.lower():conv_mode = "llava_llama_2"elif "v1" in model_name.lower():conv_mode = "llava_v1"elif "mpt" in model_name.lower():conv_mode = "mpt"else:conv_mode = "llava_v0"if args.conv_mode is not None and conv_mode != args.conv_mode:print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))else:args.conv_mode = conv_modeconversations_json = {'conversations':[]}conv = conv_templates[args.conv_mode].copy()if "mpt" in model_name.lower():roles = ('user', 'assistant')else:roles = conv.roleswidth, height = image.size# Similar operation in model_worker.pyimage_tensor = process_images([image], image_processor, model.config)if type(image_tensor) is list:image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]else:image_tensor = image_tensor.to(model.device, dtype=torch.float16)for i ,inp in enumerate(test_prompt):conversations_json['conversations'].append({"from": "human","value":inp})if i==0:# first messageif model.config.mm_use_im_start_end:inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inpelse: #inp = DEFAULT_IMAGE_TOKEN + '\n' + inp  # 走这步变成  <image>\n描述图像内容conv.append_message(conv.roles[0], inp)else:# later messages # 后面循环对话添加内容conv.append_message(conv.roles[0], inp)conv.append_message(conv.roles[1], None)prompt = conv.get_prompt()input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2keywords = [stop_str]  # '</s>' ,这个是每句结束标志stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)# 下面开始走模型with torch.inference_mode():output_ids = model.generate(input_ids,images=image_tensor,do_sample=True if args.temperature > 0 else False,temperature=args.temperature,max_new_tokens=args.max_new_tokens,streamer=streamer,use_cache=True,stopping_criteria=[stopping_criteria])outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()  # ouput_ids中去除input_ids位置promptconv.messages[-1][-1] = outputsconversations_json['conversations'].append({"from": "gpt","value":outputs.replace('</s>','')})print(conversations_json)img_drawingbox(image,conversations_json,res_img_path=None)return conversations_jsondef parse_args():parser = argparse.ArgumentParser()## 直接使用合并后的模型进行推理# parser.add_argument("--model-path", type=str, default="/LLaVA/llava_v1.5_lora/llava-v1.5-13b")# parser.add_argument("--model-base", type=str, default=None)## lora推理方法parser.add_argument("--model-path", type=str, default="/LLaVA/checkpoints/llava-v1.5-13b-lora_vaild_1epoch/checkpoint-10200")parser.add_argument("--model-base", type=str, default="/LLaVA/llava_v1.5_lora/vicuna-13b-v1.5")parser.add_argument("--device", type=str, default="cuda")parser.add_argument("--conv-mode", type=str, default=None)parser.add_argument("--temperature", type=float, default=0.2)parser.add_argument("--max-new-tokens", type=int, default=512)parser.add_argument("--load-8bit", action="store_true")# parser.add_argument("--load-4bit", default=True)parser.add_argument("--load-4bit", action="store_true")parser.add_argument("--debug", action="store_true")args = parser.parse_args()return argsif __name__ == "__main__":args=parse_args()tokenizer, model, image_processor, context_len,model_name=llava_init(args)img_path = '/LLaVA/llava/serve/examples/1.jpg'images = load_image(img_path)test_prompt = ["图中是否有城市管理相关目标?若有,请提供相应坐标。"]predect_information_dict = llava_infer(images,test_prompt,args,tokenizer, model, image_processor, model_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/4946.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rust Vec<T> 集合使用教程

Rust Vec 集合使用教程 文章目录 Rust Vec<T> 集合使用教程1. 创建和初始化 Vec<T>代码示例运行结果 2. 访问和修改 Vec<T> 中的元素代码示例运行结果 3. 添加和删除 Vec<T> 中的元素代码示例运行结果 4. 遍历 Vec<T>代码示例运行结果 5. 使用 V…

web服务的部署及高级优化

搭建web服务器 1.1、配置主机IP以及软件仓库搭建 [rootserver129 ~]# vmset.sh 100 //主机IP配置为172.25.254.100 1.2、查看搭建web服务器所需的软件包 [rootserver100 ~]# dnf search nginx 名称 精准匹配&#xff1a;nginx nginx.x86_64 : A high performance web serve…

头歌实践教学平台:CG7-v2.0-实体消隐

第1关&#xff1a;立方体消隐 一. 任务描述 1. 本关任务 (1) 理解深度缓冲器算法(Z-Buffer)算法; (2) 将triangle函数和main函数中的空白部分补充完整。 2. 输入 (1) 代码将自动输入一个边长为1的obj正方体模型&#xff0c;具体模型如下图&#xff1a; (2) 代码会自动对将…

Kafka Exactly Once 语义实现原理:幂等性与事务消息

01 前言 在现代分布式系统中&#xff0c;确保数据处理的准确性和一致性是至关重要的。Apache Kafka&#xff0c;作为一个广泛使用的流处理平台&#xff0c;提供了强大的消息队列和流处理功能。随着业务需求的增长&#xff0c;Kafka 的事务消息功能应运而生&#xff0c;它允许应…

力扣279完全平方数

力扣279完全平方数 给你一个整数 n &#xff0c;返回 和为 n 的完全平方数的最少数量 。 完全平方数 是一个整数&#xff0c;其值等于另一个整数的平方&#xff1b;换句话说&#xff0c;其值等于一个整数自乘的积。例如&#xff0c;1、4、9 和 16 都是完全平方数&#xff0c;…

【OceanBase诊断调优】—— OceanBase 数据库日志解读

适用版本&#xff1a;V2.1.x、V2.2.x、V3.1.x、V3.2.x observer.log 日志 OBServer 启动日志 搜索关键字&#xff1a; [NOTICE] 日志说明&#xff1a; OBServer 启动过程中比较关键的日志信息。 [2023-05-11 14:19:09.703272] INFO [SERVER] ob_server.cpp:533 [95303][0]…

单链表的经典oj题(1)

前言 这次博客将要以图解的形式&#xff0c;把单链表的经典题目&#xff0c;讲解&#xff0c;绝对是干货&#xff0c;来吧兄弟萌 第一题 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 …

USB HID报告描述符学习

参考资料 HID 报告描述符 (qq.com)https://mp.weixin.qq.com/s?__bizMzU1ODI3MzQ1MA&mid2247485748&idx1&sn112bd8014eb96b03308b3b808549e8d4&chksmfc284ff1cb5fc6e770c2d2ece46c17bf2529901b45a357938978fa62163723556ad497b05c47&cur_album_id3340417…

力扣经典150题第四十七题:汇总区间

目录 题目描述和要求示例解释解题思路算法实现复杂度分析测试和验证总结和拓展参考资料 题目描述和要求 给定一个无重复元素的有序整数数组 nums&#xff0c;要求返回恰好覆盖数组中所有数字的最小有序区间范围列表。即&#xff0c;nums 的每个元素都恰好被某个区间范围所覆盖…

三、VLAN间路由(三层交换)

VLAN间路由可以通过二层交换机配合路由器来实现&#xff0c;也可以通过三层交换机来实现。 目录 1.单臂路由 2.通过三层交换机实现不同vlan的互访 1.单臂路由 注&#xff1a; 1.三层接口不能正确识别带vlan tag的数据帧 2.所有子接口与主接口共享MAC地址 命令 int g0/0/0.1…

Tom与Locust的渐入佳境

本书 第一章&#xff1a;Tom的Locust压力测试之旅 第二章&#xff1a;意外的挑战&#xff1a;系统性能问题的出现 第三章&#xff1a;高手相助&#xff1a;遇见性能测试专家 第四章&#xff1a;Locust初探&#xff1a;探寻压力测试工具 第五章&#xff1a;脚本编写&#xff1a…

Java Spring 中 Bean 的作用域(Scope)

在 Java Spring 框架中&#xff0c;Bean 的作用域&#xff08;Scope&#xff09;定义了 Bean 的生命周期以及其在 Spring 容器中的可见性。Spring 提供了几种不同的 Bean 作用域&#xff0c;以满足不同的应用需求。以下是 Spring 中主要的 Bean 作用域及其详细解释&#xff1a;…

试用了三个Ai音乐工具,我的偶像河图要完蛋了

试了三个生成音乐的ai工具&#xff0c;分别是爆火的suno,后期新秀udio&#xff0c;还有我们国内的天工。 先说感受&#xff0c;suno和天工我觉得稍微靠前&#xff0c;udio可能我的配置风格有问题&#xff0c;啪啪啪连选了好几个风格&#xff0c;生成的东西有点怪。 我随手写了…

语音识别的基本概念

语音识别的基本概念​​​​​​​ ​​​​​​​ 言语是一种复杂的现象。人们很少了解它是如何产生和感知的。天真的想法常常是语音是由单词构成的&#xff0c;而每个单词又由音素组成。不幸的是&#xff0c;现实却大不相同。语音是一个动态过程&#xff0c;没有明确区分的…

linux学习:线程安全(信号量+互斥锁读写锁+条件变量+可重入函数)

目录 信号量 有名信号量 步骤 api 创建、打开一个POSIX有名信号量 对 POSIX 有名信号量进行 P、V 操作 关闭、删除 POSIX 有名信号量 例子 无名信号量 步骤 api 初始化、销毁 POSIX 无名信号量 互斥锁读写锁 例子 两条线程 使用互斥锁来互斥地访问标准输出 在加锁…

算法人生(12):从“优先级队列算法”到“”六点优先工作法”

算法思想和生活中很多解决问题的思想有着异曲同工之妙&#xff0c;让我们来看下今天的“优先级队列算法”可以怎么应用到我们的生活中吧&#xff01; 优先级队列算法&#xff08;Priority Queue Algorithm&#xff09; 是一种特殊的数据结构&#xff0c;它在常规队列秉持着“先…

MySQL中START REPLICA 语句详解

在数据库管理和操作中&#xff0c;复制是保证数据可用性和分布式处理的关键技术之一。MySQL从8.0.22版本开始引入了START REPLICA语句&#xff0c;替代了原来的START SLAVE语句。本篇博文将详细介绍START REPLICA语句的用法和功能&#xff0c;帮助数据库管理员更有效地管理MySQ…

软件工程师,如何有效缓解工作压力

概述 在这个快速迭代、技术日新月异的数字时代&#xff0c;软件工程师们常常站在技术创新的最前沿。他们肩负着构建高效、可靠软件系统的重任&#xff0c;同时也面临着紧迫的截止日期、复杂的技术难题和持续的学习需求&#xff0c;这些因素共同构成了巨大的工作压力。如何在高压…

[SQL系列]从零开始学Clickhouse——集群篇

在上一篇中&#xff0c;我们通过Docker构建了一个简单的单点Clickhouse&#xff0c;但是如果要做大数据的处理的话&#xff0c;Clickhouse集群是必不可少的&#xff0c;今天我们先用Docker简单地搭建一个Clickhouse集群。 容器逐个部署 使用Docker部署ClickHouse集群涉及几个步…

1.认识USB协议

目录 前言 在嵌入式场景的具体体现 USB通信协议 总结 前言 在这之前&#xff0c;我们需要认识USB是什么东西&#xff0c;它是一种通信协议&#xff0c;协议只是规定数据的&#xff0c;在物理层面上&#xff0c;它可以有多种表现形式。在我们日常生活中也非常常见&#xff0…