Skip to content

Commit b01cfab

Browse files
authored
[Model] Removing unnecessary reshapes in get_logits (#2314)
1 parent ea391de commit b01cfab

File tree

6 files changed

+14
-34
lines changed

6 files changed

+14
-34
lines changed

cpp/serve/engine_actions/eagle_batch_draft.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,16 @@ class EagleBatchDraftActionObj : public EngineActionObj {
116116
request_internal_ids);
117117
NDArray logits;
118118
if (models_[model_id]->CanGetLogits()) {
119-
logits = models_[model_id]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
120-
/*seq_len*/ 1);
119+
logits = models_[model_id]->GetLogits(hidden_states);
121120
} else {
122121
// - Use base model's head.
123-
logits =
124-
models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
122+
logits = models_[0]->GetLogits(hidden_states);
125123
}
126124
RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode");
127-
ICHECK_EQ(logits->ndim, 3);
125+
ICHECK_EQ(logits->ndim, 2);
128126
ICHECK_EQ(logits->shape[0], num_rsentries);
129-
ICHECK_EQ(logits->shape[1], 1);
130127

131128
// - Update logits.
132-
logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype);
133129
logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);
134130

135131
// - Compute probability distributions.

cpp/serve/engine_actions/eagle_batch_verify.cc

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,12 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
114114
RECORD_EVENT(trace_recorder_, request_ids, "start verify");
115115
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
116116
embeddings, request_internal_ids, verify_lengths);
117-
NDArray logits =
118-
models_[verify_model_id_]->GetLogits(hidden_states, 1, cum_verify_lengths[num_rsentries]);
117+
NDArray logits = models_[verify_model_id_]->GetLogits(hidden_states);
119118
RECORD_EVENT(trace_recorder_, request_ids, "finish verify");
120-
ICHECK_EQ(logits->ndim, 3);
121-
ICHECK_EQ(logits->shape[0], 1);
122-
ICHECK_EQ(logits->shape[1], cum_verify_lengths[num_rsentries]);
119+
ICHECK_EQ(logits->ndim, 2);
120+
ICHECK_EQ(logits->shape[0], cum_verify_lengths.back());
123121

124122
// - Update logits.
125-
logits =
126-
logits.CreateView({cum_verify_lengths[num_rsentries], logits->shape[2]}, logits->dtype);
127123
logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates,
128124
request_ids, &cum_verify_lengths, &draft_output_tokens);
129125

@@ -273,19 +269,16 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
273269
fused_embedding_hidden_states, request_internal_ids);
274270

275271
if (models_[draft_model_id_]->CanGetLogits()) {
276-
logits = models_[draft_model_id_]->GetLogits(hidden_states, /*batch_size*/ num_rsentries,
277-
/*seq_len*/ 1);
272+
logits = models_[draft_model_id_]->GetLogits(hidden_states);
278273
} else {
279274
// - Use base model's head.
280-
logits = models_[0]->GetLogits(hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);
275+
logits = models_[0]->GetLogits(hidden_states);
281276
}
282277
RECORD_EVENT(trace_recorder_, request_ids, "finish proposal decode");
283-
ICHECK_EQ(logits->ndim, 3);
278+
ICHECK_EQ(logits->ndim, 2);
284279
ICHECK_EQ(logits->shape[0], num_rsentries);
285-
ICHECK_EQ(logits->shape[1], 1);
286280

287281
// - Update logits.
288-
logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype);
289282
logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);
290283

291284
// - Compute probability distributions.

cpp/serve/engine_actions/eagle_new_request_prefill.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
183183
hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates(
184184
hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states);
185185
// logits_for_sample: (b * s, v)
186-
logits_for_sample =
187-
models_[sample_model_id]->GetLogits(hidden_states_for_sample, 1, num_rsentries);
186+
logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample);
188187
// - Update logits.
189188
ICHECK(logits_for_sample.defined());
190189
Array<GenerationConfig> generation_cfg;
@@ -195,8 +194,6 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
195194
generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg);
196195
mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[sample_model_id]);
197196
}
198-
logits_for_sample = logits_for_sample.CreateView({num_rsentries, logits_for_sample->shape[2]},
199-
logits_for_sample->dtype);
200197
logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg,
201198
mstates_for_logitproc, request_ids);
202199

cpp/serve/model.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class ModelImpl : public ModelObj {
125125
return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined();
126126
}
127127

128-
NDArray GetLogits(const ObjectRef& hidden_states, int batch_size, int seq_len) final {
128+
NDArray GetLogits(const ObjectRef& hidden_states) final {
129129
NVTXScopedRange nvtx_scope("GetLogits");
130130
CHECK(ft_.get_logits_func_.defined()) << "`get_logits` function is not found in the model.";
131131

@@ -139,18 +139,14 @@ class ModelImpl : public ModelObj {
139139
if (trace_enabled_) {
140140
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
141141
}
142-
143142
NDArray logits{nullptr};
144143
if (ft_.use_disco) {
145144
logits = Downcast<DRef>(ret)->DebugGetFromRemote(0);
146145
} else {
147146
logits = Downcast<NDArray>(ret);
148147
}
149-
CHECK(logits.defined());
150148
// logits: (b * s, v)
151-
ICHECK_EQ(logits->ndim, 2);
152-
ICHECK_EQ(logits->shape[0], batch_size * seq_len);
153-
return logits.CreateView({batch_size, seq_len, logits->shape[1]}, logits->dtype);
149+
return logits;
154150
}
155151

156152
ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states,

cpp/serve/model.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,9 @@ class ModelObj : public Object {
135135
/*!
136136
* \brief Compute logits for last hidden_states.
137137
* \param last_hidden_states The last hidden_states to compute logits for.
138-
* \param batch_size The batch size of last_hidden_states
139-
* \param seq_len The length of tokens in last_hidden_states
140138
* \return The computed logits.
141139
*/
142-
virtual NDArray GetLogits(const ObjectRef& last_hidden_states, int batch_size, int seq_len) = 0;
140+
virtual NDArray GetLogits(const ObjectRef& last_hidden_states) = 0;
143141

144142
/*!
145143
* \brief Batch prefill function. Embedding in, logits out.

python/mlc_llm/model/llama/llama_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def get_default_spec(self):
352352
},
353353
},
354354
"get_logits": {
355-
"hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype),
355+
"hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype),
356356
"$": {
357357
"param_mode": "packed",
358358
"effect_mode": "none",

0 commit comments

Comments
 (0)