Skip to content

Commit a2b3d3d

Browse files
kevinch-nvrajeevsrao
authored andcommitted
Add switch for batch agnostic mode in NMS plugin
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
1 parent 0953f2f commit a2b3d3d

6 files changed

Lines changed: 31 additions & 13 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
- Algorithms optimization for NMS kernels and ROIAlign kernel
1414
- Fix invalid cuda config issue when bs is larger than 32
1515
- Fix issues found on Jetson NANO
16+
- Add switch for batch-agnostic mode in NMS plugin
1617

1718
### Removed
1819
- Removed fcplugin from demoBERT to improve latency

include/NvInferPluginUtils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ enum class CodeTypeSSD : int32_t
177177
//! \param inputOrder Specifies the order of inputs {loc_data, conf_data, priorbox_data}.
178178
//! \param confSigmoid Set to true to calculate sigmoid of confidence scores.
179179
//! \param isNormalized Set to true if bounding box data is normalized by the network.
180+
//! \param isBatchAgnostic Defaults to true. Set to false if prior boxes are unique per batch
180181
//!
181182
struct DetectionOutputParameters
182183
{
@@ -187,6 +188,7 @@ struct DetectionOutputParameters
187188
int32_t inputOrder[3];
188189
bool confSigmoid;
189190
bool isNormalized;
191+
bool isBatchAgnostic{true};
190192
};
191193

192194
//!

plugin/common/kernels/decodeBBoxes.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ __launch_bounds__(nthds_per_cta)
8787
const bool clip_bbox,
8888
const T_BBOX* loc_data,
8989
const T_BBOX* prior_data,
90-
T_BBOX* bbox_data)
90+
T_BBOX* bbox_data,
91+
const bool batch_agnostic)
9192
{
9293
for (int index = blockIdx.x * nthds_per_cta + threadIdx.x;
9394
index < nthreads;
@@ -113,7 +114,7 @@ __launch_bounds__(nthds_per_cta)
113114
// do not assume each images' anchor boxes are identical
114115
// e.g., in FasterRCNN, priors are ROIs from proposal layer and are different
115116
// for each image.
116-
const int pi = (batch * 2 * num_priors + d) * 4;
117+
const int pi = batch_agnostic ? d * 4 : (batch * 2 * num_priors + d) * 4;
117118
// Index to the right variances corresponding to the current bounding box
118119
const int vi = pi + num_priors * 4;
119120
// Encoding method: CodeTypeSSD::CORNER
@@ -296,15 +297,16 @@ pluginStatus_t decodeBBoxes_gpu(
296297
const bool clip_bbox,
297298
const void* loc_data,
298299
const void* prior_data,
299-
void* bbox_data)
300+
void* bbox_data,
301+
const bool batch_agnostic)
300302
{
301303
const int BS = 512;
302304
const int GS = (nthreads + BS - 1) / BS;
303305
decodeBBoxes_kernel<T_BBOX, BS><<<GS, BS, 0, stream>>>(nthreads, code_type, variance_encoded_in_target,
304306
num_priors, share_location, num_loc_classes,
305307
background_label_id, clip_bbox,
306308
(const T_BBOX*) loc_data, (const T_BBOX*) prior_data,
307-
(T_BBOX*) bbox_data);
309+
(T_BBOX*) bbox_data, batch_agnostic);
308310
CSC(cudaGetLastError(), STATUS_FAILURE);
309311
return STATUS_SUCCESS;
310312
}
@@ -321,7 +323,8 @@ typedef pluginStatus_t (*dbbFunc)(cudaStream_t,
321323
const bool,
322324
const void*,
323325
const void*,
324-
void*);
326+
void*,
327+
const bool);
325328

326329
struct dbbLaunchConfig
327330
{
@@ -361,7 +364,8 @@ pluginStatus_t decodeBBoxes(
361364
const DataType DT_BBOX,
362365
const void* loc_data,
363366
const void* prior_data,
364-
void* bbox_data)
367+
void* bbox_data,
368+
const bool batch_agnostic)
365369
{
366370
dbbLaunchConfig lc = dbbLaunchConfig(DT_BBOX);
367371
for (unsigned i = 0; i < dbbLCOptions.size(); ++i)
@@ -380,7 +384,8 @@ pluginStatus_t decodeBBoxes(
380384
clip_bbox,
381385
loc_data,
382386
prior_data,
383-
bbox_data);
387+
bbox_data,
388+
batch_agnostic);
384389
}
385390
}
386391
return STATUS_BAD_PARAM;

plugin/common/kernels/detectionForward.cu

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ pluginStatus_t detectionInference(
4141
void* workspace,
4242
bool isNormalized,
4343
bool confSigmoid,
44-
int scoreBits)
44+
int scoreBits,
45+
const bool isBatchAgnostic)
4546
{
4647
// Batch size * number bbox per sample * 4 = total number of bounding boxes * 4
4748
const int locCount = N * C1;
@@ -70,7 +71,8 @@ pluginStatus_t detectionInference(
7071
DT_BBOX,
7172
locData,
7273
priorData,
73-
bboxDataRaw);
74+
bboxDataRaw,
75+
isBatchAgnostic);
7476

7577
ASSERT_FAILURE(status == STATUS_SUCCESS);
7678

@@ -246,7 +248,8 @@ namespace plugin
246248
void* workspace,
247249
bool isNormalized,
248250
bool confSigmoid,
249-
int scoreBits)
251+
int scoreBits,
252+
const bool isBatchAgnostic)
250253
{
251254
// Batch size * number bbox per sample * 4 = total number of bounding boxes * 4
252255
const int locCount = N * C1;
@@ -275,7 +278,8 @@ namespace plugin
275278
DT_BBOX,
276279
locData,
277280
priorData,
278-
bboxDataRaw);
281+
bboxDataRaw,
282+
isBatchAgnostic);
279283

280284
ASSERT_FAILURE(status == STATUS_SUCCESS);
281285

plugin/common/kernels/kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pluginStatus_t detectionInference(cudaStream_t stream, int N, int C1, int C2, bo
4343
bool varianceEncodedInTarget, int backgroundLabelId, int numPredsPerClass, int numClasses, int topK, int keepTopK,
4444
float confidenceThreshold, float nmsThreshold, CodeTypeSSD codeType, DataType DT_BBOX, const void* locData,
4545
const void* priorData, DataType DT_SCORE, const void* confData, void* keepCount, void* topDetections,
46-
void* workspace, bool isNormalized = true, bool confSigmoid = false, int scoreBits = 16);
46+
void* workspace, bool isNormalized = true, bool confSigmoid = false, int scoreBits = 16, const bool isBatchAgnostic = true);
4747

4848
pluginStatus_t nmsInference(cudaStream_t stream, int N, int boxesSize, int scoresSize, bool shareLocation,
4949
int backgroundLabelId, int numPredsPerClass, int numClasses, int topK, int keepTopK, float scoreThreshold,
@@ -84,7 +84,7 @@ size_t detectionForwardPostNMSSize(int N, int numClasses, int topK);
8484

8585
pluginStatus_t decodeBBoxes(cudaStream_t stream, int nthreads, CodeTypeSSD code_type, bool variance_encoded_in_target,
8686
int num_priors, bool share_location, int num_loc_classes, int background_label_id, bool clip_bbox, DataType DT_BBOX,
87-
const void* loc_data, const void* prior_data, void* bbox_data);
87+
const void* loc_data, const void* prior_data, void* bbox_data, const bool batch_agnostic);
8888

8989
size_t normalizePluginWorkspaceSize(bool acrossSpatial, int C, int H, int W);
9090

plugin/nmsPlugin/nmsPlugin.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ NMSBasePluginCreator::NMSBasePluginCreator() noexcept
567567
mPluginAttributes.emplace_back(PluginField("isNormalized", nullptr, PluginFieldType::kINT32, 1));
568568
mPluginAttributes.emplace_back(PluginField("codeType", nullptr, PluginFieldType::kINT32, 1));
569569
mPluginAttributes.emplace_back(PluginField("scoreBits", nullptr, PluginFieldType::kINT32, 1));
570+
mPluginAttributes.emplace_back(PluginField("isBatchAgnostic", nullptr, PluginFieldType::kINT32, 1));
570571
mFC.nbFields = mPluginAttributes.size();
571572
mFC.fields = mPluginAttributes.data();
572573
}
@@ -684,6 +685,11 @@ IPluginV2Ext* NMSPluginCreator::createPlugin(const char* name, const PluginField
684685
ASSERT(fields[i].type == PluginFieldType::kINT32);
685686
mScoreBits = *(static_cast<const int32_t*>(fields[i].data));
686687
}
688+
else if (!strcmp(attrName, "isBatchAgnostic"))
689+
{
690+
ASSERT(fields[i].type == PluginFieldType::kINT32);
691+
params.isBatchAgnostic = static_cast<int>(*(static_cast<const int*>(fields[i].data)));
692+
}
687693
}
688694

689695
DetectionOutput* obj = new DetectionOutput(params);

0 commit comments

Comments
 (0)