Skip to content

Per module lm history #8199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 6 additions & 62 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dspy.dsp.utils import settings
from dspy.utils.callback import with_callbacks
from dspy.utils.inspect_history import pretty_print_history

MAX_HISTORY_SIZE = 10_000
GLOBAL_HISTORY = []
Expand Down Expand Up @@ -81,6 +82,9 @@ def _process_lm_response(self, response, prompt, messages, **kwargs):
}
self.history.append(entry)
self.update_global_history(entry)
caller_modules = settings.caller_modules or []
for module in caller_modules:
module.history.append(entry)
return outputs

@with_callbacks
Expand Down Expand Up @@ -129,7 +133,7 @@ def copy(self, **kwargs):
return new_instance

def inspect_history(self, n: int = 1):
_inspect_history(self.history, n)
return inspect_history(self.history, n)

def update_global_history(self, entry):
if settings.disable_history:
Expand All @@ -141,66 +145,6 @@ def update_global_history(self, entry):
GLOBAL_HISTORY.append(entry)


def _green(text: str, end: str = "\n"):
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end


def _red(text: str, end: str = "\n"):
return "\x1b[31m" + str(text) + "\x1b[0m" + end


def _blue(text: str, end: str = "\n"):
return "\x1b[34m" + str(text) + "\x1b[0m" + end


def _inspect_history(history, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
outputs = item["outputs"]
timestamp = item.get("timestamp", "Unknown time")

print("\n\n\n")
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")

for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
if isinstance(msg["content"], str):
print(msg["content"].strip())
else:
if isinstance(msg["content"], list):
for c in msg["content"]:
if c["type"] == "text":
print(c["text"].strip())
elif c["type"] == "image_url":
image_str = ""
if "base64" in c["image_url"].get("url", ""):
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
image_str = (
f"<{c['image_url']['url'].split('base64,')[0]}base64,"
f"<IMAGE BASE 64 ENCODED({len_base64!s})>"
)
else:
image_str = f"<image_url: {c['image_url']['url']}>"
print(_blue(image_str.strip()))
elif c["type"] == "input_audio":
audio_format = c["input_audio"]["format"]
len_audio = len(c["input_audio"]["data"])
audio_str = f"<audio format='{audio_format}' base64-encoded, length={len_audio}>"
print(_blue(audio_str.strip()))
print("\n")

print(_red("Response:"))
print(_green(outputs[0].strip()))

if len(outputs) > 1:
choices_text = f" \t (and {len(outputs) - 1} other completions)"
print(_red(choices_text, end=""))

print("\n\n\n")


def inspect_history(n: int = 1):
"""The global history shared across all LMs."""
return _inspect_history(GLOBAL_HISTORY, n)
return pretty_print_history(GLOBAL_HISTORY, n)
1 change: 1 addition & 0 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
track_usage=False,
usage_tracker=None,
caller_predict=None,
caller_modules=None,
stream_listeners=[],
provide_traceback=False, # Whether to include traceback information in error logs.
num_threads=8, # Number of threads to use for parallel processing.
Expand Down
2 changes: 1 addition & 1 deletion dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

class Predict(Module, Parameter):
def __init__(self, signature, callbacks=None, **config):
super().__init__(callbacks=callbacks)
Copy link
Collaborator

@okhat okhat May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw typical custom DSPy programs don't call super().__init__(...).

Will that be an issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing will crash, but they will lose some perks from dspy.Module. We can follow what PyTorch does - recommending calling super().__init__() when customizing modules: https://docs.pytorch.org/tutorials/beginner/examples_nn/polynomial_module.html#pytorch-custom-nn-modules

self.stage = random.randbytes(8).hex()
self.signature = ensure_signature(signature)
self.config = config
self.callbacks = callbacks or []
self.reset()

def reset(self):
Expand Down
35 changes: 25 additions & 10 deletions dspy/primitives/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dspy.predict.parallel import Parallel
from dspy.primitives.module import BaseModule
from dspy.utils.callback import with_callbacks
from dspy.utils.inspect_history import pretty_print_history
from dspy.utils.usage_tracker import track_usage


Expand All @@ -20,26 +21,38 @@ def _base_init(self):
def __init__(self, callbacks=None):
self.callbacks = callbacks or []
self._compiled = False
# LM calling history of the module.
self.history = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Does this work if there are multiple LMs used in the module for different sub-modules?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! This per-module history works the same for multi-lm situation.

Just tested with the following code and no issue:

import dspy


class MyProgram(dspy.Module):
    def __init__(self):
        super().__init__()
        self.cot = dspy.ChainOfThought("question -> answer")
        self.cot2 = dspy.ChainOfThought("question, answer -> judgement")
        self.cot2.predict.lm = dspy.LM("openai/gpt-4o", cache=False)

    def forward(self, question: str, **kwargs) -> str:
        answer = self.cot(question=question).answer
        return self.cot2(question=question, answer=answer)


dspy.settings.configure(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.JSONAdapter())
program = MyProgram()
program(question="What is the capital of France?")

print(program.history)


@with_callbacks
def __call__(self, *args, **kwargs):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = self.forward(*args, **kwargs)
caller_modules = settings.caller_modules or []
caller_modules = list(caller_modules)
caller_modules.append(self)

with settings.context(caller_modules=caller_modules):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = self.forward(*args, **kwargs)
output.set_lm_usage(usage_tracker.get_total_tokens())
return output

return self.forward(*args, **kwargs)
return self.forward(*args, **kwargs)

@with_callbacks
async def acall(self, *args, **kwargs):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = await self.aforward(*args, **kwargs)
output.set_lm_usage(usage_tracker.get_total_tokens())
return output
caller_modules = settings.caller_modules or []
caller_modules = list(caller_modules)
caller_modules.append(self)

with settings.context(caller_modules=caller_modules):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = await self.aforward(*args, **kwargs)
output.set_lm_usage(usage_tracker.get_total_tokens())
return output

return await self.aforward(*args, **kwargs)
return await self.aforward(*args, **kwargs)

def named_predictors(self):
from dspy.predict.predict import Predict
Expand Down Expand Up @@ -75,6 +88,8 @@ def map_named_predictors(self, func):
set_attribute_by_name(self, name, func(predictor))
return self

def inspect_history(self, n: int = 1):
return pretty_print_history(self.history, n)

def batch(
self,
Expand Down
2 changes: 2 additions & 0 deletions dspy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dspy.utils import exceptions
from dspy.utils.callback import BaseCallback, with_callbacks
from dspy.utils.dummies import DummyLM, DummyVectorizer, dummy_rm
from dspy.utils.inspect_history import pretty_print_history


def download(url):
Expand All @@ -30,4 +31,5 @@ def download(url):
"dummy_rm",
"StatusMessage",
"StatusMessageProvider",
"pretty_print_history",
]
53 changes: 53 additions & 0 deletions dspy/utils/inspect_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
def _green(text: str, end: str = "\n"):
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end


def _red(text: str, end: str = "\n"):
return "\x1b[31m" + str(text) + "\x1b[0m" + end


def _blue(text: str, end: str = "\n"):
return "\x1b[34m" + str(text) + "\x1b[0m" + end


def pretty_print_history(history, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
outputs = item["outputs"]
timestamp = item.get("timestamp", "Unknown time")

print("\n\n\n")
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")

for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
if isinstance(msg["content"], str):
print(msg["content"].strip())
else:
if isinstance(msg["content"], list):
for c in msg["content"]:
if c["type"] == "text":
print(c["text"].strip())
elif c["type"] == "image_url":
image_str = ""
if "base64" in c["image_url"].get("url", ""):
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
image_str = (
f"<{c['image_url']['url'].split('base64,')[0]}base64,"
f"<IMAGE BASE 64 ENCODED({len_base64!s})>"
)
else:
image_str = f"<image_url: {c['image_url']['url']}>"
print(_blue(image_str.strip()))
print("\n")

print(_red("Response:"))
print(_green(outputs[0].strip()))

if len(outputs) > 1:
choices_text = f" \t (and {len(outputs) - 1} other completions)"
print(_red(choices_text, end=""))

print("\n\n\n")
8 changes: 4 additions & 4 deletions tests/predict/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def __init__(self):
self.parallel = dspy.Parallel(num_threads=2)

def forward(self, input):
dspy.settings.configure(lm=lm)
res1 = self.predictor.batch([input] * 5)
with dspy.context(lm=lm):
res1 = self.predictor.batch([input] * 5)

dspy.settings.configure(lm=res_lm)
res2 = self.predictor2.batch([input] * 5)
with dspy.context(lm=res_lm):
res2 = self.predictor2.batch([input] * 5)

return (res1, res2)

Expand Down
Loading