从代码中学习:评估模型的性能

从代码中学习:评估模型的性能

在这篇博客中,我们将逐步解析一段Python代码,并解释每一行的作用。这段代码主要用于加载数据集、加载预训练模型、进行推理并评估模型的性能。我们将以简单易懂的方式解释每一部分,确保即使是小学生也能理解。

1. 导入必要的库

首先,我们需要导入一些Python库,这些库将帮助我们完成后续的任务。

import datasets
import tempfile
import logging
import random
import config
import os
import yaml
import logging
import difflib
import pandas as pdimport transformers
import datasets
import torchfrom tqdm import tqdm
from utilities import *
from transformers import AutoTokenizer, AutoModelForCausalLMlogger = logging.getLogger(__name__)
global_config = None

解释:

  • datasets:用于加载和处理数据集的库。
  • tempfile:用于创建临时文件和目录。
  • logging:用于记录日志信息。
  • random:用于生成随机数。
  • config:用于管理配置文件的库。
  • os:用于与操作系统交互,如文件路径操作。
  • yaml:用于解析YAML格式的配置文件。
  • difflib:用于比较文本的差异。
  • pandas as pd:用于数据处理和分析。
  • transformers:用于加载和使用预训练的自然语言处理模型。
  • torch:PyTorch库,用于深度学习。
  • tqdm:用于显示进度条。
  • AutoTokenizerAutoModelForCausalLM:用于自动加载预训练的分词器和语言模型。

2. 加载数据集

接下来,我们加载一个名为lamini/lamini_docs的数据集,并从中提取测试集。

dataset = datasets.load_dataset("lamini/lamini_docs")
test_dataset = dataset["test"]

解释:

  • datasets.load_dataset("lamini/lamini_docs"):加载名为lamini/lamini_docs的数据集。
  • dataset["test"]:从数据集中提取测试集。

3. 打印测试集中的第一个问题及答案

我们可以打印出测试集中的第一个问题和对应的答案,以便查看数据的内容。

print(test_dataset[0]["question"])
print(test_dataset[0]["answer"])

解释:

  • test_dataset[0]["question"]:获取测试集中第一个问题的文本。
  • test_dataset[0]["answer"]:获取测试集中第一个问题的答案。

4. 加载预训练模型和分词器

我们加载一个名为lamini/lamini_docs_finetuned的预训练模型及其对应的分词器。

model_name = "lamini/lamini_docs_finetuned"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

解释:

  • model_name = "lamini/lamini_docs_finetuned":指定预训练模型的名称。
  • AutoTokenizer.from_pretrained(model_name):加载与模型对应的分词器。
  • AutoModelForCausalLM.from_pretrained(model_name):加载预训练的语言模型。

5. 定义一个简单的评估函数

我们定义一个函数is_exact_match,用于判断两个字符串是否完全匹配。

def is_exact_match(a, b):return a.strip() == b.strip()

解释:

  • a.strip() == b.strip():去除字符串ab两端的空白字符后,判断它们是否相等。

6. 设置模型为评估模式

在进行推理之前,我们需要将模型设置为评估模式。

model.eval()

解释:

  • model.eval():将模型设置为评估模式,这会关闭一些在训练时使用的功能,如Dropout。

7. 定义推理函数

我们定义一个函数inference,用于生成模型的预测结果。

def inference(text, model, tokenizer, max_input_tokens=1000, max_output_tokens=100):# Tokenizetokenizer.pad_token = tokenizer.eos_tokeninput_ids = tokenizer.encode(text,return_tensors="pt",truncation=True,max_length=max_input_tokens)# Generatedevice = model.devicegenerated_tokens_with_prompt = model.generate(input_ids=input_ids.to(device),max_length=max_output_tokens)# Decodegenerated_text_with_prompt = tokenizer.batch_decode(generated_tokens_with_prompt, skip_special_tokens=True)# Strip the promptgenerated_text_answer = generated_text_with_prompt[0][len(text):]return generated_text_answer

解释:

  • tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=max_input_tokens):将输入文本编码为模型可以理解的输入ID。
  • model.generate(input_ids=input_ids.to(device), max_length=max_output_tokens):使用模型生成文本。
  • tokenizer.batch_decode(generated_tokens_with_prompt, skip_special_tokens=True):将生成的token解码为文本。
  • generated_text_with_prompt[0][len(text):]:去除提示文本,只保留生成的答案。

8. 运行模型并比较预测结果与预期答案

我们使用模型生成答案,并将其与预期答案进行比较。

test_question = test_dataset[0]["question"]
generated_answer = inference(test_question, model, tokenizer)
print(test_question)
print(generated_answer)answer = test_dataset[0]["answer"]
print(answer)exact_match = is_exact_match(generated_answer, answer)
print(exact_match)

解释:

  • test_question = test_dataset[0]["question"]:获取测试集中的第一个问题。
  • generated_answer = inference(test_question, model, tokenizer):使用模型生成答案。
  • answer = test_dataset[0]["answer"]:获取测试集中的第一个问题的答案。
  • exact_match = is_exact_match(generated_answer, answer):判断生成的答案与预期答案是否完全匹配。

9. 在整个数据集上运行模型并评估

我们可以在整个测试集上运行模型,并计算准确率。

n = 10
metrics = {'exact_matches': []}
predictions = []
for i, item in tqdm(enumerate(test_dataset)):print("i Evaluating: " + str(item))question = item['question']answer = item['answer']try:predicted_answer = inference(question, model, tokenizer)except:continuepredictions.append([predicted_answer, answer])exact_match = is_exact_match(predicted_answer, answer)metrics['exact_matches'].append(exact_match)if i > n and n != -1:break
print('Number of exact matches: ', sum(metrics['exact_matches']))

解释:

  • n = 10:设置最大评估样本数为10。
  • metrics = {'exact_matches': []}:用于存储每个样本的匹配结果。
  • predictions = []:用于存储预测结果和真实答案。
  • for i, item in tqdm(enumerate(test_dataset)):遍历测试集中的每个样本。
  • predicted_answer = inference(question, model, tokenizer):使用模型生成答案。
  • exact_match = is_exact_match(predicted_answer, answer):判断生成的答案与预期答案是否完全匹配。
  • if i > n and n != -1: break:如果评估的样本数超过n,则停止评估。

10. 将预测结果保存为DataFrame

我们可以将预测结果保存为Pandas DataFrame,以便进一步分析。

df = pd.DataFrame(predictions, columns=["predicted_answer", "target_answer"])
print(df)

解释:

  • pd.DataFrame(predictions, columns=["predicted_answer", "target_answer"]):将预测结果和真实答案保存为DataFrame。

11. 加载评估数据集

我们还可以加载另一个评估数据集,以便进行更全面的评估。

evaluation_dataset_path = "lamini/lamini_docs_evaluation"
evaluation_dataset = datasets.load_dataset(evaluation_dataset_path)

解释:

  • datasets.load_dataset(evaluation_dataset_path):加载评估数据集。

12. 尝试ARC基准测试

最后,我们可以尝试使用ARC基准测试来评估模型的性能。

!python lm-evaluation-harness/main.py --model hf-causal --model_args pretrained=lamini/lamini_docs_finetuned --tasks arc_easy --device cpu

解释:

  • !python lm-evaluation-harness/main.py:运行ARC基准测试脚本。
  • --model hf-causal:指定模型类型为因果语言模型。
  • --model_args pretrained=lamini/lamini_docs_finetuned:指定预训练模型。
  • --tasks arc_easy:指定任务为ARC Easy。
  • --device cpu:指定在CPU上运行。

总结

通过这篇博客,我们逐步解析了一段用于加载数据集、加载预训练模型、进行推理并评估模型性能的Python代码。希望这些解释能帮助你更好地理解每一行代码的作用,并激发你对Python编程的兴趣!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/65395.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue BPMN Modeler流程图

1、参考地址 git clone https://github.com/evanyangg/vue-bpmn-modeler.git 2、安装bpmn.js npm install bpmn-js --save 3、使用bpmn.js <template><div class"containers"><div class"canvas" ref"canvas"></div&g…

STM32完全学习——FATFS0.15移植SD卡

一、下载FATFS源码 大家都知道使用CubMAX可以很快的将&#xff0c;FATFS文件管理系统移植到单片机上&#xff0c;但是别的芯片没有这么好用的工具&#xff0c;就需要自己从官网下载源码进行移植。我们首先解决SD卡的驱动问题&#xff0c;然后再移植FATFS文件管理系统。 二、SD…

5、栈应用-表达式求值

本章内容使用上述栈结构函数&#xff0c;来完成表达式求值操作。 表达式例如&#xff1a;3*(7-2) 或者 (0-12)*((5-3)*32)/(22) 。 1、实现思路 a、建立OPTR&#xff08;运算符&#xff09;和OPND&#xff08;数字&#xff09;两个栈&#xff0c;后输入字符串以结束 b、自左向…

【递归与回溯深度解析:经典题解精讲(下篇)】—— Leetcode

文章目录 有效的数独解数独单词搜索黄金矿工不同的路径||| 有效的数独 递归解法思路 将每个数独的格子视为一个任务&#xff0c;依次检查每个格子是否合法。 如果当前格子中的数字违反了数独规则&#xff08;在行、列或 33 小方块中重复&#xff09;&#xff0c;直接返回 Fals…

Llama 3 预训练(二)

目录 3. 预训练 3.1 预训练数据 3.1.1 网络数据筛选 PII 和安全过滤 文本提取与清理 去重&#xff08;De-duplication&#xff09; 启发式过滤&#xff08;Heuristic Filtering&#xff09; 基于模型的质量过滤 代码和数学推理数据处理 多语言数据处理 3.1.2 确定数…

双指针——查找总价格为目标值的两个商品

一.题目描述 LCR 179. 查找总价格为目标值的两个商品 - 力扣&#xff08;LeetCode&#xff09; 二.题目解析 这个题目非常简单&#xff0c;其实就是判断有没有两个数加起来等于target。 三.算法解析 1.暴力解法 暴力解法的话我们可以枚举出所有的情况&#xff0c;然后判…

sqlserver镜像设置

本案例是双机热备&#xff0c;只设置主体服务器&#xff08;主&#xff09;和镜像服务器&#xff08;从&#xff09;&#xff0c;不设置见证服务器 设置镜像前先检查是否启用了 主从服务器数据库的 TCP/IP协议 和 RemoteDAC &#xff08;1&#xff09;打开SQL Server配置管理器…

Elasticsearch:analyzer(分析器)

一、概述 可用于将字符串字段转换为单独的术语&#xff1a; 添加到倒排索引中&#xff0c;以便文档可搜索。级查询&#xff08;如 生成搜索词的 match查询&#xff09;使用。 分析器分为内置分析器和自定义的分析器&#xff0c;它们都是由若干个字符过滤器&#xff08;chara…

ElementPlus 自定义封装 el-date-picker 的快捷功能

文章目录 需求分析 需求 分析 我们看到官网上给出的案例如下&#xff0c;但是不太满足我们用户想要的快捷功能&#xff0c;因为不太多&#xff0c;因此需要我们自己封装一些&#xff0c;方法如下 外部自定义该组件的快捷内容 export const getPickerOptions () > {cons…

低代码开发平台排名2024

低代码开发平台在过去几年中迅速崛起&#xff0c;成为企业数字化转型的重要工具。这些平台通过可视化界面和拖放组件&#xff0c;使业务人员和技术人员都能快速构建应用程序&#xff0c;大大缩短了开发周期。以下是一些在2024年值得关注和使用的低代码开发平台。 一、Zoho Cre…

计算机网络——期末复习(4)协议或技术汇总、思维导图

思维导图 协议与技术 物理层通信协议&#xff1a;曼彻斯特编码链路层通信协议&#xff1a;CSMA/CD &#xff08;1&#xff09;停止-等待协议&#xff08;属于自动请求重传ARQ协议&#xff09;&#xff1a;确认、否认、重传、超时重传、 &#xff08;2&#xff09;回退N帧协…

【MySQL学习笔记】关于索引

文章目录 【MySQL学习笔记】关于索引1.索引数据结构2.索引存储3.联合索引3.1 联合索引的b树结构3.2 索引覆盖&#xff1f;回表&#xff1f;3.3 联合索引最左匹配原则3.5 索引下推 4.索引失效 【MySQL学习笔记】关于索引 1.索引数据结构 索引是一种能提高查询速度的数据结构。…

D104【python 接口自动化学习】- pytest进阶参数化用法

day104 pytest参数化parametrize单参数 学习日期&#xff1a;20241223 学习目标&#xff1a;pytest基础用法 -- pytest参数化parametrize单参数 学习笔记&#xff1a; 参数化 parametrize 参数化可以组装测试数据&#xff0c;在测试前定义好测试数据&#xff0c;并在测试用…

第T4周:TensorFlow实现猴痘识别(Tensorboard的使用)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 目标&#xff1a; 1、学习tensorboard的使用 具体实现&#xff1a; &#xff08;一&#xff09;环境&#xff1a; 语言环境&#xff1a;Python 3.10 编 译 器…

Docker-构建自己的Web-Linux系统-镜像webtop:ubuntu-kde

介绍 安装自己的linux-server,可以作为学习使用&#xff0c;web方式访问&#xff0c;基于ubuntu构建开源项目 https://github.com/linuxserver/docker-webtop安装 docker run -d -p 1336:3000 -e PASSWORD123456 --name webtop lscr.io/linuxserver/webtop:ubuntu-kde登录 …

小米路由器开启SSH,配置阿里云ddns,开启外网访问SSH和WEB管理界面

文章目录 前言一、开启SSH二、配置阿里云ddns1.准备工作2.创建ddns脚本3.添加定时任务 三、开启外网访问SSH和WEB管理界面1、解除WEB管理页面访问限制2.手动添加防火墙端口转发规则&#xff0c;开启外网访问WEB管理和SSH 前言 例如&#xff1a;随着人工智能的不断发展&#xf…

什么是ESC ---- 防止车辆打滑并提高驾驶时稳定性的技术

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 所谓鸡汤&#xff0c;要么蛊惑你认命&#xff0c;要么怂恿你拼命&#xff0c;但都是回避问题的根源&…

LinkedList类 (链表)

目录 一. LinkedList 基本介绍 二. LinkedList 中的法及其应用 1. 添加元素 (1) add() (2) addAll() (3) addFirst() (4) addLast() 2. 删除元素 (1) remove() (2) removeAll() (3) removeFirst() (4) removeLast() 3. 遍历元素 (1) for 循环遍历 (2) for - each …

复习打卡大数据篇——Hadoop MapReduce

目录 1. MapReduce基本介绍 2. MapReduce原理 1. MapReduce基本介绍 什么是MapReduce MapReduce是一个分布式运算程序的编程框架&#xff0c;核心功能是将用户编写的业务逻辑代码和自带默认组件整合成一个完整的分布式运算程序&#xff0c;并发运行在Hadoop集群上。 MapRed…

Java基础知识(四) -- 面向对象(下)

1.类变量和类方法 1.1 类变量背景 有一群小孩在玩堆雪人,不时有新的小孩加入,请问如何知道现在共有多少人在玩? 思路分析: 核心在于如何让变量count被所有对象共享 public class Child {private String name;// 定义静态变量(所有Child对象共享)public static int count 0;p…