从0-1实现大模型

 

目录

输入数据准备 

滑动窗口造数据 Data sampling with a sliding window

数据加载器的输出DataLoader

位置编码Encoding word positions

自注意力机制

点积的原理

QKV的原理

实现代码

Multi-head Attention 

线性层


github: LLMs-from-scratch/ch02/01_main-chapter-code

LLM Visualization

输入数据准备 

滑动窗口造数据 Data sampling with a sliding window

We train LLMs to generate one word at a time, so we want to prepare the training data accordingly where the next word in a sequence represents the target to predict:

 and ---->  establishedand established ---->  himselfand established himself ---->  inand established himself in ---->  a

数据加载器的输出DataLoader

批量Batch的输出形式是:

Inputs:tensor([[992, 993, 994, 995],[996, 997, 998, 999]])Targets:tensor([[ 993,  994,  995,  996],[ 997,  998,  999, 1000]])
无需Batch输出的单条形式:
[tensor([[0, 1, 2, 3]]), tensor([[1, 2, 3, 4]])]
[tensor([[1, 2, 3, 4]]), tensor([[2, 3, 4, 5]])]

位置编码Encoding word positions

嵌入层将标识符转换为相同的向量表示,无论它们在输入序列中的位置如何:

位置嵌入与标记嵌入向量相结合,形成大型语言模型的输入嵌入。

自注意力机制

点积的原理

Q*K 这个点积操作, 这是一种通用且简单的方法,确保每个输出元素都能受到输入向量中所有元素的影响(其中的影响由权重决定)。因此,它在神经网络中经常出现。

点积是衡量两个向量之间相似性的一种方式。如果它们非常相似,点积会很大。如果它们非常不同,点积会很小或为负。

我们对Q、K、V向量的每个输出单元重复此操作:

QKV的原理

我们的Q(查询)、K(键)和V(值)向量是用来做什么的?命名给我们提供了一个提示:“键”和“值”让人想起软件中的字典,键映射到值。然后“查询”是我们用来查找值的工具。

软件类比: 查找表: table = { "key0": "value0", "key1": "value1", ... } 查询过程: table["key1"] => "value1"

在自注意力的情况下,我们不是返回单个条目,而是返回条目的某种加权组合。为了找到那个加权,我们取Q向量和每个K向量之间的点积。我们规范化那个加权,然后最后用它来乘以相应的V向量,并将它们全部加起来。

我们的查找表中的{K, V}条目是过去的6列,而Q值是当前时间。

我们首先计算当前列(t = 5)的Q向量与之前每列的K向量之间的点积。然后将这些点积存储在注意力矩阵中相应的行(t = 5)。

实现代码

class CausalAttention(nn.Module):def __init__(self, d_in, d_out, context_length,dropout, qkv_bias=False):super().__init__()self.d_out = d_outself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.dropout = nn.Dropout(dropout) # Newself.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Newdef forward(self, x):b, num_tokens, d_in = x.shape # New batch dimension bkeys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.transpose(1, 2) # Changed transposeattn_scores.masked_fill_(  # New, _ ops are in-placeself.mask.bool()[:num_tokens, :num_tokens], -torch.inf) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights) # Newcontext_vec = attn_weights @ valuesreturn context_vectorch.manual_seed(123)context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)context_vecs = ca(batch)print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

Multi-head Attention 

这就是自注意力层中一个头的处理过程。因此,自注意力的主要目标是每一列都想要从其他列中找到相关信息并提取它们的值,这是通过将其查询向量与那些其他列的键向量进行比较来实现的。加上的一个限制是,它只能查看过去。

多头注意力机制的主要思想是使用不同的、学习得到的线性投影并行地多次运行注意力机制。这使得模型能够同时从不同的位置关注来自不同表示子空间的信息。

class MultiHeadAttentionWrapper(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])def forward(self, x):return torch.cat([head(x) for head in self.heads], dim=-1)torch.manual_seed(123)context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2
)context_vecs = mha(batch)print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

这就是一个完整的Transformer块!

这些块构成了任何GPT模型的主体,并会重复多次,一个块的输出会作为下一个块的输入,继续沿着残差路径前进。

正如在深度学习中常见的那样,很难确切地说这些层在做什么,但我们有一些大致的想法:早期层倾向于专注于学习低级特征和模式,而后期层则学习识别和理解更高层次的抽象和关系。在自然语言处理的背景下,较低层可能学习语法、句法和简单的词汇关联,而较高层可能捕捉更复杂的语义关系、话语结构和依赖上下文的含义。

线性层

最后,我们来到了模型的末端。最终的Transformer块的输出会通过一个层归一化,然后我们使用一个线性变换(矩阵乘法),这次没有偏置项。

这次最终的变换将每个列向量从长度C变换到长度nvocab。因此,它实际上是为每个列的词汇表中的每个词生成一个分数。这些分数有一个特殊的名字:logits。

"Logits"这个名字来源于"log-odds",即每个标记的赔率的对数。"Log"被使用是因为我们接下来应用的softmax会进行指数化以转换为"赔率"或概率。

为了将这些分数转换为漂亮的概率,我们通过一个softmax操作传递它们。现在,对于每个列,我们都有一个模型分配给词汇表中每个词的概率。

在这个特定的模型中,它有效地学习了所有如何排序三个字母的问题的答案,所以概率严重偏向于正确答案。

当我们按时间顺序运行模型时,我们使用最后一列的概率来确定要添加到序列中的下一个标记。例如,如果我们向模型提供了六个标记,我们将使用第六列的输出概率。

这个列的输出是一系列概率,我们实际上必须选择其中之一作为序列中的下一个。我们通过"从分布中采样"来做到这一点。也就是说,我们根据其概率加权随机选择一个标记。例如,一个有0.9概率的标记将有90%的机会被选择。

然而,这里还有其他选项,比如总是选择概率最高的标记。

我们还可以通过使用温度参数来控制分布的"平滑度"。较高的温度会使分布更加均匀,而较低的温度会使它更集中在最高概率的标记上。

我们通过在应用softmax之前将logits(线性变换的输出)除以温度来做到这一点。由于softmax中的指数化对较大的数字有很大的影响,使它们更接近会减少这种效果。

Finally, we come to the end of the model. The output of the final transformer block is passed through a layer normalization, and then we use a linear transformation (matrix multiplication), this time without a bias.

This final transformation takes each of our column vectors from length C to length nvocab. Hence, it's effectively producing a score for each word in the vocabulary for each of our columns. These scores have a special name: logits.

The name "logits" comes from "log-odds," i.e., the logarithm of the odds of each token. "Log" is used because the softmax we apply next does an exponentiation to convert to "odds" or probabilities.

To convert these scores into nice probabilities, we pass them through a softmax operation. Now, for each column, we have a probability the model assigns to each word in the vocabulary.

In this particular model, it has effectively learned all the answers to the question of how to sort three letters, so the probabilities are heavily weighted toward the correct answer.

When we're stepping the model through time, we use the last column's probabilities to determine the next token to add to the sequence. For example, if we've supplied six tokens into the model, we'll use the output probabilities of the 6th column.

This column's output is a series of probabilities, and we actually have to pick one of them to use as the next in the sequence. We do this by "sampling from the distribution." That is, we randomly choose a token, weighted by its probability. For example, a token with a probability of 0.9 will be chosen 90% of the time.

There are other options here, however, such as always choosing the token with the highest probability.

We can also control the "smoothness" of the distribution by using a temperature parameter. A higher temperature will make the distribution more uniform, and a lower temperature will make it more concentrated on the highest probability tokens.

We do this by dividing the logits (the output of the linear transformation) by the temperature before applying the softmax. Since the exponentiation in the softmax has a large effect on larger numbers, making them all closer together will reduce this effect.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/848346.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL之查询性能优化(六)

查询性能优化 查询优化器 9.等值传播 如果两个列的值通过等式关联,那么MySQL能够把其中一个列的WHERE条件传递到另一列上。例如,我们看下面的查询: mysql> SELECT film.film_id FROM film-> INNER JOIN film_actor USING(film_id)-> WHERE f…

四川汇聚荣聚荣科技有限公司综合实力怎么样?

在科技日新月异的今天,企业的综合实力成为衡量其市场竞争力的重要指标。四川汇聚荣聚荣科技有限公司作为一家在行业内具有一定影响力的企业,其综合实力如何,自然成为外界关注的焦点。以下将从多个维度深入分析该公司的实力。 一、公司概况与核…

模型测试优化

针对怼螺丝孔场景交叉测试 文章目录 修改一:修改二: 基于训练场景,进行修改,用以验证泛化性 模型说明:训练所用的物体模型上,有两个孔位,其中左侧为1号孔位,右侧为2号孔位 现状&…

QtCharts使用

1.基础配置 1.QGraphicsView提升为QChartView#include <QtCharts> QT_CHARTS_USE_NAMESPACE #include "ui_widget.h"2. QT charts 2.柱状图 2.1QBarSeries //1.创建Qchart对象QChart *chart new QChart();chart->setTitle("直方图演示");//设…

数据结构复习指导之归并排序、基数排序、计数排序

目录 1.归并排序 1.1二路归并操作的功能 1.2算法思想 1.3代码分析 1.4性能分析 2.基数排序 2.1算法思想 2.2基数排序的中间过程的分析 2.3性能分析 3.计数排序 3.1算法思想 3.2代码分析 3.3性能分析 知识回顾 1.归并排序 1.1二路归并操作的功能 归并排序与上述基…

HarmonyOS鸿蒙-DevEco Studio工具

一、官网下载DevEco Studio工具地址 文章内容: 1、下载工具 2、运行项目 3、安装启动器 https://developer.harmonyos.com/cn/develop/deveco-studio/https://developer.harmonyos.com/cn/develop/deveco-studio/ 下载不同平台工具目录 : 二、 安装DevEco Studio工具 安装的配置…

如何令谷歌浏览器搜索时,子页面使用新窗口,而不是迭代打开

1 问题描述 工作相关需要常用谷歌浏览器&#xff0c;但是现在设置就是每次搜索后&#xff0c;点击搜索结果进去之后&#xff0c;都会覆盖掉原来的父页面&#xff0c;也就是如果我看完了这个子页面的内容&#xff0c;关掉的话&#xff0c;我就需要重新google.com来一遍。。。很…

Dinky MySQLCDC 整库同步到 MySQL jar包冲突问题解决

资源&#xff1a;flink 1.17.0、dinky 1.0.2 问题&#xff1a;对于kafka相关的包内类找不到的情况 解决&#xff1a;使用 flink-sql-connector- 胖包即可&#xff0c;去掉 flink-connector- 相关瘦包&#xff0c;解决胖瘦包冲突 source使用 flink-sql-connector- 胖包&#…

Java【springBoot和springCould引入外部jar包】

在项目的研发过程中&#xff0c;我们经常需要导入外部系统提供的jar包&#xff0c;并且这种jar包并没有上传到开源的maven仓库&#xff0c;属于内部环境的包&#xff0c;那么应该如何添加呢&#xff1f; springBoot 1、首先&#xff0c;将你的 JAR 文件拷贝到项目的 resource…

基础数学-求平方根(easy)

一、问题描述 二、实现思路 1.题目不能直接调用Math.sqrt(x) 2.这个题目可以使用二分法来缩小返回值范围 所以我们在left<right时 使 mid (leftright)/21 当mid*mid>x时&#xff0c;说明right范围过大&#xff0c;rightright-1 当mid*mid<x时&#xff0c;说明left范…

使用Qt对word文档进行读写

目录 开发环境原理使用的QT库搭建开发环境准备word模板测试用例结果Gitee地址 开发环境 vs2022 Qt 5.9.1 msvc2017_x64&#xff0c;在文章最后提供了源码。 原理 Qt对于word文档的操作都是在书签位置进行插入文本、图片或表格的操作。 使用的QT库 除了基本的gui、core、…

JavaWeb1 Json+BOM+DOM+事件监听

JS对象-Json //Json 字符串转JS对象 var jsObject Json.parse(userStr); //JS对象转JSON字符串 var jsonStr JSON.stringify(jsObject);JS对象-BOM BOM是浏览器对象模型&#xff0c;允许JS与浏览器对话 它包括5个对象&#xff1a;window、document、navigator、screen、hi…

力扣hot100:138. 随机链表的复制(技巧,数据结构)

LeetCode&#xff1a;138. 随机链表的复制 这是一个经典的数据结构题&#xff0c;当做数据结构来学习。 1、哈希映射 需要注意的是&#xff0c;指针也能够当做unordered_map的键值&#xff0c;指针实际上是一个地址值&#xff0c;在unordered_map中&#xff0c;使用指针的实…

VXLAN技术

VXLAN技术 一、VXLAN简介 1、定义 VXLAN&#xff08;Virtual eXtensible Local Area Network&#xff09;&#xff1a;采用MAC in UDP&#xff08;User Datagram Protocol&#xff09;封装方式&#xff0c;是NVO3&#xff08;Network Virtualization over Layer 3&#xff09…

使用 Logback.xml 配置文件输出日志信息

官方链接&#xff1a;Chapter 3: Configurationhttps://logback.qos.ch/manual/configuration.html 配置使用 logback 的方式有很多种&#xff0c;而使用配置文件是较为简单的一种方式&#xff0c;下述就是简单描述一个 logback 配置文件基本的配置项&#xff1a; 由于 logba…

Vuforia AR篇(七)— 二维码识别

目录 前言一、什么是Barcode &#xff1f;二、使用步骤三、点击二维码显示信息四、效果 前言 在数字化时代&#xff0c;条形码和二维码已成为连接现实世界与数字信息的重要桥梁。Vuforia作为领先的AR开发平台&#xff0c;提供了Barcode Scanner功能&#xff0c;使得在Unity中实…

json和axion结合

目录 java中使用JSON对象 在pom.xml中导入依赖 使用 public static String toJSONString(Object object)把自定义对象变成JSON对象 json和axios综合案例 使用的过滤器 前端代码 响应和请求都是普通字符串 和 请求时普通字符串&#xff0c;响应是json字符串 响应的数据是…

MySQL换路径(文件夹)

#MySQL作为免费数据库很受欢迎&#xff0c;即使公司没有使用&#xff0c;自己也可以用。它是一个服务&#xff0c;在点击CtrlAltDelete选择任务管理器后&#xff0c;它在服务那个归类里。 经常整理计算机磁盘分类的小伙伴&#xff0c;如果你们安装了MySQL&#xff0c;并且想移…

插件:Plugins

一、安装网格插件

重大变化,2024软考!

根据官方发布的2024年度计算机技术与软件专业技术资格&#xff08;水平&#xff09;考试安排&#xff0c;2024年软考上、下半年开考科目有着巨大变化&#xff0c;我为大家整理了相关信息&#xff0c;大家可以看看&#xff01; &#x1f3af;2024年上半年&#xff1a;5月25日&am…