Skip to content

Per module lm history #8199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 6 additions & 57 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dspy.dsp.utils import settings
from dspy.utils.callback import with_callbacks
from dspy.utils.inspect_history import pretty_print_history

MAX_HISTORY_SIZE = 10_000
GLOBAL_HISTORY = []
Expand Down Expand Up @@ -81,6 +82,9 @@ def _process_lm_response(self, response, prompt, messages, **kwargs):
}
self.history.append(entry)
self.update_global_history(entry)
caller_modules = settings.caller_modules or []
for module in caller_modules:
module.history.append(entry)
return outputs

@with_callbacks
Expand Down Expand Up @@ -129,7 +133,7 @@ def copy(self, **kwargs):
return new_instance

def inspect_history(self, n: int = 1):
_inspect_history(self.history, n)
return inspect_history(self.history, n)

def update_global_history(self, entry):
if settings.disable_history:
Expand All @@ -141,61 +145,6 @@ def update_global_history(self, entry):
GLOBAL_HISTORY.append(entry)


def _green(text: str, end: str = "\n"):
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end


def _red(text: str, end: str = "\n"):
return "\x1b[31m" + str(text) + "\x1b[0m" + end


def _blue(text: str, end: str = "\n"):
return "\x1b[34m" + str(text) + "\x1b[0m" + end


def _inspect_history(history, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
outputs = item["outputs"]
timestamp = item.get("timestamp", "Unknown time")

print("\n\n\n")
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")

for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
if isinstance(msg["content"], str):
print(msg["content"].strip())
else:
if isinstance(msg["content"], list):
for c in msg["content"]:
if c["type"] == "text":
print(c["text"].strip())
elif c["type"] == "image_url":
image_str = ""
if "base64" in c["image_url"].get("url", ""):
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
image_str = (
f"<{c['image_url']['url'].split('base64,')[0]}base64,"
f"<IMAGE BASE 64 ENCODED({len_base64!s})>"
)
else:
image_str = f"<image_url: {c['image_url']['url']}>"
print(_blue(image_str.strip()))
print("\n")

print(_red("Response:"))
print(_green(outputs[0].strip()))

if len(outputs) > 1:
choices_text = f" \t (and {len(outputs) - 1} other completions)"
print(_red(choices_text, end=""))

print("\n\n\n")


def inspect_history(n: int = 1):
"""The global history shared across all LMs."""
return _inspect_history(GLOBAL_HISTORY, n)
return pretty_print_history(GLOBAL_HISTORY, n)
1 change: 1 addition & 0 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
track_usage=False,
usage_tracker=None,
caller_predict=None,
caller_modules=None,
stream_listeners=[],
provide_traceback=False, # Whether to include traceback information in error logs.
num_threads=8, # Number of threads to use for parallel processing.
Expand Down
2 changes: 1 addition & 1 deletion dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

class Predict(Module, Parameter):
def __init__(self, signature, callbacks=None, **config):
super().__init__(callbacks=callbacks)
self.stage = random.randbytes(8).hex()
self.signature = ensure_signature(signature)
self.config = config
self.callbacks = callbacks or []
self.reset()

def reset(self):
Expand Down
36 changes: 26 additions & 10 deletions dspy/primitives/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dspy.predict.parallel import Parallel
from dspy.primitives.module import BaseModule
from dspy.utils.callback import with_callbacks
from dspy.utils.inspect_history import pretty_print_history
from dspy.utils.usage_tracker import track_usage


Expand All @@ -20,26 +21,38 @@ def _base_init(self):
def __init__(self, callbacks=None):
self.callbacks = callbacks or []
self._compiled = False
# LM calling history of the module.
self.history = []

@with_callbacks
def __call__(self, *args, **kwargs):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = self.forward(*args, **kwargs)
caller_modules = settings.caller_modules or []
caller_modules = list(caller_modules)
caller_modules.append(self)

with settings.context(caller_modules=caller_modules):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = self.forward(*args, **kwargs)
output.set_lm_usage(usage_tracker.get_total_tokens())
return output

return self.forward(*args, **kwargs)
return self.forward(*args, **kwargs)

@with_callbacks
async def acall(self, *args, **kwargs):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = await self.aforward(*args, **kwargs)
output.set_lm_usage(usage_tracker.get_total_tokens())
return output
caller_modules = settings.caller_modules or []
caller_modules = list(caller_modules)
caller_modules.append(self)

return await self.aforward(*args, **kwargs)
with settings.context(caller_modules=caller_modules):
if settings.track_usage and settings.usage_tracker is None:
with track_usage() as usage_tracker:
output = await self.aforward(*args, **kwargs)
output.set_lm_usage(usage_tracker.get_total_tokens())
return output

return await self.aforward(*args, **kwargs)

def named_predictors(self):
from dspy.predict.predict import Predict
Expand Down Expand Up @@ -75,6 +88,9 @@ def map_named_predictors(self, func):
set_attribute_by_name(self, name, func(predictor))
return self

def inspect_history(self, n: int = 1):
return pretty_print_history(self.history, n)

# def activate_assertions(self, handler=backtrack_handler, **handler_args):
# """
# Activates assertions for the module.
Expand Down
2 changes: 2 additions & 0 deletions dspy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dspy.utils.callback import BaseCallback, with_callbacks
from dspy.utils.dummies import DummyLM, DummyVectorizer, dummy_rm
from dspy.streaming.messages import StatusMessageProvider, StatusMessage
from dspy.utils.inspect_history import pretty_print_history

import os
import requests
Expand All @@ -27,4 +28,5 @@ def download(url):
"dummy_rm",
"StatusMessage",
"StatusMessageProvider",
"pretty_print_history",
]
53 changes: 53 additions & 0 deletions dspy/utils/inspect_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
def _green(text: str, end: str = "\n"):
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end


def _red(text: str, end: str = "\n"):
return "\x1b[31m" + str(text) + "\x1b[0m" + end


def _blue(text: str, end: str = "\n"):
return "\x1b[34m" + str(text) + "\x1b[0m" + end


def pretty_print_history(history, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
outputs = item["outputs"]
timestamp = item.get("timestamp", "Unknown time")

print("\n\n\n")
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")

for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
if isinstance(msg["content"], str):
print(msg["content"].strip())
else:
if isinstance(msg["content"], list):
for c in msg["content"]:
if c["type"] == "text":
print(c["text"].strip())
elif c["type"] == "image_url":
image_str = ""
if "base64" in c["image_url"].get("url", ""):
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
image_str = (
f"<{c['image_url']['url'].split('base64,')[0]}base64,"
f"<IMAGE BASE 64 ENCODED({len_base64!s})>"
)
else:
image_str = f"<image_url: {c['image_url']['url']}>"
print(_blue(image_str.strip()))
print("\n")

print(_red("Response:"))
print(_green(outputs[0].strip()))

if len(outputs) > 1:
choices_text = f" \t (and {len(outputs) - 1} other completions)"
print(_red(choices_text, end=""))

print("\n\n\n")
Loading