Transformer - Positional Encoding 位置编码 代码实现

Transformer - Positional Encoding 位置编码 代码实现

flyfish

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import mathclass PositionalEncoding(nn.Module):def __init__(self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer("pe", pe)def forward(self, x):x = x +  self.pe[:, : x.size(1)].requires_grad_(False)return self.dropout(x)# 词嵌⼊维度是64维
d_model = 64
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=60x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)pe_result = pe(x)print("pe_result:", pe_result)

绘图

import numpy as np
import matplotlib.pyplot as plt
# 创建⼀张15 x 5⼤⼩的画布
plt.figure(figsize=(15, 5))pe = PositionalEncoding(d_model, 0, max_len)y = pe(torch.zeros(1, max_len, d_model))# 只查看3,4,5,6维的值.
plt.plot(np.arange(max_len), y[0, :, 3:7].data.numpy())plt.legend(["dim %d"%p for p in [3,4,5,6]])

在这里插入图片描述

register_buffer 的测试

# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transformsclass MLPNet (nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1 * 28 * 28, 128)self.fc2 =nn.Linear(128, 128)self.fc3 = nn.Linear(128, 10)self.dropout1=nn.Dropout2d(0.2)self.dropout2=nn.Dropout2d(0.2)self.tmp = torch.randn(size=(1, 3))pe = torch.randn(size=(1, 3))self.register_buffer('pe', pe)def forward(self, x):x = F.relu(self.fc1(x))x = self.dropout1(x)x = F.relu(self.fc2(x))x = self.dropout2(x)return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)print(torch.__version__)root="mydir/"trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
test_set = datasets.MNIST(root=root, train=False, transform=trans, download=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'train_loader = DataLoader(train_set, batch_size=100, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)epochs = 1
for epoch in range(epochs):train_loss = 0train_acc = 0val_loss = 0val_acc = 0net.train()for i, (images, labels) in enumerate(train_loader):images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)optimizer.zero_grad()out = net(images)loss = criterion(out, labels)train_loss += loss.item()train_acc += (out.max(1)[1] == labels).sum().item()loss.backward()optimizer.step()avg_train_loss = train_loss / len(train_loader.dataset)avg_train_acc = train_acc / len(train_loader.dataset)net.eval()with torch.no_grad():for (images, labels) in test_loader:images, labels = images.view(-1, 28*28*1).to(device), labels.to(device)out = net(images)loss = criterion(out, labels)val_loss += loss.item()acc = (out.max(1)[1] == labels).sum()val_acc += acc.item()avg_val_loss = val_loss / len(test_loader.dataset)avg_val_acc = val_acc / len(test_loader.dataset)print ('Epoch [{}/{}], Loss: {loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}'.format(epoch+1, epochs, loss=avg_train_loss, val_loss=avg_val_loss, val_acc=avg_val_acc))dir_name = 'output'
if not os.path.exists(dir_name):os.mkdir(dir_name)
model_save_path = os.path.join(dir_name, "model.pt")
torch.save(net.state_dict(), model_save_path)model = MLPNet()
model.load_state_dict(torch.load(model_save_path))print(model.tmp)
print(model.pe)
# -*- coding: utf-8 -*-
"""
@author: flyfish
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transformsclass MLPNet (nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(1 * 28 * 28, 128)self.fc2 =nn.Linear(128, 128)self.fc3 = nn.Linear(128, 10)self.dropout1=nn.Dropout2d(0.2)self.dropout2=nn.Dropout2d(0.2)self.tmp = torch.randn(size=(1, 3))pe = torch.randn(size=(1, 3))self.register_buffer('pe', pe)def forward(self, x):x = F.relu(self.fc1(x))x = self.dropout1(x)x = F.relu(self.fc2(x))x = self.dropout2(x)return F.relu(self.fc3(x))
net = MLPNet()
print(net.tmp)
print(net.pe)dir_name = 'output'
if not os.path.exists(dir_name):os.mkdir(dir_name)model_save_path = os.path.join(dir_name, "model.pt")model = MLPNet()
model.load_state_dict(torch.load(model_save_path))print(model.tmp)
print(model.pe)

从模型加载的pe值,从未改变

tensor([[0.0566, 0.8944, 0.0873]])
tensor([[ 0.2529,  0.5227, -0.2610]])
tensor([[ 0.4632, -0.2602, -1.0032]])
tensor([[-0.3486,  0.8183, -1.3838]])
tensor([[ 0.7163,  0.5574, -0.0848]])
tensor([[-0.3415, -0.9013, -1.6136]])
tensor([[ 0.5490,  1.7691, -1.1375]])
tensor([[-0.3486,  0.8183, -1.3838]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/788422.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习理论基础(六)注意力机制

目录 深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息&…

http: server gave HTTP response to HTTPS client 分析一下这个问题如何解决中文告诉我详细的解决方案

这个错误信息表明 Docker 客户端在尝试通过 HTTPS 协议连接到 Docker 仓库时,但是服务器却返回了一个 HTTP 响应。这通常意味着 Docker 仓库没有正确配置为使用 HTTPS,或者客户端没有正确配置以信任仓库的 SSL 证书。以下是几种可能的解决方案&#xff1…

半导体制程离子注入注入的是哪些离子

离子注入是一种低温过程 通过该过程将一种元素的离子加速进入固体靶材,从而改变靶材的物理、化学或电学性质。离子注入用于半导体器件制造和金属精加工以及材料科学研究。如果离子停止并保留在目标中,则它们可以改变目标的元素成分(如果离子…

6 个典型的Java 设计模式应用场景题

单例模式(Singleton) 场景: 在一个Web服务中,数据库连接池应当在整个应用生命周期中只创建一次,以减少资源消耗和提升性能。使用单例模式确保数据库连接池的唯一实例。 代码实现: import java.sql.Connection; import java.sql.SQLException;public class DatabaseConne…

上位机图像处理和嵌入式模块部署(qmacviusal边缘宽度测量)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面有一篇文章,我们了解了测量标定是怎么做的。即,我们需要提前知道测量的方向,灰度的方向,实际的…

“省钱有道”的太平鸟,如何真正“高飞”?

衣食住行产业中,服装品类消费弹性较大、可选属性较强,其发展可以显著反映当前的经济温度。 根据国家统计局数据,2023年1-12月,我国限额以上单位服装类商品零售额累计10352.9亿元,同比增长15.4%,增速比2022…

Python框架下的qt设计之JSON格式化转换小程序

JSON转换小程序 代码展示: 主程序代码: from PyQt6.QtWidgets import (QApplication, QDialog, QMessageBox )import sys import jsonclass MyJsonFormatter(jsonui.Ui_jsonFormatter,QDialog): # jsonui是我qt界面py文件名def __init__(self):super()…

【HTML】注册页面制作 案例二

(大家好,今天我们将通过案例实战对之前学习过的HTML标签知识进行复习巩固,大家和我一起来吧,加油!💕) 案例复习 通过综合案例,主要复习: 表格标签,可以让内容…

【Go】十七、进程、线程、协程

文章目录 1、进程、线程2、协程3、主死从随4、启动多个协程5、使用WaitGroup控制协程退出6、多协程操作同一个数据7、互斥锁8、读写锁9、deferrecover优化多协程 1、进程、线程 进程作为资源分配的单位,在内存中会为每个进程分配不同的内存区域 一个进程下面有多个…

集合的学习

为什么要有集合:集合会自动扩容 集合不能存基本数据类型(基本数据类型是存放真实的值,而引用数据类型是存放一个地址,这个地址存放在栈区,地址所指向的内容存放在堆区) 数组和集合的对比: 集…

Flutter 开发学习笔记(3):第三方UI库的引入

文章目录 前言初始化程序Icon导入如何导入 Toast消息提示框引入简单封装简单使用 Charts图表导入新建pages文件夹存放page简单代码实现效果 总结 前言 Flutter已经发布了有10年了,生态也算比较完善了。用于安卓程序开发应该是非常的方便。我们这里就接入一些简单的…

golang语言系列:Web框架+路由 之 Gin

云原生学习路线导航页(持续更新中) 本文是golang语言学习系列,本篇对Gin框架的基本使用方法进行学习 1.Gin框架是什么 Gin 是一个 Go (Golang) 编写的轻量级 http web 框架,运行速度非常快,如果你是性能和高效的追求者…

【JavaEE】_Spring MVC项目上传文件

目录 1. 文件上传具体实现 2. 保存文件 1. 文件上传具体实现 .java文件内容如下: package com.example.demo.controller;import com.example.demo.Person; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.Multip…

day4|gin的中间件和路由分组

中间件其实是一个方法, 在.use就可以调用中间件函数 r : gin.Default()v1 : r.Group("v1")//v1 : r.Group("v1").Use()v1.GET("test", func(c *gin.Context) {fmt.Println("get into the test")c.JSON(200, gin.H{"…

特征融合篇 | YOLOv8改进之将Neck网络更换为GFPN(附2种改进方法)

前言:Hello大家好,我是小哥谈。GFPN(Global Feature Pyramid Network)是一种用于目标检测的神经网络架构,它是在Faster R-CNN的基础上进行改进的,旨在提高目标检测的性能和效果。其核心思想是引入全局特征金字塔,通过多尺度的特征融合来提取更丰富的语义信息。具体来说,…

FPGA + 图像处理 (二) RGB转YUV色域、转灰度图及仿真

前言 具体关于色域的知识就不细说了,简单来讲YUV中Y通道可以理解为就是图像的灰度图,因此,将RGB转化为YUV是求彩色图的灰度直方图、进行二值化操作等的基础。 HDMI时序生成模块 这里先介绍一下仿真时用于生成HDMI时序,用这个时…

自贡市第一人民医院:超融合与 SKS 承载 HIS 等核心业务应用,加速国产化与云原生转型

自贡市第一人民医院始建于 1908 年,现已发展成为集医疗、科研、教学、预防、公共卫生应急处置为一体的三级甲等综合公立医院。医院建有“全国综合医院中医药工作示范单位”等 8 个国家级基地,建成高级卒中中心、胸痛中心等 6 个国家级中心。医院日门诊量…

【Docker】搭建便捷的Docker容器管理工具 - dockerCopilot

【Docker】搭建便捷的Docker容器管理工具 - dockerCopilot 前言 本教程基于绿联的NAS设备DX4600 Pro的docker功能进行搭建。前面有介绍过OneKey,而dockerCopilot便是OneKey的升级版,作者对其进行了重新命名,并且对界和功能都进行了全面的优…

负载均衡集群

一、集群的基本原理 集群:数据内容是一致的,集群可以被替代 分布式:各司其职,每台服务器存储自己独有的数据,对外作为单点被访问是访问整体的数据; 分布式是不能被替代的;分布式分为MFS、GFS、…

结构体内存对齐和位段(重点)!!!

乐观学习,乐观生活,才能不断前进啊!!! 我的主页:optimistic_chen 我的专栏:c语言 点击主页:optimistic_chen和专栏:c语言, 创作不易,大佬们点赞鼓…