文章目录
- 尝试1:强行设置dropout层train mode为False
- 尝试2:找到onnx模型中的dropout, train mode设置为False
- 尝试3:直接删除dropout层,连接其输入输出
- 结语
最近训练模型使用了tinyvit,性能挺强的:
但是导出onnx时,会提示dropout层的train mode被设置为True了。
UserWarning: ONNX export mode is set to TrainingMode.EVAL, but operator 'dropout' is set to train=True. Exporting with train=True.
这个警告如果只是使用onnxruntime去推理的话,可以不用处理,但是如果使用openvino则会在转换模型时失败。因为导出的onnx中出现了Dropout层,一般的推理框架是不支持推理的时候用dropout的。
尝试1:强行设置dropout层train mode为False
for m in torch_model.modules():if isinstance(m, torch.nn.Dropout):m.training = False
问题依旧
尝试2:找到onnx模型中的dropout, train mode设置为False
做这个尝试的本意是先设置为False, 再用onnx-simplify去优化一把,理论上会把dropout层去掉。
# 遍历模型的所有Dropout节点, 找到所有的training mode节点名称
training_mode_inputs=[]
for node in model.graph.node:if node.op_type == 'Dropout':# 获取Dropout节点的training_mode输入(假设是最后一个输入)training_mode_input = node.input[-1]# 检查这个输入是否指向之前找到的值为True的常量节点training_mode_inputs.append(training_mode_input)# 遍历所有初始化器
for initializer in model.graph.initializer:# 检查初始化器是否是我们要找的training_mode输入if initializer.name in training_mode_inputs:# 假设这个初始化器是一个布尔值,我们将其修改为False# 注意:ONNX中的布尔值是以int64类型存储的,0表示False,1表示True# initializer.data_type = onnx.TensorProto.INT64initializer.int64_data[:] = [0] # 修改为False
from onnx import helper
new_initializers = []for initializer in model.graph.initializer:if initializer.name in training_mode_inputs:# 创建一个新的TensorProto对象,值为Falsenew_initializer = helper.make_tensor(name=initializer.name, # 保持原来的名称data_type=onnx.TensorProto.BOOL,dims=initializer.dims, # 保持原来的维度vals=[0] # 设置值为False(在ONNX中用0表示))new_initializers.append(new_initializer)else:new_initializers.append(initializer)# 替换原来的初始化器列表
# Clear existing initializers
model.graph.ClearField('initializer')
# Add the new initializers
model.graph.initializer.extend(new_initializers)
理想很丰满,现实很骨感···并没有发生什么变化
尝试3:直接删除dropout层,连接其输入输出
dropout层在推理的时候也没什么用,直接删除,然后连接上原dropout的输入输出层就好了
import onnx
from onnx import helper# 加载模型
onnx_model = onnx.load(model_path)
graph = onnx_model.graph# 找到 Dropout 层
nodes_to_remove = [node for node in graph.node if node.op_type == 'Dropout']# 删除 Dropout 层并重新连接
for node in nodes_to_remove:input_name = node.input[0]output_name = node.output[0]# 找到所有使用 Dropout 输出作为输入的节点for next_node in graph.node:for i, input_name in enumerate(next_node.input):if input_name == node.output[0]:next_node.input[i] = node.input[0]# 从图中移除 Dropout 节点graph.node.remove(node)# 保存修改后的模型
# check if the model is valid
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'tinyvit_11m_sim_replace.onnx')
成功了,模型的dropout层都被删除了。
结语
虽然尝试了好几种方式···不过这些具体的代码我基本都是问的copilot,不得不说代码助手减轻了好多工作。