Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/openvino/ov_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,11 @@ void OVInferRequest::Infer() {
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
: OVInferRequest(std::move(infer_request)), target_device(device) {
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
if (gpu_or_npu) {

// check if there is input_ids tensors and if the tensor type is int64,
// because logic prefill_use_full_chat_history is only for specific inputs and data type
auto input_ids_opt = FindTensor("input_ids");
if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() != ov::element::i64) {
prefill_use_full_chat_history = true;
}
}
Expand Down
32 changes: 21 additions & 11 deletions onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
main_input_name = "input_ids";
}

if (ModelHasInputOutputNames(ov_model, "input_hidden_states")) {
main_input_name = "input_hidden_states";
}

if (ModelHasInputOutputNames(ov_model, "/model/embed_tokens/Gather_output_0")) {
main_input_name = "/model/embed_tokens/Gather_output_0";
}

auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];

auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({std::move(input_batch)}));
Expand Down Expand Up @@ -121,20 +129,22 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
std::vector<std::string> key_value_input_names;
std::vector<std::string> not_kv_inputs;
for (const ov::Output<ov::Node>& input : model->inputs()) {
auto& names = input.get_names();

bool found = false;
for (auto& name : names) {
if (name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(name);
const auto& params = model->get_parameters();
bool found = false;
for (auto i = 0; i < params.size(); i++) {
auto param_name = params.at(i)->output(0).get_any_name();
if (param_name.find("key_values") != std::string::npos) {
key_value_input_names.push_back(param_name);
found = true;
} else if (param_name.find("key") != std::string::npos) {
key_value_input_names.push_back(param_name);
found = true;
} else if (param_name.find("value") != std::string::npos) {
key_value_input_names.push_back(param_name);
found = true;
break;
}
}

if (!found) {
not_kv_inputs.push_back(input.get_any_name());
not_kv_inputs.push_back(param_name);
}
}

Expand Down
Loading