微调Helsinki-NLP-en-zh模型

Helsinki-NLP 是一个广泛使用的开源机器翻译(Machine Translation,MT)模型系列,基于 Marian NMT 框架

Hugggingface地址:https://huggingface.co/Helsinki-NLP/opus-mt-en-zh

原本的模型对于国内外公司的名称支持度很差,比如会把‘FireFox‘翻译成‘消防’,所以我需要在保留原本翻译能力的基础上,增强对公司名称的翻译能力。

1 数据集准备

我使用GPT-4这类大模型为我生成了500条公司名称中英文对,原本是.xlsx格式的文件,将其合并转为.tsv
在这里插入图片描述

2 冻结参数

为了保留原来的翻译能力,我们需要冻结法大部分模型的参数,只解冻少量参数用于训练,最大程度的不影响翻译能力。

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import os# 配置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载本地模型和分词器
model_path = "/kaggle/input/helsinki-nlp-en-zh/pytorch/default/1/Helsinki-NLP-en-zh"  # 替换为本地模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)  # 自动检测并加载模型结构和权重# 将模型移动到设备
model = model.to(device)print("模型和分词器加载成功!")# 冻结编码器和解码器的低层参数,只解冻解码器高层和输出层
for name, param in model.named_parameters():if ("decoder.layers.5" in name or "lm_head" in name):param.requires_grad = Trueelse:param.requires_grad = False# 定义自定义数据集
class CompanyNameDataset(Dataset):def __init__(self, file_path, tokenizer, max_length=128):self.data = pd.read_csv(file_path, sep='\t', header=None, names=['source', 'target'])self.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.data)def __getitem__(self, idx):source = self.data.iloc[idx, 0]target = self.data.iloc[idx, 1]tokenized_data = tokenizer([source],                      # 输入源文本text_target=[target],          # 输入目标文本max_length=128,                # 设置最大长度padding="max_length",          # 填充到最大长度truncation=True,               # 截断到最大长度return_tensors="pt"            # 返回 PyTorch 张量)return {"input_ids": tokenized_data["input_ids"].squeeze(0),"attention_mask": tokenized_data["attention_mask"].squeeze(0),"labels": tokenized_data["labels"].squeeze(0),}

在冻结之前,可以通过下列代码查看模型的结构,会打印所有参数的名称,用以决定冻结哪些参数:

for name, param in model.named_parameters():print(name)

当运模型加载完毕后,我们再运行下面的代码查看解冻的参数:

for name, param in model.named_parameters():if param.requires_grad:print(f"解冻参数: {name}")

运行结果:

解冻参数: model.decoder.layers.5.self_attn.k_proj.weight
解冻参数: model.decoder.layers.5.self_attn.k_proj.bias
解冻参数: model.decoder.layers.5.self_attn.v_proj.weight
解冻参数: model.decoder.layers.5.self_attn.v_proj.bias
解冻参数: model.decoder.layers.5.self_attn.q_proj.weight
解冻参数: model.decoder.layers.5.self_attn.q_proj.bias
解冻参数: model.decoder.layers.5.self_attn.out_proj.weight
解冻参数: model.decoder.layers.5.self_attn.out_proj.bias
解冻参数: model.decoder.layers.5.self_attn_layer_norm.weight
解冻参数: model.decoder.layers.5.self_attn_layer_norm.bias
解冻参数: model.decoder.layers.5.encoder_attn.k_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.k_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn.v_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.v_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn.q_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.q_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn.out_proj.weight
解冻参数: model.decoder.layers.5.encoder_attn.out_proj.bias
解冻参数: model.decoder.layers.5.encoder_attn_layer_norm.weight
解冻参数: model.decoder.layers.5.encoder_attn_layer_norm.bias
解冻参数: model.decoder.layers.5.fc1.weight
解冻参数: model.decoder.layers.5.fc1.bias
解冻参数: model.decoder.layers.5.fc2.weight
解冻参数: model.decoder.layers.5.fc2.bias
解冻参数: model.decoder.layers.5.final_layer_norm.weight
解冻参数: model.decoder.layers.5.final_layer_norm.bias

我们选择把decoder的第5层解冻(即最靠近输出层的那一层),这样可以避免影响原始的翻译能力。

3 加载数据

# 加载数据
file_path = '/kaggle/input/company-logo-name-tsv/company_names.tsv'
dataset = CompanyNameDataset(file_path, tokenizer)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

4 训练模型

# 定义训练循环
def train_epoch(model, dataloader, optimizer, criterion, device):model.train()total_loss = 0for batch in dataloader:input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["labels"].to(device)outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossoptimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()return total_loss / len(dataloader)def evaluate_epoch(model, dataloader, criterion, device):model.eval()total_loss = 0with torch.no_grad():for batch in dataloader:input_ids = batch["input_ids"].to(device)attention_mask = batch["attention_mask"].to(device)labels = batch["labels"].to(device)outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstotal_loss += loss.item()return total_loss / len(dataloader)# 定义 EarlyStopping 类
class EarlyStopping:def __init__(self, patience=3, verbose=False, delta=0):"""Args:patience (int): 等待验证损失改进的轮数verbose (bool): 是否打印详细信息delta (float): 最小的改进幅度"""self.patience = patienceself.verbose = verboseself.counter = 0self.best_loss = Noneself.early_stop = Falseself.delta = deltadef __call__(self, val_loss, model):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.delta:self.counter += 1if self.verbose:print(f"EarlyStopping counter: {self.counter} out of {self.patience}")if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0  # 重置等待计数器# 初始化 EarlyStopping
early_stopping = EarlyStopping(patience=3, verbose=True)  # 等待3轮验证损失无改进# 开始训练
num_epochs = 20
for epoch in range(num_epochs):train_loss = train_epoch(model, train_loader, optimizer, criterion, device)val_loss = evaluate_epoch(model, val_loader, criterion, device)print(f"Epoch {epoch+1}/{num_epochs}")print(f"Train Loss: {train_loss:.4f}")print(f"Validation Loss: {val_loss:.4f}")# 调用 EarlyStoppingearly_stopping(val_loss, model)if early_stopping.early_stop:print("Early stopping triggered. Training stopped.")break# 保存微调后的模型
torch.save(model.state_dict(), "./fine_tuned_marianmt.pth")

使用早停法,避免过拟合。

5 合并参数

原模型中的目录结如下:
在这里插入图片描述
其中pytorch_model.bin就是原模型的权重,现在我们要把fine_tuned_marianmt.pth加载进去:

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer# 原始模型和分词器路径
original_model_path = "./original_model"  # 替换为原始模型的文件夹路径
fine_tuned_weights = "./fine_tuned_marianmt.pth"  # 微调后的权重路径
fine_tuned_model_path = "./fine_tuned_model"  # 微调后的模型保存路径# 加载原始模型架构
model = AutoModelForSeq2SeqLM.from_pretrained(original_model_path)# 加载微调后的权重
state_dict = torch.load(fine_tuned_weights)
model.load_state_dict(state_dict)# 保存微调后的模型到新目录
model.save_pretrained(fine_tuned_model_path)# 保存分词器到新目录(分词器未变化,可直接复制原始分词器配置)
tokenizer = AutoTokenizer.from_pretrained(original_model_path)
tokenizer.save_pretrained(fine_tuned_model_path)print(f"微调后的模型和分词器已保存到: {fine_tuned_model_path}")

微调后的参数就加载完毕了:
在这里插入图片描述
可以正确翻译了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/60235.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++中的初始化列表

初始化参数列表 用于在构造函数中初始化类的数据成员。 语法:构造函数():属性1(值1),属性2(值2)......{ } 性质: 1.只能在构造函数中使用 2.引用 或 常量…

QT基本绘图

QT绘图 1.概述 这篇文章介绍如何绘图 2.绘图基本操作 创建一个普通的widget类型的项目 在widget.h 文件中重写绘图事件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : p…

【IEEE独立出版 |往届均已成功检索】第八届大数据与应用统计国际学术研讨会(ISBDAS 2025)

重要信息 时间地点&#xff1a;2025年2月28日-3月2日 中国 广州 会议检索&#xff1a;EI Compendex, Scopus →点此投稿/参会/了解会议详情 组织单位 主办单位&#xff1a;广东省高等教育学会人工智能与高等教育研究分会 协办单位&#xff1a;北京师范大学人工智能与未…

C# 中的 LINQ:轻松处理集合和数据

C#中的LINQ&#xff08;Language Integrated Query&#xff09;&#xff0c;这是一个非常强大且实用的功能&#xff0c;可以简化集合操作和数据查询。以下是一篇关于C#中LINQ使用的文章。 引言 LINQ&#xff08;Language Integrated Query&#xff09;是C#语言的一个重要特性…

自动化立体仓库:详解

导语 大家好&#xff0c;我是社长&#xff0c;老K。专注分享智能制造和智能仓储物流等内容。 新书《智能物流系统构成与技术实践》人俱乐部 完整版文件和更多学习资料&#xff0c;请球友到知识星球【智能仓储物流技术研习社】自行下载。 自动化立体仓库&#xff08;Automated S…

Hash table类算法【leetcode】

哈希表中关键码就是数组的索引下标&#xff0c;然后通过下标直接访问数组中的元素 那么哈希表能解决什么问题呢&#xff0c;一般哈希表都是用来快速判断一个元素是否出现集合里。 例如要查询一个名字是否在这所学校里。 要枚举的话时间复杂度是O(n)&#xff0c;但如果使用哈希…

window 中安装 php 环境

window 中安装 php 环境 一、准备二、下载三、安装四、测试 一、准备 安装前需要安装 Apache &#xff0c;可以查看这篇博客。 二、下载 先到这里下载 这里选择版本为“VS16 x64 Thread Safe”&#xff0c;这个版本不要选择线程安全的&#xff0c;我试过&#xff0c;会缺少文…

Kubernetes部署Grafana详细教程

1. 概述 Grafana是一个强大的开源监控和可视化工具,可以帮助我们更好地理解和分析系统性能数据。在Kubernetes环境中部署Grafana可以让我们更方便地监控集群和应用的状态。 2. 准备工作 一个正常运行的Kubernetes集群kubectl命令行工具,已配置可以访问您的集群集群中已创建名…

嵌入式Linux学习之Linux基础再过部分——文件IO(1)

目录 先来看看Linux是如何操作文件IO的 文件描述符 打开文件open pathname flags mode 返回值 write 参数详解 返回值 在哪里你能使用write flags read 返回值 flags close lseek whence 参数常量 返回值 示例 1 示例 2 demo3 深入探究文件IO Linux 系统…

LeetCode 209 长度最小的子数组(滑动窗口)

跟着carl学算法&#xff0c;本系列博客仅做个人记录&#xff0c;建议大家都去看carl本人的博客&#xff0c;写的真的很好的&#xff01; 代码随想录 LeetCode 209 长度最小的子数组(滑动窗口) 给定一个含有 n 个正整数的数组和一个正整数 target 。 找出该数组中满足其总和大于…

C# 高级--反射 详解

一、反射是什么 1、C#编译运行过程 高级语言->编译->dll/exe文件->CLR/JIT->机器码 2、原理解析metadata&#xff1a;元数据数据清单&#xff0c;记录了dll中包含了哪些东西,是一个描述。IL&#xff1a;中间语言&#xff0c;编译把高级语言编译后得到的C#中最真…

【Web前端】Web API:构建Web应用核心

什么是 API API&#xff08;应用程序编程接口&#xff09;是一组定义了软件组件之间如何交互的规则和协议。它允许一个程序调用另一个程序的功能&#xff0c;而不用了解其内部实现细节。 Web 开发中&#xff0c;API 通常用于实现前端与后端之间的通信。 客户端 JavaScript 中的…

Telegram bot Mini-App开发实践---Telegram简单介绍与初始化小程序获取window.Telegram.WebApp对象并解析

➡️【好看的灵魂千篇一律,有趣的鲲志一百六七!】- 欢迎认识我~~ 作者:鲲志说 (公众号、B站同名,视频号:鲲志说996) 科技博主:极星会 星辉大使 后端研发:java、go、python、TS,前电商、现web3 主理人:COC杭州开发者社区主理人 、周周黑客松杭州主理人、 AI爱好…

VRT: 关于视频修复的模型

VRT: 关于视频修复的模型 1. 视频修复的背景与重要性背景介绍&#xff1a;重要性&#xff1a; 2. VRT的重要性和研究背景VRT的背景&#xff1a;VRT的重要性&#xff1a; 3. 视频修复概述3.1 定义与目标3.2 与单图像修复的区别3.3 对时间信息利用的需求 4. VRT模型详解4.1 整体框…

【linux】ubuntu下常用快捷键【笔记】

环境 硬件&#xff1a;通用PC 系统&#xff1a;Ubuntu 20.04 软件 &#xff1a; 打开终端窗口&#xff1a;Ctrl Alt T 关闭当前窗口&#xff1a;Alt F4 改变窗口大小&#xff1a;Alt F8 移动窗口&#xff1a; Alt F7 配合 “←”、“→”、“↑”、“↓”来移动窗口 …

java 增强型for循环 详解

Java 增强型 for 循环&#xff08;Enhanced for Loop&#xff09;详解 增强型 for 循环&#xff08;也称为 “for-each” 循环&#xff09;是 Java 从 JDK 5 开始引入的一种便捷循环语法&#xff0c;旨在简化对数组或集合类的迭代操作。 1. 基本语法 语法格式 for (类型 变量…

游戏引擎学习第17天

视频参考:https://www.bilibili.com/video/BV1LPUpYJEXE/ 回顾上一天的内容 1. 整体目标&#xff1a; 处理键盘输入&#xff1a;将键盘输入的处理逻辑从平台特定的代码中分离出来&#xff0c;放入更独立的函数中以便管理。优化消息循环&#xff1a;确保消息循环能够有效处理 …

jmeter常用配置元件介绍总结之配置元件

系列文章目录 1.windows、linux安装jmeter及设置中文显示 2.jmeter常用配置元件介绍总结之安装插件 3.jmeter常用配置元件介绍总结之线程组 4.jmeter常用配置元件介绍总结之函数助手 5.jmeter常用配置元件介绍总结之取样器 6.jmeter常用配置元件介绍总结之jsr223执行pytho…

python基础之学生成绩管理系统

声明&#xff1a;学习视频来自b站up主 泷羽sec&#xff0c;如涉及侵权马上删除文章 声明&#xff1a;本文主要用作技术分享&#xff0c;所有内容仅供参考。任何使用或依赖于本文信息所造成的法律后果均与本人无关。请读者自行判断风险&#xff0c;并遵循相关法律法规。 while…

关联度分析、灰色预测GM(1,1)、GM(1,1)残差模型——基于Python实现

关联度分析 import numpy as np import pandas as pd #关联度分析 #参考序列 Y_0[170,174,197,216.4,235.8] #被比较序列 Y_1[195.4,189.9,187.2,205,222.7] Y_2[308,310,295,346,367]#初始化序列 X_0np.array(Y_0)/Y_0[0] X_1np.array(Y_1)/Y_1[0] X_2np.array(Y_2)/Y_2[0]#计…