专门做离异相亲的网站/专业营销策划团队

专门做离异相亲的网站,专业营销策划团队,高端建站用什么软件,西安网站建设公司电话去做具体的事,然后稳稳托举自己 —— 25.3.17 数据文件: 通过网盘分享的文件:Ner命名实体识别任务 链接: https://pan.baidu.com/s/1fUiin2um4PCS5i91V9dJFA?pwdyc6u 提取码: yc6u --来自百度网盘超级会员v3的分享 一、配置文件 config.py …

去做具体的事,然后稳稳托举自己

                                        —— 25.3.17

数据文件:

通过网盘分享的文件:Ner命名实体识别任务
链接: https://pan.baidu.com/s/1fUiin2um4PCS5i91V9dJFA?pwd=yc6u 提取码: yc6u 
--来自百度网盘超级会员v3的分享

一、配置文件 config.py

1.模型与数据路径

model_path:模型训练完成后保存的位置。例如:保存最终的模型权重文件。

schema_path:数据结构定义文件,通常用于描述数据的格式(如字段名、标签类型)。
在NER任务中,可能定义实体类别(如 {"PERSON": "人名", "ORG": "组织"})。

train_data_path:训练数据集路径,通常为标注好的文本文件(如 train.txt 或 JSON 格式)。

valid_data_path: 验证数据集路径,用于模型训练时的性能评估和超参数调优。

vocab_path:​字符词汇表文件,记录模型中使用的字符集(如中文字符、字母、数字等)。


2.模型架构

max_length:输入文本的最大序列长度。超过此长度的文本会被截断或填充(如用 [PAD])。

hidden_size:模型隐藏层神经元的数量,影响模型容量和计算复杂度。

num_layers:模型的堆叠层数(如LSTM、Transformer的编码器/解码器层数)。

class_num:任务类别总数。例如:NER任务中可能有9种实体类型。

vocab_size:词表大小


3.训练配置

epoch:训练轮数。每轮遍历整个训练数据集一次。

batch_size:每次梯度更新所使用的样本数量。较小的批次可能更适合内存受限的环境。

optimizer:优化器类型,用于调整模型参数。Adam是常用优化器,结合动量梯度下降。

learning_rate:学习率,控制参数更新的步长。值过小可能导致训练缓慢,过大易过拟合。

use_crf:是否启用条件随机场(CRF)​层。在序列标注任务(如NER)中,CRF可捕捉标签间的依赖关系,提升准确性。


4.预训练模型

bert_path:预训练BERT模型的路径。BERT是一种强大的预训练语言模型,此处可能用于微调或特征提取。

# -*- coding: utf-8 -*-"""
配置参数信息
"""Config = {"model_path": "model_output","schema_path": "ner_data/schema.json","train_data_path": "ner_data/train","valid_data_path": "ner_data/test","vocab_path":"chars.txt","max_length": 100,"hidden_size": 256,"num_layers": 2,"epoch": 20,"batch_size": 16,"optimizer": "adam","learning_rate": 1e-3,"use_crf": False,"class_num": 9,"bert_path": r"F:\人工智能NLP/NLP资料\week6 语言模型/bert-base-chinese","vocab_size": 20000
}

二、数据加载 loader.py

1.初始化数据加载类

def __init__(self, data_path, config):构造函数接收数据路径和配置对象。

data_path:数据文件存储路径

config:包含训练 / 数据配置的字典

self.config:保存包含训练 / 数据配置的字典

self.path:保存数据文件存储路径

self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具

self.sentences:初始化句子列表

self.schema:加载实体标签与索引的映射关系表

self.load:调用 load() 方法从 data_path 加载原始数据,进行分词、编码、填充/截断等预处理。

    def __init__(self, data_path, config):self.config = configself.path = data_pathself.tokenizer = load_vocab(config["bert_path"])self.sentences = []self.schema = self.load_schema(config["schema_path"])self.load()

2.加载数据并预处理

① 初始化数据容器 ——>

② 文件读取与分段处理 ——>

③ 逐段解析字符与标签 ——>

④ 句子编码与填充 ——>

⑤ 数据封装与返回 

self.path:数据文件的存储路径(如 train.txt),由类初始化时传入的 data_path 参数赋值。

f:文件对象,用于读取 self.path 指向的原始数据文件。

segments:是按双换行符分隔的段落列表,每个段落对应一个样本(如一个句子及其标注序列)。

segment:遍历 segments 时的单个样本段落,进一步按行分割处理为字符和标签

labels:存储当前样本的标签序列,[8]可能表示 [CLS] 标记的 ID,用于序列起始符,之后将每个字符的标签转换为ID。

char:当前行的字符(如 "中"),属于句子中的一个基本单元。

lable:当前行的原始标签字符串(如 "B-LOC"),​尚未映射为 ID

input_ids:将字符序列编码为模型输入所需的 ID 序列(如 BERT 分词后的 Token ID)

self.data:列表,存储预处理后的数据样本,每个样本由输入张量和标签张量组成

sentence:由字符列表拼接而成的完整句子(如 "中国科技大学"),存入 self.sentences 供后续可视化或调试。

open():打开文件并返回文件对象,支持读/写/追加等模式。

参数名类型说明
file字符串文件路径(绝对/相对路径)
mode字符串打开模式(如 r-只读、w-写入、a-追加)
encoding字符串文件编码(如 utf-8,文本模式需指定)
errors字符串编码错误处理方式(如 ignorereplace

文件对象.read():读取文件内容,返回字符串或字节流

参数名类型说明
size整数可选,指定读取的字节数(默认读取全部内容)

split():按分隔符分割字符串,返回子字符串列表

参数名类型说明
delimiter字符串分隔符(默认空格)
maxsplit整数可选,最大分割次数(默认-1表示全部)

strip():去除字符串首尾指定字符(默认空白字符)

参数名类型说明
chars字符串可选,指定需去除的字符集合

join():用分隔符连接可迭代对象的元素,返回新字符串

参数名类型说明
iterable可迭代对象需连接的元素集合(如列表、元组)
sep字符串分隔符(默认空字符串)

列表.append():在列表末尾添加元素

参数名类型说明
obj任意类型要添加的元素
    def load(self):self.data = []with open(self.path, encoding="utf8") as f:segments = f.read().split("\n\n")for segment in segments:sentence = []labels = [8]  # cls_tokenfor line in segment.split("\n"):if line.strip() == "":continuechar, label = line.split()sentence.append(char)labels.append(self.schema[label])sentence = "".join(sentence)self.sentences.append(sentence)input_ids = self.encode_sentence(sentence)labels = self.padding(labels, -1)# print(self.decode(sentence, labels))# input()self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])return

3.加载字 / 词表

vocab_path:字 / 词表的存储路径

BertTokenizer.from_pretrained():Hugging Face Transformers 库中用于加载预训练 BERT 分词器的核心方法。它支持从 Hugging Face 模型库或本地路径加载预训练的分词器,并允许通过参数调整分词行为。

参数名类型默认值说明
pretrained_model_name_or_pathstr必填预训练模型名称(如 bert-base-uncased)或本地路径。若为名称,自动从 Hugging Face 下载
cache_dirstrNone模型缓存目录。若指定,下载的模型文件会存储在此路径下
force_downloadboolFalse是否强制重新下载模型,即使本地已缓存
resume_downloadboolFalse是否断点续传下载任务
do_lower_caseboolTrue(英文模型)是否将文本转为小写。​中文模型需注意:若设为 False,可能导致英文单词被识别为 [UNK]
add_special_tokensboolTrue是否在输入文本中添加 [CLS] 和 [SEP] 等特殊标记
tokenize_chinese_charsboolTrue是否对中文字符进行逐字分词(如将“你好”拆分为“你”和“好”)
strip_accentsboolNone是否去除重音符号(如将 é 转换为 e
use_fastboolTrue是否启用快速分词模式(基于 Rust 实现,速度更快)
def load_vocab(vocab_path):return BertTokenizer.from_pretrained(vocab_path)

4.加载映射关系表 

        加载位于指定路径的 JSON 格式的模式文件,并将其内容解析为 Python 对象以便在数据生成过程中使用。

path:映射关系表schema的存储路径

open():打开文件并返回文件对象,用于读写文件内容。

参数名类型默认值说明
file_namestr文件路径(需包含扩展名)
modestr'r'文件打开模式:
'r': 只读
'w': 只写(覆盖原文件)
'a': 追加写入
'b': 二进制模式
'x': 创建新文件(若存在则报错)
bufferingintNone缓冲区大小(仅二进制模式有效)
encodingstrNone文件编码(仅文本模式有效,如 'utf-8'
newlinestr'\n'行结束符(仅文本模式有效)
closefdboolTrue是否在文件关闭时自动关闭文件描述符
dir_fdint-1文件描述符(高级用法,通常忽略)
flagsint0Linux 系统下的额外标志位
modestr(重复参数,实际使用中只需指定 mode

json.load():从已打开的 JSON 文件对象中加载数据,并将其转换为 Python 对象(如字典、列表)。

参数名类型默认值说明
fpio.TextIO已打开的文件对象(需处于读取模式)
indentint/strNone缩进空格数(美化输出,如 4 或 " "
sort_keysboolFalse是否对 JSON 键进行排序
load_hookcallableNone自定义对象加载回调函数
object_hookcallableNone自定义对象解析回调函数
    def load_schema(self, path):with open(path, encoding="utf8") as f:return json.load(f)

5.封装数据

Ⅰ、初始化DataGenerator:初始化DataGenerator实例dg,传入data_path和config

Ⅱ、创建 DataLoader 对象:创建DataLoader实例dl,使用dg、batch_size和shuffle参数

Ⅲ、返回 DataLoader 迭代器:返回dl

data_path:数据文件的路径(如 train.txt),用于初始化 DataGenerator,指向原始数据文件。

config:配置参数字典,通常包含 batch_sizebert_pathschema_path 等参数,用于控制数据加载逻辑。

dg:自定义数据集对象,继承 torch.utils.data.Dataset,负责数据加载、预处理和样本生成。

dl:封装 DataGenerator 的迭代器,实现批量加载、多进程加速等功能,直接用于模型训练。

DataLoader():PyTorch 模型训练的标配工具,通过合理的参数配置(如 batch_sizenum_workersshuffle),可以显著提升数据加载效率,尤其适用于大规模数据集和复杂预处理任务。其与 Dataset 类的配合使用,是构建高效训练管道的核心。

参数名类型默认值说明
datasetDatasetNone必须参数,自定义数据集对象(需继承 torch.utils.data.Dataset)。
batch_sizeint1每个批次的样本数量。
shuffleboolFalse是否在每个 epoch 开始时打乱数据顺序(训练时推荐设为 True)。
num_workersint0使用多线程加载数据的工人数量(需大于 0 时生效)。
pin_memoryboolFalse是否将数据存储在 pinned memory 中(加速 GPU 数据传输)。
drop_lastboolFalse如果数据集长度无法被 batch_size 整除,是否丢弃最后一个不完整的批次。
persistent_workersboolFalse是否保持工作线程在 epoch 之间持续运行(减少多线程初始化开销)。
worker_init_fncallableNone自定义工作线程初始化函数。
# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dl

6.对于输入文本做截断 / 填充

Ⅰ、截断过长序列​(超过预设最大长度)

Ⅱ、填充过短序列​(用 pad_token 补齐到预设最大长度)

    #补齐或截断输入的序列,使其可以在一个batch内运算def padding(self, input_id, pad_token=0):input_id = input_id[:self.config["max_length"]]input_id += [pad_token] * (self.config["max_length"] - len(input_id))return input_id

7.类内魔术方法

self.data:表示数据集对象本身存储的数据容器

index:表示数据集中某个样本的索引值,用于定位并返回特定位置的样本。

__len__():用于定义对象的“长度”,通过内置函数 len() 调用时返回该值。它通常用于容器类(如列表、字典、自定义数据结构),表示容器中元素的个数

__getitem__():允许对象通过索引或键值访问元素,支持 obj[index] 或 obj[key] 语法。它使对象表现得像序列(如列表)或映射(如字典)

    def __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]

8.对于输入的文本编码

调用分词器编码(参数控制标准化)

self.tokenizer:将文本数据转换为深度学习模型(如 BERT)可处理的输入格式的核心工具

self.tokenizer.encode():Hugging Face Transformers 库中 BertTokenizer 的核心方法,用于将原始文本转换为模型可处理的输入形式。

参数名类型默认值说明
textstr 或 List[str]必填输入文本(单句或句子对)。
text_pairstrNone第二段文本(用于句子对任务,如问答),与 text 拼接后生成 [CLS] text [SEP] text_pair [SEP]
add_special_tokensboolTrue是否添加 [CLS] 和 [SEP] 标记。关闭后仅返回原始分词索引
max_lengthint512最大序列长度。超长文本会被截断,不足则填充
paddingstr 或 boolFalse填充策略:True/'longest'(按批次最长填充)、'max_length'(按 max_length 填充)
truncationstr 或 boolFalse截断策略:True(按 max_length 截断)、'only_first'(仅截断第一句)
return_tensorsstrNone

返回张量类型:

'pt'(PyTorch)、'tf'(TensorFlow)、'np'(NumPy)

return_attention_maskboolTrue是否生成 attention_mask,标识有效内容(1)与填充部分(0)
    def encode_sentence(self, text, padding=True):return self.tokenizer.encode(text,padding="max_length",max_length=self.config["max_length"],truncation=True)

9.对于编码后的输入文本作解码

(04+): 匹配以 0(B-LOCATION)开头,后接多个 4(I-LOCATION)的连续标签

(15+)(26+)(37+)分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。

sentence:输入的原句(添加 $ 后的版本),用于根据标签索引提取实体文本。

lables:模型输出的标签序列,转换为字符串后通过正则匹配定位实体位置。

results:存储提取的实体,键为实体类型(如 "LOCATION"),值为该类型实体的文本列表。

location:正则匹配结果,通过 span() 获取实体在 sentence 中的起止位置,用于提取具体文本片段。

join():将可迭代对象(列表、元组等)中的元素按指定分隔符连接成一个字符串。调用该方法的字符串作为分隔符。

参数名类型默认值说明
iterable可迭代对象必填需连接的元素集合,所有元素必须是字符串类型。若为空,返回空字符串。

str():将其他数据类型(整数、浮点数、布尔值等)转换为字符串类型。支持格式化输出和复杂对象的字符串表示。

参数名类型默认值说明
object任意类型必填需转换的对象,如整数、列表、字典等。
encoding字符串可选编码格式(仅对字节类型有效),如 utf-8
errors字符串可选编码错误处理策略,如 ignorereplace

defaultdict():创建字典的子类,为不存在的键自动生成默认值。需指定 default_factory(如 listint)定义默认值类型。

参数名类型默认值说明
default_factory可调用对象或无参数函数None用于生成默认值的函数。若未指定,访问不存在的键会抛出 KeyError
**kwargs关键字参数可选其他初始化字典的键值对,如 name="Alice"

re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match 对象

参数名类型说明
patternstr 或正则表达式对象要匹配的正则表达式模式
stringstr要搜索的字符串
flagsint (可选)正则匹配标志(如 re.IGNORECASE

.span():返回正则匹配的起始和结束索引(左闭右开区间)

列表.append():向列表末尾添加单个元素,直接修改原列表

参数名类型说明
element任意要添加的元素
    def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence) + 2]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()print("location", s, e)results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()print("org", s, e)results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()print("per", s, e)results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()print("time", s, e)results["TIME"].append(sentence[s:e])return results

完整代码

DataLoader():PyTorch 中用于高效加载和管理数据集的核心工具

参数名类型默认值说明
datasetDataset必填加载的数据集对象,需实现 __len__ 和 __getitem__ 方法
batch_sizeint1每个批次包含的样本数
shuffleboolFalse是否在每个训练周期(epoch)开始时打乱数据顺序。若 sampler 被指定,则忽略此参数。
samplerSamplerNone自定义数据采样策略(如随机采样 RandomSampler 或顺序采样 SequentialSampler
batch_samplerSamplerNone自定义批次采样策略(需与 batch_sizeshuffle 等参数互斥)
num_workersint0用于加载数据的子进程数。0 表示在主进程加载;大于 0 时启用多进程加速
collate_fnCallableNone合并多个样本为批次的函数(如填充序列长度)。默认将 NumPy 数组转为 Tensor
pin_memoryboolFalse若为 True,将数据复制到 CUDA 固定内存中,加速 GPU 数据传输
drop_lastboolFalse若为 True,丢弃最后一个不完整的批次(当数据集样本数无法被 batch_size 整除时)
timeoutfloat0等待从子进程收集批次的超时时间(秒)。0 表示无限等待
worker_init_fnCallableNone子进程初始化函数(如设置随机种子)
prefetch_factorint2每个子进程预加载的批次数量(需 num_workers > 0
persistent_workersboolFalse是否在训练周期结束后保留子进程(减少重复创建进程的开销)

.shape: ​NumPy 数组或 ​PyTorch 张量的属性,用于获取数据的维度信息。

input():Python 的内置函数,用于从标准输入(如键盘)读取用户输入的字符串。

参数名类型默认值说明
promptstr""可选提示信息,显示在输入前(如 input("请输入:")
返回值str-返回用户输入的字符串,需手动转换为其他类型(如 int(input())
# -*- coding: utf-8 -*-import json
import re
import os
import torch
import random
import jieba
import numpy as np
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from transformers import BertTokenizer"""
数据加载
"""class DataGenerator:def __init__(self, data_path, config):self.config = configself.path = data_pathself.tokenizer = load_vocab(config["bert_path"])self.sentences = []self.schema = self.load_schema(config["schema_path"])self.load()def load(self):self.data = []with open(self.path, encoding="utf8") as f:segments = f.read().split("\n\n")for segment in segments:sentenece = []labels = [8]  # cls_tokenfor line in segment.split("\n"):if line.strip() == "":continuechar, label = line.split()sentenece.append(char)labels.append(self.schema[label])sentence = "".join(sentenece)self.sentences.append(sentence)input_ids = self.encode_sentence(sentenece)labels = self.padding(labels, -1)# print(self.decode(sentence, labels))# input()self.data.append([torch.LongTensor(input_ids), torch.LongTensor(labels)])returndef encode_sentence(self, text, padding=True):return self.tokenizer.encode(text,padding="max_length",max_length=self.config["max_length"],truncation=True)def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence) + 2]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()print("location", s, e)results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()print("org", s, e)results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()print("per", s, e)results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()print("time", s, e)results["TIME"].append(sentence[s:e])return results# 补齐或截断输入的序列,使其可以在一个batch内运算def padding(self, input_id, pad_token=0):input_id = input_id[:self.config["max_length"]]input_id += [pad_token] * (self.config["max_length"] - len(input_id))return input_iddef __len__(self):return len(self.data)def __getitem__(self, index):return self.data[index]def load_schema(self, path):with open(path, encoding="utf8") as f:return json.load(f)def load_vocab(vocab_path):return BertTokenizer.from_pretrained(vocab_path)# 用torch自带的DataLoader类封装数据
def load_data(data_path, config, shuffle=True):dg = DataGenerator(data_path, config)dl = DataLoader(dg, batch_size=config["batch_size"], shuffle=shuffle)return dlif __name__ == "__main__":from config import Configdg = DataGenerator("ner_data/train", Config)dl = DataLoader(dg, batch_size=32)for x, y in dl:print(x.shape, y.shape)print(x[1], y[1])input()

三、模型建立 model.py

1.代码运行流程

输入 x → 嵌入层 → 双向LSTM → 全连接分类层 → 分支判断:│├── 有 target → CRF? → 是:计算 CRF 损失(通过维特比算法计算序列概率)│                 ││                 └→ 否:计算交叉熵损失(logits 展平后与标签计算交叉熵)│└── 无 target → CRF? → 是:解码最优标签序列(使用CRF的decode方法)│└→ 否:返回原始 logits(全连接层输出的未归一化分数)

2.模型初始化

代码运行流程

输入 x → BERT预训练模型 → 分类层 → 分支判断:│├── 有 target → CRF? → 是:计算 CRF 损失(通过转移矩阵计算序列联合概率)│                 ││                 └→ 否:计算交叉熵损失(logits 与标签的逐位置交叉熵)│└── 无 target → CRF? → 是:维特比解码最优路径(考虑标签转移约束)│└→ 否:返回原始 logits(全连接层输出的未归一化分数)

hidden_size:定义LSTM隐藏层的维度(即每个时间步输出的特征数量

vocab_size:词表大小,即嵌入层(Embedding)可处理的词汇总数

max_length:输入序列的最大长度,用于数据预处理(如截断或填充)

class_num:分类任务的类别数量,决定线性层(nn.Linear)的输出维度

num_layers:堆叠的LSTM层数,用于增加模型复杂度

BertModel.from_pretrained():加载预训练的 BERT 模型,支持从本地或 Hugging Face 模型库加载

参数名类型默认值说明
pretrained_model_name字符串预训练模型名称或路径(如 bert-base-chinese
config字典/类默认配置自定义模型配置,覆盖默认参数(如隐藏层维度、注意力头数)
cache_dir字符串None模型缓存目录
output_hidden_states布尔值False是否返回所有隐藏层输出(用于特征提取)

nn.Linear():实现全连接层的线性变换(y = xW^T + b

参数名类型默认值说明
in_features整数输入特征维度(如词向量维度 hidden_size
out_features整数输出特征维度(如分类类别数 class_num
bias布尔值True是否启用偏置项

CRF():条件随机场层,用于序列标注任务中约束标签转移逻辑。

参数名类型默认值说明
num_tags整数标签类别数(如 class_num
batch_first布尔值False输入张量是否为 (batch_size, seq_len) 格式

torch.nn.CrossEntropyLoss():计算交叉熵损失,常用于分类任务

参数名类型默认值说明
ignore_index整数-1忽略指定索引的标签(如填充符 -1
reduction字符串mean损失聚合方式(可选 nonesummean
    def __init__(self, config):super(TorchModel, self).__init__()hidden_size = config["hidden_size"]vocab_size = config["vocab_size"] + 1max_length = config["max_length"]class_num = config["class_num"]num_layers = config["num_layers"]# self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)# self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)self.classify = nn.Linear(hidden_size * 2, class_num)self.crf_layer = CRF(class_num, batch_first=True)self.use_crf = config["use_crf"]self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵损失

3.前向计算 

代码运行流程

输入 x → 嵌入层 → LSTM层 → 分类层 → 分支判断:│├── 有 target → CRF? → 是:计算 CRF 损失│                 ││                 └→ 否:计算交叉熵损失│└── 无 target → CRF? → 是:解码最优标签序列│└→ 否:返回预测 logits

x:输入序列的 Token ID 矩阵,代表一个批次的文本数据(如 [[101, 234, ...], [103, 456, ...]])。

target:真实标签序列(如实体标注),若不为 None 表示训练阶段,需计算损失;否则为预测阶段。

predict:分类层输出的每个位置标签的未归一化分数(logits),用于后续的 CRF 或交叉熵损失计算。

mask:标记序列中有效 Token 的位置(非填充部分),target.gt(-1) 表示标签值大于 -1 的位置有效。

gt():张量的逐元素比较函数,返回布尔型张量,标记输入张量中大于指定值的元素位置。常用于生成掩码(如忽略填充符)

参数名类型默认值说明
otherTensor/标量比较的阈值或张量。若为标量,则张量中每个元素与该值比较;若为张量,需与输入张量形状相同。
outTensorNone可选输出张量,用于存储结果。

shape():返回张量的维度信息,描述各轴的大小。

view():调整张量的形状,支持自动推断维度(通过-1占位符)。常用于数据展平或维度转换。

参数名类型默认值说明
*shape可变参数目标形状的维度序列,如view(2, 3)view(-1, 28)-1表示自动计算。
    #当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, target=None):x = self.embedding(x)  #input shape:(batch_size, sen_len)x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)if target is not None:if self.use_crf:mask = target.gt(-1)# loss 是 crf 的相反数,即 - crf(predict, target, mask)return - self.crf_layer(predict, target, mask, reduction="mean")else:#(number, class_num), (number)return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))else:if self.use_crf:return self.crf_layer.decode(predict)else:return predict

4.选择优化器 

代码运行流程

输入 config → 提取参数 → 分支判断:│├── optimizer == "adam" → 返回 Adam 优化器实例│└── optimizer == "sgd" → 返回 SGD 优化器实例

config:这个参数应该是一个字典,里面存储了配置信息。

model:这是传入的模型对象,通常是一个神经网络模型。优化器需要模型的参数来更新权重

optimizer:从config中获取的字符串,决定使用哪种优化器。比如"adam"对应Adam优化器,"sgd"对应随机梯度下降。

learning_rate:学习率,是优化器的一个重要超参数,控制权重更新的步长

Adam():自适应矩估计优化器(Adaptive Moment Estimation),结合动量和 RMSProp 的优点。

参数名类型默认值说明
lrfloat1e-3学习率。
betastuple(0.9, 0.999)动量系数(β₁, β₂)。
epsfloat1e-8防止除零误差。
weight_decayfloat0权重衰减率。
amsgradboolFalse是否启用 AMSGrad 优化。
foreachboolFalse是否为每个参数单独计算梯度。

SGD():随机梯度下降优化器(Stochastic Gradient Descent)

参数名类型默认值说明
lrfloat1e-3学习率。
momentumfloat0动量系数(如 momentum=0.9)。
weight_decayfloat0权重衰减率。
dampeningfloat0动力衰减系数(用于 SGD with Momentum)。
nesterovboolFalse是否启用 Nesterov 动量。
foreachboolFalse是否为每个参数单独计算梯度。

parameters():返回模型所有可训练参数的迭代器,常用于参数初始化或梯度清零。

参数名类型默认值说明
filtercallableNone过滤条件函数(如 lambda p: p.requires_grad)。默认返回所有参数。
def choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)

5.模型建立

# -*- coding: utf-8 -*-import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torchcrf import CRF
import torch
from transformers import BertModel"""
建立网络模型结构
"""class TorchModel(nn.Module):def __init__(self, config):super(TorchModel, self).__init__()hidden_size = config["hidden_size"]vocab_size = config["vocab_size"] + 1max_length = config["max_length"]class_num = config["class_num"]num_layers = config["num_layers"]self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)# self.layer = BertModel.from_pretrained(config["bert_path"], hidden_size=hidden_size, num_layers=num_layers)self.layer = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True, num_layers=num_layers)self.classify = nn.Linear(hidden_size * 2, class_num)self.crf_layer = CRF(class_num, batch_first=True)self.use_crf = config["use_crf"]self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1)  #loss采用交叉熵损失#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, target=None):x = self.embedding(x)  #input shape:(batch_size, sen_len)x, _ = self.layer(x)      #input shape:(batch_size, sen_len, input_dim)predict = self.classify(x) #ouput:(batch_size, sen_len, num_tags) -> (batch_size * sen_len, num_tags)if target is not None:if self.use_crf:mask = target.gt(-1)# loss 是 crf 的相反数,即 - crf(predict, target, mask)return - self.crf_layer(predict, target, mask, reduction="mean")else:#(number, class_num), (number)return self.loss(predict.view(-1, predict.shape[-1]), target.view(-1))else:if self.use_crf:return self.crf_layer.decode(predict)else:return predictdef choose_optimizer(config, model):optimizer = config["optimizer"]learning_rate = config["learning_rate"]if optimizer == "adam":return Adam(model.parameters(), lr=learning_rate)elif optimizer == "sgd":return SGD(model.parameters(), lr=learning_rate)if __name__ == "__main__":from config import Configmodel = TorchModel(Config)

四、模型效果测试 evaluate.py

1.代码运行流程

输入验证集 → 数据加载 → 模型预测 → 分支判断:│├── 启用CRF → 直接解码标签序列 → 实体提取│└── 禁用CRF → argmax获取预测标签 → 实体提取→ 统计指标计算 → 分支判断:│├── 按实体类别统计 → 计算precision/recall/F1(LOCATION/TIME/PERSON/ORGANIZATION)│└── 全局统计 → 计算micro-F1 → 输出综合评估结果

2.初始化

Ⅰ、加载配置文件、模型及日志模块 ——>

Ⅱ、读取验证集数据(固定顺序,避免随机性干扰评估)——>

Ⅲ、初始化统计字典 stats_dict,按实体类别记录正确识别数、样本实体数等

config:存储运行时配置,例如数据路径、超参数(如批次大小 batch_size)、是否使用CRF层等。通过 config["valid_data_path"] 动态获取验证集路径。

model:待评估的模型实例,用于调用预测方法(如 model(input_id)),需提前完成训练和加载。

logger:记录运行日志,例如输出评估指标(准确率、F1值)到文件或控制台,便于调试和监控。

valid_data:验证数据集,用于模型训练时的性能评估和超参数调优。

load_data():数据加载类中,用torch自带的DataLoader类封装数据的函数

    def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)

 3.统计模型效果

Ⅰ、​输入验证与初始化

        通过 assert 确保输入的三组数据长度一致(labelspred_resultssentences)。

        若模型未使用 CRF 层(use_crf=False),将预测结果通过 torch.argmax 转换为标签索引序列


Ⅱ、逐样本处理

        遍历每个样本的真实标签、预测标签及原始句子。

        若未使用 CRF,将预测标签从 GPU Tensor 转换为 CPU List(避免内存泄漏)。

        调用 decode() 方法解码标签序列,得到真实实体字典 true_entities 和预测实体字典 pred_entities


Ⅲ、实体统计

对每个实体类别(如 PERSONLOCATION):

        正确识别数:遍历预测实体列表,统计与真实实体完全匹配的数量(ent in true_entities[key])。

        ​样本实体数:统计真实实体列表的长度。

        ​识别出实体数:统计预测实体列表的长度。


Ⅳ、输出统计结果

        最终统计结果存储在 self.stats_dict 中,后续可通过该字典计算准确率(正确识别数 / 识别出实体数)和召回率(正确识别数 / 样本实体数

labels:真实标签序列(如实体标注的整数 ID 列表),用于与预测结果对比计算评估指标

pred_results:模型预测结果,若使用 CRF,为标签序列,否则为每个位置的 logits(未归一化概率)。

sentences:原始文本句子列表(如 ["中国北京", "今天天气"]),用于解码标签序列到具体实体。

use_crf:控制是否使用 CRF 层

pred_label:单个样本的预测标签序列,若未使用 CRF,需从 logits 中提取(argmax)并转换为列表。

true_label:单个样本的真实标签序列(如 [0, 4, 4, 8]),已从 GPU 张量转换为 CPU 列表。

true_entities:解码后的真实实体字典,如 {"LOCATION": ["北京"], "PERSON": []}

pred_entities:解码后的预测实体字典,用于与真实实体对比统计正确识别数。

key:字符串,实体类别名称(如 "PERSON"),遍历四类实体以分别统计指标。

assert:Python 的 ​调试断言工具,主要用于在开发阶段验证程序内部的逻辑条件是否成立

        assert expression [, message]  

参数类型是否必填作用
expression布尔表达式需要验证的条件。若结果为 False,则触发断言失败;若为 True,程序继续执行。
message字符串(可选)断言失败时输出的自定义错误信息,用于辅助调试。若省略,则输出默认错误提示。

len():返回对象的元素数量(字符串、列表、元组、字典等)

参数名类型说明
object任意可迭代对象如字符串、列表、字典等

torch.argmax():返回张量中最大值所在的索引

参数名类型说明
inputTensor输入张量
dimint沿指定维度查找最大值
keepdimbool是否保持输出维度一致

cpu():将张量从GPU移动到CPU内存

zip():将多个可迭代对象打包成元组列表

参数名类型说明
iterables多个可迭代对象如列表、元组、字符串

.detach():从计算图中分离张量,阻止梯度传播

.tolist():将张量或数组转换为Python列表

    def write_stats(self, labels, pred_results, sentences):assert len(labels) == len(pred_results) == len(sentences)if not self.config["use_crf"]:pred_results = torch.argmax(pred_results, dim=-1)for true_label, pred_label, sentence in zip(labels, pred_results, sentences):if not self.config["use_crf"]:pred_label = pred_label.cpu().detach().tolist()true_label = true_label.cpu().detach().tolist()true_entities = self.decode(sentence, true_label)pred_entities = self.decode(sentence, pred_label)# 正确率 = 识别出的正确实体数 / 识别出的实体数# 召回率 = 识别出的正确实体数 / 样本的实体数for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])self.stats_dict[key]["样本实体数"] += len(true_entities[key])self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])return

4.可视化统计模型效果

精确率 (Precision):正确预测实体数 / 总预测实体数

召回率 (Recall):正确预测实体数 / 总真实实体数​

F1值:精确率与召回率的调和平均

F1:F1分数:准确率与召回率的调和平均数,综合衡量模型的精确性与覆盖能力。

F1_scores:存储四个实体类别的 F1 分数,用于计算宏观平均。

precision:准确率:模型预测为某类实体的结果中,正确的比例。反映模型预测的精确度。

recall:召回率:真实存在的某类实体中,被模型正确识别的比例。反映模型对实体的覆盖能力。

key:当前处理的实体类别(如 "PERSON""LOCATION")。

correct_pred:总正确识别数:所有类别中被正确识别的实体总数。

total_pred:总识别实体数:模型预测出的所有实体数量(含错误识别)。

true_enti:总样本实体数:验证数据中真实存在的所有实体数量。

micro_precision:微观准确率:全局视角下的准确率,所有实体类别的正确识别数与总识别数的比例。

micro_recall:微观召回率:全局视角下的召回率,所有实体类别的正确识别数与总样本实体数的比例。

micro_f1:微观F1分数:微观准确率与微观召回率的调和平均数。

列表.append():在列表末尾添加元素

参数名类型说明
element任意要添加的元素

logger.info():记录日志信息(需配置日志模块)

参数名类型说明
formatstr格式化字符串
*args可变参数格式化参数

sum():计算可迭代对象的元素总和

参数名类型说明
iterable可迭代对象如列表、元组
start数值(可选)初始累加值

列表推导式:通过简洁语法生成新列表,语法:[表达式 for item in iterable if 条件]

    def show_stats(self):F1_scores = []for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:# 正确率 = 识别出的正确实体数 / 识别出的实体数# 召回率 = 识别出的正确实体数 / 样本的实体数precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])F1 = (2 * precision * recall) / (precision + recall + 1e-5)F1_scores.append(F1)self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))self.logger.info("Macro-F1: %f" % np.mean(F1_scores))correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])micro_precision = correct_pred / (total_pred + 1e-5)micro_recall = correct_pred / (true_enti + 1e-5)micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)self.logger.info("Micro-F1 %f" % micro_f1)self.logger.info("--------------------")return

5.评估模型效果

模型切换为评估模式:关闭Dropout等训练层​

批次处理数据

         提取原始句子 sentences

   将数据迁移至GPU(若可用)

         预测时禁用梯度计算(torch.no_grad())优化内存

统计结果:调用 write_stats 对比预测与真实标签

epoch:当前训练轮次,用于日志。

logger:记录日志的工具。

stats_dict:统计字典,记录各实体类别的指标。

valid_data:验证数据集,通常由 load_data 加载(如 config["valid_data_path"] 指定路径)

index: 循环中的批次索引

batch_data: 循环中的数据。

sentences:当前批次的原始句子

pred_results:模型预测结果

write_stats():写入统计信息

show_stats():显示统计结果

logger.info():记录日志信息(需配置日志模块)

参数名类型说明
formatstr格式化字符串
*args可变参数格式化参数

defaultdict():创建带有默认值工厂的字典

参数名类型说明
default_factory可调用对象如int、list、自定义函数

model.eval():将模型设置为评估模式(关闭Dropout等训练层)

enumerate():返回索引和元素组成的枚举对象

参数名类型说明
iterable可迭代对象如列表、字符串
startint(可选)起始索引,默认为0

torch.cuda.is_available():检查当前环境是否支持CUDA(GPU加速)

cuda():将张量或模型移动到GPU

参数名类型说明
deviceint/str指定GPU设备号,如"cuda:0"

torch.no_grad():禁用梯度计算,节省内存并加速推理

    def eval(self, epoch):self.logger.info("开始测试第%d轮模型效果:" % epoch)self.stats_dict = {"LOCATION": defaultdict(int),"TIME": defaultdict(int),"PERSON": defaultdict(int),"ORGANIZATION": defaultdict(int)}self.model.eval()for index, batch_data in enumerate(self.valid_data):sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况with torch.no_grad():pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测self.write_stats(labels, pred_results, sentences)self.show_stats()return

6.解码

​根据代码中,Schema文件映射的定义对标签序列预处理:将数值标签拼接为字符串(如 [0,4,4] → "044"

正则匹配实体

   04+B-LOCATION(0)后接多个I-LOCATION(4)

   15+B-ORGANIZATION(1)后接I-ORGANIZATION(5)

           其他实体类别同理

索引对齐:根据匹配位置截取原始句子中的实体文本

Ⅰ、输入预处理

在原句首添加 $ 符号,通常用于对齐标签与字符位置(例如避免索引越界)

        sentence = "$" + sentence

Ⅱ、标签序列转换

将整数标签序列转换为字符串,并截取长度与 sentence 对齐

str.join():将可迭代对象中的字符串元素按指定分隔符连接成一个新字符串

参数名类型说明
iterable可迭代对象元素必须为字符串类型

str():将对象转换为字符串表示形式,支持自定义类的 __str__ 方法

参数名类型说明
object任意要转换的对象

len():返回对象的长度或元素个数(适用于字符串、列表、字典等)

参数名类型说明
object可迭代对象如字符串、列表等

列表推导式:通过简洁语法生成新列表,支持条件过滤和多层循环

        [expression for item in iterable if condition]

部分类型说明
expression表达式对 item 处理后的结果
item变量迭代变量
iterable可迭代对象如列表、range() 生成的序列
condition条件表达式 (可选)过滤不符合条件的元素
        labels = "".join([str(x) for x in labels[:len(sentence)+1]])

Ⅲ、初始化结果容器

创建默认值为列表的字典,存储四类实体:

        (LOCATION、ORGANIZATION、PERSON、TIME)的识别结果

defaultdict():创建默认值字典,当键不存在时自动生成默认值(基于工厂函数)

参数名类型说明
default_factory可调用对象如 intlist 或自定义函数
        results = defaultdict(list)

Ⅳ、正则表达式匹配

    (04+): 匹配以 0(B-LOCATION)开头,后接多个 4(I-LOCATION)的连续标签

    (15+)(26+)(37+)分别对应 ORGANIZATION(B=1, I=5)、PERSON(B=2, I=6)、TIME(B=3, I=7)的标签模式。

re.finditer():在字符串中全局搜索正则表达式匹配项,返回一个迭代器,每个元素为 Match 对象

参数名类型说明
patternstr 或正则表达式对象要匹配的正则表达式模式
stringstr要搜索的字符串
flagsint (可选)正则匹配标志(如 re.IGNORECASE

.span():返回正则匹配的起始和结束索引(左闭右开区间)

列表.append():向列表末尾添加单个元素,直接修改原列表

参数名类型说明
element任意要添加的元素
        for location in re.finditer("(04+)", labels):s, e = location.span()results["LOCATION"].append(sentence[s:e])

Ⅴ、完整代码 

    '''Schema文件{"B-LOCATION": 0,"B-ORGANIZATION": 1,"B-PERSON": 2,"B-TIME": 3,"I-LOCATION": 4,"I-ORGANIZATION": 5,"I-PERSON": 6,"I-TIME": 7,"O": 8}'''def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence)+1]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()results["TIME"].append(sentence[s:e])return results

7.完整代码 

# -*- coding: utf-8 -*-
import torch
import re
import numpy as np
from collections import defaultdict
from loader import load_data"""
模型效果测试
"""class Evaluator:def __init__(self, config, model, logger):self.config = configself.model = modelself.logger = loggerself.valid_data = load_data(config["valid_data_path"], config, shuffle=False)def eval(self, epoch):self.logger.info("开始测试第%d轮模型效果:" % epoch)self.stats_dict = {"LOCATION": defaultdict(int),"TIME": defaultdict(int),"PERSON": defaultdict(int),"ORGANIZATION": defaultdict(int)}self.model.eval()for index, batch_data in enumerate(self.valid_data):sentences = self.valid_data.dataset.sentences[index * self.config["batch_size"]: (index+1) * self.config["batch_size"]]if torch.cuda.is_available():batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况with torch.no_grad():pred_results = self.model(input_id) #不输入labels,使用模型当前参数进行预测self.write_stats(labels, pred_results, sentences)self.show_stats()returndef write_stats(self, labels, pred_results, sentences):assert len(labels) == len(pred_results) == len(sentences)if not self.config["use_crf"]:pred_results = torch.argmax(pred_results, dim=-1)for true_label, pred_label, sentence in zip(labels, pred_results, sentences):if not self.config["use_crf"]:pred_label = pred_label.cpu().detach().tolist()true_label = true_label.cpu().detach().tolist()true_entities = self.decode(sentence, true_label)pred_entities = self.decode(sentence, pred_label)# 正确率 = 识别出的正确实体数 / 识别出的实体数# 召回率 = 识别出的正确实体数 / 样本的实体数for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:self.stats_dict[key]["正确识别"] += len([ent for ent in pred_entities[key] if ent in true_entities[key]])self.stats_dict[key]["样本实体数"] += len(true_entities[key])self.stats_dict[key]["识别出实体数"] += len(pred_entities[key])returndef show_stats(self):F1_scores = []for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]:# 正确率 = 识别出的正确实体数 / 识别出的实体数# 召回率 = 识别出的正确实体数 / 样本的实体数precision = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["识别出实体数"])recall = self.stats_dict[key]["正确识别"] / (1e-5 + self.stats_dict[key]["样本实体数"])F1 = (2 * precision * recall) / (precision + recall + 1e-5)F1_scores.append(F1)self.logger.info("%s类实体,准确率:%f, 召回率: %f, F1: %f" % (key, precision, recall, F1))self.logger.info("Macro-F1: %f" % np.mean(F1_scores))correct_pred = sum([self.stats_dict[key]["正确识别"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])total_pred = sum([self.stats_dict[key]["识别出实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])true_enti = sum([self.stats_dict[key]["样本实体数"] for key in ["PERSON", "LOCATION", "TIME", "ORGANIZATION"]])micro_precision = correct_pred / (total_pred + 1e-5)micro_recall = correct_pred / (true_enti + 1e-5)micro_f1 = (2 * micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-5)self.logger.info("Micro-F1 %f" % micro_f1)self.logger.info("--------------------")return'''{"B-LOCATION": 0,"B-ORGANIZATION": 1,"B-PERSON": 2,"B-TIME": 3,"I-LOCATION": 4,"I-ORGANIZATION": 5,"I-PERSON": 6,"I-TIME": 7,"O": 8}'''def decode(self, sentence, labels):sentence = "$" + sentencelabels = "".join([str(x) for x in labels[:len(sentence)+1]])results = defaultdict(list)for location in re.finditer("(04+)", labels):s, e = location.span()results["LOCATION"].append(sentence[s:e])for location in re.finditer("(15+)", labels):s, e = location.span()results["ORGANIZATION"].append(sentence[s:e])for location in re.finditer("(26+)", labels):s, e = location.span()results["PERSON"].append(sentence[s:e])for location in re.finditer("(37+)", labels):s, e = location.span()results["TIME"].append(sentence[s:e])return results

五、主函数文件 main.py

1.代码运行流程

配置参数 → 创建模型目录 → 加载训练数据 → 初始化模型 → 设备检测:│├── GPU可用 → 迁移模型至GPU│└── GPU不可用 → 保持CPU模式→ 选择优化器 → 初始化评估器 → 进入训练循环:│├── 当前epoch → 训练模式 → 遍历数据批次:│                 ││                 ├── 清空梯度 → 数据迁移至GPU → 前向计算 → 分支判断:│                 │             ││                 │             ├── 启用CRF → 计算CRF损失 → 反向传播 → 参数更新│                 │             ││                 │             └── 禁用CRF → 计算交叉熵损失 → 反向传播 → 参数更新│                 ││                 └── 记录批次损失 → 周期中点打印日志│└── 计算epoch平均损失 → 验证集评估 → 保存当前模型权重

2.导入文件

# -*- coding: utf-8 -*-import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_data

3.日志配置

logging.basicConfig():配置日志系统的基础参数(一次性设置,应在首次日志调用前调用)

参数名类型是否必需默认值说明
filename字符串None日志输出文件名(若指定,日志写入文件而非控制台)
filemode字符串'a'文件打开模式(如'w'覆盖,'a'追加)
format字符串基础格式日志格式模板(如'%(asctime)s - %(levelname)s - %(message)s'
datefmt字符串时间格式(如'%Y-%m-%d %H:%M:%S'
level整数WARNING日志级别(如logging.INFOlogging.DEBUG
stream对象None指定日志输出流(如sys.stderr,与filename互斥)

logging.getLogger():获取或创建指定名称的日志记录器(Logger)。若nameNone,返回根日志记录器

参数名类型是否必需默认值说明
name字符串None日志记录器名称(分层结构,如'module.sub'
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

4.主函数 main  

Ⅰ、创建模型保存目录

os.path.isdir():检查指定路径是否为目录(文件夹)

参数名类型是否必需默认值说明
path字符串要检查的路径(绝对或相对)

os.mkdir():创建单个目录(若父目录不存在会抛出异常)

参数名类型是否必需默认值说明
path字符串要创建的目录路径
mode整数0o777目录权限(八进制格式,某些系统可能忽略此参数)
    #创建保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])

Ⅱ、加载训练数据

    #加载训练数据train_data = load_data(config["train_data_path"], config)

Ⅲ、加载模型

    #加载模型model = TorchModel(config)

Ⅳ、检查GPU并迁移模型

torch.cuda.is_available():检查系统是否满足 CUDA 环境要求

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)

cuda():将张量或模型移动到GPU显存,加速计算

参数类型必须说明示例
deviceint/str指定GPU设备(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否异步传输数据(默认False)tensor.cuda(non_blocking=True)
    # 标识是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,迁移模型至gpu")model = model.cuda()

Ⅴ、加载优化器

    #加载优化器optimizer = choose_optimizer(config, model)

Ⅵ、加载评估器

    #加载效果测试类evaluator = Evaluator(config, model, logger)

Ⅶ、模型训练 ⭐

① Epoch循环控制

range():Python 内置函数,用于生成一个不可变的整数序列,​核心功能是为循环控制提供高效的数值迭代支持

参数名类型默认值说明
start整数0序列起始值(包含)。若省略,则默认从 0 开始。例如 range(3) 等价于 range(0,3)
stop整数必填序列结束值(不包含)。例如 range(2, 5) 生成 2,3,4
step整数1步长(正/负):
- ​正步长需满足 start < stop,否则无输出(如 range(5, 2) 无效)。
- ​负步长需满足 start > stop,例如 range(5, 0, -1) 生成 5,4,3,2,1
​**不能为 0**​(否则触发 ValueError
for epoch in range(config["epoch"]):epoch += 1
② 模型设置训练模式 

train_loss:计算当前批次的损失值,通常结合损失函数(如交叉熵、均方误差)使用

model.train():设置模型为训练模式,启用Dropout、BatchNorm等层的训练行为

参数类型默认值说明示例
modeboolTrue是否启用训练模式(True)或评估模式(False)model.train(True)

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
        model.train()logger.info("epoch %d begin" % epoch)train_loss = []

③ Batch数据遍历

enumerate():遍历可迭代对象时返回索引和元素,支持自定义起始索引

参数类型必须说明示例
iterableIterable可迭代对象(如列表、生成器)enumerate(["a", "b"])
startint索引起始值(默认0)enumerate(data, start=1)
        for index, batch_data in enumerate(train_data):

④ 梯度清零与设备切换

optimizer.zero_grad():清空模型参数的梯度,防止梯度累积

参数类型必须说明示例
set_to_nonebool是否将梯度置为None(高效但危险)optimizer.zero_grad(True)

cuda():将张量或模型移动到GPU显存,加速计算

参数类型必须说明示例
deviceint/str指定GPU设备(如0"cuda:0"tensor.cuda(device=0)
non_blockingbool是否异步传输数据(默认False)tensor.cuda(non_blocking=True)
            optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]

⑤ 前向传播与损失计算
            input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况loss = model(input_id, labels)

⑥ 反向传播与参数更新

loss.backward():反向传播计算梯度,基于损失值更新模型参数的.grad属性

参数类型必须说明示例
retain_graphbool是否保留计算图(用于多次反向传播)loss.backward(retain_graph=True)

optimizer.step():根据梯度更新模型参数,执行优化算法(如SGD、Adam)

参数类型必须说明示例
closureCallable重新计算损失的闭包函数(如LBFGS)optimizer.step(closure)
            loss.backward()optimizer.step()

⑦ 损失记录与日志输出

列表.append():在列表末尾添加元素,直接修改原列表

参数类型必须说明示例
objectAny要添加到列表末尾的元素train_loss.append(loss.item())

int():将字符串或浮点数转换为整数,支持进制转换

参数类型必须说明示例
xstr/float待转换的值(如字符串或浮点数)int("10", base=2)(输出2进制10=2)
baseint进制(默认10)

len():返回对象(如列表、字符串)的长度或元素个数

参数类型必须说明示例
objSequence/Collection可计算长度的对象(如列表、字符串)len([1, 2, 3])(返回3)

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
            train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)

⑧ Epoch评估与日志

item():从张量中提取标量值(仅当张量包含单个元素时可用)

列表.append():Python 列表(list)的内置方法,用于向列表的 ​末尾 添加一个元素。

参数名类型默认值说明
element任意类型要添加到列表末尾的元素。可以是单个值(如 42)、对象(如 [1, 2, 3])等。

logger.info():记录日志信息,输出训练过程中的关键状态

参数类型必须说明示例
msgstr日志消息(支持格式化字符串)logger.info("Epoch: %d", epoch)
*argsAny格式化参数(用于%占位符)
            train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)

⑨ 完整训练代码
    #训练for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况loss = model(input_id, labels)loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)

Ⅷ、模型保存

os.path.join():Python 中用于拼接路径的核心函数,其核心价值在于自动处理不同操作系统的路径分隔符,从而保证代码的跨平台兼容性

参数类型必填说明
path1字符串初始路径组件
*paths可变参数后续路径组件(可传多个)

torch.save():  PyTorch 中用于序列化保存模型、张量或字典等对象的核心函数,支持将数据持久化存储为 .pth 或 .pt 文件,便于后续加载和复用

参数名类型默认值说明
obj任意 PyTorch 对象必填待保存的对象,如模型、张量或字典。
fstr 或文件对象必填保存路径(如 'model.pth')或已打开的文件对象(需二进制写入模式 'wb'
pickle_protocolint2指定 pickle 协议版本(通常无需修改,高版本可能提升效率但需兼容性验证)
_use_new_zipfile_serializationboolTrue启用新版序列化格式(压缩率更高,推荐保持默认)
    model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)# torch.save(model.state_dict(), model_path)return model, train_data

5.调用模型预测

# -*- coding: utf-8 -*-import torch
import os
import random
import numpy as np
import logging
from config import Config
from model import TorchModel, choose_optimizer
from evaluate import Evaluator
from loader import load_datalogging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)"""
模型训练主程序
"""def main(config):#创建保存模型的目录if not os.path.isdir(config["model_path"]):os.mkdir(config["model_path"])#加载训练数据train_data = load_data(config["train_data_path"], config)#加载模型model = TorchModel(config)# 标识是否使用gpucuda_flag = torch.cuda.is_available()if cuda_flag:logger.info("gpu可以使用,迁移模型至gpu")model = model.cuda()#加载优化器optimizer = choose_optimizer(config, model)#加载效果测试类evaluator = Evaluator(config, model, logger)#训练for epoch in range(config["epoch"]):epoch += 1model.train()logger.info("epoch %d begin" % epoch)train_loss = []for index, batch_data in enumerate(train_data):optimizer.zero_grad()if cuda_flag:batch_data = [d.cuda() for d in batch_data]input_id, labels = batch_data   #输入变化时这里需要修改,比如多输入,多输出的情况loss = model(input_id, labels)loss.backward()optimizer.step()train_loss.append(loss.item())if index % int(len(train_data) / 2) == 0:logger.info("batch loss %f" % loss)logger.info("epoch average loss: %f" % np.mean(train_loss))evaluator.eval(epoch)# 保存模型model_path = os.path.join(config["model_path"], "epoch_%d.pth" % epoch)torch.save(model.state_dict(), model_path)return model, train_dataif __name__ == "__main__":model, train_data = main(Config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/72534.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 蓝牙音频软件栈实现分析

Linux 蓝牙音频软件栈实现分析 蓝牙协议栈简介蓝牙控制器探测BlueZ 插件系统及音频插件蓝牙协议栈简介 蓝牙协议栈是实现蓝牙通信功能的软件架构,它由多个层次组成,每一层负责特定的功能。蓝牙协议栈的设计遵循蓝牙标准 (由蓝牙技术联盟,Bluetooth SIG 定义),支持多种蓝牙…

JetBrains(全家桶: IDEA、WebStorm、GoLand、PyCharm) 2024.3+ 2025 版免费体验方案

JetBrains&#xff08;全家桶: IDEA、WebStorm、GoLand、PyCharm&#xff09; 2024.3 2025 版免费体验方案 前言 JetBrains IDE 是许多开发者的主力工具&#xff0c;但从 2024.02 版本起&#xff0c;JetBrains 调整了试用政策&#xff0c;新用户不再享有默认的 30 天免费试用…

Prosys OPC UA Gateway:实现 OPC Classic 与 OPC UA 无缝连接

在工业自动化的数字化转型中&#xff0c;设备与系统之间的高效通信至关重要。然而&#xff0c;许多企业仍依赖于基于 COM/DCOM 技术的 OPC 产品&#xff0c;这给与现代化的 OPC UA 架构的集成带来了挑战。 Prosys OPC UA Gateway 正是为解决这一问题而生&#xff0c;它作为一款…

基于银河麒麟系统ARM架构安装达梦数据库并配置主从模式

达梦数据库简要概述 达梦数据库&#xff08;DM Database&#xff09;是一款由武汉达梦公司开发的关系型数据库管理系统&#xff0c;支持多种高可用性和数据同步方案。在主从模式&#xff08;也称为 Master-Slave 或 Primary-Secondary 模式&#xff09;中&#xff0c;主要通过…

系统思考全球化落地

感谢加密货币公司Bybit的再次邀请&#xff0c;为全球团队分享系统思考课程&#xff01;虽然大家来自不同国家&#xff0c;线上学习的形式依然让大家充满热情与互动&#xff0c;思维的碰撞不断激发新的灵感。 尽管时间存在挑战&#xff0c;但我看到大家的讨论异常积极&#xff…

Figma的汉化

Figma的汉化插件有客户端版本与Chrome版本&#xff0c;大家可根据自己的需要进行选择。 下载插件 进入Figma软件汉化-Figma中文版下载-Figma中文社区使用客户端&#xff1a;直接下载客户端使用网页版&#xff1a;安装chrome浏览器汉化插件国外推荐前往chrome商店安装国内推荐下…

IDEA 一键完成:打包 + 推送 + 部署docker镜像

1、本方案要解决场景&#xff1f; 想直接通过本地 IDEA 将最新的代码部署到远程服务器上。 2、本方案适用于什么样的项目&#xff1f; 项目是一个 Spring Boot 的 Java 项目。项目用 maven 进行管理。项目的运行基于 docker 容器&#xff08;即项目将被打成 docker image&am…

SpringBoot 第一课(Ⅲ) 配置类注解

目录 一、PropertySource 二、ImportResource ①SpringConfig &#xff08;Spring框架全注解&#xff09; ②ImportResource注解实现 三、Bean 四、多配置文件 多Profile文件的使用 文件命名约定&#xff1a; 激活Profile&#xff1a; YAML文件支持多文档块&#xff…

深度解析React Native底层核心架构

React Native 工作原理深度解析 一、核心架构&#xff1a;三层异构协作体系 React Native 的跨平台能力源于其独特的 JS层-Shadow层-Native层 架构设计&#xff0c;三者在不同线程中协同工作&#xff1a; JS层 运行于JavaScriptCore&#xff08;iOS&#xff09;或Hermes&…

对话智能体的正确打开方式:解析主流AI聊天工具的核心能力与使用方式

一、人机对话的黄金法则 在与人工智能对话系统交互时&#xff0c;掌握以下七项核心原则可显著提升沟通效率&#xff1a;文末有教程分享地址 意图精准表达术 采用"背景需求限定条件"的结构化表达 示例优化&#xff1a;"请用Python编写一个网络爬虫&#xff08…

Xinference大模型配置介绍并通过git-lfs、hf-mirror安装

文章目录 一、Xinference开机服务systemd二、语言&#xff08;LLM&#xff09;模型2.1 配置介绍2.2 DeepSeek-R1-Distill-Qwen-32B&#xff08;大杯&#xff09;工具下载git-lfs&#xff08;可以绕过Hugging Face&#xff09; 2.3 DeepSeek-R1-Distill-Qwen-32B-Q4_K_M-GGUF&am…

MyBatis操纵数据库-XML实现(补充)

目录 一.多表查询二.MyBatis参数赋值(#{ }和${ })2.1 #{ }和${ }的使用2.2 #{ }和${ }的区别2.3 SQL注入2.3 ${ }的应用场景2.3.1 排序功能2.3.2 like查询 一.多表查询 多表查询的操作和单表查询基本相同&#xff0c;只需改变一下SQL语句&#xff0c;同时也要在实体类中创建出…

快速导出接口设计表——基于DOMParser的Swagger接口详情半自动化提取方法

作者声明&#xff1a;不想看作者声明的&#xff08;需要生成接口设计表的&#xff09;直接前往https://capujin.github.io/A2T/。 注&#xff1a;Github Pages生成的页面可能会出现访问不稳定&#xff0c;暂时没将源码上传至Github&#xff0c;如有需要&#xff0c;可联系我私…

AI-医学影像分割方法与流程

AI医学影像分割方法与流程–基于低场磁共振影像的病灶识别 – 作者:coder_fang AI框架&#xff1a;PaddleSeg 数据准备&#xff0c;使用MedicalLabelMe进行dcm文件标注&#xff0c;产生同名.json文件。 编写程序生成训练集图片&#xff0c;包括掩码图。 代码如下: def doC…

SGMEA: Structure-Guided Multimodal Entity Alignment

3 Method 3.1 Problem Definition 3.2 Framework Description 总体框架如图2所示&#xff0c;由三个主要部分组成&#xff1a;初始嵌入采集模块、结构引导模块和模态融合模块。 3.3 Initial Embedding Acquisition 3.3.1 Structural Embedding 3.3.2 Relation, Attribute, …

《基于超高频RFID的图书馆管理系统的设计与实现》开题报告

一、研究背景与意义 1.研究背景 随着信息化时代的到来&#xff0c;运用计算机科学技术实现图书馆的管理工作已成为优势。更加科学地管理图书馆会大大提高工作效率。我国的图书管理体系发展经历了三个阶段&#xff1a;传统图书管理模式、现代图书管理模式以及基于无线射频识别&…

[local-file-system]基于服务器磁盘的本地文件存储方案

[local-file-system]基于服务器磁盘的本地文件存储方案 仅提供后端方案 github 环境 JDK11linux/windows/mac 应用场景 适用于ToB业务&#xff0c;中小企业的单体服务&#xff0c;仅使用磁盘存储文件的解决方案 仅使用服务器磁盘存储 与业务实体相结合的文件存储方案&…

【蓝桥杯每日一题】3.16

&#x1f3dd;️专栏&#xff1a; 【蓝桥杯备篇】 &#x1f305;主页&#xff1a; f狐o狸x 目录 3.9 高精度算法 一、高精度加法 题目链接&#xff1a; 题目描述&#xff1a; 解题思路&#xff1a; 解题代码&#xff1a; 二、高精度减法 题目链接&#xff1a; 题目描述&…

vue 仿deepseek前端开发一个对话界面

后端&#xff1a;调用deepseek的api&#xff0c;所以返回数据格式和deepseek相同 {"model": "DeepSeek-R1-Distill-Qwen-1.5B", "choices": [{"index": 0, "delta": {"role": "assistant", "cont…

SpringMVC(五)拦截器

目录 拦截器基本概念 一 单个拦截器的执行 1 创建拦截器 2 SpringMVC配置&#xff0c;并指定拦截路径。 3 运行结果展示&#xff1a; 二 多个拦截器的执行顺序 三 拦截器与过滤器的区别 拦截器基本概念 SpringMVC内置拦截器机制&#xff0c;允许在请求被目标方法处理的…