forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonnx2trt_utils.cpp
1710 lines (1570 loc) · 66.6 KB
/
onnx2trt_utils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#include "onnx2trt_utils.hpp"
#include "OnnxAttrs.hpp"
#include "ShapeTensor.hpp"
#include <set>
namespace onnx2trt
{
NodeImportResult activationHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::ActivationType op, float* alpha, float* beta)
{
nvinfer1::ITensor& input = convertToTensor(inputs.at(0), ctx);
ASSERT(input.getType() != nvinfer1::DataType::kINT32 && input.getType() != nvinfer1::DataType::kBOOL
&& "TensorRT does not support activations on INT32 or BOOL inputs!", ErrorCode::kUNSUPPORTED_NODE);
nvinfer1::IActivationLayer* layer = ctx->network()->addActivation(input, op);
if (alpha)
{
layer->setAlpha(*alpha);
}
if (beta)
{
layer->setBeta(*beta);
}
return {{layer->getOutput(0)}};
}
nvinfer1::ITensor* addClip(IImporterContext* ctx, nvinfer1::ITensor* input, float clip)
{
if (clip >= 0.f)
{
nvinfer1::IActivationLayer* layer = ctx->network()->addActivation(*input, nvinfer1::ActivationType::kCLIP);
layer->setAlpha(-clip);
layer->setBeta(clip);
return layer->getOutput(0);
}
return input;
};
NodeImportResult argMinMaxHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::TopKOperation op)
{
nvinfer1::ITensor* tensorPtr = &convertToTensor(inputs.at(0), ctx);
ASSERT(tensorPtr->getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE);
// Support 1D argMin/argMax
bool needToExpandDims = (tensorPtr->getDimensions().nbDims == 1);
if (needToExpandDims)
{
// Expand dims from 1D to 2D
std::vector<int> axes{1};
tensorPtr = unsqueezeTensor(ctx, *tensorPtr, axes);
ASSERT(tensorPtr, ErrorCode::kUNSUPPORTED_NODE);
}
// Get attributes.
OnnxAttrs attrs(node, ctx);
int keepdims = attrs.get("keepdims", 1);
int axis = attrs.get("axis", 0);
// Insert a TopK layer with k set to 1.
int nbDims = tensorPtr->getDimensions().nbDims;
TRT_CHECK(convertAxis(axis, nbDims));
uint32_t axisMask = 1 << axis;
nvinfer1::ITopKLayer* layer = ctx->network()->addTopK(*tensorPtr, op, 1, axisMask);
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
// We don't care about the TopK values, just the indices.
nvinfer1::ITensor* indices = layer->getOutput(1);
indices->setType(nvinfer1::DataType::kINT32);
// Squeeze back to 1D if applicable
if (needToExpandDims)
{
std::vector<int> axes{1};
indices = squeezeTensor(ctx, *indices, axes);
ASSERT(indices, ErrorCode::kUNSUPPORTED_NODE);
}
// The default behavior of the TopK layer is to keepdims.
if (keepdims)
{
return {{indices}};
}
else
{
// Otherwise, we need to squeeze the axis dimension
std::vector<int> axes{axis};
indices = squeezeTensor(ctx, *indices, axes);
return {{indices}};
}
}
//! If t has rank less than nbDims, reshape it to have nbDims by prepending ones to its dimensions.
//! Assert failure if t has rank greater than nbDims.
static Status broadcastTensor(IImporterContext* ctx, nvinfer1::ITensor*& t, const int nbDims)
{
ASSERT(ctx->getOpsetVersion() >= 7 && "Pre-opset 7 broadcasting is unsupported in this version of the ONNX parser", ErrorCode::kUNSUPPORTED_NODE);
const nvinfer1::Dims inputDims = t->getDimensions();
const int nbInputDims = inputDims.nbDims;
assert(nbInputDims <= nbDims);
if (nbInputDims < nbDims)
{
nvinfer1::IShuffleLayer* reshape = addShuffle(ctx, *t, concat(ctx, fillShapeVector(nbDims - nbInputDims, 1), shapeOf(ctx, *t)));
t = reshape->getOutput(0);
}
return Status::success();
}
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2)
{
const int t1Dims = t1->getDimensions().nbDims;
const int t2Dims = t2->getDimensions().nbDims;
if (t1Dims == t2Dims)
{
return Status::success();
}
if (t1Dims > t2Dims)
{
return broadcastTensor(ctx, t2, t1Dims);
}
return broadcastTensor(ctx, t1, t2Dims);
}
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2, nvinfer1::ITensor*& t3)
{
const int maxDims = std::max({t1->getDimensions().nbDims, t2->getDimensions().nbDims, t3->getDimensions().nbDims});
TRT_CHECK(broadcastTensor(ctx, t1, maxDims));
TRT_CHECK(broadcastTensor(ctx, t2, maxDims));
TRT_CHECK(broadcastTensor(ctx, t3, maxDims));
return Status::success();
}
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors)
{
// Linear resize supports up to 3D resize on the outermost dimensions.
if (scaleSize > 3)
{
for (size_t i = 0; i < scaleSize - 3; i++)
{
if (scaleFactors[i] != 1)
{
return false;
}
}
}
return true;
}
nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape)
{
int rank = shape->getDimensions().d[0];
std::vector<int> starts(rank);
std::fill(starts.begin(), starts.end(), 0);
nvinfer1::Dims strides{rank};
std::fill(strides.d, strides.d + strides.nbDims, 0);
// Slice will not work if constant does not have the same rank as start/size/strides.
nvinfer1::Dims unsqueezeDims{rank};
std::fill(unsqueezeDims.d, unsqueezeDims.d + unsqueezeDims.nbDims, 1);
nvinfer1::IShuffleLayer* unsqueeze = ctx->network()->addShuffle(*constant);
unsqueeze->setReshapeDimensions(unsqueezeDims);
constant = unsqueeze->getOutput(0);
nvinfer1::ISliceLayer* broadcast = ctx->network()->addSlice(*constant, nvinfer1::Dims{}, nvinfer1::Dims{}, strides);
broadcast->setInput(1,
*addConstant(ctx, starts, ::ONNX_NAMESPACE::TensorProto_DataType_INT32, nvinfer1::Dims{1, rank})->getOutput(0));
broadcast->setInput(2, *shape);
return broadcast->getOutput(0);
}
Status convertAxis(int& axis, int nbDims)
{
// Support negative indexing
if (axis < 0)
{
axis += nbDims;
}
ASSERT(axis >= 0 && axis < nbDims, ErrorCode::kUNSUPPORTED_NODE);
return Status::success();
}
bool convertDtype(int32_t onnx_dtype, nvinfer1::DataType* trt_dtype)
{
switch (onnx_dtype)
{
case ::ONNX_NAMESPACE::TensorProto::FLOAT: *trt_dtype = nvinfer1::DataType::kFLOAT; break;
case ::ONNX_NAMESPACE::TensorProto::INT8: *trt_dtype = nvinfer1::DataType::kINT8; break;
case ::ONNX_NAMESPACE::TensorProto::FLOAT16: *trt_dtype = nvinfer1::DataType::kHALF; break;
case ::ONNX_NAMESPACE::TensorProto::BOOL: *trt_dtype = nvinfer1::DataType::kBOOL; break;
case ::ONNX_NAMESPACE::TensorProto::INT32:
*trt_dtype = nvinfer1::DataType::kINT32;
break;
// See convertOnnxWeights for sanity check if all values can be safetly downcasted to INT32
case ::ONNX_NAMESPACE::TensorProto::INT64: *trt_dtype = nvinfer1::DataType::kINT32; break;
default:
std::cerr << "Unsupported ONNX data type: " << getDtypeName(onnx_dtype) << " (" << std::to_string(onnx_dtype)
<< ")" << std::endl;
return false;
}
return true;
}
int32_t* convertINT64(const int64_t* weightValues, nvinfer1::Dims shape, IImporterContext* ctx)
{
static bool logged = false;
if (!logged)
{
LOG_WARNING(
"Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. "
"Attempting to cast down to INT32.");
logged = true;
}
const size_t nbWeights = volume(shape);
int32_t* int32Weights{
reinterpret_cast<int32_t*>(ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, shape).values)};
bool outOfBounds{false};
for (size_t i = 0; i < nbWeights; i++)
{
if (weightValues[i] > static_cast<int64_t>(INT32_MAX) || weightValues[i] < static_cast<int64_t>(INT32_MIN))
{
int32Weights[i] = static_cast<int32_t>(
std::max(std::min(weightValues[i], static_cast<int64_t>(INT32_MAX)), static_cast<int64_t>(INT32_MIN)));
LOG_VERBOSE("Weight at index " << i << ": " << weightValues[i]
<< " is out of range. Clamping to: " << int32Weights[i]);
outOfBounds = true;
}
else
{
int32Weights[i] = static_cast<int32_t>(weightValues[i]);
}
}
if (outOfBounds)
{
LOG_WARNING("One or more weights outside the range of INT32 was clamped");
}
return int32Weights;
}
template <typename DataType>
DataType* convertINT32Data(const int32_t* weightValues, nvinfer1::Dims shape, int32_t onnxdtype, IImporterContext* ctx)
{
const size_t nbWeights = volume(shape);
DataType* newWeights{
reinterpret_cast<DataType*>(ctx->createTempWeights(onnxdtype, shape).values)};
for (size_t i = 0; i < nbWeights; i++)
{
newWeights[i] = static_cast<DataType>(weightValues[i]);
}
return newWeights;
}
bool convertOnnxPadding(const std::vector<int64_t>& onnxPadding, nvinfer1::Dims2* begPadding, nvinfer1::Dims2* endPadding)
{
const size_t size = onnxPadding.size();
const size_t half = size / 2;
for (size_t i = 0; i < half - 2; i++)
{
if (onnxPadding[i] != 0)
{
return false;
}
}
begPadding->d[0] = onnxPadding[half - 2];
begPadding->d[1] = onnxPadding[half - 1];
for (size_t i = half; i < size - 2; i++)
{
if (onnxPadding[i] != 0)
{
return false;
}
}
endPadding->d[0] = onnxPadding[size - 2];
endPadding->d[1] = onnxPadding[size - 1];
return true;
}
bool convertOnnxWeights(
const ::ONNX_NAMESPACE::TensorProto& onnxTensor, onnx2trt::ShapedWeights* weights, IImporterContext* ctx)
{
// Pass through for optional (empty) initializers for unused attributes.
if (isOnnxTensorEmpty(onnxTensor))
{
auto empty = onnx2trt::ShapedWeights::empty(::ONNX_NAMESPACE::TensorProto::FLOAT);
*weights = empty;
return true;
}
nvinfer1::Dims shape;
shape.nbDims = onnxTensor.dims().size();
std::copy(onnxTensor.dims().begin(), onnxTensor.dims().end(), shape.d);
auto onnxDtype = onnxTensor.data_type();
void* dataPtr{nullptr}; // TODO: See if can make const*
size_t nbytes{0};
if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT64)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = convertINT64(reinterpret_cast<const int64_t*>(onnxTensor.raw_data().data()), shape, ctx);
nbytes = onnxTensor.raw_data().size() / 2;
}
else if (onnxTensor.int64_data().size() > 0)
{
dataPtr = convertINT64(onnxTensor.int64_data().data(), shape, ctx);
nbytes = onnxTensor.int64_data().size() * sizeof(int32_t);
}
onnxDtype = ::ONNX_NAMESPACE::TensorProto::INT32;
}
// Check for supported types that can be found in the int32_data field in the TensorProto
// https://github.com/onnx/onnx/blob/master/onnx/onnx.proto#L382-L387
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT32 || onnxDtype == ::ONNX_NAMESPACE::TensorProto::FLOAT16
|| onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT8 || onnxDtype == ::ONNX_NAMESPACE::TensorProto::BOOL)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = (void*)(onnxTensor.raw_data().data());
nbytes = onnxTensor.raw_data().size();
}
else
{
switch (onnxDtype)
{
case ::ONNX_NAMESPACE::TensorProto::INT32:
dataPtr = (void*) (onnxTensor.int32_data().data());
break;
// According to the ONNX proto spec, fp16 values are bit-wise converted to uint16_t when serialied into the protobuf.
case ::ONNX_NAMESPACE::TensorProto::FLOAT16:
dataPtr = convertINT32Data<uint16_t>(onnxTensor.int32_data().data(), shape, onnxDtype, ctx);
break;
case ::ONNX_NAMESPACE::TensorProto::INT8:
dataPtr = convertINT32Data<int8_t>(onnxTensor.int32_data().data(), shape, onnxDtype, ctx);
break;
case ::ONNX_NAMESPACE::TensorProto::BOOL:
dataPtr = convertINT32Data<uint8_t>(onnxTensor.int32_data().data(), shape, onnxDtype, ctx);
break;
}
nbytes = onnxTensor.int32_data().size() * getDtypeSize(onnxDtype);
}
}
else if (onnxDtype == ::ONNX_NAMESPACE::TensorProto::FLOAT)
{
if (onnxTensor.raw_data().size() > 0)
{
dataPtr = (void*)(onnxTensor.raw_data().data());
nbytes = onnxTensor.raw_data().size();
}
else
{
dataPtr = (void*)(onnxTensor.float_data().data());
nbytes = onnxTensor.float_data().size() * sizeof(float);
}
}
else
{
LOG_ERROR("Found unsupported datatype (" << onnxDtype << ") when importing initializer: " << onnxTensor.name());
return false;
}
onnx2trt::ShapedWeights trt_weights(onnxDtype, dataPtr, shape);
// Sanity check that weights were converted properly
if (trt_weights.size_bytes() != nbytes)
{
LOG_ERROR("Size mismatch when importing initializer: " << onnxTensor.name() << ". Expected size: " << nbytes << " , actual size: " << trt_weights.size_bytes());
return false;
}
*weights = trt_weights;
return true;
}
nvinfer1::ITensor* convertToScalar(IImporterContext* ctx, nvinfer1::ITensor* inpTensor)
{
if (inpTensor->getDimensions().nbDims == 0)
{
return inpTensor;
}
const auto tensorVolume = volume(inpTensor->getDimensions());
if (tensorVolume != 1)
{
LOG_VERBOSE("Cannot convert tensor to scalar. Note: Tensor dimensions were: "
<< inpTensor->getDimensions() << ", with volume: " << tensorVolume);
return nullptr;
}
nvinfer1::IShuffleLayer* reshape = ctx->network()->addShuffle(*inpTensor);
reshape->setReshapeDimensions(nvinfer1::Dims{0});
return reshape->getOutput(0);
}
nvinfer1::ITensor& convertToTensor(TensorOrWeights& input, IImporterContext* ctx)
{
if (input.is_tensor())
{
return input.tensor();
}
else
{
// Handle non-tensor indices input by adding a new constant layer to the network.
ShapedWeights& weights = input.weights();
// Note the TRT doesn't natively handle boolean weights. First create an INT32 weights copy of the boolean weights, then cast it back to bool within TRT.
if (weights.type == ::ONNX_NAMESPACE::TensorProto::BOOL)
{
ShapedWeights convertedWeights = ctx->createTempWeights(::ONNX_NAMESPACE::TensorProto::INT32, weights.shape);
int* intValues = static_cast<int*>(weights.values);
std::memcpy(convertedWeights.values, intValues, weights.count() * sizeof(int));
auto* boolTensor = ctx->network()->addConstant(convertedWeights.shape, convertedWeights)->getOutput(0);
auto* castLayer = ctx->network()->addIdentity(*boolTensor);
castLayer->setOutputType(0,nvinfer1::DataType::kBOOL);
return *(castLayer->getOutput(0));
}
else
{
return *(ctx->network()->addConstant(weights.shape, weights)->getOutput(0));
}
}
}
nvinfer1::ITensor* convertToScalar(TensorOrWeights& input, IImporterContext* ctx)
{
if (input.is_tensor())
{
return convertToScalar(ctx, &input.tensor());
}
else
{
ShapedWeights& weights = input.weights();
if (volume(weights.shape) != 1)
{
LOG_VERBOSE("Cannot convert weights to scalar. Note: Tensor dimensions were: "
<< weights.shape << ", with volume: " << volume(weights.shape));
return nullptr;
}
return ctx->network()->addConstant(nvinfer1::Dims{0, {0}}, weights)->getOutput(0);
}
}
bool convertWeightDescriptor(
onnxTensorDescriptorV1 const& desc, onnx2trt::ShapedWeights* weights, IImporterContext* ctx)
{
nvinfer1::Dims shape;
shape.nbDims = desc.dimensions;
// Special case for scalars
if (shape.nbDims == 0)
{
shape.nbDims = 1;
shape.d[0] = 1;
}
else
{
std::copy(desc.shape, desc.shape + desc.dimensions, shape.d);
}
size_t element_count = 1;
for (int i = 0; i < shape.nbDims; ++i)
{
element_count *= shape.d[i];
}
void* dataPtr;
size_t nbytes;
int32_t dtype;
dataPtr = (void*) (desc.buffer);
if (desc.dataType == ONNXIFI_DATATYPE_FLOAT32)
{
dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT;
nbytes = element_count * sizeof(float);
}
else if (desc.dataType == ONNXIFI_DATATYPE_FLOAT16)
{
dtype = ::ONNX_NAMESPACE::TensorProto::FLOAT16;
nbytes = element_count * sizeof(float) / 2;
}
else if (desc.dataType == ONNXIFI_DATATYPE_INT32)
{
dtype = ::ONNX_NAMESPACE::TensorProto::INT32;
nbytes = element_count * sizeof(int32_t);
}
else if (desc.dataType == ONNXIFI_DATATYPE_INT64)
{
dataPtr = convertINT64(reinterpret_cast<const int64_t*>(desc.buffer), shape, ctx);
dtype = ::ONNX_NAMESPACE::TensorProto::INT32;
nbytes = element_count * sizeof(int32_t);
}
else
{
// Unsupported format
return false;
}
onnx2trt::ShapedWeights trt_weights(dtype, dataPtr, shape);
(void) nbytes;
assert(trt_weights.size_bytes() == nbytes);
*weights = trt_weights;
return true;
}
int divCeil(int n, int d)
{
return (n - 1) / d + 1;
}
bool elementwiseCheck(const std::vector<TensorOrWeights>& inputs, const nvinfer1::ElementWiseOperation op)
{
switch (op)
{
// These operations only support boolean inputs
case nvinfer1::ElementWiseOperation::kAND:
case nvinfer1::ElementWiseOperation::kOR:
case nvinfer1::ElementWiseOperation::kXOR:
if (!std::all_of(inputs.begin(), inputs.end(), [](const TensorOrWeights& input) {return input.isBool();}))
{
return false;
}
break;
// These operations do not support boolean types
case nvinfer1::ElementWiseOperation::kDIV:
case nvinfer1::ElementWiseOperation::kEQUAL:
case nvinfer1::ElementWiseOperation::kFLOOR_DIV:
case nvinfer1::ElementWiseOperation::kGREATER:
case nvinfer1::ElementWiseOperation::kLESS:
case nvinfer1::ElementWiseOperation::kMAX:
case nvinfer1::ElementWiseOperation::kMIN:
case nvinfer1::ElementWiseOperation::kPROD:
case nvinfer1::ElementWiseOperation::kSUB:
case nvinfer1::ElementWiseOperation::kSUM:
if (std::any_of(inputs.begin(), inputs.end(), [](const TensorOrWeights& input) {return input.isBool();}))
{
return false;
}
break;
// Pow does not support bool or INT32 types
case nvinfer1::ElementWiseOperation::kPOW:
if (std::any_of(inputs.begin(), inputs.end(), [](const TensorOrWeights& input) {return input.isBool() || input.isInt32();}))
{
return false;
}
break;
}
return true;
}
NodeImportResult elementwiseHelper(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node,
std::vector<TensorOrWeights>& inputs, nvinfer1::ElementWiseOperation binary_op)
{
ASSERT(!inputs.empty(), ErrorCode::kINVALID_NODE);
ASSERT(elementwiseCheck(inputs, binary_op), ErrorCode::kUNSUPPORTED_NODE);
std::vector<nvinfer1::ITensor*> inputTensors;
int maxNbDims = -1;
for (auto input : inputs)
{
maxNbDims = std::max(maxNbDims, input.shape().nbDims);
}
for (auto input : inputs)
{
auto* tensor_ptr = &convertToTensor(input, ctx);
// Broadcast all input tensors to size of maxNbDims
broadcastTensor(ctx, tensor_ptr, maxNbDims);
ASSERT(tensor_ptr->getDimensions().nbDims == maxNbDims && "Failed to broadcast tensors elementwise!",
ErrorCode::kUNSUPPORTED_NODE);
inputTensors.push_back(tensor_ptr);
}
// Use the first tensor input as the base for the elementwise operation
nvinfer1::ITensor* combined = inputTensors.at(0);
if (inputTensors.size() == 1)
{
// Note: Single input must be wrapped in identity to avoid messing up network outputs
return {{identity(ctx, combined)}};
}
for (size_t i = 1; i < inputTensors.size(); ++i)
{
nvinfer1::ITensor* tensor = inputTensors.at(i);
ASSERT(tensor->getDimensions().nbDims == combined->getDimensions().nbDims, ErrorCode::kUNSUPPORTED_NODE);
auto* layer = ctx->network()->addElementWise(*combined, *tensor, binary_op);
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
combined = layer->getOutput(0);
}
return {{combined}};
}
nvinfer1::ITensor* flattenTensor(IImporterContext* ctx, nvinfer1::ITensor& tensor, int axis)
{
const ShapeTensor dims = shapeOf(ctx, tensor);
const ShapeTensor d0 = product(ctx, dims, 0, axis, 1);
const ShapeTensor d1 = product(ctx, dims, axis, dims.size, 1);
nvinfer1::IShuffleLayer* flattenLayer = addShuffle(ctx, tensor, concat(ctx, d0, d1));
return flattenLayer->getOutput(0);
}
nvinfer1::ITensor* gatherDimension(IImporterContext* ctx, nvinfer1::ITensor* shapeTensor, int dim, nvinfer1::Dims shape)
{
auto& axisValue = *addConstantScalar(ctx, dim, ::ONNX_NAMESPACE::TensorProto_DataType_INT32, shape)->getOutput(0);
return ctx->network()->addGather(*shapeTensor, axisValue, 0)->getOutput(0);
}
// Helper function to generate padding values for convTranspose
void generatePadding(nvinfer1::Dims input_dims, nvinfer1::Dims output_shape, nvinfer1::Dims kernel_size,
nvinfer1::Dims strides, nvinfer1::Dims dilations, const int nbSpatialDims, nvinfer1::Dims& beg_padding,
nvinfer1::Dims& end_padding, nvinfer1::Dims& output_padding, nvinfer1::PaddingMode paddingMode)
{
// When auto_pad == NONSET or VALID, input padding is explict
// explicit output shape may require output padding
if (paddingMode == nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN)
{
nvinfer1::Dims expected_output_shape;
for (int i = 0; i < nbSpatialDims; i++)
{
expected_output_shape.d[i] = (input_dims.d[2 + i] - 1) * strides.d[i]
+ (kernel_size.d[i] - 1) * dilations.d[i] + 1 - beg_padding.d[i] - end_padding.d[i];
output_padding.d[i] = output_shape.d[i] - expected_output_shape.d[i];
}
}
else
{
// When auto_pad == SAME_UPPER or SAME_LOWER, output padding is explict
// explicit output shape may require input padding
nvinfer1::Dims total_padding = makeDims(nbSpatialDims, 0);
for (int i = 0; i < nbSpatialDims; i++)
{
total_padding.d[i] = (input_dims.d[2 + i] - 1) * strides.d[i] + (kernel_size.d[i] - 1) * dilations.d[i] + 1
+ output_padding.d[i] - output_shape.d[i];
if (paddingMode == nvinfer1::PaddingMode::kSAME_UPPER)
{
beg_padding.d[i] = total_padding.d[i] - (total_padding.d[i] / 2);
end_padding.d[i] = total_padding.d[i] / 2;
}
else
{
beg_padding.d[i] = total_padding.d[i] / 2;
end_padding.d[i] = total_padding.d[i] - (total_padding.d[i] / 2);
}
}
}
}
float getActivationDefaultAlpha(nvinfer1::ActivationType type)
{
switch (type)
{
case nvinfer1::ActivationType::kRELU: return 0.f;
case nvinfer1::ActivationType::kSIGMOID: return 0.f;
case nvinfer1::ActivationType::kTANH: return 0.f;
case nvinfer1::ActivationType::kLEAKY_RELU: return 0.01f;
case nvinfer1::ActivationType::kELU: return 1.0f;
case nvinfer1::ActivationType::kSELU: return 1.67326319217681884765625f;
case nvinfer1::ActivationType::kSOFTSIGN: return 0.f;
case nvinfer1::ActivationType::kSOFTPLUS: return 0.f;
case nvinfer1::ActivationType::kCLIP: return 0.f;
case nvinfer1::ActivationType::kHARD_SIGMOID: return 0.2f;
case nvinfer1::ActivationType::kSCALED_TANH: return 1.0f;
case nvinfer1::ActivationType::kTHRESHOLDED_RELU: return 1.0f;
}
throw std::runtime_error{"Unrecognized activation type"};
}
float getActivationDefaultBeta(nvinfer1::ActivationType type)
{
switch (type)
{
case nvinfer1::ActivationType::kRELU: return 0.f;
case nvinfer1::ActivationType::kSIGMOID: return 0.f;
case nvinfer1::ActivationType::kTANH: return 0.f;
case nvinfer1::ActivationType::kLEAKY_RELU: return 0.f;
case nvinfer1::ActivationType::kELU: return 0.f;
case nvinfer1::ActivationType::kSELU: return 1.05070102214813232421875f;
case nvinfer1::ActivationType::kSOFTSIGN: return 0.f;
case nvinfer1::ActivationType::kSOFTPLUS: return 0.f;
case nvinfer1::ActivationType::kCLIP: return 0.f;
case nvinfer1::ActivationType::kHARD_SIGMOID: return 0.5f;
case nvinfer1::ActivationType::kSCALED_TANH: return 1.0f;
case nvinfer1::ActivationType::kTHRESHOLDED_RELU: return 0.f;
}
throw std::runtime_error{"Unrecognized activation type"};
}
nvinfer1::ITensor* getAxisLength(IImporterContext* ctx, nvinfer1::ITensor* inpTensor, int axis, nvinfer1::Dims shape)
{
// fast path for static dims
auto dims = inpTensor->getDimensions();
int d = dims.d[axis];
if (d >= 0)
{
return addConstantScalar(ctx, d, ::ONNX_NAMESPACE::TensorProto_DataType_INT32, shape)->getOutput(0);
}
else
{
nvinfer1::ITensor* inpShape = ctx->network()->addShape(*inpTensor)->getOutput(0);
return gatherDimension(ctx, inpShape, axis, shape);
}
}
int getConvOutputSize(int input_size, int filter_size, int stride, int dilation_rate, int total_padding);
const char* getDtypeName(int32_t onnxDtype)
{
switch (onnxDtype)
{
case ::ONNX_NAMESPACE::TensorProto::FLOAT: return "FLOAT";
case ::ONNX_NAMESPACE::TensorProto::UINT8: return "UINT8";
case ::ONNX_NAMESPACE::TensorProto::INT8: return "INT8";
case ::ONNX_NAMESPACE::TensorProto::UINT16: return "UINT16";
case ::ONNX_NAMESPACE::TensorProto::INT16: return "INT16";
case ::ONNX_NAMESPACE::TensorProto::INT32: return "INT32";
case ::ONNX_NAMESPACE::TensorProto::INT64: return "INT64";
case ::ONNX_NAMESPACE::TensorProto::STRING: return "STRING";
case ::ONNX_NAMESPACE::TensorProto::BOOL: return "BOOL";
case ::ONNX_NAMESPACE::TensorProto::FLOAT16: return "FLOAT16";
case ::ONNX_NAMESPACE::TensorProto::DOUBLE: return "DOUBLE";
case ::ONNX_NAMESPACE::TensorProto::UINT32: return "UINT32";
case ::ONNX_NAMESPACE::TensorProto::UINT64: return "UINT64";
case ::ONNX_NAMESPACE::TensorProto::COMPLEX64: return "COMPLEX64";
case ::ONNX_NAMESPACE::TensorProto::COMPLEX128: return "COMPLEX128";
default: return "<UNKNOWN>";
}
}
int getDtypeSize(int32_t onnxDtype)
{
switch (onnxDtype)
{
case ::ONNX_NAMESPACE::TensorProto::FLOAT16: return 2;
case ::ONNX_NAMESPACE::TensorProto::FLOAT: return 4;
case ::ONNX_NAMESPACE::TensorProto::DOUBLE: return 8;
case ::ONNX_NAMESPACE::TensorProto::COMPLEX64: return 8;
case ::ONNX_NAMESPACE::TensorProto::COMPLEX128: return 16;
case ::ONNX_NAMESPACE::TensorProto::UINT8: return 1;
case ::ONNX_NAMESPACE::TensorProto::INT8: return 1;
case ::ONNX_NAMESPACE::TensorProto::UINT16: return 2;
case ::ONNX_NAMESPACE::TensorProto::INT16: return 2;
case ::ONNX_NAMESPACE::TensorProto::UINT32:
return 4;
// Booleans are stored in int32 tensors in ONNX
case ::ONNX_NAMESPACE::TensorProto::BOOL: return 1;
case ::ONNX_NAMESPACE::TensorProto::INT32: return 4;
case ::ONNX_NAMESPACE::TensorProto::UINT64: return 8;
case ::ONNX_NAMESPACE::TensorProto::INT64: return 8;
default: return -1;
}
}
void getKernelParams(IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& onnx_node, nvinfer1::Dims* kernel_size,
nvinfer1::Dims* strides, nvinfer1::Dims* beg_padding, nvinfer1::Dims* end_padding,
nvinfer1::PaddingMode& paddingMode, bool& count_exclude_padding, nvinfer1::Dims* dilations,
nvinfer1::Dims* output_padding, const bool poolingCeilMode)
{
const int nbSpatialDims = kernel_size->nbDims;
OnnxAttrs attrs(onnx_node, ctx);
if (attrs.count("kernel_shape"))
{
auto const* onnx_kernel_size = attrs.at("kernel_shape");
setAttr(kernel_size, onnx_kernel_size, nbSpatialDims, 1);
}
if (attrs.count("strides"))
{
auto const* onnx_strides = attrs.at("strides");
setAttr(strides, onnx_strides, nbSpatialDims, 1);
}
if (dilations && attrs.count("dilations"))
{
auto const* onnx_dilations = attrs.at("dilations");
setAttr(dilations, onnx_dilations, nbSpatialDims, 1);
}
if (attrs.count("count_include_pad"))
{
auto const* include_pad = attrs.at("count_include_pad");
int val = include_pad->i();
val == 1 ? count_exclude_padding = false : count_exclude_padding = true;
}
// For ConvTranspose Layer
if (attrs.count("output_padding"))
{
*output_padding = attrs.get<nvinfer1::Dims>("output_padding");
}
paddingMode = poolingCeilMode ? nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP : nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN;
auto onnx_auto_pad = attrs.get("auto_pad", std::string("NOTSET"));
if (onnx_auto_pad != "SAME_LOWER" && onnx_auto_pad != "SAME_UPPER")
{
if (attrs.count("pads"))
{
auto onnx_padding = attrs.get<std::vector<int>>("pads");
int ndim = onnx_padding.size() / 2;
for (int i = 0; i < nbSpatialDims; ++i)
{
if (i < ndim)
{
beg_padding->d[i] = onnx_padding.at(i);
end_padding->d[i] = onnx_padding.at(i + ndim);
}
else
{
beg_padding->d[i] = 0;
end_padding->d[i] = 0;
}
}
}
if (onnx_auto_pad != "VALID" && onnx_auto_pad != "NOTSET")
{
if (onnx_auto_pad == "EXPLICIT_ROUND_UP")
{
paddingMode = nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP;
}
else if (onnx_auto_pad == "CAFFE_ROUND_DOWN")
{
paddingMode = nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN;
}
else if (onnx_auto_pad == "CAFFE_ROUND_UP")
{
paddingMode = nvinfer1::PaddingMode::kCAFFE_ROUND_UP;
}
}
}
else
{
// If auto_pad is SAME_LOWER or SAME_UPPER, input padding should be calculated
// "pads" attribute should not be specified
assert(!attrs.count("pads"));
// Note: ONNX is always NCHW ordering
if (onnx_auto_pad == "SAME_LOWER")
{
paddingMode = nvinfer1::PaddingMode::kSAME_LOWER;
}
else if (onnx_auto_pad == "SAME_UPPER")
{
paddingMode = nvinfer1::PaddingMode::kSAME_UPPER;
}
else
{
throw std::invalid_argument("Unexpected auto_pad value: " + onnx_auto_pad);
}
}
}
nvinfer1::ITensor* globalPoolingHelper(IImporterContext* ctx, nvinfer1::ITensor& tensor, nvinfer1::ReduceOperation op)
{
nvinfer1::Dims dims = tensor.getDimensions();
// Generate a bitmask of all 1s except the last 2 bits (N and C axes)
uint32_t reduceAxes = ((1 << dims.nbDims) - 1) & ~0b11;
return ctx->network()->addReduce(tensor, op, reduceAxes, /*keepDimensions=*/true)->getOutput(0);
}
nvinfer1::IPluginV2* importPluginFromRegistry(IImporterContext* ctx, const std::string& pluginName,
const std::string& pluginVersion, const std::string& nodeName,
const std::vector<nvinfer1::PluginField>& pluginFields)
{
const auto mPluginRegistry = getPluginRegistry();
const auto pluginCreator
= mPluginRegistry->getPluginCreator(pluginName.c_str(), pluginVersion.c_str(), "ONNXTRT_NAMESPACE");
if (!pluginCreator)
{
return nullptr;
}
nvinfer1::PluginFieldCollection fc;
fc.nbFields = pluginFields.size();
fc.fields = pluginFields.data();
return pluginCreator->createPlugin(nodeName.c_str(), &fc);
}
bool isDynamic(const nvinfer1::Dims& shape)
{
return std::any_of(shape.d, shape.d + shape.nbDims, [](int dim) { return dim < 0; });
}
bool isOnnxTensorEmpty(const ::ONNX_NAMESPACE::TensorProto& onnxTensor)
{
return onnxTensor.raw_data().empty() && onnxTensor.double_data().empty()
&& onnxTensor.float_data().empty() && onnxTensor.int32_data().empty()
&& onnxTensor.int64_data().empty() && onnxTensor.string_data().empty()
&& onnxTensor.uint64_data().empty();
}
bool isTransposeRequired(nvinfer1::Dims const& shape, nvinfer1::Permutation const& perm)
{
int ndim = shape.nbDims;
int prev_significant_dim = 0;
for (int dst_i = 0; dst_i < ndim; ++dst_i)
{
int src_i = perm.order[dst_i];
int dim_i = shape.d[src_i];
if (dim_i != 1)
{
// We must do a transpose for dynamically shaped tensors
if (dim_i == -1)
{
return true;
}
if (src_i < prev_significant_dim)
{
return true;
}
prev_significant_dim = src_i;
}
}
return false;
}
NodeImportResult lstmLegacyImporter(
IImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector<TensorOrWeights>& inputs)
{
// Input
nvinfer1::ITensor& raw_input = convertToTensor(inputs.at(0), ctx);
ASSERT(3 == raw_input.getDimensions().nbDims && "Input tensor must be 3 dimensional", ErrorCode::kINVALID_NODE);
ASSERT((raw_input.getType() == nvinfer1::DataType::kFLOAT || raw_input.getType() == nvinfer1::DataType::kHALF)
&& "Only fp16 and fp32 inputs are supported",
ErrorCode::kUNSUPPORTED_NODE);
const nvinfer1::DataType input_type = raw_input.getType();
const int32_t max_seq_len = raw_input.getDimensions().d[0];
const int32_t batch_size = raw_input.getDimensions().d[1];
// Attributes
OnnxAttrs attrs(node, ctx);
const std::string direction_str = attrs.get<std::string>("direction", "forward");
ASSERT((direction_str == "forward" || direction_str == "bidirectional") && "Reverse LSTM unsupported",
ErrorCode::kUNSUPPORTED_NODE);
const nvinfer1::RNNDirection direction
= (direction_str == "forward") ? nvinfer1::RNNDirection::kUNIDIRECTION : nvinfer1::RNNDirection::kBIDIRECTION;
const int num_directions = (direction_str == "forward") ? 1 : 2;
// There are three distinct uses of an activation function within the LSTM equations
// One for the input/forget/output gates, one for the cell state, and one for the output
// RNNv2 only supports the default choice for each, listed here (duplicated for bidirectional)
std::vector<std::string> default_activations = {"Sigmoid", "Tanh", "Tanh"};
if (num_directions == 2)
{
default_activations.insert(default_activations.end(), {"Sigmoid", "Tanh", "Tanh"});
}
const std::vector<std::string> activations
= attrs.get<std::vector<std::string>>("activations", default_activations);
ASSERT(activations == default_activations && "Nonstandard activations within LSTM unsupported",
ErrorCode::kUNSUPPORTED_NODE);
const float clip = attrs.get<float>("clip", 0.0f);
ASSERT(clip == 0.0f && "Clipping unsupported", ErrorCode::kUNSUPPORTED_NODE);
const int32_t hidden_size = attrs.get<int>("hidden_size");
ASSERT(hidden_size > 0, ErrorCode::kINVALID_NODE);
const int32_t input_forget = attrs.get<int>("input_forget", 0);
ASSERT(0 == input_forget && "Coupled input/forget unsupported", ErrorCode::kUNSUPPORTED_NODE);
// Optional Inputs
bool has_bias = false;
nvinfer1::ITensor* sequence_lens = nullptr;
nvinfer1::ITensor* initial_h = nullptr;
nvinfer1::ITensor* initial_c = nullptr;
for (int i = 3; i < node.input_size(); i++)
{
const std::string& input_name = node.input(i);
if (input_name == "B")
{
has_bias = true;
}
else if (input_name == "sequence_lens")
{
sequence_lens = &(convertToTensor(inputs.at(i), ctx));
ASSERT(sequence_lens && sequence_lens->getType() == nvinfer1::DataType::kINT32
&& "Failed to process sequence_lens (sequence_lens must be int32)",
ErrorCode::kINVALID_NODE);
}
else if (input_name == "initial_h" || input_name == "initial_c")
{
nvinfer1::ITensor* output = nullptr;
if (inputs.at(i).is_weights())