-
Notifications
You must be signed in to change notification settings - Fork 12
/
demo.py
102 lines (81 loc) · 2.22 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from superjsonmode.integrations.openai import StructuredOpenAIModel
from openai import OpenAI
from pydantic import BaseModel
import time
import json
class color:
CYAN = "\033[96m"
BOLD = "\033[1m"
END = "\033[0m"
print(
"\n"
+ color.BOLD
+ "Generating JSON naively with OpenAI gpt-3.5-turbo..."
+ color.END
)
print("-------------------------------------------")
prompt = """Luke Skywalker is a famous character."""
start = time.time()
default_prompt = f"""{prompt}
Based on the prompt above, generate a JSON blob with the following keys: "name", "genre", "age", "race", "occupation", "best_friend", and "home_planet".
"""
print(default_prompt)
client = OpenAI()
completion = client.chat.completions.create(
model="gpt-3.5-turbo", messages=[{"role": "user", "content": default_prompt}]
)
print(
"Total time: "
+ color.CYAN
+ color.BOLD
+ f"{time.time() - start} seconds"
+ color.END
)
print("-------------------------------------------")
print(completion.choices[0].message.content)
print("-------------------------------------------\n")
print(
"\n" + color.BOLD + "Testing same OpenAI model with Super JSON Mode..." + color.END
)
print("-------------------------------------------")
model = StructuredOpenAIModel()
class Character(BaseModel):
name: str
genre: str
age: int
race: str
occupation: str
best_friend: str
home_planet: str
prompt_template = """{prompt}
Please fill in the following information about this character for this key. Keep it succinct. It should be a {type}.
{key}: """
start = time.time()
output = model.generate(
prompt,
extraction_prompt_template=prompt_template,
schema=Character,
batch_size=7,
stop=["\n\n"],
temperature=0,
)
print(
"Total time: "
+ color.CYAN
+ color.BOLD
+ f"{time.time() - start} seconds"
+ color.END
)
print("-------------------------------------------")
# Total Time: 0.409s
print(json.dumps(output, indent=2))
# {
# "name": "Luke Skywalker",
# "genre": "Science fiction",
# "age": "23",
# "race": "Human",
# "occupation": "Jedi Knight",
# "best_friend": "Han Solo",
# "home_planet": "Tatooine",
# }
print("-------------------------------------------")