Transformer - Positional Encoding 位置编码 代码实现

Transformer - Positional Encoding 位置编码 代码实现

flyfish

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import mathclass PositionalEncoding(nn.Module):def __init__(self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer("pe", pe)def forward(self, x):x = x +  self.pe[:, : x.size(1)].requires_grad_(False)return self.dropout(x)# 词嵌⼊维度是64维
d_model = 64
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=60x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)pe_result = pe(x)print("pe_result:", pe_result)

绘图

import numpy as np
import matplotlib.pyplot as plt
# 创建⼀张15 x 5⼤⼩的画布
plt.figure(figsize=(15, 5))pe = PositionalEncoding(d_model, 0, max_len)y = pe(torch.zeros(1, max_len, d_model))# 只查看3,4,5,6维的值.
plt.plot(np.arange(max_len), y[0, :, 3:7].data.numpy())plt.legend(["dim %d"%p for p in [3,4,5,6]])

在这里插入图片描述

register_buffer 的测试

# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transformsclass MLPNet (nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1 * 28 * 28, 128)self.fc2 =nn.Linear(128, 128)self.fc3 = nn.Linear(128, 10)self.dropout1=nn.Dropout2d(0.2)self.dropout2=nn.Dropout2d(0.2)self.tmp = torch.randn(size=(1, 3))pe = torch.randn(size=(1, 3))self.register_buffer('pe', pe)def forward(self, x):x = F.relu(self.fc1(x))x = self.dropout1(x)x = F.relu(self.fc2(x))x = self.dropout2(x)return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)print(torch.__version__)root="mydir/"trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)epochs = 1
for epoch in range(epochs):train_loss = 0train_acc = 0val_loss = 0val_acc = 0net.train()for i, (images, labels) in enumerate(train_loader):images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)optimizer.zero_grad()out = net(images)loss = criterion(out, labels)train_loss += loss.item()train_acc += (out.max(1)[1] == labels).sum().item()loss.backward()optimizer.step()avg_train_loss = train_loss / len(train_loader.dataset)avg_train_acc = train_acc / len(train_loader.dataset)net.eval()with torch.no_grad():for (images, labels) in test_loader:images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)out = net(images)loss = criterion(out, labels)val_loss += loss.item()acc = (out.max(1)[1] == labels).sum()val_acc += acc.item()avg_val_loss = val_loss / len(test_loader.dataset)avg_val_acc = val_acc / len(test_loader.dataset)print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'.format(epoch+1, epochs, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))dir_name = 'output'
if not os.path.exists(dir_name):os.mkdir(dir_name)
model_save_path = os.path.join(dir_name, "model.pt")
torch.save(net.state_dict(), model_save_path)model = MLPNet()
model.load_state_dict(torch.load(model_save_path))print(model.tmp)
print(model.pe)
# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transformsclass MLPNet (nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1 * 28 * 28, 128)self.fc2 =nn.Linear(128, 128)self.fc3 = nn.Linear(128, 10)self.dropout1=nn.Dropout2d(0.2)self.dropout2=nn.Dropout2d(0.2)self.tmp = torch.randn(size=(1, 3))pe = torch.randn(size=(1, 3))self.register_buffer('pe', pe)def forward(self, x):x = F.relu(self.fc1(x))x = self.dropout1(x)x = F.relu(self.fc2(x))x = self.dropout2(x)return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)dir_name = 'output'
if not os.path.exists(dir_name):os.mkdir(dir_name)model_save_path = os.path.join(dir_name, "model.pt")model = MLPNet()
model.load_state_dict(torch.load(model_save_path))print(model.tmp)
print(model.pe)

从模型加载的pe值,从未改变

tensor([[0.0566, 0.8944, 0.0873]])
tensor([[ 0.2529,  0.5227, -0.2610]])
tensor([[ 0.4632, -0.2602, -1.0032]])
tensor([[-0.3486,  0.8183, -1.3838]])
tensor([[ 0.7163,  0.5574, -0.0848]])
tensor([[-0.3415, -0.9013, -1.6136]])
tensor([[ 0.5490,  1.7691, -1.1375]])
tensor([[-0.3486,  0.8183, -1.3838]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/788422.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习理论基础(六)注意力机制

目录 深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息&…

基于 CentOS7 制作 Apache HTTPD 2.4.58 的RPM安装包

编译环境: 操作系统:CentOS7 httpd版本:2.4.58 制作工具:rpmbuild(这个之前的文章有介绍,看这里) 下载httpd源码: 官网目前的最新版本是2.4.58,下载备用&#xff0c…

http: server gave HTTP response to HTTPS client 分析一下这个问题如何解决中文告诉我详细的解决方案

这个错误信息表明 Docker 客户端在尝试通过 HTTPS 协议连接到 Docker 仓库时,但是服务器却返回了一个 HTTP 响应。这通常意味着 Docker 仓库没有正确配置为使用 HTTPS,或者客户端没有正确配置以信任仓库的 SSL 证书。以下是几种可能的解决方案&#xff1…

半导体制程离子注入注入的是哪些离子

离子注入是一种低温过程 通过该过程将一种元素的离子加速进入固体靶材,从而改变靶材的物理、化学或电学性质。离子注入用于半导体器件制造和金属精加工以及材料科学研究。如果离子停止并保留在目标中,则它们可以改变目标的元素成分(如果离子…

开源充电桩设备监控系统技术解决方案

开源 | 慧哥充电桩平台V2.5.2(支持 汽车 电动自行车 云快充1.5、云快充1.6 微服务 ) SpringBoot设备监控系统解决方案 一、引言 1.项目背景 随着物联网技术的快速发展,设备的智能化和网络化程度日益提高。在现代工业和信息化的背景下&#x…

6 个典型的Java 设计模式应用场景题

单例模式(Singleton) 场景: 在一个Web服务中,数据库连接池应当在整个应用生命周期中只创建一次,以减少资源消耗和提升性能。使用单例模式确保数据库连接池的唯一实例。 代码实现: import java.sql.Connection; import java.sql.SQLException;public class DatabaseConne…

上位机图像处理和嵌入式模块部署(qmacviusal边缘宽度测量)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面有一篇文章,我们了解了测量标定是怎么做的。即,我们需要提前知道测量的方向,灰度的方向,实际的…

“省钱有道”的太平鸟,如何真正“高飞”?

衣食住行产业中,服装品类消费弹性较大、可选属性较强,其发展可以显著反映当前的经济温度。 根据国家统计局数据,2023年1-12月,我国限额以上单位服装类商品零售额累计10352.9亿元,同比增长15.4%,增速比2022…

Python框架下的qt设计之JSON格式化转换小程序

JSON转换小程序 代码展示: 主程序代码: from PyQt6.QtWidgets import (QApplication, QDialog, QMessageBox )import sys import jsonclass MyJsonFormatter(jsonui.Ui_jsonFormatter,QDialog): # jsonui是我qt界面py文件名def __init__(self):super()…

【HTML】注册页面制作 案例二

(大家好,今天我们将通过案例实战对之前学习过的HTML标签知识进行复习巩固,大家和我一起来吧,加油!💕) 案例复习 通过综合案例,主要复习: 表格标签,可以让内容…

说明计算机视觉(CV)技术的优势和挑战

计算机视觉(Computer Vision,CV)技术是一种利用计算机科学和工程技术来处理和分析图像和视频的技术。以下是计算机视觉技术的优势和挑战的几个例子: 优势: 高效快速:计算机视觉技术可以在短时间内处理大量…

【Go】十七、进程、线程、协程

文章目录 1、进程、线程2、协程3、主死从随4、启动多个协程5、使用WaitGroup控制协程退出6、多协程操作同一个数据7、互斥锁8、读写锁9、deferrecover优化多协程 1、进程、线程 进程作为资源分配的单位,在内存中会为每个进程分配不同的内存区域 一个进程下面有多个…

集合的学习

为什么要有集合:集合会自动扩容 集合不能存基本数据类型(基本数据类型是存放真实的值,而引用数据类型是存放一个地址,这个地址存放在栈区,地址所指向的内容存放在堆区) 数组和集合的对比: 集…

Zookeeper 怎么实现分布式锁

基于ZooKeeper实现分布式锁的原理主要基于ZooKeeper提供的一些特性,包括有序性、唯一性、临时节点等。下面是基于ZooKeeper实现分布式锁的 基本原理 有序性:ZooKeeper保证所有写入操作的全局顺序性。当客户端向ZooKeeper写入数据时,ZooKeepe…

Flutter 开发学习笔记(3):第三方UI库的引入

文章目录 前言初始化程序Icon导入如何导入 Toast消息提示框引入简单封装简单使用 Charts图表导入新建pages文件夹存放page简单代码实现效果 总结 前言 Flutter已经发布了有10年了,生态也算比较完善了。用于安卓程序开发应该是非常的方便。我们这里就接入一些简单的…

Pytorch实用教程:TensorDataset和DataLoader的介绍及用法示例

TensorDataset TensorDataset是PyTorch中torch.utils.data模块的一部分,它包装张量到一个数据集中,并允许对这些张量进行索引,以便能够以批量的方式加载它们。 当你有多个数据源(如特征和标签)时,TensorD…

golang语言系列:Web框架+路由 之 Gin

云原生学习路线导航页(持续更新中) 本文是golang语言学习系列,本篇对Gin框架的基本使用方法进行学习 1.Gin框架是什么 Gin 是一个 Go (Golang) 编写的轻量级 http web 框架,运行速度非常快,如果你是性能和高效的追求者…

【JavaEE】_Spring MVC项目上传文件

目录 1. 文件上传具体实现 2. 保存文件 1. 文件上传具体实现 .java文件内容如下: package com.example.demo.controller;import com.example.demo.Person; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.Multip…

拒绝服务攻击(Dos)与Tomcat的解决方法

拒绝服务攻击Dos 拒绝服务攻击(Denial of Service,DoS)是一种网络攻击,旨在使目标系统无法提供正常的服务,使其无法响应合法用户的请求。这种攻击通过消耗目标系统的资源,例如带宽、处理能力或存储空间&am…

【C语言数据库】Sqlite3基础介绍

1. SQLite简介 SQLite is a C-language library that implements a small, fast, self-contained, high-reliability, full-featured, SQL database engine. SQLite is the most used database engine in the world. SQLite is built into all mobile phones and most computer…