File tree Expand file tree Collapse file tree 2 files changed +18
-17
lines changed
cpp/tensorrt_llm/executor Expand file tree Collapse file tree 2 files changed +18
-17
lines changed Original file line number Diff line number Diff line change @@ -1205,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig)
12051205 totalSize += su::serializedSize (kvCacheConfig.getSecondaryOffloadMinPriority ());
12061206 totalSize += su::serializedSize (kvCacheConfig.getEventBufferMaxSize ());
12071207 totalSize += su::serializedSize (kvCacheConfig.getUseUvm ());
1208+ totalSize += su::serializedSize (kvCacheConfig.getAttentionDpEventsGatherPeriodMs ());
12081209 return totalSize;
12091210}
12101211
Original file line number Diff line number Diff line change @@ -307,6 +307,22 @@ struct get_variant_alternative_type
307307 }
308308};
309309
310+ template <typename T>
311+ T deserialize (std::istream& is);
312+
313+ // Helper function to deserialize variant by index using template recursion
314+ template <typename T, std::size_t ... Is>
315+ T deserializeVariantByIndex (std::istream& is, std::size_t index, std::index_sequence<Is...> /* indices*/ )
316+ {
317+ T result;
318+ bool found = ((Is == index ? (result = deserialize<std::variant_alternative_t <Is, T>>(is), true ) : false ) || ...);
319+ if (!found)
320+ {
321+ TLLM_THROW (" Invalid variant index during deserialization: " + std::to_string (index));
322+ }
323+ return result;
324+ }
325+
310326// Deserialize
311327template <typename T>
312328T deserialize (std::istream& is)
@@ -595,23 +611,7 @@ T deserialize(std::istream& is)
595611 std::size_t index = 0 ;
596612 is.read (reinterpret_cast <char *>(&index), sizeof (index));
597613
598- // TODO: Is there a better way to implement this?
599- T data;
600- if (index == 0 )
601- {
602- using U = std::variant_alternative_t <0 , T>;
603- data = deserialize<U>(is);
604- }
605- else if (index == 1 )
606- {
607- using U = std::variant_alternative_t <1 , T>;
608- data = deserialize<U>(is);
609- }
610- else
611- {
612- TLLM_THROW (" Serialization of variant of size > 2 is not supported." );
613- }
614- return data;
614+ return deserializeVariantByIndex<T>(is, index, std::make_index_sequence<std::variant_size_v<T>>{});
615615 }
616616 else
617617 {
You can’t perform that action at this time.
0 commit comments