VTK  9.6.20260516
vtkONNXInference.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
2// SPDX-License-Identifier: BSD-3-Clause
23#ifndef vtkONNXInference_h
24#define vtkONNXInference_h
25
26#include "vtkFiltersONNXModule.h" // For export macro
28
29#include "vtkDataObject.h" // For AttributeTypes
30
31#include <memory> // For std::unique_ptr
32#include <vector> // For std::vector
33
34VTK_ABI_NAMESPACE_BEGIN
36namespace Ort
37{
38class AllocatorWithDefaultOptions;
39class Value;
40}
41
42class VTKFILTERSONNX_EXPORT vtkONNXInference : public vtkPassInputTypeAlgorithm
43{
44public:
46 void PrintSelf(ostream& os, vtkIndent indent) override;
47
49
51
54 void SetModelFile(const std::string& file);
55 vtkGetMacro(ModelFile, std::string);
57
68
71 void SetTimeStepValues(const std::vector<double>& times);
72
76 void SetTimeStepValue(vtkIdType idx, double timeStepValue);
77
83
89
91
95 vtkSetMacro(TimeStepIndex, int);
96 vtkGetMacro(TimeStepIndex, int);
98
110 void SetInputParameters(const std::vector<float>& params);
111
116 void SetInputParameter(vtkIdType idx, float InputParameter);
117
123
125
130 void SetInputShape(const std::vector<int64_t>& shape);
131 void SetInputShape(vtkIdType idx, int shapeElement);
133 const std::vector<int64_t>& GetInputShape() const;
135
141
147
149
157 void SetInputPermutation(const std::vector<int>& shape);
158 const std::vector<int>& GetInputPermutation() const;
160
162
170 void SetOutputPermutation(const std::vector<int>& permutation);
171 const std::vector<int>& GetOutputPermutation() const;
173
175
179 vtkSetMacro(FieldArrayInput, bool);
180 vtkGetMacro(FieldArrayInput, bool);
181 vtkBooleanMacro(FieldArrayInput, bool);
183
185
188 vtkSetMacro(ProcessedFieldArrayName, const std::string&);
189 vtkGetMacro(ProcessedFieldArrayName, const std::string&);
191
193
196 vtkSetMacro(OutputDimension, int);
197 vtkGetMacro(OutputDimension, int);
199
201
205 vtkSetMacro(ArrayAssociation, int);
206 vtkGetMacro(ArrayAssociation, int);
208
209protected:
211 ~vtkONNXInference() override = default;
212
217
219
224 int ExecuteData(vtkDataObject* input, vtkDataObject* output, double timevalue);
225
226private:
227 vtkONNXInference(const vtkONNXInference&) = delete;
228 void operator=(const vtkONNXInference&) = delete;
229
234 void InitializeSession();
235
241 bool ShouldGenerateTimeSteps();
242
247 bool GenerateInputTensorFromParameters(
248 std::vector<float>& parameters, Ort::Value& inputTensor, double timeValue);
249
254 bool GenerateInputTensorFromFieldArray(
255 Ort::Value& inputTensor, vtkDataSetAttributes* inAttributes);
256
261 std::vector<Ort::Value> RunModel(Ort::Value& inputTensor);
262
263 // Input related parameters
264 std::string ModelFile;
265 std::vector<int64_t> InputShape = { 0 };
266 std::vector<float> InputParameters;
267 std::vector<double> TimeStepValues;
268 int TimeStepIndex = -1;
269 bool FieldArrayInput = false;
270 std::string ProcessedFieldArrayName;
271 std::vector<int> InputPermutation;
272
273 // Output related parameters
274 int OutputDimension = 1;
275 std::vector<int> OutputPermutation;
276
277 int ArrayAssociation = vtkDataObject::CELL;
278 std::vector<float> InputDataBuffer;
279
280 bool Initialized = false;
281 std::unique_ptr<vtkONNXInferenceInternals> Internals;
282};
283VTK_ABI_NAMESPACE_END
284
285#endif // vtkONNXInference_h
general representation of visualization data
represent and manipulate attribute data in a dataset
a simple class to control print indentation
Definition vtkIndent.h:108
Store zero or more vtkInformation instances.
Store vtkAlgorithm input/output information.
void ClearInputParameters()
Clear the input parameters vector.
void ClearTimeStepValues()
Clear the time step values vector.
void SetInputPermutation(const std::vector< int > &shape)
Set/Get the permutation between the VTK array and the model input.
void SetInputShape(vtkIdType idx, int shapeElement)
Set/Get the shape of the input.
int RequestInformation(vtkInformation *, vtkInformationVector **, vtkInformationVector *) override
This is required to inform the pipeline of the time steps.
~vtkONNXInference() override=default
const std::vector< int > & GetOutputPermutation() const
Set/Get the permutation between the model output and a VTK array.
void SetInputShape(const std::vector< int64_t > &shape)
Set/Get the shape of the input.
int RequestData(vtkInformation *, vtkInformationVector **, vtkInformationVector *) override
This is called within ProcessRequest when a request asks the algorithm to do its work.
void SetInputParameters(const std::vector< float > &params)
Input Parameters.
void SetInputParameter(vtkIdType idx, float InputParameter)
Set an input parameter at a given index.
static vtkONNXInference * New()
void SetTimeStepValues(const std::vector< double > &times)
Time Steps.
void SetTimeStepValue(vtkIdType idx, double timeStepValue)
Set a time value at a given index.
void SetModelFile(const std::string &file)
Get/Set the path to the ONNX model and load it.
void ClearInputShape()
Clear the input shape vector.
int ExecuteData(vtkDataObject *input, vtkDataObject *output, double timevalue)
Execute the inference and add the resulting array on the given data object.
void SetNumberOfInputShapeElements(vtkIdType nb)
Set the number of input shape values.
void SetOutputPermutation(const std::vector< int > &permutation)
Set/Get the permutation between the model output and a VTK array.
void PrintSelf(ostream &os, vtkIndent indent) override
Methods invoked by print to print information about the object including superclasses.
const std::vector< int > & GetInputPermutation() const
Set/Get the permutation between the VTK array and the model input.
void SetInputShape(vtkIdType nb)
Set/Get the shape of the input.
void SetNumberOfTimeStepValues(vtkIdType nb)
Set the number of time step values.
const std::vector< int64_t > & GetInputShape() const
Set/Get the shape of the input.
VTK internal class for hiding ONNX members.
int vtkIdType
Definition vtkType.h:363