✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。
本文目录
- 简介(Introduction)
- 项目结构(Structure)
- 部署(Deployment)
- 克隆项目(Clone Project)
- 安装依赖(Requirements)
- 训练模型(Train Model)
- 启动 Web UI(Start Web UI)
- 项目演示(Demo)
- Web UI
- 学习率变化(Learning Rate Schedule)
- 训练历史(Training History)
- 代码(Code)
- 配置参数(Config)
- 数据集(Dataset)
- 模型(Model)
- 工具(Utils)
- 训练(Train)
- Web UI
简介(Introduction)
本项目是一个基于 Transformer 的中文对联生成器,使用 PyTorch 构建模型,使用 Gradio 构建 Web UI。
数据集:https://www.kaggle.com/datasets/marquis03/chinese-couplets-dataset
GitHub 仓库:https://github.com/Marquis03/Chinese-Couplets-Generator-based-on-Transformer
Gitee 仓库:https://gitee.com/marquis03/Chinese-Couplets-Generator-based-on-Transformer
项目结构(Structure)
.
├── config
│ ├── __init__.py
│ └── config.py
├── data
│ ├── fixed_couplets_in.txt
│ └── fixed_couplets_out.txt
├── dataset
│ ├── __init__.py
│ └── dataset.py
├── img
│ ├── history.png
│ ├── lr_schedule.png
│ └── webui.gif
├── model
│ ├── __init__.py
│ └── model.py
├── trained
│ ├── vocab.pkl
│ └── CoupletsTransformer_best.pth
├── utils
│ ├── __init__.py
│ └── EarlyStopping.py
├── LICENSE
├── README.md
├── requirements.txt
├── train.py
└── webui.py
部署(Deployment)
克隆项目(Clone Project)
git clone https://github.com/Marquis03/Chinese-Couplets-Generator-based-on-Transformer.git
cd Chinese-Couplets-Generator-based-on-Transformer
安装依赖(Requirements)
pip install -r requirements.txt
训练模型(Train Model)
python train.py
Kaggle Notebook: https://www.kaggle.com/code/marquis03/chinese-couplets-generator-based-on-transformer
启动 Web UI(Start Web UI)
python webui.py
项目演示(Demo)
Web UI
学习率变化(Learning Rate Schedule)
训练历史(Training History)
代码(Code)
配置参数(Config)
该部分用于配置项目的参数,包括全局参数、路径参数、模型参数、训练参数和日志参数。
对应项目文件为 config/config.py
。
import os
import sys
import time
import torch
from loguru import loggerclass Config:def __init__(self):# globalself.seed = 0self.cuDNN = Trueself.debug = Falseself.num_workers = 0self.str_time = time.strftime("%Y-%m-%dT%H%M%S", time.localtime(time.time()))# pathself.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))self.dataset_dir = os.path.join(self.project_dir, "data")self.in_path = os.path.join(self.dataset_dir, "fixed_couplets_in.txt")self.out_path = os.path.join(self.dataset_dir, "fixed_couplets_out.txt")self.log_dir = os.path.join(self.project_dir, "logs")self.save_dir = os.path.join(self.log_dir, self.str_time)self.img_save_dir = os.path.join(self.save_dir, "images")self.model_save_dir = os.path.join(self.save_dir, "checkpoints")for path in (self.log_dir,self.save_dir,self.img_save_dir,self.model_save_dir,):if not os.path.exists(path):os.makedirs(path)# modelself.d_model = 256self.num_head = 8self.num_encoder_layers = 2self.num_decoder_layers = 2self.dim_feedforward = 1024self.dropout = 0.1# trainself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.batch_size = 128self.val_ratio = 0.1self.epochs = 20self.warmup_ratio = 0.12self.lr_max = 1e-3self.lr_min = 1e-4self.beta1 = 0.9self.beta2 = 0.98self.epsilon = 10e-9self.weight_decay = 0.01self.early_stop = Trueself.patience = 4self.delta = 0# loglogger.remove()level_std = "DEBUG" if self.debug else "INFO"logger.add(sys.stdout,colorize=True,format="[<green>{time:YYYY-MM-DD HH:mm:ss,SSS}</green>|<level>{level: <8}</level>|<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan>] >>> <level>{message}</level>",level=level_std,)logger.add(os.path.join(self.save_dir, f"{self.str_time}.log"),format="[{time:YYYY-MM-DD HH:mm:ss,SSS}|{level: <8}|{name}:{function}:{line}] >>> {message}",level="INFO",)logger.info("### Config:")for key, value in self.__dict__.items():logger.info(f"### {key:20} = {value}")
数据集(Dataset)
该部分用于定义词典、数据集以及相关函数,包括数据的加载、词典的构建、数据集的封装和数据集的加载器。
对应项目文件为 dataset/dataset.py
。
from collections import Counterimport torch
import numpy as np
from torch.utils.data import Dataset, DataLoaderdef load_data(filepaths, tokenizer=lambda s: s.strip().split()):raw_in_iter = iter(open(filepaths[0], encoding="utf8"))raw_out_iter = iter(open(filepaths[1], encoding="utf8"))return list(zip(map(tokenizer, raw_in_iter), map(tokenizer, raw_out_iter)))class Vocab(object):UNK = "<unk>" # 0PAD = "<pad>" # 1BOS = "<bos>" # 2EOS = "<eos>" # 3def __init__(self, data=None, min_freq=1):counter = Counter()for lines in data:counter.update(lines[0])counter.update(lines[1])self.word2idx = {Vocab.UNK: 0, Vocab.PAD: 1, Vocab.BOS: 2, Vocab.EOS: 3}self.idx2word = {0: Vocab.UNK, 1: Vocab.PAD, 2: Vocab.BOS, 3: Vocab.EOS}idx = 4for word, freq in counter.items():if freq >= min_freq:self.word2idx[word] = idxself.idx2word[idx] = wordidx += 1def __len__(self):return len(self.word2idx)def __getitem__(self, word):return self.word2idx.get(word, 0)def __call__(self, word):if not isinstance(word, (list, tuple)):return self[word]return [self[w] for w in word]def to_tokens(self, indices):if not isinstance(indices, (list, tuple, np.ndarray, torch.Tensor)):return self.idx2word[int(indices)]return [self.idx2word[int(i)] for i in indices]def pad_sequence(sequences, batch_first=False, padding_value=0):max_len = max([s.size(0) for s in sequences])out_tensors = []for tensor in sequences:padding_content = [padding_value] * (max_len - tensor.size(0))tensor = torch.cat([tensor, torch.tensor(padding_content)], dim=0)out_tensors.append(tensor)out_tensors = torch.stack(out_tensors, dim=1)if batch_first:out_tensors = out_tensors.transpose(0, 1)return out_tensors.long()class CoupletsDataset(Dataset):def __init__(self, data, vocab):self.data = dataself.vocab = vocabself.PAD_IDX = self.vocab[self.vocab.PAD]self.BOS_IDX = self.vocab[self.vocab.BOS]self.EOS_IDX = self.vocab[self.vocab.EOS]def __len__(self):return len(self.data)def __getitem__(self, index):raw_in, raw_out = self.data[index]in_tensor_ = torch.LongTensor(self.vocab(raw_in))out_tensor_ = torch.LongTensor(self.vocab(raw_out))return in_tensor_, out_tensor_def collate_fn(self, batch):in_batch, out_batch = [], []for in_, out_ in batch:in_batch.append(in_)out_ = torch.cat([torch.LongTensor([self.BOS_IDX]),out_,torch.LongTensor([self.EOS_IDX]),],dim=0,)out_batch.append(out_)in_batch = pad_sequence(in_batch, True, self.PAD_IDX)out_batch = pad_sequence(out_batch, True, self.PAD_IDX)return in_batch, out_batchdef get_loader(self, batch_size, shuffle=False, num_workers=0):return DataLoader(self,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,collate_fn=self.collate_fn,pin_memory=True,)
模型(Model)
该部分用于定义模型,包括 TokenEmbedding、PositionalEncoding 和 CoupletsTransformer。
对应项目文件为 model/model.py
。
import math
import torch
import torch.nn as nnclass TokenEmbedding(nn.Module):def __init__(self, vocab_size, emb_size):super(TokenEmbedding, self).__init__()self.embedding = nn.Embedding(vocab_size, emb_size)self.emb_size = emb_sizedef forward(self, tokens):return self.embedding(tokens) * math.sqrt(self.emb_size)class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer("pe", pe)def forward(self, x):x = x + self.pe[:, : x.size(1)]return self.dropout(x)class CoupletsTransformer(nn.Module):def __init__(self,vocab_size,d_model=512,nhead=8,num_encoder_layers=6,num_decoder_layers=6,dim_feedforward=2048,dropout=0.1,):super(CoupletsTransformer, self).__init__()self.name = "CoupletsTransformer"self.token_embedding = TokenEmbedding(vocab_size, d_model)self.pos_embedding = PositionalEncoding(d_model, dropout)self.transformer = nn.Transformer(d_model=d_model,nhead=nhead,num_encoder_layers=num_encoder_layers,num_decoder_layers=num_decoder_layers,dim_feedforward=dim_feedforward,dropout=dropout,batch_first=True,)self.fc = nn.Linear(d_model, vocab_size)self._reset_parameters()def _reset_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, src, tgt, padding_value=0):src_embed = self.token_embedding(src) # [batch_size, src_len, embed_dim]src_embed = self.pos_embedding(src_embed) # [batch_size, src_len, embed_dim]tgt_embed = self.token_embedding(tgt) # [batch_size, tgt_len, embed_dim]tgt_embed = self.pos_embedding(tgt_embed) # [batch_size, tgt_len, embed_dim]tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(-1)).to(tgt.device)src_key_padding_mask = src == padding_value # [batch_size, src_len]tgt_key_padding_mask = tgt == padding_value # [batch_size, tgt_len]outs = self.transformer(src=src_embed,tgt=tgt_embed,tgt_mask=tgt_mask,src_key_padding_mask=src_key_padding_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=src_key_padding_mask,) # [batch_size, tgt_len, embed_dim]logits = self.fc(outs) # [batch_size, tgt_len, vocab_size]return logitsdef encoder(self, src):src_embed = self.token_embedding(src)src_embed = self.pos_embedding(src_embed)memory = self.transformer.encoder(src_embed)return memorydef decoder(self, tgt, memory):tgt_embed = self.token_embedding(tgt)tgt_embed = self.pos_embedding(tgt_embed)outs = self.transformer.decoder(tgt_embed, memory=memory)return outsdef generate(self, text, vocab):self.eval()device = next(self.parameters()).devicemax_len = len(text)src = torch.LongTensor(vocab(list(text))).unsqueeze(0).to(device)memory = self.encoder(src)l_out = [vocab.BOS]for i in range(max_len):tgt = torch.LongTensor(vocab(l_out)).unsqueeze(0).to(device)outs = self.decoder(tgt, memory)prob = self.fc(outs[:, -1, :])next_token = vocab.to_tokens(prob.argmax(1).item())if next_token == vocab.EOS:breakl_out.append(next_token)return "".join(l_out[1:])
工具(Utils)
该部分用于定义工具函数,包括 EarlyStopping。
对应项目文件为 utils/EarlyStopping.py
。
class EarlyStopping(object):def __init__(self, patience=7, delta=0):self.patience = patienceself.counter = 0self.best_score = Noneself.early_stop = Falseself.val_loss_min = float("inf")self.delta = deltadef __call__(self, val_loss, model):score = -val_lossif self.best_score is None:self.best_score = scoreelif score < self.best_score + self.delta:self.counter += 1if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.counter = 0
训练(Train)
该部分用于定义训练函数,包括训练、验证和保存模型。
对应项目文件为 train.py
。
import os
import gc
import time
import math
import random
import joblib
import warningswarnings.filterwarnings("ignore")import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_splitsns.set_theme(style="darkgrid", font_scale=1.2, font="SimHei", rc={"axes.unicode_minus": False}
)import torch
from torch import nn, optim
from torch.optim.lr_scheduler import LambdaLRfrom config import Config
from model import CoupletsTransformer
from dataset import load_data, Vocab, CoupletsDataset
from utils import EarlyStoppingdef train_model(config, model, train_loader, val_loader, optimizer, criterion, scheduler
):model = model.to(config.device)best_loss = float("inf")history = []model_path = os.path.join(config.model_save_dir, f"{model.name}_best.pth")if config.early_stop:early_stopping = EarlyStopping(patience=config.patience, delta=config.delta)for epoch in tqdm(range(1, config.epochs + 1), desc=f"All"):train_loss = train_one_epoch(config, model, train_loader, optimizer, criterion, scheduler)val_loss = evaluate(config, model, val_loader, criterion)perplexity = math.exp(val_loss)history.append((epoch, train_loss, val_loss))msg = f"Epoch {epoch}/{config.epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, Perplexity: {perplexity:.4f}"logger.info(msg)if val_loss < best_loss:logger.info(f"Val loss decrease from {best_loss:>10.6f} to {val_loss:>10.6f}")torch.save(model.state_dict(), model_path)best_loss = val_lossif config.early_stop:early_stopping(val_loss, model)if early_stopping.early_stop:logger.info(f"Early stopping at epoch {epoch}")breaklogger.info(f"Save best model with val loss {best_loss:.6f} to {model_path}")model_path = os.path.join(config.model_save_dir, f"{model.name}_last.pth")torch.save(model.state_dict(), model_path)logger.info(f"Save last model with val loss {val_loss:.6f} to {model_path}")history = pd.DataFrame(history, columns=["Epoch", "Train Loss", "Val Loss"]).set_index("Epoch")history.plot(subplots=True, layout=(1, 2), sharey="row", figsize=(14, 6), marker="o", lw=2)history_path = os.path.join(config.img_save_dir, "history.png")plt.savefig(history_path, dpi=300)logger.info(f"Save history to {history_path}")def train_one_epoch(config, model, train_loader, optimizer, criterion, scheduler):model.train()train_loss = 0for src, tgt in tqdm(train_loader, desc=f"Epoch", leave=False):src, tgt = src.to(config.device), tgt.to(config.device)output = model(src, tgt[:, :-1], config.PAD_IDX)output = output.contiguous().view(-1, output.size(-1))tgt = tgt[:, 1:].contiguous().view(-1)loss = criterion(output, tgt)train_loss += loss.item()optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()return train_loss / len(train_loader)def evaluate(config, model, val_loader, criterion):model.eval()val_loss = 0with torch.no_grad():for src, tgt in tqdm(val_loader, desc=f"Val", leave=False):src, tgt = src.to(config.device), tgt.to(config.device)output = model(src, tgt[:, :-1], config.PAD_IDX)output = output.contiguous().view(-1, output.size(-1))tgt = tgt[:, 1:].contiguous().view(-1)loss = criterion(output, tgt)val_loss += loss.item()return val_loss / len(val_loader)def test_model(model, data, vocab):model.eval()for src_text, tgt_text in data:src_text, tgt_text = "".join(src_text), "".join(tgt_text)out_text = model.generate(src_text, vocab)logger.info(f"\nInput: {src_text}\nTarget: {tgt_text}\nOutput: {out_text}")def seed_everything(seed):os.environ["PYTHONHASHSEED"] = str(seed)random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)def main():config = Config()# Set random seedseed_everything(config.seed)logger.info(f"Set random seed to {config.seed}")# Set cuDNNif config.cuDNN:torch.backends.cudnn.enabled = Truetorch.backends.cudnn.benchmark = Truetorch.backends.cudnn.deterministic = True# Load datadata = load_data([config.in_path, config.out_path])if config.debug:data = data[:1000]logger.info(f"Load {len(data)} couplets")# Build vocabvocab = Vocab(data)vocab_size = len(vocab)logger.info(f"Build vocab with {vocab_size} tokens")vocab_path = os.path.join(config.model_save_dir, "vocab.pkl")joblib.dump(vocab, vocab_path)logger.info(f"Save vocab to {vocab_path}")# Build datasetdata_train, data_val = train_test_split(data, test_size=config.val_ratio, random_state=config.seed, shuffle=True)train_dataset = CoupletsDataset(data_train, vocab)val_dataset = CoupletsDataset(data_val, vocab)config.PAD_IDX = train_dataset.PAD_IDXlogger.info(f"Build train dataset with {len(train_dataset)} samples")logger.info(f"Build val dataset with {len(val_dataset)} samples")# Build dataloadertrain_loader = train_dataset.get_loader(config.batch_size, shuffle=True, num_workers=config.num_workers)val_loader = val_dataset.get_loader(config.batch_size, shuffle=False, num_workers=config.num_workers)logger.info(f"Build train dataloader with {len(train_loader)} batches")logger.info(f"Build val dataloader with {len(val_loader)} batches")# Build modelmodel = CoupletsTransformer(vocab_size=vocab_size,d_model=config.d_model,nhead=config.num_head,num_encoder_layers=config.num_encoder_layers,num_decoder_layers=config.num_decoder_layers,dim_feedforward=config.dim_feedforward,dropout=config.dropout,)logger.info(f"Build model with {model.name}")# Build optimizeroptimizer = optim.AdamW(model.parameters(),lr=1,betas=(config.beta1, config.beta2),eps=config.epsilon,weight_decay=config.weight_decay,)# Build criterioncriterion = nn.CrossEntropyLoss(ignore_index=config.PAD_IDX, reduction="mean")# Build schedulerlr_max, lr_min = config.lr_max, config.lr_minT_max = config.epochs * len(train_loader)warm_up_iter = int(T_max * config.warmup_ratio)def WarmupExponentialLR(cur_iter):gamma = math.exp(math.log(lr_min / lr_max) / (T_max - warm_up_iter))if cur_iter < warm_up_iter:return (lr_max - lr_min) * (cur_iter / warm_up_iter) + lr_minelse:return lr_max * gamma ** (cur_iter - warm_up_iter)scheduler = LambdaLR(optimizer, lr_lambda=WarmupExponentialLR)df_lr = pd.DataFrame([WarmupExponentialLR(i) for i in range(T_max)],columns=["Learning Rate"],)plt.figure(figsize=(10, 6))sns.lineplot(data=df_lr, linewidth=2)plt.title("Learning Rate Schedule")plt.xlabel("Iteration")plt.ylabel("Learning Rate")lr_img_path = os.path.join(config.img_save_dir, "lr_schedule.png")plt.savefig(lr_img_path, dpi=300)logger.info(f"Save learning rate schedule to {lr_img_path}")# Garbage collectgc.collect()torch.cuda.empty_cache()# Train modeltrain_model(config, model, train_loader, val_loader, optimizer, criterion, scheduler)# Test modeltest_model(model, data_val[:10], vocab)if __name__ == "__main__":main()
Web UI
该部分用于定义 Web UI,包括输入、输出和启动 Web UI。
对应项目文件为 webui.py
。
import random
import joblibimport torch
import gradio as grfrom dataset import Vocab
from model import CoupletsTransformerdata_path = "./data/fixed_couplets_in.txt"
vocab_path = "./trained/vocab.pkl"
model_path = "./trained/CoupletsTransformer_best.pth"vocab = joblib.load(vocab_path)
vocab_size = len(vocab)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = CoupletsTransformer(vocab_size,d_model=256,nhead=8,num_encoder_layers=2,num_decoder_layers=2,dim_feedforward=1024,dropout=0.1,
).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()example = (line.replace(" ", "").strip() for line in iter(open(data_path, encoding="utf8"))
)
example = [line for line in example if len(line) > 5]example = random.sample(example, 300)def generate_couplet(vocab, model, src_text):if not src_text:return "上联不能为空"out_text = model.generate(src_text, vocab)return out_textinput_text = gr.Textbox(label="上联",placeholder="在这里输入上联",max_lines=1,lines=1,show_copy_button=True,autofocus=True,
)output_text = gr.Textbox(label="下联",placeholder="在这里生成下联",max_lines=1,lines=1,show_copy_button=True,
)demo = gr.Interface(fn=lambda x: generate_couplet(vocab, model, x),inputs=input_text,outputs=output_text,title="中文对联生成器",description="输入上联,生成下联",allow_flagging="never",submit_btn="生成下联",clear_btn="清空",examples=example,examples_per_page=50,
)demo.launch()