Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 48 additions & 50 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.is_deepseekv3_mtp_mode = self.args.mtp_mode == "deepseekv3"
self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]

self._init_datatype()
self._init_config()
Expand Down Expand Up @@ -262,10 +262,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
infer_state.b_req_idx = model_input.b_req_idx
infer_state.b_seq_len = model_input.b_seq_len
if model_input.is_prefill:
if model_input.b_ready_cache_len is not None:
infer_state.b_ready_cache_len = model_input.b_ready_cache_len
else:
infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len)
assert model_input.b_ready_cache_len is not None
infer_state.b_ready_cache_len = model_input.b_ready_cache_len

infer_state.multimodal_params = model_input.multimodal_params

Expand Down Expand Up @@ -337,14 +335,14 @@ def _prefill(
infer_state = self._create_inferstate(model_input)
init_req_to_token_indexes(
self.req_manager.req_to_token_indexs,
model_input.b_req_idx,
model_input.b_seq_len,
infer_state.b_ready_cache_len,
model_input.b_req_idx_cpu,
model_input.b_seq_len_cpu,
model_input.b_ready_cache_len_cpu,
model_input.max_len_in_batch,
infer_state.mem_index,
)

infer_state.init_some_extra_state(self, model_input.input_ids)
infer_state.init_some_extra_state(self, model_input)
return self._context_forward(model_input.input_ids, infer_state)

def _decode(
Expand All @@ -369,7 +367,7 @@ def _decode(
infer_state.b_seq_len,
infer_state.mem_index,
)
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
infer_state.init_some_extra_state(self, padded_model_input)

if self.graph.need_capture(find_graph_batch_size):
infer_state.is_cuda_graph = True
Expand All @@ -390,7 +388,7 @@ def _decode(
infer_state.b_seq_len,
infer_state.mem_index,
)
infer_state.init_some_extra_state(self, model_input.input_ids)
infer_state.init_some_extra_state(self, model_input)
model_output = self._token_forward(model_input.input_ids, infer_state)

return model_output
Expand Down Expand Up @@ -540,15 +538,15 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
infer_state0.b_seq_len,
infer_state0.mem_index,
)
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
infer_state0.init_some_extra_state(self, padded_model_input0)
infer_state1 = self._create_inferstate(padded_model_input1, 1)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state1.b_req_idx,
infer_state1.b_seq_len,
infer_state1.mem_index,
)
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
infer_state1.init_some_extra_state(self, padded_model_input1)

if self.graph.need_capture(find_graph_batch_size):
infer_state0.is_cuda_graph = True
Expand Down Expand Up @@ -684,25 +682,25 @@ def _check_max_len_infer(self):
# 模拟最大长度进行 prefill,观察是否出现 OOM
try:
logger.info("begin check max_len infer")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cpu")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu")
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu")
total_token_num = self.batch_max_tokens
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu")
model_input = ModelInput(
batch_size=1,
total_token_num=total_token_num,
max_len_in_batch=self.batch_max_tokens,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
input_ids_cpu=dummy_input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
b_mtp_index_cpu=b_mtp_index,
is_prefill=True,
b_ready_cache_len=b_ready_cache_len,
b_ready_cache_len_cpu=b_ready_cache_len,
)
model_output = self.forward(
model_input,
Expand Down Expand Up @@ -750,29 +748,29 @@ def _autotune_warmup(self):
self.layers_num = self.autotune_layers()
for input_len in tqdm(warmup_lengths, desc="warming up"):
try:
rand_gen = torch.Generator(device="cuda")
rand_gen = torch.Generator(device="cpu")
rand_gen.manual_seed(input_len)
dummy_input_ids = torch.randint(
0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen
0, 10000, (input_len,), dtype=torch.int32, device="cpu", generator=rand_gen
)
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu")
b_seq_len[:] = input_len
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu")
total_token_num = input_len
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu")
model_input = ModelInput(
batch_size=1,
total_token_num=total_token_num,
max_len_in_batch=input_len,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
input_ids_cpu=dummy_input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
b_mtp_index_cpu=b_mtp_index,
is_prefill=True,
b_ready_cache_len=b_ready_cache_len,
b_ready_cache_len_cpu=b_ready_cache_len,
multimodal_params=[],
**self._gen_special_model_input(total_token_num),
)
Expand Down Expand Up @@ -807,27 +805,27 @@ def _init_padded_req(self):
# prefill init padding req.
prefill_input_len = 1
batch_size = 1
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cpu")
b_req_idx = torch.tensor(
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
mem_indexes = torch.tensor(
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cpu")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
total_token_num = prefill_input_len * batch_size
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
model_input = ModelInput(
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=prefill_input_len,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
b_seq_len=b_seq_len,
b_ready_cache_len=b_ready_cache_len,
input_ids_cpu=dummy_input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_mtp_index_cpu=b_mtp_index,
b_seq_len_cpu=b_seq_len,
b_ready_cache_len_cpu=b_ready_cache_len,
is_prefill=True,
multimodal_params=[],
**self._gen_special_model_input(total_token_num),
Expand Down
32 changes: 21 additions & 11 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@ class ModelInput:
batch_size: int
total_token_num: int
max_len_in_batch: int
input_ids: torch.Tensor
b_req_idx: torch.Tensor
b_mtp_index: torch.Tensor
b_seq_len: torch.Tensor
input_ids: torch.Tensor = None
b_req_idx: torch.Tensor = None
b_mtp_index: torch.Tensor = None
b_seq_len: torch.Tensor = None
mem_indexes: torch.Tensor = None
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
multimodal_params: list = field(default_factory=list)

# cpu 变量
input_ids_cpu: torch.Tensor = None
b_req_idx_cpu: torch.Tensor = None
b_mtp_index_cpu: torch.Tensor = None
mem_indexes_cpu: torch.Tensor = None
b_seq_len_cpu: torch.Tensor = None
b_ready_cache_len_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
Expand All @@ -33,15 +38,20 @@ class ModelInput:
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def to_cuda(self):
if self.input_ids is not None:
self.input_ids = self.input_ids.cuda(non_blocking=True)
# input_ids 可能不存在,通过req_to_token_indexs来获取
if self.input_ids is None and self.input_ids_cpu is not None:
self.input_ids = self.input_ids_cpu.cuda(non_blocking=True)
if self.mem_indexes is None:
self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True)
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
if self.b_ready_cache_len is not None:
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
if self.b_req_idx is None:
self.b_req_idx = self.b_req_idx_cpu.cuda(non_blocking=True)
if self.b_seq_len is None:
self.b_seq_len = self.b_seq_len_cpu.cuda(non_blocking=True)
# b_ready_cache_len 只在 prefill 阶段生效
if self.b_ready_cache_len_cpu is not None:
self.b_ready_cache_len = self.b_ready_cache_len_cpu.cuda(non_blocking=True)
if self.b_mtp_index is None:
self.b_mtp_index = self.b_mtp_index_cpu.cuda(non_blocking=True)


@dataclass
Expand Down
36 changes: 19 additions & 17 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,25 +195,27 @@ def warmup(self, model):
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")

model_input = ModelInput(
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=max_len_in_batch,
input_ids=input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
input_ids_cpu=input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
b_mtp_index_cpu=b_mtp_index,
is_prefill=False,
b_ready_cache_len_cpu=b_ready_cache_len,
**model._gen_special_model_input(batch_size),
)
model_output: ModelOutput = model.forward(model_input)
Expand Down Expand Up @@ -251,25 +253,25 @@ def warmup_overlap(self, model):
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")

micro_batch = ModelInput(
is_prefill=False,
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=max_len_in_batch,
input_ids=input_ids,
input_ids_cpu=input_ids,
b_mtp_index=b_mtp_index,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
**model._gen_special_model_input(batch_size),
)
decode_batches.append(micro_batch)
Expand Down
6 changes: 2 additions & 4 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self):
# 的输入会用到,其他模型和场景都不会用到
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
def init_some_extra_state(self, model, model_input: ModelInput):
if self.is_prefill:
(
self.b_q_seq_len,
Expand All @@ -75,9 +75,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.max_q_seq_len,
self.max_kv_seq_len,
) = gen_prefill_params(
input_token_num=input_ids.shape[0],
b_ready_cache_len=self.b_ready_cache_len,
b_seq_len=self.b_seq_len,
model_input,
)
self.b_start_loc = self.b1_cu_q_seq_len[0:-1]
else:
Expand Down
15 changes: 12 additions & 3 deletions lightllm/common/basemodel/triton_kernel/gen_prefill_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import triton
import triton.language as tl

from lightllm.common.basemodel.batch_objs import ModelInput


@triton.jit
def _gen_cumsum_pad0_kernel(
Expand Down Expand Up @@ -80,7 +82,14 @@ def _gen_prefill_position(


@torch.no_grad()
def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_seq_len: torch.Tensor):
def gen_prefill_params(model_input: ModelInput):
# input_token_num: int, b_ready_cache_len: torch.Tensor, b_seq_len: torch.Tensor):
input_token_num = model_input.input_ids.shape[0]
b_seq_len = model_input.b_seq_len
b_ready_cache_len = model_input.b_ready_cache_len
b_seq_len_cpu = model_input.b_seq_len_cpu
b_ready_cache_len_cpu = model_input.b_ready_cache_len_cpu

batch_size = b_ready_cache_len.shape[0]
position_ids = torch.empty((input_token_num,), dtype=torch.int32, device="cuda")
assert b_ready_cache_len.shape[0] == b_seq_len.shape[0]
Expand All @@ -99,6 +108,6 @@ def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_
num_stages=1,
)
b_kv_seq_len = b_seq_len
max_q_seq_len = b_q_seq_len.max().item()
max_kv_seq_len = b_kv_seq_len.max().item()
max_q_seq_len = (b_seq_len_cpu - b_ready_cache_len_cpu).max()
max_kv_seq_len = b_seq_len_cpu.max()
Comment on lines +111 to +112

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The .max() method on a PyTorch tensor returns a 0-dimensional tensor, not a Python scalar. Downstream code that uses max_q_seq_len and max_kv_seq_len (e.g., for creating new tensors with torch.empty) expects an integer and will fail with a TypeError. You should add .item() to convert the 0-dim tensors to scalars. Since these operations are on CPU tensors, calling .item() will not cause a device synchronization.

Suggested change
max_q_seq_len = (b_seq_len_cpu - b_ready_cache_len_cpu).max()
max_kv_seq_len = b_seq_len_cpu.max()
max_q_seq_len = (b_seq_len_cpu - b_ready_cache_len_cpu).max().item()
max_kv_seq_len = b_seq_len_cpu.max().item()

return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len
Loading