diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 594e81a9b4..533b51e954 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -1,5 +1,6 @@ """Attention backend selection utilities.""" -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.common.basemodel.attention.paged_fa3.fp import PagedFa3AttBackend +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.log_utils import init_logger from lightllm.utils.backend_validator import validate from typing import Dict @@ -23,7 +24,7 @@ data_type_to_backend = { "None": { "triton": TritonAttBackend, - "fa3": Fa3AttBackend, + "fa3": PagedFa3AttBackend if get_page_size() > 1 else Fa3AttBackend, "flashinfer": FlashInferAttBackend, }, "int4kv": { diff --git a/lightllm/common/basemodel/attention/flashinfer/fp.py b/lightllm/common/basemodel/attention/flashinfer/fp.py index 91a004ec2e..6653918eeb 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp.py @@ -110,7 +110,7 @@ def _nomarl_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: self.backend: FlashInferAttBackend = self.backend # for typing - o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + o_tensor = alloc_func(q.shape, q.dtype, device=q.device) self.prefill_wrapper.run( q, (k.unsqueeze(1), v.unsqueeze(1)), diff --git a/lightllm/common/basemodel/attention/flashinfer/fp8.py b/lightllm/common/basemodel/attention/flashinfer/fp8.py index 76cf5622f2..851a2a5e90 100644 --- a/lightllm/common/basemodel/attention/flashinfer/fp8.py +++ b/lightllm/common/basemodel/attention/flashinfer/fp8.py @@ -46,7 +46,7 @@ def prefill_att( def _fp8_prefill_att( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty ) -> torch.Tensor: - o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + o_tensor = alloc_func(q.shape, q.dtype, device=q.device) k = k.unsqueeze(1).view(torch.float8_e4m3fn) v = v.unsqueeze(1).view(torch.float8_e4m3fn) layer_index = self.backend._find_layer_index(k=k, v=v, att_state=self) @@ -97,7 +97,7 @@ def _fp8_decode_att( v: torch.Tensor, alloc_func=torch.empty, ): - o_tensor = alloc_func(q.shape, q.dtype, device="cuda") + o_tensor = alloc_func(q.shape, q.dtype, device=q.device) k = k.unsqueeze(1).view(torch.float8_e4m3fn) v = v.unsqueeze(1).view(torch.float8_e4m3fn) diff --git a/lightllm/common/basemodel/attention/flashinfer/mla.py b/lightllm/common/basemodel/attention/flashinfer/mla.py index 84b44dc45a..4b5152b907 100644 --- a/lightllm/common/basemodel/attention/flashinfer/mla.py +++ b/lightllm/common/basemodel/attention/flashinfer/mla.py @@ -102,7 +102,7 @@ def _mla_prefill_att( ) -> torch.Tensor: self.backend: MlaFlashInferAttBackend = self.backend # for typing k_nope, k_rope = k - o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device="cuda") + o_tensor = alloc_func((q.shape[0], q.shape[1], v.shape[-1]), q.dtype, device=q.device) q_head_num = q.shape[1] k = torch.cat([k_nope, torch.repeat_interleave(k_rope, q_head_num, dim=-2)], dim=-1) self.prefill_wrapper.run(q, k, v, out=o_tensor) @@ -125,7 +125,7 @@ def init_state(self): self.kv_starts = self.infer_state.b1_cu_kv_seq_len - self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device="cuda") + self.q_indptr = torch.arange(batch_size + 1, dtype=torch.int32, device=device) if batch_size <= model.graph_max_batch_size and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch: self.kv_indices = self.backend.kv_indices_buffer[self.infer_state.microbatch_index][ : batch_size * self.backend.max_seq_length diff --git a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py index 673b5896d8..fd4bacdfa2 100644 --- a/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py +++ b/lightllm/common/basemodel/attention/nsa/flashmla_sparse.py @@ -135,10 +135,19 @@ def init_state(self): hold_req_idx=self.infer_state.req_manager.HOLD_REQUEST_ID, ) self.nsa_cache_seqlens = torch.minimum( - torch.full(size=(self.infer_state.batch_size,), fill_value=2048, dtype=torch.int32, device="cuda"), + torch.full( + size=(self.infer_state.batch_size,), + fill_value=2048, + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ), self.infer_state.b_seq_len, ) - padded_seq_lens = torch.zeros(size=(self.nsa_cache_seqlens.shape[0] + 1,), dtype=torch.int32, device="cuda") + padded_seq_lens = torch.zeros( + size=(self.nsa_cache_seqlens.shape[0] + 1,), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) # 进行 cumsum 操作 padded_seq_lens[1:].copy_(self.nsa_cache_seqlens, non_blocking=True) self.nsa_cu_seqlens_k_new = padded_seq_lens.cumsum(dim=0, dtype=torch.int32) diff --git a/lightllm/common/basemodel/attention/paged_fa3/fp.py b/lightllm/common/basemodel/attention/paged_fa3/fp.py new file mode 100644 index 0000000000..5207ec830b --- /dev/null +++ b/lightllm/common/basemodel/attention/paged_fa3/fp.py @@ -0,0 +1,372 @@ +import dataclasses +import torch +import triton +from ..base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.utils.envs_utils import get_env_start_args, get_page_size +from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy +from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor +from typing import Any + + +class PagedFa3AttBackend(BaseAttBackend): + + def __init__(self, model, page_size=None): + super().__init__(model=model) + self.page_size = page_size or get_page_size() + self.get_page_table_buffer() + + def get_page_table_buffer(self): + model = self.model + if not hasattr(self, "_shared_page_table_buffer"): + shared_len = model.graph_max_batch_size * triton.cdiv(model.graph_max_len_in_batch, self.page_size) + self._shared_page_table_buffer = [ + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + torch.empty(shared_len, dtype=torch.int32).to(get_current_device_id()), + ] + return self._shared_page_table_buffer + + def get_decode_seq_len_cpu_buffers(self, min_len: int): + """Pinned CPU int32 buffers reused for npu_fused_infer_attention_score list args.""" + model = self.model + cap = max(min_len, model.graph_max_batch_size) + if not hasattr(self, "_decode_seq_len_cpu_q") or self._decode_seq_len_cpu_q.shape[0] < min_len: + self._decode_seq_len_cpu_q = torch.empty(cap, dtype=torch.int32, pin_memory=True) + self._decode_seq_len_cpu_kv = torch.empty(cap, dtype=torch.int32, pin_memory=True) + return self._decode_seq_len_cpu_q, self._decode_seq_len_cpu_kv + + def create_att_prefill_state(self, infer_state): + return PagedFa3PrefillAttState(backend=self, infer_state=infer_state) + + def create_att_decode_state(self, infer_state): + return PagedFa3DecodeAttState(backend=self, infer_state=infer_state) + + +@dataclasses.dataclass +class PagedFa3PrefillAttState(BasePrefillAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + atten_mask: torch.Tensor = None + + def init_state(self): + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + self.page_table = torch.empty( + (self.infer_state.batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + page_table_copy( + page_table=self.page_table, + req_to_token_indexs=self.infer_state.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + page_size=self.backend.page_size, + ) + if self.atten_mask is None: + self.atten_mask = torch.triu(torch.ones([2048, 2048]), diagonal=1).to(dtype=torch.int8, device=self.infer_state.input_ids.device) + + def prefill_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_prefill_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + + if q.device.type == "npu": + import torch_npu + + N_KV, HEAD_DIM = k.shape[-2:] + # to (num_blocks, block_size, hidden_size) + key = k.view(-1, self.backend.page_size, N_KV * HEAD_DIM) + value = v.view(-1, self.backend.page_size, N_KV * HEAD_DIM) + out = torch_npu.npu_fused_infer_attention_score( + query=q, + key=key, + value=value, + input_layout="TND", + sparse_mode=3, + atten_mask=self.atten_mask, + scale=sm_scale, + actual_seq_lengths=self.infer_state.b1_cu_q_seq_len_cpu, + actual_seq_lengths_kv=self.infer_state.b_cu_kv_seq_len_cpu, + num_heads=q.shape[-2], + num_key_value_heads=N_KV, + block_table=self.page_table, + block_size=self.backend.page_size, + )[0] + return out + else: + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) + + +@dataclasses.dataclass +class PagedFa3DecodeAttState(BaseDecodeAttState): + cu_seqlens_q: torch.Tensor = None + cu_seqlens_k: torch.Tensor = None + page_table: torch.Tensor = None + b_att_seq_len: torch.Tensor = None + decode_max_q_seq_len: int = None + + def init_state(self): + args_mtp_step = get_env_start_args().mtp_step + if args_mtp_step > 0: + mtp_size = args_mtp_step + 1 + b_q_seq_len = torch.full( + (self.infer_state.b_seq_len.shape[0] // mtp_size,), + fill_value=mtp_size, + dtype=torch.int32, + device=self.infer_state.b_seq_len.device, + ) + b_kv_seq_len = self.infer_state.b_seq_len[mtp_size - 1 :: mtp_size] + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + self.cu_seqlens_q = b1_cu_q_seq_len.int() + self.cu_seqlens_k = b1_cu_kv_seq_len.int() + else: + self.cu_seqlens_q = self.infer_state.b1_cu_q_seq_len.int() + self.cu_seqlens_k = self.infer_state.b1_cu_kv_seq_len.int() + + att_batch_size = self.infer_state.batch_size // (args_mtp_step + 1) + assert self.infer_state.batch_size % (args_mtp_step + 1) == 0 + model = self.backend.model + table_len = triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) + if ( + self.infer_state.batch_size <= model.graph_max_batch_size + and self.infer_state.max_kv_seq_len <= model.graph_max_len_in_batch + ): + page_buffer = self.backend.get_page_table_buffer() + shared_table_len = triton.cdiv(model.graph_max_len_in_batch, self.backend.page_size) + self.page_table = page_buffer[self.infer_state.microbatch_index][ + : att_batch_size * shared_table_len + ].reshape(att_batch_size, shared_table_len) + else: + self.page_table = torch.empty( + (att_batch_size, table_len), + dtype=torch.int32, + device=self.infer_state.input_ids.device, + ) + + if args_mtp_step > 0: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx[args_mtp_step :: (args_mtp_step + 1)], + page_size=self.backend.page_size, + ) + self.b_att_seq_len = self.infer_state.b_seq_len[args_mtp_step :: (args_mtp_step + 1)].contiguous() + self.decode_max_q_seq_len = args_mtp_step + 1 + else: + page_table_copy( + page_table=self.page_table[:, :table_len], + req_to_token_indexs=model.req_manager.req_to_token_indexs, + b_req_idx=self.infer_state.b_req_idx, + page_size=self.backend.page_size, + ) + self.b_att_seq_len = self.infer_state.b_seq_len + self.decode_max_q_seq_len = 1 + + def decode_att(self, q, k, v, att_control: AttControl = AttControl(), alloc_func=torch.empty): + assert att_control.use_alibi is False + return self._normal_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) + + def _normal_decode_att(self, q, k, v, att_control: AttControl, alloc_func=torch.empty): + if att_control.use_sliding_window: + window_size = att_control.sliding_window + else: + window_size = (-1, -1) + + if att_control.use_att_sink: + sink_weight = att_control.sink_weight + else: + sink_weight = None + + sm_scale = 1.0 / (q.shape[-1] ** 0.5) + if q.device.type == "npu": + import torch_npu + + N_Q = q.shape[-2] + N_KV, HEAD_DIM = k.shape[-2:] + + k = k.view(-1, self.backend.page_size, N_KV * HEAD_DIM) + v = v.view(-1, self.backend.page_size, N_KV * HEAD_DIM) + + output = torch.empty_like(q) + softmax_lse = torch.empty(1, dtype=torch.float16, device=q.device) + if torch.npu.is_current_stream_capturing(): + stream = torch.npu.current_stream() + + from lightllm.common.basemodel.graph.acl_graph import get_attn_params + + batch_size = self.infer_state.batch_size + attn_params = get_attn_params() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + + workspace = attn_params.workspaces.get(batch_size, None) + if workspace is None: + workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query=q, + key=k, + value=v, + input_layout="TND", + scale=sm_scale, + actual_seq_lengths=self.infer_state.b1_cu_q_seq_len_cpu, + actual_seq_lengths_kv=self.infer_state.b_cu_kv_seq_len_cpu, + num_heads=N_Q, + num_key_value_heads=N_KV, + block_table=self.page_table, + block_size=self.backend.page_size, + ) + attn_params.workspaces[batch_size] = workspace + + torch.npu.graph_task_group_begin(stream) + torch_npu.npu_fused_infer_attention_score.out( + query=q, + key=k, + value=v, + input_layout="TND", + scale=sm_scale, + actual_seq_lengths=self.infer_state.b1_cu_q_seq_len_cpu, + actual_seq_lengths_kv=self.infer_state.b_cu_kv_seq_len_cpu, + num_heads=N_Q, + num_key_value_heads=N_KV, + block_table=self.page_table, + block_size=self.backend.page_size, + workspace=workspace, + out=[output, softmax_lse], + ) + handle = torch.npu.graph_task_group_end(stream) + + from lightllm.common.basemodel.graph.acl_graph import add_attn_params + + add_attn_params( + batch_size=self.infer_state.batch_size, + event=event, + handle=handle, + attn_params=( + weak_ref_tensor(q), + weak_ref_tensor(k), + weak_ref_tensor(v), + sm_scale, + N_Q, + N_KV, + weak_ref_tensor(self.page_table), + self.backend.page_size, + weak_ref_tensor(output), + weak_ref_tensor(softmax_lse), + ) + ) + else: + torch_npu.npu_fused_infer_attention_score.out( + query=q, + key=k, + value=v, + input_layout="TND", + scale=sm_scale, + actual_seq_lengths=self.infer_state.b1_cu_q_seq_len_cpu, + actual_seq_lengths_kv=self.infer_state.b_cu_kv_seq_len_cpu, + num_heads=N_Q, + num_key_value_heads=N_KV, + block_table=self.page_table, + block_size=self.backend.page_size, + out=[output, softmax_lse], + ) + + return output + else: + return flash_attn_with_kvcache( + q=q, + k_cache=k.view(-1, self.backend.page_size, k.shape[1], k.shape[2]), + v_cache=v.view(-1, self.backend.page_size, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.b_att_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.decode_max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + sinks=sink_weight, + ) + + +def update_attn_params( + batch_size: int, + actual_seq_lengths: list[int], + actual_seq_lengths_kv: list[int], + update_stream: Any, +): + import torch_npu + from lightllm.common.basemodel.graph.acl_graph import get_attn_params + + attn_params = get_attn_params() + handles = attn_params.handles[batch_size] + events = attn_params.events[batch_size] + workspace = attn_params.workspaces[batch_size] + params_list = attn_params.attn_params[batch_size] + + with torch.npu.stream(update_stream): + for handle, event, attn_param in zip(handles, events, params_list): + (q, k, v, sm_scale, N_Q, N_KV, page_table, block_size, output, softmax_lse) = attn_param + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu.npu_fused_infer_attention_score.out( + q, + k, + v, + input_layout="TND", + scale=sm_scale, + actual_seq_lengths=actual_seq_lengths, + actual_seq_lengths_kv=actual_seq_lengths_kv, + num_heads=N_Q, + num_key_value_heads=N_KV, + block_table=page_table, + block_size=block_size, + workspace=workspace, + out=[output, softmax_lse], + ) + torch.npu.graph_task_update_end(update_stream) + event.record(update_stream) + + +def weak_ref_tensor(tensor: Any) -> Any: + import torch_npu + + if isinstance(tensor, torch.Tensor): + return torch_npu._C._weak_ref_tensor(tensor) + else: + return tensor diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3b..9664c5dea7 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -62,7 +62,7 @@ def _alibi_prefill_att( def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd - out = alloc_func(q.shape, q.dtype) + out = alloc_func(q.shape, q.dtype, device=q.device) context_attention_fwd( q, k, diff --git a/lightllm/common/basemodel/attention_vit/xformers/fp.py b/lightllm/common/basemodel/attention_vit/xformers/fp.py index 361b5db050..32c8c5debf 100644 --- a/lightllm/common/basemodel/attention_vit/xformers/fp.py +++ b/lightllm/common/basemodel/attention_vit/xformers/fp.py @@ -1,13 +1,18 @@ +import inspect import torch import torch.nn.functional as F try: from xformers import ops as xformers_ops from xformers.ops import fmha + + _HAS_DEVICE_ARG = "device" in inspect.signature(fmha.BlockDiagonalMask.from_seqlens).parameters except ImportError: xformers_ops = None fmha = None + _HAS_DEVICE_ARG = False + from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend @@ -34,7 +39,13 @@ def _vit_att_fwd( if max_seqlen: assert max(seqlens) <= max_seqlen - attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device) + # In xformers, the API for BlockDiagonalMask.from_seqlens changed across versions: + # - v0.0.26 and earlier: no `device` argument is supported + # - v0.0.27 and later: a `device` argument was added + if not _HAS_DEVICE_ARG: + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + else: + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens, device=q.device) q_ = q.unsqueeze(0) # [1, T, H, D] k_ = k.unsqueeze(0) # [1, T, H, D] diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 05aaaadca8..860a3491ef 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -19,10 +19,12 @@ from lightllm.common.build_utils import repair_config from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.common.basemodel.cuda_graph import CudaGraph +from lightllm.common.basemodel.graph import DecodeGraph +from lightllm.common.basemodel.attention.paged_fa3.fp import update_attn_params from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -56,6 +58,10 @@ class TpPartBaseModel: def __init__(self, kvargs): self.args = get_env_start_args() + + self.platform_backend = get_backend() + self.target_device = self.platform_backend.runtime.target_device() + self.run_mode = kvargs["run_mode"] self.weight_dir_ = kvargs["weight_dir"] self.max_total_token_num = kvargs["max_total_token_num"] @@ -129,12 +135,17 @@ def __init__(self, kvargs): logger.info(f"use prefill att backend1: {self.prefill_att_backend1.__class__.__name__}") logger.info(f"use decode att backend1: {self.decode_att_backend1.__class__.__name__}") + if not self.disable_cudagraph: + self.b1_cu_q_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32) + self.b_cu_kv_seq_len_cpu_ref = torch.zeros(self.graph_max_batch_size, dtype=torch.int32) + self.ref_initialized = False + self._autotune_warmup() self._init_padded_req() self._init_cudagraph() self._init_prefill_cuda_graph() self._check_max_len_infer() - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() set_model_init_status(True) return @@ -273,15 +284,15 @@ def _init_att_backend1(self): return def _init_cudagraph(self): - self.graph = ( - None - if self.disable_cudagraph - else CudaGraph( + if self.disable_cudagraph: + self.graph = None + else: + self.graph = DecodeGraph( + platform_backend=self.platform_backend.name, max_batch_size=self.graph_max_batch_size, max_len_in_batch=self.graph_max_len_in_batch, tp_world_size=self.tp_world_size_, ) - ) if self.graph is not None: if get_env_start_args().enable_decode_microbatch_overlap: self.graph.warmup_overlap(self) @@ -305,8 +316,8 @@ def _init_custom(self): @torch.no_grad() def forward(self, model_input: ModelInput): - model_input.to_cuda() - assert model_input.mem_indexes.is_cuda + model_input.to_device(self.target_device) + assert model_input.mem_indexes.device == self.target_device if model_input.is_prefill: return self._prefill(model_input) @@ -525,7 +536,7 @@ def _prefill( alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, ) - prefill_mem_indexes_ready_event = torch.cuda.Event() + prefill_mem_indexes_ready_event = self.platform_backend.runtime.create_event() prefill_mem_indexes_ready_event.record() infer_state.init_some_extra_state(self) @@ -575,11 +586,34 @@ def _decode( infer_state.init_some_extra_state(self) infer_state.init_att_state() + batch_size = infer_state.batch_size + b1_cu_q_seq_len_cpu_slice = self.b1_cu_q_seq_len_cpu_ref[:batch_size] + b_cu_kv_seq_len_cpu_slice = self.b_cu_kv_seq_len_cpu_ref[:batch_size] + if self.platform_backend.name == "ascend": + if not self.platform_backend.graph.is_capturing(): + need_update_attn_params = not self.ref_initialized or not torch.equal( + b_cu_kv_seq_len_cpu_slice, infer_state.b_cu_kv_seq_len_cpu + ) + if need_update_attn_params: + b1_cu_q_seq_len_cpu_slice.copy_(infer_state.b1_cu_q_seq_len_cpu) + b_cu_kv_seq_len_cpu_slice.copy_(infer_state.b_cu_kv_seq_len_cpu) + update_attn_params( + batch_size, + b1_cu_q_seq_len_cpu_slice, + b_cu_kv_seq_len_cpu_slice, + self.graph.update_stream, + ) + self.ref_initialized = True + if self.graph.need_capture(infer_batch_size): infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode(self._token_forward, infer_state) else: - model_output: ModelOutput = self.graph.replay(infer_state) + model_output: ModelOutput = self.graph.replay( + infer_state, + b1_cu_q_seq_len_cpu_slice, + b_cu_kv_seq_len_cpu_slice, + ) model_output = self._create_unpad_decode_model_output(model_output, origin_batch_size=origin_batch_size) else: @@ -687,18 +721,18 @@ def _token_forward(self, infer_state: InferStateInfo): model_output.mtp_main_output_hiddens = input_embs.contiguous() # 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。 - if infer_state.is_cuda_graph: + if infer_state.is_cuda_graph and self.platform_backend.name != "ascend": model_output.to_no_ref_tensor() return model_output @torch.no_grad() def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: ModelInput): - model_input0.to_cuda() - model_input1.to_cuda() + model_input0.to_device(self.target_device) + model_input1.to_device(self.target_device) - assert model_input0.mem_indexes.is_cuda - assert model_input1.mem_indexes.is_cuda + assert model_input0.mem_indexes.device == self.target_device + assert model_input1.mem_indexes.device == self.target_device assert self.args.enable_tpsp_mix_mode origin_handle_token_num0 = model_input0.total_token_num - model_input0.prefix_total_token_num @@ -741,7 +775,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod infer_state1.init_some_extra_state(self) infer_state1.init_att_state() - prefill_mem_indexes_ready_event = torch.cuda.Event() + prefill_mem_indexes_ready_event = self.platform_backend.runtime.create_event() prefill_mem_indexes_ready_event.record() model_output0, model_output1 = self._overlap_tpsp_context_forward(infer_state0, infer_state1=infer_state1) @@ -765,8 +799,8 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod @torch.no_grad() def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput): - model_input0.to_cuda() - model_input1.to_cuda() + model_input0.to_device(self.target_device) + model_input1.to_device(self.target_device) assert self.args.enable_tpsp_mix_mode if model_input0.input_ids is None: @@ -783,8 +817,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode ) # TODO 动态 mtp fix assert model_input0.batch_size == model_input1.batch_size - assert model_input0.mem_indexes.is_cuda - assert model_input1.mem_indexes.is_cuda + assert model_input0.mem_indexes.device == self.target_device + assert model_input1.mem_indexes.device == self.target_device origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_kv_seq_len, model_input1.max_kv_seq_len) @@ -816,6 +850,10 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode infer_state1.init_some_extra_state(self) infer_state1.init_att_state() + batch_size = infer_state0.batch_size + b1_cu_q_seq_len_cpu_slice = self.b1_cu_q_seq_len_cpu_ref[:batch_size] + b_cu_kv_seq_len_cpu_slice = self.b_cu_kv_seq_len_cpu_ref[:batch_size] + if self.graph.need_capture(infer_batch_size): infer_state0.is_cuda_graph = True infer_state1.is_cuda_graph = True @@ -828,6 +866,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode else: model_output0, model_output1 = self.graph.replay( infer_state0, + b1_cu_q_seq_len_cpu_slice, + b_cu_kv_seq_len_cpu_slice, infer_state1=infer_state1, ) @@ -949,7 +989,7 @@ def _overlap_tpsp_token_forward(self, infer_state: InferStateInfo, infer_state1: model_output.mtp_main_output_hiddens = input_embs.contiguous() model_output1.mtp_main_output_hiddens = input_embs1.contiguous() - if infer_state.is_cuda_graph: + if infer_state.is_cuda_graph and self.platform_backend.name != "ascend": model_output.to_no_ref_tensor() model_output1.to_no_ref_tensor() @@ -969,15 +1009,15 @@ def _check_max_len_infer(self): # 模拟最大长度进行 prefill,观察是否出现 OOM try: logger.info("begin check max_len infer") - dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") - b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") + dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device=self.target_device) + b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device=self.target_device) + mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).to(device=self.target_device) + b_seq_len = torch.ones(1, dtype=torch.int32, device=self.target_device) b_seq_len[:] = self.batch_max_tokens - b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device=self.target_device) + b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device=self.target_device) total_token_num = self.batch_max_tokens - b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(1, dtype=torch.int32, device=self.target_device) model_input = ModelInput( batch_size=1, total_token_num=total_token_num, @@ -1041,19 +1081,19 @@ def _autotune_warmup(self): self.layers_num = self.autotune_layers() for input_len in tqdm(warmup_lengths, desc="warming up"): try: - rand_gen = torch.Generator(device="cuda") + rand_gen = torch.Generator(device=self.target_device) rand_gen.manual_seed(input_len) dummy_input_ids = torch.randint( - 0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen + 0, 10000, (input_len,), dtype=torch.int32, device=self.target_device, generator=rand_gen ) - b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() - b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") + b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device=self.target_device) + mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).to(device=self.target_device) + b_seq_len = torch.ones(1, dtype=torch.int32, device=self.target_device) b_seq_len[:] = input_len - b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device=self.target_device) + b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device=self.target_device) total_token_num = input_len - b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(1, dtype=torch.int32, device=self.target_device) model_input = ModelInput( batch_size=1, total_token_num=total_token_num, @@ -1082,14 +1122,14 @@ def _autotune_warmup(self): self.req_manager.free_all() self.mem_manager.free_all() gc.collect() - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() except Exception as e: logger.warning(f"autotune warmup for length {input_len} failed: {str(e)}") logger.exception(str(e)) self.req_manager.free_all() self.mem_manager.free_all() gc.collect() - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() self.layers_num = layer_num_bak torch.distributed.barrier() Autotuner.end_autotune_warmup() @@ -1106,19 +1146,19 @@ def _init_padded_req(self): # prefill init padding req. prefill_input_len = 1 batch_size = 1 - dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda") + dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device=self.target_device) b_req_idx = torch.tensor( - [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device=self.target_device ) mem_indexes = torch.tensor( - [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device=self.target_device ) - b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.ones(batch_size, dtype=torch.int32, device=self.target_device) + b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device=self.target_device) b_q_seq_len = b_seq_len - b_ready_cache_len b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len total_token_num = prefill_input_len * batch_size - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device=self.target_device) model_input = ModelInput( batch_size=batch_size, total_token_num=total_token_num, @@ -1149,7 +1189,7 @@ def _init_padded_req(self): del b_seq_len del b_ready_cache_len del model_output - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() return def _gen_special_model_input(self, token_num: int): @@ -1163,7 +1203,7 @@ def _gen_special_model_input(self, token_num: int): ) if is_mtp_draft_model: special_model_input["mtp_draft_input_hiddens"] = torch.randn( - token_num, self.config["hidden_size"], dtype=self.data_type, device="cuda" + token_num, self.config["hidden_size"], dtype=self.data_type, device=self.target_device ) else: special_model_input["mtp_draft_input_hiddens"] = None diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 758c0b5194..2bb6d82ba3 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -1,6 +1,6 @@ import torch from dataclasses import dataclass, field -from typing import Optional +from typing import Any, Optional from typing import List from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor @@ -47,28 +47,36 @@ class ModelInput: # 的 draft 模型的输入 mtp_draft_input_hiddens: Optional[torch.Tensor] = None - def to_cuda(self): + def to_device(self, device: torch.device): + """ from TpPartBaseModel.target_device """ + + def _to_device(t: torch.Tensor) -> torch.Tensor: + return t.to(device, non_blocking=True) + if self.input_ids is not None: - self.input_ids = self.input_ids.cuda(non_blocking=True) + self.input_ids = _to_device(self.input_ids) if self.mem_indexes is None: - self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True) - self.b_req_idx = self.b_req_idx.cuda(non_blocking=True) - self.b_seq_len = self.b_seq_len.cuda(non_blocking=True) - self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True) + self.mem_indexes = _to_device(self.mem_indexes_cpu) + + self.b_req_idx = _to_device(self.b_req_idx) + self.b_seq_len = _to_device(self.b_seq_len) + self.b_mtp_index = _to_device(self.b_mtp_index) + if self.b_ready_cache_len is not None: - self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True) + self.b_ready_cache_len = _to_device(self.b_ready_cache_len) if self.b_prefill_start_loc is not None: - self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True) + self.b_prefill_start_loc = _to_device(self.b_prefill_start_loc) + if not self.is_prefill and enable_diverse_mode_gqa_decode_fast_kernel(): batch_size = len(self.b_req_idx) if self.b_mark_shared_group is None: - self.b_mark_shared_group = torch.ones(size=(batch_size,), dtype=torch.int32, device="cuda") + self.b_mark_shared_group = torch.ones(size=(batch_size,), dtype=torch.int32, device=device) else: - self.b_mark_shared_group = self.b_mark_shared_group.cuda(non_blocking=True) + self.b_mark_shared_group = _to_device(self.b_mark_shared_group) if self.b_shared_seq_len is None: - self.b_shared_seq_len = torch.zeros(size=(batch_size,), dtype=torch.int32, device="cuda") + self.b_shared_seq_len = torch.zeros(size=(batch_size,), dtype=torch.int32, device=device) else: - self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True) + self.b_shared_seq_len = _to_device(self.b_shared_seq_len) def __post_init__(self): self.check_input() @@ -82,7 +90,7 @@ class ModelOutput: # 通用变量 logits: torch.Tensor # 用于判断 mem_indexes 是否成功写入 req manager 中的事件对象。 - prefill_mem_indexes_ready_event: torch.Event = None + prefill_mem_indexes_ready_event: Any = None # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输出变量。只在特殊的模型模式下才会具体使用和生效。 diff --git a/lightllm/common/basemodel/graph/__init__.py b/lightllm/common/basemodel/graph/__init__.py new file mode 100644 index 0000000000..4f6738d20f --- /dev/null +++ b/lightllm/common/basemodel/graph/__init__.py @@ -0,0 +1,14 @@ +from lightllm.common.basemodel.graph.acl_graph import AclGraph +from lightllm.common.basemodel.graph.base.decode_graph import DecodeGraph +from lightllm.common.basemodel.graph.cuda_graph import CudaGraph + +DECODE_GRAPH_MAP = { + "cuda": CudaGraph, + "musa": CudaGraph, + "ascend": AclGraph, + "maca": CudaGraph, +} + +DecodeGraph.PLATFORM_CLASS_MAP = DECODE_GRAPH_MAP + +__all__ = ["DecodeGraph"] diff --git a/lightllm/common/basemodel/graph/acl_graph.py b/lightllm/common/basemodel/graph/acl_graph.py new file mode 100644 index 0000000000..eae03e6f56 --- /dev/null +++ b/lightllm/common/basemodel/graph/acl_graph.py @@ -0,0 +1,97 @@ +import torch +from dataclasses import dataclass, field +from lightllm.common.basemodel.attention.paged_fa3.fp import update_attn_params +from lightllm.common.basemodel.batch_objs import ModelOutput +from lightllm.common.basemodel.graph.base.decode_graph import DecodeGraph +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.utils.log_utils import init_logger +from typing import Any, Optional + +logger = init_logger(__name__) + + +class AclGraph(DecodeGraph): + + def _init_decode_graph_extra(self): + init_attn_params(self.graph_batch_sizes) + self.update_stream = torch.npu.Stream() + + def _replay( + self, + infer_state: InferStateInfo, + b1_cu_q_seq_len_cpu: list[int], + b_cu_kv_seq_len_cpu: list[int], + ) -> ModelOutput: + graph_output = super()._replay(infer_state, b1_cu_q_seq_len_cpu, b_cu_kv_seq_len_cpu) + batch_size = infer_state.input_ids.shape[0] + update_attn_params( + batch_size, + b1_cu_q_seq_len_cpu, + b_cu_kv_seq_len_cpu.add_(1), + self.update_stream, + ) + return graph_output + + def _replay_overlap( + self, + infer_state: InferStateInfo, + infer_state1: InferStateInfo, + b1_cu_q_seq_len_cpu: list[int], + b_cu_kv_seq_len_cpu: list[int], + ): + graph_model_output, graph_model_output1 = super()._replay_overlap( + infer_state, infer_state1, b1_cu_q_seq_len_cpu, b_cu_kv_seq_len_cpu) + batch_size = infer_state.input_ids.shape[0] + update_attn_params( + batch_size, + b1_cu_q_seq_len_cpu, + b_cu_kv_seq_len_cpu.add_(1), + self.update_stream, + ) + return graph_model_output, graph_model_output1 + + def replay( + self, + infer_state: InferStateInfo, + b1_cu_q_seq_len_cpu: list[int], + b_cu_kv_seq_len_cpu: list[int], + infer_state1: Optional[InferStateInfo] = None, + ): + if self.enable_decode_microbatch_overlap: + return self._replay_overlap(infer_state, infer_state1, b1_cu_q_seq_len_cpu, b_cu_kv_seq_len_cpu) + assert infer_state1 is None + return self._replay(infer_state, b1_cu_q_seq_len_cpu, b_cu_kv_seq_len_cpu) + + +# Adapted from: https://github.com/vllm-project/vllm-ascend/blob/v0.11.0/vllm_ascend/compilation/acl_graph.py +@dataclass +class AclGraphParams: + handles: dict[int, list[Any]] = field(default_factory=dict) + events: dict[int, list[Any]] = field(default_factory=dict) + workspaces: dict[int, Any] = field(default_factory=dict) + attn_params: dict[int, list[tuple]] = field(default_factory=dict) + + +ATTN_PARAMS: Optional[AclGraphParams] = None + + +def init_attn_params(batch_sizes: list[int]): + global ATTN_PARAMS + ATTN_PARAMS = AclGraphParams( + handles={bs: [] for bs in batch_sizes}, + events={bs: [] for bs in batch_sizes}, + workspaces={bs: None for bs in batch_sizes}, + attn_params={bs: [] for bs in batch_sizes}, + ) + + +def get_attn_params(): + return ATTN_PARAMS + + +def add_attn_params(batch_size: int, event: Any, handle: Any, attn_params: tuple): + global ATTN_PARAMS + if ATTN_PARAMS is not None: + ATTN_PARAMS.handles[batch_size].append(handle) + ATTN_PARAMS.events[batch_size].append(event) + ATTN_PARAMS.attn_params[batch_size].append(attn_params) diff --git a/lightllm/common/basemodel/graph/base/__init__.py b/lightllm/common/basemodel/graph/base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/graph/base/decode_graph.py similarity index 63% rename from lightllm/common/basemodel/cuda_graph.py rename to lightllm/common/basemodel/graph/base/decode_graph.py index 782150661e..0429b2861b 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/graph/base/decode_graph.py @@ -1,103 +1,116 @@ -import os -import torch -import copy import bisect +import copy import triton -from typing import Optional +import torch +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args -from lightllm.distributed import dist_group_manager -from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput -from .infer_struct import InferStateInfo - +from typing import Optional logger = init_logger(__name__) -class CudaGraph: - # CudaGraph forward pass for the decoding stage. +class DecodeGraph: + + PLATFORM_CLASS_MAP: dict[str, type["DecodeGraph"]] = {} - def __init__(self, max_batch_size=8, max_len_in_batch=8192, tp_world_size: int = 1): - self.graph = {} - self.tp_world_size = tp_world_size - self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None - self.args = get_env_start_args() - self.mtp_step = self.args.mtp_step + def __new__( + cls, + max_batch_size: int, + max_len_in_batch: int, + tp_world_size: int = 1, + platform_backend: str = "cuda", + ): + if cls is not DecodeGraph: + return object.__new__(cls) + impl_cls = cls.PLATFORM_CLASS_MAP[platform_backend] + return object.__new__(impl_cls) + + def __init__( + self, + max_batch_size: int, + max_len_in_batch: int, + tp_world_size: int = 1, + platform_backend: str = "cuda", + ): + self.hardware_platform = platform_backend + self._init_decode_graph(max_batch_size, max_len_in_batch, tp_world_size) + self._init_decode_graph_extra() + + def _init_decode_graph_extra(self): + pass + + def _init_decode_graph(self, max_batch_size: int, max_len_in_batch: int, tp_world_size: int): + self.graph: dict[int, tuple] = {} + + args = get_env_start_args() + mtp_step = args.mtp_step self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch - self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.enable_decode_microbatch_overlap = args.enable_decode_microbatch_overlap - # gen cuda graph batch_sizes - # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] - # and [graph_split_batch_size + graph_grow_step_size, - # if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1) + self.platform_backend = get_backend() + self.target_device = self.platform_backend.runtime.target_device() - graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1) - graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) + self.mempool = self.platform_backend.graph.graph_pool_handle() - batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] + graph_split_batch_size = args.graph_split_batch_size * (mtp_step + 1) + graph_grow_step_size = args.graph_grow_step_size * (mtp_step + 1) + batch_sizes = [i * (mtp_step + 1) for i in range(1, graph_split_batch_size + 1)] for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size): batch_sizes.append(_batch_size) - batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size])) batch_sizes.append(max_batch_size) batch_sizes.sort() - if self.args.enable_tpsp_mix_mode: - batch_sizes = [triton.cdiv(e, self.tp_world_size) * self.tp_world_size for e in batch_sizes] + + if args.enable_tpsp_mix_mode: + batch_sizes = [triton.cdiv(e, tp_world_size) * tp_world_size for e in batch_sizes] batch_sizes = list(set(batch_sizes)) batch_sizes.sort() - - self.cuda_graph_batch_sizes = batch_sizes + + self.graph_batch_sizes = batch_sizes assert batch_sizes[-1] == self.max_batch_size - logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") + logger.info(f"decode graph batch_sizes: {self.graph_batch_sizes}") - def can_run(self, batch_size, max_len_in_batch): + def can_run(self, batch_size: int, max_len_in_batch: int) -> bool: return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch - def need_capture(self, batch_size): + def need_capture(self, batch_size: int) -> bool: find_batch_size = self.find_closest_graph_batch_size(batch_size) if find_batch_size is not None: return find_batch_size not in self.graph else: assert False, "dead code" - def find_closest_graph_batch_size(self, batch_size): - index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) - if index < len(self.cuda_graph_batch_sizes): - find_batch_size = self.cuda_graph_batch_sizes[index] + def find_closest_graph_batch_size(self, batch_size: int) -> Optional[int]: + index = bisect.bisect_left(self.graph_batch_sizes, batch_size) + if index < len(self.graph_batch_sizes): + find_batch_size = self.graph_batch_sizes[index] return find_batch_size else: return None - def _capture_decode(self, decode_func, infer_state: InferStateInfo): - graph_obj = torch.cuda.CUDAGraph() - input_ids = infer_state.input_ids - batch_size = input_ids.shape[0] + def _capture_decode(self, decode_func, infer_state: InferStateInfo) -> ModelOutput: + graph_obj = self.platform_backend.graph.create_graph() + batch_size = infer_state.input_ids.shape[0] infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size # warmup - # 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上 - # 做一些初始化的操作,后续层可以复用这些计算的结果,如 - # lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py - # 中做的一些操作,所以在 warmup 的时候,需要调用infer_state的copy函数做一个 - # 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性, - # 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象 - # 中的 tensor。 - for _ in range(1): - # 记录原始存在的变量 pure_para_set = set(vars(infer_state).keys()) - torch.cuda.synchronize() + self.platform_backend.runtime.synchronize() decode_func(copy.copy(infer_state)) - torch.cuda.synchronize() + self.platform_backend.runtime.synchronize() for param_name in set(vars(infer_state).keys()): if param_name not in pure_para_set: delattr(infer_state, param_name) - with torch.cuda.graph(graph_obj, pool=self.mempool): + with self.platform_backend.graph.graph(graph_obj, pool=self.mempool): model_output = decode_func(infer_state) self.graph[batch_size] = (graph_obj, infer_state, model_output) - graph_obj.replay() + return model_output def _capture_decode_overlap( @@ -105,39 +118,31 @@ def _capture_decode_overlap( decode_func, infer_state: InferStateInfo, infer_state1: InferStateInfo, - ): - graph_obj = torch.cuda.CUDAGraph() - input_ids = infer_state.input_ids - batch_size = input_ids.shape[0] + ) -> tuple[ModelOutput, ModelOutput]: + graph_obj = self.platform_backend.graph.create_graph() + batch_size = infer_state.input_ids.shape[0] infer_state.max_kv_seq_len = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size infer_state1.max_kv_seq_len = self.graph_max_len_in_batch infer_state1.total_token_num = self.graph_max_len_in_batch * batch_size # warmup for _ in range(1): - # 记录原始存在的变量 pure_para_set = set(vars(infer_state).keys()) pure_para_set1 = set(vars(infer_state1).keys()) - torch.cuda.synchronize() + self.platform_backend.runtime.synchronize() decode_func(copy.copy(infer_state), copy.copy(infer_state1)) - torch.cuda.synchronize() - for para_name in set(vars(infer_state).keys()): - if para_name not in pure_para_set: - delattr(infer_state, para_name) - for para_name in set(vars(infer_state1).keys()): - if para_name not in pure_para_set1: - delattr(infer_state1, para_name) - - with torch.cuda.graph(graph_obj, pool=self.mempool): + self.platform_backend.runtime.synchronize() + for param_name in set(vars(infer_state).keys()): + if param_name not in pure_para_set: + delattr(infer_state, param_name) + for param_name in set(vars(infer_state1).keys()): + if param_name not in pure_para_set1: + delattr(infer_state1, param_name) + + with self.platform_backend.graph.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(infer_state, infer_state1) - self.graph[batch_size] = ( - graph_obj, - infer_state, - infer_state1, - model_output, - model_output1, - ) - graph_obj.replay() + self.graph[batch_size] = (graph_obj, infer_state, infer_state1, model_output, model_output1) + return model_output, model_output1 def capture_decode( @@ -145,28 +150,32 @@ def capture_decode( decode_func, infer_state: InferStateInfo, infer_state1: Optional[InferStateInfo] = None, - ): - """ - Capture the cuda graph for the decoding stage. - input_ids1 and infer_state1 is used for the overlap. - """ + ) -> tuple[ModelOutput, ModelOutput]: if self.enable_decode_microbatch_overlap: return self._capture_decode_overlap(decode_func, infer_state, infer_state1) else: assert infer_state1 is None return self._capture_decode(decode_func, infer_state) - def _replay(self, infer_state: InferStateInfo): + def _replay( + self, + infer_state: InferStateInfo, + b1_cu_q_seq_len_cpu: list[int], + b_cu_kv_seq_len_cpu: list[int], + ) -> ModelOutput: batch_size = infer_state.input_ids.shape[0] graph_obj, graph_infer_state, graph_output = self.graph[batch_size] graph_infer_state.copy_for_cuda_graph(infer_state) - graph_obj.replay() + self.platform_backend.graph.replay_graph(graph_obj) + return graph_output def _replay_overlap( self, infer_state: InferStateInfo, infer_state1: InferStateInfo, + b1_cu_q_seq_len_cpu: list[int], + b_cu_kv_seq_len_cpu: list[int], ): batch_size = infer_state.input_ids.shape[0] ( @@ -178,37 +187,38 @@ def _replay_overlap( ) = self.graph[batch_size] graph_infer_state.copy_for_cuda_graph(infer_state) graph_infer_state1.copy_for_cuda_graph(infer_state1) - graph_obj.replay() + self.platform_backend.graph.replay_graph(graph_obj) + return graph_model_output, graph_model_output1 - def replay(self, infer_state, infer_state1=None): + def replay(self, infer_state, b1_cu_q_seq_len_cpu: list[int], b_cu_kv_seq_len_cpu: list[int], infer_state1=None): if self.enable_decode_microbatch_overlap: - return self._replay_overlap(infer_state, infer_state1) + return self._replay_overlap(infer_state, infer_state1, b1_cu_q_seq_len_cpu, b_cu_kv_seq_len_cpu) else: assert infer_state1 is None - return self._replay(infer_state) + return self._replay(infer_state, b1_cu_q_seq_len_cpu, b_cu_kv_seq_len_cpu) @torch.no_grad() def warmup(self, model): logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.") # for typing easy - from .basemodel import TpPartBaseModel + from lightllm.common.basemodel.basemodel import TpPartBaseModel model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: - seq_len = 2 + for batch_size in self.graph_batch_sizes[::-1]: + seq_len = self.graph_max_len_in_batch total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() + input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device=self.target_device) + mem_indexes = model.mem_manager.alloc(len(input_ids)).to(self.target_device) b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device=self.target_device ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device=self.target_device) b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device=self.target_device) model_input = ModelInput( batch_size=batch_size, @@ -237,7 +247,7 @@ def warmup(self, model): for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() logger.info( f"Capture cudagraph success, batch_size <={self.max_batch_size} " @@ -248,25 +258,25 @@ def warmup(self, model): def warmup_overlap(self, model): logger.info("Begin capture overlap cudagraph, use the --disable_cudagraph to disable it.") # for typing easy - from .basemodel import TpPartBaseModel + from lightllm.common.basemodel.basemodel import TpPartBaseModel model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + for batch_size in self.graph_batch_sizes[::-1]: decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph - seq_len = 2 + seq_len = self.graph_max_len_in_batch total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch - input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() + input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device=self.target_device) + mem_indexes = model.mem_manager.alloc(len(input_ids)).to(self.target_device) b_req_idx = torch.tensor( - [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" + [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device=self.target_device ) - b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") + b_seq_len = torch.empty(batch_size, dtype=torch.int32, device=self.target_device) b_seq_len.fill_(seq_len) - b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device=self.target_device) micro_batch = ModelInput( is_prefill=False, @@ -288,7 +298,7 @@ def warmup_overlap(self, model): for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() _, _ = model.microbatch_overlap_decode(decode_batches[0], decode_batches[1]) @@ -301,7 +311,7 @@ def warmup_overlap(self, model): for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() logger.info( f"Capture overlap cudagraph success, batch_size <={self.max_batch_size} " diff --git a/lightllm/common/basemodel/graph/cuda_graph.py b/lightllm/common/basemodel/graph/cuda_graph.py new file mode 100644 index 0000000000..6e458a8651 --- /dev/null +++ b/lightllm/common/basemodel/graph/cuda_graph.py @@ -0,0 +1,5 @@ +from lightllm.common.basemodel.graph.base.decode_graph import DecodeGraph + +CudaGraph = DecodeGraph + +__all__ = ["CudaGraph"] diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 711484c835..01cbc93d43 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -5,7 +5,9 @@ from lightllm.common.req_manager import ReqManager from lightllm.distributed import CustomProcessGroup from typing import Tuple, Any, Optional, List -from .triton_kernel.gen_prefill_params import gen_prefill_params + +from lightllm.platform import get_backend +from .triton_kernel.gen_prefill_params import gen_prefill_params from .triton_kernel.gen_decode_params import gen_decode_params from .triton_kernel.multimodal_emb import mark_multimodal_obj from .batch_objs import ModelInput @@ -14,6 +16,32 @@ from .attention import BasePrefillAttState, BaseDecodeAttState +class SeqLenManager: + def __init__(self, max_batch: int): + self.max_batch = max_batch + + self.b1_cu_q_seq_len_cpu = torch.empty( + max_batch, dtype=torch.int32, device='cpu', pin_memory=True) + self.b_cu_kv_seq_len_cpu = torch.empty( + max_batch, dtype=torch.int32, device='cpu', pin_memory=True) + + self.b_cu_q_seq_len_list = None + self.b_cu_kv_seq_len_list = None + + def update(self, b1_cu_q_seq_len: torch.Tensor, b_cu_kv_seq_len: torch.Tensor): + n_q = b1_cu_q_seq_len.numel() - 1 + n_kv = b_cu_kv_seq_len.numel() + + self.b1_cu_q_seq_len_cpu[:n_q].copy_(b1_cu_q_seq_len[1:], non_blocking=False) + self.b_cu_kv_seq_len_cpu[:n_kv].copy_(b_cu_kv_seq_len, non_blocking=False) + + self.n_q = n_q + self.n_kv = n_kv + + def get_tensor_slices(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.b1_cu_q_seq_len_cpu[:self.n_q], self.b_cu_kv_seq_len_cpu[:self.n_kv] + + class InferStateInfo: """ 推理时用的信息结构体 @@ -68,7 +96,9 @@ def __init__(self): # b1 开头的tensor变量其shape为[batch_size + 1,] self.b_q_seq_len: torch.Tensor = None self.b1_cu_q_seq_len: torch.Tensor = None + self.b1_cu_q_seq_len_cpu: torch.Tensor = None self.b_kv_seq_len: torch.Tensor = None + self.b_kv_seq_len_cpu: torch.Tensor = None self.b1_cu_kv_seq_len: torch.Tensor = None self.position_ids: torch.Tensor = None self.max_q_seq_len: int = None @@ -100,6 +130,12 @@ def __init__(self): self.dp_output_split_sizes: List[List[int]] = None self.dp_input_split_sizes: List[List[int]] = None + self.platform_backend = get_backend() + + if self.platform_backend.name == "ascend": + args = get_env_start_args() + self.seq_len_manager = SeqLenManager(args.running_max_req_size + 1) + def init_some_extra_state(self, model): if self.is_prefill: ( @@ -123,6 +159,12 @@ def init_some_extra_state(self, model): self.position_ids, ) = gen_decode_params(self.b_seq_len) self.b_kv_start_loc = self.b1_cu_kv_seq_len[0:-1] + if self.platform_backend.name == "ascend": + self.seq_len_manager.update(self.b1_cu_q_seq_len, self.b_kv_seq_len) + self.b1_cu_q_seq_len_cpu, self.b_cu_kv_seq_len_cpu = self.seq_len_manager.get_tensor_slices() + else: + self.b1_cu_q_seq_len_cpu = [] + self.b_cu_kv_seq_len_cpu = [] def init_att_state(self): if self.is_prefill: @@ -172,8 +214,8 @@ def prepare_prefill_dp_balance(self): args = get_env_start_args() - dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32) - input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32) + dp_input_lens = torch.empty(size=(args.dp,), device=input_ids.device, dtype=torch.int32) + input_len = torch.empty(size=(1,), device=input_ids.device, dtype=torch.int32) input_len.fill_(len(input_ids)) dist.all_gather_into_tensor( output_tensor=dp_input_lens, @@ -302,7 +344,7 @@ def _all_to_all_balance_get(self, data: torch.Tensor, change_state: bool = True) dest_data = g_cache_manager.alloc_tensor( shape=(handle_len * scale_size,), data_type=data.dtype, - device="cuda", + device=data.device, ) dist.all_to_all_single( output=dest_data.view(-1), @@ -332,7 +374,7 @@ def _all_to_all_unbalance_get(self, data: torch.Tensor, change_state: bool = Tru origin_data = g_cache_manager.alloc_tensor( shape=(origin_len * scale_size,), data_type=data.dtype, - device="cuda", + device=data.device, ) dist.all_to_all_single( output=origin_data.view(-1), @@ -348,19 +390,19 @@ def _all_to_all_unbalance_get(self, data: torch.Tensor, change_state: bool = Tru def prefill_cuda_graph_create_graph_obj(self): if not hasattr(self, "prefill_cuda_graph_exe_list"): self.prefill_cuda_graph_exe_list = [] - graph_obj = torch.cuda.CUDAGraph() - capture_graph = torch.cuda.graph(graph_obj, pool=self.mem_pool) + graph_obj = self.platform_backend.graph.create_graph() + capture_graph = self.platform_backend.graph.graph(graph_obj, pool=self.mem_pool) self.prefill_cuda_graph_exe_list.append((graph_obj, capture_graph)) return - def prefill_cuda_graph_get_current_capture_graph(self) -> torch.cuda.graph: + def prefill_cuda_graph_get_current_capture_graph(self) -> Any: assert len(self.prefill_cuda_graph_exe_list) > 0, "no cuda graph exe obj found" if isinstance(self.prefill_cuda_graph_exe_list[-1], tuple): return self.prefill_cuda_graph_exe_list[-1][1] else: return self.prefill_cuda_graph_exe_list[-2][1] - def prefill_cuda_graph_add_cpu_runnning_func(self, func, after_graph: torch.cuda.graph): + def prefill_cuda_graph_add_cpu_runnning_func(self, func, after_graph: Any): if not hasattr(self, "prefill_cuda_graph_exe_list"): self.prefill_cuda_graph_exe_list = [] if after_graph is None: @@ -377,7 +419,7 @@ def prefill_replay(self, new_infer_state: "InferStateInfo"): for func in self.prefill_cuda_graph_exe_list: if isinstance(func, tuple): graph_obj, _ = func - graph_obj.replay() + self.platform_backend.graph.replay_graph(graph_obj) else: func(new_infer_state) return diff --git a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py index 7889e8090e..86c24b8c6e 100644 --- a/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py +++ b/lightllm/common/basemodel/layer_infer/cache_tensor_manager.py @@ -45,17 +45,27 @@ def __init__(self): self.free_shape_dtype_to_bufs: Dict[Tuple, List[BufNode]] = collections.defaultdict(list) self.calcu_shape_cache: Dict[torch.Size, int] = {} self.changed_ptr: Set[int] = set() - from torch._C import _storage_Use_Count as use_count - - # use_count 函数可以用于获取有多少 tensor 真正引用了这片显存 tensor - self.use_count = use_count + # lazy init use_count to avoid get_backend() call error in constructor + self.use_count = None self.managed_total_tensor_bytes = 0 # 防止误用导致显存泄露,添加标记变量。 # 当使用者没有合法的调用 cache_env_in 和 cache_env_out 的时候 # 如果调用了alloc_tensor 接口,则退化为 torch.empty 申请方式。 self.cache_env_ok = False + def _ensure_use_count(self): + if self.use_count is not None: + return + from lightllm.platform import get_backend + + if get_backend().name == "ascend": + from torch_npu._C import _storage_Use_Count as use_count + else: + from torch._C import _storage_Use_Count as use_count + self.use_count = use_count + def cache_env_in(self): + self._ensure_use_count() self.managed_total_tensor_bytes = 0 setattr(torch.Tensor, "__del__", custom_del) self.cache_env_ok = True diff --git a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py index 04f8cda16b..407cb3a5b5 100644 --- a/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/pre_layer_infer_template.py @@ -1,4 +1,5 @@ import torch +from lightllm.platform import get_backend from ..pre_layer_infer import PreLayerInfer @@ -8,6 +9,8 @@ class PreLayerInferTpl(PreLayerInfer): def __init__(self, network_config): super().__init__(network_config) self.eps_ = 1e-5 + self.platform_backend = get_backend() + self.target_device = self.platform_backend.runtime.target_device() return def _norm(self, input, infer_state, layer_weight) -> torch.Tensor: diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 276b5856f9..bfcadc11e0 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -101,7 +101,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh def _context_attention_wrapper_run( self, q: torch.Tensor, cache_kv: torch.Tensor, infer_state: InferStateInfo, layer_weight ) -> torch.Tensor: - if torch.cuda.is_current_stream_capturing(): + if self.platform_backend.graph.is_capturing(): q = q.contiguous() cache_kv = cache_kv.contiguous() _q, _cache_kv = ( @@ -113,7 +113,7 @@ def _context_attention_wrapper_run( def get_o_shape_dtype_device(): # 在一个新的 graph 中尝试运行,并不是为了捕获图,是为了尝试得到 o 的形状等信息 - with torch.cuda.graph(cuda_graph=torch.cuda.CUDAGraph()): + with self.platform_backend.graph.graph(graph_obj=self.platform_backend.graph.create_graph()): __o = self._context_attention_kernel(_q, _cache_kv, infer_state, layer_weight) o_shape = __o.shape o_dtype = __o.dtype @@ -123,7 +123,7 @@ def get_o_shape_dtype_device(): import gc gc.collect() - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() return o_shape, o_dtype, o_device o_shape, o_dtype, o_device = get_o_shape_dtype_device() diff --git a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py index 53daffcddf..c2b742326f 100644 --- a/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py +++ b/lightllm/common/basemodel/layer_infer/transformer_layer_infer.py @@ -1,3 +1,4 @@ +from lightllm.platform import get_backend from .base_layer_infer import BaseLayerInfer @@ -8,4 +9,6 @@ def __init__(self, layer_num, network_config): super().__init__() self.layer_num_ = layer_num self.network_config_ = network_config + self.platform_backend = get_backend() + self.target_device = self.platform_backend.runtime.target_device() return diff --git a/lightllm/common/basemodel/layer_weights/base_layer_weight.py b/lightllm/common/basemodel/layer_weights/base_layer_weight.py index b1d992a7c4..d6814c4702 100644 --- a/lightllm/common/basemodel/layer_weights/base_layer_weight.py +++ b/lightllm/common/basemodel/layer_weights/base_layer_weight.py @@ -2,7 +2,8 @@ import numpy as np import threading from lightllm.common.basemodel.layer_weights.meta_weights import BaseWeight -from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size +from lightllm.platform import get_backend +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size class BaseLayerWeight: @@ -10,6 +11,8 @@ def __init__(self): self.tp_rank_ = get_current_rank_in_dp() self.tp_world_size_ = get_dp_world_size() self.lock = threading.Lock() + platform_backend = get_backend() + self.target_device = platform_backend.runtime.target_device() def load_hf_weights(self, weights): """ @@ -39,5 +42,5 @@ def verify_load(self): layer_num = None assert attr.verify_load(), f"Loading {attr_name} of layers {layer_num} fails." - def _cuda(self, cpu_tensor): - return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) + def _to_device(self, cpu_tensor: torch.Tensor) -> torch.Tensor: + return cpu_tensor.contiguous().to(device=self.target_device) diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad6..8616ee4ede 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -3,6 +3,7 @@ import gc from safetensors import safe_open from tqdm import tqdm +from lightllm.platform import get_backend import lightllm.utils.petrel_helper as utils from lightllm.utils.dist_utils import get_current_device_id @@ -11,7 +12,7 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay # fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug import torch.distributed as dist - torch.cuda.set_device(get_current_device_id()) + get_backend().runtime.set_device(get_current_device_id()) if use_safetensors: weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py index 2013d55be0..42d1d89db4 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/att_sink_weight.py @@ -1,7 +1,6 @@ import torch from typing import Dict, Tuple from .base_weight import BaseWeightTpl -from lightllm.utils.dist_utils import get_current_device_id class TpAttSinkWeight(BaseWeightTpl): @@ -15,7 +14,7 @@ def __init__(self, all_q_head_num: int, weight_name: str, data_type): def _create_weight(self): self.weight = torch.empty( - (self._end_head_index - self._start_head_index,), dtype=self.data_type_, device="cuda" + (self._end_head_index - self._start_head_index,), dtype=self.data_type_, device=self.target_device ) self.weight.load_ok = False @@ -24,9 +23,8 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): return t_weight = weights[self.weight_name] - self.weight = ( - t_weight[self._start_head_index : self._end_head_index].to(self.data_type_).cuda(get_current_device_id()) - ) + self.weight = t_weight[self._start_head_index : self._end_head_index].to( + device=self.target_device, dtype=self.data_type_) self.weight.load_ok = True def verify_load(self): diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py index 714e7acf48..fdf84c33a9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py @@ -1,6 +1,7 @@ import torch from abc import ABC, abstractmethod from typing import Dict, Tuple +from lightllm.platform import get_backend from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp, get_current_device_id @@ -25,10 +26,12 @@ def verify_load(self) -> bool: class BaseWeightTpl(BaseWeight): def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: torch.dtype = None): super().__init__() + self.platform_backend = get_backend() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.device_id_ = get_current_device_id() self.data_type_ = data_type + self.target_device = self.platform_backend.runtime.target_device(self.device_id_) def load_hf_weights(self, weights): raise NotImplementedError("load_hf_weights must implement this method") diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py index d94a4c709b..6bbccaf90c 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/embedding_weight.py @@ -1,13 +1,12 @@ import torch import numpy as np from typing import Dict, Optional + from .base_weight import BaseWeightTpl -from .platform_op import PlatformAwareOp from lightllm.common.basemodel.triton_kernel.embedding import embedding as embedding_kernel -from lightllm.utils.dist_utils import get_dp_world_size, get_current_rank_in_dp -class EmbeddingWeight(BaseWeightTpl, PlatformAwareOp): +class EmbeddingWeight(BaseWeightTpl): def __init__(self, dim: int, vocab_size: int, weight_name: str, data_type: torch.dtype): super().__init__() self.dim = dim @@ -81,7 +80,14 @@ def _musa_forward( def __call__( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - return self._forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + return self.platform_backend.ops.embedding( + input_ids=input_ids, + weight=self.weight, + out=out, + alloc_func=alloc_func, + vob_start_id=self.tp_vocab_start_id, + vob_end_id=self.tp_vocab_end_id, + ) class LMHeadWeight(EmbeddingWeight): @@ -143,10 +149,15 @@ def _cuda_forward( return out def __call__(self, input: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty) -> torch.Tensor: - return self._forward(input=input, out=out, alloc_func=alloc_func) + return self.platform_backend.ops.lm_head( + input=input, + weight=self.weight, + out=out, + alloc_func=alloc_func, + ) -class NoTpPosEmbeddingWeight(BaseWeightTpl, PlatformAwareOp): +class NoTpPosEmbeddingWeight(BaseWeightTpl): def __init__(self, dim: int, max_position_embeddings: int, weight_name: str, data_type: torch.dtype): super().__init__() self.dim = dim @@ -206,4 +217,11 @@ def _cuda_forward( def __call__( self, input_ids: torch.Tensor, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - return self._forward(input_ids=input_ids, out=out, alloc_func=alloc_func) + return self.platform_backend.ops.embedding( + input_ids=input_ids, + weight=self.weight, + out=out, + alloc_func=alloc_func, + vob_start_id=0, + vob_end_id=self.max_position_embeddings, + ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..0b274fc075 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -79,8 +79,10 @@ def _init_redundancy_expert_params(self): self.redundancy_expert_num = get_redundancy_expert_num() self.redundancy_expert_ids = get_redundancy_expert_ids(self.layer_num_) self.auto_update_redundancy_expert: bool = get_env_start_args().auto_update_redundancy_expert - self.redundancy_expert_ids_tensor = torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cuda") - self.routed_expert_counter_tensor = torch.zeros((self.n_routed_experts,), dtype=torch.int64, device="cuda") + self.redundancy_expert_ids_tensor = torch.tensor( + self.redundancy_expert_ids, dtype=torch.int64, device=self.target_device) + self.routed_expert_counter_tensor = torch.zeros( + (self.n_routed_experts,), dtype=torch.int64, device=self.target_device) # TODO: find out the reason of failure of deepep when redundancy_expert_num is 1. assert self.redundancy_expert_num != 1, "redundancy_expert_num can not be 1 for some unknown hang of deepep." @@ -278,7 +280,7 @@ def _create_weight(self): self.e_score_correction_bias = torch.empty( (self.n_routed_experts,), dtype=self.data_type_, - device=f"cuda:{self.device_id_}", + device=self.target_device, ) self.w13, w13_param_list = self.quant_method.create_moe_weight( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py index 6ed0cef0b4..a96fb727d2 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gpt_oss_fused_moe_weight_tp.py @@ -102,7 +102,7 @@ def load_hf_weights(self, weights): scales=weights[self._down_scales_name], dtype=torch.bfloat16, )[:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :] - self.w2 = (self._cuda(w2.transpose(1, 2)), None) + self.w2 = (self._to_device(w2.transpose(1, 2)), None) if ( weights.get(self._gate_up_blocks_name, None) is not None @@ -113,17 +113,17 @@ def load_hf_weights(self, weights): scales=weights[self._gate_up_scales_name], dtype=torch.bfloat16, )[:, :, self.split_inter_size * self.tp_rank_ * 2 : self.split_inter_size * (self.tp_rank_ + 1) * 2] - self.w1 = (self._cuda(w1.transpose(1, 2)), None) + self.w1 = (self._to_device(w1.transpose(1, 2)), None) if weights.get(self._gate_up_bias_name, None) is not None: w1_bias = weights[self._gate_up_bias_name][ :, self.split_inter_size * self.tp_rank_ * 2 : self.split_inter_size * (self.tp_rank_ + 1) * 2 ] - self.w1_bias = self._cuda(w1_bias) + self.w1_bias = self._to_device(w1_bias) if weights.get(self._down_bias_name, None) is not None: w2_bias = weights[self._down_bias_name] - self.w2_bias = self._cuda(w2_bias) + self.w2_bias = self._to_device(w2_bias) def verify_load(self): assert self.w1 is not None and self.w2 is not None @@ -188,9 +188,9 @@ def _convert_moe_packed_tensors( import math # Check if blocks and scales are on CPU, and move to GPU if so - if not blocks.is_cuda and torch.cuda.is_available(): - blocks = blocks.cuda() - scales = scales.cuda() + if blocks.device != self.target_device: + blocks = blocks.to(self.target_device) + scales = scales.to(self.target_device) scales = scales.to(torch.int32) - 127 # that's because 128=2**7 @@ -227,5 +227,5 @@ def _convert_moe_packed_tensors( del blocks, scales, lut return out.transpose(1, 2).contiguous() - def _cuda(self, cpu_tensor): - return cpu_tensor.contiguous().to(self.data_type_).cuda(get_current_device_id()) + def _to_device(self, cpu_tensor: torch.Tensor) -> torch.Tensor: + return cpu_tensor.contiguous().to(device=self.target_device) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index bdd86eb51e..1340b9d5fd 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -253,7 +253,7 @@ def prefilled_group_gemm( num_recv_tokens_per_expert = torch.tensor( num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" - ).cuda(non_blocking=True) + ).to(device=recv_topk_idx.device, non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 6391a10800..e467a60664 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -18,7 +18,7 @@ def create_workspace(self): marlin_make_workspace_new, ) - return marlin_make_workspace_new(torch.device("cuda"), 4) + return marlin_make_workspace_new(self.quant_method.target_device, 4) def _fused_experts( self, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index d6e923a115..208ae81424 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -66,7 +66,7 @@ def _select_experts( end=self.n_routed_experts + self.num_fused_shared_experts, step=1, dtype=topk_ids.dtype, - device="cuda", + device=topk_ids.device, ) .view(1, self.num_fused_shared_experts) .repeat(topk_ids.shape[0], 1) @@ -74,7 +74,7 @@ def _select_experts( pad_topk_weights = torch.full( (topk_weights.shape[0], self.num_fused_shared_experts), fill_value=1.0, - device="cuda", + device=topk_weights.device, dtype=topk_weights.dtype, ) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 5021699143..cabcf73f7d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -92,7 +92,7 @@ def load_hf_weights(self, weights): def _create_weight(self): self.bias = None if self.bias_names is not None: - self.bias = torch.empty(sum(self.out_dims), dtype=self.data_type_).cuda(get_current_device_id()) + self.bias = torch.empty(sum(self.out_dims), dtype=self.data_type_).to(device=self.target_device) # bias_list shares storage with bias for each output shard self.bias_list = torch.split(self.bias, self.out_dims, dim=0) for sub_bias in self.bias_list: @@ -200,7 +200,7 @@ def __init__( return def _create_weight(self): - self.weight = torch.empty(self.dim0, self.dim1, self.dim2, dtype=self.data_type_).cuda(get_current_device_id()) + self.weight = torch.empty(self.dim0, self.dim1, self.dim2, dtype=self.data_type_).to(device=self.target_device) self.weight.load_ok = False return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index ee9d1923c3..4696347927 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -1,15 +1,15 @@ import torch from typing import Optional, Dict + from .base_weight import BaseWeightTpl -from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size +from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward -from .platform_op import PlatformAwareOp -class RMSNormWeight(BaseWeightTpl, PlatformAwareOp): +class RMSNormWeight(BaseWeightTpl): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype): super().__init__(tp_rank=0, tp_world_size=1) self.dim = dim @@ -69,7 +69,13 @@ def _musa_forward( def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + return self.platform_backend.ops.rms_norm( + input=input, + weight=self.weight, + eps=eps, + out=out, + alloc_func=alloc_func, + ) class GatedRMSNormWeight(RMSNormWeight): @@ -118,10 +124,17 @@ def __call__( out: Optional[torch.Tensor] = None, alloc_func=torch.empty, ) -> torch.Tensor: - return self._forward(input=input, gate_value=gate_value, eps=eps, out=out, alloc_func=alloc_func) + return self.platform_backend.ops.rms_norm( + input=input, + weight=self.weight, + eps=eps, + out=out, + alloc_func=alloc_func, + gate_value=gate_value, + ) -class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): +class LayerNormWeight(BaseWeightTpl): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): super().__init__(tp_rank=0, tp_world_size=1) self.dim = dim @@ -185,7 +198,14 @@ def _musa_forward( def __call__( self, input: torch.Tensor, eps: float, out: Optional[torch.Tensor] = None, alloc_func=torch.empty ) -> torch.Tensor: - return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) + return self.platform_backend.ops.layer_norm( + input=input, + weight=self.weight, + bias=self.bias, + eps=eps, + out=out, + alloc_func=alloc_func, + ) class TpRMSNormWeight(RMSNormWeight): @@ -246,7 +266,7 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): self.weight.load_ok = True -class QKRMSNORMWeight(BaseWeightTpl, PlatformAwareOp): +class QKRMSNORMWeight(BaseWeightTpl): def __init__(self, dim: int, q_weight_name: str, k_weight_name: str, data_type: torch.dtype): super().__init__(tp_rank=0, tp_world_size=1) self.dim = dim @@ -325,8 +345,14 @@ def __call__( q: torch.Tensor, k: torch.Tensor, eps: float, - ) -> None: - return self._forward(q=q, k=k, eps=eps) + ) -> tuple: + return self.platform_backend.ops.qk_rms_norm( + q=q, + k=k, + w_q=self.q_weight, + w_k=self.k_weight, + eps=eps, + ) class QKGEMMANormWeight(QKRMSNORMWeight): @@ -347,3 +373,13 @@ def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple # See https://github.com/huggingface/transformers/pull/29402 # So we need to set fp32_multiply to True here. return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps, fp32_multiply=True) + + def __call__(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple: + return self.platform_backend.ops.qk_rms_norm( + q=q, + k=k, + w_q=self.q_weight, + w_k=self.k_weight, + eps=eps, + fp32_multiply=True, + ) diff --git a/lightllm/common/basemodel/prefill_cuda_graph.py b/lightllm/common/basemodel/prefill_cuda_graph.py index 3c53a1b81c..fb7d4f1abb 100644 --- a/lightllm/common/basemodel/prefill_cuda_graph.py +++ b/lightllm/common/basemodel/prefill_cuda_graph.py @@ -3,29 +3,34 @@ import copy import bisect import triton -from typing import List, Tuple -from typing import Optional +from typing import List, Tuple, Optional +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor from lightllm.distributed import dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.common.basemodel.graph.base.decode_graph import DecodeGraph from .infer_struct import InferStateInfo -from .cuda_graph import CudaGraph logger = init_logger(__name__) class PrefillCudaGraph: - # CudaGraph forward pass for the decoding stage. - def __init__(self, decode_cuda_graph: CudaGraph, tp_world_size: int): + def __init__(self, decode_cuda_graph: Optional[DecodeGraph], tp_world_size: int): self.graph = {} self.tp_world_size = tp_world_size + if decode_cuda_graph is not None: + self.platform_backend = decode_cuda_graph.platform_backend + self.target_device = decode_cuda_graph.target_device self.mempool = decode_cuda_graph.mempool # prefill 和 decode 共享一个 mempool else: - self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None + self.platform_backend = get_backend() + self.target_device = self.platform_backend.runtime.target_device() + self.mempool = self.platform_backend.graph.graph_pool_handle() \ + if self.platform_backend.runtime.is_available() else None self.args = get_env_start_args() self.enable_prefill_microbatch_overlap = self.args.enable_prefill_microbatch_overlap @@ -165,14 +170,14 @@ def warmup(self, model): for handle_token_num in self.graph_handle_token_nums[::-1]: logger.info(f"Capture prefill cudagraph, handle_token_num: {handle_token_num}") total_token_num = handle_token_num - input_ids = torch.tensor([1 for _ in range(total_token_num)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor([model.req_manager.HOLD_REQUEST_ID], dtype=torch.int32, device="cuda") - b_seq_len = torch.empty(1, dtype=torch.int32, device="cuda") + input_ids = torch.tensor([1 for _ in range(total_token_num)], dtype=torch.int32, device=self.target_device) + mem_indexes = model.mem_manager.alloc(len(input_ids)).to(device=self.target_device, non_blocking=True) + b_req_idx = torch.tensor([model.req_manager.HOLD_REQUEST_ID], dtype=torch.int32, device=self.target_device) + b_seq_len = torch.empty(1, dtype=torch.int32, device=self.target_device) b_seq_len.fill_(total_token_num) - b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(1, dtype=torch.int32, device=self.target_device) + b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device=self.target_device) + b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device=self.target_device) model_input = ModelInput( batch_size=1, @@ -206,7 +211,7 @@ def warmup(self, model): for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() logger.info( f"Capture repfill cudagraph success, token_num <={self.max_handle_token_num} " f"will infer with cudagraph." @@ -225,14 +230,14 @@ def warmup_overlap(self, model): for micro_batch_index in [0, 1]: # dummy prefill, capture the cudagraph total_token_num = handle_token_num - input_ids = torch.tensor([1 for _ in range(total_token_num)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() - b_req_idx = torch.tensor([model.req_manager.HOLD_REQUEST_ID], dtype=torch.int32, device="cuda") - b_seq_len = torch.empty(1, dtype=torch.int32, device="cuda") + input_ids = torch.tensor([1 for _ in range(total_token_num)], dtype=torch.int32, device=self.target_device) + mem_indexes = model.mem_manager.alloc(len(input_ids)).to(device=self.target_device, non_blocking=True) + b_req_idx = torch.tensor([model.req_manager.HOLD_REQUEST_ID], dtype=torch.int32, device=self.target_device) + b_seq_len = torch.empty(1, dtype=torch.int32, device=self.target_device) b_seq_len.fill_(total_token_num) - b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") - b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") - b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device="cuda") + b_mtp_index = torch.zeros(1, dtype=torch.int32, device=self.target_device) + b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device=self.target_device) + b_prefill_start_loc = torch.zeros(1, dtype=torch.int32, device=self.target_device) micro_batch = ModelInput( batch_size=1, @@ -260,7 +265,7 @@ def warmup_overlap(self, model): for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() _, _ = model.microbatch_overlap_prefill(prefill_batches[0], prefill_batches[1]) @@ -273,7 +278,7 @@ def warmup_overlap(self, model): for var_name, var_value in list(locals().items()): if isinstance(var_value, torch.Tensor): del locals()[var_name] - torch.cuda.empty_cache() + self.platform_backend.runtime.empty_cache() logger.info( f"Capture overlap cudagraph success, handle_token_num <={self.max_handle_token_num} " diff --git a/lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py index ad526c0360..73b408a671 100644 --- a/lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/alibi_att/token_flashattention_nopad.py @@ -171,7 +171,7 @@ def token_attention_fwd( batch_size = b_seq_len.shape[0] calcu_shape1 = (batch_size, head_num, k.shape[2]) - att_m_tensor = alloc_tensor_func((head_num, total_token_num), dtype=q.dtype, device="cuda") + att_m_tensor = alloc_tensor_func((head_num, total_token_num), dtype=q.dtype, device=q.device) token_att_fwd( q.view(calcu_shape1), k, att_m_tensor, alibi, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, max_len_in_batch diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index e549298e3b..64e378c0f9 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -25,8 +25,8 @@ def gqa_token_decode_attention_flash_decoding( else: block_num = 32 - mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") - mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda") + mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device=q.device) + mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device=q.device) flash_decode_stage1( q=q.view(calcu_shape1), diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py index 141587ff38..324c815d49 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_vsm.py @@ -2,6 +2,7 @@ import triton import triton.language as tl from lightllm.common.kernel_config import KernelConfigs +from lightllm.platform import get_backend from lightllm.utils.device_utils import calcu_kernel_best_vsm_count from frozendict import frozendict from functools import lru_cache @@ -382,7 +383,7 @@ def emstimate_stage1_vsm( q_head_num = q_head_num batch_size = b_req_idx.shape[0] kernel = _kernel_gqa_token_decode_attention_flash_decoding_vsm_stage1.warmup( - torch.empty([1], dtype=torch.int64, device="cuda"), + torch.empty([1], dtype=torch.int64, device=q.device), q, k, v, @@ -423,7 +424,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( sm_scale = 1.0 / (q_head_dim ** 0.5) if not run_config: - if torch.cuda.is_current_stream_capturing(): + if get_backend().graph.is_capturing(): avg_seq_len_in_batch = infer_state.max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size @@ -440,6 +441,7 @@ def gqa_token_decode_attention_flash_decoding_vsm( if out is None: out = alloc_tensor_func(q.shape, dtype=q.dtype, device=q.device) + device = q.device num_vsm = emstimate_stage1_vsm( q, k, @@ -450,9 +452,9 @@ def gqa_token_decode_attention_flash_decoding_vsm( torch.empty( [q_head_num, 0, q_head_dim], dtype=torch.float32, - device="cuda", + device=device, ), - torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda"), + torch.empty([q_head_num, 0], dtype=torch.float32, device=device), sm_scale, **run_config, ) @@ -463,14 +465,14 @@ def gqa_token_decode_attention_flash_decoding_vsm( 1, ], dtype=torch.int64, - device="cuda", + device=device, ) mid_o_batch_start_index = torch.empty( [ batch_size, ], dtype=torch.int64, - device="cuda", + device=device, ) _fwd_kernel_calcu_index_and_block_seq[(1,)]( infer_state.b_seq_len, @@ -493,12 +495,12 @@ def gqa_token_decode_attention_flash_decoding_vsm( q_head_dim, ], dtype=torch.float32, - device="cuda", + device=device, ) mid_o_logexpsum = torch.empty( [q_head_num, num_vsm * 4 + batch_size], dtype=torch.float32, - device="cuda", + device=device, ) gqa_token_decode_attention_flash_decoding_vsm_stage1( infer_state.decode_att_block_seq, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py index 9521364ba6..0454bd6728 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int4kv/ppl_int4kv_flash_decoding.py @@ -30,8 +30,8 @@ def token_decode_attention_flash_decoding( else: block_num = 32 - mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") - mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=q.dtype, device="cuda") + mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device=q.device) + mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=q.dtype, device=q.device) from .int4kv_flash_decoding_stage1 import int4kv_flash_decode_stage1 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py index b40251b8f8..d84268b8bb 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/int8kv_flash_decoding_diverse.py @@ -1,6 +1,7 @@ # 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样 import torch from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.platform import get_backend from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 from .int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 @@ -18,10 +19,11 @@ def token_decode_attention_flash_decoding( alloc_tensor_func=torch.empty, shared_streams_dict={}, ): + platform_backend = get_backend() if "stream1" not in shared_streams_dict: - shared_streams_dict["stream1"] = torch.cuda.Stream() + shared_streams_dict["stream1"] = platform_backend.runtime.create_stream() if "stream2" not in shared_streams_dict: - shared_streams_dict["stream2"] = torch.cuda.Stream() + shared_streams_dict["stream2"] = platform_backend.runtime.create_stream() stream1 = shared_streams_dict["stream1"] stream2 = shared_streams_dict["stream2"] @@ -37,16 +39,16 @@ def token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=torch.float32, device=q.device ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=torch.float32, device=q.device ) - current_stream = torch.cuda.current_stream() + current_stream = platform_backend.runtime.current_stream() stream1.wait_stream(current_stream) - with torch.cuda.stream(stream1): + with platform_backend.runtime.stream(stream1): flash_decode_stage1( q=q.view(calcu_shape1), k=cache_k, @@ -64,7 +66,7 @@ def token_decode_attention_flash_decoding( max_batch_group_size=get_diverse_max_batch_shared_group_size(), ) stream2.wait_stream(current_stream) - with torch.cuda.stream(stream2): + with platform_backend.runtime.stream(stream2): flash_decode_stage2( q=q.view(calcu_shape1), k=cache_k, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py index b61e8eace2..e5bdf55c24 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py @@ -36,8 +36,8 @@ def token_decode_attention_flash_decoding( else: block_num = 32 - mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") - mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda") + mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device=q.device) + mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device=q.device) flash_decode_stage1( q=q.view(calcu_shape1), diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py index 6c50fc3927..0b608381ee 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/mha/flash_decoding/flash_decoding.py @@ -14,10 +14,10 @@ def token_decode_attention_flash_decoding(q, infer_state, cache_k, cache_v, out= o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device=q.device ) mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 1], dtype=torch.float32, device=q.device ) flash_decode_stage1( diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index 5ba6d0beb6..eaeb492412 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -7,6 +7,7 @@ import math import torch.nn.functional as F +from lightllm.platform import get_backend from lightllm.utils.device_utils import is_tesla @@ -459,7 +460,10 @@ def _fwd_kernel_contiguous_kv( def context_attention_fwd_contiguous_kv( q, k, v, o, b_start_loc, b_kv_start_loc, b_seq_len, max_q_input_len, b_prompt_cache_len ): - BLOCK_M = 128 if not is_tesla() else 64 + if is_tesla() or get_backend().name == "maca": + BLOCK_M = 64 + else: + BLOCK_M = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv diff --git a/lightllm/common/basemodel/triton_kernel/embedding.py b/lightllm/common/basemodel/triton_kernel/embedding.py index 8c88f9fd23..5b91c8f39b 100644 --- a/lightllm/common/basemodel/triton_kernel/embedding.py +++ b/lightllm/common/basemodel/triton_kernel/embedding.py @@ -1,6 +1,7 @@ import torch import triton import triton.language as tl +from typing import Optional @triton.jit @@ -65,6 +66,24 @@ def embedding(input_ids, weight: torch.Tensor, vob_start_id, vob_end_id, out: to ) +@torch.no_grad() +def npu_embedding( + input_ids: torch.Tensor, + wte_weight: torch.Tensor, + vob_start_id: int, + vob_end_id: int, + out: Optional[torch.Tensor] = None, +) -> None: + token_ids = input_ids - vob_start_id + + mask = (token_ids < 0) | (token_ids >= vob_end_id - vob_start_id) + token_ids = token_ids.masked_fill(mask, 0) + + emb = torch.nn.functional.embedding(token_ids, wte_weight) + res = emb.masked_fill_(mask.unsqueeze(-1), 0) + out.copy_(res) + + @torch.no_grad() def embedding_new(input_ids, weight, vob_start_id, vob_end_id): # out = self.alloc_tensor((N_CTX, DIM), dtype=torch.float32) diff --git a/lightllm/common/basemodel/triton_kernel/fa3_utils.py b/lightllm/common/basemodel/triton_kernel/fa3_utils.py index 0a524b63b6..37a432d913 100644 --- a/lightllm/common/basemodel/triton_kernel/fa3_utils.py +++ b/lightllm/common/basemodel/triton_kernel/fa3_utils.py @@ -1,3 +1,4 @@ +import torch import triton import triton.language as tl @@ -29,14 +30,32 @@ def page_table_copy_kernel( tl.store(page_table_ptr + output_pos, mem_index, mask=mask) +@torch.no_grad() +def paged_page_table_copy( + page_table: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_req_idx: torch.Tensor, + page_size: int, +) -> None: + num_pages = page_table.shape[1] + max_seq_len_k = num_pages * page_size + sampled = req_to_token_indexs[b_req_idx, :max_seq_len_k:page_size] + page_table.copy_(sampled // page_size) + + def page_table_copy( - page_table, # destination tensor [batch, seq] + page_table, # destination tensor [batch, seq] or [batch, num_pages] req_to_token_indexs, # source tensor [batch, seq] b_req_idx, # request index to copy from + page_size: int = 1, ): assert page_table.dim() == 2, "page_table should be 2D" assert req_to_token_indexs.dim() == 2, "req_to_token_indexs should be 2D" + if page_size > 1: + paged_page_table_copy(page_table, req_to_token_indexs, b_req_idx, page_size) + return + max_seq_len_k = page_table.shape[1] batch_size = page_table.size(0) BLOCK_SIZE = 128 diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 638abbd6ca..495bedc707 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -359,7 +359,7 @@ def moe_align2(token_num_mul_topk_num: int, exports_token_num: torch.Tensor, blo max_num_tokens_padded = token_num_mul_topk_num + exports_token_num.shape[0] * (block_m - 1) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_m) # first is expert, second is m_index, third is token_start_index - mblocks_to_tuple_info = torch.empty((max_num_m_blocks, 3), dtype=torch.int32, device="cuda") + mblocks_to_tuple_info = torch.empty((max_num_m_blocks, 3), dtype=torch.int32, device=exports_token_num.device) expert_num = exports_token_num.shape[0] @@ -790,7 +790,7 @@ def grouped_matmul( if support_tma: # TMA descriptors require a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) + return torch.empty(size, device=token_inputs.device, dtype=torch.int8) triton.set_allocator(alloc_fn) @@ -943,6 +943,7 @@ def fused_experts_impl( hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype ) + device = hidden_states.device for chunk in range(triton.cdiv(num_tokens, CHUNK_SIZE)): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] @@ -955,9 +956,9 @@ def fused_experts_impl( curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device="cuda") - expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device="cuda") - expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device="cuda") + expert_to_tokens = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.int32, device=device) + expert_to_weights = torch.empty((E, topk_num * tokens_in_chunk), dtype=torch.float32, device=device) + expert_to_token_num = torch.zeros((E,), dtype=torch.int32, device=device) moe_align_fused( expert_to_token_index=expert_to_tokens, expert_to_weight=expert_to_weights, diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..7ec63a1355 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -150,7 +150,7 @@ def fused_experts_impl( num_recv_tokens_per_expert = torch.tensor( num_recv_tokens_per_expert_list, dtype=torch.int32, pin_memory=True, device="cpu" - ).cuda(non_blocking=True) + ).to(device=hidden_states.device, non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py index fb0323cd4b..7225fe41d1 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_topk.py @@ -225,9 +225,10 @@ def triton_grouped_topk( else: dtype = torch.float32 - scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") - out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") - out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda") + device = hidden_states.device + scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device=device) + out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device=device) + out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device=device) assert total_expert_num % num_expert_group == 0 diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 1c01cbd638..f514a7d68e 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -171,9 +171,10 @@ def select_experts( ######################################## warning ################################################## # here is used to match autotune feature, make topk_ids more random if Autotuner.is_autotune_warmup(): - rand_gen = torch.Generator(device="cuda") + device = hidden_states.device + rand_gen = torch.Generator(device=device) rand_gen.manual_seed(router_logits.shape[0]) - router_logits = torch.randn(size=router_logits.shape, generator=rand_gen, dtype=torch.float32, device="cuda") + router_logits = torch.randn(size=router_logits.shape, generator=rand_gen, dtype=torch.float32, device=device) _, topk_ids = torch.topk(router_logits, k=top_k, dim=1) return topk_weights, topk_ids diff --git a/lightllm/common/basemodel/triton_kernel/gather_token_id.py b/lightllm/common/basemodel/triton_kernel/gather_token_id.py index f8181d73c0..fc7780b9f6 100644 --- a/lightllm/common/basemodel/triton_kernel/gather_token_id.py +++ b/lightllm/common/basemodel/triton_kernel/gather_token_id.py @@ -28,6 +28,8 @@ def _fwd_kernel_scatter( if not HAS_OUT_IS_NONE: cur_has_out = tl.load(b_has_out + block_range, mask=block_mask, other=False) + # Mask must have boolean scalar type on NPU + cur_has_out = cur_has_out.to(tl.int1) if OLD_VERSION_TRITON: cur_has_out = cur_has_out != 0 tl.store( @@ -122,7 +124,7 @@ def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b output: (batch_size,) """ batch_size = b_req_idx.shape[0] - output = torch.empty(batch_size, dtype=req_to_next_token_ids.dtype, device="cuda") + output = torch.empty(batch_size, dtype=req_to_next_token_ids.dtype, device=b_req_idx.device) BLOCK = 256 grid = (triton.cdiv(batch_size, BLOCK),) num_warps = 1 diff --git a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py index 8f9172b552..2c23391135 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_prefill_params.py @@ -45,8 +45,8 @@ def gen_cumsum_pad0_tensor(b_q_seq_len: torch.Tensor, b_kv_seq_len: torch.Tensor assert b_q_seq_len.shape == b_kv_seq_len.shape assert b_q_seq_len.is_contiguous() - b1_cu_q_seq_len = torch.empty((b_q_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") - b1_cu_kv_seq_len = torch.empty((b_kv_seq_len.shape[0] + 1,), dtype=torch.int32, device="cuda") + b1_cu_q_seq_len = torch.empty((b_q_seq_len.shape[0] + 1,), dtype=torch.int32, device=b_q_seq_len.device) + b1_cu_kv_seq_len = torch.empty((b_kv_seq_len.shape[0] + 1,), dtype=torch.int32, device=b_kv_seq_len.device) _gen_cumsum_pad0_kernel[(1,)]( b_q_seq_len, b1_cu_q_seq_len, @@ -85,7 +85,7 @@ def _gen_prefill_position( @torch.no_grad() def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_seq_len: torch.Tensor): batch_size = b_ready_cache_len.shape[0] - position_ids = torch.empty((input_token_num,), dtype=torch.int32, device="cuda") + position_ids = torch.empty((input_token_num,), dtype=torch.int32, device=b_ready_cache_len.device) assert b_ready_cache_len.shape[0] == b_seq_len.shape[0] b_q_seq_len = b_seq_len - b_ready_cache_len b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_seq_len) diff --git a/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py b/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py index 6f14c6d5a5..f01fd7b818 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_sampling_params.py @@ -46,11 +46,12 @@ def gen_sampling_params(b_req_idx: torch.Tensor, req_sampling_params_manager): req_sampling_params_manager: ReqSamplingParamsManager = req_sampling_params_manager batch_size = b_req_idx.shape[0] - b_presence_penalty = torch.empty((batch_size,), dtype=torch.float32, device="cuda") - b_frequency_penalty = torch.empty((batch_size,), dtype=torch.float32, device="cuda") - b_repetition_penalty = torch.empty((batch_size,), dtype=torch.float32, device="cuda") - b_temperature = torch.empty((batch_size,), dtype=torch.float32, device="cuda") - b_exponential_decay_length_penalty = torch.empty((batch_size,), dtype=torch.float32, device="cuda") + device = b_req_idx.device + b_presence_penalty = torch.empty((batch_size,), dtype=torch.float32, device=device) + b_frequency_penalty = torch.empty((batch_size,), dtype=torch.float32, device=device) + b_repetition_penalty = torch.empty((batch_size,), dtype=torch.float32, device=device) + b_temperature = torch.empty((batch_size,), dtype=torch.float32, device=device) + b_exponential_decay_length_penalty = torch.empty((batch_size,), dtype=torch.float32, device=device) BLOCK = 256 @@ -123,6 +124,8 @@ def _token_id_counter_update_kernel( next_token_ids_ptr, mask_ptr, batch_size, + vocab_size, + num_req_rows, HAS_MASK: tl.constexpr, BLOCK: tl.constexpr, OLD_VERSION_TRITON: tl.constexpr, @@ -137,19 +140,19 @@ def _token_id_counter_update_kernel( if HAS_MASK: mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False) + # tt.atomic_rmw must be 1-bit inputs on NPU + add_mask = (loc_mask & mask) != 0 if OLD_VERSION_TRITON: mask = mask != 0 - tl.atomic_add( - req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, - 1, - mask=loc_mask & mask, - ) + update_mask = add_mask else: - tl.atomic_add( - req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, - 1, - mask=loc_mask, - ) + update_mask = loc_mask + + tl.atomic_add( + req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n, + 1, + mask=update_mask, + ) return @@ -161,6 +164,13 @@ def update_req_to_token_id_counter( mask: torch.Tensor = None, ): batch_size = b_req_idx.shape[0] + vocab_size = req_to_out_token_id_counter.shape[1] + num_req_rows = req_to_out_token_id_counter.shape[0] + if vocab_size <= 0 or num_req_rows <= 0: + raise RuntimeError( + f"invalid req_to_out_token_id_counter shape {tuple(req_to_out_token_id_counter.shape)}; " + "check get_vocab_size(model_dir) for this checkpoint (gemma3 config.json may omit vocab_size)" + ) BLOCK = 256 has_mask = mask is not None _token_id_counter_update_kernel[(triton.cdiv(batch_size, BLOCK),)]( @@ -171,6 +181,8 @@ def update_req_to_token_id_counter( next_token_ids_ptr=next_token_ids, mask_ptr=mask, batch_size=batch_size, + vocab_size=vocab_size, + num_req_rows=num_req_rows, HAS_MASK=has_mask, BLOCK=BLOCK, OLD_VERSION_TRITON=triton.__version__ < "3.2.0", diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py index 0fdc43ab9f..4c0b0775ed 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py +++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py @@ -551,7 +551,12 @@ def load_cpu_kv_to_gpu( move_token_num = gpu_mem_indexes.shape[0] cpu_page_indexes = page_indexes.view((cpu_page_num, 1)).tile((1, token_block_size)).view(-1) - cpu_mem_indexes = torch.arange(0, cpu_page_all_token_num, device="cuda", dtype=torch.int32) % token_block_size + cpu_mem_indexes = torch.arange( + 0, + cpu_page_all_token_num, + device=gpu_mem_indexes.device, + dtype=torch.int32, + ) % token_block_size cpu_page_indexes = cpu_page_indexes[-move_token_num:] cpu_mem_indexes = cpu_mem_indexes[-move_token_num:] diff --git a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py index 063181d995..582d8d3afe 100644 --- a/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/mla_att/decode_att/gqa_flash_decoding.py @@ -4,6 +4,7 @@ import triton import triton.language as tl from typing import List +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig from lightllm.utils.device_utils import get_device_sm_count @@ -25,7 +26,7 @@ def gqa_token_decode_attention_flash_decoding( calcu_shape2 = (batch_size, q_head_num, q_rope_dim) if not run_config: - if torch.cuda.is_current_stream_capturing(): + if get_backend().graph.is_capturing(): avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size @@ -46,9 +47,10 @@ def gqa_token_decode_attention_flash_decoding( o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out - fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device="cuda") - mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda") - mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda") + device = q_nope.device + fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device=device) + mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device=device) + mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device=device) vsm_count = flash_decode_stage1( fake_decode_att_block_seq, @@ -72,14 +74,14 @@ def gqa_token_decode_attention_flash_decoding( 1, ], dtype=torch.int64, - device="cuda", + device=device, ) mid_o_batch_start_index = torch.empty( [ batch_size, ], dtype=torch.int64, - device="cuda", + device=device, ) _fwd_kernel_calcu_index_and_block_seq[(1,)]( infer_state.b_seq_len, @@ -95,8 +97,8 @@ def gqa_token_decode_attention_flash_decoding( infer_state.decode_att_block_seq = decode_att_block_seq infer_state.mid_o_batch_start_index = mid_o_batch_start_index - mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda") - mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda") + mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device=device) + mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device=device) flash_decode_stage1( infer_state.decode_att_block_seq, diff --git a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py index e2d4aea587..60109dc104 100644 --- a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py +++ b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py @@ -116,6 +116,51 @@ def multimodal_emb( return +@torch.no_grad() +def npu_multimodal_emb( + out: torch.Tensor, + prompt_ids: torch.Tensor, + text_weight_embs: torch.Tensor, + embed_cache: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs_in_cache: torch.Tensor, + tp_text_start_token_id: int, + tp_text_end_token_id: int, + tp_world_size: int, +): + # text mask + text_mask = (prompt_ids >= tp_text_start_token_id) & (prompt_ids < tp_text_end_token_id) + if text_mask.any(): + text_ids = prompt_ids[text_mask] - tp_text_start_token_id + out[text_mask] = torch.nn.functional.embedding(text_ids, text_weight_embs) + # image mask + image_mask = torch.zeros_like(text_mask, dtype=torch.bool) + image_index = torch.zeros_like(prompt_ids, dtype=torch.long) + + for i in range(img_token_lens.shape[0]): + start_token = img_start_token_ids[i] + start_loc = img_start_locs_in_cache[i] + token_len = img_token_lens[i] + + mask = (prompt_ids >= start_token) & (prompt_ids < start_token + token_len) + image_mask |= mask + + rel = prompt_ids[mask] - start_token + image_index[mask] = start_loc + rel + + if image_mask.any(): + target_indices = image_index[image_mask].cpu() + if embed_cache.dim() == 3: + selected = embed_cache[target_indices, 0, :] + else: + selected = embed_cache[target_indices] + selected_npu = selected.to(out.device, dtype=out.dtype, non_blocking=True) + out[image_mask] = selected_npu / tp_world_size + + return out + + @triton.jit def _mark_multimodal_obj_need_kernel( obj_start_token_ids_ptr, diff --git a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py index 3881cfe4b8..26fa49bc0f 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_quant_kernel.py @@ -1,6 +1,7 @@ import torch import triton import triton.language as tl +from lightllm.utils.device_utils import get_target_device from lightllm.utils.dist_utils import get_current_device_id @@ -44,7 +45,7 @@ def mm_weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tenso def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]: assert x.is_contiguous(), "Input tensor must be contiguous" - x = x.cuda(get_current_device_id()) + x = x.to(device=get_target_device()) if x.dim() == 3: y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) num_blocks_m = triton.cdiv(x.shape[1], block_size) diff --git a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_perchannel_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_perchannel_quant_kernel.py index ba11acc9e7..629cb3893d 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_perchannel_quant_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_perchannel_quant_kernel.py @@ -1,6 +1,7 @@ import torch import triton import triton.language as tl +from lightllm.utils.device_utils import get_target_device from lightllm.utils.dist_utils import get_current_device_id @@ -37,7 +38,7 @@ def mm_weight_quant(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: def weight_quant(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert x.is_contiguous(), "Input tensor must be contiguous" - x = x.cuda(get_current_device_id()) + x = x.to(device=get_target_device()) if x.dim() == 3: y_quant = torch.empty((x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float8_e4m3fn, device=x.device) s_scales = torch.empty((x.shape[0], x.shape[1], 1), dtype=torch.float32, device=x.device) diff --git a/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_group_quant_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_group_quant_kernel.py index 0e49fa21f5..42522c81e6 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_group_quant_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_group_quant_kernel.py @@ -220,7 +220,7 @@ def scaled_mm_act_per_group_w_perchannel( if support_tma: # TMA descriptors require a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) + return torch.empty(size, device=A.device, dtype=torch.int8) triton.set_allocator(alloc_fn) diff --git a/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py index cb3a3d7316..d031715c8c 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/scaled_mm_per_token_kernel.py @@ -219,7 +219,7 @@ def scaled_mm_per_token( if support_tma: # TMA descriptors require a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) + return torch.empty(size, device=A.device, dtype=torch.int8) triton.set_allocator(alloc_fn) diff --git a/lightllm/common/cpu_cache/creator.py b/lightllm/common/cpu_cache/creator.py index bdd6fdd772..baa5188f4b 100644 --- a/lightllm/common/cpu_cache/creator.py +++ b/lightllm/common/cpu_cache/creator.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Optional, Tuple from lightllm.utils.kv_cache_utils import attach_shm_kv_cache_ptr, create_shm_kv_cache_ptr, register_shm_ptr_to_pin +from lightllm.platform import get_backend @dataclass(frozen=True) @@ -33,12 +34,16 @@ def create_or_attach( if not pin_no_blocking: attach_handle.wait() - # 等待 device_ptr 被赋值 - while attach_handle.device_ptr is None: - time.sleep(0.01) + if get_backend().name == "ascend": + cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr) + assert shm_ptr == cpu_cache_tensor.data_ptr() + else: + # 等待 device_ptr 被赋值 + while attach_handle.device_ptr is None: + time.sleep(0.01) - cpu_cache_tensor = self._build_tensor_view(shm_ptr=attach_handle.device_ptr) - assert attach_handle.device_ptr == cpu_cache_tensor.data_ptr() + cpu_cache_tensor = self._build_tensor_view(shm_ptr=attach_handle.device_ptr) + assert attach_handle.device_ptr == cpu_cache_tensor.data_ptr() return cpu_cache_tensor, attach_handle else: cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index d49c8d7e73..d89bbeaa57 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -30,19 +30,19 @@ def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device=self.target_device) def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( - (1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + (1, max_req_total_len + 8, self.head_num, self.head_dim), dtype=self.dtype, device=self.target_device ) - self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") + self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device=self.target_device) self.token_dim_size = self.kv_move_buffer.shape[-1] * self.kv_move_buffer.shape[-2] return def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: self.kv_move_buffer = torch.empty( - (page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + (page_num, page_size, self.layer_num, self.head_num, self.head_dim), dtype=self.dtype, device=self.target_device ) self._buffer_mem_indexes_tensors = [ torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num) @@ -60,7 +60,7 @@ def write_mem_to_page_kv_move_buffer( cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes - mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + mem_indexes_gpu = pin_mem_indexes.to(device=self.target_device, non_blocking=True) dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] mla_page_io( mem_indexes=mem_indexes_gpu, @@ -81,7 +81,7 @@ def read_page_kv_move_buffer_to_mem( cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes - mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + mem_indexes_gpu = pin_mem_indexes.to(device=self.target_device, non_blocking=True) dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] for mem in dp_mems: mla_page_io( @@ -173,7 +173,7 @@ def send_to_decode_node_p2p( mems_ptr = [] for i in range(0, len(mem_managers), len(mem_managers) // dp_size_in_node): mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) - mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") + mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device=self.target_device) self.mem_ptrs_dict[layer_index] = mems_ptr move_token_indexes = [] @@ -183,8 +183,8 @@ def send_to_decode_node_p2p( move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)]) - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device=self.target_device) + token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device=self.target_device) for layer_index in range(self.layer_num): move_buffer = self._get_kv_move_data_p2p( move_token_indexes, token_dp_indexes, layer_index, self.kv_move_buffer, dp_size_in_node @@ -226,7 +226,7 @@ def receive_from_prefill_node_p2p( mems_ptr = [] for i in range(0, len(mem_managers)): mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) - mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda") + mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device=self.target_device) self.mem_ptrs_dict[layer_index] = mems_ptr move_token_indexes = [] @@ -236,8 +236,8 @@ def receive_from_prefill_node_p2p( move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)]) - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") - token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device=self.target_device) + token_dp_indexes = torch.tensor(token_dp_indexes, dtype=torch.int32, device=self.target_device) token_num = len(move_token_indexes) move_size = self.token_dim_size * token_num diff --git a/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py index 7980ca2dd7..5d065c69b0 100755 --- a/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_static_per_head_quant_mem_manager.py @@ -32,7 +32,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False ) cfg = self._load_and_check_config() all_head_num = cfg["num_head"] - all_scales = torch.tensor(cfg["scales"], dtype=torch.float32, device="cuda").view(cfg["scales_shape"]) + all_scales = torch.tensor( + cfg["scales"], dtype=torch.float32, device=self.target_device).view(cfg["scales_shape"]) factor = (get_dp_world_size() * head_num) // all_head_num assert (get_dp_world_size() * head_num) % all_head_num == 0 @@ -46,7 +47,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False v_scales = all_scales[:, v_offset + start_head : v_offset + end_head].contiguous() self.scales = torch.cat((k_scales, v_scales), dim=-1) else: - self.scales = torch.ones((self.kv_buffer.shape[0], 2 * head_num), dtype=torch.float32, device="cuda") + self.scales = torch.ones( + (self.kv_buffer.shape[0], 2 * head_num), dtype=torch.float32, device=self.target_device) return def _load_and_check_config(self): diff --git a/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py b/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py index 06338808bb..1a5938cfc2 100755 --- a/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/fp8_static_per_tensor_quant_mem_manager.py @@ -30,9 +30,11 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False "will load kv quant calibration config" ) cfg = self._load_and_check_config() - self.scales = torch.tensor(cfg["scales"], dtype=torch.float32, device="cuda").view(cfg["scales_shape"]) + self.scales = torch.tensor( + cfg["scales"], dtype=torch.float32, device=self.target_device).view(cfg["scales_shape"]) else: - self.scales = torch.ones((self.kv_buffer.shape[0], 2), dtype=torch.float32, device="cuda") + self.scales = torch.ones( + (self.kv_buffer.shape[0], 2), dtype=torch.float32, device=self.target_device) self.cpu_scales = self.scales.detach().cpu().numpy() return diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 0cc84db322..d54d5f931b 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -10,12 +10,12 @@ from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_page_size, get_unique_server_name, get_env_start_args from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.config_utils import get_num_key_value_heads from lightllm.common.kv_trans_kernel.nixl_kv_trans import page_io -from lightllm.utils.device_utils import kv_trans_use_p2p +from lightllm.utils.device_utils import get_target_device, kv_trans_use_p2p from lightllm.utils.shm_utils import create_or_link_shm from multiprocessing.reduction import ForkingPickler from filelock import FileLock @@ -35,8 +35,12 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.layer_num = layer_num self.always_copy = always_copy self.dtype = dtype + self.target_device = get_target_device() # profile the max total token num if the size is None self.profile_size(mem_fraction) + page_size = get_page_size() + if page_size > 1: + self.size = (self.size // page_size) * page_size self.mem_state = torch.arange( 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True @@ -89,7 +93,7 @@ def profile_size(self, mem_fraction): cell_size = self.get_cell_size() self.size = int(available_memory * 1024 ** 3 / cell_size) if world_size > 1: - tensor = torch.tensor(self.size, dtype=torch.int64, device=f"cuda:{get_current_device_id()}") + tensor = torch.tensor(self.size, dtype=torch.int64, device=self.target_device) dist.all_reduce(tensor, op=dist.ReduceOp.MIN) self.size = tensor.item() logger.info( @@ -104,7 +108,9 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch # 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + self.kv_buffer = torch.empty((layer_num, alloc_size, 2 * head_num, head_dim), dtype=dtype, device=self.target_device) def alloc_kv_move_buffer(self, max_req_total_len): """ @@ -113,9 +119,9 @@ def alloc_kv_move_buffer(self, max_req_total_len): if isinstance(self, MemoryManager) and type(self) is not MemoryManager: raise NotImplementedError("subclass need reimpl this method") self.kv_move_buffer = torch.empty( - (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" + (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device=self.target_device ) - self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda") + self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device=self.target_device) self.token_dim_size = self.kv_move_buffer.shape[-2] * self.kv_move_buffer.shape[-1] return @@ -125,7 +131,7 @@ def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir) self.kv_move_buffer = torch.empty( - (page_num, page_size, self.layer_num, 2 * num_kv_head, self.head_dim), dtype=self.dtype, device="cuda" + (page_num, page_size, self.layer_num, 2 * num_kv_head, self.head_dim), dtype=self.dtype, device=self.target_device ) self._buffer_mem_indexes_tensors = [ torch.empty((page_size,), dtype=torch.int64, device="cpu", pin_memory=True) for _ in range(page_num) @@ -143,7 +149,7 @@ def write_mem_to_page_kv_move_buffer( cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes - mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + mem_indexes_gpu = pin_mem_indexes.to(device=self.target_device, non_blocking=True) repeat_count = dp_world_size * self.kv_buffer.shape[2] // self.kv_move_buffer.shape[3] dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] for tp_index in range(dp_world_size): @@ -172,10 +178,10 @@ def read_page_kv_move_buffer_to_mem( cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes - mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) + mem_indexes_gpu = pin_mem_indexes.to(device=self.target_device, non_blocking=True) dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] - mem_indexes_gpu = torch.tensor(mem_indexes, dtype=torch.int64, device="cpu", pin_memory=True).cuda( - non_blocking=True + mem_indexes_gpu = torch.tensor(mem_indexes, dtype=torch.int64, device="cpu", pin_memory=True).to( + device=self.target_device, non_blocking=True ) for tp_index in range(dp_world_size): page_io( @@ -286,7 +292,7 @@ def send_to_decode_node_p2p( if task.move_kv_len != 0: move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :]) - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device=self.target_device) for i, mem in enumerate(mem_managers): for layer_index in range(mem.layer_num): move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer) @@ -318,7 +324,7 @@ def receive_from_prefill_node_p2p( if task.move_kv_len != 0: move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :]) - move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda") + move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device=self.target_device) token_num = len(move_token_indexes) move_size = self.token_dim_size * token_num diff --git a/lightllm/common/kv_cache_mem_manager/mem_utils.py b/lightllm/common/kv_cache_mem_manager/mem_utils.py index 79ea448794..80ded0187b 100644 --- a/lightllm/common/kv_cache_mem_manager/mem_utils.py +++ b/lightllm/common/kv_cache_mem_manager/mem_utils.py @@ -10,6 +10,8 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.llm_utils import get_llm_model_class +from lightllm.common.kv_cache_mem_manager.npu_mem_manager import NPUMemoryManager +from lightllm.platform import get_backend from functools import lru_cache logger = init_logger(__name__) @@ -51,7 +53,10 @@ def select_mem_manager_class(): elif get_env_start_args().llm_kv_type == "fp8kv_spt": memory_manager_class = FP8StaticPerTensorQuantMemManager elif get_env_start_args().llm_kv_type == "None": - memory_manager_class = MemoryManager + if get_backend().name == "ascend": + memory_manager_class = NPUMemoryManager + else: + memory_manager_class = MemoryManager logger.info(f"Model kv cache using mem_manager class: {memory_manager_class}") return memory_manager_class diff --git a/lightllm/common/kv_cache_mem_manager/npu_mem_manager.py b/lightllm/common/kv_cache_mem_manager/npu_mem_manager.py new file mode 100644 index 0000000000..1417378356 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/npu_mem_manager.py @@ -0,0 +1,130 @@ + +from lightllm.common.kv_cache_mem_manager.operator.base import BaseMemManagerOperator +import torch +from typing import Any, List, Tuple + +from lightllm.server.pd_io_struct import KVMoveTask +from lightllm.utils.envs_utils import get_page_size +from lightllm.utils.log_utils import init_logger + +from .mem_manager import MemoryManager + + +logger = init_logger(__name__) + + +class NPUOperator(BaseMemManagerOperator): + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): + kb, vb = self.mem_manager.k_buffer[layer_index], self.mem_manager.v_buffer[layer_index] + k_src, v_src = kv[:, : self.mem_manager.head_num, :], kv[:, self.mem_manager.head_num :, :] + assert kv.shape[0] == mem_index.shape[0], (kv.shape, mem_index.shape) + assert k_src.shape[1] == kb.shape[1] and k_src.shape[2] == kb.shape[2], (k_src.shape, kb.shape) + assert v_src.shape[1] == vb.shape[1] and v_src.shape[2] == vb.shape[2], (v_src.shape, vb.shape) + self.mem_manager.k_buffer[layer_index].index_copy_(0, mem_index, k_src) + self.mem_manager.v_buffer[layer_index].index_copy_(0, mem_index, v_src) + + +class NPUMemoryManager(MemoryManager): + operator_class = NPUOperator + + def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: + return self.k_buffer[layer_index], self.v_buffer[layer_index] + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + page_size = get_page_size() + alloc_size = ((size // page_size) + 1) * page_size if page_size > 1 else size + 1 + logger.info(f"Total page blocks allocated: {alloc_size // page_size} for page_size: {page_size}") + self.k_buffer = torch.empty((layer_num, alloc_size, head_num, head_dim), dtype=dtype, device=self.target_device) + self.v_buffer = torch.empty((layer_num, alloc_size, head_num, head_dim), dtype=dtype, device=self.target_device) + self.kv_buffer = self.k_buffer + + def _free_buffers(self): + self.k_buffer = None + self.v_buffer = None + self.kv_buffer = None + + def get_index_kv_buffer(self, index): + return { + "kv_buffer": torch.cat([self.k_buffer[:, index], self.v_buffer[:, index]], dim=1), + } + + def load_index_kv_buffer(self, index, load_tensor_dict): + t = load_tensor_dict["kv_buffer"] + self.k_buffer[:, index].copy_(t[:, : self.head_num]) + self.v_buffer[:, index].copy_(t[:, self.head_num :]) + + def alloc_kv_move_buffer(self, max_req_total_len): + raise NotImplementedError("NPUMemoryManager does not support PD-separated alloc_kv_move_buffer") + + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + raise NotImplementedError("NPUMemoryManager does not support PD-separated alloc_paged_kv_move_buffer") + + def write_mem_to_page_kv_move_buffer( + self, + mem_indexes: List[int], + page_index: int, + dp_index: int, + mem_managers: List["MemoryManager"], + dp_world_size: int, + ): + raise NotImplementedError("NPUMemoryManager does not support PD-separated write_mem_to_page_kv_move_buffer") + + def read_page_kv_move_buffer_to_mem( + self, + mem_indexes: List[int], + page_index: int, + dp_index: int, + mem_managers: List["MemoryManager"], + dp_world_size: int, + ): + raise NotImplementedError("NPUMemoryManager does not support PD-separated read_page_kv_move_buffer_to_mem") + + def send_to_decode_node( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm, + ): + raise NotImplementedError("NPUMemoryManager does not support PD-separated send_to_decode_node") + + def receive_from_prefill_node( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm, + ): + raise NotImplementedError("NPUMemoryManager does not support PD-separated receive_from_prefill_node") + + def send_to_decode_node_p2p( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm, + ): + raise NotImplementedError("NPUMemoryManager does not support PD-separated send_to_decode_node_p2p") + + def receive_from_prefill_node_p2p( + self, + move_tasks: List[KVMoveTask], + mem_managers: List["MemoryManager"], + dp_size_in_node: int, + nccl_comm, + ): + raise NotImplementedError("NPUMemoryManager does not support PD-separated receive_from_prefill_node_p2p") + + def copy_kv_from_other_dp_ranks( + self, + mem_managers: List["MemoryManager"], + move_token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + mem_indexes: torch.Tensor, + dp_size_in_node: int, + rank_in_dp: int, + ): + raise NotImplementedError( + "NPUMemoryManager does not support copy_kv_from_other_dp_ranks (needs split-kv kernel)" + ) diff --git a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py index 109e813220..c4a56af6ad 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/linear_att.py +++ b/lightllm/common/kv_cache_mem_manager/operator/linear_att.py @@ -31,7 +31,8 @@ def load_cpu_cache_to_gpu( cpu_cache_client: "CpuKvCacheClient", req: "InferReq", ): - assert mem_indexes.is_cuda and page_indexes.is_cuda + assert mem_indexes.device == self.mem_manager.target_device + assert page_indexes.device == self.mem_manager.target_device args = get_env_start_args() assert triton.cdiv(len(mem_indexes), args.cpu_cache_token_page_size) == len(page_indexes) assert len(mem_indexes) % args.linear_att_hash_page_size == 0 @@ -66,7 +67,8 @@ def load_cpu_cache_to_gpu( # 将对应的小叶数据拷贝到临时的大页上,再从大页上拷贝到对应的运行态页面上 big_page_buffer_ids_cpu.append(mem_manager.CPU_CACHE_BIG_PAGE_LOAD_TEMP_BUFFER_ID) - big_page_buffer_ids_gpu = torch.tensor(big_page_buffer_ids_cpu, dtype=torch.int64, device="cpu").cuda( + big_page_buffer_ids_gpu = torch.tensor(big_page_buffer_ids_cpu, dtype=torch.int64, device="cpu").to( + device=self.mem_manager.target_device, non_blocking=True ) @@ -108,12 +110,14 @@ def offload_gpu_kv_to_cpu_cache( req: "InferReq", ): if not hasattr(self, "big_page_ids_buffer_store"): - self.big_page_ids_buffer_store = torch.empty((1024 * 1024 * 4,), dtype=torch.int64, device="cuda") + self.big_page_ids_buffer_store = torch.empty((1024 * 1024 * 4,), dtype=torch.int64, device=self.mem_manager.target_device) self.mem_indexes_buffer = torch.empty( - (get_env_start_args().max_req_total_len + 1024,), dtype=torch.int32, device="cuda" + (get_env_start_args().max_req_total_len + 1024,), dtype=torch.int32, device=self.mem_manager.target_device ) - assert mem_indexes.is_cuda and page_indexes.is_cuda and page_readies.is_cuda + assert mem_indexes.device == self.mem_manager.target_device + assert page_indexes.device == self.mem_manager.target_device + assert page_readies.device == self.mem_manager.target_device args = get_env_start_args() assert len(mem_indexes) % args.linear_att_hash_page_size == 0 assert triton.cdiv(len(mem_indexes), args.cpu_cache_token_page_size) == len(page_indexes) @@ -202,6 +206,8 @@ def copy_mem_to_mem(self, src_mem_index: torch.Tensor, dst_mem_index: torch.Tens from lightllm.common.basemodel.triton_kernel.kv_move import copy_kv_buffer_to_kv_buffer copy_kv_buffer_to_kv_buffer( - src_mem_index.cuda(non_blocking=True), dst_mem_index.cuda(non_blocking=True), self.mem_manager.kv_buffer + src_mem_index.to(device=self.mem_manager.target_device, non_blocking=True), + dst_mem_index.to(device=self.mem_manager.target_device, non_blocking=True), + self.mem_manager.kv_buffer, ) return diff --git a/lightllm/common/kv_cache_mem_manager/operator/normal.py b/lightllm/common/kv_cache_mem_manager/operator/normal.py index 3c53ace079..8800a539c9 100644 --- a/lightllm/common/kv_cache_mem_manager/operator/normal.py +++ b/lightllm/common/kv_cache_mem_manager/operator/normal.py @@ -104,7 +104,7 @@ def copy_kv_from_other_dp_ranks( # 一次性传输所有层 kv_trans_for_dp( - input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True), + input_mems=self.mem_ptrs_tensor.to(device=self.mem_manager.target_device, non_blocking=True), input_idx=move_token_indexes, input_dp_idx=token_dp_indexes, output=self.mem_manager.kv_buffer, diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index 584877a1b3..bbad113143 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -29,10 +29,10 @@ def get_cell_size(self): def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): self.kv_buffer = torch.empty( - (layer_num, size + 1, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda" + (layer_num, size + 1, 2 * head_num, head_dim // 2), dtype=torch.int8, device=self.target_device ) self.scale_buffer = torch.empty( - (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda" + (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device=self.target_device ) def _free_buffers(self): diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index 994c676e9f..17bbe14fa2 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -27,9 +27,9 @@ def get_cell_size(self): ) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") + self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=torch.int8, device=self.target_device) self.scale_buffer = torch.empty( - (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda" + (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device=self.target_device ) def _free_buffers(self): diff --git a/lightllm/common/quantization/awq.py b/lightllm/common/quantization/awq.py index f3c7623975..d3e2ff0e67 100644 --- a/lightllm/common/quantization/awq.py +++ b/lightllm/common/quantization/awq.py @@ -3,7 +3,7 @@ from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack from lightllm.common.quantization.registry import QUANTMETHODS -from lightllm.utils.dist_utils import get_current_device_id +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -114,11 +114,15 @@ def _create_weight( out_dim = sum(out_dims) group_size = self.hf_quantization_config["group_size"] expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + weight = torch.empty( + expert_prefix + (in_dim, out_dim // self.pack_factor), dtype=torch.int32 + ).to(device=self.target_device, non_blocking=True) + weight_scale = torch.empty( + expert_prefix + (in_dim // group_size, out_dim), dtype=dtype + ).to(device=self.target_device, non_blocking=True) weight_zero_point = torch.empty( expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 - ).cuda(device_id) + ).to(device=self.target_device, non_blocking=True) weight_out_dims = [_out_dim // self.pack_factor for _out_dim in out_dims] weight_scale_out_dims = out_dims weight_zero_point_out_dims = weight_out_dims @@ -144,9 +148,9 @@ def __init__(self): self.weight_scale_suffix = "scales" self.weight_zero_point_suffix = "qzeros" self.weight_suffix = "qweight" - self.g_idx = marlin_make_empty_g_idx(torch.device("cuda")) - self.g_idx_sort_indices = marlin_make_empty_g_idx(torch.device("cuda")) - self.workspace = marlin_make_workspace_new(torch.device("cuda")) + self.g_idx = marlin_make_empty_g_idx(self.target_device) + self.g_idx_sort_indices = marlin_make_empty_g_idx(self.target_device) + self.workspace = marlin_make_workspace_new(self.target_device) self.vllm_quant_type = TYPE_MAP[self.nbits] self.has_weight_scale = True self.has_weight_zero_point = True @@ -215,11 +219,13 @@ def _create_weight( expert_prefix = (num_experts,) if num_experts > 1 else () weight = torch.empty( expert_prefix + (in_dim // self.tile_size, out_dim * self.tile_size // self.pack_factor), dtype=torch.int32 - ).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (in_dim // group_size, out_dim), dtype=dtype).cuda(device_id) + ).to(device=self.target_device, non_blocking=True) + weight_scale = torch.empty( + expert_prefix + (in_dim // group_size, out_dim), dtype=dtype + ).to(device=self.target_device, non_blocking=True) weight_zero_point = torch.empty( expert_prefix + (in_dim // group_size, out_dim // self.pack_factor), dtype=torch.int32 - ).cuda(device_id) + ).to(device=self.target_device, non_blocking=True) weight_out_dims = [_out_dim * self.tile_size // self.pack_factor for _out_dim in out_dims] weight_scale_out_dims = out_dims weight_zero_point_out_dims = [_out_dim // self.pack_factor for _out_dim in out_dims] @@ -239,9 +245,8 @@ def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: assert self.hf_quantization_config is not None, "hf_quantization_config is not set" if weight is None: return - device_id = get_current_device_id() repack_weight = vllm_ops.awq_marlin_repack( - weight.cuda(device_id), + weight.to(device=self.target_device, non_blocking=True), size_k=weight.shape[0], size_n=weight.shape[1] * self.pack_factor, num_bits=self.hf_quantization_config["bits"], @@ -255,9 +260,8 @@ def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack) if weight_scale is None: return group_size = self.hf_quantization_config["group_size"] - device_id = get_current_device_id() repack_weight_scale = marlin_permute_scales( - weight_scale.cuda(device_id), + weight_scale.to(device=self.target_device, non_blocking=True), size_k=weight_scale.shape[0] * group_size, size_n=weight_scale.shape[1], group_size=self.hf_quantization_config["group_size"], @@ -269,9 +273,8 @@ def load_weight_scale(self, weight_scale: torch.Tensor, weight_pack: WeightPack) def load_weight_zero_point(self, weight_zero_point: torch.Tensor, weight_pack: WeightPack) -> None: if weight_zero_point is None: return - device_id = get_current_device_id() repack_weight_zero_point = awq_to_marlin_zero_points( - weight_zero_point.cuda(device_id), + weight_zero_point.to(device=self.target_device, non_blocking=True), size_k=weight_zero_point.shape[0], size_n=weight_zero_point.shape[1] * self.pack_factor, num_bits=self.hf_quantization_config["bits"], @@ -290,7 +293,7 @@ def is_awq_marlin_compatible(quantization_config: dict[str, Any]): group_size = quantization_config.get("group_size") zero_point = quantization_config.get("zero_point") - if not torch.cuda.is_available(): + if not get_backend().runtime.is_available(): return False if quant_method != "awq": diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..5dff432aa3 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -61,8 +61,7 @@ def method_name(self): def quantize(self, weight: torch.Tensor, output: WeightPack): from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_quant_kernel import weight_quant - device = output.weight.device - weight, scale = weight_quant(weight.cuda(device), self.block_size) + weight, scale = weight_quant(weight.to(device=output.weight.device), self.block_size) output.weight.copy_(weight) output.weight_scale.copy_(scale) return @@ -111,10 +110,12 @@ def _create_weight( weight_scale_out_dim = sum(weight_scale_out_dims) weight_scale_in_dim = (in_dim + self.block_size - 1) // self.block_size expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight = torch.empty( + expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn + ).to(device=self.target_device, non_blocking=True) weight_scale = torch.empty( expert_prefix + (weight_scale_out_dim, weight_scale_in_dim), dtype=torch.float32 - ).cuda(device_id) + ).to(device=self.target_device, non_blocking=True) mm_param = WeightPack(weight=weight, weight_scale=weight_scale) mm_param_list = self._split_weight_pack( mm_param, diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index fa926ad6f0..7de5c4d2bf 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -39,7 +39,9 @@ def _create_weight( ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=dtype).cuda(device_id) + weight = torch.empty( + expert_prefix + (out_dim, in_dim), dtype=dtype + ).to(device=self.target_device, non_blocking=True) mm_param = WeightPack(weight=weight, weight_scale=None, weight_zero_point=None) # weight layout is (out_dim, in_dim), so the split dimension is -2. mm_param_list = self._split_weight_pack( diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 95d8d806f9..84412a6298 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,6 +1,7 @@ import torch from abc import ABC, abstractmethod from dataclasses import dataclass +from lightllm.utils.device_utils import get_target_device from lightllm.utils.dist_utils import get_current_device_id from typing import Optional, List, Tuple @@ -37,6 +38,7 @@ def __init__(self): # 一些量化模式需要用到的额外量化参数,如awq量化 self.hf_quantization_config = None + self.target_device = get_target_device(self.device_id_) @abstractmethod def quantize( @@ -63,9 +65,14 @@ def apply( def method_name(self): pass + def _ensure_target_device(self, device_id: int) -> None: + if device_id != self.device_id_: + raise ValueError(f"Device {device_id} is not the target device {self.device_id_}") + def create_weight( self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int ) -> Tuple[WeightPack, List[WeightPack]]: + self._ensure_target_device(device_id) return self._create_weight( out_dims=out_dims, in_dim=in_dim, @@ -76,6 +83,7 @@ def create_weight( def create_moe_weight( self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int ) -> Tuple[WeightPack, List[WeightPack]]: + self._ensure_target_device(device_id) return self._create_weight( out_dims=out_dims, in_dim=in_dim, diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 65ec6cd145..5f61d5e109 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -62,7 +62,7 @@ def __init__(self): self.has_weight_zero_point = False def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: - weight = weight.float().cuda(self.device_id_) + weight = weight.float().to(device=self.target_device) scale = weight.abs().max(dim=-1)[0] / 127 weight = weight / scale.reshape(-1, 1) weight = torch.round(weight.clamp(min=-127, max=127)).to(dtype=torch.int8) @@ -103,8 +103,8 @@ def _create_weight( ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.int8).to(device=self.target_device) + weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).to(device=self.target_device) mm_param = WeightPack(weight=weight, weight_scale=weight_scale) mm_param_list = self._split_weight_pack( mm_param, @@ -126,7 +126,7 @@ def __init__(self): def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: qweight, weight_scale = scaled_fp8_quant( - weight.cuda(self.device_id_), scale=None, use_per_token_if_dynamic=True + weight.to(device=self.target_device), scale=None, use_per_token_if_dynamic=True ) output.weight.copy_(qweight) output.weight_scale.copy_(weight_scale.view(-1)) @@ -167,8 +167,12 @@ def _create_weight( ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + weight = torch.empty( + expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn + ).to(device=self.target_device) + weight_scale = torch.empty( + expert_prefix + (out_dim,), dtype=torch.float32 + ).to(device=self.target_device) mm_param = WeightPack(weight=weight, weight_scale=weight_scale) mm_param_list = self._split_weight_pack( @@ -193,8 +197,7 @@ def __init__(self): def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_block_quant_kernel import weight_quant - device = output.weight.device - weight, scale = weight_quant(weight.cuda(device), self.block_size) + weight, scale = weight_quant(weight.to(device=output.weight.device), self.block_size) output.weight.copy_(weight) output.weight_scale.copy_(scale) return @@ -245,10 +248,11 @@ def _create_weight( ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) + weight = torch.empty( + expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).to(device=self.target_device) weight_scale = torch.empty( expert_prefix + (out_dim // self.block_size, in_dim // self.block_size), dtype=torch.float32 - ).cuda(device_id) + ).to(device=self.target_device) mm_param = WeightPack(weight=weight, weight_scale=weight_scale) weight_scale_out_dims = [_out_dim // self.block_size for _out_dim in out_dims] mm_param_list = self._split_weight_pack( diff --git a/lightllm/common/quantization/w8a8gx.py b/lightllm/common/quantization/w8a8gx.py index c25136697d..cd5f40dd60 100644 --- a/lightllm/common/quantization/w8a8gx.py +++ b/lightllm/common/quantization/w8a8gx.py @@ -49,7 +49,7 @@ def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: # per channel quantization for weight, per token group quantization for input activation from lightllm.common.basemodel.triton_kernel.quantization.fp8w8a8_perchannel_quant_kernel import weight_quant - qweight, weight_scale = weight_quant(weight.cuda(self.device_id_)) + qweight, weight_scale = weight_quant(weight.to(device=self.target_device)) output.weight.copy_(qweight) output.weight_scale.copy_(weight_scale.view(-1)) return @@ -123,8 +123,12 @@ def _create_weight( ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () - weight = torch.empty(expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn).cuda(device_id) - weight_scale = torch.empty(expert_prefix + (out_dim,), dtype=torch.float32).cuda(device_id) + weight = torch.empty( + expert_prefix + (out_dim, in_dim), dtype=torch.float8_e4m3fn + ).to(device=self.target_device) + weight_scale = torch.empty( + expert_prefix + (out_dim,), dtype=torch.float32 + ).to(device=self.target_device) mm_param = WeightPack(weight=weight, weight_scale=weight_scale) mm_param_list = self._split_weight_pack( diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..96a8d58d6c 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,13 +1,17 @@ import torch +import triton import collections + +from triton.backends import Backend from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional, TYPE_CHECKING from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size from lightllm.utils.config_utils import get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache @@ -62,29 +66,51 @@ def is_all_free(self): class ReqManager: def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryManager): + platform_backend = get_backend() + device = platform_backend.runtime.target_device() # 这里对最大请求数量的管理在默认上多申请了一个,主要是 index 为 max_request_num 代表 # 的这个请求管理 id, 主要是为了兼容 DP 运行模式下,让各个 DP 能 padding 到 DP 中最大 # 的那个batch size 进行运行,所有 padding 的请求都会使用预留的这个请求管理 id 进行处理 # 这样让 DP 的实现更为简化一些。 self.req_list = _ReqLinkedList(max_request_num) self.req_to_token_indexs = torch.zeros( - (max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda" + (max_request_num + 1, max_sequence_length), dtype=torch.int32, device=device ) self.mem_manager = mem_manager - self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) + self.req_sampling_params_manager = ReqSamplingParamsManager( + max_request_num, device=device, platform_backend=platform_backend) self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num def alloc(self): return self.req_list.alloc() + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) + + def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None): + b_token_len = b_seq_len + if b_ready_cache_len is not None: + b_token_len = b_seq_len - b_ready_cache_len + b_token_len_cumsum = torch.cumsum(b_token_len, dim=0) + b_last_mem_index = mem_indices[b_token_len_cumsum - 1] + return b_last_mem_index + + def alloc_mem_indices( + self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None + ) -> torch.Tensor: + page_size = get_page_size() + if page_size > 1 and b_seq_len is not None: + return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index) + return self.mem_manager.alloc(need_size) + def free(self, free_req_indexes: List[int], free_token_index): for req_index in free_req_indexes: self.req_list.free(req_index) if self.req_list.is_all_free(): logger.debug(f"freed all request size {self.req_list.can_alloc_size}") - self.mem_manager.free(free_token_index) + self.mem_manager.free(self._expand_to_page_mem_indices(free_token_index)) def free_req(self, free_req_index: int): self.req_list.free(free_req_index) @@ -93,13 +119,73 @@ def free_req(self, free_req_index: int): return def free_token(self, free_token_index): - self.mem_manager.free(free_token_index) + self.mem_manager.free(self._expand_to_page_mem_indices(free_token_index)) return def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def _expand_to_page_mem_indices(self, free_token_index): + page_size = get_page_size() + if page_size > 1: + if isinstance(free_token_index, list): + free_token_index = torch.tensor(free_token_index, dtype=torch.int32) + base_indices = free_token_index[free_token_index % page_size == 0] + if len(base_indices) == 0: + return free_token_index + page_offsets = torch.arange(page_size, dtype=base_indices.dtype, device=base_indices.device) + return (base_indices[:, None] + page_offsets[None, :]).reshape(-1) + + return free_token_index + + def _expand_by_page_size(self, b_token_len, page_size): + b_page_len = triton.cdiv(b_token_len, page_size) + need_pages_num = int(b_page_len.sum().item()) + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, p_token_len + + def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index): + b_seq_len = b_seq_len.cpu() + if b_ready_cache_len is not None: + b_ready_cache_len = b_ready_cache_len.cpu() + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) + paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) + pages = paged_token_idxs.view(-1, page_size) + mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) + return pages[mask] + + assert b_last_mem_index is not None + b_last_mem_index = b_last_mem_index.cpu() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = int(need_new_page_mask.sum().item()) + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size) + token_idxs[need_new_page_mask] = new_pages_tokens[::page_size] + mask = ~need_new_page_mask + if mask.any(): + token_idxs[mask] = b_last_mem_index[mask] + 1 + return token_idxs + + def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): + page_size = get_page_size() + if page_size == 1: + return 0 + + if b_ready_cache_len is not None: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = triton.cdiv(need_tokens_array, page_size) + need_new_pages = need_pages_array.sum() + else: + need_new_pages = ((b_seq_len - 1) % page_size == 0).sum() + return need_new_pages * page_size + class ReqSamplingParamsManager: """ @@ -109,25 +195,33 @@ class ReqSamplingParamsManager: lightllm/server/router/model_infer/mode_backend/generic_post_process.py 文件中的使用方式。 """ - def __init__(self, max_request_num): + def __init__(self, max_request_num, device: torch.device, platform_backend: Backend): + self.target_device = device + self.platform_backend = platform_backend # mode ["cpu_counter", "pin_mem_counter", "gpu_counter"] self.penalty_counter_mode = get_env_start_args().penalty_counter_mode - self.vocab_size = get_vocab_size(get_env_start_args().model_dir) - self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") - self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") - self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") + model_dir = get_env_start_args().model_dir + self.vocab_size = get_vocab_size(model_dir) + if self.vocab_size <= 0: + raise RuntimeError( + f"invalid vocab_size={self.vocab_size} from model_dir={model_dir}; " + "cannot allocate penalty token counter (check config.json, AutoConfig, or tokenizer)" + ) + self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device=self.target_device) + self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device=self.target_device) + self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device=self.target_device) self.req_to_next_token_ids = torch.zeros( (max_request_num + 1, 8), dtype=torch.int64, - device="cuda", + device=self.target_device, ) self.req_to_exponential_decay_length_penalty = torch.zeros( - max_request_num + 1, dtype=torch.float32, device="cuda" + max_request_num + 1, dtype=torch.float32, device=self.target_device ) if self.penalty_counter_mode == "gpu_counter": self.req_to_out_token_id_counter = torch.zeros( - (max_request_num + 1, self.vocab_size), dtype=torch.int32, device="cuda" + (max_request_num + 1, self.vocab_size), dtype=torch.int32, device=self.target_device ) elif self.penalty_counter_mode == "pin_mem_counter": self.req_to_out_token_id_counter = torch.zeros( @@ -163,11 +257,11 @@ def init_req_sampling_params(self, req: "InferReq"): key="prompt_ids_for_penalty", data=req.shm_req.get_prompt_ids_numpy(), dtype=torch.int32, - ).cuda(non_blocking=True) + ).to(device=self.target_device, non_blocking=True) token_id_counter( prompt_ids=prompt_ids, out_token_id_counter=self.req_to_out_token_id_counter[req.req_idx] ) - torch.cuda.current_stream().synchronize() + self.platform_backend.runtime.current_stream().synchronize() return @@ -177,7 +271,10 @@ def update_reqs_out_token_counter_gpu( if self.penalty_counter_mode not in ["gpu_counter", "pin_mem_counter"]: return - assert b_req_idx.is_cuda and next_token_ids.is_cuda and b_req_idx.shape[0] == next_token_ids.shape[0] + assert b_req_idx.device == next_token_ids.device, \ + f"b_req_idx.device ({b_req_idx.device}) != next_token_ids.device ({next_token_ids.device})" + assert b_req_idx.shape[0] == next_token_ids.shape[0], \ + f"b_req_idx.shape[0] ({b_req_idx.shape[0]}) != next_token_ids.shape[0] ({next_token_ids.shape[0]})" update_req_to_token_id_counter( b_req_idx=b_req_idx, @@ -223,9 +320,9 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List["InferReq"]): ) return ( - p_token_ids_tensor.cuda(non_blocking=True), - p_token_counts_tensor.cuda(non_blocking=True), - p_cumsum_seq_len_tensor.cuda(non_blocking=True), + p_token_ids_tensor.to(device=self.target_device, non_blocking=True), + p_token_counts_tensor.to(device=self.target_device, non_blocking=True), + p_cumsum_seq_len_tensor.to(device=self.target_device, non_blocking=True), ) @@ -246,14 +343,14 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager, linear_con dtype=self.linear_config.conv_state_dtype, shape=self.linear_config.get_conv_state_shape(), layer_num=self.linear_config.linear_layer_num, - device="cuda", + device=self.target_device, ) self.req_to_ssm_state = LayerCache( size=(max_request_num + 1) * (self.mtp_step + 1), dtype=self.linear_config.ssm_state_dtype, shape=self.linear_config.get_ssm_state_shape(), layer_num=self.linear_config.linear_layer_num, - device="cuda", + device=self.target_device, ) return diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index c62a2572ff..c335ad05f2 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -9,6 +9,7 @@ from pathlib import Path from tqdm import tqdm from frozendict import frozendict +from lightllm.platform import get_backend from lightllm.utils.device_utils import get_current_device_name from lightllm.utils.log_utils import init_logger from typing import Callable, Optional, Union, List @@ -103,17 +104,9 @@ def __init__( run_key_distance_func: Callable = lambda run_key, config_key: abs(int(run_key) - int(config_key)), mutates_args: List[str] = [], ): - self.configs_gen_func = configs_gen_func self.kernel_name = kernel_name - self.cache_dir = os.path.join( - Path(__file__).parent, - "autotune_kernel_configs", - get_triton_version(), - get_current_device_name(), - self.kernel_name, - ) - os.makedirs(self.cache_dir, exist_ok=True) + self._cache_dir = None self.fn = fn self.static_key_func = static_key_func self.run_key_func = run_key_func @@ -130,6 +123,7 @@ def __init__( ] self._run_key_func_param_names = [name for name, _ in inspect.signature(self.run_key_func).parameters.items()] self.mutates_args = mutates_args + self._platform_backend = None assert get_triton_autotune_level() in [ AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG, @@ -139,6 +133,25 @@ def __init__( ] return + @property + def platform_backend(self): + if self._platform_backend is None: + self._platform_backend = get_backend() + return self._platform_backend + + @property + def cache_dir(self) -> str: + if self._cache_dir is None: + self._cache_dir = os.path.join( + Path(__file__).parent, + "autotune_kernel_configs", + get_triton_version(), + get_current_device_name(), + self.kernel_name, + ) + os.makedirs(self._cache_dir, exist_ok=True) + return self._cache_dir + @torch.no_grad() def __call__(self, *args, **kwargs): if kwargs.get("run_config", None) is not None: @@ -253,19 +266,19 @@ def kernel_call(): # warmup kernel_call() - torch.cuda.current_stream().synchronize() - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, stream=torch.cuda.Stream()): + self.platform_backend.runtime.current_stream().synchronize() + g = self.platform_backend.graph.create_graph() + with self.platform_backend.graph.graph(g, stream=self.platform_backend.runtime.create_stream()): for _ in range(n_repeat): kernel_call() - torch.cuda.current_stream().synchronize() + self.platform_backend.runtime.current_stream().synchronize() state = _BenchmarkState() for i in range(n_retries): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + start_event = self.platform_backend.runtime.create_event(enable_timing=True) + end_event = self.platform_backend.runtime.create_event(enable_timing=True) start_event.record() - g.replay() + self.platform_backend.graph.replay_graph(g) end_event.record() end_event.synchronize() state.update(start_event.elapsed_time(end_event) / n_repeat) diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index f01f1c87f7..46ef5eecb8 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -23,6 +23,7 @@ import torch.distributed as dist from torch.distributed import ReduceOp, ProcessGroup from typing import List, Dict, Optional, Union +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import has_nvlink from lightllm.utils.envs_utils import ( @@ -56,13 +57,15 @@ def __init__(self): self.symm_mem_reduce = None self.flashinfer_reduce = None self.dp_world_size = get_dp_world_size() - self.device_group = create_new_group_for_current_dp("nccl") + dist_backend = get_backend().runtime.dist_backend + self.device_group = create_new_group_for_current_dp(dist_backend) if get_env_start_args().enable_dp_prefill_balance: - self.dp_prefill_balance_group = create_dp_special_inter_group("nccl") + self.dp_prefill_balance_group = create_dp_special_inter_group(dist_backend) else: self.dp_prefill_balance_group = None self.autotune_group = dist.new_group([i for i in range(get_global_world_size())], backend="gloo") + self.backend_runtime = get_backend().runtime def _support_custom_allreduce(self) -> bool: return has_nvlink() and self.dp_world_size in [2, 4, 6, 8] @@ -73,7 +76,7 @@ def init_symm_mem_reduce(self) -> None: from .symm_mem_all_reduce import SymmMemAllreduce data_type = get_torch_dtype(get_env_start_args().data_type) - symm = SymmMemAllreduce(self.device_group, torch.cuda.current_device(), dtype=data_type) + symm = SymmMemAllreduce(self.device_group, self.backend_runtime.current_device(), dtype=data_type) if not symm.disabled: self.symm_mem_reduce = symm logger.info("Enable SymmMem ALLReduce.") @@ -84,7 +87,7 @@ def init_flashinfer_reduce(self) -> None: from .flashinfer_all_reduce import FlashInferAllReduce fi_cpu_group = create_new_group_for_current_dp("gloo") - fi = FlashInferAllReduce(fi_cpu_group, torch.cuda.current_device()) + fi = FlashInferAllReduce(fi_cpu_group, self.backend_runtime.current_device()) if not fi.disabled: self.flashinfer_reduce = fi logger.info("Enable FlashInfer ALLReduce.") diff --git a/lightllm/distributed/flashinfer_all_reduce.py b/lightllm/distributed/flashinfer_all_reduce.py index 27856d9ac7..457de04fd4 100644 --- a/lightllm/distributed/flashinfer_all_reduce.py +++ b/lightllm/distributed/flashinfer_all_reduce.py @@ -6,6 +6,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -54,7 +55,7 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]) - self._ws_dtype = None self._ws_max_token_num = 0 - if not _FI_OK or not torch.cuda.is_available(): + if not _FI_OK or not get_backend().runtime.is_available(): return if isinstance(device, int): device = torch.device(f"cuda:{device}") diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index 1363f719a6..3bf9037a23 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -66,7 +66,7 @@ def _parse_config(self): assert self.n_head % self.tp_world_size_ == 0 tp_head_num = self.n_head // self.tp_world_size_ tmp_alibi = generate_alibi(self.n_head, dtype=torch.float32) - self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num : (self.tp_rank_ + 1) * tp_head_num].contiguous().cuda() + self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num : (self.tp_rank_ + 1) * tp_head_num].contiguous().to(device=self.target_device) def _init_weight_names(self): self._q_weight_name = f"h.{self.layer_num_}.self_attention.q_proj.weight" diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index fa2dee444f..8265475da7 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -5,7 +5,7 @@ from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd as ds2_rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from functools import partial from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale @@ -161,11 +161,14 @@ def _get_qkv( layer_weight.kv_a_layernorm_( cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) - rotary_emb_fwd( - q_rope, - cache_kv[:, :, self.kv_lora_rank :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q_rope, + k=cache_kv[:, :, self.kv_lora_rank :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, + rotary_impl=ds2_rotary_emb_fwd, ) if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) @@ -190,11 +193,14 @@ def _get_qkv( layer_weight.kv_a_layernorm_( cache_kv[:, :, : self.kv_lora_rank], eps=self.eps_, out=cache_kv[:, :, : self.kv_lora_rank] ) - rotary_emb_fwd( - q_rope, - cache_kv[:, :, self.kv_lora_rank :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q_rope, + k=cache_kv[:, :, self.kv_lora_rank :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, + rotary_impl=ds2_rotary_emb_fwd, ) return q, cache_kv diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 3eb09f9176..09fa5ae18c 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -102,8 +102,8 @@ def load_hf_weights(self, weights): # for deepseek_v3, the bmm operator is not quantized if self.quant_cfg.quantized_weight: kv_b_proj_ = weight_dequant( - kv_b_proj_.cuda(), - weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), + kv_b_proj_.to(device=self.target_device), + weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].to(device=self.target_device), ).cpu() k_b_proj_, v_b_proj_ = self._split_kv_b_proj(kv_b_proj_) weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = k_b_proj_ diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index e596eed97c..d0acca13e7 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -84,13 +84,13 @@ def _init_to_get_yarn_rotary(self): beta_fast = rope_scaling.get("beta_fast", 32.0) beta_slow = rope_scaling.get("beta_slow", 1.0) - pos_freqs = base ** (torch.arange(0, dim, 2).float().cuda() / dim) + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device=self.target_device) / dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scale * pos_freqs) low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) inv_freq_mask = ( - 1 - linear_ramp_mask(low, high, dim // 2).float().cuda() + 1 - linear_ramp_mask(low, high, dim // 2).float().to(device=self.target_device) ) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask @@ -100,10 +100,10 @@ def _init_to_get_yarn_rotary(self): # Build here to make `torch.jit.trace` work. max_seq_len_cached = max_position_embeddings - t = torch.arange(max_seq_len_cached, device="cuda", dtype=torch.float32) + t = torch.arange(max_seq_len_cached, device=self.target_device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation - self._cos_cached = (freqs.cos() * _mscale).to(self.data_type).cuda() - self._sin_cached = (freqs.sin() * _mscale).to(self.data_type).cuda() + self._cos_cached = (freqs.cos() * _mscale).to(self.data_type).to(device=self.target_device) + self._sin_cached = (freqs.sin() * _mscale).to(self.data_type).to(device=self.target_device) return diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py index b9be73e278..6e4b641192 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_fp8.py @@ -4,6 +4,7 @@ import triton import triton.language as tl from typing import List +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import get_device_sm_count @@ -32,7 +33,7 @@ def gqa_token_decode_attention_flash_decoding_fp8( calcu_shape2 = (batch_size, q_head_num, q_rope_dim) if not run_config: - if torch.cuda.is_current_stream_capturing(): + if get_backend().graph.is_capturing(): avg_seq_len_in_batch = max_kv_seq_len else: avg_seq_len_in_batch = infer_state.total_token_num // batch_size @@ -55,9 +56,10 @@ def gqa_token_decode_attention_flash_decoding_fp8( o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out - fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device="cuda") - mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda") - mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda") + device = q_nope.device + fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device=device) + mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device=device) + mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device=device) vsm_count = flash_decode_stage1_fp8( fake_decode_att_block_seq, @@ -83,14 +85,14 @@ def gqa_token_decode_attention_flash_decoding_fp8( 1, ], dtype=torch.int64, - device="cuda", + device=device, ) mid_o_batch_start_index = torch.empty( [ batch_size, ], dtype=torch.int64, - device="cuda", + device=device, ) _fwd_kernel_calcu_index_and_block_seq[(1,)]( infer_state.b_seq_len, @@ -105,8 +107,8 @@ def gqa_token_decode_attention_flash_decoding_fp8( infer_state.decode_att_block_seq = decode_att_block_seq infer_state.mid_o_batch_start_index = mid_o_batch_start_index - mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda") - mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda") + mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device=device) + mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device=device) flash_decode_stage1_fp8( infer_state.decode_att_block_seq, diff --git a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py index 899531448b..66357b5aa3 100644 --- a/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek3_2/layer_infer/transformer_layer_infer.py @@ -4,13 +4,14 @@ from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek3_2.layer_weights.transformer_layer_weight import Deepseek3_2TransformerLayerWeight from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward -from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd as ds2_rotary_emb_fwd from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.models.deepseek3_2.triton_kernel.act_quant import act_quant from lightllm.models.deepseek3_2.triton_kernel.destindex_copy_indexer_ks import destindex_copy_indexer_ks from lightllm.models.deepseek3_2.triton_kernel.extract_indexer_ks import extract_indexer_ks from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import all_gather_into_tensor +from lightllm.platform import get_backend class Deepseek3_2TransformerLayerInfer(Deepseek2TransformerLayerInfer): @@ -55,11 +56,14 @@ def _get_qkv( out=cache_kv[:, :, : self.kv_lora_rank], ) - rotary_emb_fwd( - q_rope, - cache_kv[:, :, self.kv_lora_rank :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q_rope, + k=cache_kv[:, :, self.kv_lora_rank :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, + rotary_impl=ds2_rotary_emb_fwd, ) return q, cache_kv @@ -270,13 +274,14 @@ def _get_q_k_bf16( k = layer_weight.k_norm_(k, eps=self.eps) # 为什么 indexer 和主模型用的q k 的 rotary的排布方式不一样,这不是脱裤子放屁麻。 - from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd - - rotary_emb_fwd( - q[:, :, : self.qk_rope_head_dim], - k[:, None, : self.qk_rope_head_dim], - infer_state.position_cos, - infer_state.position_sin, + get_backend().ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q[:, :, : self.qk_rope_head_dim], + k=k[:, None, : self.qk_rope_head_dim], + cos=infer_state.position_cos, + sin=infer_state.position_sin, + rotary_impl=ds2_rotary_emb_fwd, ) q = self._rotate_activation(q) diff --git a/lightllm/models/gemma3/gemma3_visual.py b/lightllm/models/gemma3/gemma3_visual.py index b2f7a6b779..2f06aae0cb 100644 --- a/lightllm/models/gemma3/gemma3_visual.py +++ b/lightllm/models/gemma3/gemma3_visual.py @@ -7,6 +7,7 @@ from typing import List, Union from safetensors import safe_open from io import BytesIO +from lightllm.models.visual_utils import VisualDeviceMixin from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.utils.log_utils import init_logger @@ -15,9 +16,13 @@ logger = init_logger(__name__) -class Gemma3VisionModel: - def __init__(self): - pass +class Gemma3VisionModel(VisualDeviceMixin): + + def _device_module_attrs(self): + return ("vision_tower", "avg_pool",) + + def _device_tensor_dict_attrs(self): + return ("projector_weights",) def load_model(self, weight_dir): config_file = os.path.join(weight_dir, "config.json") @@ -29,13 +34,13 @@ def load_model(self, weight_dir): else: assert False, "only hf format model is supported for Gemma3" + self.mm_tokens_per_image = int(config["mm_tokens_per_image"]) self.patches_per_image = int(config["vision_config"]["image_size"] // config["vision_config"]["patch_size"]) - self.tokens_per_side = int(config["mm_tokens_per_image"] ** 0.5) + self.tokens_per_side = int(self.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) self.vision_tower.requires_grad_(False) - self.device = torch.device("cpu") assert "model.mm_projector.linear" in self.projector_weights assert "model.mm_projector.norm" in self.projector_weights @@ -52,8 +57,15 @@ def load_hf_model(self, config, weight_dir): torch_dtype=torch.float16, ) self.vision_tower = model.vision_tower - model.multi_modal_projector = None - model.language_model = None + # Free projector/LLM memory. New transformers uses read-only properties on the + # wrapper; fall back to inner model.model (Gemma3Model) when setattr fails. + try: + model.multi_modal_projector = None + model.language_model = None + except AttributeError: + inner_model = getattr(model, "model", model) + inner_model.multi_modal_projector = None + inner_model.language_model = None # load projector weights self.projector_weights = {} @@ -70,12 +82,6 @@ def load_hf_model(self, config, weight_dir): k.replace("multi_modal_projector.mm_soft_emb_norm.weight", "model.mm_projector.norm") ] = d.get_tensor(k).to(torch.bfloat16) - def cuda(self): - self.vision_tower = self.vision_tower.cuda() - for k, v in self.projector_weights.items(): - self.projector_weights[k] = v.cuda() - return self - def gemma3_rms_norm(self, input, weight, eps: float = 1e-6): def _norm(x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) @@ -88,7 +94,7 @@ def _norm(x): # batch images infer def forward(self, x): - x = x.to(torch.bfloat16).cuda() + x = self.move_to_infer_device(x.to(torch.bfloat16)) x = self.vision_tower(x, output_hidden_states=True).last_hidden_state batch_size, _, seq_length = x.shape @@ -129,7 +135,7 @@ def encode(self, images: List[ImageItem]): else: raise Exception("Unsupport input types: {} for {}".format(type(img), img)) - cur_num = img_tensors[-1].shape[0] + cur_num = img_tensors[-1].shape[0] * self.mm_tokens_per_image valid_ids.append([valid_id, valid_id + cur_num]) valid_id += cur_num @@ -138,5 +144,6 @@ def encode(self, images: List[ImageItem]): img = torch.cat(img_tensors, dim=0) all_img_embeds = self.forward(img) + all_img_embeds = all_img_embeds.reshape(-1, all_img_embeds.shape[-1]) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py index ff8d899e8e..53827d2cb5 100644 --- a/lightllm/models/gemma3/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/pre_layer_infer.py @@ -7,7 +7,10 @@ class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer): def __init__(self, network_config): super().__init__(network_config) - self.embed_scale = torch.tensor(network_config["hidden_size"] ** 0.5, dtype=torch.float32) + self.embed_scale = torch.tensor( + network_config["hidden_size"] ** 0.5, + device=self.target_device, + dtype=torch.float32) self.boi_token_index: int = 255_999 self.eoi_token_index: int = 256_000 return @@ -56,17 +59,17 @@ def context_forward(self, input_ids, infer_state, layer_weight): f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}" ) # each tp will fill the img embeds, should divide by world_size - img_start_token_ids = torch.tensor(img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True).cuda( - non_blocking=True - ) - img_token_lens = torch.tensor(img_token_lens, dtype=torch.long, device="cpu", pin_memory=True).cuda( - non_blocking=True - ) + img_start_token_ids = torch.tensor( + img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True + ).to(device=self.target_device, non_blocking=True) + img_token_lens = torch.tensor( + img_token_lens, dtype=torch.long, device="cpu", pin_memory=True + ).to(device=self.target_device, non_blocking=True) img_start_locs_in_cache = torch.tensor( img_start_locs_in_cache, dtype=torch.long, device="cpu", pin_memory=True - ).cuda(non_blocking=True) + ).to(device=self.target_device, non_blocking=True) - multimodal_emb( + self.platform_backend.ops.multimodal_emb( out=out, prompt_ids=input_ids, text_weight_embs=layer_weight.wte_weight_.weight, @@ -86,4 +89,4 @@ def context_forward(self, input_ids, infer_state, layer_weight): def token_forward(self, input_ids, infer_state, layer_weight): input_embedding = super().token_forward(input_ids, infer_state, layer_weight) input_dtype = input_embedding.dtype - return (input_embedding.float() * self.embed_scale.to(input_embedding.device).float()).to(input_dtype) + return (input_embedding.float() * self.embed_scale).to(input_dtype) diff --git a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py index 86f00cfbb4..9971939e29 100644 --- a/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma3/layer_infer/transformer_layer_infer.py @@ -4,7 +4,6 @@ from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): @@ -12,10 +11,8 @@ class Gemma3TransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) - self.tp_k_head_num_ = network_config["num_key_value_heads"] - self.tp_v_head_num_ = network_config["num_key_value_heads"] self.eps_ = 1e-6 - self.head_dim_ = 256 + self.head_dim_ = network_config.get("head_dim", self.head_dim_) self.sliding_window_pattern = 6 return @@ -44,18 +41,22 @@ def _get_qkv( is_sliding = bool((self.layer_num_ + 1) % self.sliding_window_pattern) if is_sliding: - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos_local.to(q.dtype), - infer_state.position_sin_local.to(q.dtype), + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, 0 : self.tp_k_head_num_, :], + cos=infer_state.position_cos_local.to(q.dtype), + sin=infer_state.position_sin_local.to(q.dtype), ) else: - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos_global.to(q.dtype), - infer_state.position_sin_global.to(q.dtype), + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, 0 : self.tp_k_head_num_, :], + cos=infer_state.position_cos_global.to(q.dtype), + sin=infer_state.position_sin_global.to(q.dtype), ) if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) diff --git a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py index 7ae0fbcca3..8e7bf27425 100644 --- a/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/pre_and_post_layer_weight.py @@ -1,5 +1,5 @@ from lightllm.common.basemodel import PreAndPostLayerWeight -from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, NoTpGEMMANormWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, NoTpGEMMANormWeight class Gemma3PreAndPostLayerWeight(PreAndPostLayerWeight): @@ -14,12 +14,17 @@ def __init__(self, data_type, network_config): weight_name="language_model.model.embed_tokens.weight", data_type=self.data_type_, ) - self.lm_head_weight_ = self.wte_weight_ + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="language_model.model.embed_tokens.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_, + ) self.final_norm_weight_ = NoTpGEMMANormWeight( dim=hidden_size, weight_name="language_model.model.norm.weight", data_type=self.data_type_, - bias_name=None, ) return diff --git a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py index a4340a17a7..2077014156 100644 --- a/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma3/layer_weights/transformer_layer_weight.py @@ -4,6 +4,8 @@ class Gemma3TransformerLayerWeight(LlamaTransformerLayerWeight): + _HF_LAYER_PREFIX = "language_model.model.layers." + def __init__( self, layer_num, @@ -14,14 +16,31 @@ def __init__( super().__init__(layer_num, data_type, network_config, quant_cfg) return + + def _init_weight_names(self): super()._init_weight_names() - self._att_norm_weight_name = f"model.layers.{self.layer_num_}.input_layernorm.weight" - self._k_norm_weight_name = f"model.layers.{self.layer_num_}.self_attn.k_norm.weight" - self._q_norm_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_norm.weight" - self._ffn_norm_weight_name = f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" - self._pre_feedforward_layernorm_name = f"model.layers.{self.layer_num_}.pre_feedforward_layernorm.weight" - self._post_feedforward_layernorm_name = f"model.layers.{self.layer_num_}.post_feedforward_layernorm.weight" + old = f"model.layers.{self.layer_num_}." + new = f"{self._HF_LAYER_PREFIX}{self.layer_num_}." + for attr in ( + "_q_weight_name", + "_k_weight_name", + "_v_weight_name", + "_kv_weight_name", + "_o_weight_name", + "_gate_weight_name", + "_up_weight_name", + "_gate_up_weight_name", + "_down_weight_name", + "_att_norm_weight_name", + "_ffn_norm_weight_name", + ): + setattr(self, attr, getattr(self, attr).replace(old, new)) + p = new + self._k_norm_weight_name = f"{p}self_attn.k_norm.weight" + self._q_norm_weight_name = f"{p}self_attn.q_norm.weight" + self._pre_feedforward_layernorm_name = f"{p}pre_feedforward_layernorm.weight" + self._post_feedforward_layernorm_name = f"{p}post_feedforward_layernorm.weight" def _init_ffn(self): self.gate_proj = ROWMMWeight( @@ -66,22 +85,20 @@ def _init_norm(self): super()._init_norm() self.k_norm_weight_ = NoTpGEMMANormWeight( - dim=self.head_dim_, weight_name=self._k_norm_weight_name, data_type=self.data_type_, bias_name=None + dim=self.head_dim, weight_name=self._k_norm_weight_name, data_type=self.data_type_ ) self.q_norm_weight_ = NoTpGEMMANormWeight( - dim=self.head_dim_, weight_name=self._q_norm_weight_name, data_type=self.data_type_, bias_name=None + dim=self.head_dim, weight_name=self._q_norm_weight_name, data_type=self.data_type_ ) self.pre_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( dim=self.n_embed, weight_name=self._pre_feedforward_layernorm_name, data_type=self.data_type_, - bias_name=None, ) self.post_feedforward_layernorm_weight_ = NoTpGEMMANormWeight( dim=self.n_embed, weight_name=self._post_feedforward_layernorm_name, data_type=self.data_type_, - bias_name=None, ) def load_hf_weights(self, weights): diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index 9931c31713..7d3b0745aa 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -108,11 +108,11 @@ def _init_to_get_rotary(self, default_base=10000.0): max_seq_len = max_position_embeddings * rope_scaling_factor inv_freq_local = 1.0 / ( - 10000.0 ** (torch.arange(0, partial_head_dim, 2, dtype=torch.int64).float().cuda() / partial_head_dim) + 10000.0 ** (torch.arange(0, partial_head_dim, 2, dtype=torch.int64).float().to(device=self.target_device) / partial_head_dim) ) inv_freq_global = ( 1.0 - / (1000000.0 ** (torch.arange(0, partial_head_dim, 2, dtype=torch.int64).float().cuda() / partial_head_dim)) + / (1000000.0 ** (torch.arange(0, partial_head_dim, 2, dtype=torch.int64).float().to(device=self.target_device) / partial_head_dim)) / rope_scaling_factor ) # local default @@ -125,14 +125,14 @@ def _init_to_get_rotary(self, default_base=10000.0): freqs_global = torch.outer(t, inv_freq_global) freqs_local = torch.outer(t, inv_freq_local) - self._cos_cached = torch.cos(freqs_global).to(torch.float32).cuda() - self._sin_cached = torch.sin(freqs_global).to(torch.float32).cuda() + self._cos_cached = torch.cos(freqs_global).to(torch.float32).to(device=self.target_device) + self._sin_cached = torch.sin(freqs_global).to(torch.float32).to(device=self.target_device) - self._cos_cached_global = torch.cos(freqs_global).to(torch.float32).cuda() - self._sin_cached_global = torch.sin(freqs_global).to(torch.float32).cuda() + self._cos_cached_global = torch.cos(freqs_global).to(torch.float32).to(device=self.target_device) + self._sin_cached_global = torch.sin(freqs_global).to(torch.float32).to(device=self.target_device) - self._cos_cached_local = torch.cos(freqs_local).to(torch.float32).cuda() - self._sin_cached_local = torch.sin(freqs_local).to(torch.float32).cuda() + self._cos_cached_local = torch.cos(freqs_local).to(torch.float32).to(device=self.target_device) + self._sin_cached_local = torch.sin(freqs_local).to(torch.float32).to(device=self.target_device) return def _init_custom(self): diff --git a/lightllm/models/glm4_moe_lite/model.py b/lightllm/models/glm4_moe_lite/model.py index a8fe49ac5e..eb1a879520 100644 --- a/lightllm/models/glm4_moe_lite/model.py +++ b/lightllm/models/glm4_moe_lite/model.py @@ -48,5 +48,5 @@ def _init_glm4_standard_rotary(self): t = torch.arange(max_seq_len, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached = torch.sin(freqs).to(self.data_type).to(device=self.target_device) diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py index 093ad2b5d1..d44da234ad 100644 --- a/lightllm/models/internvl/internvl_visual.py +++ b/lightllm/models/internvl/internvl_visual.py @@ -8,6 +8,7 @@ from torchvision import transforms as T from torchvision.transforms.functional import InterpolationMode from transformers import AutoModel, AutoTokenizer +from lightllm.models.visual_utils import VisualDeviceMixin, default_infer_dtype from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO @@ -18,13 +19,13 @@ logger = init_logger(__name__) -class InternVLVisionModel: - def __init__(self): - pass +class InternVLVisionModel(VisualDeviceMixin): + + def _device_module_attrs(self): + return ("model",) def load_model(self, weight_dir): - assert torch.cuda.is_available() - self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + self.dtype = default_infer_dtype() self.config = json.load(open(os.path.join(weight_dir, "config.json"))) # self.model = AutoModel.from_pretrained( # weight_dir, @@ -38,12 +39,9 @@ def load_model(self, weight_dir): self.model = InternVLChatModel.from_pretrained( weight_dir, config=cfg, torch_dtype=self.dtype, language_model="fake_language_model" ) - self.model.eval().cuda() + self.model.eval() self.load_image_func = get_load_image_func(weight_dir) - def cuda(self): - return self - def encode(self, images: List[ImageItem]): img_tensors = [] valid_ids = [] @@ -68,7 +66,7 @@ def encode(self, images: List[ImageItem]): return None imgs = torch.cat(img_tensors, dim=0) - pixel_values = imgs.cuda().to(dtype=self.dtype) + pixel_values = self.move_to_infer_device(imgs, dtype=self.dtype) all_img_embeds = self.model.extract_feature(pixel_values) return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/llama/layer_infer/post_layer_infer.py b/lightllm/models/llama/layer_infer/post_layer_infer.py index 3d575975b5..bc9dfe9c97 100644 --- a/lightllm/models/llama/layer_infer/post_layer_infer.py +++ b/lightllm/models/llama/layer_infer/post_layer_infer.py @@ -35,13 +35,13 @@ def _slice_get_last_input(self, input_embdings: torch.Tensor, infer_state: Llama select_token_num += 1 last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) - last_input = self.alloc_tensor((select_token_num, embed_dim_), dtype=input_embdings.dtype) + last_input = self.alloc_tensor((select_token_num, embed_dim_), dtype=input_embdings.dtype, device=input_embdings.device) last_input[:, :] = input_embdings[last_index, :] return last_input, select_token_num if infer_state.is_prefill and not infer_state.return_all_prompt_logics: batch_size = infer_state.batch_size - last_input = self.alloc_tensor((batch_size, embed_dim_), dtype=input_embdings.dtype) + last_input = self.alloc_tensor((batch_size, embed_dim_), dtype=input_embdings.dtype, device=input_embdings.device) last_index = ( torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 ) @@ -63,6 +63,7 @@ def _token_forward( ): last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) input_embdings_dtype = input_embdings.dtype + input_embdings_device = input_embdings.device input_embdings = None last_input = self._norm(last_input, infer_state, layer_weight) last_input = last_input.permute(1, 0).view(-1, token_num) @@ -72,7 +73,8 @@ def _token_forward( if self.tp_world_size_ == 1: gather_data = logic_batch else: - gather_data = self.alloc_tensor((vocab_size, token_num), dtype=input_embdings_dtype) + gather_data = self.alloc_tensor( + (vocab_size, token_num), dtype=input_embdings_dtype, device=input_embdings_device) split_indexes = np.linspace(0, vocab_size, self.tp_world_size_ + 1, dtype=np.int64) all_gather( [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.tp_world_size_)], @@ -84,6 +86,7 @@ def _token_forward( ans_logics = self.alloc_tensor( (token_num, vocab_size), dtype=torch.float32, + device=input_embdings_device, ) ans_logics[:, :] = gather_data.permute(1, 0) gather_data = None diff --git a/lightllm/models/llama/layer_infer/pre_layer_infer.py b/lightllm/models/llama/layer_infer/pre_layer_infer.py index edeb764ec9..7442c889e6 100644 --- a/lightllm/models/llama/layer_infer/pre_layer_infer.py +++ b/lightllm/models/llama/layer_infer/pre_layer_infer.py @@ -15,7 +15,6 @@ def __init__(self, network_config): return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): - input_embdings = layer_weight.wte_weight_(input_ids=input_ids, alloc_func=self.alloc_tensor) if self.tp_world_size_ > 1: all_reduce(input_embdings, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 69acffaa4d..a52c962cfd 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -3,8 +3,6 @@ import torch.distributed as dist from functools import partial from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.basemodel import TransformerLayerInferTpl from lightllm.distributed.communication_op import all_gather_into_tensor, reduce_scatter_tensor @@ -13,6 +11,85 @@ logger = init_logger(__name__) +def npu_silu_and_mul_fwd( + input: torch.Tensor, + layout="blocked", + limit=None, + alpha=None, +) -> torch.Tensor: + assert input.is_contiguous() + assert input.dim() == 2 + assert (limit is None and alpha is None) or (limit is not None and alpha is not None) + N = input.shape[1] // 2 + + if layout == "blocked": + gate = input[:, :N] + up = input[:, N:] + elif layout == "interleaved": + gate = input[:, 0::2] + up = input[:, 1::2] + else: + raise ValueError(f"unknown layout: {layout}") + + if limit is not None and alpha is not None: + gate_fp32_limit = torch.minimum( + gate.float(), + torch.tensor(limit, device=gate.device, dtype=torch.float32), + ) + + gate_act = torch.sigmoid(gate_fp32_limit * alpha) * gate_fp32_limit + gate_act = gate_act.to(input.dtype) + + up_clip = torch.clamp(up, -limit, limit) + out = (up_clip + 1) * gate_act + else: + import torch_npu + + out = torch_npu.npu_swiglu(input, dim=-1) + + return out + + +def npu_ffn_fwd( + input: torch.Tensor, + layer_weight: LlamaTransformerLayerWeight, + embed_dim: int, +) -> torch.Tensor: + import torch.nn.functional as F + + input = input.view(-1, embed_dim) + # up + gate_up_proj_bias = [layer_weight.gate_up_proj.bias] if layer_weight.gate_up_proj.bias is not None else None + weight = layer_weight.gate_up_proj.mm_param.weight + # up_gate_out = torch_npu.npu_grouped_matmul( + # x=[input], + # weight=[weight], + # bias=gate_up_proj_bias, + # split_item=0, + # group_type=-1, + # group_list=None, + # )[0] + up_gate_out = F.linear(input, weight, bias=gate_up_proj_bias) + + # activation + ffn1_out = npu_silu_and_mul_fwd(up_gate_out) + + # down + down_proj_bias = [layer_weight.down_proj.bias] if layer_weight.down_proj.bias is not None else None + weight = layer_weight.down_proj.mm_param.weight + # ffn2_out = torch_npu.npu_grouped_matmul( + # x=[ffn1_out], + # weight=[weight], + # bias=down_proj_bias, + # split_item=0, + # group_type=-1, + # group_list=None, + # )[0] + ffn2_out = F.linear(ffn1_out, weight, bias=down_proj_bias) + + return ffn2_out + + class LlamaTransformerLayerInfer(TransformerLayerInferTpl): """ """ @@ -83,11 +160,13 @@ def _get_qkv( q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, 0 : self.tp_k_head_num_, :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, ) if infer_state.need_dp_prefill_balance: @@ -112,21 +191,17 @@ def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTrans input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) ffn2_out = self._ffn_tp(input=input, infer_state=infer_state, layer_weight=layer_weight) - ffn2_out = self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) - return ffn2_out + return self._tpsp_reduce(input=ffn2_out, infer_state=infer_state) def _ffn_tp( self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - up_gate_out = layer_weight.gate_up_proj.mm(input) - ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) - silu_and_mul_fwd(up_gate_out, ffn1_out) - input = None - up_gate_out = None - ffn2_out = layer_weight.down_proj.mm(ffn1_out) - ffn1_out = None - return ffn2_out + return self.platform_backend.ops.ffn( + input=input, + layer_weight=layer_weight, + alloc_func=self.alloc_tensor, + embed_dim=self.embed_dim_, + ) # # keep code # def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight)->torch.Tensor: diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index c104ebccc9..db1f9baa48 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -135,8 +135,8 @@ def _init_to_get_rotary(self, default_base=10000): ) freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached = torch.sin(freqs).to(self.data_type).to(device=self.target_device) return def _init_to_get_dynamic_ntk_rotary(self): @@ -148,16 +148,16 @@ def _init_to_get_dynamic_ntk_rotary(self): else: scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) max_seq_len = max(self.max_seq_length, max_position_embeddings) - self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda") - self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda") + self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device=self.target_device) + self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device=self.target_device) inv_freq = 1.0 / ( base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) ) t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).to(device=self.target_device) for seq_loc_index in range(max_position_embeddings, max_seq_len, 1): new_base = base * ( @@ -174,8 +174,8 @@ def _init_to_get_dynamic_ntk_rotary(self): dtype=torch.float32, ) freqs = torch.outer(t, inv_freq) - self._cos_cached[seq_loc_index : seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached[seq_loc_index : seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached[seq_loc_index : seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached[seq_loc_index : seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).to(device=self.target_device) return def _init_to_get_yarn_rotary(self): @@ -195,12 +195,12 @@ def _init_to_get_yarn_rotary(self): beta_fast = 32.0 beta_slow = 1.0 - pos_freqs = base ** (torch.arange(0, dim, 2).float().cuda() / dim) + pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device=self.target_device) / dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scale * pos_freqs) low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) inv_freq_mask = ( - 1 - linear_ramp_mask(low, high, dim // 2).float().cuda() + 1 - linear_ramp_mask(low, high, dim // 2).float().to(device=self.target_device) ) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask @@ -208,12 +208,12 @@ def _init_to_get_yarn_rotary(self): # Build here to make `torch.jit.trace` work. max_seq_len_cached = max_position_embeddings - t = torch.arange(max(max_seq_len_cached, self.max_seq_length), device="cuda", dtype=torch.float32) + t = torch.arange(max(max_seq_len_cached, self.max_seq_length), device=self.target_device, dtype=torch.float32) freqs = torch.einsum("i,j->ij", t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self._cos_cached = emb.cos().to(self.data_type).cuda() * mscale - self._sin_cached = emb.sin().to(self.data_type).cuda() * mscale + self._cos_cached = emb.cos().to(self.data_type).to(device=self.target_device) * mscale + self._sin_cached = emb.sin().to(self.data_type).to(device=self.target_device) * mscale return @@ -234,8 +234,8 @@ def _init_to_su_rotary(self): rope_scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) max_seq_len = max(self.max_seq_length, max_position_embeddings) - self._cos_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda") - self._sin_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda") + self._cos_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device=self.target_device) + self._sin_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device=self.target_device) inv_freq = 1.0 / ( short_factor @@ -244,10 +244,10 @@ def _init_to_su_rotary(self): t = torch.arange(original_max_position_embeddings, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) self._cos_cached[0:original_max_position_embeddings, :] = ( - (torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda() + (torch.cos(freqs) * rope_scaling_factor).to(self.data_type).to(device=self.target_device) ) self._sin_cached[0:original_max_position_embeddings, :] = ( - (torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda() + (torch.sin(freqs) * rope_scaling_factor).to(self.data_type).to(device=self.target_device) ) inv_freq = 1.0 / ( @@ -257,10 +257,10 @@ def _init_to_su_rotary(self): t = torch.arange(original_max_position_embeddings, max_seq_len, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) self._cos_cached[original_max_position_embeddings:, :] = ( - (torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda() + (torch.cos(freqs) * rope_scaling_factor).to(self.data_type).to(device=self.target_device) ) self._sin_cached[original_max_position_embeddings:, :] = ( - (torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda() + (torch.sin(freqs) * rope_scaling_factor).to(self.data_type).to(device=self.target_device) ) return @@ -298,6 +298,6 @@ def _init_to_get_llama3_rotary(self, default_base=10000): t = torch.arange(max(max_seq_len, self.max_seq_length), device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached = torch.sin(freqs).to(self.data_type).to(device=self.target_device) return diff --git a/lightllm/models/llama/triton_kernel/rotary_emb.py b/lightllm/models/llama/triton_kernel/rotary_emb.py index c6d4f3010d..c27a0ea286 100755 --- a/lightllm/models/llama/triton_kernel/rotary_emb.py +++ b/lightllm/models/llama/triton_kernel/rotary_emb.py @@ -116,7 +116,7 @@ def _rotary_kernel( @torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): +def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): total_len = q.shape[0] head_num_q, head_num_k = q.shape[1], k.shape[1] head_dim = int(q.shape[2] * partial_rotary_factor) @@ -158,6 +158,55 @@ def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): return +@torch.no_grad() +def npu_rotary_emb_fwd( + *, + is_prefill: bool, + batch_size: int, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + partial_rotary_factor: float = 1.0, +) -> None: + if partial_rotary_factor != 1.0: + rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor) + return + + head_dim = q.shape[-1] + if cos.shape[-1] != head_dim: + assert cos.shape[-1] * 2 == head_dim, (cos.shape, q.shape) + assert sin.shape[-1] * 2 == head_dim, (sin.shape, q.shape) + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + + import torch_npu + + if is_prefill or head_dim != 128: + # to [1, total_tokens, num_q_heads, head_dim] + q = q.unsqueeze(0) + k = k.unsqueeze(0) + # to [1, total_tokens, 1, head_dim] + cos = cos.unsqueeze(1).unsqueeze(0) + sin = sin.unsqueeze(1).unsqueeze(0) + q_embed = torch_npu.npu_rotary_mul(q, cos, sin, rotary_mode="half") + k_embed = torch_npu.npu_rotary_mul(k, cos, sin, rotary_mode="half") + q.copy_(q_embed) + k.copy_(k_embed) + else: + # to [batch_size, -1, num_q_heads, head_dim] + num_q_heads, head_dim = q.shape[-2:] + q = q.view(batch_size, -1, num_q_heads, head_dim) + num_k_heads, head_dim = k.shape[-2:] + k = k.view(batch_size, -1, num_k_heads, head_dim) + # to [batch_size, -1, 1, head_dim] + cos = cos.view(batch_size, -1, 1, head_dim) + sin = sin.view(batch_size, -1, 1, head_dim) + q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin, 'BSND') + q.copy_(q_embed) + k.copy_(k_embed) + + def torch_rotary_emb(x, cos, sin): seq_len, h, dim = x.shape dim = dim // 4 diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 293bcd4450..ca08754fcc 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -6,6 +6,7 @@ from typing import List, Union from safetensors import safe_open from io import BytesIO +from lightllm.models.visual_utils import VisualDeviceMixin from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.utils.log_utils import init_logger @@ -14,9 +15,13 @@ logger = init_logger(__name__) -class LlavaVisionModel: - def __init__(self): - pass +class LlavaVisionModel(VisualDeviceMixin): + + def _device_module_attrs(self): + return ("vision_tower",) + + def _device_tensor_dict_attrs(self): + return ("projector_weights",) def load_model(self, weight_dir): config_file = os.path.join(weight_dir, "config.json") @@ -29,7 +34,6 @@ def load_model(self, weight_dir): self.load_bin_model(config, weight_dir) self.vision_tower.requires_grad_(False) - self.device = torch.device("cpu") assert "model.mm_projector.0.weight" in self.projector_weights assert "model.mm_projector.0.bias" in self.projector_weights @@ -93,15 +97,9 @@ def load_bin_model(self, config, weight_dir): if "model.mm_projector" in k: self.projector_weights[k] = v.half() - def cuda(self): - self.vision_tower = self.vision_tower.cuda() - for k, v in self.projector_weights.items(): - self.projector_weights[k] = v.cuda() - return self - # batch images infer def forward(self, x): - x = x.half().cuda() + x = self.move_to_infer_device(x.half()) x = self.vision_tower(x, output_hidden_states=True) x = x.hidden_states[self.select_layer] if self.select_feature == "patch" or self.select_feature == "default": diff --git a/lightllm/models/mistral/model.py b/lightllm/models/mistral/model.py index f09525c59f..f6db619485 100644 --- a/lightllm/models/mistral/model.py +++ b/lightllm/models/mistral/model.py @@ -75,6 +75,6 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached = torch.sin(freqs).to(self.data_type).to(device=self.target_device) return diff --git a/lightllm/models/mixtral/model.py b/lightllm/models/mixtral/model.py index 3c2d7b4e87..1ac6c38c82 100644 --- a/lightllm/models/mixtral/model.py +++ b/lightllm/models/mixtral/model.py @@ -81,6 +81,6 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached = torch.sin(freqs).to(self.data_type).to(device=self.target_device) return diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index fddb14cfe5..261c308f2f 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -1,5 +1,5 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd as phi3_rotary_emb_fwd from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -17,11 +17,14 @@ def _get_qkv(self, input_emb, infer_state: LlamaInferStateInfo, layer_weight: Ph cache_kv = layer_weight.kv_proj.mm(input_emb).view( -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ ) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, 0 : self.tp_k_head_num_, :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, + rotary_impl=phi3_rotary_emb_fwd, ) if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index 003188d088..2ccd079a42 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -1,6 +1,5 @@ import torch from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.qwen.layer_weights.transformer_layer_weight import QwenTransformerLayerWeight from lightllm.models.qwen.infer_struct import QwenInferStateInfo @@ -19,11 +18,13 @@ def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: Qwe -1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_ ) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, 0 : self.tp_k_head_num_, :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, ) if infer_state.logn_values is not None: q.mul_(infer_state.logn_values.view(-1, 1)) diff --git a/lightllm/models/qwen/model.py b/lightllm/models/qwen/model.py index e7b9c76492..18b3d65288 100644 --- a/lightllm/models/qwen/model.py +++ b/lightllm/models/qwen/model.py @@ -74,8 +74,8 @@ def _init_qwen_dynamic_ntk(self): t = torch.arange(total_seq_len_supported + 128 * 1024, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) - self._cos_cached.append(torch.cos(freqs).to(self.data_type).cuda()) - self._sin_cached.append(torch.sin(freqs).to(self.data_type).cuda()) + self._cos_cached.append(torch.cos(freqs).to(self.data_type).to(device=self.target_device)) + self._sin_cached.append(torch.sin(freqs).to(self.data_type).to(device=self.target_device)) self._cos_cached = torch.stack(self._cos_cached, dim=0).contiguous() self._sin_cached = torch.stack(self._sin_cached, dim=0).contiguous() @@ -87,5 +87,5 @@ def _init_qwen_logn_attn(self): logn_list = [ math.log(i, seq_len) if i > seq_len else 1 for i in range(1, total_seq_len_supported + 128 * 1024 + 1) ] - self.logn_tensor = torch.tensor(logn_list).cuda() + self.logn_tensor = torch.tensor(logn_list).to(device=self.target_device) return diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..525d54fd22 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -16,7 +16,7 @@ from lightllm.server.visualserver import get_vit_attn_backend from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton - +from lightllm.models.visual_utils import VisualDeviceMixin class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -135,7 +135,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class Qwen2_5_VisionTransformerPretrainedModel(nn.Module): +class Qwen2_5_VisionTransformerPretrainedModel(VisualDeviceMixin, nn.Module): def __init__( self, kvargs, @@ -286,23 +286,22 @@ def get_window_index(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) - rotary_cos = rotary_cos.to("cuda", non_blocking=True) - rotary_sin = rotary_sin.to("cuda", non_blocking=True) + rotary_cos, rotary_sin = self.move_to_infer_device(rotary_cos, rotary_sin) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32 ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True) + cu_seqlens = self.move_to_infer_device(F.pad(cu_seqlens, (1, 0), value=0)) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() window_index, cu_window_seqlens = self.get_window_index(grid_thw) cu_window_seqlens = torch.tensor( cu_window_seqlens, - device=hidden_states.device, + device=self.infer_device, dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) - cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to("cuda", non_blocking=True) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) max_window_seqlen = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item() seq_len, _ = hidden_states.size() @@ -401,8 +400,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) - image_grid_thw = grid_thw.to("cuda", non_blocking=True) + pixel_values, image_grid_thw = self.move_to_infer_device( + imgs, grid_thw, dtype=(self.data_type, None) + ) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) diff --git a/lightllm/models/qwen2_vl/infer_struct.py b/lightllm/models/qwen2_vl/infer_struct.py index 747be932d9..667fe5b267 100644 --- a/lightllm/models/qwen2_vl/infer_struct.py +++ b/lightllm/models/qwen2_vl/infer_struct.py @@ -14,6 +14,7 @@ def __init__(self): self.position_sin = None def init_some_extra_state(self, model): + self.target_device = model.target_device rope_scaling = model.config.get("rope_scaling", {}) self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) InferStateInfo.init_some_extra_state(self, model) @@ -63,11 +64,11 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor: # 没有任何图片 if image_start_num == 0: return self.position_ids.unsqueeze(0).expand(3, -1).contiguous() - b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) - b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 - b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) - b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) - b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").to(device=self.target_device, non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").to(device=self.target_device, non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").to(device=self.target_device, non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").to(device=self.target_device, non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").to(device=self.target_device, non_blocking=True) position_ids = self.position_ids.unsqueeze(0).expand(3, -1).contiguous() get_mrope_position_triton( b_image_start_idx=b_image_start_idx, diff --git a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py index ae6861071e..4a4d4a73a6 100755 --- a/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py @@ -7,7 +7,7 @@ class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) mrope_section = network_config["rope_scaling"]["mrope_section"] - self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device=self.target_device) def _get_qkv(self, input, infer_state, layer_weight): input = self._tpsp_allgather(input, infer_state) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 6076756043..64570b1bcc 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -25,6 +25,7 @@ from PIL import Image from typing import List from torchvision import transforms as T +from lightllm.models.visual_utils import VisualDeviceMixin from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO import torch.nn as nn @@ -64,8 +65,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) - # Use channels_last_3d to enable cuDNN optimized Conv3D path - hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) + if hidden_states.device.type == "cuda": + # Use channels_last_3d to enable cuDNN optimized Conv3D path + hidden_states = hidden_states.contiguous(memory_format=torch.channels_last_3d) + else: + hidden_states = hidden_states.contiguous() hidden_states = self.proj(hidden_states).view(-1, self.embed_dim) return hidden_states @@ -175,7 +179,7 @@ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) return hidden_states -class Qwen2VisionTransformerPretrainedModel(nn.Module): +class Qwen2VisionTransformerPretrainedModel(VisualDeviceMixin, nn.Module): def __init__( self, kvargs, @@ -213,7 +217,7 @@ def __init__( ) head_dim = self.embed_dim // self.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda() + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ @@ -285,15 +289,14 @@ def rot_pos_emb(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) - rotary_cos = rotary_cos.to("cuda", non_blocking=True) - rotary_sin = rotary_sin.to("cuda", non_blocking=True) + rotary_cos, rotary_sin = self.move_to_infer_device(rotary_cos, rotary_sin) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32 ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - cu_seqlens = cu_seqlens.to("cuda", non_blocking=True) + cu_seqlens = self.move_to_infer_device(cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, @@ -333,8 +336,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) - image_grid_thw = grid_thw.to("cuda", non_blocking=True) + pixel_values, image_grid_thw = self.move_to_infer_device( + imgs, grid_thw, dtype=(self.data_type, None) + ) all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) diff --git a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py index 8063856599..57e181296e 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py +++ b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py @@ -60,7 +60,7 @@ def rotary_kernel( def apply_rotary_pos_emb_triton( tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128 ) -> torch.Tensor: - assert tensor.is_cuda and cos.is_cuda and sin.is_cuda + # assert tensor.is_cuda and cos.is_cuda and sin.is_cuda assert cos.is_contiguous() and sin.is_contiguous() if tensor.ndim != 3: raise RuntimeError("tensor shape should be [L, H, D]") diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index bc313fe467..c26c0c6963 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -22,6 +22,8 @@ ) from torchvision.transforms.v2 import functional as F +from lightllm.platform import get_backend +from lightllm.utils.device_utils import get_target_device from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -77,6 +79,73 @@ def resize_image( return image +def _flatten_patches_qwen2vl_npu( + patches: torch.Tensor, + batch_size: int, + grid_t: int, + temporal_patch_size: int, + channel: int, + grid_h: int, + grid_w: int, + merge_size: int, + patch_size: int, +) -> torch.Tensor: + """ + Current NPU runtime doesn't support tensors with more than 8 dims. So we need to flatten the patches to 8D. + The original code is: + patches = ( + patches.view( + batch_size, + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + .permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + .contiguous() + ) + flatten_patches = patches.view( + batch_size, + grid_t * grid_h * grid_w, + channel * self.temporal_patch_size * self.patch_size * self.patch_size, + ) + """ + gh_ms = grid_h // merge_size + gw_ms = grid_w // merge_size + patches = patches.view( + batch_size, + grid_t, + temporal_patch_size, + channel, + gh_ms, + merge_size * patch_size, + gw_ms, + merge_size * patch_size, + ) + patches = patches.permute(0, 1, 4, 6, 5, 7, 3, 2).contiguous() + patches = patches.view( + batch_size * grid_t * gh_ms * gw_ms, + merge_size, + patch_size, + merge_size, + patch_size, + channel, + temporal_patch_size, + ) + patches = patches.permute(0, 1, 3, 2, 4, 5, 6).contiguous() + patches = patches.permute(0, 1, 2, 5, 6, 3, 4).contiguous() + return patches.view( + batch_size, + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size, + ) + + class Qwen2VLImageProcessor(BaseImageProcessorFast): def __init__( self, @@ -99,6 +168,10 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + + self.platform_backend = get_backend() + self.target_device = get_target_device() + self.size = size self.do_resize = do_resize self.resample = resample @@ -177,13 +250,13 @@ def rescale_and_normalize( @torch.inference_mode() def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: try: - return self._preprocess_bydevice(image, device="cuda") + return self._preprocess_bydevice(image, device=self.target_device) except Exception as e: logger.warning(f"Exception during image preprocessing on CUDA: {str(e)}") - torch.cuda.current_stream().synchronize() + self.platform_backend.runtime.current_stream().synchronize() return self._preprocess_bydevice(image, device="cpu") - def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]: + def _preprocess_bydevice(self, image, device) -> Tuple[torch.Tensor, torch.Tensor]: if image.mode != "RGB": image = image.convert("RGB") image_arr = np.asarray(image, dtype=np.uint8) @@ -226,7 +299,7 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc processed_grids = {} for shape, stacked_images in grouped_images.items(): - stacked_images = stacked_images.to("cuda", non_blocking=True) + stacked_images = stacked_images.to(device=device, non_blocking=True) resized_height, resized_width = stacked_images.shape[-2:] @@ -249,28 +322,41 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc grid_t = grid_t // self.temporal_patch_size grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size - patches = ( - patches.view( - batch_size, - grid_t, - self.temporal_patch_size, - channel, - grid_h // self.merge_size, - self.merge_size, - self.patch_size, - grid_w // self.merge_size, - self.merge_size, - self.patch_size, + if self.platform_backend.name == "ascend": + flatten_patches = _flatten_patches_qwen2vl_npu( + patches=patches, + batch_size=batch_size, + grid_t=grid_t, + temporal_patch_size=self.temporal_patch_size, + channel=channel, + grid_h=grid_h, + grid_w=grid_w, + merge_size=self.merge_size, + patch_size=self.patch_size, + ) + else: + patches = ( + patches.view( + batch_size, + grid_t, + self.temporal_patch_size, + channel, + grid_h // self.merge_size, + self.merge_size, + self.patch_size, + grid_w // self.merge_size, + self.merge_size, + self.patch_size, + ) + .permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + .contiguous() ) - .permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) - .contiguous() - ) - flatten_patches = patches.view( - batch_size, - grid_t * grid_h * grid_w, - channel * self.temporal_patch_size * self.patch_size * self.patch_size, - ) + flatten_patches = patches.view( + batch_size, + grid_t * grid_h * grid_w, + channel * self.temporal_patch_size * self.patch_size * self.patch_size, + ) processed_images_grouped[shape] = flatten_patches processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 3a66d506ca..47712925af 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -3,7 +3,6 @@ from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -31,11 +30,14 @@ def _get_qkv( eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, + + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, : self.tp_k_head_num_, :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, ) if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py index afbd02a482..ee22461d9b 100644 --- a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -17,7 +17,7 @@ def __init__(self, layer_num, network_config): # Initialize mrope section from config rope_scaling = network_config.get("rope_scaling", {}) mrope_section = rope_scaling.get("mrope_section", [11, 11, 10]) - self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device=self.target_device) def _get_qkv( self, diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 54e4373652..4002fa2749 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -5,7 +5,6 @@ from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.envs_utils import get_env_start_args @@ -60,11 +59,13 @@ def _get_qkv( eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, : self.tp_k_head_num_, :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, ) if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) diff --git a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py index 67fd49cd1f..fe4e204cc9 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/audio_process.py +++ b/lightllm/models/qwen3_omni_moe_thinker/audio_process.py @@ -188,10 +188,6 @@ def _preprocess( if return_tensors is not None: padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - input_features = torch.from_numpy(np.asarray(padded_inputs["input_features"], dtype=np.float32)).to( - device="cuda", dtype=torch.bfloat16 - ) - attention_mask = torch.from_numpy(np.asarray(padded_inputs["attention_mask"], dtype=np.float32)).to( - device="cuda", dtype=torch.int32 - ) + input_features = torch.from_numpy(np.asarray(padded_inputs["input_features"], dtype=np.float32)) + attention_mask = torch.from_numpy(np.asarray(padded_inputs["attention_mask"], dtype=np.float32)) return input_features, attention_mask diff --git a/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py index 1a05a752f3..321cf8dd75 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_omni_moe_thinker/layer_infer/transformer_layer_infer.py @@ -10,6 +10,5 @@ def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.mrope_section = torch.tensor( - network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" - ) + network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device=self.target_device) return diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py index 4ad4300fc0..3638047cdf 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_audio.py @@ -14,6 +14,7 @@ from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd from lightllm.models.qwen3_omni_moe_thinker.audio_process import WhisperFeatureExtractor +from lightllm.models.visual_utils import VisualDeviceMixin QWEN3_OMNI_CONV_CHUNKSIZE = int(os.getenv("LIGHTLLM_QWEN3_OMNI_CONV_CHUNKSIZE", 200)) @@ -146,7 +147,7 @@ def forward(self, seqlen: int): return self.positional_embedding[:seqlen, :] -class Qwen3OmniMoeAudioEncoder(nn.Module): +class Qwen3OmniMoeAudioEncoder(VisualDeviceMixin, nn.Module): def __init__( self, kvargs, @@ -347,7 +348,14 @@ def encode(self, audio_items: List[AudioItem]): else: raise ValueError(f"cannot read audio which type is {type(item)}!") - input_features, feature_attention_mask = self.processor._preprocess(audio, return_attention_mask=True) + input_features, feature_attention_mask = self.processor._preprocess( + audio, return_attention_mask=True + ) + input_features, feature_attention_mask = self.move_to_infer_device( + input_features, + feature_attention_mask, + dtype=(self.data_type, torch.int32), + ) if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index 0276724749..1e4c92af52 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from transformers.activations import ACT2FN +from lightllm.models.visual_utils import VisualDeviceMixin from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor @@ -120,7 +121,7 @@ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) return hidden_states -class Qwen3OmniMoeVisionTransformerPretrainedModel(nn.Module): +class Qwen3OmniMoeVisionTransformerPretrainedModel(VisualDeviceMixin, nn.Module): def __init__( self, kvargs, @@ -178,7 +179,7 @@ def __init__( self.num_grid_per_side = int(self.num_position_embeddings ** 0.5) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda() + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ @@ -348,13 +349,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) hidden_states = hidden_states + pos_embeds rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) - rotary_cos = rotary_cos.to("cuda", non_blocking=True) - rotary_sin = rotary_sin.to("cuda", non_blocking=True) + rotary_cos, rotary_sin = self.move_to_infer_device(rotary_cos, rotary_sin) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32, ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True) + cu_seqlens = self.move_to_infer_device(F.pad(cu_seqlens, (1, 0), value=0)) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): @@ -402,8 +402,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) - image_grid_thw = grid_thw.to("cuda", non_blocking=True) + pixel_values, image_grid_thw = self.move_to_infer_device( + imgs, grid_thw, dtype=(self.data_type, None) + ) img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw) all_img_embeds_df, valid_ids = self.concat_img_embed_and_deepstack_features( img_embeds, deepstack_feature_lists, valid_ids diff --git a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py index 6be827ac0a..83b9c2a048 100644 --- a/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py @@ -1,7 +1,6 @@ import torch import torch.distributed as dist from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo -from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer from ..layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight @@ -49,16 +48,16 @@ def context_forward( # each tp will fill the img embeds, should divide by world_size infer_state.img_start_token_ids = torch.tensor( img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True - ).cuda(non_blocking=True) - infer_state.img_token_lens = torch.tensor(img_token_lens, dtype=torch.long, device="cpu", pin_memory=True).cuda( - non_blocking=True - ) + ).to(device=self.target_device, non_blocking=True) + infer_state.img_token_lens = torch.tensor( + img_token_lens, dtype=torch.long, device="cpu", pin_memory=True + ).to(device=self.target_device, non_blocking=True) infer_state.img_start_locs_in_cache = torch.tensor( img_start_locs_in_cache, dtype=torch.long, device="cpu", pin_memory=True - ).cuda(non_blocking=True) + ).to(device=self.target_device, non_blocking=True) infer_state.input_ids = input_ids - multimodal_emb( + self.platform_backend.ops.multimodal_emb( out=out, prompt_ids=input_ids, text_weight_embs=layer_weight.wte_weight_.weight, diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index 6567eb57cc..92875dd8aa 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -15,8 +15,7 @@ def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.mrope_section = torch.tensor( - network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" - ) + network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device=self.target_device) def _get_qkv( self, @@ -79,7 +78,7 @@ def _apply_deepstack_features_wrapper_run( infer_state: InferStateInfo, layer_num: int, ): - if torch.cuda.is_current_stream_capturing(): + if self.platform_backend.graph.is_capturing(): input_embeddings = input_embeddings.contiguous() _input_embeddings = tensor_to_no_ref_tensor(input_embeddings) pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index bed8898115..dab280cb30 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -25,6 +25,7 @@ import torch.nn.functional as F from transformers.activations import ACT2FN +from lightllm.models.visual_utils import VisualDeviceMixin from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor @@ -116,7 +117,7 @@ def forward(self, hidden_states, cu_seqlens, max_seqlen, rotary_cos, rotary_sin) return hidden_states -class Qwen3VisionTransformerPretrainedModel(nn.Module): +class Qwen3VisionTransformerPretrainedModel(VisualDeviceMixin, nn.Module): def __init__( self, kvargs, @@ -161,7 +162,7 @@ def __init__( self.num_grid_per_side = int(self.num_position_embeddings ** 0.5) head_dim = self.hidden_size // self.num_heads - self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).cuda() + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [ @@ -343,13 +344,12 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) hidden_states = hidden_states + pos_embeds rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw) - rotary_cos = rotary_cos.to("cuda", non_blocking=True) - rotary_sin = rotary_sin.to("cuda", non_blocking=True) + rotary_cos, rotary_sin = self.move_to_infer_device(rotary_cos, rotary_sin) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32, ) - cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True) + cu_seqlens = self.move_to_infer_device(F.pad(cu_seqlens, (1, 0), value=0)) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): @@ -400,8 +400,9 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) - image_grid_thw = grid_thw.to("cuda", non_blocking=True) + pixel_values, image_grid_thw = self.move_to_infer_device( + imgs, grid_thw, dtype=(self.data_type, None) + ) img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw) all_img_embeds_df, valid_ids = self.concat_img_embed_and_deepstack_features( img_embeds, deepstack_feature_lists, valid_ids diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index edf1f8cecf..d90919693e 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -13,8 +13,7 @@ class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) self.mrope_section = torch.tensor( - network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device="cuda" - ) + network_config["rope_scaling"]["mrope_section"], dtype=torch.int32, device=self.target_device) def _get_qkv( self, @@ -77,7 +76,7 @@ def _apply_deepstack_features_wrapper_run( infer_state: InferStateInfo, layer_num: int, ): - if torch.cuda.is_current_stream_capturing(): + if self.platform_backend.graph.is_capturing(): input_embeddings = input_embeddings.contiguous() _input_embeddings = tensor_to_no_ref_tensor(input_embeddings) pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index bb48bfe49c..23a315a070 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -15,7 +15,6 @@ from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule from lightllm.distributed import all_reduce -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type from functools import partial @@ -182,11 +181,13 @@ def _get_qkv( eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, : self.tp_k_head_num_, :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, partial_rotary_factor=self.partial_rotary_factor, ) if infer_state.need_dp_prefill_balance: @@ -281,7 +282,7 @@ def _gdn_prefill_wrapper_run( infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight, ) -> Tuple[torch.Tensor, torch.Tensor]: - if torch.cuda.is_current_stream_capturing(): + if self.platform_backend.graph.is_capturing(): mixed_qkvzba = mixed_qkvzba.contiguous() _mixed_qkvzba = tensor_to_no_ref_tensor(mixed_qkvzba) pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 4a8ee80a46..166c9138c7 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -41,7 +41,7 @@ def __init__(self, kvargs) -> None: def _init_triton(self): def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor: - return torch.empty(size, device="cuda", dtype=torch.int8) + return torch.empty(size, device=self.target_device, dtype=torch.int8) # Set Triton allocator for TMA descriptors # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py index b5b6cfc369..527202774d 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -19,13 +19,14 @@ from .index import prepare_chunk_indices from .op import make_tensor_descriptor from .utils import input_guard, is_amd, is_tma_supported +from lightllm.utils.device_utils import get_target_device def _ensure_triton_allocator(): """Ensure Triton has an allocator set for kernels requiring scratch memory.""" def alloc_fn(size: int, alignment: int, stream: Optional[int]): - return torch.empty(size, device="cuda", dtype=torch.int8) + return torch.empty(size, device=torch.device(get_target_device()), dtype=torch.int8) triton.set_allocator(alloc_fn) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b9fe2569c..ec61cb3e32 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -4,7 +4,6 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce @@ -60,17 +59,17 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}" ) # each tp will fill the img embeds, should divide by world_size - img_start_token_ids = torch.tensor(img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True).cuda( - non_blocking=True - ) - img_token_lens = torch.tensor(img_token_lens, dtype=torch.long, device="cpu", pin_memory=True).cuda( - non_blocking=True - ) + img_start_token_ids = torch.tensor( + img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True + ).to(device=self.target_device, non_blocking=True) + img_token_lens = torch.tensor( + img_token_lens, dtype=torch.long, device="cpu", pin_memory=True + ).to(device=self.target_device, non_blocking=True) img_start_locs_in_cache = torch.tensor( img_start_locs_in_cache, dtype=torch.long, device="cpu", pin_memory=True - ).cuda(non_blocking=True) + ).to(device=self.target_device, non_blocking=True) - multimodal_emb( + self.platform_backend.ops.multimodal_emb( out=out, prompt_ids=input_ids, text_weight_embs=layer_weight.wte_weight_.weight, diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index 07a7412020..4b8cd9e134 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -18,6 +18,7 @@ from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode +from lightllm.models.visual_utils import VisualDeviceMixin def get_abs_pos(abs_pos, tgt_size): @@ -330,7 +331,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): return x -class QWenVisionTransformer(nn.Module): +class QWenVisionTransformer(VisualDeviceMixin, nn.Module): def __init__( self, image_size: int, diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index f34619b1f8..b1ed70302d 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -1,6 +1,5 @@ import torch from functools import partial -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.stablelm.layer_weights.transformer_layer_weight import StablelmTransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -25,12 +24,14 @@ def _get_qkv( cache_kv = layer_weight.kv_proj.mm( input.view(-1, self.embed_dim_), ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - self.partial_rotary_factor, + self.platform_backend.ops.rotary_emb( + is_prefill=infer_state.is_prefill, + batch_size=infer_state.batch_size, + q=q.view(-1, self.tp_q_head_num_, self.head_dim_), + k=cache_kv[:, 0 : self.tp_k_head_num_, :], + cos=infer_state.position_cos, + sin=infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, ) if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) diff --git a/lightllm/models/starcoder2/model.py b/lightllm/models/starcoder2/model.py index 2b75459140..d8fcd00970 100644 --- a/lightllm/models/starcoder2/model.py +++ b/lightllm/models/starcoder2/model.py @@ -84,6 +84,6 @@ def _init_to_get_rotary(self, default_base=10000): t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) - self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached = torch.cos(freqs).to(self.data_type).to(device=self.target_device) + self._sin_cached = torch.sin(freqs).to(self.data_type).to(device=self.target_device) return diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 9deaf08575..dedc874774 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -14,6 +14,7 @@ from safetensors import safe_open from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel +from lightllm.models.visual_utils import VisualDeviceMixin from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image @@ -149,7 +150,7 @@ def forward(self, image_features, input_embeddings): return hidden_states -class TarsierVisionTransformerPretrainedModel(nn.Module): +class TarsierVisionTransformerPretrainedModel(VisualDeviceMixin, nn.Module): def __init__( self, vision_config=None, @@ -165,8 +166,7 @@ def __init__( **kwargs, ): super().__init__() - self.vision_tower = Qwen2VisionTransformerPretrainedModel(**vision_config) - + self.vision_tower = Qwen2VisionTransformerPretrainedModel(kwargs, **vision_config) if projection_head == "Pixel_Shuffle": self.multi_modal_projector = PixelShuffleMultiModalProjector( image_newline_idx, @@ -255,7 +255,6 @@ def encode(self, images: List[ImageItem]): image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image=image_data) - pixel_values = pixel_values.to(dtype=torch.bfloat16) img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: @@ -273,8 +272,10 @@ def encode(self, images: List[ImageItem]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.cuda() - image_grid_thw = grid_thw.cuda() + infer_dtype = next(self.parameters()).dtype + pixel_values, image_grid_thw = self.move_to_infer_device( + imgs, grid_thw, dtype=(infer_dtype, None) + ) all_img_embeds = self.forward(pixel_values=pixel_values, image_grid_thw=image_grid_thw) diff --git a/lightllm/models/visual_utils.py b/lightllm/models/visual_utils.py new file mode 100644 index 0000000000..ceddb58252 --- /dev/null +++ b/lightllm/models/visual_utils.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from typing import List, Optional, Sequence, Union +from lightllm.utils.device_utils import get_target_device +from lightllm.utils.dist_utils import get_current_device_id + + +def default_infer_dtype(device_id: Optional[int] = None) -> torch.dtype: + device = get_target_device(device_id) + if device.type == "cuda" and torch.cuda.is_available(): + return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 + return torch.bfloat16 + + +def _tensor_to_infer_device( + tensor: torch.Tensor, + device: torch.device, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = True, +) -> torch.Tensor: + out = tensor.to(device=device, non_blocking=non_blocking) + if dtype is not None: + out = out.to(dtype=dtype) + return out + + +def _resolve_move_dtypes( + n: int, + dtype: Optional[Union[torch.dtype, Sequence[Optional[torch.dtype]]]], +) -> List[Optional[torch.dtype]]: + if dtype is None: + return [None] * n + if isinstance(dtype, torch.dtype): + return [dtype] * n + if isinstance(dtype, (list, tuple)): + dtypes = list(dtype) + if len(dtypes) != n: + raise ValueError(f"dtype length {len(dtypes)} must match number of tensors ({n})") + return dtypes + raise TypeError(f"dtype must be torch.dtype or a sequence, got {type(dtype)!r}") + + +class VisualDeviceMixin: + device_id: Optional[int] = None + target_device: Optional[torch.device] = None + + def setup_device(self, device_id: Optional[int] = None): + self.device_id = device_id if device_id is not None else get_current_device_id() + self.target_device = get_target_device(self.device_id) + if isinstance(self, nn.Module): + self.to(device=self.target_device) + else: + self._setup_device_non_module() + return self + + def _device_module_attrs(self) -> Sequence[str]: + """ The attributes that are nn.Modules, e.g. vision_tower, audio, model """ + return () + + def _device_tensor_dict_attrs(self) -> Sequence[str]: + # attributes that are dict[str, Tensor], e.g. projector_weights + """ The attributes that are dict[str, Tensor], e.g. projector_weights """ + return () + + def _move_module_attr(self, name: str) -> None: + mod = getattr(self, name, None) + if isinstance(mod, nn.Module): + setattr(self, name, mod.to(device=self.target_device)) + + def _move_tensor_dict_attr(self, name: str) -> None: + weights = getattr(self, name, None) + if isinstance(weights, dict): + for k, v in list(weights.items()): + weights[k] = v.to(device=self.target_device) + + def _setup_device_non_module(self): + for name in self._device_module_attrs(): + self._move_module_attr(name) + for name in self._device_tensor_dict_attrs(): + self._move_tensor_dict_attr(name) + + @property + def infer_device(self) -> torch.device: + if self.target_device is not None: + return self.target_device + if isinstance(self, nn.Module): + for param in self.parameters(): + if param.numel() > 0: + return param.device + raise RuntimeError(f"{type(self).__name__}: call setup_device() before inference") + + def move_to_infer_device( + self, + *tensors: torch.Tensor, + dtype: Optional[Union[torch.dtype, Sequence[Optional[torch.dtype]]]] = None, + non_blocking: bool = True, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if not tensors: + raise ValueError("move_to_infer_device() requires at least one tensor") + device = self.infer_device + dtypes = _resolve_move_dtypes(len(tensors), dtype) + if len(tensors) == 1: + return _tensor_to_infer_device(tensors[0], device, dtype=dtypes[0], non_blocking=non_blocking) + return [ + _tensor_to_infer_device(t, device, dtype=dt, non_blocking=non_blocking) + for t, dt in zip(tensors, dtypes) + ] diff --git a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py index 73eb0b46ac..011ee28fd7 100644 --- a/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py @@ -29,12 +29,18 @@ def _create_weight(self): split_embed_dim = split_end - split_start # Pre-allocate memory for vision model weights - self.class_embedding = torch.empty((1, 1, split_embed_dim), dtype=self.data_type_).cuda() - self.position_embedding = torch.empty((1, self.num_positions, split_embed_dim), dtype=self.data_type_).cuda() + self.class_embedding = torch.empty( + (1, 1, split_embed_dim), dtype=self.data_type_ + ).to(device=self.target_device) + self.position_embedding = torch.empty( + (1, self.num_positions, split_embed_dim), dtype=self.data_type_ + ).to(device=self.target_device) self.patch_embedding_weight_ = torch.empty( (split_embed_dim, 3, self.patch_size, self.patch_size), dtype=self.data_type_ - ).cuda() - self.patch_embedding_bias_ = torch.empty(split_embed_dim, dtype=self.data_type_).cuda() + ).to(device=self.target_device) + self.patch_embedding_bias_ = torch.empty( + split_embed_dim, dtype=self.data_type_ + ).to(device=self.target_device) self.layernorm_weight_ = LayerNormWeight( dim=self.embed_dim * int(1 / self.downsample_ratio) ** 2, @@ -60,9 +66,8 @@ def _create_weight(self): ) return - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) + def _to_device(self, cpu_tensor): + return cpu_tensor.contiguous().to(device=self.target_device, dtype=self.data_type_) def _get_pos_embed(self, H, W): pos_embed = self.position_embedding[:, 1:, :] diff --git a/lightllm/models/vit/layer_weights/transformer_layer_weight.py b/lightllm/models/vit/layer_weights/transformer_layer_weight.py index 198b3022be..96ecea9275 100644 --- a/lightllm/models/vit/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/vit/layer_weights/transformer_layer_weight.py @@ -11,7 +11,6 @@ LayerNormWeight, TpRMSNormWeight, ) -from lightllm.utils.dist_utils import get_current_device_id class ViTTransformerLayerWeight(TransformerLayerWeight): @@ -19,9 +18,8 @@ def __init__(self, layer_num, data_type, network_config, quant_cfg=None): super().__init__(layer_num, data_type, network_config, quant_cfg) return - def _cuda(self, cpu_tensor): - device_id = get_current_device_id() - return cpu_tensor.contiguous().to(self.data_type_).cuda(device_id) + def _to_device(self, cpu_tensor: torch.Tensor) -> torch.Tensor: + return cpu_tensor.contiguous().to(device=self.target_device, dtype=self.data_type_) def _parse_config(self): self.padding_hidden_size = self.network_config_["padding_hidden_size"] @@ -195,11 +193,11 @@ def load_hf_weights(self, weights): if f"vision_model.encoder.layers.{self.layer_num_}.ls1" in weights: ls1 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls1"] - self.ls1 = self._cuda(ls1) + self.ls1 = self._to_device(ls1) if f"vision_model.encoder.layers.{self.layer_num_}.ls2" in weights: ls2 = weights[f"vision_model.encoder.layers.{self.layer_num_}.ls2"] - self.ls2 = self._cuda(ls2) + self.ls2 = self._to_device(ls2) self.use_ls = True return super().load_hf_weights(weights) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 13f8e2827f..935d2be2fa 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -7,6 +7,9 @@ from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight from lightllm.models.vit.layer_weights.hf_load_utils import load_hf_weights +from lightllm.models.visual_utils import VisualDeviceMixin +from lightllm.utils.device_utils import get_target_device +from lightllm.utils.dist_utils import get_current_device_id from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.common.build_utils import repair_config from lightllm.utils.log_utils import init_logger @@ -25,7 +28,7 @@ logger = init_logger(__name__) -class VisionTransformer: +class VisionTransformer(VisualDeviceMixin): # weight class pre_and_post_weight_class = ViTPreAndPostLayerWeight @@ -46,6 +49,10 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.load_image_func = get_load_image_func(self.weight_dir_) self.max_batch_size = kvargs.get("max_batch_size", 1) + self.device_id = kvargs.get("device_id") + if self.device_id is None: + self.device_id = get_current_device_id() + self.target_device = get_target_device(self.device_id) self._init_datatype() self._init_config() @@ -66,7 +73,7 @@ def _check_max_len_infer(self): try: dummy_images = torch.randn( (self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type - ).cuda() + ).to(device=self.infer_device) all_img_embeds = self.forward(dummy_images) del all_img_embeds logger.info(f"vit check max_len {self.max_batch_size} infer ok") @@ -191,12 +198,9 @@ def encode(self, images: List[ImageItem]): return None imgs = torch.cat(img_tensors, dim=0) - pixel_values = imgs.cuda().to(dtype=self.data_type) + pixel_values = self.move_to_infer_device(imgs, dtype=self.data_type) all_img_embeds = self.forward(pixel_values) return all_img_embeds.view(-1, all_img_embeds.shape[-1]), uuids, valid_ids - def cuda(self): - return self - def load_model(self, weight_dir): pass diff --git a/lightllm/models/whisper/whisper_audio.py b/lightllm/models/whisper/whisper_audio.py index 8a984d29a5..1c9913ad64 100644 --- a/lightllm/models/whisper/whisper_audio.py +++ b/lightllm/models/whisper/whisper_audio.py @@ -6,6 +6,7 @@ from typing import List, Union from safetensors.torch import load_file from transformers.processing_utils import ProcessorMixin +from lightllm.models.visual_utils import VisualDeviceMixin from lightllm.server.multimodal_params import AudioItem @@ -79,7 +80,14 @@ def get_prompt_ids(self, text: str, return_tensors="np"): return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors) -class WhisperAudioModel: +class WhisperAudioModel(VisualDeviceMixin): + + def _device_module_attrs(self): + return ("audio",) + + def _device_tensor_dict_attrs(self): + return ("projector_weights",) + def __init__(self, kvargs): self.max_seconds = 30 self.sampling_rate = 16000 @@ -90,19 +98,11 @@ def __init__(self, kvargs): else: self.data_type = torch.float16 - def cuda(self): - self.audio = self.audio.cuda() - for k, v in self.projector_weights.items(): - self.projector_weights[k] = v.cuda() - self.device = torch.device("cuda") - return self - def load_model(self, weight_dir, config): self.audio_processor = WhisperProcessor.from_pretrained(weight_dir) from lightllm.models.whisper.modeling_whisper import WhisperEncoder, WhisperConfig self.audio = WhisperEncoder(WhisperConfig(**config["audio_config"])).to(self.data_type) - self.device = torch.device("cpu") self.projector_weights = {} self.load_weight(weight_dir) @@ -133,9 +133,11 @@ def load_weight(self, weight_dir): assert "mlp2.3.weight" in self.projector_weights def forward(self, audio_values, audio_lens_after_cnn): - audio_values = audio_values.to(self.data_type).to(device=self.device) + audio_values = self.move_to_infer_device(audio_values.to(self.data_type)) audio_values = audio_values.squeeze(1) - audio_lens_after_cnn = torch.tensor(audio_lens_after_cnn).cuda() + audio_lens_after_cnn = torch.tensor( + audio_lens_after_cnn, device=self.infer_device, dtype=torch.long + ) max_len_in_batch = torch.max(audio_lens_after_cnn).item() padding_mask = torch.ones([audio_values.size(0), max_len_in_batch]).to( diff --git a/lightllm/platform/__init__.py b/lightllm/platform/__init__.py new file mode 100644 index 0000000000..83b655dc8d --- /dev/null +++ b/lightllm/platform/__init__.py @@ -0,0 +1,30 @@ +from typing import Optional +from lightllm.platform.base.registry import Backend, get_platform_spec + +_backend: Optional[Backend] = None + + +def get_backend() -> Backend: + global _backend + + if _backend is not None: + return _backend + + from lightllm.platform.plugins import configure_op_plugins + from lightllm.utils.envs_utils import get_env_start_args + + configure_op_plugins() + + platform_name = get_env_start_args().hardware_platform + spec = get_platform_spec(platform_name) + + backend_cls = spec.backend_cls + _backend = backend_cls() + + if not _backend.runtime.is_available(): + raise RuntimeError(f"Backend {platform_name} is not available") + + return _backend + + +__all__ = ["get_backend"] diff --git a/lightllm/platform/backends/__init__.py b/lightllm/platform/backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/platform/backends/ascend.py b/lightllm/platform/backends/ascend.py new file mode 100644 index 0000000000..91c1d3ecfc --- /dev/null +++ b/lightllm/platform/backends/ascend.py @@ -0,0 +1,13 @@ +from lightllm.platform.base.ops import build_ops +from lightllm.platform.base.registry import Backend, register_platform +from lightllm.platform.graph.ascend import AscendGraphBackend +from lightllm.platform.runtime.ascend import AscendRuntime + + +@register_platform("ascend", op_fallback=("ascend",)) +class AscendBackend(Backend): + + def __init__(self) -> None: + self._runtime = AscendRuntime() + self._graph = AscendGraphBackend() + self._ops = build_ops(self.platform_name) diff --git a/lightllm/platform/backends/cuda.py b/lightllm/platform/backends/cuda.py new file mode 100644 index 0000000000..122f5f0e50 --- /dev/null +++ b/lightllm/platform/backends/cuda.py @@ -0,0 +1,13 @@ +from lightllm.platform.base.ops import build_ops +from lightllm.platform.base.registry import Backend, register_platform +from lightllm.platform.graph.cuda import CudaGraphBackend +from lightllm.platform.runtime.cuda import CudaRuntime + + +@register_platform("cuda", op_fallback=("cuda_like",)) +class CudaBackend(Backend): + + def __init__(self) -> None: + self._runtime = CudaRuntime() + self._graph = CudaGraphBackend() + self._ops = build_ops(self.platform_name) diff --git a/lightllm/platform/backends/maca.py b/lightllm/platform/backends/maca.py new file mode 100644 index 0000000000..8eb39cdfee --- /dev/null +++ b/lightllm/platform/backends/maca.py @@ -0,0 +1,7 @@ +from lightllm.platform.backends.cuda import CudaBackend +from lightllm.platform.base.registry import register_platform + + +@register_platform("maca", op_fallback=("cuda_like",)) +class MacaBackend(CudaBackend): + pass diff --git a/lightllm/platform/backends/musa.py b/lightllm/platform/backends/musa.py new file mode 100644 index 0000000000..a01f94d6a0 --- /dev/null +++ b/lightllm/platform/backends/musa.py @@ -0,0 +1,7 @@ +from lightllm.platform.backends.cuda import CudaBackend +from lightllm.platform.base.registry import register_platform + + +@register_platform("musa", op_fallback=("cuda_like",)) +class MusaBackend(CudaBackend): + pass diff --git a/lightllm/platform/base/graph.py b/lightllm/platform/base/graph.py new file mode 100644 index 0000000000..e62c8fc8b7 --- /dev/null +++ b/lightllm/platform/base/graph.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from typing import Any, ContextManager, Optional + + +class BackendGraph(ABC): + + @abstractmethod + def create_graph(self) -> Any: + pass + + @abstractmethod + def graph(self, graph_obj: Any, pool: Optional[Any] = None, stream: Optional[Any] = None) -> ContextManager: + pass + + def replay_graph(self, graph_obj: Any) -> Any: + graph_obj.replay() + + @abstractmethod + def graph_pool_handle(self) -> Any: + pass + + @abstractmethod + def is_capturing(self) -> bool: + pass diff --git a/lightllm/platform/base/ops/__init__.py b/lightllm/platform/base/ops/__init__.py new file mode 100644 index 0000000000..5e2c6e6296 --- /dev/null +++ b/lightllm/platform/base/ops/__init__.py @@ -0,0 +1,4 @@ +from lightllm.platform.base.ops.base import OpsProtocol +from lightllm.platform.base.ops.runtime import build_ops, register_op + +__all__ = ["OpsProtocol", "build_ops", "register_op"] diff --git a/lightllm/platform/base/ops/base.py b/lightllm/platform/base/ops/base.py new file mode 100644 index 0000000000..d1c49d940e --- /dev/null +++ b/lightllm/platform/base/ops/base.py @@ -0,0 +1,119 @@ +import inspect +import torch +from typing import Any, Callable, Optional, Protocol, Tuple, runtime_checkable + + +def get_protocol_op_names(protocol: type) -> tuple[str, ...]: + """ Get the names of the public methods in the protocol. """ + names: list[str] = [] + for name, value in protocol.__dict__.items(): + if name.startswith("_"): + continue + if inspect.isfunction(value): + names.append(name) + return tuple(names) + + +@runtime_checkable +class OpsProtocol(Protocol): + + def multimodal_emb( + self, + *, + out: torch.Tensor, + prompt_ids: torch.Tensor, + text_weight_embs: torch.Tensor, + embed_cache: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs_in_cache: torch.Tensor, + tp_text_start_token_id: int, + tp_text_end_token_id: int, + tp_world_size: int, + ) -> None: ... + + def offload_embed_tensor_to_cache( + self, + *, + embed_tensor: torch.Tensor, + cache_tensor: torch.Tensor, + start_index_in_cache: int, + ) -> None: ... + + def rotary_emb( + self, + *, + is_prefill: bool, + batch_size: int, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + partial_rotary_factor: float = 1.0, + rotary_impl: Optional[Callable] = None, + ) -> None: ... + + def ffn( + self, + *, + input: torch.Tensor, + layer_weight: Any, + alloc_func: Callable, + embed_dim: int, + ) -> torch.Tensor: ... + + def embedding( + self, + *, + input_ids: torch.Tensor, + weight: torch.Tensor, + out: Optional[torch.Tensor] = None, + alloc_func: Callable = torch.empty, + vob_start_id: int = 0, + vob_end_id: Optional[int] = None, + ) -> torch.Tensor: ... + + def lm_head( + self, + *, + input: torch.Tensor, + weight: torch.Tensor, + out: Optional[torch.Tensor] = None, + alloc_func: Callable = torch.empty, + ) -> torch.Tensor: ... + + def rms_norm( + self, + *, + input: torch.Tensor, + weight: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func: Callable = torch.empty, + gate_value: Optional[torch.Tensor] = None, + ) -> torch.Tensor: ... + + def layer_norm( + self, + *, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + out: Optional[torch.Tensor] = None, + alloc_func: Callable = torch.empty, + ) -> torch.Tensor: ... + + def qk_rms_norm( + self, + q: torch.Tensor, + k: torch.Tensor, + w_q: torch.Tensor, + w_k: torch.Tensor, + eps: float, + fp32_multiply: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: ... + + +# op names for ops, the order is the same as the declaration order in the protocol +OP_NAMES: tuple[str, ...] = get_protocol_op_names(OpsProtocol) diff --git a/lightllm/platform/base/ops/ensure_out.py b/lightllm/platform/base/ops/ensure_out.py new file mode 100644 index 0000000000..e51fc28250 --- /dev/null +++ b/lightllm/platform/base/ops/ensure_out.py @@ -0,0 +1,137 @@ +import torch +from typing import Any, Callable, Optional, Tuple, TypedDict, Union + +# (int, int, ...) or (("tensor_name", dim_index), ...) +OutShapeSpec = Union[Tuple[int, ...], Tuple[Tuple[str, int], ...]] +# torch.dtype or "tensor_name" +OutDtypeSpec = Union[torch.dtype, str] +# torch.device or "tensor_name" +OutDeviceSpec = Union[torch.device, str] + + +class AutoOutSpec(TypedDict, total=False): + input_name: str + out_shape: OutShapeSpec + out_dtype: OutDtypeSpec + out_device: OutDeviceSpec + + +def ensure_out( + out: Optional[torch.Tensor], + *, + shape: Tuple[int, ...], + dtype: torch.dtype, + device: Union[str, torch.device], + alloc_func: Callable = torch.empty, + contiguous: bool = True, +) -> torch.Tensor: + if out is None: + return alloc_func(shape, dtype=dtype, device=device) + + if tuple(out.shape) != tuple(shape): + raise ValueError(f"out.shape {tuple(out.shape)} != expected {tuple(shape)}") + if out.dtype != dtype: + raise ValueError(f"out.dtype {out.dtype} != expected {dtype}") + if out.device != torch.device(device): + raise ValueError(f"out.device {out.device} != expected {device}") + if contiguous and not out.is_contiguous(): + raise ValueError("out must be contiguous") + return out + + +def _is_literal_shape(spec: OutShapeSpec) -> bool: + return all(isinstance(dim, int) for dim in spec) + + +def _resolve_shape(spec: OutShapeSpec, kwargs: dict) -> tuple[int, ...]: + if _is_literal_shape(spec): + return tuple(spec) + dims: list[int] = [] + for name, dim in spec: + if name not in kwargs: + raise ValueError( + f"out_shape references {name!r} but kwargs keys are {list(kwargs.keys())}" + ) + tensor = kwargs[name] + try: + dims.append(tensor.shape[dim]) + except IndexError as exc: + raise ValueError( + f"out_shape ({name!r}, {dim}) is invalid for tensor shape {tuple(tensor.shape)}" + ) from exc + return tuple(dims) + + +def _resolve_dtype(spec: OutDtypeSpec, kwargs: dict) -> torch.dtype: + if isinstance(spec, torch.dtype): + return spec + return kwargs[spec].dtype + + +def _resolve_device(spec: OutDeviceSpec, kwargs: dict) -> torch.device: + if isinstance(spec, torch.device): + return spec + if spec in kwargs: + return kwargs[spec].device + return torch.device(spec) + + +def _is_out_fully_specified(config: AutoOutSpec) -> bool: + return ( + config.get("out_shape") is not None + and config.get("out_dtype") is not None + and config.get("out_device") is not None + ) + + +def _get_base_tensor(config: AutoOutSpec, kwargs: dict) -> tuple[torch.Tensor, str]: + input_name = config["input_name"] + if input_name not in kwargs: + raise ValueError( + f"input_name '{input_name}' not found in kwargs, available keys: {list(kwargs.keys())}" + ) + + tensor = kwargs[input_name] + if not isinstance(tensor, torch.Tensor): + raise ValueError(f"kwargs['{input_name}'] must be a torch.Tensor, got {type(tensor)}") + return tensor, input_name + + +def _resolve_out_spec(config: AutoOutSpec, kwargs: dict) -> tuple[tuple[int, ...], torch.dtype, torch.device]: + # If out_shape, out_dtype, and out_device are all specified, then input_name is optional + if _is_out_fully_specified(config): + shape = _resolve_shape(config["out_shape"], kwargs) + dtype = _resolve_dtype(config["out_dtype"], kwargs) + device = _resolve_device(config["out_device"], kwargs) + return shape, dtype, device + # If out_shape, out_dtype, and out_device are not all specified, then input_name is required + if "input_name" not in config: + raise ValueError( + "input_name is required when out_shape, out_dtype, and out_device are not all specified" + ) + # Get the base tensor and input_name + base_tensor, input_name = _get_base_tensor(config, kwargs) + out_shape = config.get("out_shape") + # If out_shape is not specified, use the shape of the base tensor + if out_shape is None: + shape: tuple[int, ...] = tuple(base_tensor.shape) + dtype = base_tensor.dtype + device = base_tensor.device + else: + shape = _resolve_shape(out_shape, kwargs) + dtype = _resolve_dtype(config.get("out_dtype", input_name), kwargs) + device = _resolve_device(config.get("out_device", input_name), kwargs) + + return shape, dtype, device + + +def wrap_with_out(config: AutoOutSpec, impl: Callable) -> Callable: + + def public(*, out: Optional[torch.Tensor] = None, alloc_func: Callable = torch.empty, **kwargs: Any): + shape, dtype, device = _resolve_out_spec(config, kwargs) + out = ensure_out(out, shape=shape, dtype=dtype, device=device, alloc_func=alloc_func) + return impl(out=out, **kwargs) + + public.__name__ = impl.__name__ + public.__doc__ = impl.__doc__ + return public diff --git a/lightllm/platform/base/ops/runtime.py b/lightllm/platform/base/ops/runtime.py new file mode 100644 index 0000000000..7d75734002 --- /dev/null +++ b/lightllm/platform/base/ops/runtime.py @@ -0,0 +1,197 @@ +import inspect +from typing import Callable, TypeVar + +from lightllm.platform.base.ops.base import OP_NAMES, OpsProtocol +from lightllm.platform.base.ops.ensure_out import AutoOutSpec, wrap_with_out +from lightllm.platform.base.registry import get_op_modules_for_fallback +from lightllm.platform.plugins import get_ops_plugin_config, resolve_op_fallback + +F = TypeVar("F", bound=Callable) + + +class OpRegistry: + + def __init__(self) -> None: + self._ops: dict[str, dict[str, Callable]] = {} + + def register(self, impl_family: str, op_name: str, impl: Callable) -> None: + family_ops = self._ops.setdefault(impl_family, {}) + if op_name in family_ops: + raise ValueError(f"Op '{op_name}' already registered for impl_family '{impl_family}'") + family_ops[op_name] = impl + + def get(self, impl_family: str, op_name: str) -> Callable | None: + return self._ops.get(impl_family, {}).get(op_name) + + def has_impl_family(self, impl_family: str) -> bool: + return bool(self._ops.get(impl_family)) + + +op_registry = OpRegistry() + + +# Helper function to validate tensor name in function parameters +def _require_tensor_param(op_name: str, param_name: str, parameters: dict[str, inspect.Parameter]) -> None: + if param_name not in parameters: + raise ValueError( + f"register_op({op_name!r}): tensor param {param_name!r} " + f"not found in function parameters {list(parameters)}" + ) + + +def _validate_auto_out_spec(op_name: str, out: AutoOutSpec, sig: inspect.Signature) -> None: + parameters = sig.parameters + # if out_shape, out_dtype, and out_device are all specified, then input_name is optional + fully_specified = ( + out.get("out_shape") is not None + and out.get("out_dtype") is not None + and out.get("out_device") is not None + ) + if not fully_specified: + if "input_name" not in out: + raise ValueError( + f"register_op({op_name!r}): 'input_name' is required unless " + "out_shape, out_dtype, and out_device are all specified" + ) + _require_tensor_param(op_name, out["input_name"], parameters) + + out_shape = out.get("out_shape") + if out_shape is not None: + if not isinstance(out_shape, tuple) or not out_shape: + raise ValueError(f"register_op({op_name!r}): 'out_shape' must be a non-empty tuple") + if not all(isinstance(dim, int) for dim in out_shape): + for item in out_shape: + if not isinstance(item, tuple) or len(item) != 2: + raise ValueError( + f"register_op({op_name!r}): invalid out_shape item {item!r}, " + "expected (tensor_name, dim_index)" + ) + name, dim = item + if not isinstance(name, str) or not isinstance(dim, int): + raise ValueError( + f"register_op({op_name!r}): invalid out_shape item {item!r}, " + "expected (tensor_name, dim_index)" + ) + _require_tensor_param(op_name, name, parameters) + + for key in ("out_dtype", "out_device"): + spec = out.get(key) + if isinstance(spec, str): + _require_tensor_param(op_name, spec, parameters) + + +def register_op( + impl_family: str, + *, + name: str | None = None, + out: AutoOutSpec | None = None, +) -> Callable[[F], F]: + + def decorator(fn: F) -> F: + op_name = name or fn.__name__ + if out is not None: + _validate_auto_out_spec(op_name, out, inspect.signature(fn)) + impl: Callable = wrap_with_out(out, fn) + else: + impl = fn + op_registry.register(impl_family, op_name, impl) + return fn + + return decorator + + +def load_ops(fallback_chain: tuple[str, ...], extra_modules: tuple[str, ...] = ()) -> None: + import importlib + + # Load extra modules first + for module_name in extra_modules: + importlib.import_module(module_name) + + for module_name in get_op_modules_for_fallback(fallback_chain): + importlib.import_module(module_name) + + +class OpsView(OpsProtocol): + + def __init__(self, platform: str, *, fallback_chain: tuple[str, ...]) -> None: + self._platform = platform + + for op_name in OP_NAMES: + impl = None + for impl_family in fallback_chain: + impl = op_registry.get(impl_family, op_name) + if impl is not None: + break + if impl is None: + raise KeyError( + f"Op '{op_name}' is not registered for platform '{platform}' " + f"via fallback chain {fallback_chain}" + ) + setattr(self, op_name, impl) + + def __getattr__(self, name: str): + raise AttributeError(f"Op '{name}' is not registered for platform '{self._platform}'") + + +def build_ops(platform: str) -> OpsView: + plugin_config = get_ops_plugin_config() + fallback_chain = resolve_op_fallback(platform, plugin_config) + load_ops(fallback_chain, plugin_config.extra_modules) + _validate_extra_fallback_loaded(platform, plugin_config) + + return OpsView(platform, fallback_chain=fallback_chain) + + +def _validate_extra_fallback_loaded(platform: str, plugin_config) -> None: + if not plugin_config.extra_fallback: + return + + from lightllm.platform.base.registry import get_platform_spec + + missing: list[str] = [] + for family in plugin_config.extra_fallback: + if op_registry.has_impl_family(family): + continue + missing.append(family) + + if not missing: + return + + hints = _format_extra_fallback_hints(missing, plugin_config) + raise RuntimeError( + "External op impl families configured in extra_op_fallback did not register " + f"any @register_op implementations after loading: {missing}. " + f"Fallback chain for platform {platform!r} was " + f"{resolve_op_fallback(platform, plugin_config)!r}, but these families are empty " + f"so all ops silently fall back to {get_platform_spec(platform).op_fallback!r}. " + f"{hints}" + ) + + +def _format_extra_fallback_hints(missing: list[str], plugin_config) -> str: + hints: list[str] = [] + if plugin_config.extra_modules: + hints.append( + "Check that --extra_op_modules / plugin extra_modules import paths are correct " + "and @register_op impl_family names match extra_op_fallback." + ) + else: + hints.append( + "Add --extra_op_modules (scheme 1) or --extra_op_plugins (scheme 2)." + ) + try: + from importlib.metadata import entry_points as eps_fn + + available = sorted(ep.name for ep in eps_fn(group="lightllm.op_plugins")) + except Exception: + available = [] + if available: + hints.append(f"Installed op plugins: {available}.") + for family in missing: + if family.endswith("_plugin") or family == "example_plugin": + hints.append( + f"Family {family!r} looks like a pip plugin impl name; " + f"you likely want --extra_op_plugins example_op_plugin, " + f"not --extra_op_fallback {family}." + ) + return " ".join(hints) diff --git a/lightllm/platform/base/registry.py b/lightllm/platform/base/registry.py new file mode 100644 index 0000000000..7ed001f8f1 --- /dev/null +++ b/lightllm/platform/base/registry.py @@ -0,0 +1,111 @@ +import importlib.util +from abc import ABC +from dataclasses import dataclass +from typing import TYPE_CHECKING, Type + +from lightllm.platform.base.graph import BackendGraph +from lightllm.platform.base.runtime import BackendRuntime + +if TYPE_CHECKING: + from lightllm.platform.base.ops.base import OpsProtocol + +# For internal module registration, we use the prefix "lightllm.platform.ops." +OP_FAMILY_MODULES_PREFIX = "lightllm.platform.ops." + +PLATFORMS: dict[str, "PlatformSpec"] = {} + +_platforms_loaded = False + + +@dataclass(frozen=True) +class PlatformSpec: + name: str + backend_cls: Type["Backend"] + op_fallback: tuple[str, ...] + + +class Backend(ABC): + platform_name: str + _runtime: BackendRuntime + _graph: BackendGraph + _ops: "OpsProtocol" + + @property + def name(self) -> str: + return self.platform_name + + @property + def runtime(self) -> BackendRuntime: + return self._runtime + + @property + def graph(self) -> BackendGraph: + return self._graph + + @property + def ops(self) -> "OpsProtocol": + return self._ops + + +def register_platform(name: str, *, op_fallback: tuple[str, ...]): + def decorator(backend_cls: Type[Backend]) -> Type[Backend]: + if name in PLATFORMS: + raise ValueError(f"Platform already registered: {name}") + # set platform name to the backend class + backend_cls.platform_name = name + PLATFORMS[name] = PlatformSpec( + name=name, + backend_cls=backend_cls, + op_fallback=op_fallback, + ) + return backend_cls + + return decorator + + +def _ensure_platforms_registered() -> None: + global _platforms_loaded + if _platforms_loaded: + return + _platforms_loaded = True + + import importlib + import pkgutil + + import lightllm.platform.backends as backends_pkg + + for module_info in pkgutil.iter_modules(backends_pkg.__path__): + if module_info.name.startswith("_"): + continue + # To import the backend module to avoid adding it to __init__.py manually + importlib.import_module(f"{backends_pkg.__name__}.{module_info.name}") + + +def get_platform_spec(platform: str) -> PlatformSpec: + _ensure_platforms_registered() + try: + return PLATFORMS[platform] + except KeyError as exc: + raise RuntimeError(f"Platform is not configured: {platform}") from exc + + +def has_builtin_ops_module(family: str) -> bool: + module_name = f"{OP_FAMILY_MODULES_PREFIX}{family}" + return importlib.util.find_spec(module_name) is not None + + +def get_op_modules_for_fallback(fallback_chain: tuple[str, ...]) -> tuple[str, ...]: + modules: list[str] = [] + seen: set[str] = set() + for family in fallback_chain: + module_name = f"{OP_FAMILY_MODULES_PREFIX}{family}" + if module_name in seen: + continue + # External modules are not required to be imported by lightllm.platform.ops + if not has_builtin_ops_module(family): + continue + + modules.append(module_name) + seen.add(module_name) + + return tuple(modules) diff --git a/lightllm/platform/base/runtime.py b/lightllm/platform/base/runtime.py new file mode 100644 index 0000000000..e4991f554b --- /dev/null +++ b/lightllm/platform/base/runtime.py @@ -0,0 +1,120 @@ +import torch +import torch.distributed as dist +from abc import ABC, abstractmethod +from typing import Any, Optional, ContextManager, Tuple, Union + + +class BackendRuntime(ABC): + + @property + @abstractmethod + def device_type(self) -> str: + pass + + @property + @abstractmethod + def dist_backend(self) -> str: + pass + + @property + def dist_init_passes_device_id(self) -> bool: + return True + + def init_process_group( + self, + *, + host: str, + port: int, + rank: int, + world_size: int, + device_id: int, + ) -> torch.device: + target_device = self.target_device(device_id) + self.set_device(target_device) + kwargs = dict( + backend=self.dist_backend, + init_method=f"tcp://{host}:{port}", + rank=rank, + world_size=world_size, + ) + if self.dist_init_passes_device_id: + kwargs["device_id"] = target_device + dist.init_process_group(**kwargs) + return target_device + + @abstractmethod + def mem_get_info(self, device: Union[int, torch.device]) -> Tuple[int, int]: + pass + + @abstractmethod + def get_device_properties(self, device: Union[int, torch.device]) -> Any: + pass + + def target_device(self, device_id: Optional[int] = None) -> torch.device: + if device_id is None: + device_id = self.current_device() + return torch.device(self.device_type, device_id) + + @abstractmethod + def device_count(self) -> int: + pass + + @abstractmethod + def is_available(self) -> bool: + pass + + @abstractmethod + def current_device(self) -> int: + pass + + @abstractmethod + def get_device_name(self, device_id: Optional[int] = None) -> str: + pass + + def _parse(self, device: Union[int, str, torch.device]) -> torch.device: + if isinstance(device, torch.device): + _device = device + elif isinstance(device, int): + _device = torch.device(self.device_type, device) + elif isinstance(device, str): + _device = torch.device(device) + else: + raise ValueError(f"Invalid device: {device}") + + if _device.type != self.device_type: + raise ValueError( + f"Expected device type {self.device_type!r}, got {_device.type!r} ({_device})" + ) + return _device + + @abstractmethod + def set_device(self, device: Union[int, str, torch.device]) -> None: + pass + + @abstractmethod + def create_stream(self, **kwargs) -> Any: + pass + + @abstractmethod + def stream(self, stream: Optional[Any] = None) -> ContextManager: + pass + + @abstractmethod + def current_stream(self, device_id: Optional[int] = None) -> Any: + pass + + @abstractmethod + def create_event(self, **kwargs) -> torch.Event: + pass + + @abstractmethod + def synchronize(self) -> None: + pass + + @abstractmethod + def empty_cache(self) -> None: + pass + + @abstractmethod + def manual_seed_all(self, seed: int) -> None: + pass diff --git a/lightllm/platform/graph/ascend.py b/lightllm/platform/graph/ascend.py new file mode 100644 index 0000000000..5654deabe7 --- /dev/null +++ b/lightllm/platform/graph/ascend.py @@ -0,0 +1,18 @@ +import torch +from typing import Any, ContextManager, Optional +from lightllm.platform.base.graph import BackendGraph + + +class AscendGraphBackend(BackendGraph): + + def create_graph(self) -> Any: + return torch.npu.NPUGraph() + + def graph(self, graph_obj: Any, pool: Optional[Any] = None, stream: Optional[Any] = None) -> ContextManager: + return torch.npu.graph(graph_obj, pool=pool, stream=stream) + + def graph_pool_handle(self) -> Any: + return torch.npu.graph_pool_handle() + + def is_capturing(self) -> bool: + return torch.npu.is_current_stream_capturing() \ No newline at end of file diff --git a/lightllm/platform/graph/cuda.py b/lightllm/platform/graph/cuda.py new file mode 100644 index 0000000000..6bea7f1456 --- /dev/null +++ b/lightllm/platform/graph/cuda.py @@ -0,0 +1,18 @@ +import torch +from typing import Any, ContextManager, Optional +from lightllm.platform.base.graph import BackendGraph + + +class CudaGraphBackend(BackendGraph): + + def create_graph(self) -> Any: + return torch.cuda.CUDAGraph() + + def graph(self, graph_obj: Any, pool: Optional[Any] = None, stream: Optional[Any] = None) -> ContextManager: + return torch.cuda.graph(graph_obj, pool=pool, stream=stream) + + def graph_pool_handle(self) -> Any: + return torch.cuda.graph_pool_handle() + + def is_capturing(self) -> bool: + return torch.cuda.is_current_stream_capturing() diff --git a/lightllm/platform/ops/ascend.py b/lightllm/platform/ops/ascend.py new file mode 100644 index 0000000000..62ca4f50f4 --- /dev/null +++ b/lightllm/platform/ops/ascend.py @@ -0,0 +1,175 @@ +import torch +from typing import Any, Callable, Optional, Tuple + +from lightllm.common.basemodel.triton_kernel.embedding import npu_embedding +from lightllm.common.basemodel.triton_kernel.multimodal_emb import npu_multimodal_emb +from lightllm.models.llama.layer_infer.transformer_layer_infer import npu_ffn_fwd +from lightllm.models.llama.triton_kernel.rotary_emb import npu_rotary_emb_fwd +from lightllm.platform.base.ops import register_op +from lightllm.server.embed_cache.copy_to_cache import npu_offload_embed_tensor_to_cache + + +@register_op("ascend") +def multimodal_emb( + *, + out: torch.Tensor, + prompt_ids: torch.Tensor, + text_weight_embs: torch.Tensor, + embed_cache: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs_in_cache: torch.Tensor, + tp_text_start_token_id: int, + tp_text_end_token_id: int, + tp_world_size: int, +) -> None: + npu_multimodal_emb( + out=out, + prompt_ids=prompt_ids, + text_weight_embs=text_weight_embs, + embed_cache=embed_cache, + img_token_lens=img_token_lens, + img_start_token_ids=img_start_token_ids, + img_start_locs_in_cache=img_start_locs_in_cache, + tp_text_start_token_id=tp_text_start_token_id, + tp_text_end_token_id=tp_text_end_token_id, + tp_world_size=tp_world_size, + ) + + +@register_op("ascend") +def offload_embed_tensor_to_cache( + *, + embed_tensor: torch.Tensor, + cache_tensor: torch.Tensor, + start_index_in_cache: int, +) -> None: + npu_offload_embed_tensor_to_cache( + embed_tensor=embed_tensor, + cache_tensor=cache_tensor, + start_index_in_cache=start_index_in_cache, + ) + + +@register_op("ascend") +def rotary_emb( + *, + is_prefill: bool, + batch_size: int, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + partial_rotary_factor: float = 1.0, + rotary_impl: Optional[Callable] = None, +) -> None: + impl = rotary_impl or npu_rotary_emb_fwd + impl( + is_prefill=is_prefill, + batch_size=batch_size, + q=q, + k=k, + cos=cos, + sin=sin, + partial_rotary_factor=partial_rotary_factor, + ) + + +@register_op("ascend") +def ffn( + *, + input: torch.Tensor, + layer_weight: Any, + alloc_func: Callable, + embed_dim: int, +) -> torch.Tensor: + return npu_ffn_fwd(input, layer_weight, embed_dim) + + +@register_op( + "ascend", + out={"input_name": "weight","out_shape": (("input_ids", 0), ("weight", 1))}, +) +def embedding( + *, + input_ids: torch.Tensor, + weight: torch.Tensor, + out: torch.Tensor, + vob_start_id: int, + vob_end_id: Optional[int] = None, +) -> torch.Tensor: + if vob_end_id is None: + vob_end_id = weight.shape[0] + npu_embedding(input_ids, weight, vob_start_id, vob_end_id, out) + return out + + +@register_op( + "ascend", + out={"input_name": "input", "out_shape": (("weight", 0), ("input", 1))}, +) +def lm_head( + *, + input: torch.Tensor, + weight: torch.Tensor, + out: torch.Tensor, +) -> torch.Tensor: + torch.mm(weight, input, out=out) + return out + + +@register_op("ascend", out={"input_name": "input"}) +def rms_norm( + *, + input: torch.Tensor, + weight: torch.Tensor, + eps: float, + out: torch.Tensor, + gate_value: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if gate_value is not None: + raise NotImplementedError("gate_value is not supported for rms_norm on ascend") + + import torch_npu + + _out = torch_npu.npu_rms_norm(input, weight, epsilon=eps)[0] + out.copy_(_out) + return out + + +@register_op("ascend", out={"input_name": "input"}) +def layer_norm( + *, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + out: torch.Tensor, +) -> torch.Tensor: + raise NotImplementedError("layer_norm is not supported on ascend") + + +@register_op("ascend") +def qk_rms_norm( + *, + q: torch.Tensor, + k: torch.Tensor, + w_q: torch.Tensor, + w_k: torch.Tensor, + eps: float, + fp32_multiply: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + import torch_npu + + head_dim_q = w_q.shape[0] + head_dim_k = w_k.shape[0] + flat_q = q.reshape(-1, head_dim_q) + flat_k = k.reshape(-1, head_dim_k) + _q = torch_npu.npu_rms_norm(flat_q, w_q, epsilon=eps)[0] + _k = torch_npu.npu_rms_norm(flat_k, w_k, epsilon=eps)[0] + _q = _q.view(q.shape) + _k = _k.view(k.shape) + q.copy_(_q) + k.copy_(_k) + + return q, k diff --git a/lightllm/platform/ops/cuda_like.py b/lightllm/platform/ops/cuda_like.py new file mode 100644 index 0000000000..51a1bfdb97 --- /dev/null +++ b/lightllm/platform/ops/cuda_like.py @@ -0,0 +1,173 @@ +import torch +from typing import Any, Callable, Optional, Tuple + +from lightllm.common.basemodel.triton_kernel.embedding import embedding as cuda_embedding +from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb as cuda_multimodal_emb +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.common.basemodel.triton_kernel.norm.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.platform.base.ops import register_op +from lightllm.server.embed_cache.copy_to_cache import ( + offload_embed_tensor_to_cache as cuda_offload_embed_tensor_to_cache, +) + + +@register_op("cuda_like") +def multimodal_emb( + *, + out: torch.Tensor, + prompt_ids: torch.Tensor, + text_weight_embs: torch.Tensor, + embed_cache: torch.Tensor, + img_token_lens: torch.Tensor, + img_start_token_ids: torch.Tensor, + img_start_locs_in_cache: torch.Tensor, + tp_text_start_token_id: int, + tp_text_end_token_id: int, + tp_world_size: int, +) -> None: + cuda_multimodal_emb( + out=out, + prompt_ids=prompt_ids, + text_weight_embs=text_weight_embs, + embed_cache=embed_cache, + img_token_lens=img_token_lens, + img_start_token_ids=img_start_token_ids, + img_start_locs_in_cache=img_start_locs_in_cache, + tp_text_start_token_id=tp_text_start_token_id, + tp_text_end_token_id=tp_text_end_token_id, + tp_world_size=tp_world_size, + ) + + +@register_op("cuda_like") +def offload_embed_tensor_to_cache( + *, + embed_tensor: torch.Tensor, + cache_tensor: torch.Tensor, + start_index_in_cache: int, +) -> None: + cuda_offload_embed_tensor_to_cache( + embed_tensor=embed_tensor, + cache_tensor=cache_tensor, + start_index_in_cache=start_index_in_cache, + ) + + +@register_op("cuda_like") +def rotary_emb( + *, + is_prefill: bool, + batch_size: int, + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + partial_rotary_factor: float = 1.0, + rotary_impl: Optional[Callable] = None, +) -> None: + impl = rotary_impl or rotary_emb_fwd + impl(q=q, k=k, cos=cos, sin=sin, partial_rotary_factor=partial_rotary_factor) + + +@register_op("cuda_like") +def ffn( + *, + input: torch.Tensor, + layer_weight: Any, + alloc_func: Callable, + embed_dim: int, +) -> torch.Tensor: + input = input.view(-1, embed_dim) + up_gate_out = layer_weight.gate_up_proj.mm(input) + ffn1_out = alloc_func( + (input.size(0), up_gate_out.size(1) // 2), + dtype=input.dtype, + device=input.device, + ) + silu_and_mul_fwd(up_gate_out, ffn1_out) + return layer_weight.down_proj.mm(ffn1_out) + + +@register_op( + "cuda_like", + out={"input_name": "weight", "out_shape": (("input_ids", 0), ("weight", 1))}, +) +def embedding( + *, + input_ids: torch.Tensor, + weight: torch.Tensor, + out: torch.Tensor, + vob_start_id: int, + vob_end_id: Optional[int] = None, +) -> torch.Tensor: + if vob_end_id is None: + vob_end_id = weight.shape[0] + cuda_embedding( + input_ids=input_ids, + weight=weight, + vob_start_id=vob_start_id, + vob_end_id=vob_end_id, + out=out, + ) + return out + + +@register_op( + "cuda_like", + out={"input_name": "input", "out_shape": (("weight", 0), ("input", 1))}, +) +def lm_head( + *, + input: torch.Tensor, + weight: torch.Tensor, + out: torch.Tensor, +) -> torch.Tensor: + torch.mm(weight, input, out=out) + return out + + +@register_op("cuda_like", out={"input_name": "input"}) +def rms_norm( + *, + input: torch.Tensor, + weight: torch.Tensor, + eps: float, + out: torch.Tensor, + gate_value: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if gate_value is None: + rmsnorm_forward(x=input, weight=weight, eps=eps, out=out) + else: + gated_rmsnorm_forward(x=input, weight=weight, bias=None, eps=eps, z=gate_value, out=out) + return out + + +@register_op("cuda_like", out={"input_name": "input"}) +def layer_norm( + *, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + out: torch.Tensor, +) -> torch.Tensor: + out_ = layernorm_forward(x=input, weight=weight, bias=bias, eps=eps) + out.copy_(out_) + return out + + +@register_op("cuda_like") +def qk_rms_norm( + *, + q: torch.Tensor, + k: torch.Tensor, + w_q: torch.Tensor, + w_k: torch.Tensor, + eps: float, + fp32_multiply: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + return qk_rmsnorm_fused_forward(q=q, k=k, w_q=w_q, w_k=w_k, eps=eps, fp32_multiply=fp32_multiply) diff --git a/lightllm/platform/plugins.py b/lightllm/platform/plugins.py new file mode 100644 index 0000000000..1615b55fe7 --- /dev/null +++ b/lightllm/platform/plugins.py @@ -0,0 +1,214 @@ +from dataclasses import dataclass +from importlib.metadata import entry_points +from typing import Any, Callable, Iterable +from lightllm.utils.envs_utils import get_env_start_args + +OP_PLUGIN_ENTRY_GROUP = "lightllm.op_plugins" + + +@dataclass(frozen=True) +class OpsPluginConfig: + extra_fallback: tuple[str, ...] = () + extra_modules: tuple[str, ...] = () + + +_ops_plugin_config: OpsPluginConfig | None = None + + +def _parse_csv(value: str | None) -> tuple[str, ...]: + """ Parse a comma-separated string into a tuple of strings. """ + if not value: + return () + return tuple(item.strip() for item in value.split(",") if item.strip()) + + +def _normalize_tuple(values: Iterable[str] | None) -> tuple[str, ...]: + if not values: + return () + return tuple(value.strip() for value in values if value and value.strip()) + + +def merge_ops_plugin_configs(configs: Iterable[OpsPluginConfig]) -> OpsPluginConfig: + extra_fallback: list[str] = [] + extra_modules: list[str] = [] + seen_fallback: set[str] = set() + seen_modules: set[str] = set() + for config in configs: + for family in config.extra_fallback: + if family not in seen_fallback: + seen_fallback.add(family) + extra_fallback.append(family) + for module_name in config.extra_modules: + if module_name not in seen_modules: + seen_modules.add(module_name) + extra_modules.append(module_name) + + return OpsPluginConfig( + extra_fallback=tuple(extra_fallback), + extra_modules=tuple(extra_modules), + ) + + +def _coerce_plugin_config(value: Any) -> OpsPluginConfig: + if value is None: + return OpsPluginConfig() + if isinstance(value, OpsPluginConfig): + return value + if isinstance(value, dict): + return OpsPluginConfig( + extra_fallback=_normalize_tuple(value.get("extra_fallback")), + extra_modules=_normalize_tuple(value.get("extra_modules")), + ) + raise TypeError(f"Unsupported op plugin config type: {type(value)!r}") + + +def _iter_op_plugin_entry_points(): + eps = entry_points() + # For Python 3.10+ + if hasattr(eps, "select"): + yield from eps.select(group=OP_PLUGIN_ENTRY_GROUP) + return + # For Python 3.9 and below + yield from eps.get(OP_PLUGIN_ENTRY_GROUP, []) + + +def _load_entry_point_plugins(plugin_names: tuple[str, ...]) -> list[OpsPluginConfig]: + if not plugin_names: + return [] + + selected = set(plugin_names) + configs: list[OpsPluginConfig] = [] + loaded_names: set[str] = set() + + for entry_point in _iter_op_plugin_entry_points(): + if entry_point.name not in selected: + continue + # Load the plugin + register_fn: Callable[[], Any] = entry_point.load() + configs.append(_coerce_plugin_config(register_fn())) + loaded_names.add(entry_point.name) + + # Check if any plugins are missings + missing = selected - loaded_names + if missing: + available = sorted(entry_point.name for entry_point in _iter_op_plugin_entry_points()) + message = ( + f"Op plugin(s) not found in entry point group {OP_PLUGIN_ENTRY_GROUP!r}: " + f"{sorted(missing)}" + ) + if available: + message += f". Installed plugins: {available}" + else: + message += ( + ". No op plugins installed; register entry points in group " + f"{OP_PLUGIN_ENTRY_GROUP!r} and pip install -e your plugin package." + ) + raise RuntimeError(message) + + return configs + + +def _list_installed_op_plugin_names() -> tuple[str, ...]: + return tuple(sorted(entry_point.name for entry_point in _iter_op_plugin_entry_points())) + + +def _validate_direct_ops_config(config: OpsPluginConfig) -> None: + from lightllm.platform.base.registry import has_builtin_ops_module + + if config.extra_modules and not config.extra_fallback: + raise RuntimeError( + "--extra_op_modules requires --extra_op_fallback: external modules must " + "@register_op under impl family names listed in extra_op_fallback." + ) + + external_fallbacks = [ + family for family in config.extra_fallback if not has_builtin_ops_module(family) + ] + if not external_fallbacks: + return + + if config.extra_modules: + return + + hints: list[str] = [ + "External impl families need modules that call @register_op. " + "Use --extra_op_modules (scheme 1) or --extra_op_plugins (scheme 2)." + ] + installed = _list_installed_op_plugin_names() + if installed: + hints.append(f"Installed op plugins: {list(installed)}.") + for family in external_fallbacks: + if family.endswith("_plugin") or family == "example_plugin": + hints.append( + f"For family {family!r}, did you mean --extra_op_plugins example_op_plugin " + f"instead of --extra_op_fallback {family}?" + ) + raise RuntimeError( + f"--extra_op_fallback includes external impl families {external_fallbacks} " + f"without --extra_op_modules; no ops will be loaded for them. " + + " ".join(hints) + ) + + +def _plugin_config_from_cli() -> OpsPluginConfig: + args = get_env_start_args() + return OpsPluginConfig( + extra_fallback=_parse_csv(getattr(args, "extra_op_fallback", None)), + extra_modules=_parse_csv(getattr(args, "extra_op_modules", None)), + ) + + +def _collect_op_plugin_names() -> tuple[str, ...]: + args = get_env_start_args() + return _parse_csv(getattr(args, "extra_op_plugins", None)) + + +def configure_op_plugins() -> OpsPluginConfig: + global _ops_plugin_config + + # Collect plugin names from CLI + plugin_names = _collect_op_plugin_names() + # Collect plugin config from CLI + direct_config = _plugin_config_from_cli() + # Check if there are plugins or direct config + has_plugins = bool(plugin_names) + has_direct = bool(direct_config.extra_fallback or direct_config.extra_modules) + # Check if both plugins and direct config are present + if has_plugins and has_direct: + raise RuntimeError( + "Op plugin configuration is ambiguous: use either " + "--extra_op_plugins or (--extra_op_fallback / --extra_op_modules), not both." + ) + # Load plugins if present + if has_plugins: + _ops_plugin_config = merge_ops_plugin_configs(_load_entry_point_plugins(plugin_names)) + # Use direct config if present + elif has_direct: + _validate_direct_ops_config(direct_config) + _ops_plugin_config = direct_config + # Use default config if no plugins or direct config are present + else: + _ops_plugin_config = OpsPluginConfig() + + return _ops_plugin_config + + +def get_ops_plugin_config() -> OpsPluginConfig: + if _ops_plugin_config is None: + return OpsPluginConfig() + return _ops_plugin_config + + +def resolve_op_fallback(platform: str, plugin_config: OpsPluginConfig | None = None) -> tuple[str, ...]: + from lightllm.platform.base.registry import get_platform_spec + + config = plugin_config or get_ops_plugin_config() + merged: list[str] = [] + seen: set[str] = set() + for family in config.extra_fallback + get_platform_spec(platform).op_fallback: + if family in seen: + continue + seen.add(family) + merged.append(family) + + return tuple(merged) diff --git a/lightllm/platform/runtime/ascend.py b/lightllm/platform/runtime/ascend.py new file mode 100644 index 0000000000..f94529ea32 --- /dev/null +++ b/lightllm/platform/runtime/ascend.py @@ -0,0 +1,62 @@ +import torch +from typing import Any, ContextManager, Optional, Tuple, Union +from lightllm.platform.base.runtime import BackendRuntime + + +class AscendRuntime(BackendRuntime): + + @property + def device_type(self) -> str: + return "npu" + + @property + def dist_backend(self) -> str: + return "hccl" + + @property + def dist_init_passes_device_id(self) -> bool: + return False + + def mem_get_info(self, device: Union[int, torch.device]) -> Tuple[int, int]: + return torch.npu.mem_get_info(device) + + def get_device_properties(self, device: Union[int, torch.device]) -> Any: + return torch.npu.get_device_properties(device) + + def device_count(self) -> int: + return torch.npu.device_count() + + def is_available(self) -> bool: + return torch.npu.is_available() + + def current_device(self) -> int: + return torch.npu.current_device() + + def get_device_name(self, device_id: Optional[int] = None) -> str: + device_id = device_id if device_id is not None else self.current_device() + return torch.npu.get_device_name(device_id) + + def set_device(self, device: Union[int, str, torch.device]) -> None: + torch.npu.set_device(self._parse(device)) + + def create_stream(self, **kwargs) -> Any: + return torch.npu.Stream(**kwargs) + + def stream(self, stream: Optional[Any] = None) -> ContextManager: + return torch.npu.stream(stream) + + def current_stream(self, device_id: Optional[int] = None) -> Any: + device_id = device_id if device_id is not None else self.current_device() + return torch.npu.current_stream(device_id) + + def create_event(self, **kwargs) -> torch.Event: + return torch.npu.Event(**kwargs) + + def synchronize(self) -> None: + torch.npu.synchronize() + + def empty_cache(self) -> None: + torch.npu.empty_cache() + + def manual_seed_all(self, seed: int) -> None: + torch.npu.manual_seed_all(seed) diff --git a/lightllm/platform/runtime/cuda.py b/lightllm/platform/runtime/cuda.py new file mode 100644 index 0000000000..39fbcc8d99 --- /dev/null +++ b/lightllm/platform/runtime/cuda.py @@ -0,0 +1,62 @@ +import torch +from typing import Any, ContextManager, Optional, Tuple, Union +from lightllm.platform.base.runtime import BackendRuntime + + +class CudaRuntime(BackendRuntime): + + @property + def device_type(self) -> str: + return "cuda" + + @property + def dist_backend(self) -> str: + return "nccl" + + @property + def dist_init_passes_device_id(self) -> bool: + return False + + def mem_get_info(self, device: Union[int, torch.device]) -> Tuple[int, int]: + return torch.cuda.mem_get_info(device) + + def get_device_properties(self, device: Union[int, torch.device]) -> Any: + return torch.cuda.get_device_properties(device) + + def device_count(self) -> int: + return torch.cuda.device_count() + + def is_available(self) -> bool: + return torch.cuda.is_available() + + def current_device(self) -> int: + return torch.cuda.current_device() + + def get_device_name(self, device_id: Optional[int] = None) -> str: + device_id = device_id if device_id is not None else self.current_device() + return torch.cuda.get_device_name(device_id) + + def set_device(self, device: Union[int, str, torch.device]) -> None: + torch.cuda.set_device(self._parse(device)) + + def create_stream(self, **kwargs) -> Any: + return torch.cuda.Stream(**kwargs) + + def stream(self, stream: Optional[Any] = None) -> ContextManager: + return torch.cuda.stream(stream) + + def current_stream(self, device_id: Optional[int] = None) -> Any: + device_id = device_id if device_id is not None else self.current_device() + return torch.cuda.current_stream(device_id) + + def create_event(self, **kwargs) -> torch.Event: + return torch.cuda.Event(**kwargs) + + def synchronize(self) -> None: + torch.cuda.synchronize() + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + def manual_seed_all(self, seed: int) -> None: + torch.cuda.manual_seed_all(seed) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f33f58b86d..cc220960ec 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -804,8 +804,27 @@ def make_argument_parser() -> argparse.ArgumentParser: "--hardware_platform", type=str, default="cuda", - choices=["cuda", "musa"], - help="""Hardware platform: cuda | musa""", + choices=["cuda", "musa", "ascend", "maca"], + help="""Hardware platform: cuda | musa | ascend | maca""", + ) + parser.add_argument( + "--extra_op_plugins", + type=str, + default=None, + help="""Comma-separated pip op plugin names (entry point group: lightllm.op_plugins).""", + ) + parser.add_argument( + "--extra_op_fallback", + type=str, + default=None, + help="""Comma-separated impl families prepended to the platform op fallback chain. + Use with --extra_op_modules for local kernel overrides without a pip plugin package.""", + ) + parser.add_argument( + "--extra_op_modules", + type=str, + default=None, + help="""Comma-separated Python modules to import for external @register_op implementations.""", ) parser.add_argument( "--enable_torch_fallback", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8c6af128c8..6b513c7c8f 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -9,7 +9,7 @@ from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger -from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name +from lightllm.utils.envs_utils import get_page_size, set_env_start_args, set_unique_server_name, get_unique_server_name from lightllm.utils.envs_utils import get_lightllm_gunicorn_keep_alive from .detokenization.manager import start_detokenization_process from .router.manager import start_router_process @@ -213,6 +213,21 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # page_size > 1 compatibility check + if get_page_size() > 1: + assert args.run_mode not in ( + "prefill", + "decode", + ), "page_size > 1 is not supported with RPyC PD split mode, please set PAGE_SIZE=1" + assert args.run_mode not in ( + "nixl_prefill", + "nixl_decode", + ), "page_size > 1 is not supported with NIXL PD split mode, please set PAGE_SIZE=1" + assert ( + not args.enable_dp_prefill_balance + ), "page_size > 1 is not supported with DP prefill balance, please set PAGE_SIZE=1" + assert not args.enable_cpu_cache, "page_size > 1 is not supported with CPU cache, please set PAGE_SIZE=1" + if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index 82919856d9..1db6f91fdb 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -10,6 +10,7 @@ from rpyc.utils.classic import obtain from lightllm.models.whisper.whisper_audio import WhisperAudioModel from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_audio import Qwen3OmniMoeAudioEncoder +from lightllm.platform import get_backend from lightllm.server.multimodal_params import AudioItem from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_audio_distributed_env @@ -50,7 +51,7 @@ def exposed_init_model(self, kvargs): raise Exception(f"can not support {self.model_type} now") self.model.load_model(weight_dir, model_cfg) - self.model = self.model.cuda() + self.model.setup_device(device_id=self.device_id) self.model.check_long_audio_infer() self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) @@ -149,7 +150,7 @@ def _infer_worker(self): """ 与 visual _infer_worker 一致:推理后对每个 item 单独放入 store_queue,由 store 线程批处理再 commit。 """ - torch.cuda.set_device(self.device_id) + get_backend().runtime.set_device(self.device_id) while True: try: if self.tp_rank_id == 0: @@ -181,7 +182,7 @@ def _save_to_cpu_cache(self, all_embeds: List[torch.Tensor], audios: List[AudioI self.cpu_embed_cache_client.copy_to_cache( embed_tensor=_emb, start_index_in_cache=audio.start_index_in_embed_cache ) - audio.cuda_event = torch.cuda.Event() + audio.cuda_event = get_backend().runtime.create_event() audio.cuda_event.record() return diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 954daa50fe..a6e039c06c 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -196,3 +196,17 @@ class StartArgs: disable_linear_att_small_page_cpu_cache: bool = field(default=False) linear_att_cache_size: Optional[int] = field(default=None) linear_att_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) + + hardware_platform: str = field(default="cuda", metadata={"choices": ["cuda", "musa", "ascend", "maca"]}) + extra_op_plugins: Optional[str] = field( + default=None, + metadata={"help": "Comma-separated lightllm.op_plugins entry point names"}, + ) + extra_op_fallback: Optional[str] = field( + default=None, + metadata={"help": "Comma-separated impl families prepended to the platform op fallback chain"}, + ) + extra_op_modules: Optional[str] = field( + default=None, + metadata={"help": "Comma-separated Python modules to import for external @register_op implementations"}, + ) diff --git a/lightllm/server/embed_cache/copy_to_cache.py b/lightllm/server/embed_cache/copy_to_cache.py index 9f275f79a2..3cf3bf1114 100644 --- a/lightllm/server/embed_cache/copy_to_cache.py +++ b/lightllm/server/embed_cache/copy_to_cache.py @@ -49,6 +49,7 @@ def offload_embed_tensor_to_cache( embed_tensor = embed_tensor.reshape(embed_tensor.shape[0], 1, embed_tensor.shape[1]) token_num = embed_tensor.shape[0] + grid = (token_num,) _offload_embed_tensor_to_cache[grid]( @@ -68,3 +69,17 @@ def offload_embed_tensor_to_cache( num_stages=1, ) return + + +@torch.no_grad() +def npu_offload_embed_tensor_to_cache( + embed_tensor: torch.Tensor, + cache_tensor: torch.Tensor, + start_index_in_cache: int, +): + if len(embed_tensor.shape) == 2: + embed_tensor = embed_tensor.reshape(embed_tensor.shape[0], 1, embed_tensor.shape[1]) + + token_num = embed_tensor.shape[0] + end = start_index_in_cache + token_num + cache_tensor[start_index_in_cache:end].copy_(embed_tensor.cpu()) diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py index 2d62cb73e5..0d047733a0 100644 --- a/lightllm/server/embed_cache/embed_cache_client.py +++ b/lightllm/server/embed_cache/embed_cache_client.py @@ -4,8 +4,8 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.embed_utils import calcu_embed_cache_meta from lightllm.common.cpu_cache import CpuCacheCreator, CpuCacheTensorSpec +from lightllm.platform import get_backend from .allocator import MemoryBlock, MemoryManager -from .copy_to_cache import offload_embed_tensor_to_cache logger = init_logger(__name__) @@ -40,6 +40,7 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool, pin_shm: bool = pin=pin_shm, pin_no_blocking=False, ) + self.platform_backend = get_backend() return def alloc_indexes(self, token_num: int) -> Optional["MemoryBlock"]: @@ -50,7 +51,7 @@ def release_indexes(self, block: "MemoryBlock"): return def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int): - offload_embed_tensor_to_cache( + self.platform_backend.ops.offload_embed_tensor_to_cache( embed_tensor=embed_tensor, cache_tensor=self.cpu_embed_cache_tensor, start_index_in_cache=start_index_in_cache, @@ -62,7 +63,7 @@ def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: # check for qwen3 vision embed tensor shape, use apply deepstack assert embed_tensor.shape[1] == self.cpu_embed_cache_tensor.shape[1] - offload_embed_tensor_to_cache( + self.platform_backend.ops.offload_embed_tensor_to_cache( embed_tensor=embed_tensor, cache_tensor=self.cpu_embed_cache_tensor, start_index_in_cache=start_index_in_cache, diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 0000000000..72fab6e274 --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,538 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +import collections +from typing import Tuple, Dict, Set, List, Optional, Union +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) + + def _compute_key(self, tokens: torch.Tensor): + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else page_tokens.cpu().numpy().tobytes() + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + split_parent_node.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + split_parent_node.children[self._compute_key(self.token_id_key[prefix_len:])] = self + split_parent_node.ref_counter = self.ref_counter + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = self.token_id_key[prefix_len:] + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + child_key = child._compute_key(child.token_id_key) + assert child_key not in self.children.keys() + self.children[child_key] = child + child.parent = self + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[child_node._compute_key(child_node.token_id_key)] + child_node.parent = None + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return len(self.children) == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + from lightllm.common.kv_cache_mem_manager import MemoryManager + + self.mem_manager: MemoryManager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _align_prefix_len(self, prefix_len: int) -> int: + if self.page_size <= 1: + return prefix_len + if prefix_len % self.page_size == 0: + return prefix_len + if self._page_size_is_power_of_2: + return prefix_len & ~self._page_size_mask + return (prefix_len // self.page_size) * self.page_size + + def _get_page_aligned_key(self, key, value=None, free_truncated=False): + aligned_len = len(key) + if aligned_len == 0: + return None, None + if self.page_size > 1 and aligned_len % self.page_size != 0: + aligned_len = self._align_prefix_len(aligned_len) + if free_truncated and aligned_len < len(key) and self.mem_manager is not None and value is not None: + truncated_value = value[aligned_len:] + if len(truncated_value) > 0: + base = truncated_value[0] - truncated_value[0] % self.page_size + full_page = torch.arange( + base, base + self.page_size, dtype=truncated_value.dtype, device=truncated_value.device + ) + self.mem_manager.free(full_page) + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + if value is None: + value = key + + assert len(key) == len(value) + key, value = self._get_page_aligned_key(key, value, free_truncated=True) + if key is None: + return 0, None + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key, value)) + + ans_prefix_len = 0 + ans_node = None + + while len(handle_stack) != 0: + node, key, value = handle_stack.popleft() + ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value) + if len(ans_tuple) == 4: + (_prefix_len, new_node, new_key, new_value) = ans_tuple + ans_prefix_len += _prefix_len + handle_stack.append((new_node, new_key, new_value)) + else: + _prefix_len, ans_node = ans_tuple + ans_prefix_len += _prefix_len + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + assert ans_node is not None + + return ans_prefix_len, ans_node + + def _insert_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, value: torch.Tensor + ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + child_key = node._compute_key(key) + if child_key in node.children.keys(): + child: TreeNode = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + if prefix_len == len(key): + if prefix_len == len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, child + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if child.is_leaf(): + self.evict_tree_set.add(child) + + return prefix_len, split_parent_node + else: + assert False, "can not run to here" + + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, new_node + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return (prefix_len, child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + if update_refs: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key)) + + ans_node = None + + while len(handle_stack) != 0: + node, key = handle_stack.popleft() + ans_tuple = self._match_prefix_helper_no_recursion( + node=node, key=key, ans_value_list=ans_value_list, update_refs=update_refs + ) + if isinstance(ans_tuple, tuple): + new_node, new_key = ans_tuple + handle_stack.append((new_node, new_key)) + else: + ans_node = ans_tuple + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + return ans_node + + def _match_prefix_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + if len(key) == 0: + return node + + child_key = node._compute_key(key) + if child_key not in node.children.keys(): + return node + else: + child = node.children[child_key] + prefix_len = match(key, child.token_id_key) + prefix_len = self._align_prefix_len(prefix_len) + if prefix_len == 0: + return node + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return (child, key[prefix_len:]) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: + parent_node = child_node.parent + if ( + parent_node is None + or parent_node == self.root_node + or parent_node.ref_counter != 0 + or len(parent_node.children) != 1 + or child_node.ref_counter != 0 + ): + return None + + if child_node.is_leaf(): + self.evict_tree_set.discard(child_node) + + child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key]) + child_node.token_mem_index_value = torch.cat( + [parent_node.token_mem_index_value, child_node.token_mem_index_value] + ) + child_node.node_value_len = len(child_node.token_mem_index_value) + child_node.time_id = max(parent_node.time_id, child_node.time_id) + + grandparent_node = parent_node.parent + key_in_grandparent = grandparent_node._compute_key(parent_node.token_id_key) + grandparent_node.children[key_in_grandparent] = child_node + child_node.parent = grandparent_node + + parent_node.parent = None + + if child_node.is_leaf(): + self.evict_tree_set.add(child_node) + + return child_node + + def merge_unreferenced_nodes(self): + worklist = collections.deque( + [ + node + for node in self.evict_tree_set + if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node + ] + ) + + while worklist: + node = worklist.popleft() + if node.parent is None: + continue + merged_node = self._try_merge(node) + if merged_node: + worklist.append(merged_node) + + def assert_leafs_is_right(self): + for node in self.evict_tree_set: + if node.is_leaf() and node.ref_counter == 0: + a = node.token_mem_index_value.cuda() + assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) + + def clear_tree_nodes(self): + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def add_node_ref_counter(self, node: TreeNode): + if node is None: + return + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + node = node.parent + + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]: + if node is None: + return None + + ans_list = [] + while node is not None: + ans_list.append(node.token_mem_index_value) + node = node.parent + + ans_list.reverse() + return torch.concat(ans_list, dim=0) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 88b099459b..8b461ba86a 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -404,7 +404,7 @@ def merge_unreferenced_nodes(self): def assert_leafs_is_right(self): for node in self.evict_tree_set: if node.is_leaf() and node.ref_counter == 0: - a = node.token_mem_index_value.cuda() + a = node.token_mem_index_value.to(device=self.mem_manager.target_device) assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) def clear_tree_nodes(self): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 7c19b5748e..31dfdb1013 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -38,8 +38,8 @@ class InferenceContext: vocab_size = None cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None - overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 - cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream + overlap_stream: Any = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 + cpu_kv_cache_stream: Any = None # 用 cpu kv cache 操作的 stream is_linear_att_mixed_model: bool = False # 标记模型是否是full att 混合 linear att 的混合模型。 def register( @@ -66,20 +66,22 @@ def register( self.is_linear_att_mixed_model = isinstance(self.req_manager, ReqManagerForMamba) + self.backend_runtime = self.backend.backend_runtime + return def init_cpu_embed_cache_client(self): self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) return - def get_overlap_stream(self) -> torch.cuda.Stream: + def get_overlap_stream(self) -> Any: if self.overlap_stream is None: - self.overlap_stream = torch.cuda.Stream() + self.overlap_stream = self.backend_runtime.create_stream() return self.overlap_stream - def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: + def get_cpu_kv_cache_stream(self) -> Any: if self.cpu_kv_cache_stream is None: - self.cpu_kv_cache_stream = torch.cuda.Stream() + self.cpu_kv_cache_stream = self.backend_runtime.create_stream() return self.cpu_kv_cache_stream def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: @@ -373,7 +375,7 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L assert len(b_req_idx) == len(big_page_buffer_ids) big_page_buffer_ids = torch.tensor(big_page_buffer_ids, dtype=torch.int32, requires_grad=False, device="cpu") - big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True) + big_page_buffer_ids = big_page_buffer_ids.to(device=b_req_idx.device, non_blocking=True) from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer @@ -513,6 +515,7 @@ def __init__( self.shm_index = shm_index self.multimodal_params = multimodal_params self.vocab_size = vocab_size + self.last_kv_mem_index = -1 # 请求需要被暂停 self.wait_pause = False @@ -626,6 +629,7 @@ def _match_radix_cache(self): # 从 cpu 到 gpu 是流内阻塞操作 g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 + self.last_kv_mem_index = value_tensor[-1].item() if ready_cache_len > 0 else -1 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 self.shm_req.shm_cur_kv_len = self.cur_kv_len diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ca982ec0f0..a468a5f1ca 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -6,6 +6,7 @@ import torch.distributed as dist from typing import List, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig +from lightllm.platform import get_backend from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger from lightllm.models import get_model @@ -17,10 +18,11 @@ from lightllm.common.linear_att_cache_manager import LinearAttCacheManager from lightllm.server.router.dynamic_prompt.linear_att_radix_cache import LinearAttPagedRadixCache from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify from lightllm.utils.dist_utils import init_distributed_env -from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.envs_utils import get_page_size, get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import AbortedReqCmd, StopStrMatchedReqCmd from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -81,7 +83,10 @@ def __init__(self) -> None: self._radix_tree_merge_counter: int = 0 self._enable_radix_tree_timer_merge: bool = enable_radix_tree_timer_merge() self._radix_tree_merge_update_delta: int = get_radix_tree_merge_update_delta() - pass + + @property + def backend_runtime(self): + return get_backend().runtime def init_model(self, kvargs): self.args: StartArgs = kvargs.get("args", None) @@ -199,11 +204,12 @@ def init_model(self, kvargs): linear_att_small_page_buffers=self.linear_att_cache_manager, ) else: - self.radix_cache = RadixCache( + radix_cacahe_class = PagedRadixCache if get_page_size() > 1 else RadixCache + self.radix_cache = radix_cacahe_class( unique_name=get_unique_server_name(), total_token_num=self.model.mem_manager.size, rank_in_node=self.rank_in_node, - mem_manager=self.model.mem_manager, + mem_manager=self.model.mem_manager ) if "prompt_cache_kv_buffer" in model_cfg: @@ -220,26 +226,31 @@ def init_model(self, kvargs): vocab_size=self.model.vocab_size, ) + device = self.backend_runtime.target_device() + # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 if self.dp_size > 1: - self.dp_reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - self.dp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) + self.dp_reduce_tensor = torch.tensor([0], dtype=torch.int32, device=device, requires_grad=False) + self.dp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device=device, requires_grad=False) self.dp_all_gather_tensor = torch.tensor( - [0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False + [0 for _ in range(self.global_world_size)], + dtype=torch.int32, + device=device, + requires_grad=False, ) # 用于协同读取 ShmObjsIOBuffer 中的请求信息的通信tensor和通信组对象。 - self.node_broadcast_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) - self.node_nccl_group = create_new_group_for_current_node("nccl") + self.node_broadcast_tensor = torch.tensor([0], dtype=torch.int32, device=device, requires_grad=False) + self.node_nccl_group = create_new_group_for_current_node(self.backend_runtime.dist_backend) # 用于在多节点tp模式下协同读取 ShmObjsIOBuffer 中的请求信息的通信tensor和通信组对象。 if self.is_multinode_tp: - self.multinode_tp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device="cuda") + self.multinode_tp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device=device) self.multinode_tp_all_gather_tensor = torch.tensor( - [0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False + [0 for _ in range(self.global_world_size)], dtype=torch.int32, device=device, requires_grad=False ) self.multinode_tp_nccl_group = dist.new_group( - [rank for rank in range(self.global_world_size)], backend="nccl" + [rank for rank in range(self.global_world_size)], backend=self.backend_runtime.dist_backend ) if ( @@ -279,7 +290,7 @@ def init_dp_kv_shared(self): from lightllm.server.router.model_infer.mode_backend.dp_backend.dp_shared_kv_trans import DPKVSharedMoudle from lightllm.common.kv_cache_mem_manager import MemoryManager - torch.cuda.set_device(get_current_device_id()) + self.backend_runtime.set_device(get_current_device_id()) self.dp_kv_shared_module = DPKVSharedMoudle( max_req_num=self.args.running_max_req_size, @@ -481,7 +492,7 @@ def _read_nixl_trans_io_buffer_and_update_req_status(self): ) # to do 这个地方是否需要加流同步 req_to_next_token_ids[req.req_idx, 0:1].fill_(obj.first_gen_token_id) - torch.cuda.current_stream().synchronize() + self.backend_runtime.current_stream().synchronize() InferReqUpdatePack(req_obj=req, output_len=req.cur_output_len).handle( next_token_id=obj.first_gen_token_id, next_token_logprob=obj.first_gen_token_logprob, @@ -807,7 +818,7 @@ def _sample_and_scatter_token( if is_prefill: b_has_out = g_pin_mem_manager.gen_from_list( key="b_has_out", data=b_prefill_has_output_cpu, dtype=torch.bool - ).cuda(non_blocking=True) + ).to(device=logits.device, non_blocking=True) scatter_token( next_token_ids=next_token_ids, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 60045fab6c..3a9bf79bbe 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -51,7 +51,7 @@ def __init__(self) -> None: return def infer_loop(self): - torch.cuda.set_device(get_current_device_id()) + self.backend_runtime.set_device(get_current_device_id()) try: while True: event_pack = self.overlap_event_manager.get_overlap_event_pack() @@ -74,7 +74,7 @@ def infer_loop(self): if run_way.is_prefill(): # 进行一次流同步,保证 _try_read_new_reqs 中的一些算子操作,必然已经完成。 # 防止后续的推理流程读取到显存中可能存在错误的数据。 - g_infer_context.get_overlap_stream().wait_stream(torch.cuda.current_stream()) + g_infer_context.get_overlap_stream().wait_stream(self.backend_runtime.current_stream()) self.prefill( event_pack=event_pack, prefill_reqs=prefill_reqs, @@ -83,7 +83,7 @@ def infer_loop(self): elif run_way.is_decode(): # 进行一次流同步,保证 _try_read_new_reqs 中的一些算子操作,必然已经完成。 # 防止后续的推理流程读取到显存中可能存在错误的数据。 - g_infer_context.get_overlap_stream().wait_stream(torch.cuda.current_stream()) + g_infer_context.get_overlap_stream().wait_stream(self.backend_runtime.current_stream()) self.decode( event_pack=event_pack, decode_reqs=decode_reqs, @@ -107,7 +107,7 @@ def prefill_normal( ): # 第一阶段: 模型推理 model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, @@ -122,7 +122,7 @@ def prefill_normal( b_req_idx=model_input.b_req_idx, reqs=run_reqs, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() # 第二阶段 @@ -150,7 +150,7 @@ def decode_normal( decode_reqs: List[InferReq], ): model_input, run_reqs = prepare_decode_inputs(decode_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, @@ -160,7 +160,7 @@ def decode_normal( is_prefill=False, mask_func=self.decode_mask_func, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() # 第二阶段 @@ -188,7 +188,7 @@ def prefill_mtp( prefill_reqs: List[InferReq], ): model_input, run_reqs = prepare_prefill_inputs(prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( logits=model_output.logits, @@ -207,7 +207,7 @@ def prefill_mtp( b_req_idx=model_input.b_req_idx, reqs=run_reqs, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() # 第二阶段 @@ -241,7 +241,7 @@ def decode_mtp( """ model_input, run_reqs = prepare_decode_inputs(decode_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index model_output = self.model.forward(model_input) next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id) @@ -251,7 +251,7 @@ def decode_mtp( key="b_req_mtp_start_loc", data=b_req_mtp_start_loc, dtype=torch.int32, - ).cuda(non_blocking=True) + ).to(device=next_token_ids.device, non_blocking=True) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, @@ -266,7 +266,7 @@ def decode_mtp( key="mtp_accept_len", gpu_tensor=mtp_accept_len, ) - verify_event = torch.cuda.Event() + verify_event = self.backend_runtime.create_event() verify_event.record() next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem( @@ -287,7 +287,7 @@ def decode_mtp( next_token_ids=next_token_ids, mask=accepted_index == 1, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() # 第二阶段 @@ -389,7 +389,7 @@ def _draft_decode_eagle( g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_reqs * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(num_reqs * self.mtp_step) g_infer_state_lock.release() - eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) + eagle_mem_indexes = eagle_mem_indexes_cpu.to(device=next_token_ids.device, non_blocking=True) # share some inference info with the main model draft_model_input = main_model_input diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index d8de132cac..4e97da85a1 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -36,7 +36,8 @@ def init_custom(self): self.sorted_tokens = SortedList( [(token_str, token_id) for token_str, token_id in vob_dict.items()], key=lambda x: x[0] ) - self.token_indexes = torch.tensor([e[1] for e in self.sorted_tokens], dtype=torch.int64, device="cuda") + self.token_indexes = torch.tensor( + [e[1] for e in self.sorted_tokens], dtype=torch.int64, device=self.model.target_device) return def _decode_mask_callback(self, run_reqs: List[InferReq], logits: torch.Tensor): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py index 696452b419..153b01dd66 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py @@ -4,6 +4,8 @@ import time from typing import Dict, List, Tuple, Optional, Union from rpyc.utils.classic import obtain + +from lightllm.platform import get_backend from .decode_impl import DecodeNode from lightllm.common.basemodel.infer_lock import acquire_lock_until_ready, release_acquired_lock, g_router_lock from .decode_task_cache import g_kv_move_task_cache, g_success_kv_move_task_cache @@ -23,7 +25,7 @@ def __init__(self, backend: DecodeNode) -> None: return def on_connect(self, conn): - torch.cuda.set_device(f"cuda:{self.device_id}") + get_backend().runtime.set_device(self.device_id) return def judge_token_is_ok(self, key_len, max_new_token): diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py index cdca638873..f5e7baade1 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py @@ -8,6 +8,7 @@ from torch.distributed import TCPStore from datetime import timedelta from typing import List, Dict, Union +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup @@ -43,7 +44,7 @@ def _handle_kvmove_task( move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id] ) logger.info(f"trans finished: {move_tasks[0].to_decode_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() + get_backend().runtime.synchronize() logger.info(f"trans cost time: {(time.time() - start)}, {move_tasks[0].to_decode_log_info()}") task_out_queue.put("ok") except BaseException as e: @@ -71,7 +72,7 @@ def _handle_prefill_join( result_list = [] def async_connect(): - torch.cuda.set_device(node_info.decode_device_id) + get_backend().runtime.set_device(node_info.decode_device_id) group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=False, store=store_client) comm = PyNcclCommunicator(group, node_info.decode_device_id) result_list.append(comm) @@ -107,7 +108,7 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp. ) try: - torch.cuda.set_device(device_id) + get_backend().runtime.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py index 62609c4c91..b646b713d5 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/p2p_fix.py @@ -12,6 +12,8 @@ from torch.multiprocessing.reductions import storage_from_cache, shared_cache, StorageWeakRef from torch.multiprocessing.reductions import reduce_nested_tensor, reduce_sparse_tensor, rebuild_tensor +from lightllm.platform import get_backend + def p2p_fix_rebuild_cuda_tensor( tensor_cls, @@ -35,7 +37,7 @@ def p2p_fix_rebuild_cuda_tensor( # 得到的指针可能不是接收进程当前上下文设备可以访问的,所以在这里 # hack 修改了使用的 storage_device,这样后续tritonkernel同时 # 访问几张显卡上的数据,进行p2p操作就不会出问题了。 - storage_device = torch.cuda.current_device() + storage_device = get_backend().runtime.current_device() # If storage_handle is None, storage points to nullptr. if storage_handle is None or storage_size_bytes == 0: storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py index 1f2dd52c5a..14e0210f49 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_infer_rpyc.py @@ -3,6 +3,7 @@ import rpyc from typing import Dict, List, Tuple from rpyc.utils.classic import obtain +from lightllm.platform import get_backend from .prefill_impl import ChunckedPrefillForPrefillNode from lightllm.common.basemodel.infer_lock import g_router_lock, acquire_lock_until_ready, release_acquired_lock from .prefill_task_cache import g_kv_move_task_cache @@ -21,7 +22,7 @@ def __init__(self, backend: ChunckedPrefillForPrefillNode) -> None: return def on_connect(self, conn): - torch.cuda.set_device(f"cuda:{self.device_id}") + get_backend().runtime.set_device(self.device_id) return # pd 分离模式会使用的一些接口,用于做一些全局信息管理 diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py index a328e3e080..5ad02891fc 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_trans_process.py @@ -8,6 +8,7 @@ from torch.distributed import TCPStore from datetime import timedelta from typing import List, Dict, Union +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import KVMoveTask, PDTransJoinInfo, PDTransLeaveInfo, KVMoveTaskGroup @@ -42,7 +43,7 @@ def _handle_kvmove_task( else: cur_mem.send_to_decode_node(move_tasks, mem_managers, dp_size_in_node, connect_id_to_comm[connect_id]) logger.info(f"trans finished: {move_tasks[0].to_prefill_log_info()} move len: {total_move_kv_len}") - torch.cuda.synchronize() + get_backend().runtime.synchronize() logger.info( f"trans cost time: {(time.time() - start)}," f"move_total_kv_len: {total_move_kv_len}, {move_tasks[0].to_prefill_log_info()}" @@ -67,7 +68,7 @@ def _handle_decode_join( result_list = [] def async_connect(): - torch.cuda.set_device(node_info.prefill_device_id) + get_backend().runtime.set_device(node_info.prefill_device_id) group = StatelessP2PProcessGroup.create(src_id=src_id, dest_id=dest_id, is_server=True, store=store) comm = PyNcclCommunicator(group, node_info.prefill_device_id) result_list.append(comm) @@ -109,7 +110,7 @@ def _init_env( ) try: - torch.cuda.set_device(device_id) + get_backend().runtime.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) master_store = TCPStore( host_name=store_ip, port=store_port, is_master=True, use_libuv=True, timeout=timedelta(seconds=30) diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 5a179cb620..e0d04c3d0a 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -39,7 +39,7 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq group_reqs, is_chuncked_mode=not self.disable_chunked_prefill ) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) logits = model_output.logits @@ -50,14 +50,14 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq b_req_idx = [req.req_idx for req in run_reqs] b_has_out = [model_input.b_prefill_has_output_cpu[i] for i in batch_idx] - batch_idx = g_pin_mem_manager.gen_from_list(key="batch_idx_", data=batch_idx, dtype=torch.int64).cuda( - non_blocking=True + batch_idx = g_pin_mem_manager.gen_from_list(key="batch_idx_", data=batch_idx, dtype=torch.int64).to( + device=self.backend_runtime.target_device(), non_blocking=True ) - b_req_idx = g_pin_mem_manager.gen_from_list(key="b_req_idx_", data=b_req_idx, dtype=torch.int32).cuda( - non_blocking=True + b_req_idx = g_pin_mem_manager.gen_from_list(key="b_req_idx_", data=b_req_idx, dtype=torch.int32).to( + device=self.backend_runtime.target_device(), non_blocking=True ) - b_has_out = g_pin_mem_manager.gen_from_list(key="b_has_out_", data=b_has_out, dtype=torch.bool).cuda( - non_blocking=True + b_has_out = g_pin_mem_manager.gen_from_list(key="b_has_out_", data=b_has_out, dtype=torch.bool).to( + device=self.backend_runtime.target_device(), non_blocking=True ) logits = logits[batch_idx] @@ -77,7 +77,7 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq next_token_ids=next_token_ids, next_token_logprobs=next_token_logprobs ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() # 第二阶段 @@ -163,7 +163,7 @@ def _diverse_pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: b pack = InferReqUpdatePack(req_obj=req_obj, output_len=pre_master_req_pack.output_len) update_func_objs.append(pack) - torch.cuda.current_stream().synchronize() + self.backend_runtime.current_stream().synchronize() return update_func_objs def _master_req_to_radix_cache(self, master_req: InferReq): diff --git a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py index c83e8cd4a5..bfac284717 100644 --- a/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py @@ -85,7 +85,7 @@ def _init_reqs(self, reqs: List[Tuple]): return req_ids def infer_loop(self): - torch.cuda.set_device(get_current_device_id()) + self.backend_runtime.set_device(get_current_device_id()) try: while True: event_pack = self.overlap_event_manager.get_overlap_event_pack() @@ -116,7 +116,7 @@ def infer_loop(self): if run_way.is_prefill(): # 进行一次流同步,保证 _try_read_new_reqs 中的一些算子操作,必然已经完成。 # 防止后续的推理流程读取到显存中可能存在错误的数据。 - g_infer_context.get_overlap_stream().wait_stream(torch.cuda.current_stream()) + g_infer_context.get_overlap_stream().wait_stream(self.backend_runtime.current_stream()) self.prefill( event_pack=event_pack, prefill_reqs=prefill_reqs, @@ -125,7 +125,7 @@ def infer_loop(self): elif run_way.is_decode(): # 进行一次流同步,保证 _try_read_new_reqs 中的一些算子操作,必然已经完成。 # 防止后续的推理流程读取到显存中可能存在错误的数据。 - g_infer_context.get_overlap_stream().wait_stream(torch.cuda.current_stream()) + g_infer_context.get_overlap_stream().wait_stream(self.backend_runtime.current_stream()) self.decode( event_pack=event_pack, decode_reqs=decode_reqs, @@ -149,7 +149,7 @@ def prefill_normal( ): model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs) run_reqs_num = len(run_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( @@ -165,7 +165,7 @@ def prefill_normal( b_req_idx=model_input.b_req_idx[:run_reqs_num], reqs=run_reqs, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if run_reqs_num > 0: @@ -175,7 +175,7 @@ def prefill_normal( # 第三阶段 event_pack.notify_forward_and_wait_post_handle() - sync_event.synchronize() + self.backend_runtime.synchronize() self._post_handle( run_reqs=run_reqs, next_token_ids=next_token_ids_cpu, @@ -196,7 +196,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(req_objs=decode_reqs) model_input: ModelInput = model_input run_reqs_num = len(run_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) if run_reqs_num > 0: _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( @@ -207,7 +207,7 @@ def decode_normal(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq is_prefill=False, mask_func=None, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if run_reqs_num > 0: @@ -244,7 +244,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer _, ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -277,7 +277,7 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer if g_infer_context.is_linear_att_mixed_model: g_infer_context.copy_linear_att_state_to_cache_buffer(b_req_idx=b_req_idx, reqs=run_reqs) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if (req_num0 + req_num1) > 0: @@ -317,7 +317,7 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe model_input0: ModelInput = model_input0 model_input1: ModelInput = model_input1 - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -340,7 +340,7 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe is_prefill=False, mask_func=None, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if (req_num0 + req_num1) > 0: @@ -371,7 +371,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] # main model prefill model_input, run_reqs, _ = padded_prepare_prefill_inputs(prefill_reqs) req_num = len(run_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output: ModelOutput = self.model.forward(model_input) b_has_out_cpu = model_input.b_prefill_has_output_cpu[0:req_num] logits = model_output.logits[0:req_num, :] @@ -401,7 +401,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq] if req_num > 0: g_infer_context.copy_linear_att_state_to_cache_buffer(b_req_idx=b_req_idx, reqs=run_reqs) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if req_num > 0: @@ -436,7 +436,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): b_mtp_index_cpu = model_input.b_mtp_index req_num = len(run_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) mtp_accept_len, b_req_mtp_start_loc, next_token_ids = None, None, None if req_num > 0: @@ -455,7 +455,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): key="b_req_mtp_start_loc", data=b_req_mtp_start_loc, dtype=torch.int32, - ).cuda(non_blocking=True) + ).to(device=self.backend_runtime.target_device(), non_blocking=True) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, @@ -471,7 +471,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): gpu_tensor=mtp_accept_len, ) - verify_event = torch.cuda.Event() + verify_event = self.backend_runtime.create_event() verify_event.record() eagle_mem_indexes_cpu = self._draft_decode_func( @@ -489,7 +489,7 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): mask=accepted_index == 1, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if req_num > 0: @@ -596,7 +596,7 @@ def _draft_decode_eagle( g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) g_infer_state_lock.release() - eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) + eagle_mem_indexes = eagle_mem_indexes_cpu.to(device=next_token_ids.device, non_blocking=True) # process the draft model output for _step in range(self.mtp_step): @@ -644,7 +644,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I run_reqs1, _, ) = padded_overlap_prepare_prefill_inputs(prefill_reqs) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_prefill(model_input0, model_input1) logits0 = model_output0.logits logits1 = model_output1.logits @@ -708,7 +708,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I _b_req_idx = torch.cat((model_input0.b_req_idx[0:req_num0], model_input1.b_req_idx[0:req_num1]), dim=0) g_infer_context.copy_linear_att_state_to_cache_buffer(b_req_idx=_b_req_idx, reqs=run_reqs) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if req_num0 + req_num1 > 0: @@ -746,7 +746,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf all_next_token_ids = [] b_mtp_index_cpu0 = model_input0.b_mtp_index b_mtp_index_cpu1 = model_input1.b_mtp_index - with torch.cuda.stream(g_infer_context.get_overlap_stream()): + with self.backend_runtime.stream(g_infer_context.get_overlap_stream()): model_output0, model_output1 = self.model.microbatch_overlap_decode(model_input0, model_input1) logits0 = model_output0.logits @@ -771,7 +771,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf key="b_req_mtp_start_loc", data=b_req_mtp_start_loc, dtype=torch.int32, - ).cuda(non_blocking=True) + ).to(device=self.backend_runtime.target_device(), non_blocking=True) mtp_accept_len, accepted_index = self._verify_mtp_v2( new_next_token_ids=next_token_ids, @@ -788,7 +788,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf ) all_next_token_ids.append(next_token_ids) - verify_event = torch.cuda.Event() + verify_event = self.backend_runtime.create_event() verify_event.record() eagle_mem_indexes_cpu = self._draft_decode_overlap_func( @@ -810,7 +810,7 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf next_token_ids=next_token_ids, mask=accepted_index == 1, ) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() if req_num0 + req_num1 > 0: @@ -958,7 +958,7 @@ def _draft_decode_eagle_overlap( g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(real_req_num * self.mtp_step) eagle_mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(real_req_num * self.mtp_step) g_infer_state_lock.release() - eagle_mem_indexes = eagle_mem_indexes_cpu.cuda(non_blocking=True) + eagle_mem_indexes = eagle_mem_indexes_cpu.to(device=next_token_ids.device, non_blocking=True) eagle_mem_indexes0 = eagle_mem_indexes[0 : real_req_num0 * self.mtp_step] eagle_mem_indexes1 = eagle_mem_indexes[real_req_num0 * self.mtp_step : real_req_num * self.mtp_step] diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 03ac4cfb05..1645879970 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -32,6 +32,7 @@ def padded_prepare_prefill_inputs( b_req_idx = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] batch_multimodal_params = [] b_ready_cache_len = [] b_mtp_index = [] @@ -68,6 +69,7 @@ def padded_prepare_prefill_inputs( b_ready_cache_len.append(0) total_token_num += 1 prefix_total_token_num += 0 + b_last_mem_index.append(req.last_kv_mem_index) batch_multimodal_params.append({"images": [], "audios": []}) max_kv_seq_len = max(b_seq_len) @@ -79,6 +81,7 @@ def padded_prepare_prefill_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu") b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu") b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len @@ -86,8 +89,18 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + input_ids.shape[0] - padded_req_num, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len[: len(req_objs)], b_ready_cache_len[: len(req_objs)] + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() if padded_req_num > 0: @@ -140,6 +153,7 @@ def padded_prepare_decode_inputs( b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] args_mtp_step = get_env_start_args().mtp_step batch_multimodal_params = [] for req in req_objs: @@ -162,7 +176,7 @@ def padded_prepare_decode_inputs( b_q_seq_len.append(1) b_mtp_index.append(step + 1) batch_multimodal_params.append(req.multimodal_params) - + b_last_mem_index.append(req.last_kv_mem_index) # padding fake req for decode for _ in range(padded_req_num): seq_len = 2 @@ -187,13 +201,23 @@ def padded_prepare_decode_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token padded_mem_indexes_num = padded_req_num * (args_mtp_step + 1) g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_mem_indexes_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_mem_indexes_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + b_seq_len.shape[0] - padded_mem_indexes_num, b_seq_len[: len(b_last_mem_index)] + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0] - padded_mem_indexes_num, + b_seq_len[: len(b_last_mem_index)], + b_last_mem_index=b_last_mem_index, + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() if padded_mem_indexes_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 41e89da9ab..4a9bf973fe 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -24,7 +24,9 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): skip_top_p, exist_req_use_random_seed, ) = _get_post_sample_tensors(reqs) - eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) + target_device = g_infer_context.backend_runtime.target_device() + + eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).to(device=target_device, non_blocking=True) sampling_params_manager = g_infer_context.req_manager.req_sampling_params_manager @@ -102,7 +104,7 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor probs_sum = torch.cumsum(probs_sort, dim=-1) probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 - probs_sort[torch.arange(0, probs.shape[-1], device="cuda").view(1, -1) >= top_ks.view(-1, 1)] = 0.0 + probs_sort[torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1)] = 0.0 return probs_sort, probs_idx @@ -214,15 +216,16 @@ def _get_post_sample_tensors(reqs: List[InferReq]): key="cu_invalid_token_num", data=cu_invalid_token_num, dtype=torch.int32 ) + target_device = g_infer_context.backend_runtime.target_device() return ( - req_idxes_cpu.cuda(non_blocking=True), - temperatures_cpu.cuda(non_blocking=True), - top_ps_cpu.cuda(non_blocking=True), - top_ks_cpu.cuda(non_blocking=True), - length_penalty_param_cpu.cuda(non_blocking=True), - mask_eos_reqs_cpu.cuda(non_blocking=True), - invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, - cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, + req_idxes_cpu.to(device=target_device, non_blocking=True), + temperatures_cpu.to(device=target_device, non_blocking=True), + top_ps_cpu.to(device=target_device, non_blocking=True), + top_ks_cpu.to(device=target_device, non_blocking=True), + length_penalty_param_cpu.to(device=target_device, non_blocking=True), + mask_eos_reqs_cpu.to(device=target_device, non_blocking=True), + invalid_token_ids_cpu.to(device=target_device, non_blocking=True) if has_invalid_token_ids else None, + cu_invalid_token_num_cpu.to(device=target_device, non_blocking=True) if has_invalid_token_ids else None, is_all_greedy, has_invalid_token_ids, skip_top_k, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 4eb8c7e1e6..e67a0db284 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -18,6 +18,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> b_req_idx = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] batch_multimodal_params = [] b_ready_cache_len = [] b_mtp_index = [] @@ -26,6 +27,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> for req in req_objs: run_reqs.append(req) batch_multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) b_req_idx.append(req.req_idx) if is_chuncked_mode: @@ -57,6 +59,7 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu") b_q_seq_len = torch.tensor(b_q_seq_len, dtype=torch.int32, device="cpu") b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len @@ -64,8 +67,16 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0], b_seq_len, b_ready_cache_len + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices(input_ids.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() model_input = ModelInput( @@ -97,6 +108,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index = [] b_seq_len = [] b_q_seq_len = [] + b_last_mem_index = [] multimodal_params = [] for req in req_objs: run_reqs.append(req) @@ -108,6 +120,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In total_token_num += seq_len b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -118,13 +131,14 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(step + 1) multimodal_params.append(req.multimodal_params) b_q_seq_len.append(1) - + b_last_mem_index.append(req.last_kv_mem_index) max_kv_seq_len = max(b_seq_len) max_q_seq_len = max(b_q_seq_len) b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") if enable_diverse_mode_gqa_decode_fast_kernel(): b_shared_seq_len, b_mark_shared_group = build_diverse_shared_group_infos(run_reqs=run_reqs) @@ -135,8 +149,13 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i].item() g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 5c923b5e60..8676f4dd06 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -4,13 +4,13 @@ import dataclasses import bisect from functools import lru_cache -from typing import Optional, List, Deque +from typing import Any, Optional, List, Deque from collections import deque from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.utils.config_utils import is_linear_att_mixed_model from lightllm.utils.envs_utils import get_env_start_args from ..infer_batch import InferReq -from lightllm.utils.dist_utils import create_new_group_for_current_dp +from lightllm.utils.dist_utils import create_new_group_for_current_dp from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu, load_cpu_kv_to_gpu from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.log_utils import init_logger @@ -24,16 +24,21 @@ def __init__(self, backend): from .base_backend import ModeBackend self.backend: ModeBackend = backend + + # get backend runtime and target device from backend parameter + self.backend_runtime = self.backend.backend_runtime + self.target_device = self.backend_runtime.target_device() + self.gloo_group = create_new_group_for_current_dp("gloo") self.filter_group = create_new_group_for_current_dp("gloo") - self.init_sync_group = create_new_group_for_current_dp("nccl") + self.init_sync_group = create_new_group_for_current_dp(self.backend_runtime.dist_backend) dist.barrier(group=self.init_sync_group) - self.offload_sync_group = create_new_group_for_current_dp("nccl") + self.offload_sync_group = create_new_group_for_current_dp(self.backend_runtime.dist_backend) dist.barrier(group=self.offload_sync_group) - self.offload_sync_tensor = torch.empty((1,), dtype=torch.int32, device="cuda") + self.offload_sync_tensor = torch.empty((1,), dtype=torch.int32, device=self.target_device) - self.page_index_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.int32, device="cuda") - self.page_ready_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.bool, device="cuda") + self.page_index_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.int32, device=self.target_device) + self.page_ready_buffer = torch.empty((1024 * 1024 * 4,), dtype=torch.bool, device=self.target_device) self.cpu_cache_handle_queue: Deque[TransTask] = deque() self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=False) @@ -108,10 +113,10 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): mem_manager = self.backend.model.mem_manager req_manager = self.backend.model.req_manager - mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) - page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda( - non_blocking=True - ) + mem_indexes_cuda = mem_indexes.to(device=self.target_device, non_blocking=True) + page_indexes_cuda = torch.tensor( + need_pages, dtype=torch.int32, device="cpu" + ).to(device=mem_indexes_cuda.device, non_blocking=True) # 因为在支持 linear att 以后,所有的页面加载必须要按照 page页面的整数倍来做, # 不然可能导致页面数据不完整,导致无法从kv中恢复完整的 linear att状态,所以 # 这里需要进行pad操作,使操作的页面是完整的。 @@ -141,7 +146,7 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): req=req, ) - torch.cuda.current_stream().synchronize() + self.backend_runtime.current_stream().synchronize() if self.backend.is_master_in_dp: req.shm_req.shm_cur_kv_len = req.cur_kv_len @@ -213,9 +218,9 @@ def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> L return true_finished_reqs def _start_kv_cache_offload_task( - self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream + self, req: InferReq, cpu_kv_cache_stream: Any = None, ) -> Optional["TransTask"]: - with torch.cuda.stream(cpu_kv_cache_stream): + with self.backend_runtime.stream(cpu_kv_cache_stream): # 综合考虑后只对prompt做缓存管理,不包含decode内容,这里与radix cache不一致 token_hash_list = req.shm_req.token_hash_list.get_all() page_len_list = req.shm_req.token_hash_page_len_list.get_all() @@ -290,7 +295,7 @@ def _start_kv_cache_offload_task( if self.backend.dp_world_size > 1: dist.all_reduce(self.offload_sync_tensor, op=dist.ReduceOp.MAX, group=self.offload_sync_group) - sync_event = torch.cuda.Event() + sync_event = self.backend_runtime.create_event() sync_event.record() req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING trans_task = TransTask( @@ -364,4 +369,4 @@ class TransTask: page_indexes: torch.Tensor page_readies: torch.Tensor req_obj: InferReq - sync_event: torch.cuda.Event + sync_event: Any diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index b04cbb900a..4f50f99780 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -6,7 +6,8 @@ import collections import queue import pickle -from typing import List, Dict, Union, Deque, Optional +from typing import Any, List, Dict, Union, Deque, Optional +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import ( @@ -50,7 +51,7 @@ def _init_env( torch.backends.cudnn.enabled = False try: - torch.cuda.set_device(device_id) + get_backend().runtime.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") @@ -103,7 +104,8 @@ def __init__( kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size ) - self.copy_cuda_stream = torch.cuda.Stream() + self.platform_backend = get_backend() + self.copy_cuda_stream = self.platform_backend.runtime.create_stream() self.transporter = NixlKVTransporter( node_id=self.args.pd_node_id, tp_idx=device_id, kv_move_buffer=kv_move_buffer ) @@ -202,7 +204,7 @@ def dispatch_task_loop(self): def accept_peer_task_loop( self, ): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: if len(self.waiting_dict) == 0: time.sleep(0.001) @@ -278,7 +280,7 @@ def _check_tasks_time_out(self): @log_exception def read_peer_kv_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: page_index = self.page_index_queue.get() local_trans_task = self.read_peer_kv_queue.get() @@ -328,11 +330,11 @@ def update_task_status_loop( @log_exception def read_page_to_mems_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: trans_task: NIXLChunckedTransTask = self.ready_page_task_queue.get() # 将数据写回 mem manger - with torch.cuda.stream(stream=self.copy_cuda_stream): + with self.platform_backend.runtime.stream(stream=self.copy_cuda_stream): cur_mem = self.mem_managers[self.device_id] cur_mem.read_page_kv_move_buffer_to_mem( mem_indexes=trans_task.mem_indexes, @@ -341,7 +343,7 @@ def read_page_to_mems_loop(self): mem_managers=self.mem_managers, dp_world_size=self.dp_world_size, ) - sync_event = torch.cuda.Event() + sync_event = self.platform_backend.runtime.create_event() sync_event.record() self.success_queue.put((sync_event, trans_task)) @@ -349,11 +351,11 @@ def read_page_to_mems_loop(self): @log_exception def success_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: sync_event, trans_task = self.success_queue.get() trans_task: NIXLChunckedTransTask = trans_task - sync_event: Optional[torch.cuda.Event] = sync_event + sync_event: Optional[Any] = sync_event # 兼容传输kv 数量为0的时候, sync_event 为 None的情况。 if sync_event is not None: sync_event.synchronize() @@ -370,7 +372,7 @@ def success_loop(self): @log_exception def fail_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: trans_task: NIXLChunckedTransTask = self.failed_queue.get() diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 063ce5c6a9..679c97a893 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -6,7 +6,8 @@ import collections import queue import pickle -from typing import List, Dict, Union, Deque, Optional +from typing import Any, List, Dict, Union, Deque, Optional +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskRet @@ -43,7 +44,7 @@ def _init_env( torch.backends.cudnn.enabled = False try: - torch.cuda.set_device(device_id) + get_backend().runtime.set_device(device_id) graceful_registry(inspect.currentframe().f_code.co_name) task_out_queue.put("proc_start") @@ -93,7 +94,8 @@ def __init__( kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer( page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size ) - self.copy_cuda_stream = torch.cuda.Stream() + self.platform_backend = get_backend() + self.copy_cuda_stream = self.platform_backend.runtime.create_stream() self.transporter = NixlKVTransporter( node_id=self.args.pd_node_id, tp_idx=device_id, kv_move_buffer=kv_move_buffer ) @@ -122,7 +124,7 @@ def __init__( @log_exception def recv_task_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: page_index = self.page_index_queue.get() @@ -138,12 +140,12 @@ def recv_task_loop(self): @log_exception def local_copy_kv_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: trans_task: NIXLChunckedTransTask = self.local_copy_kv_queue.get() # 将kv 数据拷贝到 page 上,然后传输给 decode node,让其进行读取。 - with torch.cuda.stream(stream=self.copy_cuda_stream): + with self.platform_backend.runtime.stream(stream=self.copy_cuda_stream): cur_mem = self.mem_managers[self.device_id] cur_mem.write_mem_to_page_kv_move_buffer( trans_task.mem_indexes, @@ -152,7 +154,7 @@ def local_copy_kv_loop(self): mem_managers=self.mem_managers, dp_world_size=self.dp_world_size, ) - sync_event = torch.cuda.Event() + sync_event = self.platform_backend.runtime.create_event() sync_event.record() self.notify_peer_read_kv_queue.put((sync_event, trans_task)) @@ -160,11 +162,11 @@ def local_copy_kv_loop(self): @log_exception def notify_peer_to_read_kv_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: sync_event, trans_task = self.notify_peer_read_kv_queue.get() trans_task: NIXLChunckedTransTask = trans_task - sync_event: torch.cuda.Event = sync_event + sync_event: Any = sync_event sync_event.synchronize() @@ -252,7 +254,7 @@ def _check_tasks_time_out(self): @log_exception def success_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: trans_task: NIXLChunckedTransTask = self.success_queue.get() # 写回后,回收页面 @@ -267,7 +269,7 @@ def success_loop(self): @log_exception def fail_loop(self): - torch.cuda.set_device(self.device_id) + self.platform_backend.runtime.set_device(self.device_id) while True: trans_task: NIXLChunckedTransTask = self.failed_queue.get() diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 884b5930b0..5f62d93277 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -1,4 +1,5 @@ import uuid +from lightllm.utils.envs_utils import get_page_size import numpy as np from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue @@ -38,9 +39,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() with g_router_lock.obj: + page_size = get_page_size() + page_remaining = len(self.cache_len_list) * (page_size - 1) if page_size > 1 else 0 ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens + < self.max_total_tokens - page_remaining ) ok_req_num = len(self.cache_len_list) <= self.running_max_req_size diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 92ca2e3836..9edfe62ce7 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -14,6 +14,7 @@ from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.vit.model import VisionTransformer +from lightllm.platform import get_backend from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel @@ -72,6 +73,7 @@ def exposed_init_model(self, kvargs): "quant_type": kvargs["quant_type"], "quant_cfg": kvargs["quant_cfg"], "max_batch_size": kvargs["max_batch_size"], + "device_id": self.device_id, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -110,7 +112,7 @@ def exposed_init_model(self, kvargs): raise Exception(f"can not support {self.model_type} now") self.model.load_model(weight_dir) - self.model = self.model.cuda() + self.model.setup_device(device_id=self.device_id) if not self.is_visual_only_mode: self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -236,7 +238,7 @@ def _infer_worker(self): """ 任务处理循环: 从队列中取出任务, 执行完成后通知调用者 """ - torch.cuda.set_device(self.device_id) + get_backend().runtime.set_device(self.device_id) while True: try: # 从队列获取任务, 阻塞等待 @@ -253,7 +255,6 @@ def _infer_worker(self): # 执行任务: 调用父类的forward方法处理图像 all_img_embeds, uuids, valid_ids = self._forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) if self.is_visual_only_mode: self._store_to_afs(all_img_embeds, valid_ids, images) @@ -272,7 +273,7 @@ def _store_to_cpu_cache(self, all_img_embeds, valid_ids, images): self.cpu_embed_cache_client.copy_vision_to_cache( embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache ) - cuda_event = torch.cuda.Event() + cuda_event = get_backend().runtime.create_event() cuda_event.record() image.cuda_event = cuda_event self.store_queue.put(image) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index c64e8a912b..cb856666d3 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -254,25 +254,78 @@ def get_model_architectures(model_path: str): return "unknown_architecture" +def _get_vocab_size_from_autoconfig(model_path: str) -> Optional[int]: + try: + from transformers import AutoConfig + + cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + text_cfg = getattr(cfg, "text_config", None) + if text_cfg is not None: + vocab_size = getattr(text_cfg, "vocab_size", None) + if vocab_size is not None: + return int(vocab_size) + vocab_size = getattr(cfg, "vocab_size", None) + if vocab_size is not None: + return int(vocab_size) + except Exception as e: + logger.warning(f"cannot get vocab_size from AutoConfig for {model_path}: {e}") + return None + + +def _get_vocab_size_from_tokenizer(model_path: str) -> Optional[int]: + try: + from lightllm.server.tokenizer import get_tokenizer + + start_args = get_env_start_args() + tokenizer = get_tokenizer( + model_path, + tokenizer_mode=start_args.tokenizer_mode, + trust_remote_code=start_args.trust_remote_code, + ) + vocab_size = getattr(tokenizer, "vocab_size", None) + if vocab_size is not None: + return int(vocab_size) + inner = getattr(tokenizer, "tokenizer", None) + if inner is not None: + vocab_size = getattr(inner, "vocab_size", None) + if vocab_size is not None: + return int(vocab_size) + except Exception as e: + logger.warning(f"cannot get vocab_size from tokenizer for {model_path}: {e}") + return None + + def get_vocab_size(model_path: str): try: config_json = get_config_json(model_path) # qwen3-omini special if "thinker_config" in config_json: config_json = config_json["thinker_config"] - if "llm_config" in config_json: - vocab_size = int(config_json["llm_config"]["vocab_size"]) + if "llm_config" in config_json and config_json["llm_config"].get("vocab_size") is not None: + return int(config_json["llm_config"]["vocab_size"]) + + vocab_size = _get_config_llm_keyvalue(model_path=model_path, key_name=["vocab_size"]) + if vocab_size is not None: + return int(vocab_size) + + if "vocab_size" in config_json and config_json["vocab_size"] is not None: + return int(config_json["vocab_size"]) + + vocab_size = _get_vocab_size_from_autoconfig(model_path) + if vocab_size is not None: return vocab_size - elif "text_config" in config_json: - vocab_size = int(config_json["text_config"]["vocab_size"]) + + vocab_size = _get_vocab_size_from_tokenizer(model_path) + if vocab_size is not None: + logger.warning( + f"vocab_size missing in config.json for {model_path}, " + f"using tokenizer vocab_size={vocab_size}" + ) return vocab_size - vocab_size = config_json["vocab_size"] - if not isinstance(vocab_size, int): - vocab_size = int(vocab_size) - return vocab_size - except: - logger.error("can not get vocab_size from config.json, return 0") - return 0 + except Exception as e: + logger.error(f"can not get vocab_size from {model_path}: {e}") + logger.error(f"can not get vocab_size from config.json, return 0 (model_path={model_path})") + return 0 def get_dtype(model_path: str): diff --git a/lightllm/utils/cpu_cache_host_register.py b/lightllm/utils/cpu_cache_host_register.py new file mode 100644 index 0000000000..da8ed79cbd --- /dev/null +++ b/lightllm/utils/cpu_cache_host_register.py @@ -0,0 +1,119 @@ +import os +import ctypes +import torch +from typing import Any, Callable, Dict, Tuple +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.platform import get_backend +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +HostRegisterWorker = Callable[[int, Tuple[int ,int], Any], None] + +_REGISTRY: Dict[str, HostRegisterWorker] = {} + + +def register_host_register_worker(backend_name: str): + + def decorator(fn: HostRegisterWorker) -> HostRegisterWorker: + if backend_name in _REGISTRY: + raise ValueError(f"HostRegisterWorker {backend_name} already registered!") + _REGISTRY[backend_name] = fn + return fn + + return decorator + + +def get_host_register_worker(): + backend_name = get_backend().name + try: + return _REGISTRY[backend_name] + except KeyError: + raise RuntimeError(f"platform {backend_name} is not registered!") + + +@register_host_register_worker("cuda") +def _cuda_worker(shm_ptr: int, tasks: Tuple[int ,int], handle: Any): + cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") + cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + cuda.cudaHostRegister.restype = ctypes.c_int + cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] + cuda.cudaHostGetDevicePointer.restype = ctypes.c_int + + cudaHostRegisterFlag = 3 + + torch.cuda.set_device(get_current_device_id()) + # TODO 这个地方的分块注册是否具备合法性和合理性。 + for offset, seg_len in tasks: + ptr = ctypes.c_void_p(shm_ptr + offset) + r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) + if r != 0: + raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") + handle.task_count += 1 + + if handle.device_ptr is None: + # 提前获取对应的指针对象,避免在wait后再获取,照成过长的阻塞等待。 + device_ptr = ctypes.c_void_p() + host_ptr = ctypes.c_void_p(shm_ptr) + res = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) + if res != 0: + raise Exception(f"cudaHostGetDevicePointer failed with error code {res}") + + logger.info( + f"cudaHostGetDevicePointer success, host_ptr={host_ptr.value}, device_ptr={device_ptr.value}" + ) + handle.device_ptr = device_ptr.value + + handle.tasks_finished.set() + +@register_host_register_worker("ascend") +def _npu_worker(shm_ptr: int, tasks: Tuple[int ,int], handle: Any): + import acl + + acl.init() + ret = acl.rt.set_device(get_current_device_id()) + assert ret == 0, f"acl.rt.set_device failed with error code {ret}" + + ACL_HOST_REGISTER_MAPPED = 0 + for offset, seg_len in tasks: + ptr = shm_ptr + offset + res = acl.rt.host_register(ptr, seg_len, ACL_HOST_REGISTER_MAPPED) + assert res[1] == 0, f"acl.rt.host_register failed with error code {res}" + handle.task_count += 1 + + handle.tasks_finished.set() + + +@register_host_register_worker("maca") +def _metax_worker(shm_ptr: int, tasks: Tuple[int ,int], handle: Any): + mc = ctypes.CDLL(os.path.join(os.getenv("MACA_PATH", "/opt/maca"), "lib/libmcruntime.so")) + mc.mcHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + mc.mcHostRegister.restype = ctypes.c_int + mc.mcHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] + mc.mcHostGetDevicePointer.restype = ctypes.c_int + + cudaHostRegisterFlag = 3 + + torch.cuda.set_device(get_current_device_id()) + # TODO 这个地方的分块注册是否具备合法性和合理性。 + for offset, seg_len in tasks: + ptr = ctypes.c_void_p(shm_ptr + offset) + r = mc.mcHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) + if r != 0: + raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") + handle.task_count += 1 + + if handle.device_ptr is None: + # 提前获取对应的指针对象,避免在wait后再获取,照成过长的阻塞等待。 + device_ptr = ctypes.c_void_p() + host_ptr = ctypes.c_void_p(shm_ptr) + res = mc.mcHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) + if res != 0: + raise Exception(f"mcHostGetDevicePointer failed with error code {res}") + + logger.info( + f"mcHostGetDevicePointer success, host_ptr={host_ptr.value}, device_ptr={device_ptr.value}" + ) + handle.device_ptr = device_ptr.value + + handle.tasks_finished.set() diff --git a/lightllm/utils/custom_kernel_utis.py b/lightllm/utils/custom_kernel_utis.py index 9a7578a243..17735386e8 100644 --- a/lightllm/utils/custom_kernel_utis.py +++ b/lightllm/utils/custom_kernel_utis.py @@ -3,6 +3,8 @@ import triton.language as tl from typing import List +from lightllm.platform import get_backend + def custom_cat(tensors): """ @@ -12,7 +14,7 @@ def custom_cat(tensors): if not isinstance(tensors, (list, tuple)): raise ValueError("Input must be a list of tensors") - assert tensors[0].is_cuda and len(tensors[0].shape) == 1 + assert len(tensors[0].shape) == 1, f"tensors[0].shape: {tensors[0].shape}" sizes = [t.shape[0] for t in tensors] dest_size = sum(sizes) out_tensor = torch.empty((dest_size,), dtype=tensors[0].dtype, device="cpu", pin_memory=True) @@ -21,7 +23,7 @@ def custom_cat(tensors): for t, size in zip(tensors, sizes): out_tensor[start_loc : (start_loc + size)].copy_(t, non_blocking=True) start_loc += size - torch.cuda.current_stream().synchronize() + get_backend().runtime.current_stream().synchronize() return out_tensor diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 43b10ec88b..1b47ddad5b 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -6,6 +6,8 @@ from enum import Enum from typing import Optional from functools import lru_cache +from lightllm.platform import get_backend +from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -104,15 +106,17 @@ def is_nvidia(): def get_current_device_name(): if torch.cuda.is_available() or is_musa(): device = torch.cuda.current_device() - gpu_name = torch.cuda.get_device_name(device) + device_name = torch.cuda.get_device_name(device) # 4090 trans to 4090 D - if "4090" in gpu_name and "4090 D" not in gpu_name: - gpu_name = gpu_name.replace("4090", "4090 D") - - gpu_name = gpu_name.replace(" ", "_") - return gpu_name + if "4090" in device_name and "4090 D" not in device_name: + device_name = device_name.replace("4090", "4090 D") + elif hasattr(torch, "npu") and torch.npu.is_available(): + device = torch.npu.current_device() + device_name = torch.npu.get_device_name(device) else: - return None + return "unknown_device" + + return device_name.replace(" ", "_") @lru_cache(maxsize=None) @@ -120,7 +124,7 @@ def init_p2p(device_index): """ torch 调用跨卡的to操作后,triton编译的算子便能自动操作跨卡tensor。 """ - num_gpus = torch.cuda.device_count() + num_gpus = get_backend().runtime.device_count() tensor = torch.zeros((1,)) tensor = tensor.to(f"cuda:{device_index}") for j in range(num_gpus): @@ -350,3 +354,9 @@ def get_platform(platform_name: Optional[str] = None) -> Platform: if platform is None: raise ValueError(f"Unknown platform name: {platform_name}") return platform + + +def get_target_device(device_id: Optional[int] = None) -> torch.device: + if device_id is None: + device_id = get_current_device_id() + return get_backend().runtime.target_device(device_id) diff --git a/lightllm/utils/dist_check_utils.py b/lightllm/utils/dist_check_utils.py index e11da07c8c..302a668544 100644 --- a/lightllm/utils/dist_check_utils.py +++ b/lightllm/utils/dist_check_utils.py @@ -9,6 +9,8 @@ import threading from typing import TYPE_CHECKING, Callable +from lightllm.platform import get_backend +from lightllm.utils.device_utils import get_target_device from lightllm.utils.log_utils import init_logger if TYPE_CHECKING: @@ -64,29 +66,29 @@ def _flashinfer_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> N import torch import torch.distributed as dist - cuda_device = torch.device(f"cuda:{process_rank}") - torch.cuda.set_device(cuda_device) + target_device = get_target_device(process_rank) + get_backend().runtime.set_device(target_device) dist.init_process_group( - "nccl", + get_backend().runtime.dist_backend, init_method=f"tcp://127.0.0.1:{init_tcp_port}", world_size=2, rank=process_rank, - device_id=cuda_device, + device_id=target_device, ) try: gloo_process_group = dist.new_group([0, 1], backend="gloo") from lightllm.distributed.flashinfer_all_reduce import FlashInferAllReduce - flashinfer_all_reduce = FlashInferAllReduce(gloo_process_group, cuda_device) + flashinfer_all_reduce = FlashInferAllReduce(gloo_process_group, target_device) if flashinfer_all_reduce.disabled: raise RuntimeError("FlashInferAllReduce disabled") if process_rank == 0: - input_tensor = torch.zeros(2, 64, device=cuda_device, dtype=torch.bfloat16) + input_tensor = torch.zeros(2, 64, device=target_device, dtype=torch.bfloat16) else: - input_tensor = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + input_tensor = torch.ones(2, 64, device=target_device, dtype=torch.bfloat16) output_tensor = flashinfer_all_reduce.all_reduce(input_tensor) dist.barrier() - expected_reduced = torch.ones(2, 64, device=cuda_device, dtype=torch.bfloat16) + expected_reduced = torch.ones(2, 64, device=target_device, dtype=torch.bfloat16) if not torch.allclose(output_tensor, expected_reduced): raise RuntimeError("FlashInfer allreduce value mismatch") finally: @@ -102,29 +104,29 @@ def _symm_mem_two_gpu_check_worker(process_rank: int, init_tcp_port: int) -> Non import torch import torch.distributed as dist - cuda_device = torch.device(f"cuda:{process_rank}") - torch.cuda.set_device(cuda_device) + target_device = get_target_device(process_rank) + get_backend().runtime.set_device(target_device) dist.init_process_group( - "nccl", + get_backend().runtime.dist_backend, init_method=f"tcp://127.0.0.1:{init_tcp_port}", world_size=2, rank=process_rank, - device_id=cuda_device, + device_id=target_device, ) try: - nccl_process_group = dist.new_group([0, 1], backend="nccl") + nccl_process_group = dist.new_group([0, 1], backend=get_backend().runtime.dist_backend) from lightllm.distributed.symm_mem_all_reduce import SymmMemAllreduce - symm_mem_all_reduce = SymmMemAllreduce(nccl_process_group, cuda_device, dtype=torch.bfloat16) + symm_mem_all_reduce = SymmMemAllreduce(nccl_process_group, target_device, dtype=torch.bfloat16) if symm_mem_all_reduce.disabled: raise RuntimeError("SymmMemAllreduce disabled") if process_rank == 0: - activation_tensor = torch.zeros(8, 32, device=cuda_device, dtype=torch.bfloat16) + activation_tensor = torch.zeros(8, 32, device=target_device, dtype=torch.bfloat16) else: - activation_tensor = torch.ones(8, 32, device=cuda_device, dtype=torch.bfloat16) + activation_tensor = torch.ones(8, 32, device=target_device, dtype=torch.bfloat16) symm_mem_all_reduce.all_reduce(activation_tensor) dist.barrier() - expected_reduced = torch.ones(8, 32, device=cuda_device, dtype=torch.bfloat16) + expected_reduced = torch.ones(8, 32, device=target_device, dtype=torch.bfloat16) if not torch.allclose(activation_tensor, expected_reduced): raise RuntimeError("SymmMem allreduce value mismatch") finally: diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 5b9705ed0e..c6340c8406 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -2,6 +2,7 @@ import os import torch import requests +from lightllm.platform import get_backend # 规范 rank 的含义,在 llm 推理的相关代码中下述的 rank 的含义如下: # global_rank 全局 rank 序列id, 如两节点 8卡,会存在 0 - 15 16个global_rank @@ -54,6 +55,19 @@ def get_environ(environ_name): return value +def _setup_distributed(*, host: str, port: int, rank: int, world_size: int, device_id: int) -> None: + target_device = get_backend().runtime.init_process_group( + host=host, + port=port, + rank=rank, + world_size=world_size, + device_id=device_id, + ) + _a = torch.zeros([1], device=target_device) + dist.all_reduce(_a) + del _a + + def init_vision_distributed_env(kvargs): """ # kvargs = { @@ -79,18 +93,14 @@ def init_vision_distributed_env(kvargs): set_current_rank_in_dp(tp_rank_id) device_id = kvargs["device_id"] set_current_device_id(device_id) - torch.cuda.set_device(device_id) - dist.init_process_group( - "nccl", - init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', + + _setup_distributed( + host="127.0.0.1", + port=kvargs["visual_nccl_port"], rank=kvargs["tp_rank_id"], world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), + device_id=device_id, ) - # warmup nccl communicator - _a = torch.zeros([1]).to(f"cuda:{device_id}") - dist.all_reduce(_a) - del _a def init_audio_distributed_env(kvargs): @@ -111,17 +121,14 @@ def init_audio_distributed_env(kvargs): set_current_rank_in_dp(tp_rank_id) device_id = kvargs["device_id"] set_current_device_id(device_id) - torch.cuda.set_device(device_id) - dist.init_process_group( - "nccl", - init_method=f'tcp://127.0.0.1:{kvargs["audio_nccl_port"]}', + + _setup_distributed( + host="127.0.0.1", + port=kvargs["audio_nccl_port"], rank=tp_rank_id, world_size=tp_world_size, - device_id=torch.device(f"cuda:{device_id}"), + device_id=device_id, ) - _a = torch.zeros([1]).to(f"cuda:{device_id}") - dist.all_reduce(_a) - del _a def init_distributed_env(kvargs): @@ -144,19 +151,14 @@ def init_distributed_env(kvargs): _init_nccl_env() device_id = kvargs["rank_id"] % get_node_world_size() set_current_device_id(device_id) - torch.cuda.set_device(device_id) - dist.init_process_group( - "nccl", - init_method=f'tcp://{kvargs["nccl_host"]}:{kvargs["nccl_port"]}', + + _setup_distributed( + host=kvargs["nccl_host"], + port=kvargs["nccl_port"], rank=kvargs["rank_id"], world_size=kvargs["world_size"], - device_id=torch.device(f"cuda:{device_id}"), + device_id=device_id, ) - # warmup nccl communicator - _a = torch.zeros([1]).to(f"cuda:{device_id}") - dist.all_reduce(_a) - del _a - def set_global_rank(global_rank: int): set_environ("LIGHTLLM_GLOBAL_RANK", global_rank) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 350507e897..95a07c3ffc 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -228,3 +228,8 @@ def get_added_mtp_kv_layer_num() -> int: @lru_cache(maxsize=None) def get_pd_split_max_new_tokens() -> int: return int(os.getenv("LIGHTLLM_PD_SPLIT_MAX_NEW_TOKENS", 2048)) + + +@lru_cache(maxsize=None) +def get_page_size(): + return int(os.getenv("PAGE_SIZE", 1)) diff --git a/lightllm/utils/infer_utils.py b/lightllm/utils/infer_utils.py index dadd96648d..c4b95793d1 100644 --- a/lightllm/utils/infer_utils.py +++ b/lightllm/utils/infer_utils.py @@ -5,6 +5,7 @@ from typing import Callable from lightllm.utils.log_utils import init_logger +from lightllm.platform import get_backend logger = init_logger(__name__) @@ -14,17 +15,18 @@ def mark_cost_time(func_name): def inner_func(func): def time_func(*args, **kwargs): + platform_backend = get_backend() if dist.get_rank() in [0, 1] and is_show_cost_time: - torch.cuda.synchronize() + platform_backend.runtime.synchronize() start_time = time.time() ans = func(*args, **kwargs) - torch.cuda.synchronize() + platform_backend.runtime.synchronize() logger.debug(f"{func_name} cost time: {(time.time() - start_time) * 1000}") return ans else: - torch.cuda.synchronize() + platform_backend.runtime.synchronize() ans = func(*args, **kwargs) - torch.cuda.synchronize() + platform_backend.runtime.synchronize() return ans return time_func @@ -36,14 +38,14 @@ def time_func(*args, **kwargs): def mark_start(key): - torch.cuda.synchronize() + get_backend().runtime.synchronize() global time_mark time_mark[key] = time.time() return def mark_end(key, print_min_cost=0.0): - torch.cuda.synchronize() + get_backend().runtime.synchronize() global time_mark cost_time = (time.time() - time_mark[key]) * 1000 if cost_time > print_min_cost: @@ -53,11 +55,11 @@ def mark_end(key, print_min_cost=0.0): def calculate_time(show=False, min_cost_ms=0.0): def wrapper(func): def inner_func(*args, **kwargs): - torch.cuda.synchronize() + get_backend().runtime.synchronize() if show: start_time = time.time() result = func(*args, **kwargs) - torch.cuda.synchronize() + get_backend().runtime.synchronize() if show: cost_time = (time.time() - start_time) * 1000 if cost_time > min_cost_ms: @@ -70,14 +72,14 @@ def inner_func(*args, **kwargs): def benchmark_time(func: Callable, *args, warmup: int = 1, repeat: int = 5, **kwargs) -> float: - torch.cuda.synchronize() + get_backend().runtime.synchronize() for _ in range(warmup): func(*args, **kwargs) - torch.cuda.synchronize() + get_backend().runtime.synchronize() start_time = time.time() for _ in range(repeat): func(*args, **kwargs) - torch.cuda.synchronize() + get_backend().runtime.synchronize() cost_time = (time.time() - start_time) * 1000 / repeat # unit: ms return cost_time @@ -89,14 +91,15 @@ def set_random_seed(seed: int) -> None: import numpy as np torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) + platform_backend = get_backend() + if platform_backend.runtime.is_available(): + platform_backend.runtime.manual_seed_all(seed) def post_empty_cache(func): def wrapper(*args, **kwargs): result = func(*args, **kwargs) - torch.cuda.empty_cache() + get_backend().runtime.empty_cache() return result return wrapper diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 55284b27f3..95472098c4 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -30,6 +30,8 @@ from lightllm.utils.auto_shm_cleanup import register_sysv_shm_for_cleanup from lightllm.utils.dist_utils import get_current_device_id from lightllm.common.linear_att_cache_manager.config_objs import LinearAttCacheConfig +from lightllm.platform import get_backend +from lightllm.utils.cpu_cache_host_register import get_host_register_worker logger = init_logger(__name__) @@ -257,40 +259,13 @@ def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> "AsyncRegistrationHandle handle = AsyncRegistrationHandle(total_tasks=len(tasks)) - def _worker(): - cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") - cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] - cuda.cudaHostRegister.restype = ctypes.c_int - cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] - cuda.cudaHostGetDevicePointer.restype = ctypes.c_int - - cudaHostRegisterFlag = 3 - - torch.cuda.set_device(get_current_device_id()) - # TODO 这个地方的分块注册是否具备合法性和合理性。 - for offset, seg_len in tasks: - ptr = ctypes.c_void_p(shm_ptr + offset) - r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) - if r != 0: - raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") - handle.task_count += 1 - - if handle.device_ptr is None: - # 提前获取对应的指针对象,避免在wait后再获取,照成过长的阻塞等待。 - device_ptr = ctypes.c_void_p() - host_ptr = ctypes.c_void_p(shm_ptr) - res = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) - if res != 0: - raise Exception(f"cudaHostGetDevicePointer failed with error code {res}") - - logger.info( - f"cudaHostGetDevicePointer success, host_ptr={host_ptr.value}, device_ptr={device_ptr.value}" - ) - handle.device_ptr = device_ptr.value - - handle.tasks_finished.set() - - th = threading.Thread(target=_worker, name=f"cpu_cache_register_{shm_ptr}", daemon=True) + platform_worker = get_host_register_worker() + th = threading.Thread( + target=platform_worker, + args=(shm_ptr, tasks, handle), + name=f"cpu_cache_register_{shm_ptr}", + daemon=True, + ) handle.thread = th th.start() return handle diff --git a/lightllm/utils/profile_max_tokens.py b/lightllm/utils/profile_max_tokens.py index e3a62b62ea..b7b76fc0a5 100644 --- a/lightllm/utils/profile_max_tokens.py +++ b/lightllm/utils/profile_max_tokens.py @@ -6,6 +6,7 @@ import argparse from lightllm.common.build_utils import repair_config from lightllm.utils.dist_utils import get_current_device_id +from lightllm.platform import get_backend data_type_dict = {"float32": 4, "float16": 2, "bfloat16": 2, "fp32": 4, "fp16": 2, "bf16": 2, "int8": 1, "int4": 0.5} @@ -14,10 +15,12 @@ def get_available_gpu_memory(world_size): """ Get available memory. """ - torch.cuda.empty_cache() - free_gpu_memory, _ = torch.cuda.mem_get_info(get_current_device_id()) + backend_runtime = get_backend().runtime + backend_runtime.empty_cache() + free_gpu_memory, _ = backend_runtime.mem_get_info(get_current_device_id()) if world_size > 1: - tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(f"cuda:{get_current_device_id()}") + target_device = backend_runtime.target_device(get_current_device_id()) + tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(target_device) torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) free_gpu_memory = tensor.item() return free_gpu_memory / (1024 ** 3) @@ -27,7 +30,8 @@ def get_total_gpu_memory(): """ Get the total GPU memory of the machine """ - total_memory = torch.cuda.get_device_properties(0).total_memory + backend_runtime = get_backend().runtime + total_memory = backend_runtime.get_device_properties(0).total_memory return total_memory / (1024 ** 3) # Convert to GB diff --git a/lightllm/utils/tuning_utils.py b/lightllm/utils/tuning_utils.py index 93f482d976..e9045f8fbb 100644 --- a/lightllm/utils/tuning_utils.py +++ b/lightllm/utils/tuning_utils.py @@ -5,6 +5,7 @@ from multiprocessing.pool import Pool from multiprocessing.pool import util, worker from typing import Callable, Any, Dict, List +from lightllm.platform import get_backend from lightllm.utils.log_utils import init_logger from lightllm.utils.watchdog_utils import Watchdog @@ -33,7 +34,7 @@ def run_func(func, args): def mp_tuning(func, args: Dict[str, Any]): # 修复 pool 中的进程无法启动子进程进行 kernel tuning 的问题 Pool._repopulate_pool_static = fix_repopulate_pool_static - device_count = torch.cuda.device_count() + device_count = get_backend().runtime.device_count() with mp.Pool(processes=device_count) as pool: tasks = []