Skip to content

Commit

Permalink
add try-exception in forward thread
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Sep 22, 2024
1 parent 3f3ddb8 commit 8c1ad02
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
1 change: 1 addition & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class ResponseType(enum.Enum):
SESSION_NOT_EXIST = enum.auto()
HANDLER_NOT_EXIST = enum.auto()
INPUT_LENGTH_ERROR = enum.auto()
INTERNAL_ENGINE_ERROR = enum.auto()


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def generate(
gen_config.temperature = 1.0
gen_config.repetition_penalty = 1.0
# set random if it is not set and sequence_start is True
if gen_config.random_seed is None and sequence_start:
elif gen_config.random_seed is None and sequence_start:
gen_config.random_seed = random.getrandbits(64)
if gen_config.n > 1:
logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. "
Expand Down
28 changes: 25 additions & 3 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,12 @@ def _forward_thread(self, inputs):
self.gpu_count)

def _func():
output = self.model_inst.forward(inputs, instance_comm)
try:
output = self.model_inst.forward(inputs, instance_comm)
except Exception as e:
logger.error(f'Exception happened: {e}')
self.que.put((-1, None))
return
self.que.put((True, output))

self.executor = ThreadPoolExecutor(1)
Expand All @@ -372,7 +377,12 @@ def _async_forward_thread(self, inputs, que: LifoQueue):
self.gpu_count)

def _func():
output = self.model_inst.forward(inputs, instance_comm)
try:
output = self.model_inst.forward(inputs, instance_comm)
except Exception as e:
logger.error(f'Exception happened: {e}')
self.que.put((-1, None))
return
que.put((True, output))

self.executor = ThreadPoolExecutor(1)
Expand Down Expand Up @@ -653,6 +663,10 @@ async def async_stream_infer(self,
await asyncio.sleep(0.002)

finish, tm_outputs = que.get()
if finish < 0:
yield EngineOutput()
self.executor.shutdown()
break

outputs = _tm_dict_to_torch_dict(tm_outputs)

Expand Down Expand Up @@ -766,6 +780,10 @@ def stream_infer(self,
self.que.get()

finish, tm_outputs = self.que.get()
if finish < 0:
yield EngineOutput()
self.executor.shutdown()
break

outputs = _tm_dict_to_torch_dict(tm_outputs)

Expand Down Expand Up @@ -892,7 +910,9 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
# start forward thread
self._forward_thread(tm_inputs)

_, tm_outputs = self.que.get()
res, tm_outputs = self.que.get()
if res < 0:
return None

outputs = _tm_dict_to_torch_dict(tm_outputs)
logits = outputs['logits']
Expand Down Expand Up @@ -942,6 +962,8 @@ def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
steps,
sequence_start=(i == 0),
sequence_end=(i == n_max_iter - 1))
if _logits is None:
return None
_logits = _logits.to(device=device)
logits.append(_logits)

Expand Down

0 comments on commit 8c1ad02

Please sign in to comment.