Skip to content

Commit

Permalink
Remove accomodation for old model pkl that are not dictionaries with …
Browse files Browse the repository at this point in the history
…metadata.
  • Loading branch information
jlubken committed Dec 28, 2021
1 parent 3f3eb35 commit 47f4154
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
20 changes: 14 additions & 6 deletions src/dsdk/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from abc import ABC
from contextlib import contextmanager
from json import dumps
from logging import getLogger
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional

Expand All @@ -22,10 +23,18 @@
BaseMixin = ABC


class Model: # pylint: disable=too-few-public-methods
class Model:
"""Model."""

YAML = "!model"
INIT = dumps(
{
"key": "Model.__init__",
"name": "%s",
"path": "%s",
"version": "%s",
}
)

@classmethod
def as_yaml_type(cls, tag: Optional[str] = None) -> None:
Expand All @@ -41,12 +50,11 @@ def as_yaml_type(cls, tag: Optional[str] = None) -> None:
def _yaml_init(cls, loader, node):
"""Yaml init."""
path = loader.construct_scalar(node)
pkl = load_pickle_file(path)
if pkl.__class__ is dict:
pkl = cls(path=path, **pkl)
else:
pkl.path = path
d = load_pickle_file(path)
assert d.__class__ is dict
pkl = cls(path=path, **d)
assert isinstance(pkl, Model)
logger.info(cls.INIT, pkl.name, pkl.path, pkl.version)
return pkl

@classmethod
Expand Down
12 changes: 5 additions & 7 deletions test/test_dsdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, **kwargs):
ext: .sql
path: ./assets/mssql
username: mssql
model: !model ./test/model.pkl
model: !model ./test/0.0.1.pkl
postgres: !postgres
database: test
host: 0.0.0.0
Expand All @@ -103,7 +103,7 @@ def __init__(self, **kwargs):
as_of: null
duration: null
gold: null
model: !model ./test/model.pkl
model: !model ./test/0.0.1.pkl
mssql: !mssql
database: test
host: 0.0.0.0
Expand Down Expand Up @@ -134,7 +134,7 @@ def build(
) -> Tuple[Callable, Dict[str, Any], str]:
"""Build from parameters."""
cls.yaml_types()
model = Model(name="test", path="./test/model.pkl", version="0.0.1-rc.1")
model = Model(name="test", path="./test/0.0.1.pkl", version="0.0.1")
mssql = Mssql(
database="test",
host="0.0.0.0",
Expand Down Expand Up @@ -168,10 +168,8 @@ def deserialize(
expected: str = EXPECTED,
) -> Tuple[Callable, Dict[str, Any], str]:
"""Build from yaml."""
pickle_file = "./test/model.pkl"
dump_pickle_file(
Model(name="test", path=pickle_file, version="0.0.1"), pickle_file
)
pickle_file = "./test/0.0.1.pkl"
dump_pickle_file({"name": "test", "version": "0.0.1"}, pickle_file)

env = {"POSTGRES_PASSWORD": "oops!", "MSSQL_PASSWORD": "oops!"}
return (
Expand Down

0 comments on commit 47f4154

Please sign in to comment.