Skip to content

Conversation

@louie-tsai
Copy link
Contributor

@louie-tsai louie-tsai commented May 10, 2025

[Current Implementation]
In order to have good performance for Tensor Parallel or Pipeline Parallel on CPU, users need to do proper CPU OMP Threads binding to avoid perf degradation causing by multiple threads on same CPU core.

Here is the current run instructions for CPU OMP THREADS BIND.
OMP_NUM_THREADS=32 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63|64-95|96-127" python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --device cpu -tp=4 --distributed-executor-backend mp

[Problem Statement]
However, CPU ids might change among OSes, and different CPU SKU also have different CPU numbers.
Moreover, users might need to check cpus_allow_list also to only bind to allowed CPU cores.
It requires users to check their environment first, and set the Binding properly.
In some cases like cluster deployment using Kubernetes, users won't know the CPU ids before the deployments, so it is hard to set the deployment scripts like k8s yaml file correctly.

[Proposed Solution]
Therefore, we introduce a new feature to automatically bind CPU OMP threads of a rank to CPU ids of a allowed NUMA node according to cpus_allowed_list.
Therefore, no need to set the VLLM_CPU_OMP_THREADS_BIND according to users env, and CPU worker will set it automatically for users.
New run instructions will like below one, and it is also easier to have Tensor Parallel support for k8s deployment environment.
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --device cpu -tp=4 --distributed-executor-backend mp

Overall, it will bind a rank/world to a allowed numa node, so 4 numa nodes will be used for tp=4 pp=1.
2 numa nodes will be used for tp=2 pp=1.
if current environment only allow 2 numa nodes, it will return errors.

[Related Environment variables]

VLLM_CPU_OMP_THREADS_BIND : By setting to auto, the OpenMP threads of each rank are bounds to the CPU cores in each NUMA node. Default value is auto.
VLLM_CPU_NUM_OF_RESERVED_CPU: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to auto. Default value is 0.

[More Details]
you could see in the below diagram.
Even vllm server doesn't get input for VLLM_CPU_OMP_THREADS_BIND, and set omp_cpuids to all.
CPU worker will automatically overwrite the local_omp_cpuid according to current system numa configuration.
image

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label May 10, 2025
@xuechendi
Copy link
Contributor

@bigPYJ1151 please help to review

Copy link
Member

@bigPYJ1151 bigPYJ1151 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@louie-tsai Thanks for your good work and apology for the late response. I have some suggestions, please take a look.

BTW, can you help to update the CPU document about the PR as well? You can find it at docs/source/getting_started/installation/cpu.md, thanks :)

logger.info("[ERROR] NO AUTO OMP Bind support because request world size: %d is more than allowed numa_size: %d",
world_size, len(node_to_cpus))
else:
rank_to_cpus=str(node_to_cpus[self.rank][0]) + '-' + str(node_to_cpus[self.rank][cpu_count_per_numa - 1 - num_of_no_bind_cpu])
Copy link

@askervin askervin May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building the rank_to_cpus string assumes that CPU ids in node_to_cpus[self.rank] are contiguous. But it could have CPUs 0-39,240-279. If so, this means that rank_to_cpus would contain CPUs 0-279. The safest option would be to construct the CPU list without assuming anything about how CPUs are numbered inside a NUMA node. (There's a lot variation in different platforms.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askervin
We used info.node_to_cpus(i) to get the list of CPU ids in a numa node, so we don't assume the CPU id will be continues in a numa node.

For CPUs 0-39,240-279, we only get node_to_cpus = [0-39] since 240-279 is also in numa node 0 when HT is on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askervin
changed the codes to address non-contiguous code id in a numa node. hope it address your feedback.

cpu_count = psutil.cpu_count(logical=False)
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
cpu_count_per_numa = cpu_count // numa_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some systems include NUMA nodes with CPUs and other NUMA nodes without CPUs. CPU-less NUMA nodes might be, for instance HBM (high-bandwidth memory), PMEM (persistent memory) or CXL memory blocks.

I'd suggest calculating cpu_count_per_numa by dividing the number of CPUs with the number of NUMA nodes that contain CPUs.

Or, to be even more accurate, calculate first node_to_cpus so that

node_to_cpus[node] = intersection of cpus on the node and the set of allowed CPUs

...and skip all nodes where the intersection is empty.

After this, if all node_to_cpus contain the same number of CPUs, there is your cpu_count_per_numa. And then your algorithm will work even if allowed CPUs would be equal-sized chunks from separate NUMA nodes. I think this would be very nice.

If CPU sets in nodes_to_cpus are of different size, it would be fine to print a warning about not doing auto affinity to ranks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@askervin
thanks for your detailed input. cpu_count_per_numa is only used to get the index of last CPU in one node. it doesn't use to find out node_to_cpus
The main logic to find CPU ids for each node are in below codes, and it should address your request "node_to_cpus[node] = intersection of cpus on the node and the set of allowed CPUs

image

also the info.node_to_cpus(i) should cover cpu id only in one numa node. In general, we should have same number of cpus per node. (not the case for GNR SNC=3).
I assume that info.node_to_cpus(i) will return CPU-less NUMA nodes. Please correct me if I have misunderstandings.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a two-socket Xeon MAX CPU as an example. It has high-bandwidth memory (HBM) integrated directly into CPU package, and DDR5 memory DIMMs next to each socket.

Assume that this system is configured to expose 4 NUMA nodes. (Depending on configuration, it could have even 16 nodes in real-life.) Assume that nodes 0 and 1 include both CPUs and memory (DDR5), while NUMA nodes 2 and 3 include only memory (HBM) and no CPUs.

info.get_numa_configured_nodes() returns the number of all NUMA nodes, that is 4. This would be the numa_size. Let's assume that

info.nodes_to_cpus(0) returns [0,1,..., 30, 31, 64, 65, ..., 94, 95]

info.nodes_to_cpus(1) returns []

info.nodes_to_cpus(2) returns [32, 33, ..., 62, 63, 96, 97, ..., 126, 127]

info.nodes_to_cpus(3) returns []

This gives you nodes_to_cpus = [[0-31, 64-95], [], [32-63, 96-127], []]. And cpu_count_per_numa = cpu_count / numa_size = 64 / 4 = 16

Now, in the case of rank=0,

rank_to_cpus=str(node_to_cpus[self.rank][0]) + '-' + str(node_to_cpus[self.rank][cpu_count_per_numa - 1 - num_of_reserved_cpu])

gives "0-16". rank=1 crashes on IndexError.

And this example is still a kind of easy, CPU ids of thread0 and thread1 in the NUMA nodes are continuous. But there are platforms where they are not. The code should work even if
info.nodes_to_cpus(0) returns [0,2,4,6]
info.nodes_to_cpus(1) returns [1,3,5,7]

NUMA ordering and CPU numbering can be quite exciting, so let's make any assumptions on it.

So what I'm suggesting is:

  • Do not require set(info.node_to_cpus(i)).issubset(cpus_allow_list). This requirement is very hard, because it means that the resource policy that is managing the server has been able to allocate every single CPU from a NUMA node. But there are many other containers (other vLLM containers, databases, whatever servers) that might have exclusive CPUs from every node, and therefore cpus_allow_list may not include all CPUs from any NUMA node. But it if includes 16 CPUs from one NUMA node and 16 CPUs from another NUMA node, this automatic optimization could still work.

  • So what do you think, if instead of this requirement, you would construct node_to_cpus like node_to_cpus.append(set(info.node_to_cpus(i)).intersection(cpus_allow_list)) whenever the intersection is not empty? This would automatically solve the problem of CPU-less NUMA node (the IndexError that I mentioned above). And it would enable running nicely in GNR SNC3, if only the resource policy allocates equal number of CPUs from every node.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one quick input,
we won't have nodes_to_cpus = [[0-31, 64-95], [], [32-63, 96-127], []], because [] is not a subset of cpu_allow_list.

I used intersect according to your input in the PR, and I saw same test output for issubset and intersect like below diagram.
image

I actually don't understand the differences between using intersection and using issubset, but they look both ok to me.
Let me know whether the new changes work for you or not.

Since I don't have empty numa node, I couldn't test empty numa nodes with our codes.
However, below codes might still have issues with empty numa node after changing into intersection.
numa_size = info.get_numa_configured_nodes(); cpu_count_per_numa = cpu_count / numa_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because [] is not a subset of cpu_allow_list

An empty set is a subset of any set, including an empty set.

I actually don't understand the differences between using intersection and using issubset, but they look both ok to me.

These are strong indications that this logic should be written into its own function so that you can write unit tests for it. It will give understanding to everyone how the logic will behave on different systems with different CPU id numbering and different CPUs allowed sets. And it will be very helpful for ensuring that the logic will work in the future, too, when new developers touch this part of the code. Without unit tests they will very easily break it, because quite likely they change the code with only the platforms they know in mind.

Maybe something like:

def auto_cpu_binding(world_size, rank, node_cpus, cpus_allowed, cpus_reserved):
    """
    Bind the process to a set of CPUs based on its rank and the available CPUs on the node.

    Args:
        world_size (int): number of processes.
        rank (int): rank of the current process.
        node_cpus (list of sets): node_cpus[i] contains CPUs on node i, can be empty.
        cpus_allowed (set): set of CPUs allowed for binding.
        cpus_reserved (int): number of CPUs per rank reserved for other purposes.

    Returns a pair:
        - set of CPUs that the process should be bound to, or empty if no binding is possible.
        - message indicating the binding decision.
    """
    if world_size <= 0:
        return set(), "invalid world size"
    if rank < 0:
        return set(), "invalid rank"
    if cpus_reserved < 0:
        return set(), "invalid reserved CPUs"
    if rank >= world_size:
        return set(), "rank exceeds world size"

    nodes_with_cpus = [cpus & cpus_allowed for cpus in node_cpus if cpus & cpus_allowed]
    if len(nodes_with_cpus) != world_size:
        return set(), f"number of nodes with allowed CPUs ({len(nodes_with_cpus)}) does not match world size ({world_size})"

    min_cpus_per_node = min(len(cpus) for cpus in nodes_with_cpus)
    if min_cpus_per_node - cpus_reserved <= 0:
            return set(), f"not enough CPUs per node ({min_cpus_per_node}) for reserved CPUs ({cpus_reserved}) and at least one CPU per rank"

    bind_cpus = sorted(nodes_with_cpus[rank])[:min_cpus_per_node - cpus_reserved]
    return set(bind_cpus), "binding rank per node"

# This shows how to play with auto_cpu_binding without having
# to test the function in more exotic hardware setups. This also works
# as a non-trivial case for unit tests:
# - asymmetric NUMA node sizes
# - asymmetric cpus_allowed on different NUMAs
# - exotic CPU numbering (still fully realistic)
if __name__ == "__main__":
    world_size = 2
    node_cpus = [{0, 2, 4, 6, 8, 10}, set(), {1, 3, 5, 7, 9}, set()]
    cpus_allowed = {2, 4, 6, 1, 5, 7, 9}
    reserved_cpus = 1

    for rank in range(world_size):
        cpus = auto_cpu_binding(world_size, rank, node_cpus, cpus_allowed, reserved_cpus)
        print(f"Rank {rank} of {world_size-1}: CPUs: {cpus[0]}, Message: {cpus[1]}")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving the implementations into a separate function accordingly. also handled the non-continuous cpd-id cases.

@louie-tsai louie-tsai requested review from askervin and bigPYJ1151 May 23, 2025 01:27
@louie-tsai louie-tsai changed the title [WIP] automatically bind CPU OMP Threads of a rank to CPU ids of a NUMA node. Automatically bind CPU OMP Threads of a rank to CPU ids of a NUMA node. May 23, 2025
Copy link
Member

@bigPYJ1151 bigPYJ1151 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are closing to be ready.

Please add some note to the CPU doc for the new added env. And fix the code-style checks, you can refer to this for auto format and lint.

@mergify
Copy link

mergify bot commented May 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @louie-tsai.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 23, 2025
@mergify mergify bot removed the needs-rebase label May 26, 2025
@louie-tsai louie-tsai force-pushed the auto_binding branch 4 times, most recently from 50a3d08 to 6ff056a Compare May 28, 2025 05:11
@louie-tsai louie-tsai requested a review from hmellor as a code owner May 28, 2025 05:20
@louie-tsai
Copy link
Contributor Author

We are closing to be ready.

Please add some note to the CPU doc for the new added env. And fix the code-style checks, you can refer to this for auto format and lint.

@bigPYJ1151 updated the cpu.md accordingly

@mergify mergify bot added the documentation Improvements or additions to documentation label May 28, 2025
@louie-tsai louie-tsai requested a review from bigPYJ1151 May 28, 2025 05:29
@louie-tsai louie-tsai force-pushed the auto_binding branch 5 times, most recently from 706b5de to 9295509 Compare May 29, 2025 22:02
@bigPYJ1151
Copy link
Member

Hi @DarkLight1337 @Isotr0py This PR has did some review rounds and looks good to me. Can you help to take a look on this? Thanks!

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:15 [cpu_worker.py:145] auto thread-binding list: 0,1,2,3,4,5,6,7,8,9,10,11
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:145] auto thread-binding list: 12,13,14,15,16,17,18,19,20,21,22,23
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51] OMP threads binding of Process 1186205:
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186205, core 12
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186674, core 13
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186675, core 14
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186676, core 15
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186677, core 16
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186678, core 17
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186679, core 18
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186680, core 19
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186681, core 20
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186682, core 21
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186683, core 22
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51]  OMP tid: 1186684, core 23
(VllmWorker rank=1 pid=1186205) INFO 06-05 16:59:15 [cpu_worker.py:51] 
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51] OMP threads binding of Process 1186204:
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186204, core 0
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186720, core 1
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186721, core 2
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186722, core 3
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186724, core 4
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186725, core 5
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186726, core 6
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186727, core 7
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186728, core 8
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186729, core 9
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186730, core 10
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]  OMP tid: 1186731, core 11
(VllmWorker rank=0 pid=1186204) INFO 06-05 16:59:16 [cpu_worker.py:51]

Given the auto bind results locally, this PR LGTM! Thanks for this improvement!

rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
from importlib import util
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to lazy import importlib.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to remove the numa and psutil checking assuming since it is installed via cpu.txt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep the numa and psutil check here, since we don't have them installed for MacOS.

I meant we can directly import importlib.util at the top-level importing :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the input. addressed it accordingly.

rank_to_cpus = self.local_omp_cpuid
# Setup OpenMP thread affinity based on NUMA nodes automatically
world_size = self.vllm_config.parallel_config.world_size
from importlib import util
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

@xuechendi
Copy link
Contributor

Hi @DarkLight1337 @bigPYJ1151, if the PR looks good to you, please help to approve.

Signed-off-by: Tsai, Louie <[email protected]>
@Isotr0py Isotr0py enabled auto-merge (squash) June 10, 2025 03:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 10, 2025
@Isotr0py Isotr0py merged commit 9368cc9 into vllm-project:main Jun 10, 2025
73 checks passed
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"

# Share the cpusets list among ranks by spawning process instead
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we double set this. duplicate to L199. may due to rebase not auto merge this change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants