Skip to content

Commit

Permalink
Use OpenAI API For Structured Generation (json, choice)
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 authored and rlouf committed Sep 17, 2024
1 parent 2b1aed0 commit 5591950
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 261 deletions.
27 changes: 19 additions & 8 deletions outlines/generate/choice.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json as pyjson
from functools import singledispatch
from typing import Callable, List

from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial

from .json import json
from .regex import regex


Expand All @@ -24,13 +26,22 @@ def choice(
def choice_openai(
model: OpenAI, choices: List[str], sampler: Sampler = multinomial()
) -> Callable:
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The OpenAI API does not support any other sampling algorithm "
+ "that the multinomial sampler."
)

def generate_choice(prompt: str, max_tokens: int = 1):
return model.generate_choice(prompt, choices, max_tokens)
"""
Call OpenAI API with response_format of a dict:
{"result": <one of choices>}
"""

choices_schema = pyjson.dumps(
{
"type": "object",
"properties": {"result": {"type": "string", "enum": choices}},
"additionalProperties": False,
"required": ["result"],
}
)
generator = json(model, choices_schema, sampler)

def generate_choice(*args, **kwargs):
return generator(*args, **kwargs)["result"]

return generate_choice
38 changes: 34 additions & 4 deletions outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,39 @@ def json(

@json.register(OpenAI)
def json_openai(
model, schema_object: Union[str, object, Callable], sampler: Sampler = multinomial()
model, schema_object: Union[str, object], sampler: Sampler = multinomial()
):
raise NotImplementedError(
"Cannot use JSON Schema-structure generation with an OpenAI model "
+ "due to the limitations of the OpenAI API"
if not isinstance(sampler, multinomial):
raise NotImplementedError(
r"The OpenAI API does not support any other sampling algorithm "
+ "than the multinomial sampler."
)

if isinstance(schema_object, type(BaseModel)):
schema = pyjson.dumps(schema_object.model_json_schema())
format_sequence = lambda x: schema_object.parse_raw(x)
elif isinstance(schema_object, str):
schema = schema_object
format_sequence = lambda x: pyjson.loads(x)
else:
raise ValueError(
f"Cannot parse schema {schema_object}. The schema must be either "
+ "a Pydantic object, a function or a string that contains the JSON "
+ "Schema specification"
)

# create copied, patched model with normalized json schema set
generator = model.new_with_replacements(
response_format={
"type": "json_schema",
"json_schema": {
"name": "default",
"strict": True,
"schema": pyjson.loads(schema),
},
}
)

generator.format_sequence = format_sequence

return generator
213 changes: 13 additions & 200 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Integration with OpenAI's API."""
import copy
import functools
import warnings
from dataclasses import asdict, dataclass, field, replace
from itertools import zip_longest
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -74,7 +73,6 @@ def __init__(
self,
client,
config,
tokenizer=None,
system_prompt: Optional[str] = None,
):
"""Create an `OpenAI` instance.
Expand All @@ -89,13 +87,9 @@ def __init__(
config
An instance of `OpenAIConfig`. Can be useful to specify some
parameters that cannot be set by calling this class' methods.
tokenizer
The tokenizer associated with the model the client connects to.
"""

self.client = client
self.tokenizer = tokenizer
self.config = config

# We count the total number of prompt and generated tokens as returned
Expand All @@ -104,6 +98,8 @@ def __init__(
self.prompt_tokens = 0
self.completion_tokens = 0

self.format_sequence = lambda seq: seq

def __call__(
self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -152,107 +148,17 @@ def __call__(
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens

return response
return self.format_sequence(response)

def stream(self, *args, **kwargs):
raise NotImplementedError(
"Streaming is currently not supported for the OpenAI API"
)

def generate_choice(
self,
prompt: str,
choices: List[str],
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
) -> str:
"""Call the OpenAI API to generate one of several choices.
Parameters
----------
prompt
A string or list of strings that will be used to prompt the model
choices
The list of strings between which we ask the model to choose
max_tokens
The maximum number of tokens to generate
system_prompt
The content of the system message that precedes the user's prompt.
"""
if self.tokenizer is None:
raise ValueError(
"You must initialize the `OpenAI` class with a tokenizer to use `outlines.generate.choice`"
)

config = replace(self.config, max_tokens=max_tokens)

greedy = False
decoded: List[str] = []
encoded_choices_left: List[List[int]] = [
self.tokenizer.encode(word) for word in choices
]

while len(encoded_choices_left) > 0:
max_tokens_left = max([len(tokens) for tokens in encoded_choices_left])
transposed_choices_left: List[Set] = [
{item for item in subset if item is not None}
for subset in zip_longest(*encoded_choices_left)
]

if not greedy:
mask = build_optimistic_mask(transposed_choices_left)
else:
mask = {}
for token in transposed_choices_left[0]: # build greedy mask
mask[token] = 100

if len(mask) == 0:
break

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)

response, prompt_tokens, completion_tokens = generate_chat(
prompt, system_prompt, self.client, config
)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens

encoded_response = self.tokenizer.encode(response)

if encoded_response in encoded_choices_left:
decoded.append(response)
break
else:
(
encoded_response,
encoded_choices_left,
) = find_response_choices_intersection(
encoded_response, encoded_choices_left
)

if len(encoded_response) == 0:
greedy = True # next iteration will be "greedy"
continue
else:
decoded.append("".join(self.tokenizer.decode(encoded_response)))

if len(encoded_choices_left) == 1: # only one choice left
choice_left = self.tokenizer.decode(encoded_choices_left[0])
decoded.append(choice_left)
break

greedy = False # after each success, stay with (or switch to) "optimistic" approach

prompt = prompt + "".join(decoded)

choice = "".join(decoded)

return choice

def generate_json(self):
"""Call the OpenAI API to generate a JSON object."""
raise NotImplementedError
def new_with_replacements(self, **kwargs):
new_instance = copy.copy(self)
new_instance.config = replace(new_instance.config, **kwargs)
return new_instance

def __str__(self):
return self.__class__.__name__ + " API"
Expand Down Expand Up @@ -313,81 +219,6 @@ async def call_api(prompt, system_prompt, config):
return results, usage["prompt_tokens"], usage["completion_tokens"]


def find_longest_intersection(response: List[int], choice: List[int]) -> List[int]:
"""Find the longest intersection between the response and the choice."""
for i, (token_r, token_c) in enumerate(zip_longest(response, choice)):
if token_r != token_c:
return response[:i]

return response


def find_response_choices_intersection(
response: List[int], choices: List[List[int]]
) -> Tuple[List[int], List[List[int]]]:
"""Find the longest intersection between the response and the different
choices.
Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices
`[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the
intersection, and `[[]]` as the list of choices left.
Parameters
----------
response
The model's response
choices
The remaining possible choices
Returns
-------
A tuple that contains the longest intersection between the response and the
different choices, and the choices which start with this intersection, with the
intersection removed.
"""
max_len_prefix = 0
choices_left = []
longest_prefix = []
for i, choice in enumerate(choices):
# Find the longest intersection between the response and the choice.
prefix = find_longest_intersection(response, choice)

if len(prefix) > max_len_prefix:
max_len_prefix = len(prefix)
choices_left = [choice[len(prefix) :]]
longest_prefix = prefix

elif len(prefix) == max_len_prefix:
choices_left.append(choice[len(prefix) :])

return longest_prefix, choices_left


def build_optimistic_mask(
transposed: List[Set[int]], max_mask_size: int = 300
) -> Dict[int, int]:
"""We build the largest mask possible.
Tokens are added from left to right, so if the encoded choices are e.g.
`[[1,2], [3,4]]`, `1` and `3` will be added before `2` and `4`.
Parameters
----------
transposed
A list of lists that contain the nth token of each choice.
"""
mask: Dict[int, int] = {}
for tokens in transposed:
for token in tokens:
if len(mask) == max_mask_size:
return mask
mask[token] = 100

return mask


def error_handler(api_call_fn: Callable) -> Callable:
"""Handle OpenAI API errors and missing API key."""

Expand Down Expand Up @@ -427,11 +258,10 @@ def openai_model(
**openai_client_params,
):
try:
import tiktoken
from openai import AsyncOpenAI
except ImportError:
raise ImportError(
"The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' OpenAI integration."
"The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
)

if config is not None:
Expand All @@ -441,15 +271,7 @@ def openai_model(

client = AsyncOpenAI(**openai_client_params)

try:
tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
warnings.warn(
f"Could not find a tokenizer for model {model_name}. Using default cl100k_base."
)
tokenizer = tiktoken.get_encoding("cl100k_base")

return OpenAI(client, config, tokenizer)
return OpenAI(client, config)


def azure_openai(
Expand All @@ -459,11 +281,10 @@ def azure_openai(
**azure_openai_client_params,
):
try:
import tiktoken
from openai import AsyncAzureOpenAI
except ImportError:
raise ImportError(
"The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' Azure OpenAI integration."
"The `openai` library needs to be installed in order to use Outlines' Azure OpenAI integration."
)

if config is not None:
Expand All @@ -473,12 +294,4 @@ def azure_openai(

client = AsyncAzureOpenAI(**azure_openai_client_params)

try:
tokenizer = tiktoken.encoding_for_model(model_name or deployment_name)
except KeyError:
warnings.warn(
f"Could not find a tokenizer for model {model_name or deployment_name}. Using default cl100k_base."
)
tokenizer = tiktoken.get_encoding("cl100k_base")

return OpenAI(client, config, tokenizer)
return OpenAI(client, config)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ module = [
"pydantic.*",
"pytest",
"referencing.*",
"tiktoken.*",
"torch.*",
"transformers.*",
"llama_cpp",
Expand Down
Loading

0 comments on commit 5591950

Please sign in to comment.