Skip to content

Commit

Permalink
fix(cohere): more prompt bugs (param sigs)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Apr 10, 2024
1 parent 524081f commit c266977
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion kani/engines/huggingface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self.hyperparams = hyperparams

if device is None:
device = "cuda" if torch.has_cuda else "cpu"
device = "cuda" if torch.backends.cuda.is_built() else "cpu"
self.device = device
if self.model.device.type != self.device:
self.model.to(device)
Expand Down
8 changes: 6 additions & 2 deletions kani/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AIParamSchema:

def __init__(self, name: str, t: type, default, aiparam: Optional["AIParam"], inspect_param: inspect.Parameter):
self.name = name
self.type = t
self.type = t # will not include Annotated if present
self.default = default
self.aiparam = aiparam
self.inspect_param = inspect_param
Expand All @@ -39,7 +39,11 @@ def description(self):
return self.aiparam.desc if self.aiparam is not None else None

def __str__(self):
return str(self.inspect_param)
default = ""
if not self.required:
default = f" = {self.default!r}"
annotation = inspect.formatannotation(self.type)
return f"{self.name}: {annotation}{default}"


class JSONSchemaBuilder(pydantic.json_schema.GenerateJsonSchema):
Expand Down
2 changes: 1 addition & 1 deletion kani/prompts/impl/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def function_prompt(f: AIFunction) -> str:
doc_params = []
for param in params:
desc = f": {param.description}" if param.description else ""
doc_params.append(f"{param.name} ({param.type}){desc}")
doc_params.append(f"{param.name} ({inspect.formatannotation(param.type)}){desc}")
args += "\n ".join(doc_params)

# return
Expand Down

0 comments on commit c266977

Please sign in to comment.