Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "kv_state.h"

#include <utility>

namespace tvm {
namespace runtime {
namespace relax_vm {

// Register Object Type
TVM_REGISTER_OBJECT_TYPE(KVStateObj);
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
TVM_REGISTER_OBJECT_TYPE(RNNStateObj);

// KV State base methods
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method<KVState>(&KVStateObj::Clear);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence")
.set_body_method<KVState>(&KVStateObj::AddSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence")
.set_body_method<KVState>(&KVStateObj::RemoveSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
.set_body_method<KVState>(&KVStateObj::ForkSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
.set_body_method<KVState>(&KVStateObj::BeginForward);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
.set_body_method<KVState>(&KVStateObj::EndForward);

// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray q_data, NDArray k_data,
NDArray v_data, NDArray o_data) {
kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data),
NullOpt, std::move(o_data), attn_score_scaling_factor);
});
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) {
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data),
attn_score_scaling_factor);
});

// RNN State methods
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
.set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) {
state->Set(layer_id, state_id, data);
return state;
});
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get")
.set_body_method<RNNState>(&RNNStateObj::DebugGet);

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
118 changes: 93 additions & 25 deletions src/runtime/relax_vm/kv_cache.h → src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,45 @@
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
#define TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
#ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_
#define TVM_RUNTIME_RELAX_VM_KV_STATE_H_
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>

#include "tvm/runtime/object.h"

namespace tvm {
namespace runtime {
namespace relax_vm {

/*!
* \brief The base class of attention KV cache for efficient
* k/v data management and attention computation.
*/
class AttentionKVCache : public Object {
/*! \brief The base class of attention KV cache and rnn state. */
class KVStateObj : public Object {
public:
/*! \brief Reset the KV cache. */
/*! \brief Reset the KV State. */
virtual void Clear() = 0;

/************** Sequence Management **************/

/*!
* \brief Add a new sequence with empty K/V data in the cache.
* \brief Add a new sequence with empty K/V state in the cache.
* Check if the validity of the input sequence id.
* \param seq_id The id of the new sequence to be added.
* \throws Error if the given sequence id is not valid.
*/
virtual void AddSequence(int64_t seq_id) = 0;

/*!
* \brief Remove a sequence and its K/V data from the KV cache.
* \brief Remove a sequence and its K/V state from the KV cache.
* \param seq_id The sequence to remove from cache.
* \throws Error if the given sequence id is not valid.
*/
virtual void RemoveSequence(int64_t seq_id) = 0;

/*!
* \brief Fork the K/V data of parent sequence to the child sequence.
* After the fork, the child sequence has K/V data of the parent
* \brief Fork the K/V state of parent sequence to the child sequence.
* After the fork, the child sequence has K/V state of the parent
* sequence.
* \param parent_seq_id The parent (source) of the fork.
* \param child_seq_id The child (destination) of the fork.
Expand All @@ -73,18 +72,6 @@ class AttentionKVCache : public Object {
*/
virtual void PopN(int64_t seq_id, int32_t n) = 0;

/************** Raw Info Query **************/

/*!
* \brief Get the number of available pages in the KV cache.
* When the underlying KV cache implementation is not
* paged KV cache, the function falls back to return the
* number of remaining size (in terms of number of tokens).
*/
virtual int32_t GetNumAvailablePages() const = 0;

/************** Attention **************/

/*!
* \brief Mark the start of the forward function with the ids of
* the sequences and the sequence length to forward for each
Expand All @@ -109,6 +96,34 @@ class AttentionKVCache : public Object {
*/
virtual void EndForward() = 0;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.KVState";
TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object)
};

class KVState : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(KVState, ObjectRef, KVStateObj);
};

/*!
* \brief The base class of attention KV cache for efficient
* k/v data management and attention computation.
*/
class AttentionKVCacheObj : public KVStateObj {
public:
/************** Raw Info Query **************/

/*!
* \brief Get the number of available pages in the KV cache.
* When the underlying KV cache implementation is not
* paged KV cache, the function falls back to return the
* number of remaining size (in terms of number of tokens).
*/
virtual int32_t GetNumAvailablePages() const = 0;

/************** Attention **************/

/*!
* \brief Compute attention with the given Q/K/V data at the specified
* layer with regard to the previously reserved append lengths.
Expand Down Expand Up @@ -197,10 +212,63 @@ class AttentionKVCache : public Object {
* \param v_data The V data to set in layout elaborated above.
*/
virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) = 0;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
TVM_DECLARE_BASE_OBJECT_INFO(AttentionKVCacheObj, KVStateObj);
};

class AttentionKVCache : public KVState {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, KVState, AttentionKVCacheObj);
};

/*!
* \brief The base class of RNN State for efficient
* State data management and attention computation.
*/
class RNNStateObj : public KVStateObj {
public:
/************** Interaction **************/
/*!
* \brief Get the State data for the specified sequence.
* \param layer_id The model layer where the state is set.
* \param state_id The state id within the layer.
* \param o_data The output data to be fetched.
* \return The array of State data, each element corresponds to a state.
* \throws Error if the given sequence id is not valid.
*/
virtual void Get(int64_t layer_id, int64_t state_id, NDArray o_data) = 0;

/*!
* \brief Set the State data for the specified sequence.
* \param layer_id The model layer where the state is set.
* \param state_id The state id within the layer.
* \param data The data to be set.
* \throws Error if the given sequence id is not valid.
*/
virtual void Set(int64_t layer_id, int64_t state_id, NDArray data) = 0;

/*!
* \brief Fetch the compact rnn state data of the given sequence.
* \param layer_id The model layer where the state is set.
* \param state_id The state id within the layer.
* \param seq_id The sequence whose state data is to be fetched.
*/
virtual NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) = 0;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.RNNState";
TVM_DECLARE_BASE_OBJECT_INFO(RNNStateObj, KVStateObj);
};

class RNNState : public KVState {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RNNState, KVState, RNNStateObj);
};

} // namespace relax_vm
} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
#endif // TVM_RUNTIME_RELAX_VM_KV_STATE_H_
11 changes: 6 additions & 5 deletions src/runtime/relax_vm/lm_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace relax_vm {
/*!
* \brief An object representing an attention kv cache.
*/
class AttentionKVCacheObj : public Object {
class AttentionKVCacheLegacyObj : public Object {
public:
/*!
* \brief Underlying support data.
Expand Down Expand Up @@ -227,7 +227,7 @@ class AttentionKVCacheObj : public Object {

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy";
TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheLegacyObj, Object);
};

/*! \brief reference to closure. */
Expand All @@ -239,7 +239,7 @@ class AttentionKVCacheLegacy : public ObjectRef {
*/
static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple reserve_shape,
int init_fill_count) {
auto n = make_object<AttentionKVCacheObj>();
auto n = make_object<AttentionKVCacheLegacyObj>();
n->data = NDArray::Empty(reserve_shape, init_data->dtype, init_data->device);
n->fill_count = 0;
n->Append(init_data);
Expand All @@ -250,10 +250,11 @@ class AttentionKVCacheLegacy : public ObjectRef {
return AttentionKVCacheLegacy(n);
}

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef, AttentionKVCacheObj);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef,
AttentionKVCacheLegacyObj);
};

TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj);

//-------------------------------------------------
// Register runtime functions
Expand Down
Loading