VTK  9.6.20260516
vtkONNXInternalUtils.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
2// SPDX-License-Identifier: BSD-3-Clause
9
10#ifndef vtkONNXInternalUtils_h
11#define vtkONNXInternalUtils_h
12
13#include "vtkSMPTools.h"
14
15#include <algorithm>
16#include <numeric>
17#include <onnxruntime_cxx_api.h>
18#include <vector>
19
21{
22
27inline Ort::Value RawToTensor(float* data, const std::vector<int64_t>& shape)
28{
29 int64_t numberElements = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<>());
30
31 Ort::MemoryInfo memInfo =
32 Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
33
34 return Ort::Value::CreateTensor<float>(memInfo, data, numberElements, shape.data(), shape.size());
35}
36
40inline bool IsPermutation(const std::vector<int>& permutation)
41{
42 std::vector<int> identity(permutation.size());
43 std::iota(identity.begin(), identity.end(), 0);
44 return std::is_permutation(identity.begin(), identity.end(), identity.begin());
45}
46
51inline std::vector<int> InversePermutation(const std::vector<int>& permutation)
52{
53 std::vector<int> inversePermutation(permutation.size());
54 for (size_t i = 0; i < permutation.size(); ++i)
55 {
56 inversePermutation[permutation[i]] = i;
57 }
58 return inversePermutation;
59}
60
65inline void Permute(
66 float* data, const std::vector<int64_t>& outputShape, const std::vector<int>& permutation)
67{
68 const size_t nDim = outputShape.size();
69 int64_t numElements =
70 std::accumulate(outputShape.begin(), outputShape.end(), 1LL, std::multiplies<>());
71
72 // Compute intput shape
73 std::vector<int> inversePermutation = InversePermutation(permutation);
74 std::vector<int64_t> inputShape(outputShape.size());
75 for (size_t i = 0; i < nDim; ++i)
76 {
77 inputShape[i] = outputShape[inversePermutation[i]];
78 }
79
80 // Compute input/output memory strides
81 auto computeStrides = [nDim](const std::vector<int64_t>& shape)
82 {
83 std::vector<int64_t> strides(nDim);
84 strides[nDim - 1] = 1;
85 for (int i = static_cast<int>(nDim) - 2; i >= 0; --i)
86 {
87 strides[i] = strides[i + 1] * shape[i + 1];
88 }
89 return strides;
90 };
91
92 std::vector<int64_t> inputStrides = computeStrides(inputShape);
93 std::vector<int64_t> outputStrides = computeStrides(outputShape);
94
95 // Permutation loop
96 std::vector<float> buffer(numElements);
97 vtkSMPTools::For(0, numElements,
98 [&](int64_t begin, int64_t end)
99 {
100 std::vector<int64_t> inputCoords(nDim);
101 std::vector<int64_t> outputCoords(nDim);
102
103 for (int64_t inputIndex = begin; inputIndex < end; ++inputIndex)
104 {
105 int tmpInputIndex = inputIndex;
106 // Input shape coords
107 for (size_t i = 0; i < nDim; ++i)
108 {
109 inputCoords[i] = tmpInputIndex / inputStrides[i];
110 tmpInputIndex %= inputStrides[i];
111 }
112
113 // Apply permutation
114 for (size_t i = 0; i < nDim; ++i)
115 {
116 outputCoords[i] = inputCoords[permutation[i]];
117 }
118
119 // Output shape coords
120 int outputIndex = 0;
121 for (size_t i = 0; i < nDim; ++i)
122 {
123 outputIndex += outputCoords[i] * outputStrides[i];
124 }
125
126 buffer[outputIndex] = data[inputIndex];
127 }
128 });
129
130 std::copy(buffer.begin(), buffer.end(), data);
131}
132
133} // namespace vtkONNXInternalUtils
134
135#endif
static void For(vtkIdType first, vtkIdType last, vtkIdType grain, Functor &f)
Execute a for operation in parallel.
void Permute(float *data, const std::vector< int64_t > &outputShape, const std::vector< int > &permutation)
This reorders the memory pointed by data so that it matches the layout defined by outputShape and per...
bool IsPermutation(const std::vector< int > &permutation)
This checks if a sequence actually represents a permutation.
Ort::Value RawToTensor(float *data, const std::vector< int64_t > &shape)
Wraps a raw float buffer into a ONNX Runtime tensor.
std::vector< int > InversePermutation(const std::vector< int > &permutation)
Computes the inverse of the input permutation, in other words the permutation you need to apply after...