Skip to content

Commit

Permalink
fix: fix default value access for prompt runners
Browse files Browse the repository at this point in the history
call field.get_default which will call .default_factory if passed in
instead of just assuming .default is there
  • Loading branch information
jonasi committed Mar 29, 2024
1 parent ad3a5e1 commit 40e15d7
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions langdspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,12 @@ def __init__(self, n_jobs=1, **kwargs):
self.kwargs = {**kwargs, 'trained_state': self.trained_state}
for field_name, field in self.__fields__.items():
if issubclass(field.type_, PromptRunner):
self.prompt_runners.append((field_name, field.default))
field_value = field.get_default()
self.prompt_runners.append((field_name, field_value))

field.default.set_model_kwargs(self.kwargs)
field_value.set_model_kwargs(self.kwargs)
# Necessary since pydantic creates a new version of the object
setattr(self, field_name, field.default)
setattr(self, field_name, field_value)


def save(self, filepath):
Expand Down

0 comments on commit 40e15d7

Please sign in to comment.