VTK  9.5.20251103
vtkONNXInference.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
2// SPDX-License-Identifier: BSD-3-Clause
23#ifndef vtkONNXInference_h
24#define vtkONNXInference_h
25
26#include "vtkFiltersONNXModule.h" // For export macro
28
29#include "vtkDataObject.h" // for AttributeTypes
30
31#include <memory> // For std::unique_ptr
32#include <vector> // For std::vector
33
34VTK_ABI_NAMESPACE_BEGIN
36namespace Ort
37{
38class AllocatorWithDefaultOptions;
39class Value;
40}
41
42class VTKFILTERSONNX_EXPORT vtkONNXInference : public vtkPassInputTypeAlgorithm
43{
44public:
46 void PrintSelf(ostream& os, vtkIndent indent) override;
47
49
51
54 void SetModelFile(const std::string& file);
55 vtkGetMacro(ModelFile, std::string);
57
68
71 void SetTimeStepValues(const std::vector<double>& times);
72
76 void SetTimeStepValue(vtkIdType idx, double timeStepValue);
77
83
89
94 vtkSetMacro(TimeStepIndex, int);
95
100 vtkGetMacro(TimeStepIndex, int);
102
112
115 void SetInputParameters(const std::vector<float>& params);
116
121 void SetInputParameter(vtkIdType idx, float InputParameter);
122
128
132 vtkGetMacro(InputSize, int);
133
140
142
145 vtkSetMacro(OutputDimension, int);
146 vtkGetMacro(OutputDimension, int);
148
150
154 vtkSetMacro(ArrayAssociation, int);
155 vtkGetMacro(ArrayAssociation, int);
157
158protected:
160 ~vtkONNXInference() override = default;
161
166
168
173 int ExecuteData(vtkDataObject* input, vtkDataObject* output, double timevalue);
174
175private:
176 vtkONNXInference(const vtkONNXInference&) = delete;
177 void operator=(const vtkONNXInference&) = delete;
178
183 void InitializeSession();
184
190 bool ShouldGenerateTimeSteps();
191
196 std::vector<Ort::Value> RunModel(Ort::Value& inputTensor);
197
198 // Input related parameters
199 std::string ModelFile;
200 int64_t InputSize = 0;
201 std::vector<float> InputParameters;
202 std::vector<double> TimeStepValues;
203 int TimeStepIndex = -1;
204
205 // Output related parameters
206 int OutputDimension = 1;
207
208 int ArrayAssociation = vtkDataObject::CELL;
209
210 bool Initialized = false;
211 std::unique_ptr<vtkONNXInferenceInternals> Internals;
212};
213VTK_ABI_NAMESPACE_END
214
215#endif // vtkONNXInference_h
general representation of visualization data
a simple class to control print indentation
Definition vtkIndent.h:108
Store zero or more vtkInformation instances.
Store vtkAlgorithm input/output information.
Infer an ONNX model.
void ClearInputParameters()
Clear the input parameters vector.
void ClearTimeStepValues()
Clear the time step values vector.
int RequestInformation(vtkInformation *, vtkInformationVector **, vtkInformationVector *) override
This is required to inform the pipeline of the time steps.
~vtkONNXInference() override=default
int RequestData(vtkInformation *, vtkInformationVector **, vtkInformationVector *) override
This is called within ProcessRequest when a request asks the algorithm to do its work.
void SetInputParameters(const std::vector< float > &params)
Input Parameters.
void SetInputParameter(vtkIdType idx, float InputParameter)
Set an input parameter at a given index.
static vtkONNXInference * New()
void SetTimeStepValues(const std::vector< double > &times)
Time Steps.
void SetTimeStepValue(vtkIdType idx, double timeStepValue)
Set a time value at a given index.
void SetModelFile(const std::string &file)
Get/Set the path to the ONNX model and load it.
void SetNumberOfInputParameters(vtkIdType nb)
Set the number of input parameters.
int ExecuteData(vtkDataObject *input, vtkDataObject *output, double timevalue)
Execute the inference and add the resulting array on the given data object.
void PrintSelf(ostream &os, vtkIndent indent) override
Methods invoked by print to print information about the object including superclasses.
void SetNumberOfTimeStepValues(vtkIdType nb)
Set the number of time step values.
Superclass for algorithms that produce output of the same type as input.
VTK internal class for hiding ONNX members.
int vtkIdType
Definition vtkType.h:367