·请参考本系列目录:【mT5多语言翻译】之一——实战项目总览
[1] 模型翻译推理
在分别使用全量参数微调和PEFT微调训练完模型之后,我们来测试模型的翻译效果。推理代码如下:
# 导入模型
if conf.is_peft:model = AutoModelForSeq2SeqLM.from_pretrained(conf.peft_save)
else:model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)model.load_state_dict(torch.load(conf.save_path))
model.to(conf.device)
model.eval()
sentences = ["kor:我要去健身了","jpn:我要去健身了","kor:他说他会爱我一辈子","jpn:他说他会爱我一辈子",
]tokenizer = AutoTokenizer.from_pretrained(conf.pretrained_path)ids = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sentences,return_tensors='pt',padding='max_length',truncation=True,max_length=conf.max_seq_len,return_attention_mask=False
)
input_ids = ids['input_ids'].to(conf.device)output_tokens = model.generate(input_ids, num_beams=10, num_return_sequences=3)for token_set in output_tokens:print(tokenizer.decode(token_set, skip_special_tokens=True))
因为训练方式有全量参数微调和PEFT微调两种,不同方式保存的模型不同。前者是全量参数,后者是PEFT添加的少量参数。
【注】直接加载PEFT保存的少量参数,也可以加载到mT5模型本身的预训练参数。这是因为在peft模型保存的文件夹中有一个
adapter_config.json
文件,里面保存了基座模型的地址。
最终,可以观察到上述代码的输出为:
나는 피트니스에 가고 싶
나는 피트니스 클럽에 가
나는 피트니스 센터에 가
ジムに行きます。
ジムに行きたいです。
ジムに行くわ
그는 평생을 나를 사랑할
그는 평생 나를 사랑할 것
그는 평생 나를 사랑할 거
彼は私を愛してくれると言っていた。
彼は私を愛してくれると言った。
彼は私を愛してくれると言っていました。
[2] 第三方接口设计
我们把模型推理简单地设计成一个GET请求的接口,代码如下:
# coding: UTF-8
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BertModel, T5Model
from conf import conf
from flask import Flask, request, jsonifyapp = Flask(__name__)# 导入模型
if conf.is_peft:model = AutoModelForSeq2SeqLM.from_pretrained(conf.peft_save)
else:model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)model.load_state_dict(torch.load(conf.save_path))
model.to(conf.device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(conf.pretrained_path)@app.route('/translate', methods=['GET'])
def translate():# 从GET请求中获取参数sentences = request.args.getlist('sentence')if not sentences:return jsonify({"error": "No sentences provided."}), 400# 对句子进行编码ids = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sentences,return_tensors='pt',padding='max_length',truncation=True,max_length=conf.max_seq_len,return_attention_mask=False)input_ids = ids['input_ids'].to(conf.device)# 生成翻译结果output_tokens = model.generate(input_ids, num_beams=10, num_return_sequences=3)# 解码翻译结果translations = [tokenizer.decode(token_set, skip_special_tokens=True) for token_set in output_tokens]# 返回结果return jsonify({"translations": translations})if __name__ == '__main__':app.run(host='0.0.0.0', port=5000)
然后就能去浏览器快乐地测试玩耍了。