Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ class LLMChat {
/*! \brief reset the runtime stats. */
void ResetRuntimeStats() {
this->prefill_total_tokens = 0;
this->decode_total_tokens = 0;
this->decode_total_tokens = -1;
this->embed_total_time = 0;
this->prefill_total_time = 0;
this->decode_total_time = 0;
Expand Down Expand Up @@ -1031,8 +1031,8 @@ class LLMChat {
int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config);

auto tend = std::chrono::high_resolution_clock::now();

this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
if (this->decode_total_tokens >= 0)
this->decode_total_time += static_cast<double>((tend - tstart).count()) / 1e9;
this->decode_total_tokens += 1;
this->ProcessNextToken(next_token, generation_config);
}
Expand Down Expand Up @@ -1223,14 +1223,16 @@ class LLMChat {
if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_frequency_penalty);
this->UpdateLogitsOrProbOnGPUSync(logits_on_device);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
} else if (gen_repetition_penalty != 1.0f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty);
this->UpdateLogitsOrProbOnGPUSync(logits_on_device);
if (gen_temperature >= 1e-6f) {
this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature);
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
} else {
if (gen_temperature < 1e-6f) {
Expand Down Expand Up @@ -1505,6 +1507,12 @@ class LLMChat {
TVMSynchronize(device_.device_type, device_.device_id, nullptr);
}

void UpdateLogitsOrProbOnGPUSync(NDArray logits_or_prob) {
logits_or_prob.CopyFrom(logits_on_cpu_);

TVMSynchronize(device_.device_type, device_.device_id, nullptr);
}

// Clear kv cache
void ResetKVCache() {
ft_.reset_kv_cache_func_(kv_cache_);
Expand Down Expand Up @@ -1547,7 +1555,7 @@ class LLMChat {
double decode_total_time = 0;
double sample_total_time = 0;
double prefill_total_time = 0;
int64_t decode_total_tokens = 0;
int64_t decode_total_tokens = -1;
int64_t prefill_total_tokens = 0;
//----------------------------
// Conversation
Expand Down