VTK  9.6.20260111
vtkONNXInference.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
2// SPDX-License-Identifier: BSD-3-Clause
23#ifndef vtkONNXInference_h
24#define vtkONNXInference_h
25
26#include "vtkFiltersONNXModule.h" // For export macro
28
29#include "vtkDataObject.h" // For AttributeTypes
30
31#include <memory> // For std::unique_ptr
32#include <vector> // For std::vector
33
34VTK_ABI_NAMESPACE_BEGIN
36namespace Ort
37{
38class AllocatorWithDefaultOptions;
39class Value;
40}
41
42class VTKFILTERSONNX_EXPORT vtkONNXInference : public vtkPassInputTypeAlgorithm
43{
44public:
46 void PrintSelf(ostream& os, vtkIndent indent) override;
47
49
51
54 void SetModelFile(const std::string& file);
55 vtkGetMacro(ModelFile, std::string);
57
68
71 void SetTimeStepValues(const std::vector<double>& times);
72
76 void SetTimeStepValue(vtkIdType idx, double timeStepValue);
77
83
89
91
95 vtkSetMacro(TimeStepIndex, int);
96 vtkGetMacro(TimeStepIndex, int);
98
110 void SetInputParameters(const std::vector<float>& params);
111
116 void SetInputParameter(vtkIdType idx, float InputParameter);
117
123
125
130 void SetInputShape(const std::vector<int64_t>& shape);
131 void SetInputShape(vtkIdType idx, int shapeElement);
133 const std::vector<int64_t>& GetInputShape() const;
135
141
147
149
153 vtkSetMacro(FieldArrayInput, bool);
154 vtkGetMacro(FieldArrayInput, bool);
155 vtkBooleanMacro(FieldArrayInput, bool);
157
159
162 vtkSetMacro(ProcessedFieldArrayName, const std::string&);
163 vtkGetMacro(ProcessedFieldArrayName, const std::string&);
165
167
170 vtkSetMacro(OutputDimension, int);
171 vtkGetMacro(OutputDimension, int);
173
175
179 vtkSetMacro(ArrayAssociation, int);
180 vtkGetMacro(ArrayAssociation, int);
182
183protected:
185 ~vtkONNXInference() override = default;
186
191
193
198 int ExecuteData(vtkDataObject* input, vtkDataObject* output, double timevalue);
199
200private:
201 vtkONNXInference(const vtkONNXInference&) = delete;
202 void operator=(const vtkONNXInference&) = delete;
203
208 void InitializeSession();
209
215 bool ShouldGenerateTimeSteps();
216
221 bool GenerateInputTensorFromParameters(
222 std::vector<float>& parameters, Ort::Value& inputTensor, double timeValue);
223
228 bool GenerateInputTensorFromFieldArray(
229 Ort::Value& inputTensor, vtkDataSetAttributes* inAttributes);
230
235 std::vector<Ort::Value> RunModel(Ort::Value& inputTensor);
236
237 // Input related parameters
238 std::string ModelFile;
239 std::vector<int64_t> InputShape = { 0 };
240 std::vector<float> InputParameters;
241 std::vector<double> TimeStepValues;
242 int TimeStepIndex = -1;
243 bool FieldArrayInput = false;
244 std::string ProcessedFieldArrayName;
245
246 // Output related parameters
247 int OutputDimension = 1;
248
249 int ArrayAssociation = vtkDataObject::CELL;
250
251 bool Initialized = false;
252 std::unique_ptr<vtkONNXInferenceInternals> Internals;
253};
254VTK_ABI_NAMESPACE_END
255
256#endif // vtkONNXInference_h
general representation of visualization data
represent and manipulate attribute data in a dataset
a simple class to control print indentation
Definition vtkIndent.h:108
Store zero or more vtkInformation instances.
Store vtkAlgorithm input/output information.
void ClearInputParameters()
Clear the input parameters vector.
void ClearTimeStepValues()
Clear the time step values vector.
void SetInputShape(vtkIdType idx, int shapeElement)
Set/Get the shape of the input.
int RequestInformation(vtkInformation *, vtkInformationVector **, vtkInformationVector *) override
This is required to inform the pipeline of the time steps.
~vtkONNXInference() override=default
void SetInputShape(const std::vector< int64_t > &shape)
Set/Get the shape of the input.
int RequestData(vtkInformation *, vtkInformationVector **, vtkInformationVector *) override
This is called within ProcessRequest when a request asks the algorithm to do its work.
void SetInputParameters(const std::vector< float > &params)
Input Parameters.
void SetInputParameter(vtkIdType idx, float InputParameter)
Set an input parameter at a given index.
static vtkONNXInference * New()
void SetTimeStepValues(const std::vector< double > &times)
Time Steps.
void SetTimeStepValue(vtkIdType idx, double timeStepValue)
Set a time value at a given index.
void SetModelFile(const std::string &file)
Get/Set the path to the ONNX model and load it.
void ClearInputShape()
Clear the input shape vector.
int ExecuteData(vtkDataObject *input, vtkDataObject *output, double timevalue)
Execute the inference and add the resulting array on the given data object.
void SetNumberOfInputShapeElements(vtkIdType nb)
Set the number of input shape values.
void PrintSelf(ostream &os, vtkIndent indent) override
Methods invoked by print to print information about the object including superclasses.
void SetInputShape(vtkIdType nb)
Set/Get the shape of the input.
void SetNumberOfTimeStepValues(vtkIdType nb)
Set the number of time step values.
const std::vector< int64_t > & GetInputShape() const
Set/Get the shape of the input.
VTK internal class for hiding ONNX members.
int vtkIdType
Definition vtkType.h:368