-
Notifications
You must be signed in to change notification settings - Fork 2k
Per module lm history #8199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Per module lm history #8199
Changes from all commits
5a60019
0a299a3
de634b6
858c3b1
b04a0d9
743289c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from dspy.predict.parallel import Parallel | ||
from dspy.primitives.module import BaseModule | ||
from dspy.utils.callback import with_callbacks | ||
from dspy.utils.inspect_history import pretty_print_history | ||
from dspy.utils.usage_tracker import track_usage | ||
|
||
|
||
|
@@ -20,26 +21,38 @@ def _base_init(self): | |
def __init__(self, callbacks=None): | ||
self.callbacks = callbacks or [] | ||
self._compiled = False | ||
# LM calling history of the module. | ||
self.history = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool! Does this work if there are multiple LMs used in the module for different sub-modules? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes! This per-module history works the same for multi-lm situation. Just tested with the following code and no issue:
|
||
|
||
@with_callbacks | ||
def __call__(self, *args, **kwargs): | ||
if settings.track_usage and settings.usage_tracker is None: | ||
with track_usage() as usage_tracker: | ||
output = self.forward(*args, **kwargs) | ||
caller_modules = settings.caller_modules or [] | ||
caller_modules = list(caller_modules) | ||
caller_modules.append(self) | ||
|
||
with settings.context(caller_modules=caller_modules): | ||
if settings.track_usage and settings.usage_tracker is None: | ||
with track_usage() as usage_tracker: | ||
output = self.forward(*args, **kwargs) | ||
output.set_lm_usage(usage_tracker.get_total_tokens()) | ||
return output | ||
|
||
return self.forward(*args, **kwargs) | ||
return self.forward(*args, **kwargs) | ||
|
||
@with_callbacks | ||
async def acall(self, *args, **kwargs): | ||
if settings.track_usage and settings.usage_tracker is None: | ||
with track_usage() as usage_tracker: | ||
output = await self.aforward(*args, **kwargs) | ||
output.set_lm_usage(usage_tracker.get_total_tokens()) | ||
return output | ||
caller_modules = settings.caller_modules or [] | ||
caller_modules = list(caller_modules) | ||
caller_modules.append(self) | ||
|
||
with settings.context(caller_modules=caller_modules): | ||
if settings.track_usage and settings.usage_tracker is None: | ||
with track_usage() as usage_tracker: | ||
output = await self.aforward(*args, **kwargs) | ||
output.set_lm_usage(usage_tracker.get_total_tokens()) | ||
return output | ||
|
||
return await self.aforward(*args, **kwargs) | ||
return await self.aforward(*args, **kwargs) | ||
|
||
def named_predictors(self): | ||
from dspy.predict.predict import Predict | ||
|
@@ -75,6 +88,8 @@ def map_named_predictors(self, func): | |
set_attribute_by_name(self, name, func(predictor)) | ||
return self | ||
|
||
def inspect_history(self, n: int = 1): | ||
return pretty_print_history(self.history, n) | ||
|
||
def batch( | ||
self, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
def _green(text: str, end: str = "\n"): | ||
return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end | ||
|
||
|
||
def _red(text: str, end: str = "\n"): | ||
return "\x1b[31m" + str(text) + "\x1b[0m" + end | ||
|
||
|
||
def _blue(text: str, end: str = "\n"): | ||
return "\x1b[34m" + str(text) + "\x1b[0m" + end | ||
|
||
|
||
def pretty_print_history(history, n: int = 1): | ||
"""Prints the last n prompts and their completions.""" | ||
|
||
for item in history[-n:]: | ||
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] | ||
outputs = item["outputs"] | ||
timestamp = item.get("timestamp", "Unknown time") | ||
|
||
print("\n\n\n") | ||
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") | ||
|
||
for msg in messages: | ||
print(_red(f"{msg['role'].capitalize()} message:")) | ||
if isinstance(msg["content"], str): | ||
print(msg["content"].strip()) | ||
else: | ||
if isinstance(msg["content"], list): | ||
for c in msg["content"]: | ||
if c["type"] == "text": | ||
print(c["text"].strip()) | ||
elif c["type"] == "image_url": | ||
image_str = "" | ||
if "base64" in c["image_url"].get("url", ""): | ||
len_base64 = len(c["image_url"]["url"].split("base64,")[1]) | ||
image_str = ( | ||
f"<{c['image_url']['url'].split('base64,')[0]}base64," | ||
f"<IMAGE BASE 64 ENCODED({len_base64!s})>" | ||
) | ||
else: | ||
image_str = f"<image_url: {c['image_url']['url']}>" | ||
print(_blue(image_str.strip())) | ||
print("\n") | ||
|
||
print(_red("Response:")) | ||
print(_green(outputs[0].strip())) | ||
|
||
if len(outputs) > 1: | ||
choices_text = f" \t (and {len(outputs) - 1} other completions)" | ||
print(_red(choices_text, end="")) | ||
|
||
print("\n\n\n") |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw typical custom DSPy programs don't call
super().__init__(...)
.Will that be an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing will crash, but they will lose some perks from
dspy.Module
. We can follow what PyTorch does - recommending callingsuper().__init__()
when customizing modules: https://docs.pytorch.org/tutorials/beginner/examples_nn/polynomial_module.html#pytorch-custom-nn-modules