From 47f415424997c660917f6374a6a5f56f89c13f53 Mon Sep 17 00:00:00 2001 From: Jason Lubken Date: Tue, 28 Dec 2021 14:22:47 -0500 Subject: [PATCH] Remove accomodation for old model pkl that are not dictionaries with metadata. --- src/dsdk/model.py | 20 ++++++++++++++------ test/test_dsdk.py | 12 +++++------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/dsdk/model.py b/src/dsdk/model.py index 46534a6..146377e 100644 --- a/src/dsdk/model.py +++ b/src/dsdk/model.py @@ -5,6 +5,7 @@ from abc import ABC from contextlib import contextmanager +from json import dumps from logging import getLogger from typing import TYPE_CHECKING, Any, Dict, Generator, Optional @@ -22,10 +23,18 @@ BaseMixin = ABC -class Model: # pylint: disable=too-few-public-methods +class Model: """Model.""" YAML = "!model" + INIT = dumps( + { + "key": "Model.__init__", + "name": "%s", + "path": "%s", + "version": "%s", + } + ) @classmethod def as_yaml_type(cls, tag: Optional[str] = None) -> None: @@ -41,12 +50,11 @@ def as_yaml_type(cls, tag: Optional[str] = None) -> None: def _yaml_init(cls, loader, node): """Yaml init.""" path = loader.construct_scalar(node) - pkl = load_pickle_file(path) - if pkl.__class__ is dict: - pkl = cls(path=path, **pkl) - else: - pkl.path = path + d = load_pickle_file(path) + assert d.__class__ is dict + pkl = cls(path=path, **d) assert isinstance(pkl, Model) + logger.info(cls.INIT, pkl.name, pkl.path, pkl.version) return pkl @classmethod diff --git a/test/test_dsdk.py b/test/test_dsdk.py index 1ac72b2..102b7d3 100644 --- a/test/test_dsdk.py +++ b/test/test_dsdk.py @@ -81,7 +81,7 @@ def __init__(self, **kwargs): ext: .sql path: ./assets/mssql username: mssql -model: !model ./test/model.pkl +model: !model ./test/0.0.1.pkl postgres: !postgres database: test host: 0.0.0.0 @@ -103,7 +103,7 @@ def __init__(self, **kwargs): as_of: null duration: null gold: null -model: !model ./test/model.pkl +model: !model ./test/0.0.1.pkl mssql: !mssql database: test host: 0.0.0.0 @@ -134,7 +134,7 @@ def build( ) -> Tuple[Callable, Dict[str, Any], str]: """Build from parameters.""" cls.yaml_types() - model = Model(name="test", path="./test/model.pkl", version="0.0.1-rc.1") + model = Model(name="test", path="./test/0.0.1.pkl", version="0.0.1") mssql = Mssql( database="test", host="0.0.0.0", @@ -168,10 +168,8 @@ def deserialize( expected: str = EXPECTED, ) -> Tuple[Callable, Dict[str, Any], str]: """Build from yaml.""" - pickle_file = "./test/model.pkl" - dump_pickle_file( - Model(name="test", path=pickle_file, version="0.0.1"), pickle_file - ) + pickle_file = "./test/0.0.1.pkl" + dump_pickle_file({"name": "test", "version": "0.0.1"}, pickle_file) env = {"POSTGRES_PASSWORD": "oops!", "MSSQL_PASSWORD": "oops!"} return (