微调Helsinki-NLP-en-zh模型

Helsinki-NLP 是一个广泛使用的开源机器翻译(Machine Translation,MT)模型系列,基于 Marian NMT 框架

Hugggingface地址:https://huggingface.co/Helsinki-NLP/opus-mt-en-zh

原本的模型对于国内外公司的名称支持度很差,比如会把‘FireFox‘翻译成‘消防’,所以我需要在保留原本翻译能力的基础上,增强对公司名称的翻译能力。

1 数据集准备

我使用GPT-4这类大模型为我生成了500条公司名称中英文对,原本是.xlsx格式的文件,将其合并转为.tsv
在这里插入图片描述

2 冻结参数

为了保留原来的翻译能力,我们需要冻结法大部分模型的参数,只解冻少量参数用于训练,最大程度的不影响翻译能力。

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import os# 配置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载本地模型和分词器
model_path = "/kaggle/input/helsinki-nlp-en-zh/pytorch/default/1/Helsinki-NLP-en-zh"  # 替换为本地模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)  # 自动检测并加载模型结构和权重# 将模型移动到设备
model = model.to(device)print("模型和分词器加载成功!")# 冻结编码器和解码器的低层参数,只解冻解码器高层和输出层
for name, param in model.named_parameters():if ("decoder.layers.5" in name or "lm_head" in name):param.requires_grad = Trueelse:param.requires_grad = False# 定义自定义数据集
class CompanyNameDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.data = pd.read_csv(file_path, sep='\t', header=None, names=['source', 'target'])self.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.data)def __getitem__(self, idx):source = self.data.iloc[idx, 0]target = self.data.iloc[idx, 1]tokenized_data = tokenizer([source],                      # 输入源文本text_target=[target],          # 输入目标文本max_length=128,                # 设置最大长度padding="max_length",          # 填充到最大长度truncation=True,               # 截断到最大长度return_tensors="pt"            # 返回 PyTorch 张量)return {"input_ids": tokenized_data["input_ids"].squeeze(0),"attention_mask": tokenized_data["attention_mask"].squeeze(0),"labels": tokenized_data["labels"].squeeze(0),}

在冻结之前,可以通过下列代码查看模型的结构,会打印所有参数的名称,用以决定冻结哪些参数:

for name, param in model.named_parameters():print(name)

当运模型加载完毕后,我们再运行下面的代码查看解冻的参数:

for name, param in model.named_parameters():if param.requires_grad:print(f"解冻参数: {name}")

运行结果:

解冻参数: model.decoder.layers.5.self_attn.k_proj.weight
解冻参数: model.decoder.layers.5.self_attn.k_proj.bias
解冻参数: model.decoder.layers.5.self_attn.v_proj.weight
解冻参数: model.decoder.layers.5.self_attn.v_proj.bias
解冻参数: model.decoder.layers.5.self_attn.q_proj.weight
解冻参数: model.decoder.layers.5.self_attn.q_proj.bias
解冻参数: model.decoder.layers.5.self_attn.out_proj.weight
解冻参数: model.decoder.layers.5.self_attn.out_proj.bias
解冻参数: model.decoder.layers.5.self_attn_layer_norm.weight
解冻参数: model.decoder.layers.5.self_attn_layer_norm.bias
解冻参数: model.decoder.layers.5.encoder_attn.k_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.k_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn.v_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.v_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn.q_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.q_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn.out_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.out_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn_layer_norm.weight
解冻参数: model.decoder.layers.5.encoder_attn_layer_norm.bias
解冻参数: model.decoder.layers.5.fc1.weight
解冻参数: model.decoder.layers.5.fc1.bias
解冻参数: model.decoder.layers.5.fc2.weight
解冻参数: model.decoder.layers.5.fc2.bias
解冻参数: model.decoder.layers.5.final_layer_norm.weight
解冻参数: model.decoder.layers.5.final_layer_norm.bias

我们选择把decoder的第5层解冻(即最靠近输出层的那一层),这样可以避免影响原始的翻译能力。

3 加载数据

# 加载数据
file_path = '/kaggle/input/company-logo-name-tsv/company_names.tsv'
dataset = CompanyNameDataset(file_path, tokenizer)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

4 训练模型

# 定义训练循环
def train_epoch(model, dataloader, optimizer, criterion, device):model.train()total_loss = 0for batch in dataloader:input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["labels"].to(device)outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossoptimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(dataloader)def evaluate_epoch(model, dataloader, criterion, device):model.eval()total_loss = 0with torch.no_grad():for batch in dataloader:input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["labels"].to(device)outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstotal_loss += loss.item()return total_loss / len(dataloader)# 定义 EarlyStopping 类
class EarlyStopping:def __init__(self, patience=3, verbose=False, delta=0):"""Args:patience (int): 等待验证损失改进的轮数verbose (bool): 是否打印详细信息delta (float): 最小的改进幅度"""self.patience = patienceself.verbose = verboseself.counter = 0self.best_loss = Noneself.early_stop = Falseself.delta = deltadef __call__(self, val_loss, model):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.delta:self.counter += 1if self.verbose:print(f"EarlyStopping counter: {self.counter} out of {self.patience}")if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0  # 重置等待计数器# 初始化 EarlyStopping
early_stopping = EarlyStopping(patience=3, verbose=True)  # 等待3轮验证损失无改进# 开始训练
num_epochs = 20
for epoch in range(num_epochs):train_loss = train_epoch(model, train_loader, optimizer, criterion, device)val_loss = evaluate_epoch(model, val_loader, criterion, device)print(f"Epoch {epoch+1}/{num_epochs}")print(f"Train Loss: {train_loss:.4f}")print(f"Validation Loss: {val_loss:.4f}")# 调用 EarlyStoppingearly_stopping(val_loss, model)if early_stopping.early_stop:print("Early stopping triggered. Training stopped.")break# 保存微调后的模型
torch.save(model.state_dict(), "./fine_tuned_marianmt.pth")

使用早停法,避免过拟合。

5 合并参数

原模型中的目录结如下:
在这里插入图片描述
其中pytorch_model.bin就是原模型的权重,现在我们要把fine_tuned_marianmt.pth加载进去:

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer# 原始模型和分词器路径
original_model_path = "./original_model"  # 替换为原始模型的文件夹路径
fine_tuned_weights = "./fine_tuned_marianmt.pth"  # 微调后的权重路径
fine_tuned_model_path = "./fine_tuned_model"  # 微调后的模型保存路径# 加载原始模型架构
model = AutoModelForSeq2SeqLM.from_pretrained(original_model_path)# 加载微调后的权重
state_dict = torch.load(fine_tuned_weights)
model.load_state_dict(state_dict)# 保存微调后的模型到新目录
model.save_pretrained(fine_tuned_model_path)# 保存分词器到新目录(分词器未变化,可直接复制原始分词器配置)
tokenizer = AutoTokenizer.from_pretrained(original_model_path)
tokenizer.save_pretrained(fine_tuned_model_path)print(f"微调后的模型和分词器已保存到: {fine_tuned_model_path}")

微调后的参数就加载完毕了:
在这里插入图片描述
可以正确翻译了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/60235.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT基本绘图

QT绘图 1.概述 这篇文章介绍如何绘图 2.绘图基本操作 创建一个普通的widget类型的项目 在widget.h 文件中重写绘图事件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : p…

自动化立体仓库:详解

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》人俱乐部 完整版文件和更多学习资料&#xff0c;请球友到知识星球【智能仓储物流技术研习社】自行下载。 自动化立体仓库&#xff08;Automated S…

Hash table类算法【leetcode】

哈希表中关键码就是数组的索引下标&#xff0c;然后通过下标直接访问数组中的元素 那么哈希表能解决什么问题呢&#xff0c;一般哈希表都是用来快速判断一个元素是否出现集合里。 例如要查询一个名字是否在这所学校里。 要枚举的话时间复杂度是O(n)&#xff0c;但如果使用哈希…

window 中安装 php 环境

window 中安装 php 环境 一、准备二、下载三、安装四、测试 一、准备 安装前需要安装 Apache &#xff0c;可以查看这篇博客。 二、下载 先到这里下载 这里选择版本为“VS16 x64 Thread Safe”&#xff0c;这个版本不要选择线程安全的&#xff0c;我试过&#xff0c;会缺少文…

嵌入式Linux学习之Linux基础再过部分——文件IO(1)

目录 先来看看Linux是如何操作文件IO的 文件描述符 打开文件open pathname flags mode 返回值 write 参数详解 返回值 在哪里你能使用write flags read 返回值 flags close lseek whence 参数常量 返回值 示例 1 示例 2 demo3 深入探究文件IO Linux 系统…

C# 高级--反射 详解

一、反射是什么 1、C#编译运行过程 高级语言->编译->dll/exe文件->CLR/JIT->机器码 2、原理解析metadata&#xff1a;元数据数据清单&#xff0c;记录了dll中包含了哪些东西,是一个描述。IL&#xff1a;中间语言&#xff0c;编译把高级语言编译后得到的C#中最真…

【Web前端】Web API:构建Web应用核心

什么是 API API&#xff08;应用程序编程接口&#xff09;是一组定义了软件组件之间如何交互的规则和协议。它允许一个程序调用另一个程序的功能&#xff0c;而不用了解其内部实现细节。 Web 开发中&#xff0c;API 通常用于实现前端与后端之间的通信。 客户端 JavaScript 中的…

Telegram bot Mini-App开发实践---Telegram简单介绍与初始化小程序获取window.Telegram.WebApp对象并解析

➡️【好看的灵魂千篇一律,有趣的鲲志一百六七!】- 欢迎认识我~~ 作者:鲲志说 (公众号、B站同名,视频号:鲲志说996) 科技博主:极星会 星辉大使 后端研发:java、go、python、TS,前电商、现web3 主理人:COC杭州开发者社区主理人 、周周黑客松杭州主理人、 AI爱好…

VRT: 关于视频修复的模型

VRT: 关于视频修复的模型 1. 视频修复的背景与重要性背景介绍&#xff1a;重要性&#xff1a; 2. VRT的重要性和研究背景VRT的背景&#xff1a;VRT的重要性&#xff1a; 3. 视频修复概述3.1 定义与目标3.2 与单图像修复的区别3.3 对时间信息利用的需求 4. VRT模型详解4.1 整体框…

游戏引擎学习第17天

视频参考:https://www.bilibili.com/video/BV1LPUpYJEXE/ 回顾上一天的内容 1. 整体目标&#xff1a; 处理键盘输入&#xff1a;将键盘输入的处理逻辑从平台特定的代码中分离出来&#xff0c;放入更独立的函数中以便管理。优化消息循环&#xff1a;确保消息循环能够有效处理 …

jmeter常用配置元件介绍总结之配置元件

系列文章目录 1.windows、linux安装jmeter及设置中文显示 2.jmeter常用配置元件介绍总结之安装插件 3.jmeter常用配置元件介绍总结之线程组 4.jmeter常用配置元件介绍总结之函数助手 5.jmeter常用配置元件介绍总结之取样器 6.jmeter常用配置元件介绍总结之jsr223执行pytho…

Java基础知识(五)

文章目录 ObjectObject 类的常见方法有哪些&#xff1f; 和 equals() 的区别hashCode() 有什么用&#xff1f;为什么要有 hashCode&#xff1f;为什么重写 equals() 时必须重写 hashCode() 方法&#xff1f; 参考链接 Object Object 类的常见方法有哪些&#xff1f; Object 类…

【大模型】LLaMA: Open and Efficient Foundation Language Models

链接&#xff1a;https://arxiv.org/pdf/2302.13971 论文&#xff1a;LLaMA: Open and Efficient Foundation Language Models Introduction 规模和效果 7B to 65B&#xff0c;LLaMA-13B 超过 GPT-3 (175B)Motivation 如何最好地缩放特定训练计算预算的数据集和模型大小&…

2024 RISC-V中国峰会 安全相关议题汇总

安全之安全(security)博客目录导读 第四届 RISC-V 中国峰会(RISC-V Summit China 2024)于8月21日至23日在杭州成功举办。此次峰会汇聚了 RISC-V 国际基金会、百余家重点企业及研究机构,约3000人线下参与,并在19日至25日间举办了超过20场同期活动,与全球开发者共同…

Pyhon基础数据结构(列表)【蓝桥杯】

a [1,2,3,4,5] a.reverse() print("a ",a) a.reverse() print("a ",a)# 列表 列表&#xff08;list&#xff09;有由一系列按照特定顺序排序的元素组成 列表是有顺序的&#xff0c;访问任何元素需要通过“下标访问” 所谓“下标”就是指元素在列表从左…

【Visual Studio系列教程】如何在 VS 上编程?

上一篇博客中&#xff0c;我们介绍了《什么是 Visual Studio&#xff1f;》。本文&#xff0c;我们来看第2篇《如何在 VS 上编程&#xff1f;》。阅读本文大约10 分钟。我们会向文件中添加代码&#xff0c;了解 Visual Studio 编写、导航和了解代码的简便方法。 本文假定&…

MySQL更换瀚高语法更换

MySQL更换瀚高语法更换 一、前言二、语句 一、前言 水一篇,mysql更换瀚高之后&#xff0c;一些需要更换的语法介绍 > 二、语句 MySQL瀚高MySQL用法瀚高用法说明ifnull(x,y)coalesce(x,y)相同相同用于检查两个表达式并返回第一个非空表达式。如果第一个表达式不是 NULL&…

论文阅读——Intrusion detection systems using longshort‑term memory (LSTM)

一.基本信息 论文名称&#xff1a;Intrusion detection systems using longshort‑term memory (LSTM) 中文翻译&#xff1a;基于长短期记忆(LSTM)的入侵检测系统 DOI&#xff1a;10.1186/s40537-021-00448-4 作者&#xff1a;FatimaEzzahra Laghrissi1* , Samira Douzi2*, Kha…

大数据挖掘期末复习

大数据挖掘 数据挖掘 数据挖掘定义 技术层面&#xff1a; 数据挖掘就是从大量的、不完全的、有噪声的、模糊的、随机的实际应用数据中&#xff0c;提取隐含在其中、人们事先不知道的、但又潜在有用的信息的过程。 数据准备环节 数据选择 质量分析 数据预处理 数据仓库 …

等精度频率计的设计

目录 主控电路设计 频率测量与计算电路设计 顶层电路设计 功能扩展及应用 频率测量的三种方法 等精度频率计通过控制闸门信号与被测信号同步&#xff0c;消除了直接测频法中的计数误差&#xff0c;因而在被测信号频率范围内测量精度基本上是恒定的。 本节以设计能够测量信号…