Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class ModelImpl : public ModelObj {
}

NDArray BatchDecode(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids) final {
NVTXScopedRange nvtx_scope("BatchDecode");
NVTXScopedRange nvtx_scope("BatchDecode num_seqs=" + std::to_string(seq_ids.size()));
int num_sequence = seq_ids.size();

CHECK(ft_.decode_func_.defined())
Expand Down Expand Up @@ -395,7 +395,8 @@ class ModelImpl : public ModelObj {

ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states_dref_or_nd,
const std::vector<int64_t>& seq_ids) final {
NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden");
NVTXScopedRange nvtx_scope("BatchDecodeToLastHidden num_seqs=" +
std::to_string(seq_ids.size()));
int num_sequence = seq_ids.size();

CHECK(ft_.decode_to_last_hidden_func_.defined())
Expand Down Expand Up @@ -443,7 +444,6 @@ class ModelImpl : public ModelObj {

NDArray BatchVerify(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) final {
NVTXScopedRange nvtx_scope("BatchVerify");
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
Expand All @@ -452,6 +452,8 @@ class ModelImpl : public ModelObj {
total_length += lengths[i];
}

NVTXScopedRange nvtx_scope("BatchVerify num_tokens=" + std::to_string(total_length));

CHECK(ft_.verify_func_.defined())
<< "`verify_with_embed` function is not found in the model. Please make sure the model is "
"compiled with flag `--sep-embed` and `--enable-batching`";
Expand Down Expand Up @@ -504,14 +506,15 @@ class ModelImpl : public ModelObj {
ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings,
const std::vector<int64_t>& seq_ids,
const std::vector<int>& lengths) final {
NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden");
CHECK(!seq_ids.empty());
CHECK_EQ(seq_ids.size(), lengths.size());
int num_sequences = seq_ids.size();
int total_length = 0;
for (int i = 0; i < num_sequences; ++i) {
total_length += lengths[i];
}
NVTXScopedRange nvtx_scope("BatchVerifyToLastHidden num_tokens=" +
std::to_string(total_length));

CHECK(ft_.verify_to_last_hidden_func_.defined())
<< "`batch_verify_to_last_hidden_states` function is not found in the model.";
Expand Down