1、forward
float* logits = forward(transformer, token, pos);
输入transformer的参数,当前token,pos位置,预测出下一个token的预测值(用矩阵乘,加减乘除等运算构成Transformer)
(gdb) p *logits
$9 = 2.19071054
// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
// qkv matmuls for this position
quantize(&s->xq, s->xb, dim);
matmul(s->q, &s->xq, w->wq + l, dim, dim);
(gdb) ptype s->xb
type = float *
量化是输入是确保与权重一样的数据类型
2、sample
2.1 未进入
if (pos < num_prompt_tokens - 1) {// if we are still processing the input prompt, force the next prompt tokennext = prompt_tokens[pos + 1];} else {// otherwise sample the next token from the logitsnext = sample(sampler, logits);}
**确定next,**如果还在input prompt,那么下一个token就是next;不是,才用sample得出next
即执行
next = prompt_tokens[pos + 1];
得
(gdb) p pos
$10 = 0
(gdb) p next
$11 = 15043 //Hello
2.2 进入
(gdb) p *logits
$20 = 0.657589614
int sample(Sampler* sampler, float* logits) {// sample the token given the logits and some hyperparametersint next;if (sampler->temperature == 0.0f) {// greedy argmax sampling: take the token with the highest probabilitynext = sample_argmax(logits, sampler->vocab_size);} else {// apply the temperature to the logitsfor (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }// apply softmax to the logits to get the probabilities for next tokensoftmax(logits, sampler->vocab_size);// flip a (float) coin (this is our source of entropy for sampling)float coin = random_f32(&sampler->rng_state);// we sample from this distribution to get the next tokenif (sampler->topp <= 0 || sampler->topp >= 1) {// simply sample from the predicted probability distributionnext = sample_mult(logits, sampler->vocab_size, coin);} else {// top-p (nucleus) sampling, clamping the least likely tokens to zeronext = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);}}return next;
}
3、decode
token=1,next=15043
调用
char* piece = decode(tokenizer, token, next);
定义
char* decode(Tokenizer* t, int prev_token, int token)
{char *piece = t->vocab[token]; //Hello// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)if (prev_token == 1 && piece[0] == ' ') { piece++; }// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'// parse this and convert and return the actual byteunsigned char byte_val;if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {piece = (char*)t->byte_pieces + byte_val * 2;}return piece;
}
(gdb) p piece
$17 = 0x55ae4f286661 "Hello"