diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 7723ce0a6c7f7..99e310185e9e4 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -361,7 +361,11 @@ void OVInferRequest::Infer() { StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) : OVInferRequest(std::move(infer_request)), target_device(device) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); - if (gpu_or_npu) { + + // check if there is input_ids tensors and if the tensor type is int64, + // because logic prefill_use_full_chat_history is only for specific inputs and data type + auto input_ids_opt = FindTensor("input_ids"); + if (gpu_or_npu && input_ids_opt.has_value() && input_ids_opt->get_element_type() != ov::element::i64) { prefill_use_full_chat_history = true; } } diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index b48b0efde7ab6..f86d2d54fc381 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -72,6 +72,14 @@ void FuseCacheReorder(std::shared_ptr ov_model, main_input_name = "input_ids"; } + if (ModelHasInputOutputNames(ov_model, "input_hidden_states")) { + main_input_name = "input_hidden_states"; + } + + if (ModelHasInputOutputNames(ov_model, "/model/embed_tokens/Gather_output_0")) { + main_input_name = "/model/embed_tokens/Gather_output_0"; + } + auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({std::move(input_batch)})); @@ -121,20 +129,22 @@ void MakeStateful(std::shared_ptr& ov_model, void PatchStatefulDecoder(std::shared_ptr model) { std::vector key_value_input_names; std::vector not_kv_inputs; - for (const ov::Output& input : model->inputs()) { - auto& names = input.get_names(); - - bool found = false; - for (auto& name : names) { - if (name.find("key_values") != std::string::npos) { - key_value_input_names.push_back(name); + const auto& params = model->get_parameters(); + bool found = false; + for (auto i = 0; i < params.size(); i++) { + auto param_name = params.at(i)->output(0).get_any_name(); + if (param_name.find("key_values") != std::string::npos) { + key_value_input_names.push_back(param_name); + found = true; + } else if (param_name.find("key") != std::string::npos) { + key_value_input_names.push_back(param_name); + found = true; + } else if (param_name.find("value") != std::string::npos) { + key_value_input_names.push_back(param_name); found = true; - break; - } } - if (!found) { - not_kv_inputs.push_back(input.get_any_name()); + not_kv_inputs.push_back(param_name); } }