karpathy make more -- 4

1 Introduction

这个部分要完成一个网络的模块化,然后实现一个新的网络结构。

2 使用torch的模块化功能

2.1 模块化

将输入的字符长度变成8,并将之前的代码模块化

# Near copy paste of the layers we have developed in Part 3# -----------------------------------------------------------------------------------------------class Linear:def __init__(self, fan_in, fan_out, bias=True):self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # note: kaiming initself.bias = torch.zeros(fan_out) if bias else Nonedef __call__(self, x):self.out = x @ self.weightif self.bias is not None:self.out += self.biasreturn self.outdef parameters(self):return [self.weight] + ([] if self.bias is None else [self.bias])# -----------------------------------------------------------------------------------------------
class BatchNorm1d:def __init__(self, dim, eps=1e-5, momentum=0.1):self.eps = epsself.momentum = momentumself.training = True# parameters (trained with backprop)self.gamma = torch.ones(dim)self.beta = torch.zeros(dim)# buffers (trained with a running 'momentum update')self.running_mean = torch.zeros(dim)self.running_var = torch.ones(dim)def __call__(self, x):# calculate the forward passif self.training:if x.ndim == 2:dim = 0elif x.ndim == 3:dim = (0,1)xmean = x.mean(dim, keepdim=True) # batch meanxvar = x.var(dim, keepdim=True) # batch varianceelse:xmean = self.running_meanxvar = self.running_varxhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit varianceself.out = self.gamma * xhat + self.beta# update the buffersif self.training:with torch.no_grad():self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmeanself.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvarreturn self.outdef parameters(self):return [self.gamma, self.beta]# -----------------------------------------------------------------------------------------------
class Tanh:def __call__(self, x):self.out = torch.tanh(x)return self.outdef parameters(self):return []# -----------------------------------------------------------------------------------------------
class Embedding:def __init__(self, num_embeddings, embedding_dim):self.weight = torch.randn((num_embeddings, embedding_dim))def __call__(self, IX):self.out = self.weight[IX]return self.outdef parameters(self):return [self.weight]class Flatten:def __call__(self, x):self.out = x.view(x.shape[0], -1)return self.outdef parameters(self):return []# -----------------------------------------------------------------------------------------------
class FlattenConsecutive:def __init__(self, n):self.n = ndef __call__(self, x):B, T, C = x.shapex = x.view(B, T//self.n, C*self.n)if x.shape[1] == 1:x = x.squeeze(1)self.out = xreturn self.outdef parameters(self):return []# -----------------------------------------------------------------------------------------------
class Sequential:def __init__(self, layers):self.layers = layersdef __call__(self, x):for layer in self.layers:x = layer(x)self.out = xreturn self.outdef parameters(self):# get parameters of all layers and stretch them out into one listreturn [p for layer in self.layers for p in layer.parameters()]

定义网络结构

block_size = 8
n_emb = 10
n_batch = 32
n_hidden = 200
g = torch.Generator().manual_seed(2147483647)
model = Sequential([Embedding(vocab_size, n_emb),Flatten(),Linear(n_emb * block_size, n_hidden, bias=False),BatchNorm1d(n_hidden),Tanh(),Linear(n_hidden, vocab_size),
])
with torch.no_grad():model.layers[-1].weight *= 0.1print(sum(p.nelement() for p in model.parameters()))
for p in model.parameters():p.requires_grad = Truefor layer in model.layers:if isinstance(layer, BatchNorm1d):layer.training = True

进行训练

import torch.nn.functional as F
max_iter = 200000
lossi = []
ud = []
for i in range(max_iter):ix = torch.randint(0, Xtr.shape[0], (n_batch,), generator=g)Xb, Yb = Xtr[ix], Ytr[ix]logits = model(Xb)loss = F.cross_entropy(logits, Yb)for p in model.parameters():p.grad = Noneloss.backward()lr = 0.1 if i < 100000 else 0.01for p in model.parameters():p.data -= lr * p.grad.datalossi.append(loss.item())with torch.no_grad():ud.append([((-lr * p.grad).std() / p.data.std()).log10().item() for p in model.parameters()])if i % 1000 == 0:print(f"Iteration: {i}/{max_iter}, Loss: {loss.item()}")# break

显示曲线

import matplotlib.pyplot as plt
plt.plot(torch.tensor(lossi).view(-1, 1000).mean(dim=1, keepdim=False))

在这里插入图片描述
比较训练和测试的误差

@torch.no_grad()
def batch_infer(datasets):X, Y = {'train' : (Xtr, Ytr),'val' : (Xdev, Ydev),'test' : (Xte, Yte),}[datasets]logits = model(X)loss = F.cross_entropy(logits, Y)print(f'{datasets}, loss is: {loss}')for layer in model.layers:if isinstance(layer, BatchNorm1d):layer.training = False
batch_infer('train')
batch_infer('val')

train, loss is: 1.926148533821106
val, loss is: 2.028862237930298

网络现在有一点过拟合了。
看一下输出的结果

for _ in range(20):context = [0] * block_sizech = []while(True):X = torch.tensor([context])logits = model(X)probs = torch.softmax(logits, dim=-1).squeeze(0)ix = torch.multinomial(probs, num_samples=1).item()context = context[1:] + [ix]ch.append(itos[ix])if ix == 0:breakprint(''.join(ch))

quab.
nomawa.
brenne.
sevanille.
razlyn.
zile.
audaina.
zaralynn.
dawsyn.
wyle.
yalikbi.
zuria.
endrame.
mesty.
nooap.
dangele.
ellania.
bako.
memaisee.
zailan.

2.2 加上wavenet

在这里插入图片描述
这个图表示,两个点使用相同的参数矩阵C,进行映射。
首先来看矩阵乘法的表示
在这里插入图片描述
然后再来看我们这个问题,
在这里插入图片描述

代码表示为:

class FlattenConsecutive:def __init__(self, n):self.n = ndef __call__(self, x):B, T, C = x.shapex = x.view(B, T//self.n, C*self.n)if x.shape[1] == 1:x = x.squeeze(1)self.out = xreturn self.outdef parameters(self):return []

定义新的完整网络

block_size = 8
n_emb = 24
n_batch = 32
n_hidden = 128
g = torch.Generator().manual_seed(2147483647)
model = Sequential([Embedding(vocab_size, n_emb),FlattenConsecutive(2), Linear(n_emb * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),FlattenConsecutive(2), Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),FlattenConsecutive(2), Linear(n_hidden * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),Linear(n_hidden, vocab_size)
])
with torch.no_grad():model.layers[-1].weight *= 0.1print(sum(p.nelement() for p in model.parameters()))
for p in model.parameters():p.requires_grad = Truefor layer in model.layers:if isinstance(layer, BatchNorm1d):layer.training = True

这里注意一个问题,因为我们采用是batchnormal, 也就是说除了最后一维的数据,其他的数据需要normalize

class BatchNorm1d:def __init__(self, dim, eps=1e-5, momentum=0.1):self.eps = epsself.momentum = momentumself.training = True# parameters (trained with backprop)self.gamma = torch.ones(dim)self.beta = torch.zeros(dim)# buffers (trained with a running 'momentum update')self.running_mean = torch.zeros(dim)self.running_var = torch.ones(dim)def __call__(self, x):# calculate the forward passif self.training:if x.ndim == 2:dim = 0elif x.ndim == 3:dim = (0,1)xmean = x.mean(dim, keepdim=True) # batch meanxvar = x.var(dim, keepdim=True) # batch varianceelse:xmean = self.running_meanxvar = self.running_varxhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit varianceself.out = self.gamma * xhat + self.beta# update the buffersif self.training:with torch.no_grad():self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmeanself.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvarreturn self.outdef parameters(self):return [self.gamma, self.beta]

其他的内容和之前的网络相同,最后看一下训练的结果
在这里插入图片描述

train, loss is: 1.7904815673828125
val, loss is: 1.9868937730789185

sabris.
lilly.
pryce.
antwling.
lakelyn.
dayre.
theora.
hunna.
michael.
amillia.
zivy.
zuri.
florby.
jairael.
aiyank.
anahit.
madelynn.
briani.
payzleigh.
sola.

2.3 convolution

我们只执行了这里的黑色部分的代码,如果完整执行就是一个convolutional neural network。
在这里插入图片描述

References

[1] WaveNet 2016 from DeepMind https://arxiv.org/abs/1609.03499

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/5778.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

8. Django 表单与模型

8. 表单与模型 表单是搜集用户数据信息的各种表单元素的集合, 其作用是实现网页上的数据交互, 比如用户在网站输入数据信息, 然后提交到网站服务器端进行处理(如数据录入和用户登录注册等).网页表单是Web开发的一项基本功能, Django的表单功能由Form类实现, 主要分为两种: dj…

Odoo14修改登录界面,实现炫酷粒子效果

目录 原登录界面 最终效果 实现步骤 插件下载 原登录界面 最终效果 实现步骤 1 odoo创建插件web_login 2 在static目录下编写css和js文件 login.css代码 html, body {position:fixed;top:0px;left:0px;height:100%;width:100%;/*Fallback if gradeints dont work */b…

【项目学习01_2024.05.01_Day03】

学习笔记 3.6 开发业务层3.6.1 创建数据字典表3.6.2 编写Service3.6.3 测试Service 3.7 接口测试3.7.1 接口完善3.7.2 Httpclient测试 3.8 前后端联调3.8.1 准备环境3.8.2 安装系统管理服务3.8.3 解决跨域问题解决跨域的方法&#xff1a;我们准备使用方案2解决跨域问题。在内容…

hadoop学习---基于hive的航空公司客户价值的LRFCM模型案例

案例需求&#xff1a; RFM模型的复习 在客户分类中&#xff0c;RFM模型是一个经典的分类模型&#xff0c;模型利用通用交易环节中最核心的三个维度——最近消费(Recency)、消费频率(Frequency)、消费金额(Monetary)细分客户群体&#xff0c;从而分析不同群体的客户价值。在某些…

CTFHub-Web-文件上传

CTFHub-Web-文件上传-WP 一、无验证 1.编写一段PHP木马脚本 2.将编写好的木马进行上传 3.显示上传成功了 4.使用文件上传工具进行尝试 5.连接成功进入文件管理 6.上翻目录找到flag文件 7.打开文件查看flag 二、前端验证 1.制作payload进行上传发现不允许这种类型的文件上传 …

手机测试之-adb

一、Android Debug Bridge 1.1 Android系统主要的目录 1.2 ADB工具介绍 ADB的全称为Android Debug Bridge,就是起到调试桥的作用,是Android SDK里面一个多用途调试工具,通过它可以和Android设备或模拟器通信,借助adb工具,我们可以管理设备或手机模拟器的状态。还可以进行很多…

数字旅游以科技创新为核心:推动旅游服务的智能化、精准化、个性化,为游客提供更加贴心、专业、高效的旅游服务

目录 一、引言 二、数字旅游以科技创新推动旅游服务智能化 1、智能化技术的应用 2、提升旅游服务的效率和质量 三、数字旅游以科技创新推动旅游服务精准化 1、精准化需求的识别与满足 2、精准化营销与推广 四、数字旅游以科技创新推动旅游服务个性化 1、个性化服务的创…

FIFO Generate IP核使用——Native Ports页配置

在使用FIFO Generate IP核时&#xff0c;如果在Basic选项页选择了Naitve接口&#xff0c;就需要配置Native Ports页&#xff0c;该页提供了针对FIFO核心的性能选项&#xff08;读取模式&#xff09;、数据端口参数、ECC&#xff08;错误检查和纠正&#xff09;以及初始化选项。…

「生存即赚」链接现实与游戏,打造3T平台生态

当前&#xff0c;在线角色扮演游戏&#xff08;RPG&#xff09;在区块链游戏市场中正迅速崛起&#xff0c;成为新宠。随着区块链技术的不断进步&#xff0c;众多游戏开发者纷纷将其游戏项目引入区块链领域&#xff0c;以利用这一新兴技术实现商业价值的最大化。在这一趋势中&am…

Flutter笔记:Widgets Easier组件库(8)使用图片

Flutter笔记 Widgets Easier组件库&#xff08;8&#xff09;&#xff1a;使用图片 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress o…

redis核心数据结构——跳表项目设计与实现(跳表结构介绍,节点类设计,随机层级函数)

跳表结构介绍。跳表是redis等知名软件的核心数据结构&#xff0c;其实现的前提是有序链表&#xff0c;思想的本质是在原有一串存储数据的链表中&#xff0c;间隔地抽出一半元素作为上一级链表&#xff0c;并将抽提出的元素和原先的位置相关联&#xff0c;这样重复下去直到最上层…

前端鼠标放上去显示更多内容demo

参考文献: title - HTML&#xff08;超文本标记语言&#xff09; | MDN (mozilla.org) <div class"up-detail" title"我是二五仔、总督小号、单曲切片人。 你甚至能在音 手 头条 管 港台bili ytb看到嘎的单曲。我是二五仔、总督小号、单曲切片人。 你甚至能…

【Mac】Axure RP 9(交互原型设计软件)安装教程

软件介绍 Axure RP 9是一款强大的原型设计工具&#xff0c;广泛用于用户界面和交互设计。它提供了丰富的功能和工具&#xff0c;能够帮助设计师创建高保真的交互原型&#xff0c;用于展示和测试软件应用或网站的功能和流程。以下是Axure RP 9的主要特点和功能&#xff1a; 交…

acwing算法提高之数据结构--平衡树Treap

目录 1 介绍2 训练 1 介绍 本博客用来记录使用平衡树求解的题目。 插入、删除、查询操作的时间复杂度都是O(logN)。 动态维护一个有序序列。 2 训练 题目1&#xff1a;253普通平衡树 C代码如下&#xff0c; #include <cstdio> #include <cstring> #include …

程序设计基础--C语言【五】

数组 目录 数组 5.1.一维数组 5.1.1.一维数组的引用 5.1.2.一维数组的初始化 5.1.3.一维数组的程序举例 5.2.二维数组 5.2.1.二维数组的定义 5.2.2.二维数组的引用 5.2.3.二维数组的初始化 5.2.4.举例 5.3.字符数组与字符串 5.3.1.字符组的初始化 5.3.2.字符数组…

【介绍下大数据组件之Storm】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

MySQL-分页查询

MySQL分页查询 MySQL 分页查询原则&#xff1a; 在 MySQL 数据库中使用 LIMIT 子句进行分页查询。MySQL 分页中开始位置为 0。分页子句在查询语句的最后侧。 LIMIT子句 SELECT 投影列 FROM 表名 WHERE 条件 ORDER BY LIMIT 开始位置&#xff0c;查询数量;示例&#xff1a; …

Delta lake with Java--利用spark sql操作数据2

上一篇文章尝试了建库&#xff0c;建表&#xff0c;插入数据&#xff0c;还差删除和更新&#xff0c;所以在这篇文章补充一下&#xff0c;代码很简单&#xff0c;具体如下&#xff1a; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession;publi…

C++ | Leetcode C++题解之第62题不同路径

题目&#xff1a; 题解&#xff1a; class Solution { public:int uniquePaths(int m, int n) {long long ans 1;for (int x n, y 1; y < m; x, y) {ans ans * x / y;}return ans;} };

附录6-4 黑马优购项目-分类和购物车

目录 1 分类 1.1 接口 1.2 窗口限制 1.3 选中状态样式判断 1.4 点击左侧时右侧会到顶点 1.5 源码 2 购物车 2.1 store 2.2 tabBar徽标 2.3 滑动删除 2.4 结算 2.4.1 结算前登录 2.4.2 结算功能 2.5 触发组件事件 2.6 源码 1 分类 分类最上部是…