-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathinference_helper.h
143 lines (116 loc) · 4.25 KB
/
inference_helper.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#ifndef VSTRT_INFERENCE_HELPER_H_
#define VSTRT_INFERENCE_HELPER_H_
#include <algorithm>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include <VSHelper.h>
#include "cuda_helper.h"
#include "trt_utils.h"
struct InputInfo {
int width;
int height;
int pitch;
int bytes_per_sample;
int tile_w;
int tile_h;
};
struct OutputInfo {
int pitch;
int bytes_per_sample;
};
struct IOInfo {
InputInfo in;
OutputInfo out;
int w_scale;
int h_scale;
int overlap_w;
int overlap_h;
};
static inline
std::optional<ErrorMessage> inference(
const InferenceInstance & instance,
int device_id,
bool use_cuda_graph,
const IOInfo & info,
const std::vector<const uint8_t *> & src_ptrs,
const std::vector<uint8_t *> & dst_ptrs
) noexcept {
const auto set_error = [](const ErrorMessage & error_message) {
return error_message;
};
checkError(cudaSetDevice(device_id));
int src_tile_w_bytes = info.in.tile_w * info.in.bytes_per_sample;
int src_tile_bytes = info.in.tile_h * info.in.tile_w * info.in.bytes_per_sample;
int dst_tile_w = info.in.tile_w * info.w_scale;
int dst_tile_h = info.in.tile_h * info.h_scale;
int dst_tile_w_bytes = dst_tile_w * info.out.bytes_per_sample;
int dst_tile_bytes = dst_tile_h * dst_tile_w * info.out.bytes_per_sample;
int step_w = info.in.tile_w - 2 * info.overlap_w;
int step_h = info.in.tile_h - 2 * info.overlap_h;
int y = 0;
while (true) {
int y_crop_start = (y == 0) ? 0 : info.overlap_h;
int y_crop_end = (y == info.in.height - info.in.tile_h) ? 0 : info.overlap_h;
int x = 0;
while (true) {
int x_crop_start = (x == 0) ? 0 : info.overlap_w;
int x_crop_end = (x == info.in.width - info.in.tile_w) ? 0 : info.overlap_w;
{
uint8_t * h_data = instance.src.h_data.data;
for (const uint8_t * _src_ptr : src_ptrs) {
const uint8_t * src_ptr { _src_ptr +
y * info.in.pitch + x * info.in.bytes_per_sample
};
vs_bitblt(
h_data, src_tile_w_bytes,
src_ptr, info.in.pitch,
static_cast<size_t>(src_tile_w_bytes),
static_cast<size_t>(info.in.tile_h)
);
h_data += src_tile_bytes;
}
}
if (use_cuda_graph) {
checkError(cudaGraphLaunch(instance.graphexec, instance.stream));
} else {
auto result = enqueue(
instance.src, instance.dst,
instance.exec_context, instance.stream
);
if (result.has_value()) {
return set_error(result.value());
}
}
checkError(cudaStreamSynchronize(instance.stream));
{
const uint8_t * h_data = instance.dst.h_data.data;
for (uint8_t * _dst_ptr : dst_ptrs) {
uint8_t * dst_ptr { _dst_ptr +
info.h_scale * y * info.out.pitch + info.w_scale * x * info.out.bytes_per_sample
};
vs_bitblt(
dst_ptr + (y_crop_start * info.out.pitch + x_crop_start * info.out.bytes_per_sample),
info.out.pitch,
h_data + (y_crop_start * dst_tile_w_bytes + x_crop_start * info.out.bytes_per_sample),
dst_tile_w_bytes,
static_cast<size_t>(dst_tile_w_bytes - (x_crop_start + x_crop_end) * info.out.bytes_per_sample),
static_cast<size_t>(dst_tile_h - (y_crop_start + y_crop_end))
);
h_data += dst_tile_bytes;
}
}
if (x + info.in.tile_w == info.in.width) {
break;
}
x = std::min(x + step_w, info.in.width - info.in.tile_w);
}
if (y + info.in.tile_h == info.in.height) {
break;
}
y = std::min(y + step_h, info.in.height - info.in.tile_h);
}
return {};
}
#endif // VSTRT_INFERENCE_HELPER_H_