Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GuideLogitsProcessor for MPS device #1306

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,12 @@ def get_mock_processor_inputs(array_library, num_tokens=30000):
logits: (4, 30,000 ) dtype=float
input_ids shape: (4, 2048) dtype=int
"""
if array_library == "torch":
logits = torch.rand((4, num_tokens), dtype=torch.float)
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int
)
elif array_library == "torch_cuda":
logits = torch.rand((4, num_tokens), dtype=torch.float, device="cuda")
if array_library.startswith("torch"):
device = array_library.split("_")[1] if "_" in array_library else "cpu"

logits = torch.rand((4, num_tokens), dtype=torch.float, device=device)
input_ids = torch.randint(
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device="cuda"
low=0, high=num_tokens, size=(4, 2048), dtype=torch.int, device=device
)
elif array_library == "numpy":
logits = np.random.rand(4, num_tokens).astype(np.float32)
Expand Down Expand Up @@ -88,6 +85,8 @@ class LogitsProcessorPassthroughBenchmark:
params += ["mlx"]
if torch.cuda.is_available():
params += ["torch_cuda"]
if torch.mps.is_available():
params += ["torch_mps"]
if is_jax_allowed():
params += ["jax"]

Expand All @@ -108,9 +107,10 @@ class LogitsProcessorStructuredBenchmark:
array_libraries = ["torch", "numpy"]
if is_mlx_lm_allowed():
array_libraries += ["mlx"]
# PR TODO
if torch.cuda.is_available():
array_libraries += ["torch_cuda"]
if torch.mps.is_available():
array_libraries += ["torch_mps"]

# accept very many or very few tokens, respectively
patterns = [r"[^Z]*", "Z*"]
Expand Down
11 changes: 4 additions & 7 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,19 @@ def process_logits(

sequence_states.append(self._guide_states[curr_state_key])

mask = torch.ones_like(logits, dtype=torch.bool)

allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to(
mask.device, non_blocking=True
)
allowed_tokens = self.guide.get_next_instruction(guide_state).tokens
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(
torch.full_like(allowed_tokens, i)
) # Store batch index for each allowed token

allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)
allowed_tokens_concat = torch.cat(allowed_tokens_batch).to(logits.device)
batch_indices_concat = torch.cat(batch_indices).to(logits.device)

mask = torch.ones_like(logits, dtype=torch.bool)
mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))

Expand Down
Loading