-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #803 from parea-ai/PAI-918-dspy-integration
Pai 918 dspy integration
- Loading branch information
Showing
6 changed files
with
470 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import os | ||
|
||
import dspy | ||
from dotenv import load_dotenv | ||
|
||
from parea import Parea | ||
|
||
load_dotenv() | ||
|
||
p = Parea(api_key=os.getenv("PAREA_API_KEY"), project_name="testing") | ||
p.trace_dspy() | ||
|
||
gpt3_turbo = dspy.OpenAI(model="gpt-3.5-turbo-1106", max_tokens=300) | ||
dspy.configure(lm=gpt3_turbo) | ||
# print(gpt3_turbo("hello! this is a raw prompt to GPT-3.5.")) | ||
|
||
# Define a retrieval model server to send retrieval requests to | ||
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url="http://20.102.90.50:2017/wiki17_abstracts") | ||
|
||
# Configure retrieval server internally | ||
dspy.settings.configure(rm=colbertv2_wiki17_abstracts) | ||
|
||
|
||
# Define a simple signature for basic question answering | ||
class BasicQA(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
|
||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
|
||
# Pass signature to Predict module | ||
generate_answer = dspy.Predict(BasicQA) | ||
|
||
# Call the predictor on a particular input. | ||
question = "What is the color of the sky?" | ||
pred = generate_answer(question=question) | ||
|
||
print(f"Question: {question}") | ||
print(f"Predicted Answer: {pred.answer}") | ||
question = "What's something great about the ColBERT retrieval model ?!?abc" | ||
|
||
# 1) Declare with a signature, and pass some config. | ||
classify = dspy.ChainOfThought("question -> answer", n=1) | ||
|
||
# 2) Call with input argument. | ||
response = classify(question=question) | ||
|
||
# 3) Access the outputs. | ||
print(response.completions.answer) | ||
|
||
|
||
# Define a simple signature for basic question answering | ||
class BasicQA(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
|
||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
|
||
# Pass signature to ChainOfThought module | ||
generate_answer = dspy.ChainOfThought(BasicQA) | ||
|
||
# Call the predictor on a particular input. | ||
question = "What is the color of the sky?12" | ||
pred = generate_answer(question=question) | ||
|
||
print(f"Question: {question}") | ||
print(f"Predicted Answer: {pred.answer}") | ||
|
||
|
||
class BasicQA(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
|
||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
|
||
# Example completions generated by a model for reference | ||
completions = [ | ||
dspy.Prediction(rationale=" I recall that during clear days, the sky often appears this colo12r", answer="blue"), | ||
dspy.Prediction(rationale=" Based on common knowledge, I believe the sky is typically seen 12as this color", answer="green"), | ||
dspy.Prediction(rationale=" From images and depictions in media, the sky is frequently42 represented with this hue", answer="blue"), | ||
] | ||
|
||
# Pass signature to MultiChainComparison module | ||
compare_answers = dspy.MultiChainComparison(BasicQA) | ||
|
||
# Call the MultiChainComparison on the completions | ||
question = " What is the color of th e sky14?" | ||
final_pred = compare_answers(completions, question=question) | ||
|
||
print(f"Question: {question}") | ||
print(f"Final Predicted Answer (after comparison): {final_pred.answer}") | ||
print(f"Final Rationale: {final_pred.rationale}") | ||
|
||
|
||
# Define a simple signature for basic question answering | ||
class GenerateAnswer(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
|
||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
|
||
# Pass signature to ProgramOfThought Module | ||
pot = dspy.ProgramOfThought(GenerateAnswer) | ||
|
||
# Call the ProgramOfThought module on a particular input | ||
question = "Sarah has 5 applez. She buys 123 more apples from the store. How many apples does Sarah have now?" | ||
result = pot(question=question) | ||
|
||
print(f"Question: {question}") | ||
print(f"Final Predicted Answer (after ProgramOfThought process): {result.answer}") | ||
|
||
|
||
# Define a simple signature for basic question answering | ||
class BasicQA(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
|
||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
|
||
# Pass signature to ReAct module | ||
react_module = dspy.ReAct(BasicQA, tools=[]) | ||
|
||
# Call the ReAct module on a particular input | ||
question = "What is the color of the 2 skies?" | ||
result = react_module(question=question) | ||
|
||
print(f"Question: {question}") | ||
print(f"Final Predicted Answer (after ReAct process): {result.answer}") | ||
|
||
|
||
query = "Where was the first FIFA World Cup held?12" | ||
|
||
|
||
# Call the retriever on a particular query. | ||
retrieve = dspy.Retrieve(k=3) | ||
topK_passages = retrieve(query).passages | ||
|
||
print(f"Top {retrieve.k} passages for question: {query} \n", "-" * 30, "\n") | ||
|
||
for idx, passage in enumerate(topK_passages): | ||
print(f"{idx+1}]", passage, "\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from typing import Any, Callable, Dict, Mapping, Optional, Tuple | ||
|
||
from copy import copy, deepcopy | ||
|
||
from wrapt import BoundFunctionWrapper, FunctionWrapper, wrap_object | ||
|
||
from parea import trace | ||
|
||
_DSPY_MODULE_NAME = "dspy" | ||
_DSP_MODULE_NAME = "dsp" | ||
|
||
|
||
class DSPyInstrumentor: | ||
|
||
def instrument(self) -> None: | ||
# Instrument LM (language model) calls | ||
from dsp.modules.lm import LM | ||
from dspy import Predict | ||
|
||
language_model_classes = LM.__subclasses__() | ||
for lm in language_model_classes: | ||
wrap_object( | ||
module=_DSP_MODULE_NAME, | ||
name=lm.__name__ + ".basic_request", | ||
factory=CopyableFunctionWrapper, | ||
args=(_GeneralDSPyWrapper("request"),), | ||
) | ||
|
||
# Predict is a concrete (non-abstract) class that may be invoked | ||
# directly, but DSPy also has subclasses of Predict that override the | ||
# forward method. We instrument both the forward methods of the base | ||
# class and all subclasses. | ||
wrap_object( | ||
module=_DSPY_MODULE_NAME, | ||
name="Predict.forward", | ||
factory=CopyableFunctionWrapper, | ||
args=(_PredictForwardWrapper(),), | ||
) | ||
|
||
predict_subclasses = Predict.__subclasses__() | ||
for predict_subclass in predict_subclasses: | ||
wrap_object( | ||
module=_DSPY_MODULE_NAME, | ||
name=predict_subclass.__name__ + ".forward", | ||
factory=CopyableFunctionWrapper, | ||
args=(_PredictForwardWrapper(),), | ||
) | ||
|
||
wrap_object( | ||
module=_DSPY_MODULE_NAME, | ||
name="Retrieve.forward", | ||
factory=CopyableFunctionWrapper, | ||
args=(_GeneralDSPyWrapper("forward"),), | ||
) | ||
|
||
wrap_object( | ||
module=_DSPY_MODULE_NAME, | ||
# At this time, dspy.Module does not have an abstract forward | ||
# method, but assumes that user-defined subclasses implement the | ||
# forward method and invokes that method using __call__. | ||
name="Module.__call__", | ||
factory=CopyableFunctionWrapper, | ||
args=(_GeneralDSPyWrapper("forward"),), | ||
) | ||
|
||
# At this time, there is no common parent class for retriever models as | ||
# there is for language models. We instrument the retriever models on a | ||
# case-by-case basis. | ||
wrap_object( | ||
module=_DSP_MODULE_NAME, | ||
name="ColBERTv2.__call__", | ||
factory=CopyableFunctionWrapper, | ||
args=(_GeneralDSPyWrapper("__call__"),), | ||
) | ||
|
||
|
||
class CopyableBoundFunctionWrapper(BoundFunctionWrapper): # type: ignore | ||
""" | ||
A bound function wrapper that can be copied and deep-copied. When used to | ||
wrap a class method, this allows the entire class to be copied and | ||
deep-copied. | ||
For reference, see | ||
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271 | ||
and | ||
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers | ||
""" | ||
|
||
def __copy__(self) -> "CopyableBoundFunctionWrapper": | ||
return CopyableBoundFunctionWrapper(copy(self.__wrapped__), self._self_instance, self._self_wrapper) | ||
|
||
def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableBoundFunctionWrapper": | ||
return CopyableBoundFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_instance, self._self_wrapper) | ||
|
||
|
||
class CopyableFunctionWrapper(FunctionWrapper): # type: ignore | ||
""" | ||
A function wrapper that can be copied and deep-copied. When used to wrap a | ||
class method, this allows the entire class to be copied and deep-copied. | ||
For reference, see | ||
https://github.com/GrahamDumpleton/wrapt/issues/86#issuecomment-426161271 | ||
and | ||
https://wrapt.readthedocs.io/en/master/wrappers.html#custom-function-wrappers | ||
""" | ||
|
||
__bound_function_wrapper__ = CopyableBoundFunctionWrapper | ||
|
||
def __copy__(self) -> "CopyableFunctionWrapper": | ||
return CopyableFunctionWrapper(copy(self.__wrapped__), self._self_wrapper) | ||
|
||
def __deepcopy__(self, memo: Dict[Any, Any]) -> "CopyableFunctionWrapper": | ||
return CopyableFunctionWrapper(deepcopy(self.__wrapped__, memo), self._self_wrapper) | ||
|
||
|
||
class _GeneralDSPyWrapper: | ||
def __init__(self, method_name: str): | ||
self._method_name = method_name | ||
|
||
def __call__( | ||
self, | ||
wrapped: Callable[..., Any], | ||
instance: Any, | ||
args: Tuple[type, Any], | ||
kwargs: Mapping[str, Any], | ||
) -> Any: | ||
span_name = instance.__class__.__name__ + "." + self._method_name | ||
return trace(name=span_name)(wrapped)(*args, **kwargs) | ||
|
||
|
||
class _PredictForwardWrapper: | ||
""" | ||
A wrapper for the Predict class to have a chain span for each prediction | ||
""" | ||
|
||
def __call__( | ||
self, | ||
wrapped: Callable[..., Any], | ||
instance: Any, | ||
args: Tuple[type, Any], | ||
kwargs: Mapping[str, Any], | ||
) -> Any: | ||
from dspy import Predict | ||
|
||
# At this time, subclasses of Predict override the base class' forward | ||
# method and invoke the parent class' forward method from within the | ||
# overridden method. The forward method for both Predict and its | ||
# subclasses have been instrumented. To avoid creating duplicate spans | ||
# for a single invocation, we don't create a span for the base class' | ||
# forward method if the instance belongs to a proper subclass of Predict | ||
# with an overridden forward method. | ||
is_instance_of_predict_subclass = isinstance(instance, Predict) and (cls := instance.__class__) is not Predict | ||
has_overridden_forward_method = getattr(cls, "forward", None) is not getattr(Predict, "forward", None) | ||
wrapped_method_is_base_class_forward_method = wrapped.__qualname__ == Predict.forward.__qualname__ | ||
if is_instance_of_predict_subclass and has_overridden_forward_method and wrapped_method_is_base_class_forward_method: | ||
return wrapped(*args, **kwargs) | ||
else: | ||
return trace(name=_get_predict_span_name(instance))(wrapped)(*args, **kwargs) | ||
|
||
|
||
def _get_predict_span_name(instance: Any) -> str: | ||
""" | ||
Gets the name for the Predict span, which are the composition of a Predict | ||
class or subclass and a user-defined signature. An example name would be | ||
"Predict(UserDefinedSignature).forward". | ||
""" | ||
class_name = str(instance.__class__.__name__) | ||
if (signature := getattr(instance, "signature", None)) and (signature_name := _get_signature_name(signature)): | ||
return f"{class_name}({signature_name}).forward" | ||
return f"{class_name}.forward" | ||
|
||
|
||
def _get_signature_name(signature: Any) -> Optional[str]: | ||
""" | ||
A best-effort attempt to get the name of a signature. | ||
""" | ||
if ( | ||
# At the time of this writing, the __name__ attribute on signatures does | ||
# not return the user-defined class name, but __qualname__ does. | ||
qual_name := getattr(signature, "__qualname__", None) | ||
) is None: | ||
return None | ||
return str(qual_name.split(".")[-1]) |
Oops, something went wrong.