From c266977465ca46a53def75298bdf34ab784219c6 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Wed, 10 Apr 2024 17:09:10 -0400 Subject: [PATCH] fix(cohere): more prompt bugs (param sigs) --- kani/engines/huggingface/base.py | 2 +- kani/json_schema.py | 8 ++++++-- kani/prompts/impl/cohere.py | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/kani/engines/huggingface/base.py b/kani/engines/huggingface/base.py index 4de28d0..40a9f61 100644 --- a/kani/engines/huggingface/base.py +++ b/kani/engines/huggingface/base.py @@ -76,7 +76,7 @@ def __init__( self.hyperparams = hyperparams if device is None: - device = "cuda" if torch.has_cuda else "cpu" + device = "cuda" if torch.backends.cuda.is_built() else "cpu" self.device = device if self.model.device.type != self.device: self.model.to(device) diff --git a/kani/json_schema.py b/kani/json_schema.py index e5b5140..6124676 100644 --- a/kani/json_schema.py +++ b/kani/json_schema.py @@ -20,7 +20,7 @@ class AIParamSchema: def __init__(self, name: str, t: type, default, aiparam: Optional["AIParam"], inspect_param: inspect.Parameter): self.name = name - self.type = t + self.type = t # will not include Annotated if present self.default = default self.aiparam = aiparam self.inspect_param = inspect_param @@ -39,7 +39,11 @@ def description(self): return self.aiparam.desc if self.aiparam is not None else None def __str__(self): - return str(self.inspect_param) + default = "" + if not self.required: + default = f" = {self.default!r}" + annotation = inspect.formatannotation(self.type) + return f"{self.name}: {annotation}{default}" class JSONSchemaBuilder(pydantic.json_schema.GenerateJsonSchema): diff --git a/kani/prompts/impl/cohere.py b/kani/prompts/impl/cohere.py index 2d2b210..9c0acdc 100644 --- a/kani/prompts/impl/cohere.py +++ b/kani/prompts/impl/cohere.py @@ -218,7 +218,7 @@ def function_prompt(f: AIFunction) -> str: doc_params = [] for param in params: desc = f": {param.description}" if param.description else "" - doc_params.append(f"{param.name} ({param.type}){desc}") + doc_params.append(f"{param.name} ({inspect.formatannotation(param.type)}){desc}") args += "\n ".join(doc_params) # return