- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.8k
[Kernel] FP8 support for MoE kernel / Mixtral #4244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| self.ws = nn.Parameter(ws, requires_grad=False) | ||
| self.w2s = nn.Parameter(w2s, requires_grad=False) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we define cuda device for the original Parameter, should we do the same here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ws will inherit the device from self.ws.data through torch.empty_like, so I don't think we need to specify a CUDA device here (and we do want to remove it from the original Parameter going forward too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
btw, the scaled quantization kernel shows 2-3x speedup over the PyTorch op implementation in FP8 linear method on my L4, so we should expect some improvements if we use this op in the FP8 linear method. I'll send a follow-up PR after this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job, looks good to me!
| # We accumulate along the K dimension. | ||
| accumulator += tl.dot(a, b) | ||
| if use_fp8: | ||
| accumulator = tl.dot(a, b, acc=accumulator, allow_tf32=True) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: allow_tf is deprecated and it is recommended to use input_precision="tf32"
Also it might be worth noting if this is either: required for fp8 to work or just a performance optimization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the input_precision="tf32" interface is not part of the triton 2.2.0 release yet (which is the version we are using -- triton is pinned to 2.2.0 via the pytorch version we are pinning).
I tried without this flag and can't measure a performance difference, so will remove it. It seems it is the default for devices with tensor cores anyways (https://triton-lang.org/main/python-api/generated/triton.language.dot.html#triton.language.dot). Removing the flag will make us robust against the allow_tf32 vs. input_precision interface change as well. Once we have upgraded trition, it might be worth looking at triton-lang/triton#3234
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Approving after a light pass given @comaniac and @mgoin's approval.
One question:
It implements dynamic per-tensor scaling (see #4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints.
Doesn't this approach still depend on the choice of scaling factor as not all datasets will have the same activation patterns for a given layer? Perhaps the d(quality)/d(per-tensor scaling factor) is small versus other methods, but there is still some impact on quality by the chosen per-tensor scaling factor.
| So in a way it will select the best scaling factor possible at runtime (assuming per-tensor and not subdividing the tensor further for scaling), since it will use the actual activations on the current batch in the forward pass to compute the scaling factor (vs. using some scaling that was computed on an offline dataset). The resulting fp8 tensor will always have 448.0 as its maximum and you can't quantize a given tensor better than that. | 
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
| if use_fp8: | ||
| accumulator = tl.dot(a, b, acc=accumulator) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your great work! I have a question here, why does fp8 need to do this, is there any difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great question, I'm doing it because it supports the fp8 fast accumulation, see
The downside is that you can't do rescaling before adding it to the accumulator
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing #3208
It implements dynamic per-tensor scaling (see #4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the
quantization="fp8"argument. You can try out the PR like this:Performance: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in #3954). With this PR, the results are as follows:
Accuracy: The accuracy with this PR on MMLU on
mistralai/Mixtral-8x7B-v0.1is as follows:this compares favorably with the fp16 results which are
Happy hacking!
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]for bug fixes.[CI/Build]for build or continuous integration improvements.[Doc]for documentation fixes and improvements.[Model]for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]For changes on the vLLM frontend (e.g., OpenAI API server,LLMclass, etc.)[Kernel]for changes affecting CUDA kernels or other compute kernels.[Core]for changes in the core vLLM logic (e.g.,LLMEngine,AsyncLLMEngine,Scheduler, etc.)[Hardware][Vendor]for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]).[Misc]for PRs that do not fit the above categories. Please use this sparingly.Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
format.shto format your code.docs/source/if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with
rfc-requiredand might not go through the PR.What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
action-requiredlabel on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!