From 5be579064357a52b3ec3942e70a2355a72f2c959 Mon Sep 17 00:00:00 2001 From: donglu Date: Fri, 20 Jun 2025 13:56:22 +0000 Subject: [PATCH 1/2] shared memory based object store Signed-off-by: donglu Address comments Signed-off-by: donglu --- .buildkite/test-pipeline.yaml | 2 + docs/configuration/optimization.md | 31 +- tests/distributed/test_shm_buffer.py | 177 +++++ tests/distributed/test_shm_storage.py | 328 +++++++++ tools/check_pickle_imports.py | 1 + vllm/config/__init__.py | 20 + .../shm_object_storage.py | 635 ++++++++++++++++++ vllm/engine/arg_utils.py | 17 +- vllm/envs.py | 7 + vllm/executor/uniproc_executor.py | 8 + vllm/multimodal/cache.py | 222 +++++- vllm/utils/__init__.py | 6 + vllm/v1/engine/core.py | 4 +- vllm/v1/executor/multiproc_executor.py | 28 +- vllm/v1/executor/utils.py | 25 + 15 files changed, 1488 insertions(+), 23 deletions(-) create mode 100644 tests/distributed/test_shm_buffer.py create mode 100644 tests/distributed/test_shm_storage.py create mode 100644 vllm/distributed/device_communicators/shm_object_storage.py create mode 100644 vllm/v1/executor/utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index adbd08b13f75..9b807140b4d9 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -789,6 +789,8 @@ steps: commands: - pytest -v -s distributed/test_comm_ops.py - pytest -v -s distributed/test_shm_broadcast.py + - pytest -v -s distributed/test_shm_buffer.py + - pytest -v -s distributed/test_shm_storage.py - label: 2 Node Tests (4 GPUs in total) # 16min timeout_in_minutes: 30 diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index c853fcf92941..8ea382181165 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -230,6 +230,20 @@ Multi-modal IPC caching is automatically enabled when there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes, to avoid repeatedly transferring the same multi-modal inputs between them. +#### Key-Replicated Cache + +By default, IPC caching uses a **key-replicated cache**, where cache keys exist +in both the API (`P0`) and engine core (`P1`) processes, but the actual cache +data resides only in `P1`. + +#### Shared Memory Cache + +When multiple worker processes are involved (e.g., when TP > 1), a +**shared-memory cache** is more efficient. This can be enabled by setting +`mm_processor_cache_type="shm"`. In this mode, cache keys are stored +on `P0`, while the cache data itself lives in shared memory accessible by all +processes. + ### Configuration You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). @@ -244,6 +258,12 @@ Examples: llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=8) +# Use a shared-memory based IPC cache +llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", + tensor_parallel_size=2, + mm_processor_cache_type="shm", + mm_processor_cache_gb=8) + # Disable the cache llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) @@ -253,11 +273,12 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: -| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | -|-------------------|-------------|------------|------------|-------------| -| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | -| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | -| ❌ | ❌ | N/A | N/A | `0` | +| Cache Type | `P0` Cache | `P1` Cache | Shm Cache | Max. Memory | mm_processor_cache_type | +|-------------------|-------------|------------|------------|-------------|-------------| +| Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` | lru | +| Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` | lru | +| Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` | shm | +| Disabled | N/A | N/A | N/A | `0` | N/A | K: Stores the hashes of multi-modal items V: Stores the processed tensor data of multi-modal items diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py new file mode 100644 index 000000000000..f939475ba952 --- /dev/null +++ b/tests/distributed/test_shm_buffer.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import unittest + +from vllm.distributed.device_communicators.shm_object_storage import ( + SingleWriterShmRingBuffer) + + +class TestSingleWriterShmRingBuffer(unittest.TestCase): + """Test suite for the ring buffer implementation""" + + def setUp(self): + """Set up test fixtures""" + self.buffer_size = 4096 + self.ring_buffer = None + + def tearDown(self): + """Clean up after tests""" + if self.ring_buffer: + del self.ring_buffer + + def test_buffer_opening(self): + """Test opening an existing buffer""" + # First create a buffer + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + try: + # Then open it with another instance + reader_buffer = SingleWriterShmRingBuffer( + *self.ring_buffer.handle()) + self.assertFalse(reader_buffer.is_writer) + self.assertEqual(reader_buffer.shared_memory.name, + self.ring_buffer.shared_memory.name) + finally: + del reader_buffer + + def test_buffer_access(self): + """Test accessing allocated buffers""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + size = 100 + address, monotonic_id = self.ring_buffer.allocate_buf(size) + + # Write some test data + test_data = b"Hello, World!" * 7 # 91 bytes + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0:len(test_data)] = test_data + + # Read it back + with self.ring_buffer.access_buf(address) as (data_buf2, metadata2): + read_data = bytes(data_buf2[0:len(test_data)]) + read_id = metadata2[0] + + self.assertEqual(read_data, test_data) + self.assertEqual(read_id, monotonic_id) + + def test_memory_error_on_full_buffer(self): + """Test that MemoryError is raised when buffer is full""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True) + + # Fill up the buffer + self.ring_buffer.allocate_buf(100) + self.ring_buffer.allocate_buf(80) # Total: 196 bytes used + + # This should fail + with self.assertRaises(MemoryError): + self.ring_buffer.allocate_buf(1) # Would exceed buffer capacity + + def test_allocation_and_free(self): + """Test allocation and freeing of buffers""" + small_buffer_size = 200 + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=small_buffer_size, create=True) + + size = 80 + # Write some data + test_data = b"Repeated test data" + for i in range(5): + address, monotonic_id = self.ring_buffer.allocate_buf(size) + with self.ring_buffer.access_buf(address) as (data_buf, metadata): + data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use + data_buf[4:len(test_data) + 4] = test_data + print(self.ring_buffer.metadata) + freed_ids = self.ring_buffer.free_buf(lambda *args: True) + print(f" Freed IDs: {freed_ids}") + self.assertEqual(freed_ids[0], i) + + def test_clear_buffer(self): + """Test clearing the buffer""" + self.ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=self.buffer_size, create=True) + + # Allocate some buffers + for _ in range(3): + self.ring_buffer.allocate_buf(100) + + # Clear the buffer + self.ring_buffer.clear() + + # Check that metadata is empty and IDs reset + self.assertEqual(len(self.ring_buffer.metadata), 0) + self.assertEqual(self.ring_buffer.monotonic_id_start, 0) + self.assertEqual(self.ring_buffer.monotonic_id_end, 0) + self.assertEqual(self.ring_buffer.data_buffer_start, 0) + self.assertEqual(self.ring_buffer.data_buffer_end, 0) + + +def main(): + """Main function demonstrating usage and running tests""" + print("=== SingleWriterShmRingBuffer Test Suite ===\n") + + # Run unit tests + print("Running unit tests...") + unittest.main(argv=[""], exit=False, verbosity=2) + + print("\n" + "=" * 50) + print("=== Manual Demo ===\n") + + # Manual demonstration + try: + print("Creating ring buffer...") + writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, + create=True) + reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle()) + + print(f"Buffer created with name: {writer_buffer.shared_memory.name}") + + # Allocate some buffers + print("\nAllocating buffers...") + address_array = [] + for i in range(3): + size = 100 + i * 50 + try: + writer_buffer.free_buf(lambda *args: True) + address, monotonic_id = writer_buffer.allocate_buf(size) + address_array.append((address, size, monotonic_id)) + + # Write some test data + with writer_buffer.access_buf(address) as (data_buf, metadata): + test_message = f"Test message {i}".encode() + data_buf[0:len(test_message)] = test_message + + except MemoryError as e: + print(f" Failed to allocate {size} bytes: {e}") + + print("\nBuffer state:") + print(f" Data buffer start: {writer_buffer.data_buffer_start}") + print(f" Data buffer end: {writer_buffer.data_buffer_end}") + print(f" Monotonic ID start: {writer_buffer.monotonic_id_start}") + print(f" Monotonic ID end: {writer_buffer.monotonic_id_end}") + print(f" Metadata entries: {len(writer_buffer.metadata)}") + + # Try to read back the data + print("\nReading back data...") + for address, size, monotonic_id in address_array: + with reader_buffer.access_buf(address) as (data_buf, metadata): + # Find null terminator or read first 50 chars + data_bytes = bytes(data_buf[0:size]) + message = data_bytes.decode() + print(f" ID {monotonic_id}: '{message}'") + + except Exception as e: + print(f"Demo error: {e}") + import traceback + + traceback.print_exc() + + print("\n=== Demo Complete ===") + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py new file mode 100644 index 000000000000..908ed33dff6c --- /dev/null +++ b/tests/distributed/test_shm_storage.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +import random +import time +import unittest +from multiprocessing import Lock + +import torch + +# Assuming these are imported from your module +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) +from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, + MultiModalSharedField) + + +def _dummy_elem(modality: str, key: str, size: int): + return MultiModalFieldElem( + modality=modality, + key=key, + data=torch.empty((size, ), dtype=torch.int8), + field=MultiModalSharedField(1), + ) + + +def _dummy_item(modality: str, size_by_key: dict[str, int]): + return MultiModalKwargsItem.from_elems([ + _dummy_elem(modality, key, size) for key, size in size_by_key.items() + ]) + + +class TestSingleWriterShmObjectStorage(unittest.TestCase): + + def setUp(self): + """Set up test fixtures before each test method.""" + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + self.storage = SingleWriterShmObjectStorage( + max_object_size=1024 * 10, # 10KB max object + n_readers=2, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + def tearDown(self): + """Clean up after each test.""" + if self.storage: + del self.storage + + def test_minimal_put_get_cycle(self): + """Test basic put and get operations.""" + key = "test_key" + value = _dummy_item("text", {"field1": 10, "field2": 20}) + + # Put operation + address, monotonic_id = self.storage.put(key, value) + + # Verify key is in index + self.assertIn(key, self.storage.key_index) + self.assertEqual(self.storage.key_index[key], (address, monotonic_id)) + self.assertEqual(self.storage.id_index[monotonic_id], key) + + # Get operation + result = self.storage.get(address, monotonic_id) + + # Verify result + self.assertEqual(result, value) + + def test_put_same_key_twice(self): + """Test behavior when putting the same key multiple times.""" + key = "duplicate_key" + value1 = "first value" + value2 = "second value" + + # First put + address1, id1 = self.storage.put(key, value1) + retrieved1 = self.storage.get(address1, id1) + self.assertEqual(retrieved1, value1) + + # should raise an error on second put + with self.assertRaises(ValueError) as context: + self.storage.put(key, value2) + + self.assertIn("already exists in the storage", str(context.exception)) + + def test_large_object_rejection(self): + """Test that objects exceeding max_object_size are rejected.""" + # Create an object larger than max_object_size + large_data = "x" * (self.storage.max_object_size + 100) + + with self.assertRaises(ValueError) as context: + self.storage.put("large_key", large_data) + + self.assertIn("exceeds max object size", str(context.exception)) + + def test_buffer_overflow_and_cleanup(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessible + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items)) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # Verify that some items are still accessibles + for key, original_value, address, monotonic_id in stored_items: + try: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + except ValueError as e: + print(f"Error retrieving {key}: {e}") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(accessible_count, len(stored_items)) + + def test_blocking_unread_object(self): + """Test behavior when buffer fills up and needs cleanup.""" + # Fill up the buffer with many small objects + stored_items = [] + + try: + for i in range(1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read all items except the first one + # to simulate a blocking situation + accessible_count = 0 + for key, original_value, address, monotonic_id in stored_items[1:]: + for i in range(self.storage.n_readers): + retrieved = self.storage.get(address, monotonic_id) + if retrieved == original_value: + accessible_count += 1 + + self.assertEqual(accessible_count, len(stored_items) - 1) + + try: + key = f"item_{len(stored_items)}" + value = f"data_{len(stored_items)}" * 100 + address, monotonic_id = self.storage.put(key, value) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # read the first item + for i in range(self.storage.n_readers): + key, original_value, address, monotonic_id = stored_items[0] + retrieved = self.storage.get(address, monotonic_id) + self.assertEqual(retrieved, original_value) + + try: + for i in range(len(stored_items), 1000): # Try to store many items + key = f"item_{i}" + value = f"data_{i}" * 100 # Make it reasonably sized + address, monotonic_id = self.storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + except MemoryError: + print(f"Buffer filled after {len(stored_items)} items") + + # some items from the first batch may still be accessible + self.assertGreaterEqual(len(stored_items), accessible_count + 10) + + def test_invalid_get_operations(self): + """Test various invalid get operations.""" + # Test with non-existent address + with self.assertRaises(ValueError): # Could be various exceptions + self.storage.get(99999, 1) + + # Store something first + address, monotonic_id = self.storage.put("test", "value") + + # Test with wrong monotonic_id + with self.assertRaises(ValueError) as context: + self.storage.get(address, monotonic_id + 100) + + self.assertIn("has been modified or is invalid", \ + str(context.exception)) + + def test_clear_storage(self): + """Test clearing the storage.""" + # Store some items + for i in range(5): + self.storage.put(f"item_{i}", f"value_{i}") + + # Clear the storage + self.storage.clear() + + # Verify that all indices are empty + self.assertEqual(len(self.storage.key_index), 0) + self.assertEqual(len(self.storage.id_index), 0) + self.assertEqual(len(self.storage.ring_buffer.metadata), 0) + + # Verify that new items can be added after clearing + address, monotonic_id = self.storage.put("new_item", "new_value") + self.assertIn("new_item", self.storage.key_index) + self.assertEqual((address, monotonic_id), (0, 0)) + + +# Reader process function +def reader_process(process_id, storage_handle, items_to_read): + """Reader process that connects to existing shared memory and reads data.""" + reader_storage = SingleWriterShmObjectStorage.create_from_handle( + storage_handle) + + print(f"Reader {process_id} started") + + errors = [] + + for key, original_value, address, monotonic_id in items_to_read: + time.sleep(random.random() / 100) + try: + # Read data from shared memory + retrieved_value = reader_storage.get(address, monotonic_id) + + # Verify data integrity + assert retrieved_value == original_value + print(f"Reader {process_id} retrieved {key}: {retrieved_value}") + except Exception as e: + errors.append((key, str(e), type(e).__name__)) + + +def run_multiprocess_example(): + """Run a minimal working example with real shared memory.""" + print("=== Minimal Object Storage Example ===") + + try: + # Create storage instance + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=1024 * 100, + create=True, # 10 MB buffer + ) + storage = SingleWriterShmObjectStorage( + max_object_size=1024, + n_readers=3, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=Lock(), + ) + + print(f"Created storage (writer: {storage.is_writer})") + + # Test basic data types + test_data = [ + ("user_data", { + "name": "Alice", + "age": 30, + "scores": [95, 87, 92] + }), + ("simple_string", "Hello, World!"), + ("number", 42), + ("list_data", [1, 2, 3, "four", 5.0]), + ] + + stored_items = [] + + # Store all data + for key, value in test_data: + print(f"Storing {key}: {value}") + address, monotonic_id = storage.put(key, value) + stored_items.append((key, value, address, monotonic_id)) + print(f" -> Stored at address {address}, ID {monotonic_id}") + + print("\n--- Retrieving Data ---") + processes = [] + handle = storage.handle() + # initialize lock for reader processes + handle.reader_lock = Lock() + for i in range(storage.n_readers): + p = multiprocessing.Process(target=reader_process, + args=(i, handle, stored_items)) + processes.append(p) + p.start() + + for p in processes: + p.join(timeout=10) + if p.is_alive(): + p.terminate() + p.join() + + except Exception as e: + print(f"Error in minimal example: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + # Run the minimal example first + run_multiprocess_example() + print("\n" + "=" * 50 + "\n") + + # Run the test suite + print("Running comprehensive test suite...") + unittest.main(verbosity=2, exit=False) diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index ad0ae45d1d46..fe717121db40 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -39,6 +39,7 @@ 'vllm/engine/multiprocessing/client.py', 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', + 'vllm/distributed/device_communicators/shm_object_storage.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', 'benchmarks/kernels/benchmark_lora.py', diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 24eaf2e360ab..ef4eaceb3d07 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -262,6 +262,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] MMEncoderTPMode = Literal["weights", "data"] +MMCacheType = Literal["shm", "lru"] class LogprobsMode(enum.Enum): @@ -450,6 +451,13 @@ class ModelConfig: `mm_processor_cache_gb * (api_server_count + data_parallel_size)`. Set to `0` to disable this cache completely (not recommended).""" + mm_processor_cache_type: MMCacheType = "lru" + """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`, + use shared memory FIFO cache. If `lru`, use mirrored LRU cache.""" + mm_shm_cache_max_object_size_mb: int = 128 + """Size limit (in MiB) for each object stored in the multi-modal processor + shared memory cache. Only effective when `mm_processor_cache_type` is + `"shm"`.""" mm_encoder_tp_mode: MMEncoderTPMode = "weights" """Indicates how to optimize multi-modal encoder inference using tensor parallelism (TP). @@ -881,6 +889,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_processor_cache_type=self.mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=self. + mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, interleave_mm_strings=self.interleave_mm_strings, skip_mm_profiling=self.skip_mm_profiling, @@ -2443,6 +2454,15 @@ class MultiModalConfig: Set to `0` to disable this cache completely (not recommended). """ + mm_processor_cache_type: MMCacheType = "lru" + """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`, + use shared memory FIFO cache. If `lru`, use mirrored LRU cache.""" + + mm_shm_cache_max_object_size_mb: int = 128 + """Size limit (in MiB) for each object stored in the multi-modal processor + shared memory cache. Only effective when `mm_processor_cache_type` is + `"shm"`.""" + mm_encoder_tp_mode: MMEncoderTPMode = "weights" """ Indicates how to optimize multi-modal encoder inference using diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py new file mode 100644 index 000000000000..3fac104bda1e --- /dev/null +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -0,0 +1,635 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pickle +from abc import ABC, abstractmethod +from collections.abc import Iterable +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import chain +from multiprocessing import shared_memory +from multiprocessing.synchronize import Lock as LockType +from typing import Any, Callable, Optional, Union +from unittest.mock import patch + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SingleWriterShmRingBuffer: + """ + A single-writer, multiple-reader ring buffer implementation using shared + memory. This class provides a thread-safe ring buffer where one process + can write data while multiple processes/threads can read from it. + + Architecture: + - Uses shared memory for cross-process communication + - Maintains metadata for each allocated buffer chunk in the writer process + - Supports custom "is_free_fn" functions to determine when buffers can be + reused + - Each buffer chunk contains: [4-byte id][4-byte size][actual_data] + + Key Concepts: + - monotonic_id_start/end: Track the range of active buffer IDs + - data_buffer_start/end: Track the physical memory range in use + - Automatic wraparound when reaching buffer end + - Lazy garbage collection based on is_free_fn checks + + Example Usage Scenarios: + + Scenario 1: Simple Linear Allocation + ``` + Buffer size: 100 bytes + Initial state: [................................................. ] + ^start=end(0) + + After allocating 20 bytes (id=0): + [id:0|size:20|data........][...................................] + ^start(0) ^end(28) + + After allocating 30 bytes (id=1): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + ``` + + Scenario 2: Memory Reclamation + ``` + Before freeing (both buffers still in use): + [id:0|size:20|data........][id:1|size:30|data..............][..] + ^start(0) ^end(66) + + After id:0 is marked free by readers: + [FREED.................... ][id:1|size:30|data..............][..] + ^start(28) ^end(66) + + After both are freed: + [FREED..............................................][..] + ^start=end(66) + ``` + + Scenario 3: Wraparound Allocation (continuing from Scenario 2) + ``` + Starting from after memory reclamation in Scenario 2: + [FREED..............................................][..] + ^start=end(66) + + Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + ``` + + Scenario 4: Error Handling - Out of Space + ``` + Starting from after wraparound allocation in Scenario 3: + [id:2|size:40|data........................][FREED.............][..] + ^end(148) ^start(66) + + Trying to allocate 20 more bytes: + occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100) + -> Raises MemoryError: "Not enough space in the data buffer" + ``` + + Thread Safety: + - Single writer: Only one process/thread should write (allocate_buf) + - Multiple readers: Multiple processes/threads can read (access_buf) + - Reader synchronization handled by is_free_fn callback + - Writer handles garbage collection (free_buf) based on reader feedback + + Memory Layout per Buffer Chunk: + [4-byte monotonic_id][4-byte chunk_size][actual_data...] + ^metadata_start ^data_start + + The monotonic_id ensures data integrity - readers can verify they're + accessing the correct data even after buffer wraparound or reuse. + """ + + def __init__( + self, + data_buffer_size: int, + name: Optional[str] = None, + create: bool = False, + ): + self.data_buffer_size = data_buffer_size + self.is_writer = create + + self.ID_NBYTES = 4 + self.ID_MAX = 2**31 # exclusive, so 2**31 - 1 is the max value + self.SIZE_NBYTES = 4 + # 4 bytes for id, 4 bytes for buffer size + self.MD_SIZE = self.ID_NBYTES + self.SIZE_NBYTES + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + if create: + # we are creating a buffer + self.metadata = { + self.monotonic_id_end: self.data_buffer_end + } # monotonic_id -> start address + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.data_buffer_size, name=name) + else: + # we are opening an existing buffer + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + self.shared_memory = shared_memory.SharedMemory(name=name) + # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa + # Some platforms allocate memory based on page size, + # so the shared memory block size may be larger or equal + # to the requested size. The size parameter is ignored + # when attaching to an existing block. + assert self.shared_memory.size >= self.data_buffer_size + + logger.debug("Shared memory created/opened with name: %s, size: %d", + self.shared_memory.name, self.data_buffer_size) + + def handle(self): + return ( + self.data_buffer_size, + self.shared_memory.name, + ) + + def clear(self) -> None: + """Clear the ring buffer.""" + assert self.is_writer, "Only the writer can clear the buffer." + self.metadata.clear() + self.monotonic_id_end = 0 + self.monotonic_id_start = 0 + self.data_buffer_start = 0 + self.data_buffer_end = 0 + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_writer: + self.shared_memory.unlink() + + def int2byte(self, integer: int) -> bytes: + """Convert an integer to bytes.""" + return integer.to_bytes(self.ID_NBYTES, "little", signed=True) + + def byte2int(self, byte_data: bytes) -> int: + """Convert bytes back to an integer.""" + return int.from_bytes(byte_data, "little", signed=True) + + def allocate_buf(self, size: int) -> tuple[int, int]: + ''' + Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory. + Memory layout: + [4-byte monotonic_id][4-byte size][buffer data...] + ''' + assert self.is_writer, "Only the writer can allocate buffers." + assert size > 0, "Size must be greater than 0" + size += self.MD_SIZE # add metadata size to the buffer size + # reset to beginning if the buffer does have enough contiguous space + buffer_end_reset = self.data_buffer_end % self.data_buffer_size + if buffer_end_reset + size > self.data_buffer_size: + buffer_end_reset = (self.data_buffer_end // self.data_buffer_size + + 1) * self.data_buffer_size + else: # no reset needed + buffer_end_reset = self.data_buffer_end + + # check if we have enough space in the data buffer + # i.e. if the new end (self.data_buffer_end + size) + # exceeds the start of the data buffer + occupied_size_new = buffer_end_reset + size - self.data_buffer_start + if occupied_size_new > self.data_buffer_size: + raise MemoryError("Not enough space in the data buffer, " + "try calling free_buf() to free up space") + self.data_buffer_end = buffer_end_reset + + # first 4 bytes as the monotonic id + buf_idx = self.data_buffer_end % self.data_buffer_size + self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \ + self.int2byte(self.monotonic_id_end) + # next 4 bytes as the size of the data buffer + self.shared_memory.buf[buf_idx + self.ID_NBYTES: \ + buf_idx + self.MD_SIZE] = self.int2byte(size) + + # record metadata + self.metadata[self.monotonic_id_end % + self.ID_MAX] = self.data_buffer_end + # update buffer and monotonic id indices + current_buffer_end = self.data_buffer_end + current_id_end = self.monotonic_id_end + self.data_buffer_end += size + self.monotonic_id_end = (self.monotonic_id_end + 1) % self.ID_MAX + return current_buffer_end, current_id_end + + @contextmanager + def access_buf(self, address: int): + buf_idx = address % self.data_buffer_size + + # read metadata + metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE] + id = self.byte2int(metadata_buff[:self.ID_NBYTES]) + size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE]) + + # yield the data buffer and metadata + data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx + + size] + with (memoryview(data_buff) as data_view, ): + yield data_view, (id, size) + + def free_buf(self, + is_free_fn: Callable[[int, memoryview], bool], + nbytes: Optional[int] = None) -> Iterable[int]: + ''' + Free a buffer of the given size. This is a no-op in shared memory, + but we need to keep track of the metadata. + + If freed memory spreads across the end and start of the ring buffer, + the actual freed memory will be in two segments. In this case there + still might not be a contiguous space of `nbytes` available. + + Args: + nbytes (int, optional): The size of the buffer to free. If None, + frees the maximum size of the ring buffer. + ''' + + assert self.is_writer, "Only the writer can free buffers." + logger.debug( + "Freeing up space in the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", + self.monotonic_id_start, self.monotonic_id_end) + monotonic_id_before = self.monotonic_id_start + # if nbytes is None, free up the maximum size of the ring buffer + if nbytes is None: + nbytes = self.data_buffer_size + freed_bytes = 0 + while self.monotonic_id_start in self.metadata and freed_bytes < nbytes: + address = self.metadata[self.monotonic_id_start] + with self.access_buf(address) as (data_buff, metadata): + if is_free_fn(self.monotonic_id_start, data_buff): + # check passed, we can free the buffer + del self.metadata[self.monotonic_id_start] + self.monotonic_id_start = ((self.monotonic_id_start + 1) % + self.ID_MAX) + self.data_buffer_start = address + freed_bytes += metadata[1] + else: + # there are still readers, we cannot free the buffer + break + + logger.debug( + "Freed %d bytes from the ring buffer, " + "monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes, + self.monotonic_id_start, self.monotonic_id_end) + + # buffer wrap around + if self.data_buffer_start >= self.data_buffer_size: + self.data_buffer_start -= self.data_buffer_size + self.data_buffer_end -= self.data_buffer_size + + monotonic_id_after = self.monotonic_id_start + # id wrap around + if monotonic_id_after >= monotonic_id_before: + return range(monotonic_id_before, monotonic_id_after) + else: + return chain(range(monotonic_id_before, self.ID_MAX), + range(0, monotonic_id_after)) + + +class ObjectSerde(ABC): + + @abstractmethod + def serialize(self, value: Any) -> tuple[Any, int, bytes, int]: + """Serialize an object to bytes.""" + raise NotImplementedError + + @abstractmethod + def deserialize(self, data: memoryview) -> Any: + """Deserialize bytes back to an object.""" + raise NotImplementedError + + +class MsgpackSerde(ObjectSerde): + + def __init__(self): + # Delayed import to avoid circular dependency + from vllm.multimodal.inputs import MultiModalKwargsItem + from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + + self.encoder = MsgpackEncoder() + self.tensor_decoder = MsgpackDecoder(torch.Tensor) + self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem) + self._mm_kwargs_item_cls = MultiModalKwargsItem + + def serialize( + self, + value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]: + len_arr = None + if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)): + type_name = type(value).__name__ + value = self.encoder.encode(value) + len_arr = [len(s) for s in value] + nbytes = sum(len_arr) + else: + value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL) + type_name = type(value).__name__ + nbytes = len(value) + + object_metadata = (type_name, nbytes, len_arr) + serialized_metadata = pickle.dumps(object_metadata, + protocol=pickle.HIGHEST_PROTOCOL) + return value, nbytes, serialized_metadata, len(serialized_metadata) + + def deserialize(self, data_view: memoryview) -> Any: + # pickle.loads do not read past the end of a pickled object + # within a large buffer, so we can skip storing the metadata size + type_name, nbytes, len_arr = pickle.loads(data_view) + serialized_data = bytearray(data_view[-nbytes:]) + + if type_name == torch.Tensor.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx:start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.tensor_decoder.decode(obj) + elif type_name == self._mm_kwargs_item_cls.__name__: + obj = [] + start_idx = 0 + for length in len_arr: + item_bytes = serialized_data[start_idx:start_idx + length] + obj.append(item_bytes) + start_idx += length + obj = self.mm_decoder.decode(obj) + elif type_name == bytes.__name__: + obj = pickle.loads(serialized_data) + else: + raise ValueError( + f"Unsupported object type '{type_name}' in metadata") + + return obj + + +@dataclass +class ShmObjectStorageHandle: + max_object_size: int + n_readers: int + ring_buffer_handle: tuple[int, str] + serde_class: type[ObjectSerde] + reader_lock: Optional[LockType] + + +class SingleWriterShmObjectStorage: + """ + A single-writer, multiple-reader object storage system built on top of a + shared memory ring buffer. Provides key-value storage with automatic memory + management and cross-process serialization support. + + This storage system follows a FIFO (First-In-First-Out) eviction policy + where the oldest objects are automatically freed when memory runs low. + Memory is reclaimed based on reader reference counting - objects are only + freed when all readers have finished accessing them. + + Architecture: + - Single writer process can put(key, value) objects + - Multiple reader processes can get(address, monotonic_id) objects + - Built on SingleWriterShmRingBuffer for efficient shared memory management + - Thread-safe operations with reader synchronization via locks + + Key Features: + - FIFO Eviction: Oldest objects are evicted first when memory is full + - Reference Counting: Objects are only freed when no readers are + accessing them + - Duplicate Key Handling: Existing keys are not overwritten, just + re-referenced + - Customized Serialization: By default uses Msgpack for efficient + serialization of Python objects, but can be extended for custom types + - Cross-Process Safety: Uses shared memory with proper synchronization + - Automatic Cleanup: Garbage collection happens transparently during + allocation + + Memory Layout per Object: + [4-byte reference_count][metadata_size][serialized_object_data] + + Thread Safety: + - Writer operations (put, clear) are single-threaded by design + - Reader operations (get) are thread-safe with lock-based reference + counting + - Memory reclamation is handled exclusively by the writer process + """ + + def __init__( + self, + max_object_size: int, + n_readers: int, + ring_buffer: SingleWriterShmRingBuffer, + serde_class: type[ObjectSerde] = MsgpackSerde, + reader_lock: Optional[LockType] = None, + ): + """ + Initialize the object storage. + + Args: + max_object_size: Maximum size for a single object in bytes. + n_readers: Number of reader processes that can access the storage. + ring_buffer: The shared memory ring buffer for storing objects. + serde_class: Serializer/deserializer for objects. + reader_lock: Optional lock for synchronizing reader access. + Raises: + ValueError: If reader_lock is None for readers. + """ + + self.max_object_size = max_object_size + self.n_readers = n_readers + self.serde_class = serde_class + self.ser_de = serde_class() + self.ring_buffer = ring_buffer + self.is_writer = self.ring_buffer.is_writer + + self.flag_bytes = 4 # for in-use flag + + if self.is_writer: + # Key-value mapping: key -> (address, monotonic_id) + self.key_index: dict[str, tuple[int, int]] = {} + # Reverse mapping: monotonic_id -> key + self.id_index: dict[int, str] = {} + # Writer flag to track in-use status: monotonic_id -> count + self.writer_flag: dict[int, int] = {} + else: + if reader_lock is None: + raise ValueError("Lock must be provided for readers.") + + self._reader_lock = reader_lock + + def clear(self) -> None: + """Clear the object storage.""" + if self.is_writer: + self.ring_buffer.clear() + self.key_index.clear() + self.id_index.clear() + self.writer_flag.clear() + logger.debug("Object storage cleared and reinitialized.") + + def copy_to_buffer( + self, + data: Union[bytes, list[bytes]], + data_bytes: int, + metadata: bytes, + md_bytes: int, + data_view: memoryview, + ) -> None: + data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata + if isinstance(data, bytes): + data_view[-data_bytes:] = data + elif isinstance(data, list): + start_idx = self.flag_bytes + md_bytes + for item_bytes in data: + item_size = len(item_bytes) + data_view[start_idx:start_idx + item_size] = item_bytes + start_idx += item_size + else: + raise ValueError( + f"Unsupported data type for serialization: {type(data)}") + + def increment_writer_flag(self, id: int) -> None: + """Set the in-use flag for the writer.""" + self.writer_flag[id] = self.writer_flag.get(id, 0) + 1 + + def increment_reader_flag(self, data_view: memoryview) -> None: + """Set the in-use flag for the reader.""" + # >0 for in-use flag + reader_count = self.ring_buffer.byte2int(data_view) + data_view[:] = self.ring_buffer.int2byte(reader_count + 1) + + def free_unused(self) -> None: + """Free unused buffers in the ring buffer.""" + # try to free up 2*max_object_size bytes of space in the ring buffer, + # since the buffer might be fragmented + freed_ids = self.ring_buffer.free_buf(self.default_is_free_check, + 2 * self.max_object_size) + # update the metadata after freeing up space + for freed_id in freed_ids: + key_to_free = self.id_index[freed_id] + del self.key_index[key_to_free] + del self.id_index[freed_id] + del self.writer_flag[freed_id] + + def is_cached(self, key: str) -> bool: + """ + Check if the object with the given key is cached. + """ + return key in self.key_index + + def get_cached(self, key: str) -> tuple[int, int]: + """ + Get the cached object by key if it exists. + """ + address, monotonic_id = self.key_index[key] + self.increment_writer_flag(monotonic_id) + return address, monotonic_id + + def put(self, key: str, value: Any) -> tuple[int, int]: + """ + Store a key-value pair in the object storage. + Attempts to free max_object_size bytes using FIFO order + when the ring buffer runs out of space during a put() operation. + + Args: + key: String key to identify the object + value: Any serializable Python object + + Raises: + MemoryError: If there's not enough space in the buffer + ValueError: If the serialized object is too large + ValueError: If the key already exists in the storage + """ + if key in self.key_index: + raise ValueError(f"Key '{key}' already exists in the storage.") + + object_data, data_bytes, object_metadata, md_bytes = \ + self.ser_de.serialize(value) + buffer_size = self.flag_bytes + data_bytes + md_bytes + + # Sanity checks + if buffer_size > self.max_object_size: + raise ValueError( + f"Serialized object size ({buffer_size} bytes) exceeds " + f"max object size ({self.max_object_size} bytes)") + + # Allocate new buffer + try: + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + except MemoryError: + self.free_unused() + # try again after freeing up space + address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size) + + # Write data to buffer + with self.ring_buffer.access_buf(address) as (data_view, metadata): + data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0) + self.copy_to_buffer(object_data, data_bytes, object_metadata, + md_bytes, data_view) + self.increment_writer_flag(monotonic_id) + + # Update key index + self.key_index[key] = (address, monotonic_id) + self.id_index[monotonic_id] = key + return address, monotonic_id + + def get(self, address: int, monotonic_id: int) -> Any: + # Read data from buffer + with self.ring_buffer.access_buf(address) as (data_view, buf_metadata): + # check id from metadata + if buf_metadata[0] != monotonic_id: + raise ValueError( + f"Data for address:id '{address}:{monotonic_id}'" + " has been modified or is invalid.") + + obj = self.ser_de.deserialize(data_view[self.flag_bytes:]) + + # decrease the in-use flag for reader reads + if self._reader_lock is not None: + with self._reader_lock: + self.increment_reader_flag(data_view[:self.flag_bytes]) + else: + # if self._reader_lock is None, it means we are the writer + # in this case, we do not need to decrease the reader count + assert self.is_writer + + return obj + + def handle(self): + """Get handle for sharing across processes.""" + return ShmObjectStorageHandle( + max_object_size=self.max_object_size, + n_readers=self.n_readers, + ring_buffer_handle=self.ring_buffer.handle(), + serde_class=self.serde_class, + reader_lock=self._reader_lock, + ) + + @staticmethod + def create_from_handle( + handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage": + logger.debug("Creating storage from handle: %s", handle) + ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle) + return SingleWriterShmObjectStorage( + max_object_size=handle.max_object_size, + n_readers=handle.n_readers, + ring_buffer=ring_buffer, + serde_class=handle.serde_class, + reader_lock=handle.reader_lock, + ) + + def default_is_free_check(self, id: int, buf: memoryview) -> bool: + """ + Default is_free function that checks if the first 4 bytes are zero. + This indicates that the buffer is free. + """ + reader_count = int.from_bytes(buf[0:4], "little", signed=True) + writer_count = self.writer_flag[id] + return reader_count >= writer_count * self.n_readers diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ba1543a8d5a3..52c9a8c30f1a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -27,8 +27,8 @@ DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, - ModelDType, ModelImpl, MultiModalConfig, + LoRAConfig, MambaDType, MMCacheType, MMEncoderTPMode, + ModelConfig, ModelDType, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, RunnerOption, SchedulerConfig, SchedulerPolicy, SpeculativeConfig, TaskOption, @@ -373,6 +373,10 @@ class EngineArgs: MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_type: Optional[MMCacheType] = \ + MultiModalConfig.mm_processor_cache_type + mm_shm_cache_max_object_size_mb: int = \ + MultiModalConfig.mm_shm_cache_max_object_size_mb mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling @@ -782,6 +786,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument("--disable-mm-preprocessor-cache", action="store_true", deprecated=True) + multimodal_group.add_argument( + "--mm-processor-cache-type", + **multimodal_kwargs["mm_processor_cache_type"]) + multimodal_group.add_argument( + "--mm-shm-cache-max-object-size-mb", + **multimodal_kwargs["mm_shm_cache_max_object_size_mb"]) multimodal_group.add_argument( "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) multimodal_group.add_argument( @@ -998,6 +1008,9 @@ def create_model_config(self) -> ModelConfig: config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, + mm_processor_cache_type=self.mm_processor_cache_type, + mm_shm_cache_max_object_size_mb=self. + mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, diff --git a/vllm/envs.py b/vllm/envs.py index d7956a2adff6..108aa06778f4 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -173,6 +173,7 @@ VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True + VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" def get_default_cache_root(): @@ -1226,6 +1227,12 @@ def get_vllm_port() -> Optional[int]: # raw bytes. Defaults to True for backward compatibility. "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), + + # Name of the shared memory buffer used for object storage. + # Only effective when mm_config.mm_processor_cache_type == "shm". + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": + lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", + "VLLM_OBJECT_STORAGE_SHM_BUFFER"), } # --8<-- [end:env-vars-definition] diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index f45a94f3151b..5920744bd496 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +from multiprocessing import Lock from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -10,9 +11,12 @@ import vllm.envs as envs from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, run_method) from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -44,6 +48,8 @@ def _init_executor(self) -> None: distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) + self.mm_receiver_cache = worker_receiver_cache_from_config( + self.vllm_config, MULTIMODAL_REGISTRY, Lock()) self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") self.collective_rpc("load_model") @@ -55,6 +61,8 @@ def collective_rpc(self, kwargs: Optional[Dict] = None) -> List[Any]: if kwargs is None: kwargs = {} + if self.mm_receiver_cache is not None and method == "execute_model": + get_and_update_mm_cache(self.mm_receiver_cache, args) answer = run_method(self.driver_worker, method, args, kwargs) return [answer] diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 35b743ed21d9..b9eeadccff83 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -3,19 +3,24 @@ import sys from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union +from multiprocessing.synchronize import Lock as LockType +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast import torch from typing_extensions import TypeAlias, override +from vllm.distributed.device_communicators.shm_object_storage import ( + MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer) +from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger -from vllm.utils import GiB_bytes, LRUCache +from vllm.utils import GiB_bytes, LRUCache, MiB_bytes from vllm.utils.jsontree import (json_count_leaves, json_map_leaves, json_reduce_leaves) -from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem, - MultiModalKwargs, MultiModalKwargsItem, - MultiModalKwargsItems, NestedTensors) +from .inputs import (MultiModalBatchedField, MultiModalFeatureSpec, + MultiModalFieldElem, MultiModalKwargs, + MultiModalKwargsItem, MultiModalKwargsItems, + NestedTensors) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -389,6 +394,106 @@ def clear_cache(self) -> None: self._cache.clear() +class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the data in shared memory. + """ + + def __init__(self, vllm_config: "VllmConfig") -> None: + super().__init__() + + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=True, # sender is the writer + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * + MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + ) + # cache (prompt_updates, modality) for P0 only + self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], + str]] = {} + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return self._shm_cache.is_cached(mm_hash) + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + + if self._shm_cache.is_cached(mm_hash): + address, monotonic_id = self._shm_cache.get_cached(mm_hash) + prompt_updates, modality = self._p0_cache[mm_hash] + return self.address_as_item(address, monotonic_id, + modality), prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + try: + address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) + # Try to remove dangling items if p0 cache is too large. + if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index): + self.remove_dangling_items() + self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality + address_item = self.address_as_item(address, monotonic_id, + mm_item[0].modality) + return address_item, mm_item[1] + except (ValueError, MemoryError) as e: + # put may fail if the object is too large or + # the cache is full. + # In this case we log the error and keep the original mm_input. + logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash, + e) + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + self._p0_cache.clear() + + def remove_dangling_items(self) -> None: + """Remove items that are no longer in the shared memory cache.""" + cached_hashes = self._shm_cache.key_index.keys() + dangling_hashes = set(self._p0_cache.keys()) - cached_hashes + for mm_hash in dangling_hashes: + del self._p0_cache[mm_hash] + + def address_as_item(self, address: int, monotonic_id: int, + modality: str) -> MultiModalKwargsItem: + addr_elem = MultiModalFieldElem( + modality=modality, + key="address", + data=address, + field=MultiModalBatchedField(), + ) + id_elem = MultiModalFieldElem( + modality=modality, + key="monotonic_id", + data=monotonic_id, + field=MultiModalBatchedField(), + ) + mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem]) + return mm_item + + def _enable_processor_cache( model_config: "ModelConfig", mm_registry: "MultiModalRegistry", @@ -408,6 +513,17 @@ def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: return supports_ipc_cache +def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool: + """Whether the shared memory based cache should be enabled.""" + + if not _enable_ipc_cache(vllm_config): + return False + + mm_config = vllm_config.model_config.get_multimodal_config() + + return mm_config.mm_processor_cache_type == "shm" + + def processor_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", @@ -421,7 +537,9 @@ def processor_cache_from_config( if not _enable_ipc_cache(vllm_config): return MultiModalProcessorOnlyCache(model_config) - return MultiModalProcessorSenderCache(model_config) + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalProcessorSenderCache(model_config) + return ShmObjectStoreSenderCache(vllm_config) def processor_only_cache_from_config( @@ -491,11 +609,70 @@ def clear_cache(self) -> None: self._cache.clear() -def receiver_cache_from_config( +class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 Worker Process when IPC caching is enabled. + + How to update each item: + + - If the item has an address, replace the input with the cached item. + - If not, return the input. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + shared_worker_lock: Optional[LockType], + ) -> None: + super().__init__() + + assert shared_worker_lock is not None, \ + "shared_worker_lock must be provided for shm receiver cache" + self.world_size = vllm_config.parallel_config.world_size + mm_config = vllm_config.model_config.get_multimodal_config() + + ring_buffer = SingleWriterShmRingBuffer( + data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes), + name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME, + create=False, # Server is a reader + ) + self._shm_cache = SingleWriterShmObjectStorage( + max_object_size=mm_config.mm_shm_cache_max_object_size_mb * + MiB_bytes, + n_readers=self.world_size, + ring_buffer=ring_buffer, + serde_class=MsgpackSerde, + reader_lock=shared_worker_lock, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + assert mm_item is not None, f"Expected an address item for {mm_hash=}" + if "address" in mm_item: + address = cast(int, mm_item["address"].data) + monotonic_id = cast(int, mm_item["monotonic_id"].data) + return self._shm_cache.get(address, monotonic_id) + + return mm_item + + @override + def clear_cache(self) -> None: + self._shm_cache.clear() + + +def engine_receiver_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", ) -> Optional[BaseMultiModalReceiverCache]: - """Return a `BaseMultiModalReceiverCache`, if enabled.""" + """ + This is used in the engine process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="lru". + """ model_config = vllm_config.model_config if not _enable_processor_cache(model_config, mm_registry): @@ -504,4 +681,31 @@ def receiver_cache_from_config( if not _enable_ipc_cache(vllm_config): return None - return MultiModalReceiverCache(model_config) + if not _enable_mm_input_shm_cache(vllm_config): + return MultiModalReceiverCache(model_config) + + return None + + +def worker_receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", + shared_worker_lock: Optional[LockType] = None, +) -> Optional[BaseMultiModalReceiverCache]: + """ + This is used in the worker process. + Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and + mm_processor_cache_type=="shm". + """ + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + if not _enable_mm_input_shm_cache(vllm_config): + return None + + return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index ea860ca10b3a..7984f1958ece 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -163,6 +163,12 @@ STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" +MB_bytes = 1_000_000 +"""The number of bytes in one megabyte (MB).""" + +MiB_bytes = 1 << 20 +"""The number of bytes in one mebibyte (MiB).""" + GB_bytes = 1_000_000_000 """The number of bytes in one gigabyte (GB).""" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b46ae72ccdf1..27ee146c4f6f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -23,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import receiver_cache_from_config +from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -131,7 +131,7 @@ def __init__(self, self.use_spec_decode = vllm_config.speculative_config is not None self.mm_registry = mm_registry = MULTIMODAL_REGISTRY - self.mm_receiver_cache = receiver_cache_from_config( + self.mm_receiver_cache = engine_receiver_cache_from_config( vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index bcf6dda9c1e9..0bc416b79401 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -12,8 +12,10 @@ from dataclasses import dataclass from enum import Enum, auto from functools import partial +from multiprocessing import Lock from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess +from multiprocessing.synchronize import Lock as LockType from threading import Thread from typing import Any, Callable, Optional, Union, cast @@ -31,10 +33,13 @@ from vllm.executor.multiproc_worker_utils import ( set_multiprocessing_worker_envs) from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils import (decorate_logs, get_distributed_init_method, get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback +from vllm.v1.executor.utils import get_and_update_mm_cache from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) from vllm.worker.worker_base import WorkerWrapperBase @@ -81,6 +86,7 @@ def _init_executor(self) -> None: scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers + shared_worker_lock = Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: @@ -92,6 +98,7 @@ def _init_executor(self) -> None: rank=rank, distributed_init_method=distributed_init_method, input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, )) # Workers must be created before wait_for_ready to avoid @@ -380,6 +387,7 @@ def __init__( rank: int, distributed_init_method: str, input_shm_handle: Handle, + shared_worker_lock: LockType, ): self.rank = rank wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank) @@ -415,6 +423,10 @@ def __init__( daemon=True, name="WorkerAsyncOutputCopy") self.async_output_copy_thread.start() + + # Initialize multimodal receiver cache if needed + self.mm_receiver_cache = worker_receiver_cache_from_config( + vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock) # Initialize device self.worker.init_device() @@ -428,11 +440,12 @@ def __init__( @staticmethod def make_worker_process( - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - input_shm_handle, # Receive SchedulerOutput + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + shared_worker_lock: LockType, ) -> UnreadyWorkerProcHandle: context = get_mp_context() # (reader, writer) @@ -449,6 +462,7 @@ def make_worker_process( "input_shm_handle": input_shm_handle, "ready_pipe": (reader, writer), "death_pipe": death_reader, + "shared_worker_lock": shared_worker_lock, } # Run EngineCore busy loop in background process. proc = context.Process(target=WorkerProc.worker_main, @@ -646,6 +660,10 @@ def worker_busy_loop(self, cancel: Optional[threading.Event] = None): func = getattr(self.worker, method) elif isinstance(method, bytes): func = partial(cloudpickle.loads(method), self.worker) + # retrieve from shm cache if available + if self.mm_receiver_cache is not None \ + and func.__name__ == "execute_model": + get_and_update_mm_cache(self.mm_receiver_cache, args) output = func(*args, **kwargs) except Exception as e: # Notes have been introduced in python 3.11 diff --git a/vllm/v1/executor/utils.py b/vllm/v1/executor/utils.py new file mode 100644 index 000000000000..e6c107692a5b --- /dev/null +++ b/vllm/v1/executor/utils.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.multimodal.cache import ShmObjectStoreReceiverCache +from vllm.v1.core.sched.output import SchedulerOutput + + +def get_and_update_mm_cache( + receiver_cache: ShmObjectStoreReceiverCache, + args: tuple[SchedulerOutput], +) -> None: + """ + For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory + cache as needed. + + Args: + receiver_cache: The receiver cache to update. + args: According to the collective_rpc call of execute_model method in + executor, args is a tuple of only one SchedulerOutput element. + """ + scheduler_output = args[0] + for request_data in scheduler_output.scheduled_new_reqs: + for i in range(len(request_data.mm_kwargs)): + mm_input = request_data.mm_kwargs[i] + request_data.mm_kwargs[i] = \ + receiver_cache.get_and_update_item(mm_input, None) From 6820d5818260b75683ecf7f461ec0d73ba4d62d1 Mon Sep 17 00:00:00 2001 From: donglu Date: Tue, 9 Sep 2025 14:34:04 +0000 Subject: [PATCH 2/2] address comments, fix CI Signed-off-by: donglu --- docs/configuration/optimization.md | 10 +++++----- tests/distributed/test_shm_buffer.py | 17 ++++++----------- tests/distributed/test_shm_storage.py | 3 +-- tests/multimodal/test_cache.py | 8 ++++---- tests/tensorizer_loader/conftest.py | 1 + vllm/executor/uniproc_executor.py | 2 ++ vllm/multimodal/cache.py | 6 ++---- vllm/v1/executor/multiproc_executor.py | 6 +++--- 8 files changed, 24 insertions(+), 29 deletions(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 8ea382181165..5807d787cf53 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -273,12 +273,12 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: -| Cache Type | `P0` Cache | `P1` Cache | Shm Cache | Max. Memory | mm_processor_cache_type | +| mm_processor_cache_type | Cache Type | `P0` Cache | `P1` Engine Cache | `P1` Worker Cache | Max. Memory | |-------------------|-------------|------------|------------|-------------|-------------| -| Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` | lru | -| Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` | lru | -| Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` | shm | -| Disabled | N/A | N/A | N/A | `0` | N/A | +| lru | Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` | +| lru | Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| shm | Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` | +| N/A | Disabled | N/A | N/A | N/A | `0` | K: Stores the hashes of multi-modal items V: Stores the processed tensor data of multi-modal items diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py index f939475ba952..f70028b87960 100644 --- a/tests/distributed/test_shm_buffer.py +++ b/tests/distributed/test_shm_buffer.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import traceback import unittest from vllm.distributed.device_communicators.shm_object_storage import ( @@ -26,15 +27,11 @@ def test_buffer_opening(self): self.ring_buffer = SingleWriterShmRingBuffer( data_buffer_size=self.buffer_size, create=True) - try: - # Then open it with another instance - reader_buffer = SingleWriterShmRingBuffer( - *self.ring_buffer.handle()) - self.assertFalse(reader_buffer.is_writer) - self.assertEqual(reader_buffer.shared_memory.name, - self.ring_buffer.shared_memory.name) - finally: - del reader_buffer + # Then open it with another instance + reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle()) + self.assertFalse(reader_buffer.is_writer) + self.assertEqual(reader_buffer.shared_memory.name, + self.ring_buffer.shared_memory.name) def test_buffer_access(self): """Test accessing allocated buffers""" @@ -166,8 +163,6 @@ def main(): except Exception as e: print(f"Demo error: {e}") - import traceback - traceback.print_exc() print("\n=== Demo Complete ===") diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py index 908ed33dff6c..03495222bc1b 100644 --- a/tests/distributed/test_shm_storage.py +++ b/tests/distributed/test_shm_storage.py @@ -4,6 +4,7 @@ import multiprocessing import random import time +import traceback import unittest from multiprocessing import Lock @@ -313,8 +314,6 @@ def run_multiprocess_example(): except Exception as e: print(f"Error in minimal example: {e}") - import traceback - traceback.print_exc() diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 44c05db2278f..3c61ee26e092 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -10,8 +10,8 @@ from vllm.multimodal.cache import (MultiModalCache, MultiModalProcessorCacheItem, MultiModalProcessorCacheItemMetadata, - processor_cache_from_config, - receiver_cache_from_config) + engine_receiver_cache_from_config, + processor_cache_from_config) from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, @@ -115,9 +115,9 @@ def _compare_caches( ): mm_registry = MultiModalRegistry() cache_0_p0 = processor_cache_from_config(config_0, mm_registry) - cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) + cache_0_p1 = engine_receiver_cache_from_config(config_0, mm_registry) cache_1_p0 = processor_cache_from_config(config_1, mm_registry) - cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + cache_1_p1 = engine_receiver_cache_from_config(config_1, mm_registry) cache_size_gb = max( config_0.model_config.mm_processor_cache_gb, diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index 18aa4c88c033..571dc2e0eb50 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -90,6 +90,7 @@ def _init_executor(self) -> None: distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) + self.mm_receiver_cache = None self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 5920744bd496..2e456ecd6de0 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -136,6 +136,8 @@ def _init_executor(self) -> None: distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker, ) + self.mm_receiver_cache = worker_receiver_cache_from_config( + self.vllm_config, MULTIMODAL_REGISTRY, Lock()) self.collective_rpc("init_worker", args=([kwargs], )) self.collective_rpc("init_device") self.collective_rpc("load_model") diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index b9eeadccff83..31ae450f4c2f 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -622,12 +622,10 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): def __init__( self, vllm_config: "VllmConfig", - shared_worker_lock: Optional[LockType], + shared_worker_lock: LockType, ) -> None: super().__init__() - assert shared_worker_lock is not None, \ - "shared_worker_lock must be provided for shm receiver cache" self.world_size = vllm_config.parallel_config.world_size mm_config = vllm_config.model_config.get_multimodal_config() @@ -690,7 +688,7 @@ def engine_receiver_cache_from_config( def worker_receiver_cache_from_config( vllm_config: "VllmConfig", mm_registry: "MultiModalRegistry", - shared_worker_lock: Optional[LockType] = None, + shared_worker_lock: LockType, ) -> Optional[BaseMultiModalReceiverCache]: """ This is used in the worker process. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0bc416b79401..2f0e5aa383b1 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from enum import Enum, auto from functools import partial -from multiprocessing import Lock from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess from multiprocessing.synchronize import Lock as LockType @@ -86,7 +85,8 @@ def _init_executor(self) -> None: scheduler_output_handle = self.rpc_broadcast_mq.export_handle() # Create workers - shared_worker_lock = Lock() + context = get_mp_context() + shared_worker_lock = context.Lock() unready_workers: list[UnreadyWorkerProcHandle] = [] success = False try: @@ -423,7 +423,7 @@ def __init__( daemon=True, name="WorkerAsyncOutputCopy") self.async_output_copy_thread.start() - + # Initialize multimodal receiver cache if needed self.mm_receiver_cache = worker_receiver_cache_from_config( vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock)