-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Kernel] LoRA - Enable CUDAGraphs for V1 #14626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2287,9 +2287,14 @@ def compute_hash(self) -> str: | |
| excluding anything before input ids/embeddings and after | ||
| the final hidden states. | ||
| """ | ||
| # no factors to consider. | ||
| # LoRA is not compatible with `torch.compile` . | ||
| factors: list[Any] = [] | ||
| factors.append(self.max_lora_rank) | ||
| factors.append(self.max_loras) | ||
| factors.append(self.fully_sharded_loras) | ||
| factors.append(self.lora_dtype) | ||
| factors.append(self.lora_extra_vocab_size) | ||
| factors.append(self.long_lora_scaling_factors) | ||
| factors.append(self.bias_enabled) | ||
| hash_str = hashlib.md5(str(factors).encode()).hexdigest() | ||
| return hash_str | ||
|
|
||
|
|
@@ -3303,6 +3308,11 @@ def compute_hash(self) -> str: | |
| vllm_factors.append("None") | ||
| if self.lora_config: | ||
| vllm_factors.append(self.lora_config.compute_hash()) | ||
| # LoRA creates static buffers based on max_num_batched_tokens. | ||
| # The tensor sizes and strides get captured in the torch.compile | ||
| # graph explicitly. | ||
| vllm_factors.append( | ||
| str(self.scheduler_config.max_num_batched_tokens)) | ||
| else: | ||
| vllm_factors.append("None") | ||
| if self.speculative_config: | ||
|
|
@@ -3453,12 +3463,15 @@ def __post_init__(self): | |
| " Disabling `torch.compile`.") | ||
| self.compilation_config.level = CompilationLevel.NO_COMPILATION | ||
|
|
||
| if self.lora_config is not None and self.compilation_config.level !=\ | ||
|
||
| CompilationLevel.NO_COMPILATION: | ||
| logger.warning("LoRA is not supported with `torch.compile` yet. " | ||
| "Disabling `torch.compile`.") | ||
| if ((not envs.VLLM_USE_V1) and self.lora_config is not None | ||
| and self.compilation_config.level | ||
| != CompilationLevel.NO_COMPILATION): | ||
| logger.warning( | ||
| "LoRA for V0 is not supported with `torch.compile` yet. " | ||
| "Disabling `torch.compile`.") | ||
| self.compilation_config.level = CompilationLevel.NO_COMPILATION | ||
|
|
||
|
|
||
| if self.model_config and self.model_config.use_mla and \ | ||
| not (current_platform.is_cuda() or current_platform.is_rocm()): | ||
| logger.info( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -237,16 +237,19 @@ def set_lora( | |
| self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| added_tokens_mask = x > self.base_layer.org_vocab_size - 1 | ||
| embeddings_indices = self.punica_wrapper.embeddings_indices | ||
|
||
| indices = embeddings_indices[1].view_as(x) | ||
| added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, | ||
| 1, 0) | ||
| embeddings_indices = torch.narrow( | ||
| self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) | ||
|
|
||
|
||
| indices = embeddings_indices[1] | ||
| full_lora_a_embeddings = F.embedding( | ||
| x + indices, | ||
| self.lora_a_stacked_2d, | ||
| ) | ||
| indices = embeddings_indices[0].view_as(x) | ||
| full_output = self.base_layer.forward( | ||
| x.add_(indices * added_tokens_mask)) | ||
| indices = embeddings_indices[0] | ||
| full_output = self.base_layer.forward(x + | ||
| (indices * added_tokens_mask)) | ||
|
||
|
|
||
| full_output_org = full_output | ||
| if full_output.ndim == 3: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -254,7 +254,9 @@ def add_expand(self, | |
| y_org = y | ||
| y = y.view(-1, y.shape[-1]) | ||
| if lora_bias_stacked is not None: | ||
| self._apply_bias(self.token_lora_indices, y, output_slices, | ||
| token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, | ||
| y.size(0)) | ||
| self._apply_bias(token_lora_indices, y, output_slices, | ||
| lora_bias_stacked) | ||
|
|
||
| if env.VLLM_USE_V1: | ||
|
|
@@ -365,7 +367,9 @@ def add_lora_linear(self, | |
| assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) | ||
| if lora_bias_stacked is not None: | ||
| assert len(lora_bias_stacked) == len(output_slices) | ||
| y = self._apply_bias(self.token_lora_indices, y, output_slices, | ||
| token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0, | ||
| y.size(0)) | ||
| y = self._apply_bias(token_lora_indices, y, output_slices, | ||
| lora_bias_stacked) | ||
|
||
|
|
||
| if buffer is None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you planning on keeping this eager?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it testing code that should be removed before this pr is ready?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intend to keep it. The CI test was running out of memory, which I assume is because of the cudagraph capture.
also, that specific test, doesn't actually run the model.