Skip to content

Conversation

@levunet
Copy link
Contributor

@levunet levunet commented Sep 9, 2025

Purpose

I've confirmed that when using the harmony library, we are passing data to the model in a different format than what the model outputs. Therefore, this PR temporarily fixes this issue to resolve the problem where the model gets confused and stops responding when using /v1/chat/completions with stream and tools.

before:
<|start|>assistant to=functions.project_tree<|channel|>commentary json<|message|>{"max_depth": 10}<|end|>

after:
<|start|>assistant<|channel|>commentary to=functions.project_tree json<|message|>{"max_depth": 10}<|call|>

Reference material: https://cookbook.openai.com/articles/openai-harmony#preambles

Test Plan

Test Result

@levunet levunet requested a review from aarnphm as a code owner September 9, 2025 01:00
@github-actions
Copy link

github-actions bot commented Sep 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added frontend gpt-oss Related to GPT-OSS models labels Sep 9, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue where gpt-oss models stop responding during frequent tool usage by correcting the prompt format. The changes adjust how tool call messages are constructed and rendered. My review identified a critical bug in the new tool call formatting which omits a required part of the syntax, and a high-severity performance issue from redundant encoding operations inside a loop. Addressing these points will help ensure the fix is both correct and efficient.

@levunet levunet force-pushed the feat/gpt-oss-bugfix branch 3 times, most recently from 37bcd77 to ff9b425 Compare September 9, 2025 01:13
@mergify
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @levunet.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
@levunet levunet force-pushed the feat/gpt-oss-bugfix branch from ff9b425 to 3e08066 Compare September 9, 2025 02:55
@mergify mergify bot removed the needs-rebase label Sep 9, 2025
@levunet levunet changed the title [gpt-oss] Fix issue: responses stop with frequent tool usage in /v1/chat/completions [gpt-oss] Fix: No response when using stream tools Sep 9, 2025
@levunet levunet changed the title [gpt-oss] Fix: No response when using stream tools [gpt-oss] Fix: No response when using stream & tools Sep 9, 2025
@levunet levunet force-pushed the feat/gpt-oss-bugfix branch from 3e08066 to 65ec373 Compare September 9, 2025 04:17
@mergify
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @levunet.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
@levunet levunet force-pushed the feat/gpt-oss-bugfix branch 4 times, most recently from 31cb0d7 to 14ea651 Compare September 9, 2025 09:19
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be a valid channel?

Copy link
Contributor Author

@levunet levunet Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The harmony library function failed to add data properly, so it was temporarily added to the channel. If not added, the data would be passed to gpt-oss in the 'before' data format, and as this content accumulates, errors occur in the output structure.

use:
width_recipient(f"functions.{name}")

data:
<|start|>assistant to=functions.{name}<|channel|>

Copy link
Contributor Author

@levunet levunet Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As errors accumulate, the following data will be output:

gpt-oss fail e.g.
"<|start|>assistant<|channel|>commentary to=functions.name>{}<|call|>"
"<|start|>assistant<|channel|>analysis to=functions.file_read <|constrain|>json<|message|>{}<|call|>"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds more like an issue in harmony library. could you raise an issue there?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a unit test for the encoding result?

@heheda12345
Copy link
Collaborator

Also CC @yeqcharlotte

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will evaluate to false for chats with no assistant messages yet, preventing
Calls to these tools must go to the commentary channel: 'functions'.
from being added to the system prompt for requests that define custom tools in the functions namespace. This leads to the same issue of the model putting tool calls in the analysis channel, often using incorrect arguments or forgetting to define a tool recipient.

I think you want to check for a developer message that includes a functions namespace definition instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you suggested, I modified it to check tool usage at msg.author.role == Role.DEVELOPER, so it can be verified without the first message. Thank you!

@levunet levunet force-pushed the feat/gpt-oss-bugfix branch 2 times, most recently from 79ccd98 to 1978a5d Compare September 11, 2025 02:02
@levunet levunet requested a review from IsaacRe September 11, 2025 02:26
@levunet levunet force-pushed the feat/gpt-oss-bugfix branch from 1978a5d to 97ff3d0 Compare September 12, 2025 04:40
Copy link
Collaborator

@yeqcharlotte yeqcharlotte left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix @levunet! could you share some example prompts that show the results before and after the fix? that reproduces your errors?

also any with tool eval result with aime or gpqa on it?

cc: @alecsolder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds more like an issue in harmony library. could you raise an issue there?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a unit test for the encoding result?

@levunet levunet force-pushed the feat/gpt-oss-bugfix branch 2 times, most recently from 5aa2ee2 to 7d3b847 Compare September 12, 2025 05:53
@mergify mergify bot removed the needs-rebase label Sep 12, 2025
@levunet
Copy link
Contributor Author

levunet commented Sep 12, 2025

@yeqcharlotte I've completed the rebase and fixed some bugs. I'll request another review after creating the unit tests.
I tried to contact the harmony team, but their GitHub repository seems to be barely maintained. So I had no choice but to work on it in vllm...

@levunet levunet force-pushed the feat/gpt-oss-bugfix branch 2 times, most recently from fa9222b to 19a14e7 Compare September 12, 2025 06:59
@levunet
Copy link
Contributor Author

levunet commented Sep 12, 2025

image Even though the harmony format data is being input correctly, errors are occasionally occurring during consecutive tool usage, so I'm looking for the cause of this issue.

@levunet levunet force-pushed the feat/gpt-oss-bugfix branch from 19a14e7 to b49586a Compare September 12, 2025 08:24
@levunet
Copy link
Contributor Author

levunet commented Sep 12, 2025

image

Fixed a bug where the commentary value was missing in Invalid Channel due to the absence of with_custom_tools value when fetching the system message. (Resolved the 'analysis to=functions.' issue during initial tool execution)

However, a new problem has emerged... While writing '<|start|>assistant<|channel|>commentary to=functions.search_pattern' for tool execution, it suddenly proceeds with inference.

This commit addresses an issue where the harmony library passes data to the model
in a different format than what the model outputs, causing the model to become
confused and stop responding when using /v1/chat/completions with stream and tools.

The fix updates the message format to match the model's expected output:
- Move recipient info from assistant start tag to channel tag
- Change content type from 'json' to '<|constrain|>json'
- Replace <|end|> token with <|call|> token for tool calls

This is a temporary fix until the underlying format mismatch is properly resolved.

Signed-off-by: kyt <[email protected]>
@levunet levunet force-pushed the feat/gpt-oss-bugfix branch from b49586a to c285452 Compare September 12, 2025 09:13
@levunet
Copy link
Contributor Author

levunet commented Sep 12, 2025

I tried to add unit tests but couldn't because I found it difficult to verify errors that occur when the model tries to use tools creatively based on its temperature setting. My current PR cannot be considered a perfect fix for the gpt-oss tool usage bug. Still, if you lower the temperature to around 0.6, it definitely uses tools better than before, but I don't think it's good code.

@yeqcharlotte What do you think would be the best way to handle this PR..?

@alecsolder
Copy link
Contributor

alecsolder commented Sep 12, 2025

Hey @levunet, a lot of the things in here are things I have noticed as well

Header element ordering

I have noticed that the harmony library renders elements in headers in a different order than the model outputs as well. They mention it in one place in their documentation here

The recipient might be defined in the role or channel section of the header.

The Harmony library is opinionated and always renders them in one order, which also happens to be different than how the model usually outputs the headers. But from my testing that has not impacted tool calling consistency.

Unexpected tokens in harmony header

image Even though the harmony format data is being input correctly, errors are occasionally occurring during consecutive tool usage, so I'm looking for the cause of this issue.

I'd have to look again for specifics in the rust code, but the Harmony library parses headers in a very specific way (generally)

  • It takes things off the front for Role
  • It takes things off the end for content_type
  • It then splits on whitespace and picks the first thing for channel and the last thing for recipient, and the issue is that there was an extra whitespace and it picked up 100 as the recipient
  • So there is one more thing leftover from the split, and it throws an error

I have seen this behavior before, and it was actually not due to tokenization/formatting and was because of there being an issue in the specific attention backend that was being used to run the model. If you can provide information on the configuration + hardware you're using when you're seeing this error that can help in case it is the same one! It was only an issue on certain hardware types.

Overall changes

I do not believe vLLM should adjust render_for_completion functionality for Harmony, it is actually pretty complicated in their library and handles things like removing certain reasoning messages between generations, which would need to be re-implemented again in python. If there are true bugs in their renderer, then it should probably be an issue and PR to the harmony repo.

I do think the following change is good though, and will help for function calling on chat completions API route in general https://github.com/vllm-project/vllm/pull/24473/files#diff-f3135631994e5e8f63fff36f5fb493f404a7c253c004183613a007548156e558R1573

I'm hoping that most of the issues you're seeing are caused by hardware issues, because there should be flags which can resolve them.

@levunet levunet closed this Sep 12, 2025
@levunet
Copy link
Contributor Author

levunet commented Sep 12, 2025

Hey @levunet, a lot of the things in here are things I have noticed as well

Header element ordering

I have noticed that the harmony library renders elements in headers in a different order than the model outputs as well. They mention it in one place in their documentation here

The recipient might be defined in the role or channel section of the header.

The Harmony library is opinionated and always renders them in one order, which also happens to be different than how the model usually outputs the headers. But from my testing that has not impacted tool calling consistency.

Unexpected tokens in harmony header

image Even though the harmony format data is being input correctly, errors are occasionally occurring during consecutive tool usage, so I'm looking for the cause of this issue.

I'd have to look again for specifics in the rust code, but the Harmony library parses headers in a very specific way (generally)

  • It takes things off the front for Role
  • It takes things off the end for content_type
  • It then splits on whitespace and picks the first thing for channel and the last thing for recipient, and the issue is that there was an extra whitespace and it picked up 100 as the recipient
  • So there is one more thing leftover from the split, and it throws an error

I have seen this behavior before, and it was actually not due to tokenization/formatting and was because of there being an issue in the specific attention backend that was being used to run the model. If you can provide information on the configuration + hardware you're using when you're seeing this error that can help in case it is the same one! It was only an issue on certain hardware types.

Overall changes

I do not believe vLLM should adjust render_for_completion functionality for Harmony, it is actually pretty complicated in their library and handles things like removing certain reasoning messages between generations, which would need to be re-implemented again in python. If there are true bugs in their renderer, then it should probably be an issue and PR to the harmony repo.

I do think the following change is good though, and will help for function calling on chat completions API route in general https://github.com/vllm-project/vllm/pull/24473/files#diff-f3135631994e5e8f63fff36f5fb493f404a7c253c004183613a007548156e558R1573

I'm hoping that most of the issues you're seeing are caused by hardware issues, because there should be flags which can resolve them.

@alecsolder
Thank you, I feel like I'm learning new information thanks to your efforts. First, I apologize for not sharing the detailed content of my tests earlier. This was a problem I discovered while developing a feature that uses tool functionality to explore multiple source codes and provide answers - the issue occurred occasionally when exploring more than 10 source codes and documents, causing mistakes. Here is information about my environment. I used python3 collect_env.py, but I'm not sure if this information is sufficient..!

==============================
System Info

OS : Ubuntu 22.04.5 LTS (x86_64)
GCC version : (Ubuntu 10.5.0-1ubuntu1~22.04.2) 10.5.0
Clang version : Could not collect
CMake version : version 3.22.1
Libc version : glibc-2.35

==============================
PyTorch Info

PyTorch version : 2.8.0+cu128
Is debug build : False
CUDA used to build PyTorch : 12.8
ROCM used to build PyTorch : N/A

==============================
Python Environment

Python version : 3.12.11 (main, Jun 4 2025, 08:56:18) [GCC 11.4.0] (64-bit runtime)
Python platform : Linux-6.8.0-79-generic-x86_64-with-glibc2.35

==============================
CUDA / GPU Info

Is CUDA available : True
CUDA runtime version : 12.8.93
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration : GPU 0: NVIDIA GeForce RTX 5090
Nvidia driver version : 575.51.02
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True

==============================
CPU Info

Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 9800X3D 8-Core Processor
CPU family: 26
Model: 68
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 0
CPU max MHz: 5269.0000
CPU min MHz: 600.0000
BogoMIPS: 9381.75
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d
Virtualization: AMD-V
L1d cache: 384 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 8 MiB (8 instances)
L3 cache: 96 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

==============================
Versions of relevant libraries

[pip3] efficientnet_pytorch==0.7.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] open_clip_torch==2.32.0
[pip3] pytorch-lightning==2.5.2
[pip3] pyzmq==27.0.2
[pip3] segmentation_models_pytorch==0.4.0
[pip3] sentence-transformers==3.2.1
[pip3] terratorch==1.0.2
[pip3] torch==2.8.0+cu128
[pip3] torchaudio==2.8.0+cu128
[pip3] torchgeo==0.7.0
[pip3] torchmetrics==1.7.4
[pip3] torchvision==0.23.0+cu128
[pip3] transformers==4.55.2
[pip3] transformers-stream-generator==0.0.5
[pip3] triton==3.4.0
[pip3] tritonclient==2.51.0
[pip3] vector-quantize-pytorch==1.21.2
[conda] Could not collect

==============================
vLLM Info

ROCM Version : Could not collect
vLLM Version : 0.10.2rc2.dev138+g147437f11 (git sha: 147437f11)
vLLM Build Flags:
CUDA Archs: 7.0 7.5 8.0 8.9 9.0 10.0 12.0; ROCm: Disabled
GPU Topology:
GPU0 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X 0-15 0 N/A

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

==============================
Environment Variables

NVIDIA_VISIBLE_DEVICES=all
NVIDIA_REQUIRE_CUDA=cuda>=12.8 brand=unknown,driver>=470,driver<471 brand=grid,driver>=470,driver<471 brand=tesla,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=vapps,driver>=470,driver<471 brand=vpc,driver>=470,driver<471 brand=vcs,driver>=470,driver<471 brand=vws,driver>=470,driver<471 brand=cloudgaming,driver>=470,driver<471 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=vpc,driver>=535,driver<536 brand=vcs,driver>=535,driver<536 brand=vws,driver>=535,driver<536 brand=cloudgaming,driver>=535,driver<536 brand=unknown,driver>=550,driver<551 brand=grid,driver>=550,driver<551 brand=tesla,driver>=550,driver<551 brand=nvidia,driver>=550,driver<551 brand=quadro,driver>=550,driver<551 brand=quadrortx,driver>=550,driver<551 brand=nvidiartx,driver>=550,driver<551 brand=vapps,driver>=550,driver<551 brand=vpc,driver>=550,driver<551 brand=vcs,driver>=550,driver<551 brand=vws,driver>=550,driver<551 brand=cloudgaming,driver>=550,driver<551 brand=unknown,driver>=560,driver<561 brand=grid,driver>=560,driver<561 brand=tesla,driver>=560,driver<561 brand=nvidia,driver>=560,driver<561 brand=quadro,driver>=560,driver<561 brand=quadrortx,driver>=560,driver<561 brand=nvidiartx,driver>=560,driver<561 brand=vapps,driver>=560,driver<561 brand=vpc,driver>=560,driver<561 brand=vcs,driver>=560,driver<561 brand=vws,driver>=560,driver<561 brand=cloudgaming,driver>=560,driver<561 brand=unknown,driver>=565,driver<566 brand=grid,driver>=565,driver<566 brand=tesla,driver>=565,driver<566 brand=nvidia,driver>=565,driver<566 brand=quadro,driver>=565,driver<566 brand=quadrortx,driver>=565,driver<566 brand=nvidiartx,driver>=565,driver<566 brand=vapps,driver>=565,driver<566 brand=vpc,driver>=565,driver<566 brand=vcs,driver>=565,driver<566 brand=vws,driver>=565,driver<566 brand=cloudgaming,driver>=565,driver<566
CUDA_CACHE_DISABLE=0
TORCH_CUDA_ARCH_LIST=7.0 7.5 8.0 8.9 9.0 10.0 12.0
NCCL_VERSION=2.25.1-1
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NVCC_THREADS=4
NVIDIA_PRODUCT_NAME=CUDA
CUDA_VERSION=12.8.1
CUDA_VISIBLE_DEVICES=0
CUDA_VISIBLE_DEVICES=0
MAX_JOBS=4
VLLM_TARGET_DEVICE=cuda
LD_LIBRARY_PATH=/usr/local/cuda/lib64
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

@levunet levunet deleted the feat/gpt-oss-bugfix branch September 12, 2025 17:46
@alecsolder
Copy link
Contributor

alecsolder commented Sep 12, 2025

Thank you! If you want to try, you can try running with

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1

If you get the same issue still then someone else may know more!

Also, I do think that the with_custom_tools change is needed for the chat completions endpoint, so if you want to make a PR just for that, I think it makes sense!

This one: https://github.com/vllm-project/vllm/pull/24473/files#diff-f3135631994e5e8f63fff36f5fb493f404a7c253c004183613a007548156e558R1573

@levunet
Copy link
Contributor Author

levunet commented Sep 13, 2025

@alecsolder
Yes, I added that content to the PR here: #24768! Thank you

@levunet
Copy link
Contributor Author

levunet commented Sep 13, 2025

@alecsolder

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1

I tried using that command, but unfortunately it didn't work for me. I will try to find the cause of the problem once again. T_T
image

@levunet
Copy link
Contributor Author

levunet commented Sep 15, 2025

@alecsolder
To check if this is a local issue on my end, I tested using vllm version 0.10.2 on RunPod, but the same symptoms appear on H100 GPU as well. I'm sharing the example file I used.

vllm option "--model openai/gpt-oss-20b --max-model-len 60000 --gpu-memory-utilization 0.6 --tool-call-parser openai --reasoning-parser openai_gptoss --enable-auto-tool-choice --max-num-batched-tokens 512 --host 0.0.0.0 --port 8000"

gpt-oss_test.py
messages.txt

@levunet
Copy link
Contributor Author

levunet commented Sep 15, 2025

The 'tool_calls not found' error occurs with approximately 50% probability.

@levunet
Copy link
Contributor Author

levunet commented Sep 16, 2025

@alecsolder
I have completely fixed the issue by modifying the harmony lib. I will submit the details as a PR later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

frontend gpt-oss Related to GPT-OSS models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants