diff --git a/src/configs.js b/src/configs.js index c2eef326d..0862f7adf 100644 --- a/src/configs.js +++ b/src/configs.js @@ -264,6 +264,9 @@ function getNormalizedConfig(config) { */ export function getCacheShapes(config, options) { if (config.model_type === 'lfm2') { + const pkv_prefix = options?.prefix ?? 'past_key_values'; + const conv_prefix = pkv_prefix === 'present' ? 'present' : 'past'; + // Custom caching mechanism for LFM2 /** @type {Record} */ const cache_values = {}; @@ -274,10 +277,10 @@ export function getCacheShapes(config, options) { for (let i = 0; i < layer_types.length; ++i) { if (layer_types[i] === 'full_attention') { for (const kv of ['key', 'value']) { - cache_values[`past_key_values.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, head_dim]; + cache_values[`${pkv_prefix}.${i}.${kv}`] = [batch_size, num_key_value_heads, 0, head_dim]; } } else if (layer_types[i] === 'conv') { - cache_values[`past_conv.${i}`] = [batch_size, hidden_size, conv_L_cache]; + cache_values[`${conv_prefix}_conv.${i}`] = [batch_size, hidden_size, conv_L_cache]; } else { throw new Error(`Unsupported layer type: ${layer_types[i]}`); }