Skip to content

Conversation

@biswapanda
Copy link
Contributor

@biswapanda biswapanda commented Oct 29, 2025

Purpose

This PR fixes a bug in the KV cache block storage system where the code was incorrectly trying to access request.lora_request.id instead of the correct request.lora_request.adapter_id property.

Root Cause:
The LoRARequest class in vllm/lora/request.py does not have an id field. Instead, it has:

  • lora_int_id: The actual ID field
  • adapter_id: A property that returns lora_int_id

Sample Error:

2025-10-29T07:10:51.703694Z ERROR core.run_engine_core: EngineCore encountered a fatal error.
Traceback (most recent call last):
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 701, in run_engine_core
    engine_core.run_busy_loop()
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 728, in run_busy_loop
    self._process_engine_step()
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 754, in _process_engine_step
    outputs, model_executed = self.step_fn()
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/engine/core.py", line 283, in step
    scheduler_output = self.scheduler.schedule()
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/core/sched/scheduler.py", line 255, in schedule
    new_blocks = self.kv_cache_manager.allocate_slots(
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/core/kv_cache_manager.py", line 302, in allocate_slots
    self.coordinator.cache_blocks(request, num_tokens_to_cache)
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/core/kv_cache_coordinator.py", line 129, in cache_blocks
    manager.cache_blocks(request, num_computed_tokens)
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/core/single_type_kv_cache_manager.py", line 145, in cache_blocks
    self.block_pool.cache_full_blocks(
  File "/home/biswaranjanp/dev/.venv/lib/python3.10/site-packages/vllm/v1/core/block_pool.py", line 252, in cache_full_blocks
    lora_id=request.lora_request.id
AttributeError: 'LoRARequest' object has no attribute 'id'

Test Plan

Added unit test case test_kv_cache_events_with_lora in test_prefix_caching.py

pytest -v  tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora

Test Result

After

tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[2] PASSED                                              [ 33%]
tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[3] PASSED                                              [ 66%]
tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[10] PASSED                                             [100%]

Before

======================================================= test session starts =======================================================
platform linux -- Python 3.12.3, pytest-8.3.5, pluggy-1.5.0 -- /home/biswaranjanp/dev/oss-vllm/.venv/bin/python
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase(PosixPath('/home/biswaranjanp/dev/oss-vllm/.hypothesis/examples'))
rootdir: /home/biswaranjanp/dev/oss-vllm
configfile: pyproject.toml
plugins: timeout-2.3.1, cov-6.3.0, buildkite-test-collector-0.1.9, rerunfailures-14.0, forked-1.6.0, schemathesis-3.39.15, hypothesis-6.131.0, shard-0.1.2, asyncio-0.24.0, mock-3.14.0, anyio-4.6.2.post1, subtests-0.14.1, hydra-core-1.3.2
asyncio: mode=Mode.STRICT, default_loop_scope=None
collected 3 items                                                                                                                 
Running 3 items in this shard: tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[2], tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[3], tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[10]

tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[2] FAILED                                              [ 33%]
tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[3] FAILED                                              [ 66%]
tests/v1/core/test_prefix_caching.py::test_kv_cache_events_with_lora[10] FAILED                                             [100%]

============================================================ FAILURES =============================================================
________________________________________________ test_kv_cache_events_with_lora[2] ________________________________________________

blocks_to_cache = 2

    @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
    def test_kv_cache_events_with_lora(blocks_to_cache: int):
        """Test BlockStored events contain correct lora_id when using LoRA requests."""
        block_size = 16
        num_blocks = blocks_to_cache + 1
    
        # Create KVCacheManager with events enabled
        manager = KVCacheManager(
            make_kv_cache_config(block_size, num_blocks),
            max_model_len=8192,
            enable_caching=True,
            enable_kv_cache_events=True,
        )
    
        # Test with LoRA request
        lora_request = LoRARequest(
            lora_name="test_lora", lora_int_id=42, lora_path="/test/path"
        )
    
        num_tokens = block_size * blocks_to_cache
        req_with_lora = make_request(
            "lora_req",
            list(range(num_tokens)),
            block_size,
            sha256,
            lora_request=lora_request,
        )
    
        # Allocate slots and get events
>       _ = manager.allocate_slots(req_with_lora, num_tokens)

tests/v1/core/test_prefix_caching.py:1371: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
vllm/v1/core/kv_cache_manager.py:331: in allocate_slots
    self.coordinator.cache_blocks(request, num_tokens_to_cache)
vllm/v1/core/kv_cache_coordinator.py:139: in cache_blocks
    manager.cache_blocks(request, num_computed_tokens)
vllm/v1/core/single_type_kv_cache_manager.py:160: in cache_blocks
    self.block_pool.cache_full_blocks(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <vllm.v1.core.block_pool.BlockPool object at 0x775f9089bc80>, request = <vllm.v1.request.Request object at 0x775f9089bc20>
blocks = [KVCacheBlock(block_id=1, ref_cnt=1, _block_hash=b'lH@mg\x7f\x9d\xe7\x00\xe7\xd6u\x99w\x82\xcc\xf5M\x8b\x0b(\xe8\x01\x...xddx{\xa9\x0e\xec\x0ejQf\xb1\xca\xa5\xafr\x177L\xd3\xb8w\x00\x00\x00\x00', prev_free_block=None, next_free_block=None)]
num_cached_blocks = 0, num_full_blocks = 2, block_size = 16, kv_cache_group_id = 0

    def cache_full_blocks(
        self,
        request: Request,
        blocks: list[KVCacheBlock],
        num_cached_blocks: int,
        num_full_blocks: int,
        block_size: int,
        kv_cache_group_id: int,
    ) -> None:
        """Cache a list of full blocks for prefix caching.
        This function takes a list of blocks that will have their block hash
        metadata to be updated and cached. Given a request, it updates the
        metadata for each block and caching it in the
        `cached_block_hash_to_block`.
        The block hashes values are computed by the Request object immediately
        when it is created and when new tokens are appended.
    
        Args:
            request: The request to cache the blocks.
            blocks: All blocks in the request.
            num_cached_blocks: The number of blocks that are already cached.
            num_full_blocks: The number of blocks that are full and should
                be cached after this function.
            block_size: Number of tokens in each block.
            kv_cache_group_id: The id of the KV cache group.
        """
        if num_cached_blocks >= num_full_blocks:
            return
        new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
        assert len(request.block_hashes) >= num_full_blocks
        new_block_hashes = request.block_hashes[num_cached_blocks:]
    
        new_hashes: list[ExternalBlockHash] | None = (
            [] if self.enable_kv_cache_events else None
        )
        for i, blk in enumerate(new_full_blocks):
            assert blk.block_hash is None
            block_hash = new_block_hashes[i]
    
            # Update and added the full block to the cache.
            block_hash_with_group_id = make_block_hash_with_group_id(
                block_hash, kv_cache_group_id
            )
            blk.block_hash = block_hash_with_group_id
            self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)
            if new_hashes is not None:
                new_hashes.append(maybe_convert_block_hash(block_hash))
    
        if self.enable_kv_cache_events:
            if num_cached_blocks == 0:
                parent_block_hash: ExternalBlockHash | None = None
            else:
                parent_block = blocks[num_cached_blocks - 1]
                assert parent_block.block_hash is not None
                parent_block_hash = maybe_convert_block_hash(
                    get_block_hash(parent_block.block_hash)
                )
    
            self.kv_event_queue.append(
                BlockStored(
                    block_hashes=new_hashes,
                    parent_block_hash=parent_block_hash,
                    token_ids=request.all_token_ids[
                        num_cached_blocks * block_size : num_full_blocks * block_size
                    ],
                    block_size=block_size,
>                   lora_id=request.lora_request.id if request.lora_request else None,
                    medium=MEDIUM_GPU,
                )
            )
E           AttributeError: 'LoRARequest' object has no attribute 'id'

vllm/v1/core/block_pool.py:262: AttributeError
________________________________________________ test_kv_cache_events_with_lora[3] ________________________________________________

blocks_to_cache = 3

    @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
    def test_kv_cache_events_with_lora(blocks_to_cache: int):
        """Test BlockStored events contain correct lora_id when using LoRA requests."""
        block_size = 16
        num_blocks = blocks_to_cache + 1
    
        # Create KVCacheManager with events enabled
        manager = KVCacheManager(
            make_kv_cache_config(block_size, num_blocks),
            max_model_len=8192,
            enable_caching=True,
            enable_kv_cache_events=True,
        )
    
        # Test with LoRA request
        lora_request = LoRARequest(
            lora_name="test_lora", lora_int_id=42, lora_path="/test/path"
        )
    
        num_tokens = block_size * blocks_to_cache
        req_with_lora = make_request(
            "lora_req",
            list(range(num_tokens)),
            block_size,
            sha256,
            lora_request=lora_request,
        )
    
        # Allocate slots and get events
>       _ = manager.allocate_slots(req_with_lora, num_tokens)

tests/v1/core/test_prefix_caching.py:1371: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
vllm/v1/core/kv_cache_manager.py:331: in allocate_slots
    self.coordinator.cache_blocks(request, num_tokens_to_cache)
vllm/v1/core/kv_cache_coordinator.py:139: in cache_blocks
    manager.cache_blocks(request, num_computed_tokens)
vllm/v1/core/single_type_kv_cache_manager.py:160: in cache_blocks
    self.block_pool.cache_full_blocks(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <vllm.v1.core.block_pool.BlockPool object at 0x775f9094d970>, request = <vllm.v1.request.Request object at 0x775f90dcb8f0>
blocks = [KVCacheBlock(block_id=1, ref_cnt=1, _block_hash=b"\xfd\x9a\xf1\\\xf9z\\\x85\xf7\xdc\xc4\xb2\xb9\xd9Jp\x80'?\x1d\x03\x...xbc\x98\x1bh\xc3<\x8d}\xd1\x06\xa1)\x19\x0e\xba_\x83\x0c\x00\x00\x00\x00', prev_free_block=None, next_free_block=None)]
num_cached_blocks = 0, num_full_blocks = 3, block_size = 16, kv_cache_group_id = 0

    def cache_full_blocks(
        self,
        request: Request,
        blocks: list[KVCacheBlock],
        num_cached_blocks: int,
        num_full_blocks: int,
        block_size: int,
        kv_cache_group_id: int,
    ) -> None:
        """Cache a list of full blocks for prefix caching.
        This function takes a list of blocks that will have their block hash
        metadata to be updated and cached. Given a request, it updates the
        metadata for each block and caching it in the
        `cached_block_hash_to_block`.
        The block hashes values are computed by the Request object immediately
        when it is created and when new tokens are appended.
    
        Args:
            request: The request to cache the blocks.
            blocks: All blocks in the request.
            num_cached_blocks: The number of blocks that are already cached.
            num_full_blocks: The number of blocks that are full and should
                be cached after this function.
            block_size: Number of tokens in each block.
            kv_cache_group_id: The id of the KV cache group.
        """
        if num_cached_blocks >= num_full_blocks:
            return
        new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
        assert len(request.block_hashes) >= num_full_blocks
        new_block_hashes = request.block_hashes[num_cached_blocks:]
    
        new_hashes: list[ExternalBlockHash] | None = (
            [] if self.enable_kv_cache_events else None
        )
        for i, blk in enumerate(new_full_blocks):
            assert blk.block_hash is None
            block_hash = new_block_hashes[i]
    
            # Update and added the full block to the cache.
            block_hash_with_group_id = make_block_hash_with_group_id(
                block_hash, kv_cache_group_id
            )
            blk.block_hash = block_hash_with_group_id
            self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)
            if new_hashes is not None:
                new_hashes.append(maybe_convert_block_hash(block_hash))
    
        if self.enable_kv_cache_events:
            if num_cached_blocks == 0:
                parent_block_hash: ExternalBlockHash | None = None
            else:
                parent_block = blocks[num_cached_blocks - 1]
                assert parent_block.block_hash is not None
                parent_block_hash = maybe_convert_block_hash(
                    get_block_hash(parent_block.block_hash)
                )
    
            self.kv_event_queue.append(
                BlockStored(
                    block_hashes=new_hashes,
                    parent_block_hash=parent_block_hash,
                    token_ids=request.all_token_ids[
                        num_cached_blocks * block_size : num_full_blocks * block_size
                    ],
                    block_size=block_size,
>                   lora_id=request.lora_request.id if request.lora_request else None,
                    medium=MEDIUM_GPU,
                )
            )
E           AttributeError: 'LoRARequest' object has no attribute 'id'

vllm/v1/core/block_pool.py:262: AttributeError
_______________________________________________ test_kv_cache_events_with_lora[10] ________________________________________________

blocks_to_cache = 10

    @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
    def test_kv_cache_events_with_lora(blocks_to_cache: int):
        """Test BlockStored events contain correct lora_id when using LoRA requests."""
        block_size = 16
        num_blocks = blocks_to_cache + 1
    
        # Create KVCacheManager with events enabled
        manager = KVCacheManager(
            make_kv_cache_config(block_size, num_blocks),
            max_model_len=8192,
            enable_caching=True,
            enable_kv_cache_events=True,
        )
    
        # Test with LoRA request
        lora_request = LoRARequest(
            lora_name="test_lora", lora_int_id=42, lora_path="/test/path"
        )
    
        num_tokens = block_size * blocks_to_cache
        req_with_lora = make_request(
            "lora_req",
            list(range(num_tokens)),
            block_size,
            sha256,
            lora_request=lora_request,
        )
    
        # Allocate slots and get events
>       _ = manager.allocate_slots(req_with_lora, num_tokens)

tests/v1/core/test_prefix_caching.py:1371: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
vllm/v1/core/kv_cache_manager.py:331: in allocate_slots
    self.coordinator.cache_blocks(request, num_tokens_to_cache)
vllm/v1/core/kv_cache_coordinator.py:139: in cache_blocks
    manager.cache_blocks(request, num_computed_tokens)
vllm/v1/core/single_type_kv_cache_manager.py:160: in cache_blocks
    self.block_pool.cache_full_blocks(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <vllm.v1.core.block_pool.BlockPool object at 0x775f90908b30>, request = <vllm.v1.request.Request object at 0x775f90908a40>
blocks = [KVCacheBlock(block_id=1, ref_cnt=1, _block_hash=b'\x855\xea\x05\x1f\x7f\xf8hp[\x89\x11\x1c[#\xd7\xc0\x88V\xd5\xb6^o\x...xb91\xa5\xfa\xa2\xb26\xb2\x86\xe5K\x12`x\xfd&n\x8b$\x00\x00\x00\x00', prev_free_block=None, next_free_block=None), ...]
num_cached_blocks = 0, num_full_blocks = 10, block_size = 16, kv_cache_group_id = 0

    def cache_full_blocks(
        self,
        request: Request,
        blocks: list[KVCacheBlock],
        num_cached_blocks: int,
        num_full_blocks: int,
        block_size: int,
        kv_cache_group_id: int,
    ) -> None:
        """Cache a list of full blocks for prefix caching.
        This function takes a list of blocks that will have their block hash
        metadata to be updated and cached. Given a request, it updates the
        metadata for each block and caching it in the
        `cached_block_hash_to_block`.
        The block hashes values are computed by the Request object immediately
        when it is created and when new tokens are appended.
    
        Args:
            request: The request to cache the blocks.
            blocks: All blocks in the request.
            num_cached_blocks: The number of blocks that are already cached.
            num_full_blocks: The number of blocks that are full and should
                be cached after this function.
            block_size: Number of tokens in each block.
            kv_cache_group_id: The id of the KV cache group.
        """
        if num_cached_blocks >= num_full_blocks:
            return
        new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
        assert len(request.block_hashes) >= num_full_blocks
        new_block_hashes = request.block_hashes[num_cached_blocks:]
    
        new_hashes: list[ExternalBlockHash] | None = (
            [] if self.enable_kv_cache_events else None
        )
        for i, blk in enumerate(new_full_blocks):
            assert blk.block_hash is None
            block_hash = new_block_hashes[i]
    
            # Update and added the full block to the cache.
            block_hash_with_group_id = make_block_hash_with_group_id(
                block_hash, kv_cache_group_id
            )
            blk.block_hash = block_hash_with_group_id
            self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)
            if new_hashes is not None:
                new_hashes.append(maybe_convert_block_hash(block_hash))
    
        if self.enable_kv_cache_events:
            if num_cached_blocks == 0:
                parent_block_hash: ExternalBlockHash | None = None
            else:
                parent_block = blocks[num_cached_blocks - 1]
                assert parent_block.block_hash is not None
                parent_block_hash = maybe_convert_block_hash(
                    get_block_hash(parent_block.block_hash)
                )
    
            self.kv_event_queue.append(
                BlockStored(
                    block_hashes=new_hashes,
                    parent_block_hash=parent_block_hash,
                    token_ids=request.all_token_ids[
                        num_cached_blocks * block_size : num_full_blocks * block_size
                    ],
                    block_size=block_size,
>                   lora_id=request.lora_request.id if request.lora_request else None,
                    medium=MEDIUM_GPU,
                )
            )
E           AttributeError: 'LoRARequest' object has no attribute 'id'

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in vllm/v1/core/block_pool.py. The code was attempting to access request.lora_request.id, but the LoRARequest object does not have an id attribute. The change replaces this with request.lora_request.adapter_id, which is the correct attribute. This prevents a runtime AttributeError when KV cache events are enabled for LoRA requests. The fix is accurate and addresses the issue effectively. I find no further issues with this change.

Copy link
Member

@markmc markmc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, the fix is correct 👍

It looks like the bug has existed since the code was merged in #16750 so it appears we have no test coverage for KV event publishing with LoRA. It would be great to get that gap plugged

@markmc markmc added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2025
@biswapanda
Copy link
Contributor Author

Agreed, this could have been caught with some tests.
I'll look into existing LoRA test cases and create a separate PR for KV event publishing tests for LoRA / base models.

@markmc
Copy link
Member

markmc commented Oct 30, 2025

Agreed, this could have been caught with some tests. I'll look into existing LoRA test cases and create a separate PR for KV event publishing tests for LoRA / base models.

Probably test_kv_cache_events() in test_prefix_caching.py is the right starting place - extend make_request() to accept a lora_request parameter and then check that the BlockStored event contains the correct lora_id

Signed-off-by: Biswa Panda <[email protected]>
@biswapanda
Copy link
Contributor Author

Thanks @markmc , I've added unit test and results. PTAL

@markmc
Copy link
Member

markmc commented Oct 31, 2025

Thanks @markmc , I've added unit test and results. PTAL

lgtm, thanks!

@biswapanda
Copy link
Contributor Author

Thank you @markmc - when will this PR be merged? what are next steps?
Asking because this is my first contribution to vLLM and I didn't this detail from doc.

@markmc
Copy link
Member

markmc commented Oct 31, 2025

Thank you @markmc - when will this PR be merged? what are next steps? Asking because this is my first contribution to vLLM and I didn't this detail from doc.

Next step is for a committer to review and approve the PR, perhaps @jeejeelee

@jeejeelee jeejeelee merged commit 1bf43ae into vllm-project:main Nov 3, 2025
47 checks passed
zhaozuy pushed a commit to zhaozuy/vllm that referenced this pull request Nov 4, 2025
git-jxj pushed a commit to git-jxj/vllm that referenced this pull request Nov 4, 2025
omerpaz95 pushed a commit to omerpaz95/vllm that referenced this pull request Nov 4, 2025
seungduk-yanolja pushed a commit to seungduk-yanolja/vllm that referenced this pull request Nov 4, 2025
seungduk-yanolja pushed a commit to seungduk-yanolja/vllm that referenced this pull request Nov 4, 2025
juliendenize pushed a commit to juliendenize/vllm that referenced this pull request Nov 6, 2025
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants