Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 52 additions & 54 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

self.is_deepseekv3_mtp_mode = self.args.mtp_mode == "deepseekv3"
self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]

self._init_datatype()
self._init_config()
Expand Down Expand Up @@ -262,10 +262,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
infer_state.b_req_idx = model_input.b_req_idx
infer_state.b_seq_len = model_input.b_seq_len
if model_input.is_prefill:
if model_input.b_ready_cache_len is not None:
infer_state.b_ready_cache_len = model_input.b_ready_cache_len
else:
infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len)
assert model_input.b_ready_cache_len is not None
infer_state.b_ready_cache_len = model_input.b_ready_cache_len

infer_state.multimodal_params = model_input.multimodal_params

Expand Down Expand Up @@ -337,14 +335,14 @@ def _prefill(
infer_state = self._create_inferstate(model_input)
init_req_to_token_indexes(
self.req_manager.req_to_token_indexs,
model_input.b_req_idx,
model_input.b_seq_len,
infer_state.b_ready_cache_len,
model_input.b_req_idx_cpu,
model_input.b_seq_len_cpu,
model_input.b_ready_cache_len_cpu,
model_input.max_len_in_batch,
infer_state.mem_index,
)

infer_state.init_some_extra_state(self, model_input.input_ids)
infer_state.init_some_extra_state(self, model_input)
return self._context_forward(model_input.input_ids, infer_state)

def _decode(
Expand All @@ -369,7 +367,7 @@ def _decode(
infer_state.b_seq_len,
infer_state.mem_index,
)
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
infer_state.init_some_extra_state(self, padded_model_input)

if self.graph.need_capture(find_graph_batch_size):
infer_state.is_cuda_graph = True
Expand All @@ -390,7 +388,7 @@ def _decode(
infer_state.b_seq_len,
infer_state.mem_index,
)
infer_state.init_some_extra_state(self, model_input.input_ids)
infer_state.init_some_extra_state(self, model_input)
model_output = self._token_forward(model_input.input_ids, infer_state)

return model_output
Expand Down Expand Up @@ -482,7 +480,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
model_input0.max_len_in_batch,
infer_state0.mem_index,
)
infer_state0.init_some_extra_state(self, input_ids0)
infer_state0.init_some_extra_state(self, model_input0)

infer_state1 = self._create_inferstate(model_input1, 1)
init_req_to_token_indexes(
Expand All @@ -493,7 +491,7 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
model_input1.max_len_in_batch,
infer_state1.mem_index,
)
infer_state1.init_some_extra_state(self, input_ids1)
infer_state1.init_some_extra_state(self, model_input1)

model_output0, model_output1 = self._overlap_tpsp_context_forward(
input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1
Expand Down Expand Up @@ -540,15 +538,15 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
infer_state0.b_seq_len,
infer_state0.mem_index,
)
infer_state0.init_some_extra_state(self, padded_model_input0.input_ids)
infer_state0.init_some_extra_state(self, padded_model_input0)
infer_state1 = self._create_inferstate(padded_model_input1, 1)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state1.b_req_idx,
infer_state1.b_seq_len,
infer_state1.mem_index,
)
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
infer_state1.init_some_extra_state(self, padded_model_input1)

if self.graph.need_capture(find_graph_batch_size):
infer_state0.is_cuda_graph = True
Expand Down Expand Up @@ -578,15 +576,15 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
infer_state0.b_seq_len,
infer_state0.mem_index,
)
infer_state0.init_some_extra_state(self, model_input0.input_ids)
infer_state0.init_some_extra_state(self, model_input0)
infer_state1 = self._create_inferstate(model_input1, 1)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs,
infer_state1.b_req_idx,
infer_state1.b_seq_len,
infer_state1.mem_index,
)
infer_state1.init_some_extra_state(self, model_input1.input_ids)
infer_state1.init_some_extra_state(self, model_input1)

model_output0, model_output1 = self._overlap_tpsp_token_forward(
model_input0.input_ids, infer_state0, input_ids1=model_input1.input_ids, infer_state1=infer_state1
Expand Down Expand Up @@ -684,25 +682,25 @@ def _check_max_len_infer(self):
# 模拟最大长度进行 prefill,观察是否出现 OOM
try:
logger.info("begin check max_len infer")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cpu")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu")
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu")
total_token_num = self.batch_max_tokens
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu")
model_input = ModelInput(
batch_size=1,
total_token_num=total_token_num,
max_len_in_batch=self.batch_max_tokens,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
input_ids_cpu=dummy_input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
b_mtp_index_cpu=b_mtp_index,
is_prefill=True,
b_ready_cache_len=b_ready_cache_len,
b_ready_cache_len_cpu=b_ready_cache_len,
)
model_output = self.forward(
model_input,
Expand Down Expand Up @@ -750,29 +748,29 @@ def _autotune_warmup(self):
self.layers_num = self.autotune_layers()
for input_len in tqdm(warmup_lengths, desc="warming up"):
try:
rand_gen = torch.Generator(device="cuda")
rand_gen = torch.Generator(device="cpu")
rand_gen.manual_seed(input_len)
dummy_input_ids = torch.randint(
0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen
0, 10000, (input_len,), dtype=torch.int32, device="cpu", generator=rand_gen
)
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda()
b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda")
b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cpu")
mem_indexes = self.mem_manager.alloc(len(dummy_input_ids))
b_seq_len = torch.ones(1, dtype=torch.int32, device="cpu")
b_seq_len[:] = input_len
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cpu")
total_token_num = input_len
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cpu")
model_input = ModelInput(
batch_size=1,
total_token_num=total_token_num,
max_len_in_batch=input_len,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
input_ids_cpu=dummy_input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
b_mtp_index_cpu=b_mtp_index,
is_prefill=True,
b_ready_cache_len=b_ready_cache_len,
b_ready_cache_len_cpu=b_ready_cache_len,
multimodal_params=[],
**self._gen_special_model_input(total_token_num),
)
Expand Down Expand Up @@ -807,27 +805,27 @@ def _init_padded_req(self):
# prefill init padding req.
prefill_input_len = 1
batch_size = 1
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cpu")
b_req_idx = torch.tensor(
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[self.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
mem_indexes = torch.tensor(
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[self.mem_manager.HOLD_TOKEN_MEMINDEX for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.ones(batch_size, dtype=torch.int32, device="cpu")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
total_token_num = prefill_input_len * batch_size
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
model_input = ModelInput(
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=prefill_input_len,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_mtp_index=b_mtp_index,
b_seq_len=b_seq_len,
b_ready_cache_len=b_ready_cache_len,
input_ids_cpu=dummy_input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_mtp_index_cpu=b_mtp_index,
b_seq_len_cpu=b_seq_len,
b_ready_cache_len_cpu=b_ready_cache_len,
is_prefill=True,
multimodal_params=[],
**self._gen_special_model_input(total_token_num),
Expand Down
32 changes: 21 additions & 11 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@ class ModelInput:
batch_size: int
total_token_num: int
max_len_in_batch: int
input_ids: torch.Tensor
b_req_idx: torch.Tensor
b_mtp_index: torch.Tensor
b_seq_len: torch.Tensor
input_ids: torch.Tensor = None
b_req_idx: torch.Tensor = None
b_mtp_index: torch.Tensor = None
b_seq_len: torch.Tensor = None
mem_indexes: torch.Tensor = None
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
multimodal_params: list = field(default_factory=list)

# cpu 变量
input_ids_cpu: torch.Tensor = None
b_req_idx_cpu: torch.Tensor = None
b_mtp_index_cpu: torch.Tensor = None
mem_indexes_cpu: torch.Tensor = None
b_seq_len_cpu: torch.Tensor = None
b_ready_cache_len_cpu: torch.Tensor = None
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
# 的一些变量
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
Expand All @@ -33,15 +38,20 @@ class ModelInput:
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def to_cuda(self):
if self.input_ids is not None:
self.input_ids = self.input_ids.cuda(non_blocking=True)
# input_ids 可能不存在,通过req_to_token_indexs来获取
if self.input_ids is None and self.input_ids_cpu is not None:
self.input_ids = self.input_ids_cpu.cuda(non_blocking=True)
if self.mem_indexes is None:
self.mem_indexes = self.mem_indexes_cpu.cuda(non_blocking=True)
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
self.b_mtp_index = self.b_mtp_index.cuda(non_blocking=True)
if self.b_ready_cache_len is not None:
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
if self.b_req_idx is None:
self.b_req_idx = self.b_req_idx_cpu.cuda(non_blocking=True)
if self.b_seq_len is None:
self.b_seq_len = self.b_seq_len_cpu.cuda(non_blocking=True)
# b_ready_cache_len 只在 prefill 阶段生效
if self.b_ready_cache_len_cpu is not None:
self.b_ready_cache_len = self.b_ready_cache_len_cpu.cuda(non_blocking=True)
if self.b_mtp_index is None:
self.b_mtp_index = self.b_mtp_index_cpu.cuda(non_blocking=True)


@dataclass
Expand Down
36 changes: 19 additions & 17 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,25 +195,27 @@ def warmup(self, model):
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cpu")
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")

model_input = ModelInput(
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=max_len_in_batch,
input_ids=input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
b_mtp_index=b_mtp_index,
input_ids_cpu=input_ids,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
b_mtp_index_cpu=b_mtp_index,
is_prefill=False,
b_ready_cache_len_cpu=b_ready_cache_len,
**model._gen_special_model_input(batch_size),
)
model_output: ModelOutput = model.forward(model_input)
Expand Down Expand Up @@ -251,25 +253,25 @@ def warmup_overlap(self, model):
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cpu")
mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda()
b_req_idx = torch.tensor(
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda"
[model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cpu"
)
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cpu")
b_seq_len.fill_(seq_len)
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu")

micro_batch = ModelInput(
is_prefill=False,
batch_size=batch_size,
total_token_num=total_token_num,
max_len_in_batch=max_len_in_batch,
input_ids=input_ids,
input_ids_cpu=input_ids,
b_mtp_index=b_mtp_index,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
mem_indexes_cpu=mem_indexes,
b_req_idx_cpu=b_req_idx,
b_seq_len_cpu=b_seq_len,
**model._gen_special_model_input(batch_size),
)
decode_batches.append(micro_batch)
Expand Down
6 changes: 2 additions & 4 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self):
# 的输入会用到,其他模型和场景都不会用到
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
def init_some_extra_state(self, model, model_input: ModelInput):
if self.is_prefill:
(
self.b_q_seq_len,
Expand All @@ -75,9 +75,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
self.max_q_seq_len,
self.max_kv_seq_len,
) = gen_prefill_params(
input_token_num=input_ids.shape[0],
b_ready_cache_len=self.b_ready_cache_len,
b_seq_len=self.b_seq_len,
model_input,
)
self.b_start_loc = self.b1_cu_q_seq_len[0:-1]
else:
Expand Down
Loading