NLP(17)--大模型发展(1)

前言

仅记录学习过程,有问题欢迎讨论

大模型的演化:

ElMO : 类似双向lstm 结果和词向量拼接 预训练鼻祖

GPT :使用了Transformer 模型 开始使用Token (发现预训练的作用)

Bert:认为双向比单向好 MLM(双向) 优于 LTR

Ernie-baidu:中文强于bert mask训练改为词组而非token

GPT2:继续使用Transformer 使用单向 more data

UNILM:使用Transformer + Mask attention 所以双向和单向都可以训

Transformer-XL & XLNet: 循环机制(利用attention|)解决bert输入长度限制(max_len = 512)

Roberta: bert变体 ,more data/去掉NSP/长样本/动态改变Mask位置 效果更好

Span bert:去掉NSP 随机mask连续token SBO

ALBert: 试图解决bert模型过大的问题(因式分解 vh = vE + E*h) 想办法减少参数( 跨层参数共享)

T5:Text-to-Text Transfer Transformer:Seq2Seq理论解决一切NLP问题

GPT3:参数量1750亿 目标:像人一样学习

总结

使用类bert形式训练 还是需要一个线性层配合任务改变 使用还是不便
更加理想的方式,使用文本来描述任务答案,训练方式统一(text-text),推理方式统一(无需解码结果 index-result)

Instruct GPT:

SFT训练 从续写到回答 输入的x 为问题 x+y为答案
在这里插入图片描述

RW训练(判断答案好坏的模型):强化学习 一个问题对应多个答案,由人工排序打分 y(w)=更好的答案 y(L) =差的 (类似于loss)

RL训练 在SFT上发展,但期望和SFT训练后结果分布相差不大 目标是使得RW model最大

In context learning:

训练内容直接在对话中产生,学习后输出内容(模型权重没变)
Z-s o-s f-s (zero/one/few shot) 问题中是否包含样例(和prompt)

提示工程:

  1. 提示词(prompt) 提示词是模型在训练过程中学习到的,模型在训练过程中会学习到一些通用的模式,这些模式可以作为提示词来指导模型的训练。
  2. 复杂任务可以拆解为多个步骤一步一训练,训练完成后再组合起来

代码

使用bert实现自回归训练模型,
添加mask attention 来实现

标题/内容任务,也是自回归模型,
对于输入输出长度不一致的问题,可以仔细看看怎么实现的,mask怎么实现的(采用mask使得x,y相互不可见)

# coding:utf8
import jsonimport torch
import torch.nn as nn
import numpy as np
import math
import random
import os
import refrom torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer"""
基于pytorch的LSTM语言模型
使用SFT训练 Q&A
"""class LanguageModel(nn.Module):def __init__(self, input_dim, vocab_size):super(LanguageModel, self).__init__()# self.embedding = nn.Embedding(len(vocab), input_dim)# self.layer = nn.LSTM(input_dim, input_dim, num_layers=1, batch_first=True)self.bert = BertModel.from_pretrained(r"D:\NLP\video\第六周\bert-base-chinese", return_dict=False)self.classify = nn.Linear(input_dim, vocab_size)# self.dropout = nn.Dropout(0.1)self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)# 当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None, mask=None):# x = self.embedding(x)  # output shape:(batch_size, sen_len, input_dim)# 使用mask来防止提前预知结果if y is not None:x, _ = self.bert(x, attention_mask=mask)y_pred = self.classify(x)return self.loss(y_pred.view(-1, y_pred.shape[-1]), y.view(-1))else:x, _ = self.bert(x)y_pred = self.classify(x)return torch.softmax(y_pred, dim=-1)# 加载语料
def load_corpus(path):title_list = []content_list = []with open(path, encoding="utf8") as f:for i, line in enumerate(f):line = json.loads(line)title_list.append(line["title"])content_list.append(line["content"])return [title_list, content_list]# 建立数据集
# sample_length 输入需要的样本数量。需要多少生成多少
# vocab 词表
# window_size 样本长度
# corpus 语料字符串
def build_dataset(sample_length, tokenizer, corpus, max_len, valid_flag=False):# dataset_x = []# dataset_y = []# dataset_mask = []dataset = []for i in range(sample_length):dataiter = build_sample(tokenizer, corpus, max_len, valid_flag)# dataset_x.append(x)# dataset_y.append(y)# dataset_mask.append(mask)dataset.append(dataiter)return DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)# 构建一个的mask
def create_mask(question_size, answer_size):len_s1 = question_size + 2  # cls + seplen_s2 = answer_size + 1  # sep# 创建掩码张量mask = torch.ones(len_s1 + len_s2, len_s1 + len_s2)# 遍历s1的每个tokenfor i in range(len_s1):# s1的当前token不能看到s2的任何tokenmask[i, len_s1:] = 0# 遍历s2的每个tokenfor i in range(len_s2):# s2的当前token不能看到后面的s2 tokenmask[len_s1 + i, len_s1 + i + 1:] = 0return maskdef pad_mask(tensor, target_shape):# 获取输入张量和目标形状的长宽height, width = tensor.shapetarget_height, target_width = target_shape# 创建一个全零张量,形状为目标形状result = torch.zeros(target_shape, dtype=tensor.dtype, device=tensor.device)# 计算需要填充或截断的区域h_start = 0w_start = 0h_end = min(height, target_height)w_end = min(width, target_width)# 将原始张量对应的部分填充到全零张量中result[h_start:h_end, w_start:w_end] = tensor[:h_end - h_start, :w_end - w_start]return resultdef build_sample(tokenizer, corpus, max_len, valid_flag=False):x_list, y_list = corpus# 随机获取一组样本:random_index = random.randint(0, len(x_list) - 1)x = x_list[random_index]if valid_flag:print(x)y = y_list[random_index]# 中文的文本转化为tokenizer的id 不添加【CLS】input_ids_x = tokenizer.encode(x, add_special_tokens=False)input_ids_y = tokenizer.encode(y, add_special_tokens=False)pad_x = [tokenizer.cls_token_id] + input_ids_x + [tokenizer.sep_token_id] + input_ids_ypad_y = len(input_ids_x) * [-1] + [tokenizer.sep_token_id] + input_ids_y# 自己对长度做处理  padding 到 max_lenpad_x = pad_x[:max_len] + [0] * (max_len - len(pad_x))pad_y = pad_y[:max_len] + [0] * (max_len - len(pad_y))# 构建一个的mask矩阵,让prompt内可以交互,answer中上下文之间没有交互mask = create_mask(len(input_ids_x), len(input_ids_y))mask = pad_mask(mask, (max_len, max_len))return [torch.LongTensor(pad_x), torch.LongTensor(pad_y), mask]# 建立模型
def build_model(vocab_size, char_dim):model = LanguageModel(char_dim, vocab_size)return model# 采样方式
def sampling_strategy(prob_distribution):if random.random() > 0.1:strategy = "greedy"else:strategy = "sampling"if strategy == "greedy":return int(torch.argmax(prob_distribution))elif strategy == "sampling":prob_distribution = prob_distribution.cpu().numpy()return np.random.choice(list(range(len(prob_distribution))), p=prob_distribution)def evaluate(openings, model, tokenizer, corpus):model.eval()# 转化为input_idopenings = tokenizer.encode(openings)# 控制生成的字数with torch.no_grad():while len(openings) <= 50:x = torch.LongTensor([openings])if torch.cuda.is_available():x = x.cuda()# 因为有了mask 所以输入就是输出,看最后一个字的预测y_pred = model(x)[0][-1]index = sampling_strategy(y_pred)openings.append(index)return tokenizer.decode(openings)def train(corpus_path, save_weight=True):epoch_num = 15  # 训练轮数batch_size = 32  # 每次训练样本个数train_sample = 1000  # 每轮训练总共训练的样本总数char_dim = 768  # 每个字的维度tokenizer = BertTokenizer.from_pretrained(r"D:\NLP\video\第六周\bert-base-chinese")vocab_size = 21128max_len = 50corpus = load_corpus(corpus_path)  # 加载语料model = build_model(vocab_size, char_dim)  # 建立模型if torch.cuda.is_available():model = model.cuda()optim = torch.optim.Adam(model.parameters(), lr=0.001)  # 建立优化器# datadataset = build_dataset(train_sample, tokenizer, corpus, max_len)print("文本词表模型加载完毕,开始训练")for epoch in range(epoch_num):model.train()watch_loss = []for x, y, mask in dataset:  # 构建一组训练样本if torch.cuda.is_available():x, y, mask = x.cuda(), y.cuda(), mask.cuda()optim.zero_grad()  # 梯度归零loss = model(x, y, mask)  # 计算lossloss.backward()  # 计算梯度optim.step()  # 更新权重watch_loss.append(loss.item())print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))result1 = evaluate("阿根廷歹徒抢服装尺码不对拿回店里换", model, tokenizer, corpus)print(result1)if not save_weight:returnelse:base_name = os.path.basename(corpus_path).replace("txt", "pth")model_path = os.path.join("model", base_name)torch.save(model.state_dict(), model_path)returnif __name__ == "__main__":train("sample_data.json", False)# mask = torch.tril(torch.ones(4, 4)).unsqueeze(0).unsqueeze(0)# print(mask)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17024.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】POSIX线程库——线程控制

目录 1.线程创建方法 例&#xff1a;多线程创建 2.线程终止 2.1 return nulptr; 2.2 pthread_exit(nullptr); 3. 线程等待 3.1 等待原因 3.2 等待方法 线程终止的返回值问题 4.线程取消 5. 线程分离 5.1 分离原因 5.2 分离方法 6.封装线程 用的接口是POSIX线程库…

音视频开发—音频相关概念:数模转换、PCM数据与WAV文件详解

文章目录 前言1.模拟数字转换&#xff08;ADC&#xff09;1.1ADC的关键步骤&#xff1a; 2.数字模拟转换&#xff08;DAC&#xff09;2.1DAC 的基本流程包括&#xff1a; 3.PCM数据3.1PCM 数据的关键要素包括&#xff1a; 4.WAV文件4.1 WAV的构成4.2WAV文件的标准块结构4.3WAV的…

OpenLayers6入门,OpenLayers实现在地图上拖拽编辑修改绘制图形

专栏目录: OpenLayers6入门教程汇总目录 前言 在前面一章中,我们已经学会了如何绘制基础的三种图形线段、圆形和多边形:《OpenLayers6入门,OpenLayers图形绘制功能,OpenLayers实现在地图上绘制线段、圆形和多边形》,那么本章将在此基础上实现图形的拖拽编辑功能,方便我…

使用Java 读取PDF表格数据并保存到TXT或Excel

目录 导入相关Java库 Java读取PDF表格数据并保存到TXT Java读取PDF表格数据并保存到Excel 在日常工作中&#xff0c;我们经常需要处理来自各种来源的数据。其中&#xff0c;PDF 文件是常见的数据来源之一。这类文件通常包含丰富的信息&#xff0c;其中可能包含重要的表格数据…

FreeRtos进阶——栈保存现场的几种场景

MCU架构 在认识栈的结构前&#xff0c;我们先来认识以下单片机的简单架构。在我们的CPU中有着很重要的一个模块——寄存器&#xff08;R0-R15&#xff09;&#xff0c;其中R13&#xff0c;R14&#xff0c;R15的别称分别为SP栈顶指针、LR返回地址、PC当前指令地址。外部RAM是单片…

Android Gradle plugin 版本和Gradle 版本

1.当看到这两个版本时&#xff0c;确实有点迷糊。但是他们是独立的&#xff0c;没有太大关联。 就是说在Android studio中看到的两个版本信息&#xff0c;并无太大关联&#xff0c;是相互独立的。Gradle插件版本决定了你的项目是如何构建的&#xff0c;而Gradle版本是执行构建…

对竞品分析的理解

一、竞品分析是什么 竞品分析即对竞争对手进行分析&#xff0c;是市场研究中的一项重要工作&#xff0c;它可以帮助企业了解竞争对手的产品、策略、市场表现等信息&#xff0c;通过竞品分析可以为自己的产品制定更加精准的策略。 二、为什么要做竞品分析 1.了解市场情况 了解…

vue/core源码中ref源码的js化

起源&#xff1a; 当看见reactivity文件中的ref.ts文件长达五百多的ts代码后&#xff0c;突发奇想想看下转化成js有多少行。 进行转化&#xff1a; let shouldTrack true; // Define shouldTrack variable let activeEffect null; // Define activeEffect variable// 定义…

M2m中的采样

采样的完整代码 import torch import numpy as np from torchvision import datasets, transforms from torch.utils.data import DataLoader, WeightedRandomSampler, SubsetRandomSamplerdef get_oversampled_data(dataset, num_sample_per_class):""" Gener…

Upstream最新发布2024年汽车网络安全报告-百度网盘下载

Upstream最新发布2024年汽车网络安全报告-百度网盘下载 2024年2月7日&#xff0c;Upstream Security发布了2024年Upstream《GLOBAL AUTOMOTIVE CYBERSECURITY REPORT》。这份报告的第六版着重介绍了汽车网络安全的拐点&#xff1a;从实验性的黑客攻击发展到规模庞大的攻击&…

fpga系列 HDL 00 : 可编程逻辑器件原理

一次性可编程器件&#xff08;融保险丝实现&#xff09; 一次性可编程器件&#xff08;One-Time Programmable Device&#xff0c;简称 OTP&#xff09;是一种在制造后仅能编程一次的存储设备。OTP器件在编程后数据不可更改。这些器件在很多应用场景中具有独特的优势和用途。 …

【LeetCode算法】第83题:删除排序链表中的重复元素

目录 一、题目描述 二、初次解答 三、官方解法 四、总结 一、题目描述 二、初次解答 1. 思路&#xff1a;双指针法&#xff0c;只需遍历一遍。使用low指向前面的元素&#xff0c;high用于查找low后面与low不同内容的节点。将具有不同内容的节点链接在low后面&#xff0c;实…

【c++】菱形虚拟继承的虚函数表如何继承

请看如下代码 #include <iostream>// 基类 class Base { public:virtual void foo() { std::cout << "Base::foo()" << std::endl; }virtual void bar() { std::cout << "Base::bar()" << std::endl; } };// 虚拟继承的中间…

西门子S7-1200加入MRP 环网用法

MRP&#xff08;介质冗余&#xff09;功能概述 SIMATIC 设备采用标准的冗余机制为 MRP&#xff08;介质冗余协议&#xff09;&#xff0c;符合 IEC62439-2 标准&#xff0c;典型重新组态时间为 200ms&#xff0c;每个环网最多支持 50个设备。​博途TIA/WINCC社区VX群 ​博途T…

Linux 批量网络远程PXE

一、搭建PXE远程安装服务器 1、yum -y install tftp-server xinetd #安装tftp服务 2、修改vim /etc/xinetd.d/tftpTFTP服务的配置文件 systemctl start tftp systemctl start xinetd 3、yum -y install dhcp #---安装服务 cp /usr/share/doc/dhc…

利用Python队列生产者消费者模式构建高效爬虫

目录 一、引言 二、生产者消费者模式概述 三、Python中的队列实现 四、生产者消费者模式在爬虫中的应用 五、实例分析 生产者类&#xff08;Producer&#xff09; 消费者类&#xff08;Consumer&#xff09; 主程序 六、总结 一、引言 随着互联网的发展&#xff0c;信…

Bug:Linux用户拥有r权限但无法打开文件【Linux权限体系】

Bug&#xff1a;Linux用户拥有r权限但无法打开文件【Linux权限体系】 0 问题描述&解决 问题描述&#xff1a; 通过go编写了一个程序&#xff0c;产生的/var/log/xx日志文件发现普通用户无权限打开 - 查看文件权限发现该文件所有者、所有者组、其他用户均有r权限 - 查看该日…

5个好用的AI写论文网站推荐

目录 1.AIQuora论文写作 2.passyyds 答辩PPT 3.AIPassgo论文降AIGC 4.文状元 5.passyyds论文写作 毕业论文是每个毕业生的痛&#xff0c;不管你是本科还是硕士要想顺利毕业你就不得不面对论文。然而&#xff0c;面对论文写作时常常感到无从下手&#xff1a;有时缺乏灵感&a…

【JavaEE进阶】——要想代码不写死,必须得有spring配置(properties和yml配置文件)

目录 本章目标&#xff1a; &#x1f6a9;配置文件 &#x1f6a9;SpringBoot配置文件 &#x1f388;配置⽂件的格式 &#x1f388; properties 配置⽂件说明 &#x1f4dd;properties语法格式 &#x1f4dd;读取配置文件 &#x1f4dd;properties 缺点分析 &#x1f3…

修改uniapp内置组件checkbox的样式

默认情况下 <view style"margin-bottom: 20rpx;"><label style"display: flex;align-items: center;width: fit-content;" click"handleCheck(cxm4s)"><checkbox /><text>车信盟出险4S维保</text></label>…