Skip to content

Commit 3ef478b

Browse files
author
Siyuan Feng
authored
[Relax][Runtime] RNNState for Space State Models (#16568)
* [Relax][Runtime] RNNState for Space State Models This commit adds the RNNState class to the Relax VM, similar to the PagedKVCache, for space state models like RWKV and mamba * refactor
1 parent d91fe45 commit 3ef478b

File tree

6 files changed

+947
-52
lines changed

6 files changed

+947
-52
lines changed

src/runtime/relax_vm/kv_state.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "kv_state.h"
21+
22+
#include <utility>
23+
24+
namespace tvm {
25+
namespace runtime {
26+
namespace relax_vm {
27+
28+
// Register Object Type
29+
TVM_REGISTER_OBJECT_TYPE(KVStateObj);
30+
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
31+
TVM_REGISTER_OBJECT_TYPE(RNNStateObj);
32+
33+
// KV State base methods
34+
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method<KVState>(&KVStateObj::Clear);
35+
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence")
36+
.set_body_method<KVState>(&KVStateObj::AddSequence);
37+
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence")
38+
.set_body_method<KVState>(&KVStateObj::RemoveSequence);
39+
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
40+
.set_body_method<KVState>(&KVStateObj::ForkSequence);
41+
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
42+
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
43+
.set_body_method<KVState>(&KVStateObj::BeginForward);
44+
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
45+
.set_body_method<KVState>(&KVStateObj::EndForward);
46+
47+
// Attention KV Cache methods
48+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
49+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
50+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
51+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
52+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
53+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
54+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention")
55+
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
56+
double attn_score_scaling_factor, NDArray q_data, NDArray k_data,
57+
NDArray v_data, NDArray o_data) {
58+
kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data),
59+
NullOpt, std::move(o_data), attn_score_scaling_factor);
60+
});
61+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
62+
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
63+
double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) {
64+
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data),
65+
attn_score_scaling_factor);
66+
});
67+
68+
// RNN State methods
69+
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
70+
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
71+
.set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) {
72+
state->Set(layer_id, state_id, data);
73+
return state;
74+
});
75+
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get")
76+
.set_body_method<RNNState>(&RNNStateObj::DebugGet);
77+
78+
} // namespace relax_vm
79+
} // namespace runtime
80+
} // namespace tvm

src/runtime/relax_vm/kv_cache.h renamed to src/runtime/relax_vm/kv_state.h

Lines changed: 93 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,46 +16,45 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
#ifndef TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
20-
#define TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
19+
#ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_
20+
#define TVM_RUNTIME_RELAX_VM_KV_STATE_H_
2121
#include <tvm/runtime/device_api.h>
2222
#include <tvm/runtime/logging.h>
2323
#include <tvm/runtime/ndarray.h>
2424
#include <tvm/runtime/registry.h>
2525

26+
#include "tvm/runtime/object.h"
27+
2628
namespace tvm {
2729
namespace runtime {
2830
namespace relax_vm {
2931

30-
/*!
31-
* \brief The base class of attention KV cache for efficient
32-
* k/v data management and attention computation.
33-
*/
34-
class AttentionKVCache : public Object {
32+
/*! \brief The base class of attention KV cache and rnn state. */
33+
class KVStateObj : public Object {
3534
public:
36-
/*! \brief Reset the KV cache. */
35+
/*! \brief Reset the KV State. */
3736
virtual void Clear() = 0;
3837

3938
/************** Sequence Management **************/
4039

4140
/*!
42-
* \brief Add a new sequence with empty K/V data in the cache.
41+
* \brief Add a new sequence with empty K/V state in the cache.
4342
* Check if the validity of the input sequence id.
4443
* \param seq_id The id of the new sequence to be added.
4544
* \throws Error if the given sequence id is not valid.
4645
*/
4746
virtual void AddSequence(int64_t seq_id) = 0;
4847

4948
/*!
50-
* \brief Remove a sequence and its K/V data from the KV cache.
49+
* \brief Remove a sequence and its K/V state from the KV cache.
5150
* \param seq_id The sequence to remove from cache.
5251
* \throws Error if the given sequence id is not valid.
5352
*/
5453
virtual void RemoveSequence(int64_t seq_id) = 0;
5554

5655
/*!
57-
* \brief Fork the K/V data of parent sequence to the child sequence.
58-
* After the fork, the child sequence has K/V data of the parent
56+
* \brief Fork the K/V state of parent sequence to the child sequence.
57+
* After the fork, the child sequence has K/V state of the parent
5958
* sequence.
6059
* \param parent_seq_id The parent (source) of the fork.
6160
* \param child_seq_id The child (destination) of the fork.
@@ -73,18 +72,6 @@ class AttentionKVCache : public Object {
7372
*/
7473
virtual void PopN(int64_t seq_id, int32_t n) = 0;
7574

76-
/************** Raw Info Query **************/
77-
78-
/*!
79-
* \brief Get the number of available pages in the KV cache.
80-
* When the underlying KV cache implementation is not
81-
* paged KV cache, the function falls back to return the
82-
* number of remaining size (in terms of number of tokens).
83-
*/
84-
virtual int32_t GetNumAvailablePages() const = 0;
85-
86-
/************** Attention **************/
87-
8875
/*!
8976
* \brief Mark the start of the forward function with the ids of
9077
* the sequences and the sequence length to forward for each
@@ -109,6 +96,34 @@ class AttentionKVCache : public Object {
10996
*/
11097
virtual void EndForward() = 0;
11198

99+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
100+
static constexpr const char* _type_key = "relax.vm.KVState";
101+
TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object)
102+
};
103+
104+
class KVState : public ObjectRef {
105+
public:
106+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(KVState, ObjectRef, KVStateObj);
107+
};
108+
109+
/*!
110+
* \brief The base class of attention KV cache for efficient
111+
* k/v data management and attention computation.
112+
*/
113+
class AttentionKVCacheObj : public KVStateObj {
114+
public:
115+
/************** Raw Info Query **************/
116+
117+
/*!
118+
* \brief Get the number of available pages in the KV cache.
119+
* When the underlying KV cache implementation is not
120+
* paged KV cache, the function falls back to return the
121+
* number of remaining size (in terms of number of tokens).
122+
*/
123+
virtual int32_t GetNumAvailablePages() const = 0;
124+
125+
/************** Attention **************/
126+
112127
/*!
113128
* \brief Compute attention with the given Q/K/V data at the specified
114129
* layer with regard to the previously reserved append lengths.
@@ -197,10 +212,63 @@ class AttentionKVCache : public Object {
197212
* \param v_data The V data to set in layout elaborated above.
198213
*/
199214
virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) = 0;
215+
216+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
217+
static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
218+
TVM_DECLARE_BASE_OBJECT_INFO(AttentionKVCacheObj, KVStateObj);
219+
};
220+
221+
class AttentionKVCache : public KVState {
222+
public:
223+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, KVState, AttentionKVCacheObj);
224+
};
225+
226+
/*!
227+
* \brief The base class of RNN State for efficient
228+
* State data management and attention computation.
229+
*/
230+
class RNNStateObj : public KVStateObj {
231+
public:
232+
/************** Interaction **************/
233+
/*!
234+
* \brief Get the State data for the specified sequence.
235+
* \param layer_id The model layer where the state is set.
236+
* \param state_id The state id within the layer.
237+
* \param o_data The output data to be fetched.
238+
* \return The array of State data, each element corresponds to a state.
239+
* \throws Error if the given sequence id is not valid.
240+
*/
241+
virtual void Get(int64_t layer_id, int64_t state_id, NDArray o_data) = 0;
242+
243+
/*!
244+
* \brief Set the State data for the specified sequence.
245+
* \param layer_id The model layer where the state is set.
246+
* \param state_id The state id within the layer.
247+
* \param data The data to be set.
248+
* \throws Error if the given sequence id is not valid.
249+
*/
250+
virtual void Set(int64_t layer_id, int64_t state_id, NDArray data) = 0;
251+
252+
/*!
253+
* \brief Fetch the compact rnn state data of the given sequence.
254+
* \param layer_id The model layer where the state is set.
255+
* \param state_id The state id within the layer.
256+
* \param seq_id The sequence whose state data is to be fetched.
257+
*/
258+
virtual NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) = 0;
259+
260+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
261+
static constexpr const char* _type_key = "relax.vm.RNNState";
262+
TVM_DECLARE_BASE_OBJECT_INFO(RNNStateObj, KVStateObj);
263+
};
264+
265+
class RNNState : public KVState {
266+
public:
267+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RNNState, KVState, RNNStateObj);
200268
};
201269

202270
} // namespace relax_vm
203271
} // namespace runtime
204272
} // namespace tvm
205273

206-
#endif // TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
274+
#endif // TVM_RUNTIME_RELAX_VM_KV_STATE_H_

src/runtime/relax_vm/lm_support.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ namespace relax_vm {
5959
/*!
6060
* \brief An object representing an attention kv cache.
6161
*/
62-
class AttentionKVCacheObj : public Object {
62+
class AttentionKVCacheLegacyObj : public Object {
6363
public:
6464
/*!
6565
* \brief Underlying support data.
@@ -227,7 +227,7 @@ class AttentionKVCacheObj : public Object {
227227

228228
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
229229
static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy";
230-
TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
230+
TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheLegacyObj, Object);
231231
};
232232

233233
/*! \brief reference to closure. */
@@ -239,7 +239,7 @@ class AttentionKVCacheLegacy : public ObjectRef {
239239
*/
240240
static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple reserve_shape,
241241
int init_fill_count) {
242-
auto n = make_object<AttentionKVCacheObj>();
242+
auto n = make_object<AttentionKVCacheLegacyObj>();
243243
n->data = NDArray::Empty(reserve_shape, init_data->dtype, init_data->device);
244244
n->fill_count = 0;
245245
n->Append(init_data);
@@ -250,10 +250,11 @@ class AttentionKVCacheLegacy : public ObjectRef {
250250
return AttentionKVCacheLegacy(n);
251251
}
252252

253-
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef, AttentionKVCacheObj);
253+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef,
254+
AttentionKVCacheLegacyObj);
254255
};
255256

256-
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
257+
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj);
257258

258259
//-------------------------------------------------
259260
// Register runtime functions

0 commit comments

Comments
 (0)