李宏毅机器学习2022-HW7-BERT-Question Answering

文章目录

  • Task
  • Baseline
    • Medium
    • Strong
    • Boss
  • Code Link

Task

HW7的任务是通过BERT完成Question Answering。

数据预处理流程梳理

数据解压后包含3个json文件:hw7_train.json, hw7_dev.json, hw7_test.json。

DRCD: 台達閱讀理解資料集 Delta Reading Comprehension Dataset

ODSQA: Open-Domain Spoken Question Answering Dataset

  • train: DRCD + DRCD-TTS
    • 10524 paragraphs, 31690 questions
  • dev: DRCD + DRCD-TTS
    • 1490 paragraphs, 4131 questions
  • test: DRCD + ODSQA
    • 1586 paragraphs, 4957 questions

{train/dev/test}_questions:

  • List of dicts with the following keys:
  • id (int)
  • paragraph_id (int)
  • question_text (string)
  • answer_text (string)
  • answer_start (int)
  • answer_end (int)

{train/dev/test}_paragraphs:

  • List of strings
  • paragraph_ids in questions correspond to indexs in paragraphs
  • A paragraph may be used by several questions

读取这三个文件,每个文件返回相应的question数据和paragraph数据,都是文本数据,不能作为模型的输入。

利用Tokenization将question和paragraph文本数据先按token为单位分开,再转换为tokens_to_ids数字数据。Dataset选取paragraph中固定长度的片段(固定长度为150),片段需包含answer部分,然后使用Tokenization 以CLS + question + SEP + document+ CLS + padding(不足的补0)的形式作为训练输入。

Total sequence length = question length + paragraph length + 3 (special tokens)
Maximum input sequence length of BERT is restricted to 512

在这里插入图片描述
在这里插入图片描述

training

在这里插入图片描述

testing

对于每个窗口,模型预测一个开始分数和一个结束分数,取最大值作为答案

在这里插入图片描述

Baseline

Medium

应用linear learning rate decay+change doc_stride

这里linear learning rate decay选用了两种方法

  • 手动调整学习率

    假设初始学习率为 η 0 η_0 η0,总的步骤数为 T T T,那么在第 t t t步时的学习率 η t η_t ηt 可以表示为:

    η t = η 0 − η 0 T × t η_t=η_0−\frac{η_0}{T}×t ηt=η0Tη0×t

    其中:

    • η 0 η_0 η0 是初始学习率。
    • T T T是总的步骤数(total_step)。
    • t t t 是当前的步骤数(从 0 开始计数)。

    optimizer.param_groups[0]["lr"] -= learning_rate / total_step η t = η t − η 0 T η t η_t=η_t−\frac{η_0}{T}η_t ηt=ηtTη0ηt

    • optimizer.param_groups[0]["lr"] 对应 η t η_t ηt
    • learning_rate 对应 η 0 η_0 η0
    • total_step 对应 T T T
    • i 对应 t t t
    # Medium--Learning rate dacay
    # Method 1: adjust learning rate manually
    total_step = 1000
    for i in range(total_step):optimizer.param_groups[0]["lr"] -= learning_rate / total_step
    
  • 通过scheduler自动调整学习率

    • (recommend) transformer
    • torch.optim
    # Method 2: Adjust learning rate automatically by scheduler# (Recommend) https://huggingface.co/transformers/main_classes/optimizer_schedules.html#transformers.get_linear_schedule_with_warmup
    from transformers import get_linear_schedule_with_warmup
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=1000)# https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
    # 这里如果要用pytorch的ExponentialLR,一定要导入optim模块,并且前面的AdamW是从transformers中import的这里要重新import
    import torch.optim as optim
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    

change doc_stride在QA_Dataset的时候修改段落滑动窗口的步长

##### TODO: Change value of doc_stride #####
# 段落滑动窗口的步长
self.doc_stride = 30  # Medium

Strong

应用➢ Improve preprocessing ➢ Try other pretrained models

  • 尝试其他预训练模型

比如bert-base-multilingual-case,因为它可以避免英文无法tokenization输出[UNK],但是计算量大

model = BertForQuestionAnswering.from_pretrained("hfl/chinese-macbert-large").to(device)
tokenizer = BertTokenizerFast.from_pretrained("hfl/chinese-macbert-large")
  • preprocessing ,在QA_Dataset中修改截取答案的窗口

    1. 随机窗口选择 Random Window Selection
      随机选择窗口的起始位置

      • 随机范围的下界
        start_min = max(0, answer_end_token - self.max_paragraph_len + 1) 答案结束位置向前移动 self.max_paragraph_len - 1 个标记后的位置和 0 较大的那个
      • 随机范围的上界
        start_max = min(answer_start_token, len(tokenized_paragraph) - self.max_paragraph_len)
        • len(tokenized_paragraph) - self.max_paragraph_len:计算段落长度减去窗口长度后的位置,确保窗口不会超出段落末尾。
        • min(answer_start_token, ...):确保上界不超过答案开始位置,避免答案被截断。
      • 随机选择
        paragraph_start = random.randint(start_min, start_max)在计算出的下界和上界之间随机选择一个整数作为窗口的起始位置。
      • 计算窗口结束位置
        paragraph_end = paragraph_start + self.max_paragraph_len确保窗口长度为 self.max_paragraph_len
    2. 滑动窗口大小 Dynamic window size

            ##### TODO: Preprocessing Strong ###### Hint: How to prevent model from learning something it should not learnif self.split == "train":# Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraphanswer_start_token = tokenized_paragraph.char_to_token(question["answer_start"])answer_end_token = tokenized_paragraph.char_to_token(question["answer_end"])# A single window is obtained by slicing the portion of paragraph containing the answer# 在training中paragraph的截取依据的是answer的position id"""mid = (answer_start_token + answer_end_token) // 2paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))paragraph_end = paragraph_start + self.max_paragraph_len"""# Strong# Method 1: Random window selectionstart_min = max(0, answer_end_token - self.max_paragraph_len + 1)  # 计算答案结束位置向前移动 self.max_paragraph_len - 1 个标记后的位置start_max = min(answer_start_token, len(tokenized_paragraph) - self.max_paragraph_len)start_max = max(start_min, start_max)paragraph_start = random.randint(start_min, start_max + 1)paragraph_end = paragraph_start + self.max_paragraph_len"""# Method 2: Dynamic window size # 这个会造成窗口的大小大于max_paragraph_len,那么会造成输入序列的长度不一致,后面padding也要改,这里暂不采用answer_length = answer_end_token - answer_start_tokendynamic_window_size = max(self.max_paragraph_len, answer_length + 20)  # 添加一些额外的空间paragraph_start = max(0, min(answer_start_token - dynamic_window_size // 2, len(tokenized_paragraph) - dynamic_window_size))paragraph_end = paragraph_start + dynamic_window_size"""

Boss

➢ Improve postprocessing ➢ Further improve the above hints

doc_stride + max_length+ learning rate scheduler + preprocessing+ postprocessing + new model + no validation

与strong baseline相比,最大的改变有两个,一是换pretrain model,在hugging face中搜索chinese + QA的模型,根据model card描述选择最好的模型,使用后大概提升2.5%的精度,二是更近一步的postprocessing,查看提交文件可看到很多answer包含CLS, SEP, UNK等字符,CLS和SEP的出现表示预测位置有误,UNK的出现说明有某些字符无法正常编码解码(例如一些生僻字),错误字符的问题均可在evaluate函数中改进,这个步骤提升了大概1%的精度。其他的修改主要是针对overfitting问题,包括减少了learning rate,提升dataset里面的paragraph max length, 将validation集合和train集合进行合并等。另外可使用的办法有ensemble,大概能提升0.5%的精度,改变random seed,也有提升的可能性。

if start_index > end_index or start_index < paragraph_start or end_index > paragraph_end:continueif '[UNK]' in answer:print('发现 [UNK],这表明有文字无法编码, 使用原始文本')#print("Paragraph:", paragraph)#print("Paragraph:", paragraph_tokenized.tokens)print('--直接解码预测:', answer)#找到原始文本中对应的位置raw_start =  paragraph_tokenized.token_to_chars(origin_start)[0]raw_end = paragraph_tokenized.token_to_chars(origin_end)[1]answer = paragraph[raw_start:raw_end]print('--原始文本预测:',answer)

Code Link

github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/55805.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

openpnp - 底部相机视觉识别CvPipeLine的参数bug修正

文章目录 openpnp - 底部相机视觉识别的CvPipeLine的参数bug概述笔记openpnp的视觉识别参数的错误原因备注END openpnp - 底部相机视觉识别的CvPipeLine的参数bug 概述 底部相机抓起一个SOD323的元件&#xff0c;进行视觉识别。 识别出的矩形错了&#xff0c;是一个很长的长方…

Qt_软件添加版本信息

文章内容: 给生成的软件添加软件的版权等信息 #include <windows.h> //中文的话增加下面这一行 #pragma code_page(65001)VS_VERSION_INFO VERSIONINFO

TEI text-embeddings-inference文本嵌入模型推理框架

参看: https://github.com/huggingface/text-embeddings-inference#docker 文本嵌入模型榜单 https://huggingface.co/spaces/mteb/leaderboard bge模型下载 https://huggingface.co/BAAI/bge-m3/tree/main export HF_ENDPOINT=https://hf-mirror.comhuggingface-cli dow…

STM32-HAL库 - MAX30102心率血氧传感器 —— 2024.10.15

一、教程简介 本教程使用CubeMX配合Keil5编写代码&#xff0c;带你10分钟拿下MAX30102。在官方例程的基础上进行移植和封装&#xff0c;测量数据准确。采用模拟I2C&#xff0c;任意三个引脚均可驱动。 二、MAX30102简介 MAX30102是一个集成的脉搏血氧仪和心率监测仪生物传感器…

Tortoise SVN 安装汉化教程(乌龟SVN)

1.首先下载 去官网下载 如果下载比较慢的&#xff0c;链接自取 https://pan.quark.cn/s/cb6f2eee3f90 2. 安装Tortoise SVN 无脑next到完成 最后到桌面右键 你就发现svn出来了&#xff0c;但是是英文的&#xff01;&#xff01;&#xff01;&#xff01; 像我这种英文不好的…

流体力学笔记

目录 1、名词2、湍流与涡流3 涡激振动4 压力面与吸力面参考&#xff1a;[空气动力学的“他山之石”](https://zhuanlan.zhihu.com/p/412542513) 1、名词 转列&#xff1a;transition 涡脱落&#xff1a;vortex shedding 涡分离&#xff1a;vortex rupture 气动噪声&#xff1a…

【java Web如何开发?】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

rel,npt时间服务器

服务器端配置 配置文件/etc/chrony.conf 开放端口 查看123端口是否开放 关闭防火墙 客户端配置 配置文件/etc/chrony.conf 重启文件和查看是否成功 注意事项要在ping通下实现&#xff0c;要是能ping通配置好了还是不行可以查看防火墙是否关闭

EMCMO--多任务优化求解约束多目标问题

EMCMO–多任务优化求解约束多目标问题 title&#xff1a; An Evolutionary Multitasking Optimization Framework for Constrained Multi-objective Optimization Problems author&#xff1a; Kangjia Qiao, Kunjie Yu, BoyangQu, Jing Liang, Hui Song, and Caitong Yue. …

Redis7 数据类型

Redis7 数据类型 文章目录 Redis7 数据类型1. Redis键&#xff08;Key&#xff09;2. Redis字符串&#xff08;String&#xff09;3. Redis列表&#xff08;List&#xff09;4. Redis哈希表&#xff08;Hash&#xff09;5. Redis集合&#xff08;Set&#xff09;5.1 常用操作5.…

Atlas800昇腾服务器(型号:3000)—驱动与固件安装(一)

服务器配置如下&#xff1a; CPU/NPU&#xff1a;鲲鹏 CPU&#xff08;ARM64&#xff09;A300I pro推理卡 系统&#xff1a;Kylin V10 SP1【下载链接】【安装链接】 驱动与固件版本版本&#xff1a; Ascend-hdk-310p-npu-driver_23.0.1_linux-aarch64.run【下载链接】 Ascend-…

Vue3学习:vite项目中图片不能显示,报错 require is not defined

今天做了一个案例“给你喜欢的人送花”&#xff0c;如果喜欢谁&#xff0c;就给谁送花&#xff0c;最多可以送5朵。运行效果如下。 这个项目是使用 npm create vitelatest 命令创建的。 包括2个组件&#xff1a; 根组件App.vue子组件HelloVote.vue。 目录结构如图所示&#x…

秋招面试题记录_半结构化面试

c八股(可能问的多一点) 1.简单说说C11语法特性 答&#xff1a; 1.auto以及decltype自动类型推导&#xff0c;避免手动声明复杂类型&#xff0c;减少冗长代码提升了可读性和安全性。 2.智能指针 自动释放内存 (具体说说) 有shared和unique 差异主要体现在所有权、内存开销、…

Java项目-基于springboot框架的校园在线拍卖系统项目实战(附源码+文档)

作者&#xff1a;计算机学长阿伟 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、ElementUI等&#xff0c;“文末源码”。 开发运行环境 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBoot、Vue、Mybaits Plus、ELementUI工具&#xff1a;IDEA/…

Java的walkFileTree方法用法【FileVisitor接口】

在Java旧版本中遍历文件系统只能通过File类通过递归来实现&#xff0c;但是这种方法不仅消耗资源大而且效率低。 NIO.2的Files工具类提供了两个静态工具方法walk()和walkFileTree()可用来高效并优雅地遍历文件系统。walkFileTree()功能更强&#xff0c;可自定义实现更多功能&am…

(二十)、从宿主机访问 k8s(minikube) 发布的 redis 服务

文章目录 1、环境准备2、具体操作2.1、启动 minikube (start/stop)2.2、准备 redis-deployment.yaml2.3、执行 redis-deployment.yaml2.3.1、查看 pod 信息和日志 2.4、检查部署和服务状态2.4.1、如果需要删除 3、查看 IP 的几个命令3.1、查看IP的几个命令3.2、解读3.3、宿主机…

k8s的部署和安装

k8s的部署和安装 一、Kubernets简介及部署方法 1.1 应用部署方式演变 在部署应用程序的方式上&#xff0c;主要经历了三个阶段&#xff1a; 传统部署&#xff1a;互联网早期&#xff0c;会直接将应用程序部署在物理机上 优点&#xff1a;简单&#xff0c;不需要其它技术的参…

HarmonyOS Next模拟器异常问题及解决方法

1、问题1&#xff1a;Failed to get the device apiVersion. 解决方法&#xff1a;关闭模拟器清除用户数据重启

电子商务网站维护技巧:保持WordPress、主题和插件的更新

在这个快节奏的数字时代&#xff0c;维护一个电子商务网站的首要任务之一是保持WordPress、主题和插件的最新状态。过时的软件不仅可能导致功能故障&#xff0c;还可能带来安全风险。本文将深入探讨如何有效地更新和维护您的WordPress网站&#xff0c;以确保其安全性和性能。 …

【天池比赛】【零基础入门金融风控 Task2赛题理解】【2.3.6】

【天池比赛】【零基础入门金融风控 Task2赛题理解】【2.3.1-2.3.5】 2.3.6 变量分布可视化 2.3.6.1 单一变量分布可视化 对于 pandas.core.series.Series 类型的变量&#xff1a; index&#xff1a;含义&#xff1a;它表示 Series 对象的索引&#xff0c;也就是每个数据点对…