diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index d2bc15f77..24ac6d7e4 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -23,6 +23,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import math from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union @@ -110,6 +111,9 @@ def process_logits( allowed_tokens = self.guide.get_next_instruction(guide_state).tokens.to( mask.device, non_blocking=True ) + allowed_tokens = allowed_tokens[ + allowed_tokens < mask.shape[-1] + ] # filter out input ids exceeding the mask length allowed_tokens_batch.append(allowed_tokens) batch_indices.append( torch.full_like(allowed_tokens, i)