Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ ENV CUDA_HOME=/usr/local/cuda \

RUN if [ "${ENABLE_CACHE}" = "1" ]; then \
apt-get update && apt-get install -y libboost-dev && rm -rf /var/lib/apt/lists/*; \
LIGHTMEM_REF=5900baf92d85ef4dbda6124093506b0af906011a; \
LIGHTMEM_REF=9f9817b0ec6ae7055dea0542a63f66de2685ed90; \
pip install --no-deps -v "git+https://github.com/ModelTC/LightMem.git@${LIGHTMEM_REF}#egg=light_mem"; \
fi

Expand Down
45 changes: 35 additions & 10 deletions lightllm/server/multi_level_kv_cache/disk_cache_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
)
raise ImportError("LightMem library is required for disk cache functionality") from e

TASK_WAIT_TIMEOUT_S = 60.0


@dataclass
class _PagePayload:
Expand All @@ -43,11 +45,11 @@ def __init__(
assert disk_cache_storage_size > 0
storage_size = int(disk_cache_storage_size * (1024 ** 3))
# num_shard与KVCACHE_MAX_BLOCK_SIZE相关,KVCACHE_MAX_BLOCK_SIZE默认64MB前提下,
# num_shard设置32, 能使disk cache的容量利用率达到90%,继续增大num_shard会导致容量利用率下降
num_shard = 32
num_worker = 48
# 读写同时进行时,分配16线程用来写,32线程用来读
max_concurrent_write_tasks = 16
# num_shard设置8, 能使disk cache的容量利用率达到90%,继续增大num_shard会导致容量利用率下降
num_shard = 8
num_worker = 24
# 读写同时进行时,分配8线程用来写,16线程用来读
max_concurrent_write_tasks = 8
Comment on lines 47 to +52

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments on lines 45-46 and 49 are now outdated and misleading because num_shard, num_worker, and max_concurrent_write_tasks have been updated. Please update the comments to reflect the new values (e.g., allocating 8 threads for writing and 16 threads for reading).

Suggested change
# num_shard与KVCACHE_MAX_BLOCK_SIZE相关,KVCACHE_MAX_BLOCK_SIZE默认64MB前提下,
# num_shard设置32, 能使disk cache的容量利用率达到90%,继续增大num_shard会导致容量利用率下降
num_shard = 32
num_worker = 48
num_shard = 8
num_worker = 24
# 读写同时进行时,分配16线程用来写,32线程用来读
max_concurrent_write_tasks = 16
max_concurrent_write_tasks = 8
# num_shard与KVCACHE_MAX_BLOCK_SIZE相关
num_shard = 8
num_worker = 24
# 读写同时进行时,分配8线程用来写,16线程用来读
max_concurrent_write_tasks = 8


cache_dir = disk_cache_dir
if not cache_dir:
Expand Down Expand Up @@ -78,6 +80,24 @@ def __init__(
def _prepare_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.flatten(1).view(dtype=torch.uint8)

def _wait_task(self, task, cond_name: str) -> bool:
cond = getattr(task, cond_name)
deadline = time.monotonic() + TASK_WAIT_TIMEOUT_S
while not cond():
if time.monotonic() >= deadline:
logger.error(
"disk cache task '%s' wait timeout after %.1fs, aborting task to avoid hang",
cond_name,
TASK_WAIT_TIMEOUT_S,
)
try:
self.service.abort(task)
except Exception as e:
logger.error("disk cache abort task failed: %s", e)
return False
time.sleep(0.001)
return True

def run(self) -> None:
while True:
time.sleep(0.1)
Expand Down Expand Up @@ -121,9 +141,16 @@ def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None:
query_result = self.service.query(hashs)
if not all(query_result):
# 限制写入并发量,给读取操作留资源
throttle_deadline = time.monotonic() + TASK_WAIT_TIMEOUT_S
while (
self.service.active_threads("r") and self.service.active_threads("w") >= self.max_concurrent_write_tasks
):
if time.monotonic() >= throttle_deadline:
logger.error(
"disk cache write throttle wait timeout after %.1fs, proceeding to submit",
TASK_WAIT_TIMEOUT_S,
)
break
time.sleep(0.001)

task = self.service.create(hash_128s=hashs, kv_page_indexer=kv_indexer, mode="w")
Expand All @@ -133,9 +160,7 @@ def _persist_pages_to_disk(self, payloads: List[_PagePayload]) -> None:
self.cpu_cache_client.deref_pages(page_list=task.page_already_list)
self.cpu_cache_client.lock.release()

# 数据安全即可结束等待,无需写入完成
while not task.data_safe():
time.sleep(0.001)
self._wait_task(task, "data_safe")

# 释放剩余需要写入的页面
remining_indexes = list(set(page_indexes) - set(task.page_already_list))
Expand Down Expand Up @@ -181,6 +206,6 @@ def load_pages(self, hashs: List[int], page_indexes: List[int], start_pos: int =

kv_indexer = torch.tensor(page_indexes, dtype=torch.int32, device="cpu")
task = self.service.create(hash_128s=hashs, kv_page_indexer=kv_indexer, mode="r", start_pos=start_pos)
while not task.ready():
time.sleep(0.001)
if not self._wait_task(task, "ready"):
return False
return all(state == PyState.Finished for state in task.state())
Loading
Loading