Skip to content

Commit

Permalink
Merge pull request #803 from parea-ai/PAI-918-dspy-integration
Browse files Browse the repository at this point in the history
Pai 918 dspy integration
  • Loading branch information
joschkabraun committed Apr 26, 2024
2 parents 32f6370 + 01e2863 commit eadc3db
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 8 deletions.
24 changes: 20 additions & 4 deletions parea/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,33 @@ def wrap_anthropic_client(self, client: "Anthropic", integration: Optional[str]
if integration:
self._client.add_integration(integration)

def auto_trace_openai_clients(self) -> None:
def auto_trace_openai_clients(self, integration: Optional[str]) -> None:
import openai

openai._ModuleClient = patch_openai_client_classes(openai._ModuleClient, self)
openai.OpenAI = patch_openai_client_classes(openai.OpenAI, self)
openai.AsyncOpenAI = patch_openai_client_classes(openai.AsyncOpenAI, self)
openai.AzureOpenAI = patch_openai_client_classes(openai.AzureOpenAI, self)
openai.AsyncAzureOpenAI = patch_openai_client_classes(openai.AsyncAzureOpenAI, self)

if integration:
self._client.add_integration(integration)

def trace_dspy(self):
from parea.utils.trace_integrations.dspy import DSPyInstrumentor

try:
import openai

if openai.version.__version__.startswith("0."):
self.wrap_openai_client(openai, "dspy")
else:
self.auto_trace_openai_clients("dspy")
except ImportError:
pass

DSPyInstrumentor().instrument()

def integrate_with_sglang(self):
self.auto_trace_openai_clients()
self._client.add_integration("sglang")
Expand Down Expand Up @@ -377,9 +396,6 @@ async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceL
return structure_trace_logs_from_api(response.json())


_initialized_parea_wrapper = False


def patch_openai_client_classes(openai_client, parea_client: Parea):
"""Creates a subclass of the given openai_client to always wrap it with Parea at instantiation."""

Expand Down
146 changes: 146 additions & 0 deletions parea/cookbook/dspy/dspy_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os

import dspy
from dotenv import load_dotenv

from parea import Parea

load_dotenv()

p = Parea(api_key=os.getenv("PAREA_API_KEY"), project_name="testing")
p.trace_dspy()

gpt3_turbo = dspy.OpenAI(model="gpt-3.5-turbo-1106", max_tokens=300)
dspy.configure(lm=gpt3_turbo)
# print(gpt3_turbo("hello! this is a raw prompt to GPT-3.5."))

# Define a retrieval model server to send retrieval requests to
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts")

# Configure retrieval server internally
dspy.settings.configure(rm=colbertv2_wiki17_abstracts)


# Define a simple signature for basic question answering
class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""

question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


# Pass signature to Predict module
generate_answer = dspy.Predict(BasicQA)

# Call the predictor on a particular input.
question = "What is the color of the sky?"
pred = generate_answer(question=question)

print(f"Question: {question}")
print(f"Predicted Answer: {pred.answer}")
question = "What's something great about the ColBERT retrieval model ?!?abc"

# 1) Declare with a signature, and pass some config.
classify = dspy.ChainOfThought("question -> answer", n=1)

# 2) Call with input argument.
response = classify(question=question)

# 3) Access the outputs.
print(response.completions.answer)


# Define a simple signature for basic question answering
class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""

question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


# Pass signature to ChainOfThought module
generate_answer = dspy.ChainOfThought(BasicQA)

# Call the predictor on a particular input.
question = "What is the color of the sky?12"
pred = generate_answer(question=question)

print(f"Question: {question}")
print(f"Predicted Answer: {pred.answer}")


class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""

question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


# Example completions generated by a model for reference
completions = [
dspy.Prediction(rationale=" I recall that during clear days, the sky often appears this colo12r", answer="blue"),
dspy.Prediction(rationale=" Based on common knowledge, I believe the sky is typically seen 12as this color", answer="green"),
dspy.Prediction(rationale=" From images and depictions in media, the sky is frequently42 represented with this hue", answer="blue"),
]

# Pass signature to MultiChainComparison module
compare_answers = dspy.MultiChainComparison(BasicQA)

# Call the MultiChainComparison on the completions
question = " What is the color of th e sky14?"
final_pred = compare_answers(completions, question=question)

print(f"Question: {question}")
print(f"Final Predicted Answer (after comparison): {final_pred.answer}")
print(f"Final Rationale: {final_pred.rationale}")


# Define a simple signature for basic question answering
class GenerateAnswer(dspy.Signature):
"""Answer questions with short factoid answers."""

question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


# Pass signature to ProgramOfThought Module
pot = dspy.ProgramOfThought(GenerateAnswer)

# Call the ProgramOfThought module on a particular input
question = "Sarah has 5 applez. She buys 123 more apples from the store. How many apples does Sarah have now?"
result = pot(question=question)

print(f"Question: {question}")
print(f"Final Predicted Answer (after ProgramOfThought process): {result.answer}")


# Define a simple signature for basic question answering
class BasicQA(dspy.Signature):
"""Answer questions with short factoid answers."""

question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


# Pass signature to ReAct module
react_module = dspy.ReAct(BasicQA, tools=[])

# Call the ReAct module on a particular input
question = "What is the color of the 2 skies?"
result = react_module(question=question)

print(f"Question: {question}")
print(f"Final Predicted Answer (after ReAct process): {result.answer}")


query = "Where was the first FIFA World Cup held?12"


# Call the retriever on a particular query.
retrieve = dspy.Retrieve(k=3)
topK_passages = retrieve(query).passages

print(f"Top {retrieve.k} passages for question: {query} \n", "-" * 30, "\n")

for idx, passage in enumerate(topK_passages):
print(f"{idx+1}]", passage, "\n")
183 changes: 183 additions & 0 deletions parea/utils/trace_integrations/dspy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from typing import Any, Callable, Dict, Mapping, Optional, Tuple

from copy import copy, deepcopy

from wrapt import BoundFunctionWrapper, FunctionWrapper, wrap_object

from parea import trace

_DSPY_MODULE_NAME = "dspy"
_DSP_MODULE_NAME = "dsp"


class DSPyInstrumentor:

def instrument(self) -> None:
# Instrument LM (language model) calls
from dsp.modules.lm import LM
from dspy import Predict

language_model_classes = LM.__subclasses__()
for lm in language_model_classes:
wrap_object(
module=_DSP_MODULE_NAME,
name=lm.__name__ + ".basic_request",
factory=CopyableFunctionWrapper,
args=(_GeneralDSPyWrapper("request"),),
)

# Predict is a concrete (non-abstract) class that may be invoked
# directly, but DSPy also has subclasses of Predict that override the
# forward method. We instrument both the forward methods of the base
# class and all subclasses.
wrap_object(
module=_DSPY_MODULE_NAME,
name="Predict.forward",
factory=CopyableFunctionWrapper,
args=(_PredictForwardWrapper(),),
)

predict_subclasses = Predict.__subclasses__()
for predict_subclass in predict_subclasses:
wrap_object(
module=_DSPY_MODULE_NAME,
name=predict_subclass.__name__ + ".forward",
factory=CopyableFunctionWrapper,
args=(_PredictForwardWrapper(),),
)

wrap_object(
module=_DSPY_MODULE_NAME,
name="Retrieve.forward",
factory=CopyableFunctionWrapper,
args=(_GeneralDSPyWrapper("forward"),),
)

wrap_object(
module=_DSPY_MODULE_NAME,
# At this time, dspy.Module does not have an abstract forward
# method, but assumes that user-defined subclasses implement the
# forward method and invokes that method using __call__.
name="Module.__call__",
factory=CopyableFunctionWrapper,
args=(_GeneralDSPyWrapper("forward"),),
)

# At this time, there is no common parent class for retriever models as
# there is for language models. We instrument the retriever models on a
# case-by-case basis.
wrap_object(
module=_DSP_MODULE_NAME,
name="ColBERTv2.__call__",
factory=CopyableFunctionWrapper,
args=(_GeneralDSPyWrapper("__call__"),),
)


class CopyableBoundFunctionWrapper(BoundFunctionWrapper): # type: ignore
"""
A bound function wrapper that can be copied and deep-copied. When used to
wrap a class method, this allows the entire class to be copied and
deep-copied.
For reference, see
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271
and
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers
"""

def __copy__(self) -> "CopyableBoundFunctionWrapper":
return CopyableBoundFunctionWrapper(copy(self.__wrapped__), self._self_instance, self._self_wrapper)

def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableBoundFunctionWrapper":
return CopyableBoundFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_instance, self._self_wrapper)


class CopyableFunctionWrapper(FunctionWrapper): # type: ignore
"""
A function wrapper that can be copied and deep-copied. When used to wrap a
class method, this allows the entire class to be copied and deep-copied.
For reference, see
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271
and
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers
"""

__bound_function_wrapper__ = CopyableBoundFunctionWrapper

def __copy__(self) -> "CopyableFunctionWrapper":
return CopyableFunctionWrapper(copy(self.__wrapped__), self._self_wrapper)

def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableFunctionWrapper":
return CopyableFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_wrapper)


class _GeneralDSPyWrapper:
def __init__(self, method_name: str):
self._method_name = method_name

def __call__(
self,
wrapped: Callable[..., Any],
instance: Any,
args: Tuple[type, Any],
kwargs: Mapping[str, Any],
) -> Any:
span_name = instance.__class__.__name__ + "." + self._method_name
return trace(name=span_name)(wrapped)(*args, **kwargs)


class _PredictForwardWrapper:
"""
A wrapper for the Predict class to have a chain span for each prediction
"""

def __call__(
self,
wrapped: Callable[..., Any],
instance: Any,
args: Tuple[type, Any],
kwargs: Mapping[str, Any],
) -> Any:
from dspy import Predict

# At this time, subclasses of Predict override the base class' forward
# method and invoke the parent class' forward method from within the
# overridden method. The forward method for both Predict and its
# subclasses have been instrumented. To avoid creating duplicate spans
# for a single invocation, we don't create a span for the base class'
# forward method if the instance belongs to a proper subclass of Predict
# with an overridden forward method.
is_instance_of_predict_subclass = isinstance(instance, Predict) and (cls := instance.__class__) is not Predict
has_overridden_forward_method = getattr(cls, "forward", None) is not getattr(Predict, "forward", None)
wrapped_method_is_base_class_forward_method = wrapped.__qualname__ == Predict.forward.__qualname__
if is_instance_of_predict_subclass and has_overridden_forward_method and wrapped_method_is_base_class_forward_method:
return wrapped(*args, **kwargs)
else:
return trace(name=_get_predict_span_name(instance))(wrapped)(*args, **kwargs)


def _get_predict_span_name(instance: Any) -> str:
"""
Gets the name for the Predict span, which are the composition of a Predict
class or subclass and a user-defined signature. An example name would be
"Predict(UserDefinedSignature).forward".
"""
class_name = str(instance.__class__.__name__)
if (signature := getattr(instance, "signature", None)) and (signature_name := _get_signature_name(signature)):
return f"{class_name}({signature_name}).forward"
return f"{class_name}.forward"


def _get_signature_name(signature: Any) -> Optional[str]:
"""
A best-effort attempt to get the name of a signature.
"""
if (
# At the time of this writing, the __name__ attribute on signatures does
# not return the user-defined class name, but __qualname__ does.
qual_name := getattr(signature, "__qualname__", None)
) is None:
return None
return str(qual_name.split(".")[-1])
Loading

0 comments on commit eadc3db

Please sign in to comment.