文章目录 PyTorch示例——使用Transformer写古诗 1. 前言 2. 版本信息 3. 导包 4. 数据与预处理 数据下载 先看一下原始数据 开始处理数据,过滤掉异常数据 定义 词典编码器 Tokenizer 定义数据集类 MyDataset 测试一下MyDataset、Tokenizer、DataLoader 5. 构建模型 位置编码器 PositionalEncoding 古诗 Transformer 模型 6. 开始训练 7. 推理 8. 更多学习资料
PyTorch示例——使用Transformer写古诗
1. 前言
很早、很早以前,在TensorFlow2 学习——RNN生成古诗词_rnn古诗生成头词汇是 “ 日 、 红 、 山 、 夜 、 湖、 海 、 月 。-CSDN博客中已使用TensorFlow+RNN的方式实现过写古诗的功能,现在来个Pytorch+Transformer 的示例😄 数据处理逻辑和前面博文中大致相似,本文中就不再赘述 Kaggle Notebook地址: PyTorch示例-使用Transformer写古诗x
2. 版本信息
PyTorch: 2.1.2
Python: 3.10.13
3. 导包
import math
import numpy as np
from collections import Counter
import torch
from torch import nn
from torch. utils. data import TensorDataset
from torch. utils. data import DataLoader
import tqdm
import random
import sysprint ( "Pytorch 版本:" , torch. __version__)
print ( "Python 版本:" , sys. version)
Pytorch 版本: 2.1.2
Python 版本: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
4. 数据与预处理
数据下载
度盘: https://pan.baidu.com/s/1HIROi4mPMv0RBWHIHvUDVg,提取码:b2pp Kaggle:https://www.kaggle.com/datasets/alionsss/poetry
先看一下原始数据
DATA_PATH = '/kaggle/input/poetry/poetry.txt'
with open ( DATA_PATH, 'r' , encoding= 'utf-8' ) as f: lines = f. readlines( ) for i in range ( 0 , 5 ) : print ( lines[ i] ) print ( f"origin_line_count = { len ( lines) } " )
首春:寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。初晴落景:晚霞聊自怡,初晴弥可喜。日晃百花色,风动千林翠。池鱼跃不同,园鸟声还异。寄言博通者,知予物外志。初夏:一朝春夏改,隔夜鸟花迁。阴阳深浅叶,晓夕重轻烟。哢莺犹响殿,横丝正网天。珮高兰影接,绶细草纹连。碧鳞惊棹侧,玄燕舞檐前。何必汾阳处,始复有山泉。度秋:夏律昨留灰,秋箭今移晷。峨嵋岫初出,洞庭波渐起。桂白发幽岩,菊黄开灞涘。运流方可叹,含毫属微理。仪鸾殿早秋:寒惊蓟门叶,秋发小山枝。松阴背日转,竹影避风移。提壶菊花岸,高兴芙蓉池。欲知凉气早,巢空燕不窥。origin_line_count = 43030
开始处理数据,过滤掉异常数据
MAX_LEN = 64
MIN_LEN = 5
DISALLOWED_WORDS = [ '(' , ')' , '(' , ')' , '__' , '《' , '》' , '【' , '】' , '[' , ']' , '?' , ';' ]
poetry = [ ]
with open ( DATA_PATH, 'r' , encoding= 'utf-8' ) as f: lines = f. readlines( )
for line in lines: fields = line. split( ":" ) if len ( fields) != 2 : continue content = fields[ 1 ] if len ( content) > MAX_LEN - 2 or len ( content) < MIN_LEN: continue if any ( word in content for word in DISALLOWED_WORDS) : continue poetry. append( content. replace( '\n' , '' ) )
for i in range ( 0 , 5 ) : print ( poetry[ i] ) print ( f"current_line_count = { len ( poetry) } " )
寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。
晚霞聊自怡,初晴弥可喜。日晃百花色,风动千林翠。池鱼跃不同,园鸟声还异。寄言博通者,知予物外志。
夏律昨留灰,秋箭今移晷。峨嵋岫初出,洞庭波渐起。桂白发幽岩,菊黄开灞涘。运流方可叹,含毫属微理。
寒惊蓟门叶,秋发小山枝。松阴背日转,竹影避风移。提壶菊花岸,高兴芙蓉池。欲知凉气早,巢空燕不窥。
山亭秋色满,岩牖凉风度。疏兰尚染烟,残菊犹承露。古石衣新苔,新巢封古树。历览情无极,咫尺轮光暮。
current_line_count = 24375
过滤掉出现频率较低的字符串,后面统一当作 UNKNOWN
MIN_WORD_FREQUENCY = 8
counter = Counter( )
for line in poetry: counter. update( line)
tokens = [ token for token, count in counter. items( ) if count >= MIN_WORD_FREQUENCY]
for i, ( token, count) in enumerate ( counter. items( ) ) : print ( token, "->" , count) if i >= 4 : break ;
寒 -> 2612
随 -> 1036
穷 -> 482
律 -> 118
变 -> 286
定义 词典编码器 Tokenizer
class Tokenizer : """词典编码器""" UNKNOWN = "<unknown>" PAD = "<pad>" BOS = "<bos>" EOS = "<eos>" def __init__ ( self, tokens) : tokens = [ Tokenizer. UNKNOWN, Tokenizer. PAD, Tokenizer. BOS, Tokenizer. EOS] + tokensself. dict_size = len ( tokens) self. token_id = { } self. id_token = { } for idx, word in enumerate ( tokens) : self. token_id[ word] = idxself. id_token[ idx] = wordself. unknown_id = self. token_id[ Tokenizer. UNKNOWN] self. pad_id = self. token_id[ Tokenizer. PAD] self. bos_id = self. token_id[ Tokenizer. BOS] self. eos_id = self. token_id[ Tokenizer. EOS] def id_to_token ( self, token_id) : """编号 -> 词""" return self. id_token. get( token_id) def token_to_id ( self, token) : """词 -> 编号,取不到时给 UNKNOWN""" return self. token_id. get( token, self. unknown_id) def encode ( self, tokens) : """词列表 -> <bos>编号 + 编号列表 + <eos>编号""" token_ids = [ self. bos_id, ] for token in tokens: token_ids. append( self. token_to_id( token) ) token_ids. append( self. eos_id) return token_idsdef decode ( self, token_ids) : """编号列表 -> 词列表(去掉起始、结束标记)""" tokens = [ ] for idx in token_ids: if idx != self. bos_id and idx != self. eos_id: tokens. append( self. id_to_token( idx) ) return tokensdef __len__ ( self) : return self. dict_size
定义数据集类 MyDataset
class MyDataset ( TensorDataset) : def __init__ ( self, data, tokenizer, max_length= 64 ) : self. data = dataself. tokenizer = tokenizerself. max_length = max_length def __getitem__ ( self, index) : line = self. data[ index] word_ids = self. encode_pad_line( line) return torch. LongTensor( word_ids) def __len__ ( self) : return len ( self. data) def encode_pad_line ( self, line) : word_ids = self. tokenizer. encode( line) word_ids = word_ids + [ tokenizer. pad_id] * ( self. max_length - len ( word_ids) ) return word_ids
测试一下MyDataset、Tokenizer、DataLoader
tokenizer = Tokenizer( tokens)
print ( "tokenizer_len: " , len ( tokenizer) )
my_dataset = MyDataset( poetry, tokenizer)
one_line_id = my_dataset[ 0 ] . tolist( )
print ( "one_line_id: " , one_line_id)
poetry_line = tokenizer. decode( one_line_id)
print ( "poetry_line: " , "" . join( [ w for w in poetry_line if w != Tokenizer. PAD] ) )
tokenizer_len: 3428
one_line_id: [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 9, 21, 22, 23, 24, 25, 15, 26, 27, 28, 29, 30, 9, 31, 32, 33, 34, 35, 15, 36, 37, 16, 38, 39, 9, 40, 41, 42, 43, 44, 15, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
poetry_line: 寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。
temp_dataloader = DataLoader( dataset= my_dataset, batch_size= 8 , shuffle= True ) one_batch_data = next ( iter ( temp_dataloader) ) for poetry_line_id in one_batch_data. tolist( ) : poetry_line = tokenizer. decode( poetry_line_id) print ( "" . join( [ w for w in poetry_line if w != Tokenizer. PAD] ) )
曲江春草生,紫阁雪分明。汲井尝泉味,听钟问寺名。墨研秋日雨,茶试老僧<unknown>。地近劳频访,乌纱出送迎。
旧隐无何别,归来始更悲。难寻白道士,不见惠禅师。草径虫鸣急,沙渠水下迟。却将波浪眼,清晓对红梨。
举世皆问人,唯师独求己。一马无四蹄,顷刻行千里。应笑北原上,丘坟乱如蚁。
海燕西飞白日斜,天门遥望五侯家。楼台深锁无人到,落尽春风第一花。
良人犹远戍,耿耿夜闺空。绣户流宵月,罗帷坐晓风。魂飞沙帐北,肠断玉关中。尚自无消息,锦衾那得同。
天涯片云去,遥指帝乡忆。惆怅增暮情,潇湘复秋色。扁舟宿何处,落日羡归翼。万里无故人,江鸥不相识。
宝鸡辞旧役,仙凤历遗墟。去此近城阙,青山明月初。
夜帆时未发,同侣暗相催。山晓月初下,江鸣潮欲来。稍分扬子岸,不辨越王台。自客水乡里,舟行知几回。
5. 构建模型
位置编码器 PositionalEncoding
class PositionalEncoding ( nn. Module) : def __init__ ( self, d_model, dropout, max_len= 2000 ) : super ( PositionalEncoding, self) . __init__( ) self. dropout = nn. Dropout( p= dropout) pe = torch. zeros( max_len, d_model) position = torch. arange( 0 , max_len) . unsqueeze( 1 ) div_term = torch. exp( torch. arange( 0 , d_model, 2 ) * - ( math. log( 10000.0 ) / d_model) ) pe[ : , 0 : : 2 ] = torch. sin( position * div_term) pe[ : , 1 : : 2 ] = torch. cos( position * div_term) pe = pe. unsqueeze( 0 ) self. register_buffer( "pe" , pe) def forward ( self, x) : """x 为embedding后的inputs,例如(1,7, 128),batch size为1,7个单词,单词维度为128""" x = x + self. pe[ : , : x. size( 1 ) ] . requires_grad_( False ) return self. dropout( x)
古诗 Transformer 模型
class PoetryModel ( nn. Module) : def __init__ ( self, num_embeddings = 4096 , embedding_dim= 128 ) : super ( PoetryModel, self) . __init__( ) self. embedding = nn. Embedding( num_embeddings= num_embeddings, embedding_dim= embedding_dim) self. transformer = nn. Transformer( d_model= embedding_dim, num_encoder_layers= 3 , num_decoder_layers= 3 , dim_feedforward= 512 ) self. positional_encoding = PositionalEncoding( embedding_dim, dropout= 0 ) self. predictor = nn. Linear( embedding_dim, num_embeddings) def forward ( self, src, tgt) : tgt_mask = nn. Transformer. generate_square_subsequent_mask( tgt. size( ) [ - 1 ] ) . to( DEVICE) src_key_padding_mask = PoetryModel. get_key_padding_mask( src) . to( DEVICE) tgt_key_padding_mask = PoetryModel. get_key_padding_mask( tgt) . to( DEVICE) src = self. embedding( src) tgt = self. embedding( tgt) src = self. positional_encoding( src) tgt = self. positional_encoding( tgt) out = self. transformer( src. permute( 1 , 0 , 2 ) , tgt. permute( 1 , 0 , 2 ) , tgt_mask= tgt_mask, src_key_padding_mask= src_key_padding_mask, tgt_key_padding_mask= tgt_key_padding_mask) return out@staticmethod def get_key_padding_mask ( tokens) : key_padding_mask = torch. zeros( tokens. size( ) ) key_padding_mask[ tokens == Tokenizer. PAD] = float ( '-inf' ) return key_padding_mask
6. 开始训练
EPOCH_NUM = 50
BATCH_SIZE = 64
DICT_SIZE = len ( tokenizer)
DEVICE = torch. device( 'cuda:0' if torch. cuda. is_available( ) else 'cpu' )
my_dataset = MyDataset( poetry, tokenizer)
train_dataloader = DataLoader( dataset= my_dataset, batch_size= BATCH_SIZE, shuffle= True )
model = PoetryModel( num_embeddings= DICT_SIZE) . to( DEVICE)
criteria = nn. CrossEntropyLoss( )
optimizer = torch. optim. Adam( model. parameters( ) , lr= 5e-4 ) print ( model)
PoetryModel((embedding): Embedding(3428, 128)(transformer): Transformer((encoder): TransformerEncoder((layers): ModuleList((0-2): 3 x TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True))(linear1): Linear(in_features=128, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=512, out_features=128, bias=True)(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)))(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True))(decoder): TransformerDecoder((layers): ModuleList((0-2): 3 x TransformerDecoderLayer((self_attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True))(multihead_attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True))(linear1): Linear(in_features=128, out_features=512, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=512, out_features=128, bias=True)(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)(dropout3): Dropout(p=0.1, inplace=False)))(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)))(positional_encoding): PositionalEncoding((dropout): Dropout(p=0, inplace=False))(predictor): Linear(in_features=128, out_features=3428, bias=True)
)
for epoch in range ( 1 , EPOCH_NUM + 1 ) : model. train( ) total_loss = 0 data_progress = tqdm. tqdm( train_dataloader, desc= "Train..." ) for step, data in enumerate ( data_progress, 1 ) : data = data. to( DEVICE) e = random. randint( 1 , 20 ) src = data[ : , : e] tgt, tgt_y = data[ : , e: - 1 ] , data[ : , e + 1 : ] out = model( src, tgt) out = model. predictor( out) loss = criteria( out. view( - 1 , out. size( - 1 ) ) , tgt_y. permute( 1 , 0 ) . contiguous( ) . view( - 1 ) ) optimizer. zero_grad( ) loss. backward( ) optimizer. step( ) total_loss += loss. item( ) data_progress. set_description( f"Train... [epoch { epoch} / { EPOCH_NUM} , loss { ( total_loss / step) : .5f } ]" )
Train... [epoch 1/50, loss 3.66444]: 100%|██████████| 381/381 [00:10<00:00, 35.40it/s]
Train... [epoch 2/50, loss 3.35216]: 100%|██████████| 381/381 [00:09<00:00, 39.61it/s]
Train... [epoch 3/50, loss 3.27860]: 100%|██████████| 381/381 [00:09<00:00, 39.44it/s]
Train... [epoch 4/50, loss 3.15286]: 100%|██████████| 381/381 [00:09<00:00, 39.10it/s]
Train... [epoch 5/50, loss 3.05621]: 100%|██████████| 381/381 [00:09<00:00, 39.32it/s]
Train... [epoch 6/50, loss 2.97613]: 100%|██████████| 381/381 [00:09<00:00, 39.42it/s]
Train... [epoch 7/50, loss 2.91857]: 100%|██████████| 381/381 [00:09<00:00, 38.83it/s]
Train... [epoch 8/50, loss 2.88052]: 100%|██████████| 381/381 [00:09<00:00, 39.59it/s]
Train... [epoch 9/50, loss 2.78789]: 100%|██████████| 381/381 [00:09<00:00, 39.19it/s]
Train... [epoch 10/50, loss 2.77379]: 100%|██████████| 381/381 [00:09<00:00, 38.24it/s]
......
Train... [epoch 41/50, loss 2.25991]: 100%|██████████| 381/381 [00:09<00:00, 39.89it/s]
Train... [epoch 42/50, loss 2.24437]: 100%|██████████| 381/381 [00:09<00:00, 39.72it/s]
Train... [epoch 43/50, loss 2.23779]: 100%|██████████| 381/381 [00:09<00:00, 39.09it/s]
Train... [epoch 44/50, loss 2.25092]: 100%|██████████| 381/381 [00:09<00:00, 39.16it/s]
Train... [epoch 45/50, loss 2.23653]: 100%|██████████| 381/381 [00:09<00:00, 39.90it/s]
Train... [epoch 46/50, loss 2.20175]: 100%|██████████| 381/381 [00:09<00:00, 39.51it/s]
Train... [epoch 47/50, loss 2.22046]: 100%|██████████| 381/381 [00:09<00:00, 39.83it/s]
Train... [epoch 48/50, loss 2.20892]: 100%|██████████| 381/381 [00:09<00:00, 39.84it/s]
Train... [epoch 49/50, loss 2.22276]: 100%|██████████| 381/381 [00:09<00:00, 39.35it/s]
Train... [epoch 50/50, loss 2.20212]: 100%|██████████| 381/381 [00:09<00:00, 39.75it/s]
7. 推理
直接推理
model. eval ( )
with torch. no_grad( ) : word_ids = tokenizer. encode( "清明时节" ) src = torch. LongTensor( [ word_ids[ : - 2 ] ] ) . to( DEVICE) tgt = torch. LongTensor( [ word_ids[ - 2 : - 1 ] ] ) . to( DEVICE) for i in range ( 64 ) : out = model( src, tgt) predict = model. predictor( out[ - 1 : ] ) y = torch. argmax( predict, dim= 2 ) tgt = torch. cat( [ tgt, y] , dim= 1 ) if y == tokenizer. eos_id: break src_decode = "" . join( [ w for w in tokenizer. decode( src[ 0 ] . tolist( ) ) if w != Tokenizer. PAD] ) print ( f"src = { src} , src_decode = { src_decode} " ) tgt_decode = "" . join( [ w for w in tokenizer. decode( tgt[ 0 ] . tolist( ) ) if w != Tokenizer. PAD] ) print ( f"tgt = { tgt} , tgt_decode = { tgt_decode} " )
src = tensor([[ 2, 403, 235, 293]], device='cuda:0'), src_decode = 清明时
tgt = tensor([[ 197, 9, 571, 324, 571, 116, 14, 15, 61, 770, 158, 514,934, 9, 228, 293, 493, 1108, 44, 15, 3]],device='cuda:0'), tgt_decode = 节,一夜一枝开。不是无人见,何时有鹤来。
为推理添加随机性
def predict ( model, src, tgt) : out = model( src, tgt) _probas = model. predictor( out[ - 1 : ] ) [ 0 , 0 , 3 : ] _probas = torch. exp( _probas) / torch. exp( _probas) . sum ( ) values, indices = torch. topk( _probas, 10 , dim= 0 ) target_index = torch. multinomial( values, 1 , replacement= True ) y = indices[ target_index] return y + 3
def generate_random_poem ( tokenizer, model, text) : """随机生成一首诗、自动续写""" if text == None or text == "" : text = tokenizer. id_to_token( random. randint( 4 , len ( tokenizer) ) ) model. eval ( ) with torch. no_grad( ) : word_ids = tokenizer. encode( text) src = torch. LongTensor( [ word_ids[ : - 2 ] ] ) . to( DEVICE) tgt = torch. LongTensor( [ word_ids[ - 2 : - 1 ] ] ) . to( DEVICE) for i in range ( 64 ) : y = predict( model, src, tgt) tgt = torch. cat( [ tgt, y. view( 1 , 1 ) ] , dim= 1 ) if y == tokenizer. eos_id: break result = torch. cat( [ src, tgt] , dim= 1 ) result_decode = "" . join( [ w for w in tokenizer. decode( result[ 0 ] . tolist( ) ) if w != Tokenizer. PAD] ) return result_decodefor i in range ( 0 , 5 ) : poetry_line = generate_random_poem( tokenizer, model, "清明" ) print ( poetry_line)
清明日已长安,不独为君一病身。唯有诗人知处在,更愁人夜月明。
清明月在何时,夜久山川有谁。今日不知名利处,一枝花落第花枝。
清明月上,风急水声。山月随人远,天河度陇平。水深秋月在,江远夜砧迎。莫问东楼兴,空怀不可情。
清明夜夜月,秋月满池塘。夜坐中琴月,空阶下菊香。风回孤枕月,月冷一枝香。惆怅江南客,明朝是此中。
生成藏头诗 的代码,请参考之前写的文章 TensorFlow2 学习——RNN生成古诗词_rnn古诗生成头词汇是 “ 日 、 红 、 山 、 夜 、 湖、 海 、 月 。-CSDN博客
8. 更多学习资料
相关文章 https://blog.csdn.net/zhaohongfei_358/article/details/126019181 https://blog.csdn.net/zhaohongfei_358/article/details/122861751 https://zhuanlan.zhihu.com/p/554013449 相关视频 https://www.bilibili.com/video/BV1Wv411h7kN?p=38 https://www.bilibili.com/video/BV1Wv411h7kN/?p=49 PyTorch官方 https://pytorch.org/tutorials/beginner/transformer_tutorial.html