SourceXtractorPlusPlus 0.19.2
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
OnnxModel.h
Go to the documentation of this file.
1/*
2 * OnnxModel.h
3 *
4 * Created on: Feb 16, 2021
5 * Author: mschefer
6 */
7
8#ifndef _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
9#define _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
10
11#include <memory>
12#include <vector>
13#include <list>
14#include <iostream>
15#include <numeric>
16
17#include <onnxruntime_cxx_api.h>
18
19namespace SourceXtractor {
20
21class OnnxModel {
22public:
23
24 explicit OnnxModel(const std::string& model_path);
25
26 template<typename T, typename U>
27 void run(std::vector<T>& input_data, std::vector<U>& output_data) const {
28 Ort::RunOptions run_options;
29 auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
30
31 // Allocate memory
33 input_shape[0] = 1;
34 size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
35
37 output_shape[0] = 1;
38 size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
39
40 // Check input and output size are OK
41 if (input_data.size() < input_size || output_data.size() < output_size) {
42 throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
43 }
44
45 // Setup input/output tensors
46 auto input_tensor = Ort::Value::CreateTensor<T>(
47 mem_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size());
48 auto output_tensor = Ort::Value::CreateTensor<U>(
49 mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
50
51 // Run the model
52 const char *input_name = m_input_names[0].c_str();
53 const char *output_name = m_output_name.c_str();
54
55 m_session->Run(run_options, &input_name, &input_tensor, 1, &output_name, &output_tensor, 1);
56 }
57
58 template<typename T, typename U>
59 void runMultiInput(std::map<std::string, std::vector<T>>& input_data, std::vector<U>& output_data) const {
60 Ort::RunOptions run_options;
61 auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
62
63 std::vector<const char *> input_names;
64 std::vector<Ort::Value> input_tensors;
65
66 int inputs_nb = m_input_names.size();
67 for (int i=0; i<inputs_nb; i++) {
68 input_names.emplace_back(m_input_names[i].c_str());
69
70 // Allocate memory
72 input_shape[0] = 1;
73 size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
74
75 // Check input size is OK
76 if (input_data[m_input_names[i]].size() < input_size) {
77 throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
78 }
79
80 input_tensors.emplace_back(Ort::Value::CreateTensor<T>(
81 mem_info, input_data[m_input_names[i]].data(), input_data[m_input_names[i]].size(),
82 input_shape.data(), input_shape.size()));
83 }
84
85 // Output name and shape
86 const char *output_name = m_output_name.c_str();
88 output_shape[0] = 1;
89
90 // Setup output tensor
91 size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
92
93 // Check output and output size are OK
94 if (output_data.size() < output_size) {
95 throw Elements::Exception() << "OnnxModel: Insufficient buffer size ";
96 }
97
98 auto output_tensor = Ort::Value::CreateTensor<U>(
99 mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
100
101 // Run the model
102 m_session->Run(run_options, &input_names[0], &input_tensors[0], inputs_nb, &output_name, &output_tensor, 1);
103 }
104
105
106 ONNXTensorElementDataType getInputType() const {
107 return m_input_types[0];
108 }
109
110 ONNXTensorElementDataType getOutputType() const {
111 return m_output_type;
112 }
113
115 return m_input_shapes[0];
116 }
117
119 return m_output_shape;
120 }
121
123 return m_domain_name;
124 }
125
127 return m_graph_name;
128 }
129
131 return m_input_names[0];
132 }
133
135 return m_output_name;
136 }
137
139 return m_model_path;
140 }
141
142 size_t getInputNb() const {
143 return m_input_names.size();
144 }
145
146 size_t getOutputNb() const {
147 return 1U;
148 }
149
150private:
156 ONNXTensorElementDataType m_output_type;
161};
162
163}
164
165
166#endif /* _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_ */
T accumulate(T... args)
T begin(T... args)
T c_str(T... args)
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:27
ONNXTensorElementDataType getInputType() const
Definition: OnnxModel.h:106
ONNXTensorElementDataType getOutputType() const
Definition: OnnxModel.h:110
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition: OnnxModel.h:155
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:160
std::string getGraphName() const
Definition: OnnxModel.h:126
std::string getDomain() const
Definition: OnnxModel.h:122
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:154
size_t getOutputNb() const
Definition: OnnxModel.h:146
const std::vector< std::int64_t > & getOutputShape() const
Definition: OnnxModel.h:118
std::string getOutputName() const
Definition: OnnxModel.h:134
ONNXTensorElementDataType m_output_type
Output type.
Definition: OnnxModel.h:156
std::vector< std::string > m_input_names
Input tensor name.
Definition: OnnxModel.h:153
std::string getInputName() const
Definition: OnnxModel.h:130
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:158
std::string m_graph_name
graph name
Definition: OnnxModel.h:152
void runMultiInput(std::map< std::string, std::vector< T > > &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:59
std::string m_domain_name
domain name
Definition: OnnxModel.h:151
const std::vector< std::int64_t > & getInputShape() const
Definition: OnnxModel.h:114
std::string m_model_path
Path to the ONNX model.
Definition: OnnxModel.h:159
size_t getInputNb() const
Definition: OnnxModel.h:142
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition: OnnxModel.h:157
std::string getModelPath() const
Definition: OnnxModel.h:138
T data(T... args)
T emplace_back(T... args)
T end(T... args)
T size(T... args)