Pytorch实现LSTM预测模型并使用C++相应的ONNX模型推理

Pytorch实现RNN模型

代码

import torch
import torch.nn as nnclass LSTM(nn.Module):def __init__(self, input_size, output_size, out_channels, num_layers, device):super(LSTM, self).__init__()self.device = deviceself.input_size = input_sizeself.hidden_size = input_sizeself.num_layers = num_layersself.output_size = output_sizeself.lstm = nn.LSTM(input_size=self.input_size,hidden_size=self.hidden_size,num_layers=self.num_layers,batch_first=True)self.out_channels = out_channelsself.fc = nn.Linear(self.hidden_size, self.output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)out, _ = self.lstm(x, (h0, c0))if self.out_channels == 1:out = out[:, -1, :]return outreturn outbatch_size = 20
input_size = 10
output_size = 10
num_layers = 2
out_channels = 1model = LSTM(input_size, output_size, out_channels, num_layers, "cpu")
model.eval() input_names = ["input"]
output_names  = ["output"]x = torch.randn((batch_size, input_size, output_size))
print(x.shape)
y = model(x)
print(y.shape)torch.onnx.export(model, x, 'LSTM.onnx', verbose=True, input_names=input_names, output_names=output_names,dynamic_axes={'input':[0], 'output':[0]} )import onnx
model = onnx.load("LSTM.onnx")
print("load model done.")
onnx.checker.check_model(model)
print(onnx.helper.printable_graph(model.graph))
print("check model done.")

运行结果

torch.Size([20, 10, 10])
torch.Size([20, 10])
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input input"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:2041: UserWarning: No names were found for specified dynamic axes of provided input.Automatically generated names will be applied to each dynamic axes of input output"No names were found for specified dynamic axes of provided input."
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py:4322: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. + "or define the initial states (h0/c0) as inputs of the model. "
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)_C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:688: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)graph, params_dict, GLOBALS.export_onnx_opset_version
/home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/onnx/utils.py:1179: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)graph, params_dict, GLOBALS.export_onnx_opset_version
Exported graph: graph(%input : Float(*, 10, 10, strides=[100, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_193 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_194 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_195 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu),%onnx::LSTM_213 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_214 : Float(1, 40, 10, strides=[400, 10, 1], requires_grad=0, device=cpu),%onnx::LSTM_215 : Float(1, 80, strides=[80, 1], requires_grad=0, device=cpu)):%/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%input), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name="/Constant_1"](), scope: __main__.LSTM::%onnx::Unsqueeze_18 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()%/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name="/Unsqueeze"](%/Gather_output_0, %onnx::Unsqueeze_18), scope: __main__.LSTM::%/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={10}, onnx_name="/Constant_2"](), scope: __main__.LSTM::%/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="/Concat"](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/ConstantOfShape_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name="/ConstantOfShape"](%/Concat_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/Cast_output_0 : Float(*, *, *, strides=[200, 10, 1], requires_grad=0, device=cpu) = onnx::Cast[to=1, onnx_name="/Cast"](%/ConstantOfShape_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:23:0%/lstm/Transpose_output_0 : Float(10, *, 10, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose"](%input), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%onnx::LSTM_26 : Tensor? = prim::Constant(), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_1_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_1"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_2_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_2"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice"](%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_3_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_3"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_4_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_4"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_5_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_5"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_1_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_1"](%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/LSTM_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM"](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %onnx::LSTM_26, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_6_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_6"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Squeeze_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze"](%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_7_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_7"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_8_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_8"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_9_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_9"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_2_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_2"](%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_10_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="/lstm/Constant_10"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_11_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_11"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_12_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={2}, onnx_name="/lstm/Constant_12"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Slice_3_output_0 : Float(*, *, *, device=cpu) = onnx::Slice[onnx_name="/lstm/Slice_3"](%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/LSTM_1_output_0 : Float(10, 1, *, 10, device=cpu), %/lstm/LSTM_1_output_1 : Float(1, *, 10, device=cpu), %/lstm/LSTM_1_output_2 : Float(1, *, 10, device=cpu) = onnx::LSTM[hidden_size=10, onnx_name="/lstm/LSTM_1"](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %onnx::LSTM_26, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Constant_13_output_0 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={1}, onnx_name="/lstm/Constant_13"](), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Squeeze_1_output_0 : Float(10, *, 10, device=cpu) = onnx::Squeeze[onnx_name="/lstm/Squeeze_1"](%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/lstm/Transpose_1_output_0 : Float(*, 10, 10, strides=[10, 200, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name="/lstm/Transpose_1"](%/lstm/Squeeze_1_output_0), scope: __main__.LSTM::/torch.nn.modules.rnn.LSTM::lstm # /home/ubuntu/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py:775:0%/Constant_3_output_0 : Long(device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_3"](), scope: __main__.LSTM::%output : Float(*, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gather[axis=1, onnx_name="/Gather_1"](%/lstm/Transpose_1_output_0, %/Constant_3_output_0), scope: __main__.LSTM:: # /zengli/20230320/ao/test/test_onnx_lstm.py:29:0return (%output)load model done.
graph torch_jit (%input[FLOAT, input_dynamic_axes_1x10x10]
) initializers (%onnx::LSTM_193[FLOAT, 1x40x10]%onnx::LSTM_194[FLOAT, 1x40x10]%onnx::LSTM_195[FLOAT, 1x80]%onnx::LSTM_213[FLOAT, 1x40x10]%onnx::LSTM_214[FLOAT, 1x40x10]%onnx::LSTM_215[FLOAT, 1x80]
) {%/Shape_output_0 = Shape(%input)%/Constant_output_0 = Constant[value = <Scalar Tensor []>]()%/Gather_output_0 = Gather[axis = 0](%/Shape_output_0, %/Constant_output_0)%/Constant_1_output_0 = Constant[value = <Tensor>]()%onnx::Unsqueeze_18 = Constant[value = <Tensor>]()%/Unsqueeze_output_0 = Unsqueeze(%/Gather_output_0, %onnx::Unsqueeze_18)%/Constant_2_output_0 = Constant[value = <Tensor>]()%/Concat_output_0 = Concat[axis = 0](%/Constant_1_output_0, %/Unsqueeze_output_0, %/Constant_2_output_0)%/ConstantOfShape_output_0 = ConstantOfShape[value = <Tensor>](%/Concat_output_0)%/Cast_output_0 = Cast[to = 1](%/ConstantOfShape_output_0)%/lstm/Transpose_output_0 = Transpose[perm = [1, 0, 2]](%input)%/lstm/Constant_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_1_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_2_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_1_output_0, %/lstm/Constant_2_output_0, %/lstm/Constant_output_0)%/lstm/Constant_3_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_4_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_5_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_1_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_4_output_0, %/lstm/Constant_5_output_0, %/lstm/Constant_3_output_0)%/lstm/LSTM_output_0, %/lstm/LSTM_output_1, %/lstm/LSTM_output_2 = LSTM[hidden_size = 10](%/lstm/Transpose_output_0, %onnx::LSTM_193, %onnx::LSTM_194, %onnx::LSTM_195, %, %/lstm/Slice_output_0, %/lstm/Slice_1_output_0)%/lstm/Constant_6_output_0 = Constant[value = <Tensor>]()%/lstm/Squeeze_output_0 = Squeeze(%/lstm/LSTM_output_0, %/lstm/Constant_6_output_0)%/lstm/Constant_7_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_8_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_9_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_2_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_8_output_0, %/lstm/Constant_9_output_0, %/lstm/Constant_7_output_0)%/lstm/Constant_10_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_11_output_0 = Constant[value = <Tensor>]()%/lstm/Constant_12_output_0 = Constant[value = <Tensor>]()%/lstm/Slice_3_output_0 = Slice(%/Cast_output_0, %/lstm/Constant_11_output_0, %/lstm/Constant_12_output_0, %/lstm/Constant_10_output_0)%/lstm/LSTM_1_output_0, %/lstm/LSTM_1_output_1, %/lstm/LSTM_1_output_2 = LSTM[hidden_size = 10](%/lstm/Squeeze_output_0, %onnx::LSTM_213, %onnx::LSTM_214, %onnx::LSTM_215, %, %/lstm/Slice_2_output_0, %/lstm/Slice_3_output_0)%/lstm/Constant_13_output_0 = Constant[value = <Tensor>]()%/lstm/Squeeze_1_output_0 = Squeeze(%/lstm/LSTM_1_output_0, %/lstm/Constant_13_output_0)%/lstm/Transpose_1_output_0 = Transpose[perm = [1, 0, 2]](%/lstm/Squeeze_1_output_0)%/Constant_3_output_0 = Constant[value = <Scalar Tensor []>]()%output = Gather[axis = 1](%/lstm/Transpose_1_output_0, %/Constant_3_output_0)return %output
}
check model done.

C++调用ONNX

实现代码

vector<float> testOnnxLSTM(std::vector<std::vector<std::vector<float>>>& inputs) 
{//设置为VERBOSE,方便控制台输出时看到是使用了cpu还是gpu执行//Ort::Env env(ORT_LOGGING_LEVEL_VERBOSE, "test");Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "Default");Ort::SessionOptions session_options;session_options.SetIntraOpNumThreads(5); // 使用五个线程执行op,提升速度// 第二个参数代表GPU device_id = 0,注释这行就是cpu执行//OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);#ifdef _WIN32const wchar_t* model_path = L"C:\\Users\\xxx\\Desktop\\LSTM.onnx";#elseconst char* model_path = "C:\\Users\\xxx\\Desktop\\LSTM.onnx";#endifwprintf(L"%s\n", model_path);Ort::Session session(env, model_path, session_options);const char* input_names[] = { "input" }; const char* output_names[] = { "output" };const int input_size = 10;const int output_size = 10;const int batch_size = 1;const int seq_len = 10;std::array<float, batch_size* seq_len* input_size> input_matrix;std::array<float, batch_size* output_size> output_matrix;std::array<int64_t, 3> input_shape{ batch_size, seq_len, input_size };std::array<int64_t, 2> output_shape{ batch_size, output_size };for (int i = 0; i < batch_size; i++)for (int j = 0; j < seq_len; j++)for (int k = 0; k < input_size; k++)input_matrix[i * seq_len * input_size + j * input_size + k] = inputs[i][j][k];Ort::Value input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_matrix.data(), input_matrix.size(), input_shape.data(), input_shape.size());try{Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_matrix.data(), output_matrix.size(), output_shape.data(), output_shape.size());session.Run(Ort::RunOptions{ nullptr }, input_names, &input_tensor, 1, output_names, &output_tensor, 1); }catch (const std::exception& e){std::cout << e.what() << std::endl;}std::cout << "get data from LSTM onnx: \n";vector<float> ret;for (int i = 0; i < output_size; i++) {ret.emplace_back(output_matrix[i]);std::cout << ret[i] << "\t";}std::cout << "\n";return ret;
}

调用代码

   std::vector<std::vector<std::vector<float>>> data;for (int i = 0; i < 1; i++) {std::vector<std::vector<float>> t1;for (int j = 0; j < 10; j++) {std::vector<float> t2;for (int k = 0; k < 10; k++) {t2.push_back(1.0 * k * j / 20);}t1.push_back(t2);}data.push_back(t1);}for (auto& i : data) {for (auto& j : i) {for (auto& k : j) {std::cout << k << "\t";}std::cout << "\n";}std::cout << "\n";}auto ret = testOnnxLSTM(data);

测试结果

0       0       0       0       0       0       0       0       0       0
0       0.05    0.1     0.15    0.2     0.25    0.3     0.35    0.4     0.45
0       0.1     0.2     0.3     0.4     0.5     0.6     0.7     0.8     0.9
0       0.15    0.3     0.45    0.6     0.75    0.9     1.05    1.2     1.35
0       0.2     0.4     0.6     0.8     1       1.2     1.4     1.6     1.8
0       0.25    0.5     0.75    1       1.25    1.5     1.75    2       2.25
0       0.3     0.6     0.9     1.2     1.5     1.8     2.1     2.4     2.7
0       0.35    0.7     1.05    1.4     1.75    2.1     2.45    2.8     3.15
0       0.4     0.8     1.2     1.6     2       2.4     2.8     3.2     3.6
0       0.45    0.9     1.35    1.8     2.25    2.7     3.15    3.6     4.05C:\Users\xxx\Desktop\LSTM.onnx
get data from LSTM onnx:
0.000401703 0.00102207 0.0011015 -0.000503412 -0.000911839 -0.0011367 -0.000309185 0.000591398 -0.000362981 -4.81475e-05

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/82618.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CockroachDB集群部署

CockroachDB集群部署 1、CockroachDB简介 CockroachDB(有时简称为CRDB)是一个免费的、开源的分布式 SQL 数据库&#xff0c;它建立在一个事务性和强一致性的键 值存储之上。它由 PebbleDB(一个受 RocksDB/leveldb 启发的 K/B 存储库)支持&#xff0c;并使用 Raft 分布式共识…

TypeScript入门

目录 一&#xff1a;语言特性 二&#xff1a;TypeScript安装 NPM 安装 TypeScript 三&#xff1a;TypeScript基础语法 第一个 TypeScript 程序 四&#xff1a;TypeScript 保留关键字 空白和换行 TypeScript 区分大小写 TypeScript 注释 TypeScript 支持两种类型的注释 …

初识C语言——详细入门一(系统性学习day4)

目录 前言 一、C语言简单介绍、特点、基本构成 简单介绍&#xff1a; 特点&#xff1a; 基本构成&#xff1a; 二、认识C语言程序 标准格式&#xff1a; 简单C程序&#xff1a; 三、基本构成分类详细介绍 &#xff08;1&#xff09;关键字 &#xff08;2&#xf…

fork函数

二.fork函数 2.1函数原型 fork()函数在 C 语言中的原型如下&#xff1a; #include <unistd.h>pid_t fork(void);其中pid_t是一个整型数据类型&#xff0c;用于表示进程ID。fork()函数返回值是一个pid_t类型的值&#xff0c;具体含义如下&#xff1a; 如果调用fork()的…

MyBatis中当实体类中的属性名和表中的字段名不一样,怎么办

方法1&#xff1a; 在mybatis核心配置文件中指定&#xff0c;springboot加载mybatis核心配置文件 springboot项目的一个特点就是0配置&#xff0c;本来就省掉了mybatis的核心配置文件&#xff0c;现在又加回去算什么事&#xff0c;总之这种方式可行但没人这样用 具体操作&…

MFC C++ 数据结构及相互转化 CString char * char[] byte PCSTR DWORE unsigned

CString&#xff1a; char * char [] BYTE BYTE [] unsigned char DWORD CHAR&#xff1a;单字节字符8bit WCHAR为Unicode字符:typedef unsigned short wchar_t TCHAR : 如果当前编译方式为ANSI(默认)方式&#xff0c;TCHAR等价于CHAR&#xff0c;如果为Unicode方式&#xff0c…

Python灰帽编程——错误异常处理与面向对象

文章目录 错误异常处理与面向对象1. 错误和异常1.1 基本概念1.1.1 Python 异常 1.2 检测&#xff08;捕获&#xff09;异常1.2.1 try except 语句1.2.2 捕获多种异常1.2.3 捕获所有异常 1.3 处理异常1.4 特殊场景1.4.1 with 语句 1.5 脚本完善 2. 内网主机存活检测程序2.1 scap…

Git从入门到起飞(详细)

Git从入门到起飞 Git从入门到起飞什么是Git&#xff1f;使用git前提(注册git)下载Git在Windows上安装Git在macOS上安装Git在Linux上安装Git 配置Git配置全局用户信息配置文本编辑器 创建第一个Git仓库初始化仓库拉取代码添加文件到仓库提交更改推送 Git基本操作查看提交历史比较…

【Java 基础篇】Java字符打印流详解:文本数据的输出利器

在Java编程中&#xff0c;我们经常需要将数据输出到文件或其他输出源中。Java提供了多种输出流来帮助我们完成这项任务&#xff0c;其中字符打印流是一个非常有用的工具。本文将详细介绍Java字符打印流的用法&#xff0c;以及如何在实际编程中充分利用它。 什么是字符打印流&a…

矩阵 m * M = c

文章目录 题1题2 题1 (2023江苏领航杯-prng) 题目来源&#xff1a;https://dexterjie.github.io/2023/09/12/%E8%B5%9B%E9%A2%98%E5%A4%8D%E7%8E%B0/2023%E9%A2%86%E8%88%AA%E6%9D%AF/ 题目描述&#xff1a; (没有原数据&#xff0c;自己生成的数据) from Crypto.Util.number…

DNG格式详解,DNG是什么?为何DNG可以取代RAW统一单反相机、苹果安卓移动端相机拍摄输出原始图像数据标准

返回图像处理总目录&#xff1a;《JavaCV图像处理合集总目录》 前言 在DNG格式发布之前&#xff0c;我们先了解一下之前单反相机、苹果和安卓移动端相机拍照输出未经处理的原始图像格式是什么&#xff1f; RAW 什么是RAW&#xff1f; RAW是未经处理、也未经压缩的格式。可以…

Rust通用编程概念(3)

Rust通用编程概念 1.变量和可变性1.执行cargo run2.变量3.变量的可变性4.常量5.遮蔽5.1遮蔽与mut区别1.遮蔽2.mut 2.数据类型1.标量类型1.1整数类型1.2浮点数类型1.3数字运算1.4布尔类型1.5字符类型 2.复合类型2.1元组类型2.2数组类型1.访问数组2.无效的数组元素访问 3.函数3.1…

js实现贪吃蛇游戏

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>贪吃蛇游戏</title><style>.game-contai…

如何解决 503 Service Temporarily Unavailable?

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f405;&#x1f43e;猫头虎建议程序员必备技术栈一览表&#x1f4d6;&#xff1a; &#x1f6e0;️ 全栈技术 Full Stack: &#x1f4da…

想要精通算法和SQL的成长之路 - 填充书架

想要精通算法和SQL的成长之路 - 填充书架 前言一. 填充书架1.1 优化 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 填充书架 原题链接 题目中有一个值得注意的点就是&#xff1a; 需要按照书本顺序摆放。每一层当中&#xff0c;只要厚度不够了&#xff0c;当前层最高…

vue3写垂直轮播效果(translateY)

实现思路&#xff1a;卡片移动使用css的translateY属性实现&#xff0c;每个卡片从最下面移动到最上面&#xff0c;然后直接移动到最下面&#xff0c;每次改变的位移是固定的&#xff0c;假设每次移动50px&#xff0c;当移动到最小时&#xff0c;就让translataY为0&#xff0c;…

Python150题day06

1.4字典练习题 ①字典基本操作 dic { python: 95, java: 99, c: 100 } 用程序解答以下题目 1.字典的长度是多少 2.请修改java这个key对应的value值为98 3.删除 c 这个key 4.增加一个key-value对&#xff0c;key值为 php,value是90 5.获取所有的key值&#xff0c;存储在列表里…

【考研数学】高等数学第六模块 —— 空间解析几何(1,向量基本概念与运算)

文章目录 引言一、空间解析几何的理论1.1 基本概念1.2 向量的运算 写在最后 引言 我自认空间想象能力较差&#xff0c;所以当初学这个很吃力。希望现在再接触&#xff0c;能好点。 一、空间解析几何的理论 1.1 基本概念 1.向量 —— 既有大小&#xff0c;又有方向的量称为向…

C语言指针,深度长文全面讲解

指针对于C来说太重要。然而&#xff0c;想要全面理解指针&#xff0c;除了要对C语言有熟练的掌握外&#xff0c;还要有计算机硬件以及操作系统等方方面面的基本知识。所以本文尽可能的通过一篇文章完全讲解指针。 为什么需要指针&#xff1f; 指针解决了一些编程中基本的问题。…

spring aop源码解析

spring知识回顾 spring的两个重要功能&#xff1a;IOC、AOP&#xff0c;在ioc容器的初始化过程中&#xff0c;会触发2种处理器的调用&#xff0c; 前置处理器(BeanFactoryPostProcessor)后置处理器(BeanPostProcessor)。 前置处理器的调用时机是在容器基本创建完成时&#xff…