Skip to content

Commit 5666993

Browse files
zoooo0820mitu626
andauthored
solve conflict (#7135)
Co-authored-by: wangyifei <mitu626@163.com>
1 parent 14676a3 commit 5666993

13 files changed

Lines changed: 1173 additions & 938 deletions

File tree

custom_ops/gpu_ops/beam_search_softmax.cu

Lines changed: 101 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,17 @@ namespace cub = hipcub;
2828
#include <sys/types.h>
2929
#include <unistd.h>
3030
#include <algorithm>
31-
#include "stdint.h"
31+
#include "cccl_compat.h" // CCCL 3.0 compatibility
3232
#include "helper.h"
33+
#include "stdint.h"
3334

3435
#define FLT_MAX 1e38
3536

3637
static constexpr int kBlockSizeForSmallBeamWidth = 256;
3738
static constexpr int kMaxVocabPartForStage1FastKernel = 128;
3839

39-
#define CASE_K(K) \
40-
case K: \
40+
#define CASE_K(K) \
41+
case K: \
4142
invokeTopKSoftMaxLauncher<T, 2 * K, GROUP>( \
4243
params, beam_group_idx, stream); \
4344
break
@@ -368,7 +369,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
368369

369370
using KVPair = cub::KeyValuePair<int, T>;
370371
KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL};
371-
cub::ArgMax argmax;
372+
fd_cub_compat::ArgMax argmax;
372373

373374
T const *local_logits = logits + beam_batch_id * vocab_size;
374375
#pragma unroll 1
@@ -595,7 +596,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
595596
typename BlockReduceMD::TempStorage md;
596597
} smemReduceBuffer;
597598

598-
cub::ArgMax argmax;
599+
fd_cub_compat::ArgMax argmax;
599600
MD partial_md{-MAX_T_VAL, 0.0f};
600601
KVPair topKVPair{vocab_size - 1, -MAX_T_VAL};
601602

@@ -1336,24 +1337,25 @@ adding while op and without affecting the speed. Use a 'fake inplace' method
13361337
here. Not elegant but useful ︸_︸.
13371338
*****/
13381339

1339-
std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
1340-
const paddle::Tensor &seq_lens,
1341-
const paddle::Tensor &stop_flags, // inplace
1342-
const paddle::Tensor &end_ids,
1343-
const paddle::Tensor &step_ids,
1344-
const paddle::Tensor &max_dec_lens,
1345-
const paddle::Tensor &block_tables, // inplace
1346-
const paddle::Tensor &cum_scores, // inplace
1347-
const paddle::Tensor &beam_cache_ids, // inplace
1348-
const paddle::Tensor &beam_hyps, // inplace
1349-
const paddle::Tensor &beam_hyps_score, // inplace
1350-
const paddle::Tensor &beam_finished, // inplace
1351-
const paddle::Tensor &beam_width,
1352-
const paddle::Tensor &beam_group_num,
1353-
const paddle::Tensor &length_penalty,
1354-
const paddle::Tensor &diversity_penalty,
1355-
bool fuse_softmax,
1356-
bool early_stop) {
1340+
std::vector<paddle::Tensor> BeamSearchSoftmax(
1341+
const paddle::Tensor &logits,
1342+
const paddle::Tensor &seq_lens,
1343+
const paddle::Tensor &stop_flags, // inplace
1344+
const paddle::Tensor &end_ids,
1345+
const paddle::Tensor &step_ids,
1346+
const paddle::Tensor &max_dec_lens,
1347+
const paddle::Tensor &block_tables, // inplace
1348+
const paddle::Tensor &cum_scores, // inplace
1349+
const paddle::Tensor &beam_cache_ids, // inplace
1350+
const paddle::Tensor &beam_hyps, // inplace
1351+
const paddle::Tensor &beam_hyps_score, // inplace
1352+
const paddle::Tensor &beam_finished, // inplace
1353+
const paddle::Tensor &beam_width,
1354+
const paddle::Tensor &beam_group_num,
1355+
const paddle::Tensor &length_penalty,
1356+
const paddle::Tensor &diversity_penalty,
1357+
bool fuse_softmax,
1358+
bool early_stop) {
13571359
std::vector<int64_t> logits_shape = logits.shape();
13581360
// logits_shape
13591361
auto cu_stream = logits.stream();
@@ -1380,43 +1382,43 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
13801382
const int end_ids_len = end_ids.dims()[0];
13811383
const int beam_group_size = beam_width_scalar / beam_group_num_scalar;
13821384

1383-
auto next_tokens = paddle::full({logits_shape[0], 1}, 0, end_ids.type(),
1384-
paddle::GPUPlace());
1385+
auto next_tokens =
1386+
paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace());
13851387

1386-
auto parent_ids = paddle::full({logits_shape[0], 1}, 0, end_ids.type(),
1387-
paddle::GPUPlace());
1388+
auto parent_ids =
1389+
paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace());
13881390

1389-
auto cum_scores_ori = paddle::empty(cum_scores.shape(), logits.type(),
1390-
paddle::GPUPlace());
1391+
auto cum_scores_ori =
1392+
paddle::empty(cum_scores.shape(), logits.type(), paddle::GPUPlace());
13911393

1392-
auto beam_cache_ids_ori = paddle::empty(beam_cache_ids.shape(), end_ids.type(),
1393-
paddle::GPUPlace());
1394+
auto beam_cache_ids_ori =
1395+
paddle::empty(beam_cache_ids.shape(), end_ids.type(), paddle::GPUPlace());
13941396

1395-
auto block_tables_ori = paddle::empty(block_tables.shape(), end_ids.type(),
1396-
paddle::GPUPlace());
1397+
auto block_tables_ori =
1398+
paddle::empty(block_tables.shape(), end_ids.type(), paddle::GPUPlace());
13971399
cudaMemcpyAsync(cum_scores_ori.mutable_data<float>(),
13981400
cum_scores.data<float>(),
1399-
sizeof(float)*cum_scores.numel(),
1401+
sizeof(float) * cum_scores.numel(),
14001402
cudaMemcpyDeviceToDevice,
14011403
cu_stream);
14021404
cudaMemcpyAsync(beam_cache_ids_ori.mutable_data<int>(),
14031405
beam_cache_ids.data<int>(),
1404-
sizeof(int)*beam_cache_ids.numel(),
1406+
sizeof(int) * beam_cache_ids.numel(),
14051407
cudaMemcpyDeviceToDevice,
14061408
cu_stream);
14071409
cudaMemcpyAsync(block_tables_ori.mutable_data<int>(),
14081410
block_tables.data<int>(),
1409-
sizeof(int)*block_tables.numel(),
1411+
sizeof(int) * block_tables.numel(),
14101412
cudaMemcpyDeviceToDevice,
14111413
cu_stream);
14121414

14131415
const int tmp_size = batch_size * beam_group_size * beam_group_size * 2;
14141416

1415-
auto tmp_topk_id = paddle::full({tmp_size}, 0, end_ids.type(),
1416-
paddle::GPUPlace());
1417+
auto tmp_topk_id =
1418+
paddle::full({tmp_size}, 0, end_ids.type(), paddle::GPUPlace());
14171419

1418-
auto tmp_topk_val = paddle::full({tmp_size}, 0.0, logits.type(),
1419-
paddle::GPUPlace());
1420+
auto tmp_topk_val =
1421+
paddle::full({tmp_size}, 0.0, logits.type(), paddle::GPUPlace());
14201422

14211423
BeamSearchParams<float> params;
14221424
params.batch_size = batch_size;
@@ -1449,7 +1451,8 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
14491451
params.block_tables_out = const_cast<int *>(block_tables.data<int>());
14501452
params.cum_scores_out = const_cast<float *>(cum_scores.data<float>());
14511453
params.beam_hyps_out = const_cast<int *>(beam_hyps.data<int>());
1452-
params.beam_hyps_score_out = const_cast<float *>(beam_hyps_score.data<float>());
1454+
params.beam_hyps_score_out =
1455+
const_cast<float *>(beam_hyps_score.data<float>());
14531456
params.beam_finished = const_cast<bool *>(beam_finished.data<bool>());
14541457
params.stop_flags = const_cast<bool *>(stop_flags.data<bool>());
14551458

@@ -1470,8 +1473,8 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
14701473

14711474
const int workspace_size = tmp_id_val_size * 2 + tmp_stage1_to_stage2_size;
14721475

1473-
auto wsp_buffer_tensor = paddle::full({workspace_size}, 0, logits.type(),
1474-
paddle::GPUPlace());
1476+
auto wsp_buffer_tensor =
1477+
paddle::full({workspace_size}, 0, logits.type(), paddle::GPUPlace());
14751478

14761479
params.tmp_ids = reinterpret_cast<int *>(wsp_buffer_tensor.data<float>());
14771480
params.tmp_vals = wsp_buffer_tensor.data<float>() + tmp_id_val_size;
@@ -1480,66 +1483,76 @@ std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
14801483
for (int beam_group_idx = 0; beam_group_idx < beam_group_num_scalar;
14811484
++beam_group_idx) {
14821485
if (beam_group_num_scalar == 1) {
1483-
invokeTopkSoftMax<float, false>(
1484-
&params, beam_group_idx, cu_stream);
1486+
invokeTopkSoftMax<float, false>(&params, beam_group_idx, cu_stream);
14851487
} else {
1486-
invokeTopkSoftMax<float, true>(
1487-
&params, beam_group_idx, cu_stream);
1488+
invokeTopkSoftMax<float, true>(&params, beam_group_idx, cu_stream);
14881489
}
14891490
}
14901491
updateBeamSearchParams<float>(&params, cu_stream);
14911492
return {next_tokens, parent_ids};
14921493
}
14931494

14941495
std::vector<std::vector<int64_t>> BeamSearchSoftmaxShape(
1495-
const std::vector<int64_t> &logits,
1496-
const std::vector<int64_t> &seq_lens,
1497-
const std::vector<int64_t> &stop_flags, // inplace
1498-
const std::vector<int64_t> &end_ids,
1499-
const std::vector<int64_t> &step_ids,
1500-
const std::vector<int64_t> &max_dec_lens,
1501-
const std::vector<int64_t> &block_tables, // inplace
1502-
const std::vector<int64_t> &cum_scores, // inplace
1503-
const std::vector<int64_t> &beam_cache_ids, // inplace
1504-
const std::vector<int64_t> &beam_hyps, // inplace
1505-
const std::vector<int64_t> &beam_hyps_score, // inplace
1506-
const std::vector<int64_t> &beam_finished, // inplace
1507-
const std::vector<int64_t> &beam_width,
1508-
const std::vector<int64_t> &beam_group_num,
1509-
const std::vector<int64_t> &length_penalty,
1510-
const std::vector<int64_t> &diversity_penalty) {
1511-
std::vector<int64_t> next_tokens = {logits[0],1};
1512-
std::vector<int64_t> parent_ids = {logits[0],1};
1513-
return {next_tokens,parent_ids};
1496+
const std::vector<int64_t> &logits,
1497+
const std::vector<int64_t> &seq_lens,
1498+
const std::vector<int64_t> &stop_flags, // inplace
1499+
const std::vector<int64_t> &end_ids,
1500+
const std::vector<int64_t> &step_ids,
1501+
const std::vector<int64_t> &max_dec_lens,
1502+
const std::vector<int64_t> &block_tables, // inplace
1503+
const std::vector<int64_t> &cum_scores, // inplace
1504+
const std::vector<int64_t> &beam_cache_ids, // inplace
1505+
const std::vector<int64_t> &beam_hyps, // inplace
1506+
const std::vector<int64_t> &beam_hyps_score, // inplace
1507+
const std::vector<int64_t> &beam_finished, // inplace
1508+
const std::vector<int64_t> &beam_width,
1509+
const std::vector<int64_t> &beam_group_num,
1510+
const std::vector<int64_t> &length_penalty,
1511+
const std::vector<int64_t> &diversity_penalty) {
1512+
std::vector<int64_t> next_tokens = {logits[0], 1};
1513+
std::vector<int64_t> parent_ids = {logits[0], 1};
1514+
return {next_tokens, parent_ids};
15141515
}
15151516

15161517
std::vector<paddle::DataType> BeamSearchSoftmaxDtype(
1517-
const paddle::DataType &logits,
1518-
const paddle::DataType &seq_lens,
1519-
const paddle::DataType &stop_flags, // inplace
1520-
const paddle::DataType &end_ids,
1521-
const paddle::DataType &step_ids,
1522-
const paddle::DataType &max_dec_lens,
1523-
const paddle::DataType &block_tables, // inplace
1524-
const paddle::DataType &cum_scores, // inplace
1525-
const paddle::DataType &beam_cache_ids, // inplace
1526-
const paddle::DataType &beam_hyps, // inplace
1527-
const paddle::DataType &beam_hyps_score, // inplace
1528-
const paddle::DataType &beam_finished, // inplace
1529-
const paddle::DataType &beam_width,
1530-
const paddle::DataType &beam_group_num,
1531-
const paddle::DataType &length_penalty,
1532-
const paddle::DataType &diversity_penalty) {
1533-
return {paddle::DataType::INT32, paddle::DataType::INT32};
1518+
const paddle::DataType &logits,
1519+
const paddle::DataType &seq_lens,
1520+
const paddle::DataType &stop_flags, // inplace
1521+
const paddle::DataType &end_ids,
1522+
const paddle::DataType &step_ids,
1523+
const paddle::DataType &max_dec_lens,
1524+
const paddle::DataType &block_tables, // inplace
1525+
const paddle::DataType &cum_scores, // inplace
1526+
const paddle::DataType &beam_cache_ids, // inplace
1527+
const paddle::DataType &beam_hyps, // inplace
1528+
const paddle::DataType &beam_hyps_score, // inplace
1529+
const paddle::DataType &beam_finished, // inplace
1530+
const paddle::DataType &beam_width,
1531+
const paddle::DataType &beam_group_num,
1532+
const paddle::DataType &length_penalty,
1533+
const paddle::DataType &diversity_penalty) {
1534+
return {paddle::DataType::INT32, paddle::DataType::INT32};
15341535
}
15351536

15361537
PD_BUILD_STATIC_OP(beam_search_softmax)
1537-
.Inputs({"logits", "seq_lens", "stop_flags", "end_ids", "step_ids", "max_dec_lens", "block_tables"
1538-
, "cum_scores", "beam_cache_ids", "beam_hyps", "beam_hyps_score", "beam_finished"
1539-
, "beam_width", "beam_group_num", "length_penalty", "diversity_penalty"})
1538+
.Inputs({"logits",
1539+
"seq_lens",
1540+
"stop_flags",
1541+
"end_ids",
1542+
"step_ids",
1543+
"max_dec_lens",
1544+
"block_tables",
1545+
"cum_scores",
1546+
"beam_cache_ids",
1547+
"beam_hyps",
1548+
"beam_hyps_score",
1549+
"beam_finished",
1550+
"beam_width",
1551+
"beam_group_num",
1552+
"length_penalty",
1553+
"diversity_penalty"})
15401554
.Outputs({"next_tokens", "parent_ids"})
1541-
.Attrs({"fuse_softmax: bool",
1542-
"early_stop: bool"})
1555+
.Attrs({"fuse_softmax: bool", "early_stop: bool"})
15431556
.SetKernelFn(PD_KERNEL(BeamSearchSoftmax))
15441557
.SetInferShapeFn(PD_INFER_SHAPE(BeamSearchSoftmaxShape))
15451558
.SetInferDtypeFn(PD_INFER_DTYPE(BeamSearchSoftmaxDtype));

0 commit comments

Comments
 (0)