diff --git a/outlines/prompts.py b/outlines/prompts.py index a7824451a..40c6a4012 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -1,3 +1,4 @@ +import dis import functools import inspect import json @@ -42,6 +43,26 @@ def __str__(self): return self.template +def _template_variable_assigned(fn: Callable) -> Optional[str]: + instruction_set = list(dis.Bytecode(fn)) + if ( + len(instruction_set) >= 3 + and instruction_set[2].opname == "STORE_FAST" + and instruction_set[2].argrepr == "_" + ): + return instruction_set[1].argval + elif any( + instr + for instr in instruction_set + if instr.opname == "STORE_FAST" and instr.argrepr == "_" + ): + raise ValueError( + "Can only use simple string literals to `_` and it must be the only instruction in the function." + ) + else: + return None + + def prompt(fn: Callable) -> Prompt: """Decorate a function that contains a prompt template. @@ -79,17 +100,24 @@ def prompt(fn: Callable) -> Prompt: """ - signature = inspect.signature(fn) - - # The docstring contains the template that will be rendered to be used - # as a prompt to the language model. - docstring = fn.__doc__ - if docstring is None: - raise TypeError("Could not find a template in the function's docstring.") + # Either the docstring or `_` variable contains the template + # that will be rendered to be used as a prompt to the language model. + potential_template = _template_variable_assigned(fn) + if potential_template is None: + docstring = fn.__doc__ + if docstring is None: + raise TypeError( + "Could not find a template in the function's docstring or assigned to `_`" + ) + else: + potential_template = docstring - template = cast(str, docstring) + template = cast(str, potential_template) + signature = inspect.signature(fn) - return Prompt(template, signature) + prompt_instance = Prompt(template, signature) + functools.update_wrapper(prompt_instance, fn) + return prompt_instance def render(template: str, **values: Optional[Dict[str, Any]]) -> str: diff --git a/tests/test_prompts.py b/tests/test_prompts.py index a0433c0e5..f46d5837e 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -148,6 +148,39 @@ def test_kwarg_tpl(var, other_var="other"): assert p == "test and test" +def test_prompt_template_keyword(): + import pydoc + + @outlines.prompt + def test_tpl_keyword(variable): + _ = """{{variable}} test""" + + assert test_tpl_keyword.template == "{{variable}} test" + assert test_tpl_keyword.parameters == ["variable"] + + @outlines.prompt + def test_tpl_keyword_w_docstring(variable): + """custom docstring""" + _ = """{{variable}} test""" + + assert "custom docstring" in pydoc.render_doc(test_tpl_keyword_w_docstring) + + @outlines.prompt + def test_tpl_keyword_single_quotes(variable): + _ = "{{variable}} test" + + assert test_tpl_keyword_single_quotes.template == "{{variable}} test" + assert test_tpl_keyword_single_quotes.parameters == ["variable"] + + +def test_prompt_bad_template_keyword(): + with pytest.raises(ValueError, match="_"): + + @outlines.prompt + def test_tpl_keyword(variable): + _ = f"{'test'}" + + def test_no_prompt(): with pytest.raises(TypeError, match="template"):