NLP(19)--大模型发展(3)

前言

仅记录学习过程,有问题欢迎讨论

大模型训练相关知识:

问题:

  1. 数据集过大,快速训练
  2. 模型过大,gpu跑不完

方案:

  • 数据并行训练:
    复制数据(batch_size)到多个gpu,各自计算loss,再反传更新;至少需要一张卡能训练一个样本
  • 模型并行:
    模型的不同层放到不同的gpu上,解决了单卡不够大的问题,但是需要更多的通讯时间了(一次训练需要传播很多次)【时间换空间】
  • 张量并行:
    将张量划分到不同的gpu上(张量左右分割),进一步减少对单卡的需求,【时间换空间】
    可以混合使用

浮点精度损失:
float32:
使用32位二进制来表示一个浮点数。第一位用于表示符号位(正或负),
接下来的八位用于表示指数,剩下的23位用于表示尾数或分数部分。Float32的数值范围大约在±3.4E+38之间
0.2 = 0.00110(B)~、
DeepSpeed(零冗余优化器)

优化训练的方案

ZeRo:

  • 相当于张量并行,gradients,parameters,分散到不同gpu计算。速度变慢
    ZeRo -offload
    添加内存帮助显卡计算哈哈

大模型的蒸馏: 小模型直接学习文本数据之外,还学习大模型预测的结果
大模型的剪枝: 删除大模型中不重要的部分,减少模型大小

PEFT微调:(当大模型对某个方面表现不好)

  • 预训练模型+微调数据集
    原始权重不动,增加一部分可训练的权重,去结合之前的部分

prompt tuning:

  • 预训练模型+提示词(提前告诉模型需要训练什么)(添加在input data)

P-tuning V2:训练虚拟token(添加在emb)

  • 该方法将 Prompt 转换为可以学习的 Embedding 层,并用MLP+LSTM的方式来对Prompt Embedding进行一层处理。

Adapter: 新增可训练的层(添加在fft后)

LoRa(目前比较流行):

  • 通过低秩分解来模拟参数的改变量,从而以极小的参数量来实现大模型的间接训练(最后再扩大)。

RAG(Retrieve Augmented Generation):

  • 模型幻觉问题,遇见不知道的问题,会乱答
    召回部分和Prompt相似的段落,重新输入模型,获取可靠答案。

优势(主要就这两条):
(1)可扩展性:减少模型大小和训练成本,并能够快速扩展知识。

(2)准确性:模型基于事实进行回答,减少幻觉的发生。

难点
(1)数据标注:需要对数据进行标注,以提供给模型进行检索。

(2)检索算法:需要设计高效的检索算法,以提高检索的效率和准确率。
在这里插入图片描述

BPE(Byte pair encoding):压缩算法

  • RAG针对词表vocab的以下问题
    1.测试过程出现词表没有的词
    2.词表过大
    3.不同语种,切分粒度不同

  • 如aaabdaaaabc ==> XdXac
    在nlp中,针对语料中出现的重复词,添加到vocab中,再切分,再添加

  • BPE通过构建跨语言共享的子词词汇表,提高了模型处理多种语言的能力,有效解决了不同语言间词汇差异大、低频词和OOV问题,增强了模型的泛化能力和翻译性能。

代码:

使用lora对大模型进行微调:
使用的是序列标注的ner代码
main.py

# -*- coding: utf-8 -*-import torch
import os
import random
import os
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data
from peft import get_peft_model, LoraConfig, TaskTypelogging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型训练主程序
"""def peft_wrapper(model):peft_config = LoraConfig(r=8,lora_alpha=32,lora_dropout=0.1,target_modules=["query", "value"])return get_peft_model(model, peft_config)def main(config):# 创建保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])# 加载训练数据train_data = load_data(config["train_data_path"], config)# 加载模型model = TorchModel(config)model = peft_wrapper(model)# 标识是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,迁移模型至gpu")model = model.cuda()# 加载优化器optimizer = choose_optimizer(config, model)# 加载效果测试类evaluator = Evaluator(config, model, logger)# 训练for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data  # 输入变化时这里需要修改,比如多输入,多输出的情况loss = model(input_id, labels)loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)# torch.save(model.state_dict(), model_path)return model, train_dataif __name__ == "__main__":model, train_data = main(Config)

model.py

# -*- coding: utf-8 -*-import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
from transformers import BertModel"""
建立网络模型结构
"""class ConfigWrapper(object):def __init__(self, config):self.config = configdef to_dict(self):return self.configclass TorchModel(nn.Module):def __init__(self, config):super(TorchModel, self).__init__()self.config = ConfigWrapper(config)max_length = config["max_length"]class_num = config["class_num"]# self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)# self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)self.bert = BertModel.from_pretrained(config["bert_path"], return_dict=False)self.classify = nn.Linear(self.bert.config.hidden_size, class_num)self.crf_layer = CRF(class_num, batch_first=True)self.use_crf = config["use_crf"]self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  # loss采用交叉熵损失# 当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, target=None):# x = self.embedding(x)  #input shape:(batch_size, sen_len)# x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)x, _ = self.bert(x)predict = self.classify(x)  # ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)if target is not None:if self.use_crf:mask = target.gt(-1)return - self.crf_layer(predict, target, mask, reduction="mean")else:# (number, class_num), (number)return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))else:if self.use_crf:return self.crf_layer.decode(predict)else:return predictdef choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)if __name__ == "__main__":from config import Configmodel = TorchModel(Config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/15756.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[杂项]优化AMD显卡对DX9游戏(天谕)的支持

目录 关键词平台说明背景RDNA 1、2、3 架构的显卡支持游戏一、 优化方法1.1 下载 二、 举个栗子(以《天谕》为例)2.1 下载微星 afterburner 软件 查看游戏内信息(可跳过)2.2 查看D3D9 帧数2.3 关闭游戏,替换 dll 文件2…

精品PPT | MES设计与实践,业务+架构+实施(免费下载))

【1】关注本公众号,转发当前文章到微信朋友圈 【2】私信发送 MES设计与实践 【3】获取本方案PDF下载链接,直接下载即可。 如需下载本方案PPT/WORD原格式,请加入微信扫描以下方案驿站知识星球,获取上万份PPT/WORD解决方案&#x…

【探索数据结构】线性表之单链表

🎉🎉🎉欢迎莅临我的博客空间,我是池央,一个对C和数据结构怀有无限热忱的探索者。🙌 🌸🌸🌸这里是我分享C/C编程、数据结构应用的乐园✨ 🎈🎈&…

Autodl服务器中Faster-rcnn(jwyang)复现(一)

前言 在做实验时需要用到faster-rcnn做对比,本节首先完成代码复现,用的数据集是VOC2007~ 项目地址:https://github.com/jwyang/faster-rcnn.pytorch/tree/pytorch-1.0 复现环境:autodl服务器+python3.6+cuda11.3+Ubuntu20.04+Pytorch1.10.0 目录 一、环境配置二、编译cud…

2024年软考总结 信息系统管理师

选择题 英文题,我是一题也没把握,虽然我理解意思。 千万不要认为考死记硬背不对。目的不在于这。工程项目中有很多重要的数字,能记住说明你合格。 案例 几乎把答案全写在案例中了。 计算题 今年最简单。没有考成本。 只考了关键路径&a…

安卓开发:相机水印设置

1.更新水印 DecimalFormat DF new DecimalFormat("#"); DecimalFormat DF1 new DecimalFormat("#.#");LocationManager LM (LocationManager)getSystemService(Context.LOCATION_SERVICE); LM.requestLocationUpdates(LocationManager.GPS_PROVIDER, 2…

【学习笔记】计算机组成原理(七)

指令系统 文章目录 指令系统7.1 机器指令7.1.1 指令的一般格式7.1.2 指令字长 7.2 操作数类型和操作类型7.2.1 操作数类型7.2.2 数据在存储器中的存放方式7.2.3 操作类型 7.3 寻址方式7.3.1 指令寻址7.3.1.1 顺序寻址7.3.1.2 跳跃寻址 7.3.2 数据寻址7.3.2.1 立即寻址7.3.2.2 直…

精品PPT | 精益生产管理中MES系统的实现与应用(免费下载)

【1】关注本公众号,转发当前文章到微信朋友圈 【2】私信发送 MES系统的实现与应用 【3】获取本方案PDF下载链接,直接下载即可。 如需下载本方案PPT/WORD原格式,请加入微信扫描以下方案驿站知识星球,获取上万份PPT/WORD解决方案&…

Redis - 缓存场景

学习资料 学习的黑马程序员哔站项目黑马点评,用作记录和探究原理。 Redis缓存 缓存 :就是数据交换的缓冲区,是存储数据的临时地方,读写性能较高 缓存常见的场景: 数据库查询加速:通过将频繁查询的数据缓存起来&…

嵩山为什么称为三水之源

三水指黄河、淮河、济河,这三条河流环绕在嵩山周边。 黄河横亘在嵩山北部,其支流伊洛河从西南方环绕嵩山,然后汇入黄河。济河,古称济水,源自济源王屋山,自身河道在东晋时代被黄河夺占,从此消失。…

毕设 大数据校园卡数据分析

文章目录 0 前言1 课题介绍2 数据预处理2.1 数据清洗2.2 数据规约 3 模型建立和分析3.1 不同专业、性别的学生与消费能力的关系3.2 消费时间的特征分析 4 Web系统效果展示5 最后 0 前言 🔥 这两年开始毕业设计和毕业答辩的要求和难度不断提升,传统的毕设…

【vue2配置】Vue Router

Vue Router官网 1、npm install vue-router4 2、创建模块,在src目录小创/views/map/MapIndex.vue模块和创router/index.js文件 3、在router/index.js配置路由 import Vue from "vue"; import Router from "vue-router"; // 引入模块 const Ma…

智慧校园建设的进阶之路

智慧校园的建设现已到达了老练的阶段,许多学校设备充满着数字化信息,进出宿舍楼,校园一卡通体系会记载下学生信息,外来人员闯入会报警,翻开电脑就能查到学生是否在宿舍等……学生的学习和日子都充满了数字化的痕迹。但…

光速入门python的OpenCV

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 本文整理python的OpenCV模块的关键知识点 争取用最短的时间入门OpenCV 并且做到笔记功能直接复制使用 OpenCV简介 不浪费时间的介绍: 就是类似于ps操作图片。 至于为什么不直接用ps,因为只有程序能…

【找出满足差值条件的下标 I】python

目录 暴力题解 优化:滑动窗口维护大小值 暴力题解 class Solution:def findIndices(self, nums: List[int], indexDifference: int, valueDifference: int) -> List[int]:nlen(nums)for i in range(n):for j in range(n-1,-1,-1):if abs(i-j)>indexDiffere…

海康威视NVR通过ehome协议接入视频监控平台,视频浏览显示3011超时错误的问题解决,即:The request timeout! 【3011】

目录 一、问题描述 二、问题分析 2.1 初步分析 2.2 查看日志 2.3 问题验证 1、查看防火墙 2、查看安全组 3、问题原因 三、问题解决 3.1 防火墙开放相关端口 3.2 安全组增加规则 3.3 测试 1、TCP端口能够联通的情况 2、TCP端口不能够联通的情况 四、验证 五、云…

「51媒体」如何与媒体建立良好关系?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 与媒体建立良好关系对于企业或个人来说都是一项重要的公关活动。 了解媒体:研究媒体和记者的兴趣,提供相关且有价值的信息。 建立联系:通过专业的方式…

牛客NC324 下一个更大的数(三)【中等 双指针 Java/Go/PHP/C++】参考lintcode 52 · 下一个排列

题目 题目链接: https://www.nowcoder.com/practice/475da0d4e37a481bacf9a09b5a059199 思路 第一步:获取数字上每一个数,组成数组arr 第二步:利用“下一个排列” 问题解题方法来继续作答,步骤:利用lintc…

C++进阶之路:何为拷贝构造函数,深入理解浅拷贝与深拷贝(类与对象_中篇)

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨ 🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢,在这里我会分享我的知识和经验。&am…

PostgreSQL基础(三):PostgreSQL的基础操作

文章目录 PostgreSQL的基础操作 一、用户操作 二、权限操作 三、操作任务