From b00f88451d1018cea0c4f59d120746f2b13783c3 Mon Sep 17 00:00:00 2001 From: mgoin Date: Wed, 25 Jun 2025 20:04:23 +0000 Subject: [PATCH] Fix Mistral tool-parser regex for nested JSON Signed-off-by: mgoin --- .../language/generation/test_mistral.py | 51 +++++++++++++++++++ .../tool_parsers/mistral_tool_parser.py | 4 +- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index bdd857ff5062..c70698ede37a 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -10,6 +10,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( MistralToolCall, MistralToolParser) from vllm.sampling_params import GuidedDecodingParams, SamplingParams +from vllm.transformers_utils.tokenizer import MistralTokenizer from ...utils import check_logprobs_close @@ -318,3 +319,53 @@ def test_mistral_guided_decoding( schema=SAMPLE_JSON_SCHEMA) except jsonschema.exceptions.ValidationError: pytest.fail("Generated response is not valid with JSON schema") + + +def test_mistral_function_call_nested_json(): + """Ensure that the function-name regex captures the entire outer-most + JSON block, including nested braces.""" + + # Create a minimal stub tokenizer that provides the few attributes the + # parser accesses (`version` and `get_vocab`). + class _StubMistralTokenizer(MistralTokenizer): + version = 11 # Satisfy the version check + + def __init__(self): + pass + + @staticmethod + def get_vocab(): + # Provide the special TOOL_CALLS token expected by the parser. + return {"[TOOL_CALLS]": 0} + + tokenizer = _StubMistralTokenizer() + parser = MistralToolParser(tokenizer) + + # Craft a model output featuring nested JSON inside the arguments. + args_dict = { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + "sub_dict": { + "foo": "bar", + "inner": { + "x": 1, + "y": 2 + } + }, + } + + model_output = ( + f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}") + + parsed = parser.extract_tool_calls(model_output, None) + + # Assertions: the tool call is detected and the full nested JSON is parsed + # without truncation. + assert parsed.tools_called + + assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id) + assert parsed.tool_calls[0].function.name == "get_current_weather" + assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict + # No additional content outside the tool call should be returned. + assert parsed.content is None diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index ab1cfd4b6eab..c0691f122904 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -77,8 +77,8 @@ def __init__(self, tokenizer: AnyTokenizer): self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) if _is_fn_name_regex_support(self.model_tokenizer): - self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})', - re.DOTALL) + self.fn_name_regex = re.compile( + r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) else: self.fn_name_regex = None