Skip to content

Conversation

@Jeffwan
Copy link
Contributor

@Jeffwan Jeffwan commented Jul 19, 2024

Address #6275 #3308 #5491

Add the two methods in the apiserver to allow user to load/unload lora adapters explicitly.

  • Update the serving_engine.py and add load and unload with request checking
  • Update api_server.py and add api endpoint and request object
  • Override LoraRequest __eq__ and __hash__. We used to use lora_init_id as the identifier. but the tricky problem we meet is user may not able to persist the id on the client side every time, using model name is more straightforward. what's more, there's no need to load multiple same lora adapters.

My proposal is to use lora_name as the identifier, the other good thing is we can use consistent names across instances. For example, let's say we have 2-3 instances, we want all of them load a specific adapter and unload a specific adapter. Using name is more consistent in such scenarios.

Note: Community already have similar PRs like #3446. I tried to contact the PR author but it's kind of stale. The current PR covers more edge cases and verified in our environment.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@Jeffwan Jeffwan marked this pull request as draft July 19, 2024 07:46
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@Jeffwan Jeffwan force-pushed the jiaxin/runtime_load_unload branch from 6f4cefb to 1020409 Compare July 24, 2024 18:27
@Jeffwan Jeffwan changed the title WIP: [Frontend] Support load and unload LoRA in api server [Frontend] Support load and unload LoRA in api server Jul 24, 2024
@Jeffwan Jeffwan changed the title [Frontend] Support load and unload LoRA in api server [Core] Support load and unload LoRA in api server Jul 24, 2024
@Jeffwan Jeffwan marked this pull request as ready for review July 24, 2024 18:29
@Jeffwan Jeffwan force-pushed the jiaxin/runtime_load_unload branch from 1020409 to 140fbcf Compare July 24, 2024 18:41
@Jeffwan
Copy link
Contributor Author

Jeffwan commented Jul 24, 2024

/cc @Yard1 @simon-mo I am adding more tests for this change and please help review the major change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should mark the dataclass as frozen if we want to use hash

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. good catch. I will add @dataclass(frozen=True) then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm/lora/request.py:37: error: Property "lora_path" defined in "LoRARequest" is read-only  [misc]
vllm/lora/request.py:72: error: Property "lora_path" defined in "LoRARequest" is read-only  [misc]

I meet some issues by adding frozen, the problem is we have compatibility support for deprecated field. At this moment, there's no lora_name modifications, I think it should be safe. Once we remove the deprecated field, we can add frozen back

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should allow operators to disable those methods and have them disabled by default; having user specify the path to the adapter themselves is inherently unsafe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. let me add some feature gate for it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a feature flag now and leave a TODO here. I think once it's enabled, it's still has the risk. The key problem is if we want to support admin operations, then we have to distinguish the users. This is something not supported yet. In our system, we distinguish the user at proxy level. If the vLLM is public to users directly, then that's a problem

@Jeffwan Jeffwan force-pushed the jiaxin/runtime_load_unload branch from dca2f01 to 0f70009 Compare August 3, 2024 23:04
@jgordley
Copy link
Contributor

Hey @Jeffwan, are you still planning on finishing up this branch? LMK if I can help in any way, this is a great feature

@Jeffwan
Copy link
Contributor Author

Jeffwan commented Aug 13, 2024

@jgordley Yeah, I am still working on it. I was off last week and I will spend some time fixing the UTs. Thanks for offering the help, I will let you know if I meet problems.

Copy link

@kfswain kfswain Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could become duplicated with multiple edits:
Example [1,2,3,4] -> remove 3 -> [1,2,4] -> add new adapter -> [1,2,4,4]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a known issue and another PR will fix this issue. Even we change to use a global counter, user can still pass the int id and override it, which means the counter may generate duplicate id next time. Technically, we should always use model name from user and id should be managed by the engine.

@Jeffwan Jeffwan force-pushed the jiaxin/runtime_load_unload branch 3 times, most recently from c559c9b to e427c0b Compare August 28, 2024 14:01
@Jeffwan
Copy link
Contributor Author

Jeffwan commented Aug 28, 2024

@Yard1 can you help review the updated commits? there're some test failures due to connection timeout and I am retrying those at the same time

@simon-mo
Copy link
Collaborator

@jeejeelee can you help review this one? thank you!

@jeejeelee
Copy link
Collaborator

@jeejeelee can you help review this one? thank you!

If there's no rush, I'll take some time to look at it this weekend.

@Jeffwan Jeffwan force-pushed the jiaxin/runtime_load_unload branch 2 times, most recently from 98d9a4c to 9246d65 Compare August 30, 2024 02:35
@jeejeelee
Copy link
Collaborator

@Jeffwan Thank you very much for your great work. I've left some comments,thanks~

@simon-mo
Copy link
Collaborator

simon-mo commented Sep 3, 2024

@jeejeelee you might need to submit your review in github, as it is not showing here.

@Jeffwan
Copy link
Contributor Author

Jeffwan commented Sep 4, 2024

I will address all the comments by tomorrow

@Jeffwan Jeffwan force-pushed the jiaxin/runtime_load_unload branch from 9246d65 to 072f946 Compare September 4, 2024 04:58
@Jeffwan Jeffwan force-pushed the jiaxin/runtime_load_unload branch from 072f946 to 7ae737a Compare September 4, 2024 21:57
@Jeffwan
Copy link
Contributor Author

Jeffwan commented Sep 4, 2024

image
All the code related CI failures have been fixed, current failures are from AMD test suite and primary reason is the docker hub rate limit.

@Jeffwan
Copy link
Contributor Author

Jeffwan commented Sep 4, 2024

@jeejeelee Hi, I address all the comments and please take another look

@Jeffwan
Copy link
Contributor Author

Jeffwan commented Sep 5, 2024

/ready

@jeejeelee
Copy link
Collaborator

@Jeffwan Thank you very much for your work again. This is indeed a useful feature. After making a few more NIT changes, I will give the green light. cc @simon-mo

@Jeffwan
Copy link
Contributor Author

Jeffwan commented Sep 5, 2024

@jeejeelee doc update looks good to me. I accepted the suggestions

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp given @jeejeelee's comment. thank @Jeffwan for the hardware and thank @jeejeelee for the shepherding!

@simon-mo simon-mo merged commit db3bf7c into vllm-project:main Sep 6, 2024
@Jeffwan Jeffwan deleted the jiaxin/runtime_load_unload branch September 6, 2024 03:51

self.served_model_names = served_model_names

self.lora_id_counter = AtomicCounter(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jeffwan this is not required here. asyncio operations all happen in the same thread. Can change this to be a simple int field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill yeah, that makes sense. Let me file a follow up PR to improve it.

dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Sep 12, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants