1010
1111void moe_permute (
1212 const torch::Tensor& input, // [n_token, hidden]
13- const torch::Tensor& topk_weights, // [n_token, topk]
14- torch::Tensor& topk_ids, // [n_token, topk]
13+ const torch::Tensor& topk_ids, // [n_token, topk]
1514 const torch::Tensor& token_expert_indices, // [n_token, topk]
1615 const std::optional<torch::Tensor>& expert_map, // [n_expert]
1716 int64_t n_expert, int64_t n_local_expert, int64_t topk,
1817 const std::optional<int64_t >& align_block_size,
19- torch::Tensor&
20- permuted_input, // [topk * n_token/align_block_size_m, hidden]
18+ torch::Tensor& permuted_input, // [permuted_size, hidden]
2119 torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
22- torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
20+ torch::Tensor& inv_permuted_idx, // [n_token, topk]
21+ torch::Tensor& permuted_idx, // [permute_size]
2322 torch::Tensor& m_indices) { // [align_expand_m]
24- TORCH_CHECK (topk_weights.scalar_type () == at::ScalarType::Float,
25- " topk_weights must be float32" );
2623 TORCH_CHECK (expert_first_token_offset.scalar_type () == at::ScalarType::Long,
2724 " expert_first_token_offset must be int64" );
2825 TORCH_CHECK (topk_ids.scalar_type () == at::ScalarType::Int,
2926 " topk_ids must be int32" );
3027 TORCH_CHECK (token_expert_indices.scalar_type () == at::ScalarType::Int,
3128 " token_expert_indices must be int32" );
32- TORCH_CHECK (src_row_id2dst_row_id_map .scalar_type () == at::ScalarType::Int,
33- " src_row_id2dst_row_id_map must be int32" );
29+ TORCH_CHECK (inv_permuted_idx .scalar_type () == at::ScalarType::Int,
30+ " inv_permuted_idx must be int32" );
3431 TORCH_CHECK (expert_first_token_offset.size (0 ) == n_local_expert + 1 ,
3532 " expert_first_token_offset shape != n_local_expert+1" )
36- TORCH_CHECK (
37- src_row_id2dst_row_id_map.sizes () == token_expert_indices.sizes (),
38- " token_expert_indices shape must be same as src_row_id2dst_row_id_map" );
33+ TORCH_CHECK (inv_permuted_idx.sizes () == token_expert_indices.sizes (),
34+ " token_expert_indices shape must be same as inv_permuted_idx" );
3935 auto n_token = input.sizes ()[0 ];
4036 auto n_hidden = input.sizes ()[1 ];
4137 auto align_block_size_value =
@@ -46,8 +42,9 @@ void moe_permute(
4642 auto sort_workspace = torch::empty (
4743 {sorter_size},
4844 torch::dtype (torch::kInt8 ).device (torch::kCUDA ).requires_grad (false ));
45+ auto copy_topk_ids = topk_ids.clone (); // copy topk_ids for preprocess
4946 auto permuted_experts_id = torch::empty_like (topk_ids);
50- auto dst_row_id2src_row_id_map = torch::empty_like (src_row_id2dst_row_id_map );
47+ auto sorted_row_idx = torch::empty_like (inv_permuted_idx );
5148 auto align_expert_first_token_offset =
5249 torch::zeros_like (expert_first_token_offset);
5350
@@ -67,24 +64,22 @@ void moe_permute(
6764 const int * expert_map_ptr = get_ptr<int >(expert_map.value ());
6865 valid_num_ptr =
6966 get_ptr<int64_t >(expert_first_token_offset) + n_local_expert;
70- preprocessTopkIdLauncher (get_ptr<int >(topk_ids ), n_token * topk,
67+ preprocessTopkIdLauncher (get_ptr<int >(copy_topk_ids ), n_token * topk,
7168 expert_map_ptr, n_expert, stream);
7269 }
7370 // expert sort topk expert id and scan expert id get expert_first_token_offset
74- sortAndScanExpert (get_ptr<int >(topk_ids), get_ptr<int >(token_expert_indices),
75- get_ptr<int >(permuted_experts_id),
76- get_ptr<int >(dst_row_id2src_row_id_map),
77- get_ptr<int64_t >(expert_first_token_offset), n_token,
78- n_expert, n_local_expert, topk, sorter,
79- get_ptr<int >(sort_workspace), stream);
71+ sortAndScanExpert (
72+ get_ptr<int >(copy_topk_ids), get_ptr<int >(token_expert_indices),
73+ get_ptr<int >(permuted_experts_id), get_ptr<int >(sorted_row_idx),
74+ get_ptr<int64_t >(expert_first_token_offset), n_token, n_expert,
75+ n_local_expert, topk, sorter, get_ptr<int >(sort_workspace), stream);
8076
8177 // dispatch expandInputRowsKernelLauncher
8278 MOE_DISPATCH (input.scalar_type (), [&] {
8379 expandInputRowsKernelLauncher<scalar_t >(
8480 get_ptr<scalar_t >(input), get_ptr<scalar_t >(permuted_input),
85- get_ptr<float >(topk_weights), get_ptr<int >(permuted_experts_id),
86- get_ptr<int >(dst_row_id2src_row_id_map),
87- get_ptr<int >(src_row_id2dst_row_id_map),
81+ get_ptr<int >(permuted_experts_id), get_ptr<int >(sorted_row_idx),
82+ get_ptr<int >(inv_permuted_idx), get_ptr<int >(permuted_idx),
8883 get_ptr<int64_t >(expert_first_token_offset), n_token, valid_num_ptr,
8984 n_hidden, topk, n_local_expert, align_block_size_value, stream);
9085 });
@@ -101,32 +96,34 @@ void moe_permute(
10196}
10297
10398void moe_unpermute (
104- const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
105- const torch::Tensor& topk_weights, // [n_token, topk]
106- const torch::Tensor& topk_ids, // [n_token, topk]
107- const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
108- const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
109- int64_t n_expert, int64_t n_local_expert, int64_t topk,
99+ const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
100+ const torch::Tensor& topk_weights, // [n_token, topk]
101+ const torch::Tensor& inv_permuted_idx, // [n_token, topk]
102+ const std::optional< torch::Tensor>&
103+ expert_first_token_offset, // [n_local_expert+1]
104+ int64_t topk,
110105 torch::Tensor& hidden_states // [n_token, hidden]
111106) {
112- TORCH_CHECK (src_row_id2dst_row_id_map.sizes () == topk_ids.sizes (),
113- " topk_ids shape must be same as src_row_id2dst_row_id_map" );
114- TORCH_CHECK (topk_ids.scalar_type () == at::ScalarType::Int,
115- " topk_ids must be int32" );
116107 TORCH_CHECK (
117108 permuted_hidden_states.scalar_type () == hidden_states.scalar_type (),
118- " topk_ids dtype must be same as src_row_id2dst_row_id_map " );
109+ " permuted_hidden_states dtype must be same as hidden_states " );
119110 auto n_token = hidden_states.size (0 );
120111 auto n_hidden = hidden_states.size (1 );
121112 auto stream = at::cuda::getCurrentCUDAStream ().stream ();
122- const int64_t * valid_ptr =
123- get_ptr<int64_t >(expert_first_token_offset) + n_local_expert;
113+
114+ int64_t const * valid_ptr = nullptr ;
115+ if (expert_first_token_offset.has_value ()) {
116+ int n_local_expert = expert_first_token_offset.value ().size (0 ) - 1 ;
117+ valid_ptr =
118+ get_ptr<int64_t >(expert_first_token_offset.value ()) + n_local_expert;
119+ }
120+
124121 MOE_DISPATCH (hidden_states.scalar_type (), [&] {
125122 finalizeMoeRoutingKernelLauncher<scalar_t , scalar_t >(
126123 get_ptr<scalar_t >(permuted_hidden_states),
127124 get_ptr<scalar_t >(hidden_states), get_ptr<float >(topk_weights),
128- get_ptr<int >(src_row_id2dst_row_id_map ), get_ptr< int >(topk_ids) ,
129- n_token, n_hidden, topk, valid_ptr, stream);
125+ get_ptr<int >(inv_permuted_idx ), n_token, n_hidden, topk, valid_ptr ,
126+ stream);
130127 });
131128}
132129
0 commit comments