55#include " permute_unpermute_kernels/dispatch.h"
66#include " core/registration.h"
77
8+ // moe_permute kernels require at least CUDA 12.0
9+ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
10+
811void moe_permute (
912 const torch::Tensor& input, // [n_token, hidden]
1013 const torch::Tensor& topk_weights, // [n_token, topk]
@@ -127,7 +130,45 @@ void moe_unpermute(
127130 });
128131}
129132
133+ #else
134+
135+ void moe_permute (const torch::Tensor& input, const torch::Tensor& topk_weights,
136+ torch::Tensor& topk_ids,
137+ const torch::Tensor& token_expert_indicies,
138+ const std::optional<torch::Tensor>& expert_map,
139+ int64_t n_expert, int64_t n_local_expert, int64_t topk,
140+ const std::optional<int64_t >& align_block_size,
141+ torch::Tensor& permuted_input,
142+ torch::Tensor& expert_first_token_offset,
143+ torch::Tensor& src_row_id2dst_row_id_map,
144+ torch::Tensor& m_indices) {
145+ TORCH_CHECK (false , " moe_unpermute is not supported on CUDA < 12.0" );
146+ }
147+
148+ void moe_unpermute (const torch::Tensor& input,
149+ const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
150+ const torch::Tensor& token_expert_indicies,
151+ const std::optional<torch::Tensor>& expert_map,
152+ int64_t n_expert, int64_t n_local_expert, int64_t topk,
153+ const std::optional<int64_t >& align_block_size,
154+ torch::Tensor& permuted_input,
155+ torch::Tensor& expert_first_token_offset,
156+ torch::Tensor& src_row_id2dst_row_id_map,
157+ torch::Tensor& m_indices) {
158+ TORCH_CHECK (false , " moe_unpermute is not supported on CUDA < 12.0" );
159+ }
160+
161+ #endif
162+
163+ bool moe_permute_unpermute_supported () {
164+ #if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
165+ return true ;
166+ #else
167+ return false ;
168+ #endif
169+ }
170+
130171TORCH_LIBRARY_IMPL_EXPAND (TORCH_EXTENSION_NAME, CUDA, m) {
131172 m.impl (" moe_permute" , &moe_permute);
132173 m.impl (" moe_unpermute" , &moe_unpermute);
133- }
174+ }
0 commit comments