Skip to content

Conversation

@WoosukKwon
Copy link
Collaborator

The compute_dtype argument of the MoE Triton kernel was incorrectly hardcoded to tl.float16.

@WoosukKwon WoosukKwon requested a review from pcmoritz April 30, 2024 01:30
Copy link
Collaborator

@pcmoritz pcmoritz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for fixing this! The bug was introduced in #4244 (very sorry!)

I tested your fix locally and it is working. I also tested the 0.4.1 release (which has this code), but that's working too.

@pcmoritz
Copy link
Collaborator

pcmoritz commented Apr 30, 2024

I think it is working in 0.4.1 because triton seems to track the dtype of a pointer and do the necessary conversions. See
https://triton-lang.org/main/python-api/generated/triton.language.store.html#triton.language.store -- value is implicitly broadcast to pointer.shape and typecast to pointer.dtype.element_ty. (that's also the reason why I didn't catch it when I was testing #4244)

@pcmoritz pcmoritz enabled auto-merge (squash) April 30, 2024 03:07
@WoosukKwon WoosukKwon disabled auto-merge April 30, 2024 05:05
@WoosukKwon WoosukKwon merged commit fa32207 into main Apr 30, 2024
@WoosukKwon WoosukKwon deleted the fix-compute-type branch April 30, 2024 05:05
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants