diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 9080d4cbe..0c3ae2ee7 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -3,6 +3,7 @@ from dspy.dsp.utils import settings from dspy.utils.callback import with_callbacks +from dspy.utils.inspect_history import pretty_print_history MAX_HISTORY_SIZE = 10_000 GLOBAL_HISTORY = [] @@ -81,6 +82,9 @@ def _process_lm_response(self, response, prompt, messages, **kwargs): } self.history.append(entry) self.update_global_history(entry) + caller_modules = settings.caller_modules or [] + for module in caller_modules: + module.history.append(entry) return outputs @with_callbacks @@ -129,7 +133,7 @@ def copy(self, **kwargs): return new_instance def inspect_history(self, n: int = 1): - _inspect_history(self.history, n) + return inspect_history(self.history, n) def update_global_history(self, entry): if settings.disable_history: @@ -141,61 +145,6 @@ def update_global_history(self, entry): GLOBAL_HISTORY.append(entry) -def _green(text: str, end: str = "\n"): - return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end - - -def _red(text: str, end: str = "\n"): - return "\x1b[31m" + str(text) + "\x1b[0m" + end - - -def _blue(text: str, end: str = "\n"): - return "\x1b[34m" + str(text) + "\x1b[0m" + end - - -def _inspect_history(history, n: int = 1): - """Prints the last n prompts and their completions.""" - - for item in history[-n:]: - messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] - outputs = item["outputs"] - timestamp = item.get("timestamp", "Unknown time") - - print("\n\n\n") - print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") - - for msg in messages: - print(_red(f"{msg['role'].capitalize()} message:")) - if isinstance(msg["content"], str): - print(msg["content"].strip()) - else: - if isinstance(msg["content"], list): - for c in msg["content"]: - if c["type"] == "text": - print(c["text"].strip()) - elif c["type"] == "image_url": - image_str = "" - if "base64" in c["image_url"].get("url", ""): - len_base64 = len(c["image_url"]["url"].split("base64,")[1]) - image_str = ( - f"<{c['image_url']['url'].split('base64,')[0]}base64," - f"" - ) - else: - image_str = f"" - print(_blue(image_str.strip())) - print("\n") - - print(_red("Response:")) - print(_green(outputs[0].strip())) - - if len(outputs) > 1: - choices_text = f" \t (and {len(outputs) - 1} other completions)" - print(_red(choices_text, end="")) - - print("\n\n\n") - - def inspect_history(n: int = 1): """The global history shared across all LMs.""" - return _inspect_history(GLOBAL_HISTORY, n) + return pretty_print_history(GLOBAL_HISTORY, n) diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index 983927f33..264ac62f3 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -23,6 +23,7 @@ track_usage=False, usage_tracker=None, caller_predict=None, + caller_modules=None, stream_listeners=[], provide_traceback=False, # Whether to include traceback information in error logs. num_threads=8, # Number of threads to use for parallel processing. diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 9f78ca804..4dd463619 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -17,10 +17,10 @@ class Predict(Module, Parameter): def __init__(self, signature, callbacks=None, **config): + super().__init__(callbacks=callbacks) self.stage = random.randbytes(8).hex() self.signature = ensure_signature(signature) self.config = config - self.callbacks = callbacks or [] self.reset() def reset(self): diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index b02c8e8da..b72b260ad 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -6,6 +6,7 @@ from dspy.predict.parallel import Parallel from dspy.primitives.module import BaseModule from dspy.utils.callback import with_callbacks +from dspy.utils.inspect_history import pretty_print_history from dspy.utils.usage_tracker import track_usage @@ -20,26 +21,38 @@ def _base_init(self): def __init__(self, callbacks=None): self.callbacks = callbacks or [] self._compiled = False + # LM calling history of the module. + self.history = [] @with_callbacks def __call__(self, *args, **kwargs): - if settings.track_usage and settings.usage_tracker is None: - with track_usage() as usage_tracker: - output = self.forward(*args, **kwargs) + caller_modules = settings.caller_modules or [] + caller_modules = list(caller_modules) + caller_modules.append(self) + + with settings.context(caller_modules=caller_modules): + if settings.track_usage and settings.usage_tracker is None: + with track_usage() as usage_tracker: + output = self.forward(*args, **kwargs) output.set_lm_usage(usage_tracker.get_total_tokens()) return output - return self.forward(*args, **kwargs) + return self.forward(*args, **kwargs) @with_callbacks async def acall(self, *args, **kwargs): - if settings.track_usage and settings.usage_tracker is None: - with track_usage() as usage_tracker: - output = await self.aforward(*args, **kwargs) - output.set_lm_usage(usage_tracker.get_total_tokens()) - return output + caller_modules = settings.caller_modules or [] + caller_modules = list(caller_modules) + caller_modules.append(self) - return await self.aforward(*args, **kwargs) + with settings.context(caller_modules=caller_modules): + if settings.track_usage and settings.usage_tracker is None: + with track_usage() as usage_tracker: + output = await self.aforward(*args, **kwargs) + output.set_lm_usage(usage_tracker.get_total_tokens()) + return output + + return await self.aforward(*args, **kwargs) def named_predictors(self): from dspy.predict.predict import Predict @@ -75,6 +88,9 @@ def map_named_predictors(self, func): set_attribute_by_name(self, name, func(predictor)) return self + def inspect_history(self, n: int = 1): + return pretty_print_history(self.history, n) + # def activate_assertions(self, handler=backtrack_handler, **handler_args): # """ # Activates assertions for the module. diff --git a/dspy/utils/__init__.py b/dspy/utils/__init__.py index 84178e491..9c431d16b 100644 --- a/dspy/utils/__init__.py +++ b/dspy/utils/__init__.py @@ -1,6 +1,7 @@ from dspy.utils.callback import BaseCallback, with_callbacks from dspy.utils.dummies import DummyLM, DummyVectorizer, dummy_rm from dspy.streaming.messages import StatusMessageProvider, StatusMessage +from dspy.utils.inspect_history import pretty_print_history import os import requests @@ -27,4 +28,5 @@ def download(url): "dummy_rm", "StatusMessage", "StatusMessageProvider", + "pretty_print_history", ] diff --git a/dspy/utils/inspect_history.py b/dspy/utils/inspect_history.py new file mode 100644 index 000000000..16baefd30 --- /dev/null +++ b/dspy/utils/inspect_history.py @@ -0,0 +1,53 @@ +def _green(text: str, end: str = "\n"): + return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end + + +def _red(text: str, end: str = "\n"): + return "\x1b[31m" + str(text) + "\x1b[0m" + end + + +def _blue(text: str, end: str = "\n"): + return "\x1b[34m" + str(text) + "\x1b[0m" + end + + +def pretty_print_history(history, n: int = 1): + """Prints the last n prompts and their completions.""" + + for item in history[-n:]: + messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] + outputs = item["outputs"] + timestamp = item.get("timestamp", "Unknown time") + + print("\n\n\n") + print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") + + for msg in messages: + print(_red(f"{msg['role'].capitalize()} message:")) + if isinstance(msg["content"], str): + print(msg["content"].strip()) + else: + if isinstance(msg["content"], list): + for c in msg["content"]: + if c["type"] == "text": + print(c["text"].strip()) + elif c["type"] == "image_url": + image_str = "" + if "base64" in c["image_url"].get("url", ""): + len_base64 = len(c["image_url"]["url"].split("base64,")[1]) + image_str = ( + f"<{c['image_url']['url'].split('base64,')[0]}base64," + f"" + ) + else: + image_str = f"" + print(_blue(image_str.strip())) + print("\n") + + print(_red("Response:")) + print(_green(outputs[0].strip())) + + if len(outputs) > 1: + choices_text = f" \t (and {len(outputs) - 1} other completions)" + print(_red(choices_text, end="")) + + print("\n\n\n") diff --git a/tests/predict/test_parallel.py b/tests/predict/test_parallel.py index da99e5757..06ed1234f 100644 --- a/tests/predict/test_parallel.py +++ b/tests/predict/test_parallel.py @@ -4,13 +4,15 @@ def test_parallel_module(): - lm = DummyLM([ - {"output": "test output 1"}, - {"output": "test output 2"}, - {"output": "test output 3"}, - {"output": "test output 4"}, - {"output": "test output 5"}, - ]) + lm = DummyLM( + [ + {"output": "test output 1"}, + {"output": "test output 2"}, + {"output": "test output 3"}, + {"output": "test output 4"}, + {"output": "test output 5"}, + ] + ) dspy.settings.configure(lm=lm) class MyModule(dspy.Module): @@ -22,13 +24,15 @@ def __init__(self): self.parallel = dspy.Parallel(num_threads=2) def forward(self, input): - return self.parallel([ - (self.predictor, input), - (self.predictor2, input), - (self.predictor, input), - (self.predictor2, input), - (self.predictor, input), - ]) + return self.parallel( + [ + (self.predictor, input), + (self.predictor2, input), + (self.predictor, input), + (self.predictor2, input), + (self.predictor, input), + ] + ) output = MyModule()(dspy.Example(input="test input").with_inputs("input")) @@ -37,20 +41,24 @@ def forward(self, input): def test_batch_module(): - lm = DummyLM([ - {"output": "test output 1"}, - {"output": "test output 2"}, - {"output": "test output 3"}, - {"output": "test output 4"}, - {"output": "test output 5"}, - ]) - res_lm = DummyLM([ - {"output": "test output 1", "reasoning": "test reasoning 1"}, - {"output": "test output 2", "reasoning": "test reasoning 2"}, - {"output": "test output 3", "reasoning": "test reasoning 3"}, - {"output": "test output 4", "reasoning": "test reasoning 4"}, - {"output": "test output 5", "reasoning": "test reasoning 5"}, - ]) + lm = DummyLM( + [ + {"output": "test output 1"}, + {"output": "test output 2"}, + {"output": "test output 3"}, + {"output": "test output 4"}, + {"output": "test output 5"}, + ] + ) + res_lm = DummyLM( + [ + {"output": "test output 1", "reasoning": "test reasoning 1"}, + {"output": "test output 2", "reasoning": "test reasoning 2"}, + {"output": "test output 3", "reasoning": "test reasoning 3"}, + {"output": "test output 4", "reasoning": "test reasoning 4"}, + {"output": "test output 5", "reasoning": "test reasoning 5"}, + ] + ) class MyModule(dspy.Module): def __init__(self): @@ -61,11 +69,11 @@ def __init__(self): self.parallel = dspy.Parallel(num_threads=2) def forward(self, input): - dspy.settings.configure(lm=lm) - res1 = self.predictor.batch([input] * 5) + with dspy.context(lm=lm): + res1 = self.predictor.batch([input] * 5) - dspy.settings.configure(lm=res_lm) - res2 = self.predictor2.batch([input] * 5) + with dspy.context(lm=res_lm): + res2 = self.predictor2.batch([input] * 5) return (res1, res2) @@ -83,13 +91,15 @@ def forward(self, input): def test_nested_parallel_module(): - lm = DummyLM([ - {"output": "test output 1"}, - {"output": "test output 2"}, - {"output": "test output 3"}, - {"output": "test output 4"}, - {"output": "test output 5"}, - ]) + lm = DummyLM( + [ + {"output": "test output 1"}, + {"output": "test output 2"}, + {"output": "test output 3"}, + {"output": "test output 4"}, + {"output": "test output 5"}, + ] + ) dspy.settings.configure(lm=lm) class MyModule(dspy.Module): @@ -101,14 +111,19 @@ def __init__(self): self.parallel = dspy.Parallel(num_threads=2) def forward(self, input): - return self.parallel([ - (self.predictor, input), - (self.predictor2, input), - (self.parallel, [ - (self.predictor2, input), + return self.parallel( + [ (self.predictor, input), - ]), - ]) + (self.predictor2, input), + ( + self.parallel, + [ + (self.predictor2, input), + (self.predictor, input), + ], + ), + ] + ) output = MyModule()(dspy.Example(input="test input").with_inputs("input")) @@ -120,13 +135,15 @@ def forward(self, input): def test_nested_batch_method(): - lm = DummyLM([ - {"output": "test output 1"}, - {"output": "test output 2"}, - {"output": "test output 3"}, - {"output": "test output 4"}, - {"output": "test output 5"}, - ]) + lm = DummyLM( + [ + {"output": "test output 1"}, + {"output": "test output 2"}, + {"output": "test output 3"}, + {"output": "test output 4"}, + {"output": "test output 5"}, + ] + ) dspy.settings.configure(lm=lm) class MyModule(dspy.Module): @@ -135,11 +152,15 @@ def __init__(self): self.predictor = dspy.Predict("input -> output") def forward(self, input): - res = self.predictor.batch([dspy.Example(input=input).with_inputs("input")]*2) + res = self.predictor.batch([dspy.Example(input=input).with_inputs("input")] * 2) return res - result = MyModule().batch([dspy.Example(input="test input").with_inputs("input")]*2) + result = MyModule().batch([dspy.Example(input="test input").with_inputs("input")] * 2) - assert {result[0][0].output, result[0][1].output, result[1][0].output, result[1][1].output} \ - == {"test output 1", "test output 2", "test output 3", "test output 4"} + assert {result[0][0].output, result[0][1].output, result[1][0].output, result[1][1].output} == { + "test output 1", + "test output 2", + "test output 3", + "test output 4", + } diff --git a/tests/primitives/test_module.py b/tests/primitives/test_module.py index 886213847..d6595503f 100644 --- a/tests/primitives/test_module.py +++ b/tests/primitives/test_module.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest import os +from litellm import ModelResponse, Message, Choices def test_deepcopy_basic(): @@ -243,3 +244,134 @@ def __call__(self, question: str) -> str: assert results[0].get_lm_usage().keys() == set(["openai/gpt-4o-mini"]) assert results[1].get_lm_usage().keys() == set(["openai/gpt-3.5-turbo"]) + + +def test_module_history(): + class MyProgram(dspy.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.cot = dspy.ChainOfThought("question -> answer") + + def forward(self, question: str, **kwargs) -> str: + return self.cot(question=question) + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices(message=Message(content="{'reasoning': 'Paris is the captial of France', 'answer': 'Paris'}")) + ], + model="openai/gpt-4o-mini", + ) + dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()) + program = MyProgram() + program(question="What is the capital of France?") + + # Second call only call the submodule. + program.cot(question="What is the capital of France?") + + # The LM history entity exists in all the ancestor callers. + assert len(program.history) == 1 + assert len(program.cot.history) == 2 + assert len(program.cot.predict.history) == 2 + + # The same history entity is shared across all the ancestor callers to reduce memory usage. + assert id(program.history[0]) == id(program.cot.history[0]) + + assert program.history[0]["outputs"] == ["{'reasoning': 'Paris is the captial of France', 'answer': 'Paris'}"] + + dspy.settings.configure(disable_history=True) + + program(question="What is the capital of France?") + # No history is recorded when history is disabled. + assert len(program.history) == 1 + assert len(program.cot.history) == 2 + assert len(program.cot.predict.history) == 2 + + dspy.settings.configure(disable_history=False) + + program(question="What is the capital of France?") + # History is recorded again when history is enabled. + assert len(program.history) == 2 + assert len(program.cot.history) == 3 + assert len(program.cot.predict.history) == 3 + + +def test_module_history_with_concurrency(): + class MyProgram(dspy.Module): + def __init__(self): + super().__init__() + self.cot = dspy.ChainOfThought("question -> answer") + + def forward(self, question: str, **kwargs) -> str: + return self.cot(question=question) + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[Choices(message=Message(content="{'reasoning': 'N/A', 'answer': 'Holy crab!'}"))], + model="openai/gpt-4o-mini", + ) + dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()) + program = MyProgram() + + parallelizer = dspy.Parallel() + + parallelizer( + [ + (program, {"question": "What is the meaning of life?"}), + (program, {"question": "why did a chicken cross the kitchen?"}), + ] + ) + assert len(program.history) == 2 + assert len(program.cot.history) == 2 + assert len(program.cot.predict.history) == 2 + + +@pytest.mark.asyncio +async def test_module_history_async(): + class MyProgram(dspy.Module): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.cot = dspy.ChainOfThought("question -> answer") + + async def aforward(self, question: str, **kwargs) -> str: + return await self.cot.acall(question=question) + + with patch("litellm.acompletion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices(message=Message(content="{'reasoning': 'Paris is the captial of France', 'answer': 'Paris'}")) + ], + model="openai/gpt-4o-mini", + ) + dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter()) + program = MyProgram() + await program.acall(question="What is the capital of France?") + + # Second call only call the submodule. + await program.cot.acall(question="What is the capital of France?") + + # The LM history entity exists in all the ancestor callers. + assert len(program.history) == 1 + assert len(program.cot.history) == 2 + assert len(program.cot.predict.history) == 2 + + # The same history entity is shared across all the ancestor callers to reduce memory usage. + assert id(program.history[0]) == id(program.cot.history[0]) + + assert program.history[0]["outputs"] == ["{'reasoning': 'Paris is the captial of France', 'answer': 'Paris'}"] + + dspy.settings.configure(disable_history=True) + + await program.acall(question="What is the capital of France?") + # No history is recorded when history is disabled. + assert len(program.history) == 1 + assert len(program.cot.history) == 2 + assert len(program.cot.predict.history) == 2 + + dspy.settings.configure(disable_history=False) + + await program.acall(question="What is the capital of France?") + # History is recorded again when history is enabled. + assert len(program.history) == 2 + assert len(program.cot.history) == 3 + assert len(program.cot.predict.history) == 3