Minimind 训练一个自己专属语言模型

发现了一个宝藏项目, 宣传是完全从0开始,仅用3块钱成本 + 2小时!即可训练出仅为25.8M的超小语言模型MiniMind,最小版本体积是 GPT-3 的 17000,做到最普通的个人GPU也可快速训练

https://github.com/jingyaogong/minimindhttps://github.com/jingyaogong/minimind

项目包含

  • MiniMind-LLM结构的全部代码(Dense+MoE模型)。
  • 包含Tokenizer分词器详细训练代码。
  • 包含Pretrain、SFT、LoRA、RLHF-DPO、模型蒸馏的全过程训练代码。
  • 收集、蒸馏、整理并清洗去重所有阶段的高质量数据集,且全部开源。
  • 从0实现预训练、指令微调、LoRA、DPO强化学习,白盒模型蒸馏。关键算法几乎不依赖第三方封装的框架,且全部开源。
  • 同时兼容transformerstrlpeft等第三方主流框架。
  • 训练支持单机单卡、单机多卡(DDP、DeepSpeed)训练,支持wandb可视化训练流程。支持动态启停训练。
  • 在第三方测评榜(C-Eval、C-MMLU、OpenBookQA等)进行模型测试。
  • 实现Openai-Api协议的极简服务端,便于集成到第三方ChatUI使用(FastGPT、Open-WebUI等)。
  • 基于streamlit实现最简聊天WebUI前端。

训练数据集下载地址 魔搭社区

创建./dataset目录, 存放训练数据集,该pretrain_hq.jsonl数据集是从 匠数大模型数据集 里清洗出字符<512长度的大约1.6GB的语料直接拼接而成

关于匠数大模型SFT数据集 “, 它是一个完整、格式统一、安全的大模型训练和研究资源。 从网络上的公开数据源收集并整理了大量开源数据集,对其进行了格式统一,数据清洗, 包含10M条数据的中文数据集和包含2M条数据的英文数据集。” 以上是官方介绍,下载文件后的数据总量大约在4B tokens,肯定是适合作为中文大语言模型的SFT数据的。 但是官方提供的数据格式很乱,全部用来sft代价太大。

预训练 pretrain_hq.jsonl 数据格式为

{"text": "如何才能摆脱拖延症? 治愈拖延症并不容易,但以下建议可能有所帮助..."}

关于提高语料质量,有一种基于query-utterance pair拼接方式,Query-Utterance Pair 拼接方式是一种多轮对话上下文建模方法。它将当前的用户输入(query)与历史对话中的某一句或多句用户发言(utterance)配对拼接,作为模型的输入。这种方式的核心是显式地利用对话历史中的关键信息,而不是简单地将所有历史对话拼接在一起。

假设一个多轮对话场景:

  • 历史对话:
    • 用户第1轮:我想订一张去上海的机票。
    • 模型第1轮:好的,请告诉我出发地和日期。
    • 用户第2轮:从北京出发,明天。
  • 当前输入(query):多少钱?

如果直接将所有历史对话拼接,可能得到:

我想订一张去上海的机票。好的,请告诉我出发地和日期。从北京出发,明天。多少钱?

这种方式虽然包含了完整上下文,但信息冗余且缺乏重点,可能导致模型难以聚焦关键信息。

Query-Utterance Pair 拼接则会选择与当前 query 最相关的历史 utterance

去上海,明天。多少钱?

通过这种配对,模型输入被简化为一个更紧凑、更相关的上下文表示

  • 增强上下文相关性
    • 多轮问答任务需要模型理解当前问题与历史对话的关联。Query-Utterance Pair 通过筛选关键历史 utterance,避免无关信息干扰,提升模型对上下文的理解能力。
  • 减少输入长度
    • 完整历史拼接可能导致输入过长(尤其在 Transformer 模型中会增加计算成本或超过最大长度限制)。这种方式减少了输入长度,提高了计算效率。
  • 提升多轮问题准确率
    • 你的实验表明,这种拼接方式使多轮问题准确率提升了 17%,说明它帮助模型更好地捕捉对话中的逻辑链条。例如,“多少钱”与“从北京出发,明天”的配对直接关联到价格查询意图。
  • 提高 BLEU 分数
    • BLEU(Bilingual Evaluation Understudy)衡量生成文本与参考文本的相似度,提升 12% 表明模型生成的回答更贴近预期,这得益于上下文输入的精准性。

sft_mini_512.jsonl 文件的数据格式

{"conversations": [{"role": "user","content": "打开音乐播放器。"}, {"role": "assistant","content": "好的,现在您的音乐播放器已经打开了。请问您想听哪首歌曲呢?"}, {"role": "user","content": "随便放一首歌吧。"}, {"role": "assistant","content": "好的,现在为您播放歌曲。请问还需要我为您做其他的什么吗?"}, {"role": "user","content": "给我推荐一部好看的电影。"}, {"role": "assistant","content": "好的,这里有一部近期口碑还不错的电影《大鱼海棠》。它是一部优秀的国产动画电影,非常值得一看。不知道您是否喜欢呢?"}, {"role": "user","content": "我想知道现在的天气情况。"}, {"role": "assistant","content": "好的,请问您现在的所在地点是哪里?我可以根据您所在的位置来查询天气情况。"}, {"role": "user","content": "上海。"}, {"role": "assistant","content": "好的,现在上海地区的天气情况如下:温度为20℃,多云转晴。请问还需要我为您做其他的什么吗?"}]
}

MiniMindLM 模型结构

MiniMindLM 是一个高效的 Transformer 语言模型,通过 MoE 机制在保持强大表达能力的同时降低计算成本。其结构设计清晰,兼顾训练和推理需求,适合生成任务和资源受限场景 ,MoE 的引入(MoEGate 和 MOEFeedForward)是最大亮点,通过稀疏激活和专家均衡优化,显著提升了模型的性能和可扩展性。查看下模型结构

class MiniMindLM(PreTrainedModel):config_class = LMConfigdef __init__(self, params: LMConfig = None):self.params = params or LMConfig()super().__init__(self.params)self.vocab_size, self.n_layers = params.vocab_size, params.n_layersself.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)self.dropout = nn.Dropout(params.dropout)self.layers = nn.ModuleList([MiniMindBlock(l, params) for l in range(self.n_layers)])self.norm = RMSNorm(params.dim, eps=params.norm_eps)self.output = nn.Linear(params.dim, params.vocab_size, bias=False)self.tok_embeddings.weight = self.output.weightself.register_buffer("pos_cis",precompute_pos_cis(dim=params.dim // params.n_heads, theta=params.rope_theta),persistent=False)self.OUT = CausalLMOutputWithPast()def forward(self,input_ids: Optional[torch.Tensor] = None,past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,use_cache: bool = False,**args):past_key_values = past_key_values or [None] * len(self.layers)start_pos = args.get('start_pos', 0)h = self.dropout(self.tok_embeddings(input_ids))pos_cis = self.pos_cis[start_pos:start_pos + input_ids.size(1)]past_kvs = []for l, layer in enumerate(self.layers):h, past_kv = layer(h, pos_cis,past_key_value=past_key_values[l],use_cache=use_cache)past_kvs.append(past_kv)logits = self.output(self.norm(h))aux_loss = sum(l.feed_forward.aux_loss for l in self.layers if isinstance(l.feed_forward, MOEFeedForward))self.OUT.__setitem__('logits', logits)self.OUT.__setitem__('aux_loss', aux_loss)self.OUT.__setitem__('past_key_values', past_kvs)return self.OUT

该模型是一个基于 Transformer 的语言模型,结合了混合专家模型(Mixture of Experts, MoE)技术,旨在通过高效的计算和稀疏激活提升性能。

整体架构

MiniMindLM 是一个典型的因果语言模型(Causal Language Model),其结构遵循 Transformer 的 Decoder-only 设计,类似于 GPT 系列,但加入了 MoE 机制以提升效率和性能。主要组成部分包括:

  • 输入嵌入层(tok_embeddings):将输入 token 映射为高维向量。
  • 多层 Transformer Block(MiniMindBlock):核心计算单元,包含注意力机制和前馈网络(可选 MoE)。
  • 归一化层(norm):RMSNorm 用于稳定训练。
  • 输出层(output):将隐藏状态映射回词汇表大小的 logits。
  • 位置编码(pos_cis):采用 RoPE(Rotary Position Embedding)来编码序列位置信息。
关键特点
  1. 因果性:通过 CausalLMOutputWithPast 输出,表明这是一个自回归模型,适用于生成任务。
  2. MoE 支持:通过 use_moe 参数控制是否使用 MOEFeedForward,替代传统的 FeedForward,引入稀疏专家机制。
  3. 缓存支持:past_key_values 和 use_cache 参数表明支持增量推理(incremental decoding),优化生成效率。
  4. 共享权重:tok_embeddings.weight = self.output.weight,输入嵌入和输出层的权重共享,减少参数量。

核心组件分析

(1) MiniMindBlock

这是 Transformer 的单层结构,包含以下子模块:

  • 注意力机制(Attention)
    • 使用多头自注意力(Multi-Head Self-Attention),头数由 n_heads 控制,每个头的维度为 head_dim = dim // n_heads。
    • 输入经过 attention_norm(RMSNorm)归一化后,进入注意力计算。
    • 支持缓存(past_key_value),用于加速推理。
    • 输出 h_attn 与输入残差连接(x + h_attn)。
  • 前馈网络(FeedForward 或 MOEFeedForward)
    • 默认使用标准前馈网络(FeedForward),但若 use_moe=True,则切换为 MOEFeedForward。
    • 输入经过 ffn_norm(RMSNorm)归一化后,进入前馈计算。
    • 输出与残差连接(h + feed_forward(...))。
  • 归一化:使用 RMSNorm 而非 LayerNorm,计算效率更高,且稳定性较好。

作用
MiniMindBlock 是模型的核心计算单元,负责捕捉序列中的依赖关系(注意力)和进行特征变换(前馈网络)。MoE 的引入使得前馈部分更高效,仅激活部分专家而非全部参数。

(2) MOEFeedForward

这是混合专家模型的前馈网络实现,替代传统全连接层。主要特点:

  • 专家模块(experts)
    • 包含 n_routed_experts 个独立的前馈网络(FeedForward),每个专家处理特定的输入子集。
  • 门控机制(gate)
    • 通过 MoEGate 决定每个 token 分配给哪些专家(topk_idx)及其权重(topk_weight)。
  • 共享专家(shared_experts)
    • 可选模块(n_shared_experts 不为 None 时启用),为所有 token 提供一个共享的前馈计算,增强通用性。
  • 训练与推理差异
    • 训练模式:输入重复 num_experts_per_tok 次,分别送入对应专家,输出加权求和。
    • 推理模式:通过 moe_infer 函数高效计算,仅激活必要专家。

作用
MOEFeedForward 通过稀疏激活减少计算量,同时利用多个专家捕捉不同模式,提升模型容量和表达能力。aux_loss(辅助损失)用于平衡专家的使用率,避免某些专家被过度忽略。

(3) MoEGate

这是 MoE 的门控机制,负责为每个 token 选择 Top-k 专家。主要逻辑:

  • 线性评分
    • 输入 hidden_states 通过线性层(F.linear)计算与 n_routed_experts 个专家的得分(logits)。
  • 得分归一化
    • 默认使用 softmax 将 logits 转为概率分布(scores)。
  • Top-k 选择
    • 使用 torch.topk 选取得分最高的 top_k 个专家及其权重。
    • 若 norm_topk_prob=True,对 Top-k 权重归一化(和为 1)。
  • 辅助损失(aux_loss)
    • 在训练时计算,用于鼓励专家均衡使用。
    • 有两种模式:
      • seq_aux=True:基于序列级别的专家使用率计算交叉熵。
      • seq_aux=False:基于全局专家使用率计算交叉熵。
    • 损失乘以超参数 alpha,加到总损失中。

作用
MoEGate 是 MoE 的核心调度器,确保每个 token 只激活少量专家(top_k),降低计算成本,同时通过 aux_loss 防止专家使用不均。

(4) MiniMindLM

顶层模型整合所有组件:

  • 输入处理
    • tok_embeddings 将 token ID 转为嵌入向量,加入 dropout。
    • pos_cis(RoPE 位置编码)动态截取,适配输入长度。
  • 层级计算
    • 依次通过 n_layers 个 MiniMindBlock,每层更新隐藏状态并缓存键值对。
  • 输出
    • 经过 norm 归一化后,output 层生成 logits。
    • 若使用 MoE,累加所有层的 aux_loss。

输出格式
CausalLMOutputWithPast 包含 logits(预测分布)、aux_loss(MoE 辅助损失)和 past_key_values(缓存)。

设计亮点

  • MoE 优化
    • 通过 top_k 和 n_routed_experts,模型只激活部分专家,大幅减少计算量。例如,若 n_routed_experts=8,top_k=2,每个 token 只调用 25% 的专家参数。
    • aux_loss 确保专家分配均衡,避免“专家坍缩”(某些专家从未被使用)。
  • 高效推理
    • moe_infer 使用 scatter_add_ 高效聚合专家输出,避免显式循环。
    • 缓存机制(past_key_values)支持自回归生成,适合对话或文本生成任务。
  • 灵活性
    • use_moe 参数允许切换传统 FFN 和 MoE FFN,便于实验对比。
    • n_shared_experts 提供通用专家,弥补稀疏专家的局限性。
  • 稳定性
    • RMSNorm 和 Kaiming 初始化(reset_parameters)提升训练稳定性。
    • 权重共享(嵌入和输出层)减少参数量,适合资源受限场景。
  • 计算复杂度
    • 传统 Transformer 的 FFN 复杂度为 O(bsz⋅seqlen⋅dim2)O(bsz \cdot seq_len \cdot dim^2)O(bsz⋅seql​en⋅dim2)。
    • MoE 模式下,每个 token 只激活 top_k 个专家,复杂度降为 O(bsz⋅seqlen⋅dim⋅topk⋅nroutedexperts/totalexperts)O(bsz \cdot seq_len \cdot dim \cdot top_k \cdot n_routed_experts / total_experts)O(bsz⋅seql​en⋅dim⋅topk​⋅nr​outede​xperts/totale​xperts),显著降低。
  • 内存需求:增加 n_routed_experts 会提升参数量,但实际激活的参数量由 top_k 控制,内存占用可控。
  • 训练开销:aux_loss 引入额外计算,但对性能提升至关重要,尤其在专家数量较多时。

评估下minimind的训练参数量

计算 MiniMindLM 的训练参数量,我们需要分析其所有可训练的模块,并根据代码中的配置参数(LMConfig)推导出具体的参数数量。按照默认的LMConfig

class LMConfig(PretrainedConfig):model_type = "minimind"def __init__(self,dim: int = 512,n_layers: int = 8,n_heads: int = 8,n_kv_heads: int = 2,vocab_size: int = 6400,hidden_dim: int = None,multiple_of: int = 64,norm_eps: float = 1e-5,max_seq_len: int = 8192,rope_theta: int = 1e6,dropout: float = 0.0,flash_attn: bool = True,##################################################### Here are the specific configurations of MOE# When use_moe is false, the following is invalid####################################################use_moe: bool = False,####################################################num_experts_per_tok: int = 2,n_routed_experts: int = 4,n_shared_experts: bool = True,scoring_func: str = 'softmax',aux_loss_alpha: float = 0.1,seq_aux: bool = True,norm_topk_prob: bool = True,**kwargs,)

从 LMConfig 中提取关键参数:

  • dim = 512(隐藏层维度)。
  • n_layers = 8(Transformer 层数)。
  • n_heads = 8(注意力头数)。
  • n_kv_heads = 2(键值头的数量,可能用于分组查询注意力 GQA,但这里先按标准计算)。
  • vocab_size = 6400(词汇表大小)。
  • hidden_dim = None(未指定,假设前馈网络中间层维度为 4 * dim,即 2048)。
  • max_seq_len = 8192(最大序列长度,仅影响缓冲区,不影响参数量)。
  • use_moe = False(默认不使用 MoE)。
  • MoE 相关参数(仅在 use_moe=True 时生效):
    • num_experts_per_tok = 2(每个 token 激活的专家数,Top-k)。
    • n_routed_experts = 4(路由专家数量)。
    • n_shared_experts = True(布尔值,但代码中应为整数,假设为 1)。
  • norm_eps 和 dropout 等不影响参数量。

由于 use_moe 默认值为 False,我将先计算非 MoE 模式下的参数量,然后再计算 use_moe=True 的情况以作对比。


2. 参数量计算(use_moe=False)

(1) 输入嵌入层(tok_embeddings)
  • 结构:nn.Embedding(vocab_size, dim)。
  • 参数量:vocab_size * dim = 6400 * 512 = 3,276,800。
  • 说明:嵌入层和输出层共享权重,因此只计算一次。
(2) 输出层(output)
  • 结构:nn.Linear(dim, vocab_size, bias=False)。
  • 参数量:dim * vocab_size = 512 * 6400 = 3,276,800。
  • 共享权重后,总嵌入参数仍为 3,276,800。
(3) MiniMindBlock(每层)

每层包含注意力模块、前馈网络和两个 RMSNorm。

注意力模块(Attention)
  • 假设为标准多头自注意力(未明确使用 GQA,但 n_kv_heads=2 暗示可能优化 KV 计算,暂按标准计算):
    • QKV 线性变换
      • 输入 dim,输出 dim(n_heads * head_dim,head_dim = 512 // 8 = 64)。
      • 参数量:dim * dim * 3 = 512 * 512 * 3 = 786,432。
    • 输出线性变换
      • 参数量:dim * dim = 512 * 512 = 262,144。
    • 总计:786,432 + 262,144 = 1,048,576。
RMSNorm(attention_norm 和 ffn_norm)
  • 每个 RMSNorm:dim = 512。
  • 两个 RMSNorm:2 * 512 = 1,024。
前馈网络(FeedForward)
  • 假设为标准两层 MLP,中间层维度 ffn_dim = 4 * dim = 2048(常见设置):
    • 第一层:dim -> ffn_dim,参数量 512 * 2048 = 1,048,576。
    • 第二层:ffn_dim -> dim,参数量 2048 * 512 = 1,048,576。
    • 无偏置假设,总计:1,048,576 + 1,048,576 = 2,097,152。
单层总参数量
  • 注意力:1,048,576。
  • 前馈:2,097,152。
  • RMSNorm:1,024。
  • 总计:1,048,576 + 2,097,152 + 1,024 = 3,146,752。
(4) 所有层
  • n_layers = 8。
  • 总计:8 * 3,146,752 = 25,174,016。
(5) 顶层 RMSNorm(norm)
  • 参数量:dim = 512。
总参数量(use_moe=False)


3. 参数量计算(use_moe=True)

假设 use_moe=True,并使用 MoE 参数:

  • n_routed_experts = 4。
  • n_shared_experts = 1(将布尔值 True 视为 1)。
(1) 输入嵌入层和输出层
  • 同上:3,276,800。
(2) MiniMindBlock(每层)

注意力模块和 RMSNorm 不变,变化在于 MOEFeedForward。

注意力模块
  • 同上:1,048,576。
RMSNorm
  • 同上:1,024。
MOEFeedForward
  • 专家网络(experts)
    • n_routed_experts = 4,每个专家是一个 FeedForward。
    • 单个专家:2,097,152(如上计算)。
    • 总计:4 * 2,097,152 = 8,388,608。
  • 共享专家(shared_experts)
    • n_shared_experts = 1,参数量:2,097,152。
  • 门控机制(MoEGate)
    • 权重:n_routed_experts * dim = 4 * 512 = 2,048。
  • MOEFeedForward 总计
    • 8,388,608 + 2,097,152 + 2,048 = 10,487,808。
单层总参数量
  • 注意力:1,048,576。
  • 前馈(MoE):10,487,808。
  • RMSNorm:1,024。
  • 总计:1,048,576 + 10,487,808 + 1,024 = 11,537,408。
(3) 所有层
  • n_layers = 8。
  • 总计:8 * 11,537,408 = 92,299,264。
(4) 顶层 RMSNorm
  • 同上:512。
总参数量(use_moe=True)


4. 结果对比

  • use_moe=False28,451,328 参数(约 28.45M)。
  • use_moe=True(n_routed_experts=4, n_shared_experts=1):95,576,576 参数(约 95.58M)。 后面可以看下模型文件大小满足该理论值

开启预训练 

python train_pretrain.py   预训练(学知识)

python train_full_sft.py 监督微调(学对话方式)

测试模型效果

确保需要测试的模型*.pth文件位于./out/目录下

# 默认为0:测试pretrain模型效果,设置为1:测试full_sft模型效果
python eval_model.py --model_mode 1

自动测试

模型转换下格式方便在 webui上使用

(spatiallm) [root@node126 minimind]# cd scripts/
(spatiallm) [root@node126 scripts]# python convert_model.py
模型参数: 25.829888 百万 = 0.025829888 B (Billion)
模型已保存为 Transformers 格式: ../MiniMind2-Small

修改下 web_demo.py里模型路径映射

# 模型路径映射
MODEL_PATHS = {"MiniMind2-Small (0.025829888 B)": ["../MiniMind2-Small", "MiniMind2-Small"],
}
selected_model = st.sidebar.selectbox('Models', list(MODEL_PATHS.keys()), index=0)

看下web demo的提示词是怎么写的

分析下是怎么组织提示词和关联多轮对话的

def setup_seed(seed):random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsedef main():model, tokenizer = load_model_tokenizer(model_path)# 初始化消息列表if "messages" not in st.session_state:st.session_state.messages = []st.session_state.chat_messages = []# Use session state messagesmessages = st.session_state.messages# 在显示历史消息的循环中for i, message in enumerate(messages):if message["role"] == "assistant":with st.chat_message("assistant", avatar=image_url):st.markdown(process_assistant_content(message["content"]), unsafe_allow_html=True)if st.button("×", key=f"delete_{i}"):# 删除当前消息及其之后的所有消息st.session_state.messages = st.session_state.messages[:i - 1]st.session_state.chat_messages = st.session_state.chat_messages[:i - 1]st.rerun()else:st.markdown(f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px;  background-color: gray; border-radius: 10px; color:white; ">{message["content"]}</div></div>',unsafe_allow_html=True)# 处理新的输入或重新生成prompt = st.chat_input(key="input", placeholder="给 MiniMind 发送消息")# 检查是否需要重新生成if hasattr(st.session_state, 'regenerate') and st.session_state.regenerate:prompt = st.session_state.last_user_messageregenerate_index = st.session_state.regenerate_index  # 获取重新生成的位置# 清除所有重新生成相关的状态delattr(st.session_state, 'regenerate')delattr(st.session_state, 'last_user_message')delattr(st.session_state, 'regenerate_index')if prompt:st.markdown(f'<div style="display: flex; justify-content: flex-end;"><div style="display: inline-block; margin: 10px 0; padding: 8px 12px 8px 12px;  background-color: gray; border-radius: 10px; color:white; ">{prompt}</div></div>',unsafe_allow_html=True)messages.append({"role": "user", "content": prompt})st.session_state.chat_messages.append({"role": "user", "content": prompt})with st.chat_message("assistant", avatar=image_url):placeholder = st.empty()random_seed = random.randint(0, 2 ** 32 - 1)setup_seed(random_seed)st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[-(st.session_state.history_chat_num + 1):]new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages,tokenize=False,add_generation_prompt=True)[-(st.session_state.max_new_tokens - 1):]x = torch.tensor(tokenizer(new_prompt)['input_ids'], device=device).unsqueeze(0)with torch.no_grad():res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=st.session_state.max_new_tokens,temperature=st.session_state.temperature,top_p=st.session_state.top_p, stream=True)try:for y in res_y:answer = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)if (answer and answer[-1] == '�') or not answer:continueplaceholder.markdown(process_assistant_content(answer), unsafe_allow_html=True)except StopIteration:print("No answer")assistant_answer = answer.replace(new_prompt, "")messages.append({"role": "assistant", "content": assistant_answer})st.session_state.chat_messages.append({"role": "assistant", "content": assistant_answer})with st.empty():if st.button("×", key=f"delete_{len(messages) - 1}"):st.session_state.messages = st.session_state.messages[:-2]st.session_state.chat_messages = st.session_state.chat_messages[:-2]st.rerun()if __name__ == "__main__":from transformers import AutoModelForCausalLM, AutoTokenizermain()

基于 Streamlit 的交互式对话界面,使用 MiniMindLM 自回归语言模型(通过 transformers.AutoModelForCausalLM 加载)进行多轮对话。

  • 处理输入:通过 st.chat_input 获取用户输入,生成提示词,调用模型生成回答,并更新会话状态。
  • 多轮对话:通过 st.session_state.chat_messages 维护对话历史,关联上下文。

提示词组织方式

提示词的构建主要发生在用户输入 prompt 后,通过以下步骤生成并传递给模型:

(1) 会话状态管理
  • st.session_state.messages
    • 存储所有对话消息,格式为 [{"role": "user/assistant", "content": "..."}, ...]。
    • 用于渲染历史消息和支持删除功能。
  • st.session_state.chat_messages
    • 与 messages 类似,但专门用于构建提示词,可能包含系统提示(system_prompt)和裁剪后的历史。
    • 通过 -(st.session_state.history_chat_num + 1) 限制历史长度。
(2)系统提示与历史拼接

st.session_state.chat_messages = system_prompt + st.session_state.chat_messages[ -(st.session_state.history_chat_num + 1):]

  • 系统提示(system_prompt)
    • 未在代码中显式定义,假设是一个预定义的列表(如 [{"role": "system", "content": "You are a helpful assistant."}])。
    • 作为对话的初始上下文,定义模型行为。
  • 历史裁剪
    • history_chat_num 控制保留的历史对话轮数(未定义,假设为一个整数,如 5)。
    • -(history_chat_num + 1) 从 chat_messages 末尾取最近的若干轮对话,加上当前输入。
    • 例如,若 history_chat_num=2,则保留最近 2 轮对话 + 当前输入。

(3) 提示词模板化

new_prompt = tokenizer.apply_chat_template(st.session_state.chat_messages,tokenize=False,add_generation_prompt=True
)[-(st.session_state.max_new_tokens - 1):]
    • 假设模板为简单拼接(如 <|system|>... <|user|>... <|assistant|>),最终生成类似:

      <|system|>You are a helpful assistant.<|user|>Hello!<|assistant|>Hi there!<|user|>What's the weather?

    • 长度截断:
      • -(max_new_tokens - 1) 限制提示词长度,确保加上生成 token 后不超过 max_new_tokens。
      • 若历史过长,只保留末尾部分,防止溢出。

    多轮对话关联机制

    多轮对话的上下文通过以下方式关联和维护:

    (1) 会话状态的持久化
    • Streamlit 的 st.session_state 是一个持久化的状态存储,跨页面刷新保留数据。
    • messages 和 chat_messages 在会话开始时初始化,并在每次用户输入或模型回复后更新。
    • 示例:
      • 用户输入 "Hello" → messages.append({"role": "user", "content": "Hello"})。
      • 模型回复 "Hi there!" → messages.append({"role": "assistant", "content": "Hi there!"})。
    (2) 历史消息的动态管理
    • 显示历史
      • 循环遍历 messages,根据 role 渲染用户或助手消息。
      • 支持删除:点击 "×" 按钮,截断 messages 和 chat_messages 到指定位置。
    • 重新生成支持
      • 若 st.session_state.regenerate=True,从 last_user_message 重新生成回答,并清除相关状态。
    (3) 上下文传递
    • chat_messages 将系统提示和最近历史拼接,确保模型接收到完整的上下文。
    • 示例:
      • 系统提示:[{"role": "system", "content": "You are a helpful assistant"}]
      • 第1轮:用户 "Hello" → 助手 "Hi there!"
      • 第2轮:用户 "What's next?" →
        • chat_messages = [{"role": "system", ...}, {"role": "user", "Hello"}, {"role": "assistant", "Hi there!"}, {"role": "user", "What's next?"}]
        • 模板化后:You are a helpful assistant. <|user|>Hello<|assistant|>Hi there!<|user|>What's next?

    webui测试结果

    测试下 Top-P 和 Temperature, 效果比较明显 Temperature 越大模型的发散思考能力越高,给出的回答更有创造性,也伴随着模型幻觉问题

    本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/899934.shtml

    如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

    相关文章

    Spring Boot 与 Spring Integration 整合教程

    精心整理了最新的面试资料和简历模板&#xff0c;有需要的可以自行获取 点击前往百度网盘获取 点击前往夸克网盘获取 Spring Boot 与 Spring Integration 整合教程 简介 Spring Integration 是 Spring 生态系统中用于实现企业集成模式&#xff08;Enterprise Integration Pa…

    Nginx 核心配置详解与性能优化最佳实践

    1.什么是 Nginx&#xff1f; Nginx 是一个高性能的 Web 服务器和反向代理服务器。它轻量、高效&#xff0c;被广泛用于现代 Web 开发中。 2.为什么前端需要了解 Nginx&#xff1f; ★ 了解 本地开发&#xff1a;可以模拟生产环境 部署前端项目&#xff1a;作为静态文件服务器…

    LayaAir3.3.0-beta.3重磅更新!Spine4.2、2D物理、UI系统、TileMap等全面升级!

    正式版推出前&#xff0c;说明3.3的功能还没开发完。所以&#xff0c;又一大波更新来了~ 下面对重点更新进行说明。 Spine的重要更新 3.3.0-beta.3版本开始&#xff0c;新增了Spine 4.2 的运行时库&#xff0c;Spine动画上可以支持物理特性了。例如&#xff0c;下图右侧女孩在启…

    pip安装timm依赖失败

    在pycharm终端给虚拟环境安装timm库失败&#xff08; pip install timm&#xff09;&#xff0c;提示你要访问 https://rustup.rs/ 来下载并安装 Rust 和 Cargo 直接不用管&#xff0c;换一条命令 pip install timm0.6.13 成功安装 简单粗暴

    BUUCTF-web刷题篇(7)

    16.BackupFile 题目提示backupfile&#xff0c;是备份文件的意思&#xff1a; 查看源码没有什么有用信息&#xff0c;也没有登录界面&#xff0c;所以也不会用到蚁剑链接来找备份文件&#xff0c;所以大概率就是通过构造playload来查找备份文件。 注&#xff1a;备份文件常用…

    Maven 构建生命周期

    Maven 构建生命周期 引言 Maven 是一个强大的项目管理和构建自动化工具,广泛应用于 Java 开发领域。Maven 的核心概念之一是构建生命周期,它定义了从项目创建到构建、测试、打包、部署等一系列操作的流程。本文将详细介绍 Maven 的构建生命周期,帮助读者更好地理解和使用 …

    PyTorch 深度学习实战(29):目标检测与 YOLOv12 实战

    在上一篇文章中,我们探讨了对比学习与自监督表示学习。本文将深入计算机视觉的核心任务之一——目标检测,重点介绍最新的 YOLOv12 (You Only Look Once v12) 算法。我们将使用 PyTorch 实现 YOLOv12 模型,并在 COCO 数据集上进行训练和评估。 一、YOLOv12 基础 YOLOv12 是 …

    使用Leaflet对的SpringBoot天地图路径规划可视化实践-以黄花机场到橘子洲景区为例

    目录 前言 一、路径规划需求 1、需求背景 2、技术选型 3、功能简述 二、Leaflet前端可视化 1、内容布局 2、路线展示 3、转折路线展示 三、总结 前言 在当今数字化与智能化快速发展的时代&#xff0c;路径规划技术已经成为现代交通管理、旅游服务以及城市规划等领域的…

    深入理解 CSS 选择器:从基础到高级的样式控制

    引言 在网页设计与开发中&#xff0c;CSS&#xff08;层叠样式表&#xff09;扮演着至关重要的角色&#xff0c;它赋予了 HTML 页面丰富的视觉效果和交互性。而 CSS 选择器则是 CSS 的核心机制之一&#xff0c;通过选择器&#xff0c;我们能够精准地指定要应用样式的 HTML 元素…

    GitHub与Gitee各是什么?它们的区别与联系是什么?

    李升伟 整理 GitHub 介绍 GitHub 是一个基于 Git 的代码托管平台&#xff0c;主要用于版本控制和协作开发。它支持多人协作&#xff0c;提供代码托管、问题跟踪、代码审查、项目管理等功能。GitHub 是全球最大的开源社区&#xff0c;许多知名开源项目都在此托管。 主要功能&…

    ESLint语法报错

    ESLint语法报错 运行报错 You may use special comments to disable some warnings. Use // eslint-disable-next-line to ignore the next line. Use /* eslint-disable */ to ignore all warnings in a file.解决方案 关闭eslint的语法检测&#xff0c;在eslintrc.js文件中…

    单例模式与线程安全

    目录 线程安全和重⼊问题 死锁和活锁 死锁 死锁四个必要条件 活锁 STL,智能指针和线程安全 线程安全的单例模式 饿汉模式 懒汉模式 懒汉模式实现单例模式(线程安全版本) 饿汉模式实现单例模式 我们来学习单例模式与线程安全 线程安全和重⼊问题 线程安全&#xff…

    Python+AI提示词用贝叶斯样条回归拟合BSF方法分析樱花花期数据模型构建迹图、森林图可视化

    原文链接&#xff1a;https://tecdat.cn/?p41308 在数据科学的领域中&#xff0c;我们常常会遇到需要处理复杂关系的数据。在众多的数据分析方法中&#xff0c;样条拟合是一种非常有效的处理数据非线性关系的手段。本专题合集围绕如何使用PyMC软件&#xff0c;对樱花花期数据进…

    WPF学习路线

    WPF学习路线 学习准备学习技术栈学习路线 1-5&#xff08;1-2周&#xff09;6-8&#xff08;3-5周&#xff09; 学习准备 个人认为前端技术一般几个关键字&#xff1a;元素资源 控制元素资源组合或者动态交互 数据交互呈现分析关键字得到的就是几个方向 布局 样式 组装资源控件…

    31天Python入门——第20天:魔法方法详解

    你好&#xff0c;我是安然无虞。 文章目录 魔法方法1. __new__和__del__2. __repr__和__len__3. __enter__和__exit__4. 可迭代对象和迭代器5. 中括号[]数据操作6. __getattr__、__setattr__ 和 __delattr__7. 可调用的8. 运算符 魔法方法 魔法方法: Python中的魔法方法是一类…

    栈 —— 数据结构基础刷题路程

    一、P1739 表达式括号匹配 - 洛谷 算法代码&#xff1a; #include<bits/stdc.h> using namespace std; const int N300008; struct mystack {int a[N];int t-1;//压栈void push(int data){a[t]data; } //取栈顶元素int top(){return a[t]; } //弹出栈顶元素void pop(){i…

    瑞昱RTD2556QR显示器驱动芯片

    一、概述 RTD2556QR芯片是由Realtek公司精心研发的一款高性能显示驱动芯片&#xff0c;专为满足现代显示设备对高分辨率、多功能接口及稳定性能的需求而设计。该芯片凭借其卓越的技术特性和广泛的应用领域&#xff0c;在显示驱动市场中占据重要地位。它集成了多种先进的功能模…

    PyQt5和OpenCV车牌识别系统

    有需要请加文章底部Q哦 可远程调试 PyQt5和OpenCV车牌识别系统 一 介绍 此车牌识别系统基于PyQt5和OpenCV开发&#xff0c;蓝牌&#xff0c;新能源(绿牌)&#xff0c;黄牌&#xff0c;白牌均可以准确识别&#xff0c;支持中文识别&#xff0c;可以导出识别结果(Excel格式)。此…

    学有所记- 探索FastAPI在docker上的部署

    目标&#xff1a; 学习怎样在docker中安装部署FastAPI&#xff0c;完成项目结构的搭建以及hello world的运行 背景&#xff1a; 公司内服务器资源有限&#xff0c;为了共享算力资源&#xff0c;同时又能隔离运行环境&#xff0c;因此采用了docker部署的方式&#xff0c;进行各…

    HTTP keepalive 详解

    一、简介 HTTP协议早期版本&#xff0c;比如1.0&#xff0c;默认是不使用持久连接的&#xff0c;也就是每个请求/响应之后都会关闭TCP连接。这样的话&#xff0c;每次请求都需要重新建立连接&#xff0c;增加了延迟和资源消耗。Keep-Alive的作用是保持连接&#xff0c;让多个请…