diff --git a/benchmarks/bench_processors.py b/benchmarks/bench_processors.py index db1e4a8f1..02ea52b79 100644 --- a/benchmarks/bench_processors.py +++ b/benchmarks/bench_processors.py @@ -37,15 +37,12 @@ def get_mock_processor_inputs(array_library, num_tokens=30000): logits: (4, 30,000 ) dtype=float input_ids shape: (4, 2048) dtype=int """ - if array_library == "torch": - logits = torch.rand((4, num_tokens), dtype=torch.float) - input_ids = torch.randint( - low=0, high=num_tokens, size=(4, 2048), dtype=torch.int - ) - elif array_library == "torch_cuda": - logits = torch.rand((4, num_tokens), dtype=torch.float, device="cuda") + if array_library.startswith("torch"): + device = array_library.split("_")[1] if "_" in array_library else "cpu" + + logits = torch.rand((4, num_tokens), dtype=torch.float, device=device) input_ids = torch.randint( - low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device="cuda" + low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device=device ) elif array_library == "numpy": logits = np.random.rand(4, num_tokens).astype(np.float32) @@ -88,6 +85,8 @@ class LogitsProcessorPassthroughBenchmark: params += ["mlx"] if torch.cuda.is_available(): params += ["torch_cuda"] + if torch.mps.is_available(): + params += ["torch_mps"] if is_jax_allowed(): params += ["jax"] @@ -108,9 +107,10 @@ class LogitsProcessorStructuredBenchmark: array_libraries = ["torch", "numpy"] if is_mlx_lm_allowed(): array_libraries += ["mlx"] - # PR TODO if torch.cuda.is_available(): array_libraries += ["torch_cuda"] + if torch.mps.is_available(): + array_libraries += ["torch_mps"] # accept very many or very few tokens, respectively patterns = [r"[^Z]*", "Z*"]