@@ -28,16 +28,17 @@ namespace cub = hipcub;
2828#include < sys/types.h>
2929#include < unistd.h>
3030#include < algorithm>
31- #include " stdint .h"
31+ #include " cccl_compat .h" // CCCL 3.0 compatibility
3232#include " helper.h"
33+ #include " stdint.h"
3334
3435#define FLT_MAX 1e38
3536
3637static constexpr int kBlockSizeForSmallBeamWidth = 256 ;
3738static constexpr int kMaxVocabPartForStage1FastKernel = 128 ;
3839
39- #define CASE_K (K ) \
40- case K: \
40+ #define CASE_K (K ) \
41+ case K: \
4142 invokeTopKSoftMaxLauncher<T, 2 * K, GROUP>( \
4243 params, beam_group_idx, stream); \
4344 break
@@ -368,7 +369,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
368369
369370 using KVPair = cub::KeyValuePair<int , T>;
370371 KVPair topKVPairPartial{vocab_size - 1 , -MAX_T_VAL};
371- cub ::ArgMax argmax;
372+ fd_cub_compat ::ArgMax argmax;
372373
373374 T const *local_logits = logits + beam_batch_id * vocab_size;
374375#pragma unroll 1
@@ -595,7 +596,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
595596 typename BlockReduceMD::TempStorage md;
596597 } smemReduceBuffer;
597598
598- cub ::ArgMax argmax;
599+ fd_cub_compat ::ArgMax argmax;
599600 MD partial_md{-MAX_T_VAL, 0 .0f };
600601 KVPair topKVPair{vocab_size - 1 , -MAX_T_VAL};
601602
@@ -1336,24 +1337,25 @@ adding while op and without affecting the speed. Use a 'fake inplace' method
13361337here. Not elegant but useful ︸_︸.
13371338*****/
13381339
1339- std::vector<paddle::Tensor> BeamSearchSoftmax (const paddle::Tensor &logits,
1340- const paddle::Tensor &seq_lens,
1341- const paddle::Tensor &stop_flags, // inplace
1342- const paddle::Tensor &end_ids,
1343- const paddle::Tensor &step_ids,
1344- const paddle::Tensor &max_dec_lens,
1345- const paddle::Tensor &block_tables, // inplace
1346- const paddle::Tensor &cum_scores, // inplace
1347- const paddle::Tensor &beam_cache_ids, // inplace
1348- const paddle::Tensor &beam_hyps, // inplace
1349- const paddle::Tensor &beam_hyps_score, // inplace
1350- const paddle::Tensor &beam_finished, // inplace
1351- const paddle::Tensor &beam_width,
1352- const paddle::Tensor &beam_group_num,
1353- const paddle::Tensor &length_penalty,
1354- const paddle::Tensor &diversity_penalty,
1355- bool fuse_softmax,
1356- bool early_stop) {
1340+ std::vector<paddle::Tensor> BeamSearchSoftmax (
1341+ const paddle::Tensor &logits,
1342+ const paddle::Tensor &seq_lens,
1343+ const paddle::Tensor &stop_flags, // inplace
1344+ const paddle::Tensor &end_ids,
1345+ const paddle::Tensor &step_ids,
1346+ const paddle::Tensor &max_dec_lens,
1347+ const paddle::Tensor &block_tables, // inplace
1348+ const paddle::Tensor &cum_scores, // inplace
1349+ const paddle::Tensor &beam_cache_ids, // inplace
1350+ const paddle::Tensor &beam_hyps, // inplace
1351+ const paddle::Tensor &beam_hyps_score, // inplace
1352+ const paddle::Tensor &beam_finished, // inplace
1353+ const paddle::Tensor &beam_width,
1354+ const paddle::Tensor &beam_group_num,
1355+ const paddle::Tensor &length_penalty,
1356+ const paddle::Tensor &diversity_penalty,
1357+ bool fuse_softmax,
1358+ bool early_stop) {
13571359 std::vector<int64_t > logits_shape = logits.shape ();
13581360 // logits_shape
13591361 auto cu_stream = logits.stream ();
@@ -1380,43 +1382,43 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
13801382 const int end_ids_len = end_ids.dims ()[0 ];
13811383 const int beam_group_size = beam_width_scalar / beam_group_num_scalar;
13821384
1383- auto next_tokens = paddle::full ({logits_shape[ 0 ], 1 }, 0 , end_ids. type (),
1384- paddle::GPUPlace ());
1385+ auto next_tokens =
1386+ paddle::full ({logits_shape[ 0 ], 1 }, 0 , end_ids. type (), paddle::GPUPlace ());
13851387
1386- auto parent_ids = paddle::full ({logits_shape[ 0 ], 1 }, 0 , end_ids. type (),
1387- paddle::GPUPlace ());
1388+ auto parent_ids =
1389+ paddle::full ({logits_shape[ 0 ], 1 }, 0 , end_ids. type (), paddle::GPUPlace ());
13881390
1389- auto cum_scores_ori = paddle::empty (cum_scores. shape (), logits. type (),
1390- paddle::GPUPlace ());
1391+ auto cum_scores_ori =
1392+ paddle::empty (cum_scores. shape (), logits. type (), paddle::GPUPlace ());
13911393
1392- auto beam_cache_ids_ori = paddle::empty (beam_cache_ids. shape (), end_ids. type (),
1393- paddle::GPUPlace ());
1394+ auto beam_cache_ids_ori =
1395+ paddle::empty (beam_cache_ids. shape (), end_ids. type (), paddle::GPUPlace ());
13941396
1395- auto block_tables_ori = paddle::empty (block_tables. shape (), end_ids. type (),
1396- paddle::GPUPlace ());
1397+ auto block_tables_ori =
1398+ paddle::empty (block_tables. shape (), end_ids. type (), paddle::GPUPlace ());
13971399 cudaMemcpyAsync (cum_scores_ori.mutable_data <float >(),
13981400 cum_scores.data <float >(),
1399- sizeof (float )* cum_scores.numel (),
1401+ sizeof (float ) * cum_scores.numel (),
14001402 cudaMemcpyDeviceToDevice,
14011403 cu_stream);
14021404 cudaMemcpyAsync (beam_cache_ids_ori.mutable_data <int >(),
14031405 beam_cache_ids.data <int >(),
1404- sizeof (int )* beam_cache_ids.numel (),
1406+ sizeof (int ) * beam_cache_ids.numel (),
14051407 cudaMemcpyDeviceToDevice,
14061408 cu_stream);
14071409 cudaMemcpyAsync (block_tables_ori.mutable_data <int >(),
14081410 block_tables.data <int >(),
1409- sizeof (int )* block_tables.numel (),
1411+ sizeof (int ) * block_tables.numel (),
14101412 cudaMemcpyDeviceToDevice,
14111413 cu_stream);
14121414
14131415 const int tmp_size = batch_size * beam_group_size * beam_group_size * 2 ;
14141416
1415- auto tmp_topk_id = paddle::full ({tmp_size}, 0 , end_ids. type (),
1416- paddle::GPUPlace ());
1417+ auto tmp_topk_id =
1418+ paddle::full ({tmp_size}, 0 , end_ids. type (), paddle::GPUPlace ());
14171419
1418- auto tmp_topk_val = paddle::full ({tmp_size}, 0.0 , logits. type (),
1419- paddle::GPUPlace ());
1420+ auto tmp_topk_val =
1421+ paddle::full ({tmp_size}, 0.0 , logits. type (), paddle::GPUPlace ());
14201422
14211423 BeamSearchParams<float > params;
14221424 params.batch_size = batch_size;
@@ -1449,7 +1451,8 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
14491451 params.block_tables_out = const_cast <int *>(block_tables.data <int >());
14501452 params.cum_scores_out = const_cast <float *>(cum_scores.data <float >());
14511453 params.beam_hyps_out = const_cast <int *>(beam_hyps.data <int >());
1452- params.beam_hyps_score_out = const_cast <float *>(beam_hyps_score.data <float >());
1454+ params.beam_hyps_score_out =
1455+ const_cast <float *>(beam_hyps_score.data <float >());
14531456 params.beam_finished = const_cast <bool *>(beam_finished.data <bool >());
14541457 params.stop_flags = const_cast <bool *>(stop_flags.data <bool >());
14551458
@@ -1470,8 +1473,8 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
14701473
14711474 const int workspace_size = tmp_id_val_size * 2 + tmp_stage1_to_stage2_size;
14721475
1473- auto wsp_buffer_tensor = paddle::full ({workspace_size}, 0 , logits. type (),
1474- paddle::GPUPlace ());
1476+ auto wsp_buffer_tensor =
1477+ paddle::full ({workspace_size}, 0 , logits. type (), paddle::GPUPlace ());
14751478
14761479 params.tmp_ids = reinterpret_cast <int *>(wsp_buffer_tensor.data <float >());
14771480 params.tmp_vals = wsp_buffer_tensor.data <float >() + tmp_id_val_size;
@@ -1480,66 +1483,76 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
14801483 for (int beam_group_idx = 0 ; beam_group_idx < beam_group_num_scalar;
14811484 ++beam_group_idx) {
14821485 if (beam_group_num_scalar == 1 ) {
1483- invokeTopkSoftMax<float , false >(
1484- ¶ms, beam_group_idx, cu_stream);
1486+ invokeTopkSoftMax<float , false >(¶ms, beam_group_idx, cu_stream);
14851487 } else {
1486- invokeTopkSoftMax<float , true >(
1487- ¶ms, beam_group_idx, cu_stream);
1488+ invokeTopkSoftMax<float , true >(¶ms, beam_group_idx, cu_stream);
14881489 }
14891490 }
14901491 updateBeamSearchParams<float >(¶ms, cu_stream);
14911492 return {next_tokens, parent_ids};
14921493}
14931494
14941495std::vector<std::vector<int64_t >> BeamSearchSoftmaxShape (
1495- const std::vector<int64_t > &logits,
1496- const std::vector<int64_t > &seq_lens,
1497- const std::vector<int64_t > &stop_flags, // inplace
1498- const std::vector<int64_t > &end_ids,
1499- const std::vector<int64_t > &step_ids,
1500- const std::vector<int64_t > &max_dec_lens,
1501- const std::vector<int64_t > &block_tables, // inplace
1502- const std::vector<int64_t > &cum_scores, // inplace
1503- const std::vector<int64_t > &beam_cache_ids, // inplace
1504- const std::vector<int64_t > &beam_hyps, // inplace
1505- const std::vector<int64_t > &beam_hyps_score, // inplace
1506- const std::vector<int64_t > &beam_finished, // inplace
1507- const std::vector<int64_t > &beam_width,
1508- const std::vector<int64_t > &beam_group_num,
1509- const std::vector<int64_t > &length_penalty,
1510- const std::vector<int64_t > &diversity_penalty) {
1511- std::vector<int64_t > next_tokens = {logits[0 ],1 };
1512- std::vector<int64_t > parent_ids = {logits[0 ],1 };
1513- return {next_tokens,parent_ids};
1496+ const std::vector<int64_t > &logits,
1497+ const std::vector<int64_t > &seq_lens,
1498+ const std::vector<int64_t > &stop_flags, // inplace
1499+ const std::vector<int64_t > &end_ids,
1500+ const std::vector<int64_t > &step_ids,
1501+ const std::vector<int64_t > &max_dec_lens,
1502+ const std::vector<int64_t > &block_tables, // inplace
1503+ const std::vector<int64_t > &cum_scores, // inplace
1504+ const std::vector<int64_t > &beam_cache_ids, // inplace
1505+ const std::vector<int64_t > &beam_hyps, // inplace
1506+ const std::vector<int64_t > &beam_hyps_score, // inplace
1507+ const std::vector<int64_t > &beam_finished, // inplace
1508+ const std::vector<int64_t > &beam_width,
1509+ const std::vector<int64_t > &beam_group_num,
1510+ const std::vector<int64_t > &length_penalty,
1511+ const std::vector<int64_t > &diversity_penalty) {
1512+ std::vector<int64_t > next_tokens = {logits[0 ], 1 };
1513+ std::vector<int64_t > parent_ids = {logits[0 ], 1 };
1514+ return {next_tokens, parent_ids};
15141515}
15151516
15161517std::vector<paddle::DataType> BeamSearchSoftmaxDtype (
1517- const paddle::DataType &logits,
1518- const paddle::DataType &seq_lens,
1519- const paddle::DataType &stop_flags, // inplace
1520- const paddle::DataType &end_ids,
1521- const paddle::DataType &step_ids,
1522- const paddle::DataType &max_dec_lens,
1523- const paddle::DataType &block_tables, // inplace
1524- const paddle::DataType &cum_scores, // inplace
1525- const paddle::DataType &beam_cache_ids, // inplace
1526- const paddle::DataType &beam_hyps, // inplace
1527- const paddle::DataType &beam_hyps_score, // inplace
1528- const paddle::DataType &beam_finished, // inplace
1529- const paddle::DataType &beam_width,
1530- const paddle::DataType &beam_group_num,
1531- const paddle::DataType &length_penalty,
1532- const paddle::DataType &diversity_penalty) {
1533- return {paddle::DataType::INT32, paddle::DataType::INT32};
1518+ const paddle::DataType &logits,
1519+ const paddle::DataType &seq_lens,
1520+ const paddle::DataType &stop_flags, // inplace
1521+ const paddle::DataType &end_ids,
1522+ const paddle::DataType &step_ids,
1523+ const paddle::DataType &max_dec_lens,
1524+ const paddle::DataType &block_tables, // inplace
1525+ const paddle::DataType &cum_scores, // inplace
1526+ const paddle::DataType &beam_cache_ids, // inplace
1527+ const paddle::DataType &beam_hyps, // inplace
1528+ const paddle::DataType &beam_hyps_score, // inplace
1529+ const paddle::DataType &beam_finished, // inplace
1530+ const paddle::DataType &beam_width,
1531+ const paddle::DataType &beam_group_num,
1532+ const paddle::DataType &length_penalty,
1533+ const paddle::DataType &diversity_penalty) {
1534+ return {paddle::DataType::INT32, paddle::DataType::INT32};
15341535}
15351536
15361537PD_BUILD_STATIC_OP (beam_search_softmax)
1537- .Inputs({" logits" , " seq_lens" , " stop_flags" , " end_ids" , " step_ids" , " max_dec_lens" , " block_tables"
1538- , " cum_scores" , " beam_cache_ids" , " beam_hyps" , " beam_hyps_score" , " beam_finished"
1539- , " beam_width" , " beam_group_num" , " length_penalty" , " diversity_penalty" })
1538+ .Inputs({" logits" ,
1539+ " seq_lens" ,
1540+ " stop_flags" ,
1541+ " end_ids" ,
1542+ " step_ids" ,
1543+ " max_dec_lens" ,
1544+ " block_tables" ,
1545+ " cum_scores" ,
1546+ " beam_cache_ids" ,
1547+ " beam_hyps" ,
1548+ " beam_hyps_score" ,
1549+ " beam_finished" ,
1550+ " beam_width" ,
1551+ " beam_group_num" ,
1552+ " length_penalty" ,
1553+ " diversity_penalty" })
15401554 .Outputs({" next_tokens" , " parent_ids" })
1541- .Attrs({" fuse_softmax: bool" ,
1542- " early_stop: bool" })
1555+ .Attrs({" fuse_softmax: bool" , " early_stop: bool" })
15431556 .SetKernelFn(PD_KERNEL(BeamSearchSoftmax))
15441557 .SetInferShapeFn(PD_INFER_SHAPE(BeamSearchSoftmaxShape))
15451558 .SetInferDtypeFn(PD_INFER_DTYPE(BeamSearchSoftmaxDtype));
0 commit comments