|
| 1 | +import json |
| 2 | +import random |
| 3 | +from typing import Any, List, Union |
| 4 | + |
| 5 | +from steamship import Block, Steamship, SteamshipError, Task |
| 6 | +from steamship.agents.llms import OpenAI |
| 7 | +from steamship.agents.schema import AgentContext, Tool |
| 8 | +from steamship.agents.utils import get_llm, with_llm |
| 9 | +from steamship.utils.repl import ToolREPL |
| 10 | + |
| 11 | +DEFAULT_PROMPT = """INSTRUCTIONS: |
| 12 | +Generate a JSON object describing {table_description}. |
| 13 | +Always return a non-empty value for every field in the object. |
| 14 | +
|
| 15 | +FIELDS DESIRED: |
| 16 | +{fields_desired} |
| 17 | +
|
| 18 | +EXAMPLE OBJECTS: |
| 19 | +{example_objects} |
| 20 | +
|
| 21 | +NEW OBJECT: |
| 22 | +{new_object_prefix} |
| 23 | +""" |
| 24 | + |
| 25 | +DEFAULT_PLURAL_OBJECT_DESCRIPTION = "employees of a company" |
| 26 | +DEFAULT_OBJECT_KEYS = ["Name", "Age", "Gender"] |
| 27 | +DEFAULT_OBJECT_EXAMPLES = [ |
| 28 | + ["Bob", 30, "Male"], |
| 29 | + ["Susan", 32, "Female"], |
| 30 | + ["Zhenzhong", 40, "Male"], |
| 31 | + ["Luis", 32, "Male"], |
| 32 | + ["Roberta", 35, "Female"], |
| 33 | + ["Sofia", 30, "Female"], |
| 34 | +] |
| 35 | +DEFAULT_NEW_ROW_PREFIX_FIELDS = [] |
| 36 | + |
| 37 | + |
| 38 | +class JsonObjectGeneratorTool(Tool): |
| 39 | + """ |
| 40 | + Example tool to illustrate generating a new JSON object provided a set of examples. |
| 41 | +
|
| 42 | + This is useful as an example of how to generate a new structured object: |
| 43 | +
|
| 44 | + - A Person (e.g. name, gender, age) |
| 45 | + - A Proposed Podcast Episode (e.g. title, description, tags) |
| 46 | +
|
| 47 | + The tool takes no input at runtime: it's a true generator parameterized only at initializtion time. |
| 48 | +
|
| 49 | + The tool's parameterization is somewhat CSV-like in construction. |
| 50 | +
|
| 51 | + """ |
| 52 | + |
| 53 | + rewrite_prompt: str = DEFAULT_PROMPT |
| 54 | + """The prompt used to generate a new JSON object.""" |
| 55 | + |
| 56 | + plural_object_description: str = DEFAULT_PLURAL_OBJECT_DESCRIPTION |
| 57 | + """Plural description of the object. E.g. 'employees of a company' or 'people' or 'podcast episodes'""" |
| 58 | + |
| 59 | + object_keys: List[str] = DEFAULT_OBJECT_KEYS |
| 60 | + """The keys the JSON should have.""" |
| 61 | + |
| 62 | + example_rows: List[List[str]] = DEFAULT_OBJECT_EXAMPLES |
| 63 | + """List of example object values, aligned to the `object_keys` parameter.""" |
| 64 | + |
| 65 | + new_row_prefix_fields: List[str] = DEFAULT_NEW_ROW_PREFIX_FIELDS |
| 66 | + """Any fields that should be hard-coded for the new row. These must be grouped as the first set of fields.""" |
| 67 | + |
| 68 | + shuffle_example_rows: bool = True |
| 69 | + """Whether randomly shuffle example rows to induce a bit of variety even with low LLM temperature.""" |
| 70 | + |
| 71 | + validate_output_as_json: bool = True |
| 72 | + """Whether to validate that the output is actually JSON.""" |
| 73 | + |
| 74 | + name: str = "JsonObjectTool" |
| 75 | + human_description: str = "Generates a new JSON object." |
| 76 | + agent_description: str = "(set at initialization time)" |
| 77 | + |
| 78 | + def __init__(self, *args, **kwargs): |
| 79 | + super().__init__(*args, **kwargs) |
| 80 | + self.agent_description = ( |
| 81 | + f"Used to generate an instance of the {self.plural_object_description} table. " |
| 82 | + "Input: Anything. " |
| 83 | + f"Output A tab-separated row describing a new instance of the {self.plural_object_description} table." |
| 84 | + ) |
| 85 | + |
| 86 | + def kv_clause(self, key: str, value: str) -> str: |
| 87 | + """Return an escaped, JSON style key-value clause `"key": "value"`""" |
| 88 | + value = str(value).replace('"', '\\"') |
| 89 | + clause = f'"{key}": "{value}"' |
| 90 | + return clause |
| 91 | + |
| 92 | + def object_json(self, schema: List[str], values: List[str]): |
| 93 | + """Render a CSV-style header row and value list into a JSON object.""" |
| 94 | + clauses = [] |
| 95 | + for field, value in zip(schema, values): |
| 96 | + clauses.append(self.kv_clause(field, value)) |
| 97 | + |
| 98 | + return "{" + ", ".join(clauses) + "}" |
| 99 | + |
| 100 | + def run(self, tool_input: List[Block], context: AgentContext) -> Union[List[Block], Task[Any]]: |
| 101 | + """Ignore tool input and generate a JSON object described by the tool's configuration. |
| 102 | +
|
| 103 | + Inputs |
| 104 | + ------ |
| 105 | + input: List[Block] |
| 106 | + A list of blocks that will be ignored. |
| 107 | + memory: AgentContext |
| 108 | + The active AgentContext. |
| 109 | +
|
| 110 | + Output |
| 111 | + ------ |
| 112 | + output: List[Blocks] |
| 113 | + A single block containing a new row of the table described by the tool's configuration. |
| 114 | + """ |
| 115 | + |
| 116 | + if self.shuffle_example_rows: |
| 117 | + # Shuffle the example rows to get a bit of variety even with low temperature. |
| 118 | + random.shuffle(self.example_rows) |
| 119 | + |
| 120 | + # Generate example JSON objects with a fixed key ordering. |
| 121 | + example_objects = [ |
| 122 | + self.object_json(self.object_keys, example_row) for example_row in self.example_rows |
| 123 | + ] |
| 124 | + example_objects_str = "\n".join(example_objects) |
| 125 | + |
| 126 | + # Generate the new row line. At a minimum it's the `{` character, but it may also hard-code a number of |
| 127 | + # fields that should be affixed rather than generated. |
| 128 | + new_object_prefix = "{" |
| 129 | + for i in range(len(self.new_row_prefix_fields)): |
| 130 | + clause = self.kv_clause(self.object_keys[i], self.new_row_prefix_fields[i]) |
| 131 | + new_object_prefix += f"{clause}, " |
| 132 | + |
| 133 | + # Compile the final generation prompt. |
| 134 | + prompt = self.rewrite_prompt.format( |
| 135 | + table_description=self.plural_object_description, |
| 136 | + fields_desired=", ".join(self.object_keys), |
| 137 | + example_objects=example_objects_str, |
| 138 | + new_object_prefix=new_object_prefix, |
| 139 | + ) |
| 140 | + |
| 141 | + # Perform the generation |
| 142 | + llm = get_llm(context) |
| 143 | + res = llm.complete(prompt, stop="}") |
| 144 | + |
| 145 | + # Make sure we only generated one block; anything else violates the assumptions of this code. |
| 146 | + blocks_emitted = len(res) |
| 147 | + if blocks_emitted != 1: |
| 148 | + raise SteamshipError(message=f"{blocks_emitted} blocks emitted; expecting 1.") |
| 149 | + |
| 150 | + # The output JSON is generation prefix row, plus the generated content, plus a final } character |
| 151 | + # The reason we have to add the final "}" character is because we used it for the stop character |
| 152 | + full_json = new_object_prefix + res[0].text + "}" |
| 153 | + |
| 154 | + if self.validate_output_as_json: |
| 155 | + try: |
| 156 | + json.loads(full_json) |
| 157 | + except BaseException: |
| 158 | + raise SteamshipError( |
| 159 | + message=f"Attempted to generate a JSON object, but did not generate valid JSON. Result: {full_json}" |
| 160 | + ) |
| 161 | + |
| 162 | + res[0].text = full_json |
| 163 | + return res |
| 164 | + |
| 165 | + |
| 166 | +if __name__ == "__main__": |
| 167 | + with Steamship.temporary_workspace() as client: |
| 168 | + ToolREPL(JsonObjectGeneratorTool()).run_with_client( |
| 169 | + client=client, context=with_llm(llm=OpenAI(client=client)) |
| 170 | + ) |
0 commit comments