Skip to content

Commit

Permalink
Refactor run_program_batch to preserve original behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
openhands-agent committed Jan 3, 2025
1 parent 21e9e63 commit f8bc3c8
Showing 1 changed file with 54 additions and 55 deletions.
109 changes: 54 additions & 55 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def run_program_batch(
default_sampling_para,
num_threads,
progress_bar,
generator_style=False,
):
if hasattr(backend, "endpoint"):
backend = backend.endpoint
Expand All @@ -109,63 +110,61 @@ def run_program_batch(
num_threads = max(96, multiprocessing.cpu_count() * 16)
num_threads = min(num_threads, len(batch_arguments))

if num_threads == 1:
rets = []
if progress_bar:
for arguments in tqdm.tqdm(batch_arguments):
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
def generate_results():
if num_threads == 1:
iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments
for arguments in iterator:
yield run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
else:
for arguments in batch_arguments:
rets.append(
run_program(
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
else:
if progress_bar:
pbar = tqdm.tqdm(total=len(batch_arguments))

with ThreadPoolExecutor(num_threads) as executor:
futures = []
for arguments in batch_arguments:
futures.append(
executor.submit(
run_program,
program,
backend,
(),
arguments,
default_sampling_para,
False,
True,
)
)
if progress_bar:
futures[-1].add_done_callback(lambda _: pbar.update())

rets = [f.result() for f in futures]
rets[-1].sync()

if progress_bar:
pbar.close()

return rets
pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None

# Process in chunks to avoid overwhelming ThreadPoolExecutor
# Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks
# so we will never reach "yield" until all tasks are done which defeat the purpose of generator style
chunk_size = len(batch_arguments) if not generator_style else 200

with ThreadPoolExecutor(num_threads) as executor:
for chunk_start in range(0, len(batch_arguments), chunk_size):
chunk_end = min(chunk_start + chunk_size, len(batch_arguments))
chunk_futures = []

# Submit chunk of tasks
for i in range(chunk_start, chunk_end):
future = executor.submit(
run_program,
program,
backend,
(),
batch_arguments[i],
default_sampling_para,
False,
True,
)
if pbar:
future.add_done_callback(lambda _: pbar.update())
chunk_futures.append(future)

# Yield results from this chunk as they complete
for future in chunk_futures:
yield future.result()

if pbar:
pbar.close()

results = generate_results()
if not generator_style:
results = list(results)
if results: # Only sync if we have results
results[-1].sync()
return results


def cache_program(program, backend):
Expand Down

0 comments on commit f8bc3c8

Please sign in to comment.