@@ -84,27 +84,21 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m)
8484 .def (" check_gen_transfer_status" , &BaseCacheTransceiver::checkGenTransferStatus)
8585 .def (" check_gen_transfer_complete" , &BaseCacheTransceiver::checkGenTransferComplete);
8686
87- nb::enum_<tb::CacheTransceiver::CommType>(m, " CommType" )
88- .value (" UNKNOWN" , tb::CacheTransceiver::CommType::UNKNOWN)
89- .value (" MPI" , tb::CacheTransceiver::CommType::MPI)
90- .value (" UCX" , tb::CacheTransceiver::CommType::UCX)
91- .value (" NIXL" , tb::CacheTransceiver::CommType::NIXL);
92-
9387 nb::enum_<executor::kv_cache::CacheState::AttentionType>(m, " AttentionType" )
9488 .value (" DEFAULT" , executor::kv_cache::CacheState::AttentionType::kDEFAULT )
9589 .value (" MLA" , executor::kv_cache::CacheState::AttentionType::kMLA );
9690
9791 nb::class_<tb::CacheTransceiver, tb::BaseCacheTransceiver>(m, " CacheTransceiver" )
98- .def (nb::init<tb::kv_cache_manager::BaseKVCacheManager*, tb::CacheTransceiver::CommType ,
99- std::vector<SizeType32>, SizeType32, SizeType32, runtime::WorldConfig, nvinfer1::DataType,
100- executor::kv_cache::CacheState::AttentionType, std::optional<executor::CacheTransceiverConfig>>(),
101- nb::arg (" cache_manager" ), nb::arg (" comm_type " ), nb::arg ( " num_kv_heads_per_layer" ), nb::arg (" size_per_head" ),
92+ .def (nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::vector<SizeType32>, SizeType32, SizeType32 ,
93+ runtime::WorldConfig, nvinfer1::DataType, executor::kv_cache::CacheState::AttentionType ,
94+ std::optional<executor::CacheTransceiverConfig>>(),
95+ nb::arg (" cache_manager" ), nb::arg (" num_kv_heads_per_layer" ), nb::arg (" size_per_head" ),
10296 nb::arg (" tokens_per_block" ), nb::arg (" world_config" ), nb::arg (" dtype" ), nb::arg (" attention_type" ),
10397 nb::arg (" cache_transceiver_config" ) = std::nullopt );
10498
10599 nb::class_<tb::kv_cache_manager::CacheTransBufferManager>(m, " CacheTransBufferManager" )
106100 .def (nb::init<tb::kv_cache_manager::BaseKVCacheManager*, std::optional<size_t >>(), nb::arg (" cache_manager" ),
107101 nb::arg (" max_num_tokens" ) = std::nullopt )
108102 .def_static (" pre_alloc_buffer_size" , &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize,
109- nb::arg (" max_num_tokens " ) = std:: nullopt );
103+ nb::arg (" cache_size_bytes_per_token_per_window " ), nb::arg ( " cache_transceiver_config " ) = nb::none () );
110104}
0 commit comments