LSTM 词语模型上的动态量化

原文链接 

(beta) Dynamic Quantization on an LSTM Word Language Model — PyTorch Tutorials 2.3.0+cu121 documentation

引言

量化涉及将模型的权重和激活值从浮点数转换为整数,这样可以缩小模型大小,加快推理速度,但对准确性的影响很小。
在本教程中,我们将把最简单的量化形式--动态量化--应用到基于 LSTM 的下一个单词预测模型中,这与 PyTorch 示例中的单词语言模型密切相关。

# imports
import os
from io import open
import timeimport torch
import torch.nn as nn
import torch.nn.functional as F

 定义模型

  在此,我们按照单词语言模型示例中的模型,定义 LSTM 模型架构。

class LSTMModel(nn.Module):"""Container module with an encoder, a recurrent module, and a decoder."""def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):super(LSTMModel, self).__init__()self.drop = nn.Dropout(dropout)self.encoder = nn.Embedding(ntoken, ninp)self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)self.decoder = nn.Linear(nhid, ntoken)self.init_weights()self.nhid = nhidself.nlayers = nlayersdef init_weights(self):initrange = 0.1self.encoder.weight.data.uniform_(-initrange, initrange)self.decoder.bias.data.zero_()self.decoder.weight.data.uniform_(-initrange, initrange)def forward(self, input, hidden):emb = self.drop(self.encoder(input))output, hidden = self.rnn(emb, hidden)output = self.drop(output)decoded = self.decoder(output)return decoded, hiddendef init_hidden(self, bsz):weight = next(self.parameters())return (weight.new_zeros(self.nlayers, bsz, self.nhid),weight.new_zeros(self.nlayers, bsz, self.nhid))

加载文本数据

 接下来,我们将 Wikitext-2 数据集加载到[Corpus]{.title-ref}中,同样按照单词语言模型示例进行预处理。

class Dictionary(object):def __init__(self):self.word2idx = {}self.idx2word = []def add_word(self, word):if word not in self.word2idx:self.idx2word.append(word)self.word2idx[word] = len(self.idx2word) - 1return self.word2idx[word]def __len__(self):return len(self.idx2word)class Corpus(object):def __init__(self, path):self.dictionary = Dictionary()self.train = self.tokenize(os.path.join(path, 'train.txt'))self.valid = self.tokenize(os.path.join(path, 'valid.txt'))self.test = self.tokenize(os.path.join(path, 'test.txt'))def tokenize(self, path):"""Tokenizes a text file."""print(path)assert os.path.exists(path), f"Error: The path {path} does not exist."# Add words to the dictionarywith open(path, 'r', encoding="utf8") as f:for line in f:words = line.split() + ['<eos>']for word in words:self.dictionary.add_word(word)# Tokenize file contentwith open(path, 'r', encoding="utf8") as f:idss = []for line in f:words = line.split() + ['<eos>']ids = []for word in words:ids.append(self.dictionary.word2idx[word])idss.append(torch.tensor(ids).type(torch.int64))ids = torch.cat(idss)return idsmodel_data_filepath = ".\data\\"corpus = Corpus(model_data_filepath + 'wikitext-2')

加载预训练模型

 这是一个关于动态量化的教程,一种在模型训练完成后应用的量化技术。因此,我们只需将一些预先训练好的权重加载到该模型架构中;这些权重是通过使用单词语言模型示例中的默认设置进行五次历时训练获得的。

ntokens = len(corpus.dictionary)model = LSTMModel(ntoken=ntokens,ninp=512,nhid=256,nlayers=5,
)# model.load_state_dict(
#     torch.load(
#         model_data_filepath + 'word_language_model_quantize.pth',
#         map_location=torch.device('cpu')
#     )
# )model.eval()
print(model)

现在让我们生成一些文本,以确保预训练模型正常工作 - 与之前类似,我们遵循此处

input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000with open(model_data_filepath + 'out.txt', 'w') as outf:with torch.no_grad():  # no tracking historyfor i in range(num_words):output, hidden = model(input_, hidden)word_weights = output.squeeze().div(temperature).exp().cpu()word_idx = torch.multinomial(word_weights, 1)[0]input_.fill_(word_idx)word = corpus.dictionary.idx2word[word_idx]outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))if i % 100 == 0:print('| Generated {}/{} words'.format(i, 1000))with open(model_data_filepath + 'out.txt', 'r') as outf:all_output = outf.read()print(all_output)

虽然不是 GPT-2,但看起来模型已经开始学习语言结构了!
我们差不多可以演示动态量化了。我们只需要再定义几个辅助函数:

bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1# create test data set
def batchify(data, bsz):# Work out how cleanly we can divide the dataset into ``bsz`` parts.nbatch = data.size(0) // bsz# Trim off any extra elements that wouldn't cleanly fit (remainders).data = data.narrow(0, 0, nbatch * bsz)# Evenly divide the data across the ``bsz`` batches.return data.view(bsz, -1).t().contiguous()test_data = batchify(corpus.test, eval_batch_size)# Evaluation functions
def get_batch(source, i):seq_len = min(bptt, len(source) - 1 - i)data = source[i:i + seq_len]target = source[i + 1:i + 1 + seq_len].reshape(-1)return data, targetdef repackage_hidden(h):"""Wraps hidden states in new Tensors, to detach them from their history."""if isinstance(h, torch.Tensor):return h.detach()else:return tuple(repackage_hidden(v) for v in h)def evaluate(model_, data_source):# Turn on evaluation mode which disables dropout.model_.eval()total_loss = 0.hidden = model_.init_hidden(eval_batch_size)with torch.no_grad():for i in range(0, data_source.size(0) - 1, bptt):data, targets = get_batch(data_source, i)output, hidden = model_(data, hidden)hidden = repackage_hidden(hidden)output_flat = output.view(-1, ntokens)total_loss += len(data) * criterion(output_flat, targets).item()return total_loss / (len(data_source) - 1)

测试动态量化

最后,我们可以在模型上调用 torch.quantization.quantize_dynamic!具体来说就是
我们指定要对模型中的 nn.LSTM 和 nn.Linear 模块进行量化
我们指定要将权重转换为 int8 值

import torch.quantizationquantized_model = torch.quantization.quantize_dynamic(model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)# 模型看起来没有变化,这对我们有什么好处呢?首先,我们看到模型的尺寸大幅缩小:
def print_size_of_model(model):torch.save(model.state_dict(), "temp.p")print('Size (MB):', os.path.getsize("temp.p") / 1e6)os.remove('temp.p')print_size_of_model(model)
print_size_of_model(quantized_model)

其次,我们看到推理时间更快,而评估损失没有区别:
注:我们将单线程比较的线程数设为一个,因为量化模型是单线程运行的。

torch.set_num_threads(1)def time_model_evaluation(model, test_data):s = time.time()loss = evaluate(model, test_data)elapsed = time.time() - sprint('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)

在本地 MacBook Pro 上运行这个程序,在不进行量化的情况下,推理时间约为 200 秒,而在进行量化的情况下,推理时间仅为 100 秒左右。

 结论

动态量化是减少模型大小的一种简单方法,但对准确性的影响有限。
感谢您的阅读!我们一如既往地欢迎任何反馈,如果您有任何问题,请在此创建一个问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/26994.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GitHub每日最火火火项目(6.14)

以下是按照要求对每个项目的总结&#xff1a; 项目名称&#xff1a;huggingface / diffusers 项目介绍&#xff1a;diffusers 是一个强大的工具库&#xff0c;专注于图像和音频生成的扩散模型。它提供了一系列预训练模型和便捷的接口&#xff0c;使开发者能够轻松地探索和利用…

vue2项目更换element-ui的主题色(绝对有效,操作简单)

vue2项目更换element-ui的主题色(绝对有效&#xff0c;操作简单) 前言&#xff1a;使用vue2element-ui开发web端项目的朋友应该会有修改element-ui主题色的需求&#xff0c;然而 网上几年前就各种传言element-ui不再维护了&#xff0c;官网显示的最后一次更新日期为2023-08-24…

leetcode打卡#day42 62. 不同路径、63. 不同路径 II、343. 整数拆分、96. 不同的二叉搜索树

62. 不同路径 class Solution { public://动态规划int uniquePaths(int m, int n) {//dp数组&#xff0c;记录到达目的地的路径数vector<vector<int>> dp(m, vector(n, 0));//初始化for(int i0; i< m; i) dp[i][0] 1;for(int i0; i< n; i) dp[0][i] 1;//遍…

【智能家居控制系统项目】一、项目系统镜像烧录与系统登录

前言 完成本章节将可以获得本项目的系统UI界面功能。本章节主要介绍如何烧录项目系统镜像以及进入系统。配套的视频介绍可以点击跳转到智能家居项目复刻配套视频 1.系统功能页面介绍 完成本章全部步骤&#xff0c;我们将可使用以下项目系统功能界面。 1.1 家居总览界面 主界面…

【成品设计】基于STM32的单相瞬时值反馈逆变器

《基于STM32的单相瞬时值反馈逆变器》 整体功能&#xff1a; 图13 软件框图 如图13所示&#xff0c;由于本设计中需要通过定时器中断执行一些程序&#xff0c;故首先对中断进行初始化。中断初始化以后即为对串口进行初始化&#xff0c;总共初始化了两个串口&#xff0c;第一个…

Ubuntu软件操作的相关命令

更新源 : sudo apt-get update 安装包 : sudo apt-get install package 删除包 : sudo apt-get remove package 搜索软件包 : sudo apt-cache search package 获取包的相关信息&#xff0c;如说明、⼤⼩、版本等 : sudo apt-cache show package 重新安装包 : sudo apt-get…

SQL SERVER触发器记录指定的几笔资料更新记录

1.目的 为了记录数据库表中资料数据动态的变更&#xff0c;实时动态且方便调整记录的范围。 2.分析 需求:记录UPDATE 修改的记录 if exists(select 1 from inserted) and exists(select 1 from deleted) &#xff1a;修改if (exists (select 1 from inserted) and n…

Unity 设置窗口置顶超级详解版

目录 前言 一、user32.dll 1.什么是user32.dll 2.如何使用user32.dll 二、句柄Handle 1.句柄 2.句柄的功能 3.拿句柄的方法 三、窗口置顶 1.窗口置顶的方法 2.参数说明 3.使用方法 四、作者的碎碎念 前言 up依旧挑战全网讲解最详细版本~~ 本篇文章讲解的是unity…

基于软件在环的飞控机建模仿真

安全关键系统&#xff08;Safety-Critical System&#xff0c;SCS&#xff09;是指由于某些行为或组合行为能够引发整体系统失效&#xff0c;继而导致财物损失、人员受伤等严重影响的系统&#xff0c;诸多安全关键领域如航空航天、核电系统、医疗设备、交通运输等领域的系统都属…

网络编程---Java飞机大战联机

解析服务器端代码 代码是放在app/lib下的src下的main/java&#xff0c;而与之前放在app/src/main下路径不同 Main函数 Main函数里只放着创建MyServer类的一行 public static void main(String args[]){new MyServer();} MyServer构造函数 1.获取本机IP地址 //获取本机IP地…

捋清UITableView展示不同类型数据的差异

背景&#xff1a; UITableView可以展示分组数据和单组数据&#xff0c;一般这两种数据有4种情况&#xff1a; 单组数据的简单类型&#xff0c;本身为字典数组&#xff0c;内部字典key对应的value全为基本数据类型。&#xff08;如lol英雄展示案例&#xff0c;不分组且组内信息…

一五零、MAC 安装mysql可视化工具连接

mysql安装&#xff0c;按照网上教程一步步安装&#xff08;官网下载安装包->解压->完成安装&#xff09;&#xff0c;最后在「系统偏好设置」无法启动mysql。 原因&#xff1a;下载的版本是8.0最新版本&#xff0c;MAC上这种方法无法启动成功。 解决方法 换低版本的mys…

如何利用 Go 高效地构建大规模并发网络应用?

要利用Go高效地构建大规模并发网络应用&#xff0c;可以考虑以下几个方面&#xff1a; 使用Goroutine并发处理&#xff1a;Goroutine是Go语言中的轻量级线程&#xff0c;可以轻松创建成千上万个并发的任务。通过使用Goroutine&#xff0c;可以高效地处理大量的并发请求&#xf…

C#版 iText7——画发票PDF(完整)

显示描述&#xff1a; 1、每页显示必须带有发票头、“销售方和购买方信息” 2、明细填充为&#xff1a;当n≤8 行时&#xff0c;发票总高度140mm&#xff0c;每条发票明细行款高度4.375mm&#xff1b; 当8<n≤12行时&#xff0c;发票高度增加17.5mm&#xff0c;不换页&#…

我们一起聊聊 Go 性能工具

从开发到部署的整个过程都离不开基本的负载测试和性能剖析。利用 Go 的 pprof 和跟踪工具&#xff0c;开发人员可以深入了解性能瓶颈、CPU 使用率和内存分配情况。 在开发过程中&#xff0c;从一开始到应用程序的推出都充满了挑战&#xff0c;而负载测试则是其中至关重要的一项…

Kettle 数据抽取工具使用教程:从入门到实战

一、简介 Kettle 是 Pentaho Data Integration (PDI) 的一个组成部分&#xff0c;是一个开源的数据集成工具。它被广泛用于数据的抽取、转换和加载 (ETL) 过程。Kettle 提供了一个易于使用的图形界面&#xff0c;可以轻松设计和执行 ETL 流程。 github 源码地址&#xff1a;ht…

postman教程-19-mock测试

上一小节我们学习了Postman接口参数化方法&#xff0c;本小节我们讲解一下Postman mock测试的方法。 一、什么叫mock测试 mock测试就是在测试过程中&#xff0c;对某些不容易构造或者不容易获取的对象&#xff0c;用一个虚拟的对象来创建以便于测试的一种测试方法&#xff0c…

chatgpt 生成的 左侧导航功能的网页

目录 一、左侧导航 1、效果如下&#xff1a; 2、代码如下&#xff1a; 3、技术点&#xff1a; 1)、箭头居中 2)、导航区域 3)、导航隐藏时&#xff0c;正文重新居中 4)、设置视口高度 这是用chatgpt生成的网页&#xff0c;其实&#xff0c;不是一下子就生成了满足需求的…

Syncovery:跨平台高效文件备份与同步的得力助手

在数字化时代&#xff0c;数据安全与文件同步已成为个人及企业不可或缺的需求。Syncovery作为一款专为Mac和Windows用户设计的文件备份和同步工具&#xff0c;凭借其高效、安全和易用的特点&#xff0c;赢得了广泛赞誉。 一、强大备份功能 Syncovery支持多种备份方案和数据格…

LeetCode 119.杨辉三角 II

1.题目要求如图所示: 示例 1:输入: rowIndex 3 输出: [1,3,3,1]示例 2:输入: rowIndex 0 输出: [1]示例 3:输入: rowIndex 1 输出: [1,1]先用malloc函数创造一个二维数组&#xff0c;变成杨辉三角&#xff0c;然后再用一维数组找到所指的那一行: /*** Note: The returned…