【机器学习】—机器学习和NLP预训练模型探索之旅

目录

一.预训练模型的基本概念

1.BERT模型

2 .GPT模型

二、预训练模型的应用

1.文本分类

使用BERT进行文本分类

2. 问答系统

使用BERT进行问答

三、预训练模型的优化

 1.模型压缩

1.1 剪枝

权重剪枝

2.模型量化

2.1 定点量化

使用PyTorch进行定点量化

3. 知识蒸馏

3.1 知识蒸馏的基本原理

3.2 实例代码:使用知识蒸馏训练学生模型

四、结论


随着数据量的增加和计算能力的提升,机器学习和自然语言处理技术得到了飞速发展。预训练模型作为其中的重要组成部分,通过在大规模数据集上进行预训练,使得模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

一.预训练模型的基本概念

预训练模型是一种在大规模数据集上预先训练好的模型,可以作为其他任务的基础。预训练模型的优势在于其能够利用大规模数据集中的知识,提高模型的泛化能力和准确性。常见的预训练模型包括BERT(Bidirectional Encoder Representations from Transformers)、GPT(Generative Pre-trained Transformer)等。

1.BERT模型

BERT是由Google提出的一种双向编码器表示模型。BERT通过在大规模文本数据上进行掩码语言模型(Masked Language Model, MLM)和下一句预测(Next Sentence Prediction, NSP)的预训练,使得模型可以学习到深层次的语言表示。

2 .GPT模型

GPT由OpenAI提出,是一种基于Transformer的生成式预训练模型。GPT通过在大规模文本数据上进行自回归语言模型的预训练,使得模型可以生成连贯的文本。

二、预训练模型的应用

预训练模型在NLP领域有广泛的应用,包括但不限于文本分类、问答系统、机器翻译等。以下将介绍几个具体的应用实例。

1.文本分类

文本分类是将文本数据按照预定义的类别进行分类的任务。预训练模型可以通过在大规模文本数据上进行预训练,从而捕捉到丰富的语义信息,提高文本分类的准确性。

使用BERT进行文本分类

import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)# 定义数据集
class TextDataset(Dataset):def __init__(self, texts, labels, tokenizer, max_len):self.texts = textsself.labels = labelsself.tokenizer = tokenizerself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]encoding = self.tokenizer.encode_plus(text,add_special_tokens=True,max_length=self.max_len,return_token_type_ids=False,padding='max_length',return_attention_mask=True,return_tensors='pt',)return {'text': text,'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'label': torch.tensor(label, dtype=torch.long)}# 准备数据
texts = ["I love this!", "I hate this!"]
labels = [1, 0]
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.1)train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len=32)
val_dataset = TextDataset(val_texts, val_labels, tokenizer, max_len=32)train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)# 训练模型
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(3):model.train()for batch in train_loader:optimizer.zero_grad()input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['label']outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()# 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for batch in val_loader:input_ids = batch['input_ids']attention_mask = batch['attention_mask']labels = batch['label']outputs = model(input_ids=input_ids, attention_mask=attention_mask)_, predicted = torch.max(outputs.logits, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Validation Accuracy: {correct / total:.2f}')

2. 问答系统

问答系统是从文本中自动提取答案的任务。预训练模型可以通过在大规模问答数据上进行预训练,从而提高答案的准确性和相关性。

使用BERT进行问答

from transformers import BertForQuestionAnswering# 加载预训练的BERT问答模型
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')# 输入问题和上下文
question = "What is the capital of France?"
context = "Paris is the capital of France."# 编码输入
inputs = tokenizer.encode_plus(question, context, return_tensors='pt')# 模型预测
outputs = model(**inputs)
start_scores = outputs.start_logits
end_scores = outputs.end_logits# 获取答案的起始和结束位置
start_idx = torch.argmax(start_scores)
end_idx = torch.argmax(end_scores) + 1# 解码答案
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][start_idx:end_idx]))
print(f'Answer: {answer}')

三、预训练模型的优化

在实际应用中,预训练模型的优化至关重要。常见的优化方法包括模型压缩、量化和蒸馏等。

 1.模型压缩

模型压缩是通过减少模型参数数量和计算量来提高模型效率的方法。压缩后的模型不仅运行速度更快,还能减少存储空间和内存占用。常见的模型压缩技术包括剪枝、量化和知识蒸馏等。

1.1 剪枝

剪枝(Pruning)是一种通过删除模型中冗余或不重要的参数来减小模型大小的方法。剪枝可以在训练过程中或训练完成后进行。常见的剪枝方法包括:

  • 权重剪枝(Weight Pruning):删除绝对值较小的权重,认为这些权重对模型输出影响不大。
  • 结构剪枝(Structured Pruning):删除整个神经元或卷积核,减少模型的计算量和存储需求。

剪枝后的模型通常需要重新训练,以恢复或接近原始模型的性能。

权重剪枝
import torch
import torch.nn.utils.prune as prune# 定义一个简单的模型
class SimpleModel(torch.nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = torch.nn.Linear(10, 10)def forward(self, x):return self.fc(x)model = SimpleModel()# 对模型的全连接层进行权重剪枝
prune.l1_unstructured(model.fc, name='weight', amount=0.5)# 查看剪枝后的权重
print(model.fc.weight)

2.模型量化

模型量化是通过降低模型参数的精度来减少计算量的方法。量化通常通过将浮点数表示的权重和激活值转换为低精度表示(如8位整数)来实现。这可以显著减少模型的存储空间和计算开销,同时在硬件上加速模型推理。

2.1 定点量化

定点量化(Fixed-point Quantization)是将浮点数表示的权重和激活值转换为固定精度的整数表示。常见的定点量化包括8位整数量化(INT8),这种量化方法在不显著降低模型精度的情况下,可以大幅提升计算效率。

使用PyTorch进行定点量化
import torch
import torch.quantization# 加载预训练模型
model = SimpleModel()# 定义量化配置
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')# 准备量化模型
model = torch.quantization.prepare(model, inplace=True)# 模拟量化后的推理过程
# 这里应该使用训练数据对模型进行微调,但为了简单起见,省略此步骤
model = torch.quantization.convert(model, inplace=True)# 查看量化后的模型
print(model)

3. 知识蒸馏

知识蒸馏(Knowledge Distillation)是通过将大模型(教师模型,Teacher Model)的知识转移到小模型(学生模型,Student Model)的方法,从而提高小模型的性能和效率。知识蒸馏的核心思想是通过教师模型的软标签(soft labels)指导学生模型的训练。

3.1 知识蒸馏的基本原理

在知识蒸馏过程中,学生模型不仅学习训练数据的真实标签,还学习教师模型对训练数据的输出,即软标签。软标签包含了更多的信息,比如类别之间的相似性,使学生模型能够更好地泛化。

蒸馏损失函数通常由两部分组成:

  • 交叉熵损失:衡量学生模型输出与真实标签之间的差异。
  • 蒸馏损失:衡量学生模型输出与教师模型软标签之间的差异。

总体损失函数为这两部分的加权和。

3.2 实例代码:使用知识蒸馏训练学生模型

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset# 定义教师模型和学生模型
teacher_model = SimpleModel()
student_model = SimpleModel()# 加载示例数据
data = torch.randn(100, 10)
labels = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)# 定义蒸馏训练函数
def distillation_train(student_model, teacher_model, data_loader, optimizer, temperature=2.0, alpha=0.5):teacher_model.eval()student_model.train()for data, labels in data_loader:optimizer.zero_grad()# 教师模型输出with torch.no_grad():teacher_logits = teacher_model(data)# 学生模型输出student_logits = student_model(data)# 计算蒸馏损失loss_ce = F.cross_entropy(student_logits, labels)loss_kl = F.kl_div(F.log_softmax(student_logits / temperature, dim=1),F.softmax(teacher_logits / temperature, dim=1),reduction='batchmean') * (temperature ** 2)loss = alpha * loss_ce + (1.0 - alpha) * loss_klloss.backward()optimizer.step()# 定义优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)# 进行蒸馏训练
for epoch in range(10):distillation_train(student_model, teacher_model, data_loader, optimizer)# 验证学生模型
student_model.eval()
correct = 0
total = 0
with torch.no_grad():for data, labels in data_loader:outputs = student_model(data)_, predicted = torch.max(outputs, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Student Model Accuracy: {correct / total:.2f}')

四、结论

预训练模型在机器学习和自然语言处理领域具有重要意义。通过在大规模数据集上进行预训练,模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/13979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高光谱成像技术简介,怎么选择成像方案?

目录 一、什么是光谱?二、光谱和光谱分析方法的类型三、多光谱和高光谱的区别四、高光谱在水果品质检测中的应用1. 高光谱成像系统2. 高光谱图像的获取方式3. 高光谱图像处理与分析4. 在水果品质检测中的应用总结 五、针对自己的应用场景怎么使用高光谱技术六、参考…

海山数据库(He3DB)从方法到实践,构建以场景为中心的体验管理体系

编者按:体验优化的过程中设计师经常会遇到几个阶段,发现问题、定义问题、优化问题、查看反馈,但在产品快速迭代的过程中,体验的问题经常被归类到“不紧急”需求中,并逐步转为长尾问题,这些不被重视的问题聚…

【OpenGL实践10】关于几何着色器

目录 一、说明 二、几何着色器 2.1 设置 2.2 基本几何着色器 2.2.1 输入类型 2.2.2 输出类型 2.2.3 顶点输入 2.2.4 顶点输出 2.3 创建几何着色器 2.4 几何着色器和顶点属性 三、动态生成几何体 四、结论 练习 一、说明 几何着色器的应用比较高级,关于…

Epson推出的FC2012AN晶体专为小尺寸、低ESR应用设计

在可穿戴设备、loT产品、无线通信模块等领域,对于小型化、低功耗和高精度的时钟需求日益增长。Epson推出的FC2012AN系列晶体单元凭借其小尺寸、低ESR等特性使其成为这些应用的理想选择。 FC2012AN系列是一款32.768K频率的晶体单元,频率偏差为 20x10-6…

【云原生】kubernetes中的service原理、应用实战案例解析

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

DSPy - prompt 优化

文章目录 一、关于 DSPy与神经网络的类比常见问题解答**DSPy 优化器可以调整什么?****我应该如何使用 DSPy 完成我的任务?****如果我对提示或合成数据生成有更好的想法怎么办?**DSPy 代表什么? 二、安装三、文档A) 教程B) 指南C) …

全球手游4月战报,《Monopoly GO!》荣耀再续!全球手游畅销榜冠军

易采游戏网5月22日消息,在刚刚过去的四月里,全球移动游戏市场的角逐愈发激烈。根据最新发布的数据,Scopely旗下的经典游戏《Monopoly GO!》再次蝉联全球手游畅销榜首冠军宝座,展现了无与伦比的市场魅力与玩家黏度。 4月Scopely《M…

中霖教育怎么样?二建继续教育几年一次?

中霖为大家介绍: 根据相关规定,二级建造师执业资格注册证书设定有效期限为三年。为确保持证人员的专业能力,在规定的期限内需要完成规定的继续教育课程并参加考核,以此来维护其职业资质的连续性。 在执业资格证书的有效期满前&a…

redis小知识

AOF与RDB的区别 AOF (Append Only File) 和 RDB (Redis Database) 都是Redis中的持久化机制,但有以下几点不同之处: 内容格式:AOF 以日志的形式记录所有写操作命令,而 RDB 则是在指定的时间间隔内对数据库进行快照,将数…

WDW-100G 高温拉力试验机 技术方案书

一、整机外观图: 二、项目简介: 微机控制高温拉力试验机是电子技术与机械传动相结合的新型材料试验机,它具有宽广准确的加载速度和测力范围,对载荷、变形、位移的测量和控制有较高的精度和灵敏度,还可以进行等速加载、…

FFmpeg开发笔记(三十)解析H.264码流中的SPS帧和PPS帧

《FFmpeg开发实战:从零基础到短视频上线》一书的“2.1.1 音视频编码的发展历程”介绍了H.26x系列的视频编码标准,其中H.264至今仍在广泛使用,无论视频文件还是网络直播,H.264标准都占据着可观的市场份额。 之所以H.264取得了巨大…

CDLinux下载网站

CDlinux - Browse /CDlinux-ISO at SourceForge.net

用神经网络预测三角形的面积

周末遛狗时,我想起一个老问题:神经网络能预测三角形的面积吗? 神经网络非常擅长分类,例如根据花瓣长度和宽度以及萼片长度和宽度预测鸢尾花的种类(setosa、versicolor 或 virginica)。神经网络还擅长一些回…

2024中青杯A题数学建模成品文章数据代码分享

人工智能视域下养老辅助系统的构建 摘要 随着全球人口老龄化的加剧,养老问题已经成为一个世界性的社会问题,对社会各个方面产生了深远影响,包括劳动力市场、医疗保健和养老金制度等。人口结构变化对养老服务的质量和覆盖面提出了更高要求。特…

ARP基本原理

相关概念 ARP报文 ARP报文分为ARP请求报文和ARP应答报文,报文格式如图1所示。 图1 ARP报文格式 Ethernet Address of destination(0–31)和Ethernet Address of destination(32–47)分别表示Ethernet Address of dest…

【算法】前缀和——除自身以外数组的乘积

本节博客是用前缀和算法求解“除自身以外数组的乘积”,有需要借鉴即可。 目录 1.题目2.前缀和算法3.变量求解4.总结 1.题目 题目链接:LINK 2.前缀和算法 1.创建两个数组 第一个数组第i位置表示原数组[0,i-1]之积第二个数组第i位置表示原数组[i1,n-1]…

Hadoop 客户端 FileSystem加载过程

如何使用hadoop客户端 public class testCreate {public static void main(String[] args) throws IOException {System.setProperty("HADOOP_USER_NAME", "hdfs");String pathStr "/home/hdp/shanshajia";Path path new Path(pathStr);Confi…

在DAYU200上实现OpenHarmony跳转拨号界面

一、简介 日常生活中,打电话是最常见的交流方式之一,那么如何在OpenAtom OpenHarmony(简称“OpenHarmony”)中进行电话服务相关的开发呢?今天我们可以一起来了解一下如何通过电话服务系统支持的API实现拨打电话的功能…

C#-根据日志等级进行日志的过滤输出

文章速览 概要具体实施创建Log系统动态修改日志等级 坚持记录实属不易,希望友善多金的码友能够随手点一个赞。 共同创建氛围更加良好的开发者社区! 谢谢~ 概要 方便后期对软件进行维护,需要在一些关键处添加log日志输出,但时间长…

【408精华知识】指令周期的数据流

文章目录 一、取指周期二、间址周期三、执行周期(一)数据传送类指令(mov/load/store)(二)运算类指令(加/减/乘/除/移位/与/或)(三)转移类指令(jmp/jxxx) 四、中断周期 CPU每取出并且执行一条指令所需要的全…