Skip to content

[RFC]: Support fast inplace model update by shared IPC buffer #24163

@weixiao-huang

Description

@weixiao-huang

Motivation.

In current colocate RL, train and inference will occupy GPU memory successively. After the train finished, it should send model weights to the inference engine and update model inplace.

Currently, vLLM does not support inplace model update method. In sglang, it serializes tensors into IPC handles in training engines and rebuild tensors from IPC handles in inference engines. This operation may cause overhead due to frequently serializing and deserializing tensors into IPC handles and sending large pickled data between inference engines.

This RFC proposes that we only need to create one GPU tensor from IPC handle in each device and share data from it, which makes model update much faster than serializing and deserializing tensors into IPC handles in each request.

Proposed Change.

Design

A model update flow is combined with multiple update requests. vLLM will expose an http interface called /v1/update-weights-from-ipc to handle each request.

When sending the first model update request, an external field handles should be added into the request. When vLLM gets handles, it will rebuild it as a shared buffer tensor and save it as an attribute. vLLM will use this shared buffer tensor to receive data from trains.

Trains will copy tensor data into this channel and send update_weights_from_ipc request to vLLM, when receiving update request, vLLM will use data from shared channel to update weight inplace.

While handling the last update request, an external field end=True should be added, which indicates that vLLM should remove the shared buffer tensor from IPC to release GPU memory. Thus a full update process is finished.
A single update request is designed below

class UpdateWeightFromIPCRequest:
    # a list of tuple to specify tensor metadata
    # the info in tuple is [name, dtype, shape]
    named_tensors: list[tuple[str, torch.dtype, torch.Size]]
    # dict key is device_uuid, could get my own from `current_platform.get_device_uuid(self.device.index)` in vLLM
    # dict value is a serialized ipc `handle`, vLLM can use `func, args = handle` and `func(*args)` to rebuild GPU tensor
    # if `handles` is not None, means this is the first request in current update flow
    # vLLM should rebuild and save this GPU tensor as a shared buffer
    handles: dict[str, tuple[Callable, tuple]] | None
    # specify the start offset of named_tensors in ipc_buffer tensor
    offset: int
    # specify whether this request is the last request in current update flow
    end: bool

The update implementation can be described in worker like below

class Worker(WorkerBase):
    ...
    def update_weights_from_ipc(
        self,
        named_tensors: list[tuple[str, torch.dtype, torch.Size, int]],
        handles: dict[str, tuple[Callable, tuple]] | None,
        offset: int,
        end: bool,
    ):
        device_id = self.device.index
        BUF_ATTR_NAME = '_shared_ipc_buffer'
        buffer: torch.Tensor
        if handles is not None:
            buffer = rebuild_ipc(handles[self.device_uuid], device_id)
            assert buffer.dtype == torch.uint8
            setattr(self, BUF_ATTR_NAME, buffer)
        else:
            assert hasattr(self, BUF_ATTR_NAME)
            buffer = getattr(self, BUF_ATTR_NAME)
            assert buffer is not None
        weights = []
        for name, dtype, shape in named_tensors:
            if isinstance(shape, (list, tuple)):
                shape = torch.Size(shape)
            assert isinstance(shape, torch.Size)
            size = dtype.itemsize * shape.numel()
            tensor = buffer[offset:offset + size].view(dtype=dtype).view(shape)
            weights.append((name, tensor))
            offset += size
        self.model_runner.model.load_weights(weights=weights)
        del weights
        if end:
            process_weights_after_loading(self.model_runner.model,
                                          self.model_config, self.device)
            if hasattr(self, BUF_ATTR_NAME):
                delattr(self, BUF_ATTR_NAME)
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

Practice

By copying weights into shared buffer, it's convenient for training clients to write a pipeline to accelerate weight update. In the Kimi-K2 report, we wrote a two-stage pipeline like below

Image

The trainer will broadcast data into half of the shared IPC buffer. At the same time, weight update will use the other half of the shared IPC buffer, which makes the two operators parallel.

By using the /v1/update-weights-from-ipc interface in vLLM and pipeline in client, it can perform less than 20s to update all 1T weights in vLLM for Kimi-K2 when deployed in thounsands of GPU devices.

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions