diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7feee8b..5321aba 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,9 +14,9 @@ jobs: - 3.11 services: redis: - image: redis + image: mongo ports: - - 6379:6379 + - 27017:27017 steps: - uses: actions/checkout@v3 @@ -38,5 +38,5 @@ jobs: run: | pipenv run python -m unittest tests/*.py env: - REDIS_HOST: localhost - REDIS_PORT: 6379 + MONGO_HOST: localhost + MONGO_PORT: 27017 diff --git a/README.md b/README.md index 63f0faf..73f1619 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Micromodel -Static and runtime dictionary validation. +Static and runtime dictionary validation (with MongoDB support). ## Install @@ -14,7 +14,7 @@ We had a HUGE Python code base which was using `pydantic` to provide a validatio We then decided to make this validation in-loco using a more vanilla approach with only `TypedDict`s. Now our dictionaries containing MongoDB documents are consistently dicts that match with the static typing. -## Usage +## Usage (validation only) ```python import typing @@ -80,6 +80,32 @@ print(result) print(m.cast({})) ``` +## Usage (with MongoDB) + +```python +import os +import typing +from micromodel import model +from pymongo import MongoClient + +db = MongoClient(os.getenv('MONGODB_URI')).get_default_database() + +Animal = typing.TypedDict('Animal', { + 'name': str, + 'specie': list[typing.Literal[ + 'dog', + 'cat', + 'bird' + ]] +}) + +m = model(Animal, coll=db['animals']) +m.insert_one({ + 'name': 'thor', + 'specie': 'dog' +}) +``` + ## License This library is [MIT licensed](https://github.com/capsulbrasil/normalize-json/tree/master/LICENSE). diff --git a/src/micromodel/micromodel.py b/src/micromodel/micromodel.py index df66f64..821efd6 100644 --- a/src/micromodel/micromodel.py +++ b/src/micromodel/micromodel.py @@ -1,21 +1,85 @@ import typing import types from abc import ABCMeta +from pymongo import ReturnDocument +from pymongo.collection import Collection +from pymongo.results import UpdateResult T = typing.TypeVar('T') ValidationOptions = typing.TypedDict('ValidationOptions', { + 'allow_missing': typing.NotRequired[bool], 'allow_extraneous': typing.NotRequired[bool] }) class Model(typing.Generic[T]): - def __init__(self, model_type: typing.Callable[[typing.Any], T], ct: dict[str, typing.Any] = {}): + coll: Collection + + def __init__(self, model_type: typing.Callable[[typing.Any], T], ct: dict[str, typing.Any] = {}, coll: Collection | None = None): self.model_type = model_type self.ct = ct + if coll: + self.coll = coll def cast(self, target: T | dict[str, typing.Any]): return typing.cast(T, target) def validate(self, target: T, options: ValidationOptions = {}): return validate(self.model_type, typing.cast(typing.Any, target), options, self.ct) + def find(self, *args: typing.Any, **kwargs: typing.Any): + result = self.coll.find(*args, **kwargs) + return typing.cast(typing.Generator[T, None, None], result) + + def find_one(self, *args: typing.Any, **kwargs: typing.Any): + result = self.coll.find_one(*args, **kwargs) + if not result: + return None + return self.cast(result) + + def insert_one(self, what: T, *args: typing.Any, **kwargs: typing.Any): + what = self.validate(what) + result = self.coll.insert_one(typing.cast(typing.Any, what), *args, **kwargs) + return result + + def _update(self, value: typing.Any, query_fields: list[str], ret: bool = True): + new = { + k: v + for k, v in value.items() + if k not in [ + '_id', + *query_fields + ] + } + + search = { + '$and': [ + { f: value[f] } + for f in query_fields + if f in value + ] + } + + if ret: + return self.coll.find_one_and_update( + search, + { '$set': new }, + return_document=ReturnDocument.AFTER, + upsert=True + ) + else: + return self.coll.update_one( + search, + { '$set': new }, + upsert=True + ) + + def update(self, value: typing.Any, query_fields: list[str]): + result = self._update(value, query_fields, ret=True) + return typing.cast(UpdateResult, result) + + def upsert(self, value: typing.Any, query_fields: list[str]): + result = self._update(value, query_fields, ret=True) + return typing.cast(T, result) + + def raise_missing_key(k: int | str): raise TypeError('missing key: %s' % k) @@ -25,12 +89,12 @@ def raise_extraneous_key(k: int | str): def raise_type_error(k: int | str, args: str, v: typing.Any): raise TypeError('incorrect type for %s: expected %s, got %s' % (k, args, v)) -def unwrap_type(obj: dict[int | str, typing.Any] | list[typing.Any], k: int | str, v: typing.Any, ct: dict[str, typing.Any] = {}): +def unwrap_type(obj: dict[int | str, typing.Any] | list[typing.Any], k: int | str, v: typing.Any, options: ValidationOptions = {}, ct: dict[str, typing.Any] = {}): origin = typing.get_origin(v) args = typing.get_args(v) if (isinstance(obj, dict) and not k in obj) or (isinstance(obj, list) and int(k) > len(obj)): - if types.NoneType not in args: + if types.NoneType not in args and not options.get('allow_missing'): raise_missing_key(k) return @@ -41,11 +105,11 @@ def unwrap_type(obj: dict[int | str, typing.Any] | list[typing.Any], k: int | st match origin: case _ if origin == list: for i in range(len(value)): - unwrap_type(value, i, args[0], ct) + unwrap_type(value, i, args[0], options, ct) case _ if origin == tuple: for i in range(len(value)): - unwrap_type(value, i, args[i], ct) + unwrap_type(value, i, args[i], options, ct) case typing.Literal: if value not in args: @@ -54,7 +118,7 @@ def unwrap_type(obj: dict[int | str, typing.Any] | list[typing.Any], k: int | st case types.UnionType: for candidate in args: if isinstance(candidate(), type(value)): - unwrap_type(obj, k, candidate, ct) + unwrap_type(obj, k, candidate, options, ct) break else: raise_type_error(k, str(args), type(value)) @@ -84,7 +148,7 @@ def validate(model_type: typing.Callable[[typing.Any], T], target: dict[str, typ obj[k] = v for k, v in hints.items(): - obj[k] = unwrap_type(obj, k, v, ct) + obj[k] = unwrap_type(obj, k, v, options, ct) return typing.cast(T, obj) @@ -92,6 +156,6 @@ def get_hints(model_type: ABCMeta): hints = typing.get_type_hints(model_type) return hints -def model(model_type: typing.Callable[[typing.Any], T], ct: dict[str, typing.Any] = {}) -> Model[T]: - return Model(model_type, ct) +def model(model_type: typing.Callable[[typing.Any], T], ct: dict[str, typing.Any] = {}, coll: Collection | None = None) -> Model[T]: + return Model(model_type, ct, coll) diff --git a/tests/mongodb.py b/tests/mongodb.py new file mode 100644 index 0000000..016e746 --- /dev/null +++ b/tests/mongodb.py @@ -0,0 +1,57 @@ +import os +import typing +from unittest import TestCase +from src.micromodel import model +from pymongo import MongoClient + +client = MongoClient('mongodb://%s:%s/test' % ( + os.getenv('MONGODB_HOST'), + int(os.getenv('MONGODB_PORT', '0')) +)) + +db = client.get_default_database() + +Animal = typing.TypedDict('Animal', { + 'name': str, + 'specie': typing.Literal[ + 'dog', + 'bird' + ] +}) + +db.drop_collection('animals') +m = model(Animal, coll=db['animals']) + +class TestMongodb(TestCase): + def test_object_equality(self): + m.insert_one({ + 'name': 'thor', + 'specie': 'dog' + }) + + result = m.find_one({ + 'name': 'thor' + }) + + if not result: + raise ValueError() + + self.assertEqual(result['name'], 'thor') + self.assertEqual(result['specie'], 'dog') + + + def test_upsert(self): + m.upsert({ + 'name': 'thor', + 'specie': 'bird' + }, ['name']) + + result = m.find_one({ + 'name': 'thor' + }) + + if not result: + raise ValueError() + + self.assertEqual(result['name'], 'thor') + self.assertEqual(result['specie'], 'bird')