1616 * specific language governing permissions and limitations
1717 * under the License.
1818 */
19- #ifndef TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
20- #define TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
19+ #ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_
20+ #define TVM_RUNTIME_RELAX_VM_KV_STATE_H_
2121#include < tvm/runtime/device_api.h>
2222#include < tvm/runtime/logging.h>
2323#include < tvm/runtime/ndarray.h>
2424#include < tvm/runtime/registry.h>
2525
26+ #include " tvm/runtime/object.h"
27+
2628namespace tvm {
2729namespace runtime {
2830namespace relax_vm {
2931
30- /* !
31- * \brief The base class of attention KV cache for efficient
32- * k/v data management and attention computation.
33- */
34- class AttentionKVCache : public Object {
32+ /* ! \brief The base class of attention KV cache and rnn state. */
33+ class KVStateObj : public Object {
3534 public:
36- /* ! \brief Reset the KV cache . */
35+ /* ! \brief Reset the KV State . */
3736 virtual void Clear () = 0;
3837
3938 /* ************* Sequence Management **************/
4039
4140 /* !
42- * \brief Add a new sequence with empty K/V data in the cache.
41+ * \brief Add a new sequence with empty K/V state in the cache.
4342 * Check if the validity of the input sequence id.
4443 * \param seq_id The id of the new sequence to be added.
4544 * \throws Error if the given sequence id is not valid.
4645 */
4746 virtual void AddSequence (int64_t seq_id) = 0;
4847
4948 /* !
50- * \brief Remove a sequence and its K/V data from the KV cache.
49+ * \brief Remove a sequence and its K/V state from the KV cache.
5150 * \param seq_id The sequence to remove from cache.
5251 * \throws Error if the given sequence id is not valid.
5352 */
5453 virtual void RemoveSequence (int64_t seq_id) = 0;
5554
5655 /* !
57- * \brief Fork the K/V data of parent sequence to the child sequence.
58- * After the fork, the child sequence has K/V data of the parent
56+ * \brief Fork the K/V state of parent sequence to the child sequence.
57+ * After the fork, the child sequence has K/V state of the parent
5958 * sequence.
6059 * \param parent_seq_id The parent (source) of the fork.
6160 * \param child_seq_id The child (destination) of the fork.
@@ -73,18 +72,6 @@ class AttentionKVCache : public Object {
7372 */
7473 virtual void PopN (int64_t seq_id, int32_t n) = 0;
7574
76- /* ************* Raw Info Query **************/
77-
78- /* !
79- * \brief Get the number of available pages in the KV cache.
80- * When the underlying KV cache implementation is not
81- * paged KV cache, the function falls back to return the
82- * number of remaining size (in terms of number of tokens).
83- */
84- virtual int32_t GetNumAvailablePages () const = 0;
85-
86- /* ************* Attention **************/
87-
8875 /* !
8976 * \brief Mark the start of the forward function with the ids of
9077 * the sequences and the sequence length to forward for each
@@ -109,6 +96,34 @@ class AttentionKVCache : public Object {
10996 */
11097 virtual void EndForward () = 0;
11198
99+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic ;
100+ static constexpr const char * _type_key = " relax.vm.KVState" ;
101+ TVM_DECLARE_BASE_OBJECT_INFO (KVStateObj, Object)
102+ };
103+
104+ class KVState : public ObjectRef {
105+ public:
106+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (KVState, ObjectRef, KVStateObj);
107+ };
108+
109+ /* !
110+ * \brief The base class of attention KV cache for efficient
111+ * k/v data management and attention computation.
112+ */
113+ class AttentionKVCacheObj : public KVStateObj {
114+ public:
115+ /* ************* Raw Info Query **************/
116+
117+ /* !
118+ * \brief Get the number of available pages in the KV cache.
119+ * When the underlying KV cache implementation is not
120+ * paged KV cache, the function falls back to return the
121+ * number of remaining size (in terms of number of tokens).
122+ */
123+ virtual int32_t GetNumAvailablePages () const = 0;
124+
125+ /* ************* Attention **************/
126+
112127 /* !
113128 * \brief Compute attention with the given Q/K/V data at the specified
114129 * layer with regard to the previously reserved append lengths.
@@ -197,10 +212,63 @@ class AttentionKVCache : public Object {
197212 * \param v_data The V data to set in layout elaborated above.
198213 */
199214 virtual void DebugSetKV (int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) = 0;
215+
216+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic ;
217+ static constexpr const char * _type_key = " relax.vm.AttentionKVCache" ;
218+ TVM_DECLARE_BASE_OBJECT_INFO (AttentionKVCacheObj, KVStateObj);
219+ };
220+
221+ class AttentionKVCache : public KVState {
222+ public:
223+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (AttentionKVCache, KVState, AttentionKVCacheObj);
224+ };
225+
226+ /* !
227+ * \brief The base class of RNN State for efficient
228+ * State data management and attention computation.
229+ */
230+ class RNNStateObj : public KVStateObj {
231+ public:
232+ /* ************* Interaction **************/
233+ /* !
234+ * \brief Get the State data for the specified sequence.
235+ * \param layer_id The model layer where the state is set.
236+ * \param state_id The state id within the layer.
237+ * \param o_data The output data to be fetched.
238+ * \return The array of State data, each element corresponds to a state.
239+ * \throws Error if the given sequence id is not valid.
240+ */
241+ virtual void Get (int64_t layer_id, int64_t state_id, NDArray o_data) = 0;
242+
243+ /* !
244+ * \brief Set the State data for the specified sequence.
245+ * \param layer_id The model layer where the state is set.
246+ * \param state_id The state id within the layer.
247+ * \param data The data to be set.
248+ * \throws Error if the given sequence id is not valid.
249+ */
250+ virtual void Set (int64_t layer_id, int64_t state_id, NDArray data) = 0;
251+
252+ /* !
253+ * \brief Fetch the compact rnn state data of the given sequence.
254+ * \param layer_id The model layer where the state is set.
255+ * \param state_id The state id within the layer.
256+ * \param seq_id The sequence whose state data is to be fetched.
257+ */
258+ virtual NDArray DebugGet (int64_t layer_id, int64_t state_id, int64_t seq_id) = 0;
259+
260+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic ;
261+ static constexpr const char * _type_key = " relax.vm.RNNState" ;
262+ TVM_DECLARE_BASE_OBJECT_INFO (RNNStateObj, KVStateObj);
263+ };
264+
265+ class RNNState : public KVState {
266+ public:
267+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS (RNNState, KVState, RNNStateObj);
200268};
201269
202270} // namespace relax_vm
203271} // namespace runtime
204272} // namespace tvm
205273
206- #endif // TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
274+ #endif // TVM_RUNTIME_RELAX_VM_KV_STATE_H_
0 commit comments