diff --git a/.travis.yml b/.travis.yml index 8a74f94b..fcd3ceea 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,11 +2,8 @@ dist: xenial language: python python: 3.8 -install: -- pip3 install sqlalchemy - script: -- make test +- make all branches: except: diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..73024d18 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.9-slim-buster + +# RUN apt install gcc libpq (no longer needed bc we use psycopg2-binary) + +COPY requirements.txt /tmp/ +RUN pip install -r /tmp/requirements.txt + +RUN mkdir -p /src +COPY src/ /src/ +RUN pip install -e /src +COPY tests/ /tests/ + +WORKDIR /src +ENV FLASK_APP=allocation/entrypoints/flask_app.py FLASK_DEBUG=1 PYTHONUNBUFFERED=1 +CMD flask run --host=0.0.0.0 --port=80 diff --git a/Makefile b/Makefile index 77cbd229..6409e955 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,32 @@ -test: - pytest --tb=short +# these will speed up builds, for docker-compose >= 1.25 +export COMPOSE_DOCKER_CLI_BUILD=1 +export DOCKER_BUILDKIT=1 -watch-tests: - ls *.py | entr pytest --tb=short +all: down build up test + +build: + docker-compose build + +up: + docker-compose up -d app + +down: + docker-compose down --remove-orphans + +test: up + docker-compose run --rm --no-deps --entrypoint=pytest app /tests/unit /tests/integration /tests/e2e + +unit-tests: + docker-compose run --rm --no-deps --entrypoint=pytest app /tests/unit + +integration-tests: up + docker-compose run --rm --no-deps --entrypoint=pytest app /tests/integration + +e2e-tests: up + docker-compose run --rm --no-deps --entrypoint=pytest app /tests/e2e + +logs: + docker-compose logs app | tail -100 black: black -l 86 $$(find * -name '*.py') diff --git a/conftest.py b/conftest.py deleted file mode 100644 index 9f7b74b0..00000000 --- a/conftest.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, clear_mappers - -from orm import metadata, start_mappers - - -@pytest.fixture -def in_memory_db(): - engine = create_engine("sqlite:///:memory:") - metadata.create_all(engine) - return engine - - -@pytest.fixture -def session(in_memory_db): - start_mappers() - yield sessionmaker(bind=in_memory_db)() - clear_mappers() diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..039400e9 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,29 @@ +version: "3" +services: + + app: + build: + context: . + dockerfile: Dockerfile + depends_on: + - postgres + environment: + - DB_HOST=postgres + - DB_PASSWORD=abc123 + - API_HOST=app + - PYTHONDONTWRITEBYTECODE=1 + volumes: + - ./src:/src + - ./tests:/tests + ports: + - "5005:80" + + + postgres: + image: postgres:9.6 + environment: + - POSTGRES_USER=allocation + - POSTGRES_PASSWORD=abc123 + ports: + - "54321:5432" + diff --git a/mypy.ini b/mypy.ini index ead5ef09..62194f35 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,9 +1,7 @@ [mypy] ignore_missing_imports = False +mypy_path = ./src +check_untyped_defs = True -[mypy-pytest.*] +[mypy-pytest.*,sqlalchemy.*] ignore_missing_imports = True - -[mypy-sqlalchemy.*] -ignore_missing_imports = True - diff --git a/orm.py b/orm.py deleted file mode 100644 index 6a3e728e..00000000 --- a/orm.py +++ /dev/null @@ -1,29 +0,0 @@ -from sqlalchemy import Table, MetaData, Column, Integer, String, Date -from sqlalchemy.orm import mapper - -import model - - -metadata = MetaData() - -order_lines = Table( - "order_lines", - metadata, - Column("orderid", String(255), primary_key=True), - Column("sku", String(255), primary_key=True), - Column("qty", Integer), -) - -batches = Table( - "batches", - metadata, - Column("reference", String(255), primary_key=True), - Column("sku", String(255), primary_key=True), - Column("_purchased_qty", Integer), - Column("eta", Date), -) - - -def start_mappers(): - mapper(model.OrderLine, order_lines) - mapper(model.Batch, batches) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..8c779254 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +# app +sqlalchemy +flask +psycopg2-binary + +# tests +pytest +pytest-icdiff +mypy +requests diff --git a/src/allocation/__init__.py b/src/allocation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/allocation/adapters/__init__.py b/src/allocation/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/allocation/adapters/orm.py b/src/allocation/adapters/orm.py new file mode 100644 index 00000000..c189a739 --- /dev/null +++ b/src/allocation/adapters/orm.py @@ -0,0 +1,49 @@ +from sqlalchemy import Table, MetaData, Column, Integer, String, Date, ForeignKey +from sqlalchemy.orm import mapper, relationship + +from allocation.domain import model + + +metadata = MetaData() + +order_lines = Table( + "order_lines", + metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("sku", String(255)), + Column("qty", Integer, nullable=False), + Column("orderid", String(255)), +) + +batches = Table( + "batches", + metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("reference", String(255)), + Column("sku", String(255)), + Column("_purchased_quantity", Integer, nullable=False), + Column("eta", Date, nullable=True), +) + +allocations = Table( + "allocations", + metadata, + Column("id", Integer, primary_key=True, autoincrement=True), + Column("orderline_id", ForeignKey("order_lines.id")), + Column("batch_id", ForeignKey("batches.id")), +) + + +def start_mappers(): + lines_mapper = mapper(model.OrderLine, order_lines) + mapper( + model.Batch, + batches, + properties={ + "_allocations": relationship( + lines_mapper, + secondary=allocations, + collection_class=set, + ) + }, + ) diff --git a/src/allocation/adapters/repository.py b/src/allocation/adapters/repository.py new file mode 100644 index 00000000..e1d9e9be --- /dev/null +++ b/src/allocation/adapters/repository.py @@ -0,0 +1,30 @@ +import abc +from allocation.domain import model + + +class AbstractRepository(abc.ABC): + @abc.abstractmethod + def add(self, batch: model.Batch): + raise NotImplementedError + + @abc.abstractmethod + def get(self, reference) -> model.Batch: + raise NotImplementedError + + @abc.abstractmethod + def list(self): + raise NotImplementedError + + +class SqlAlchemyRepository(AbstractRepository): + def __init__(self, session): + self.session = session + + def add(self, batch): + self.session.add(batch) + + def get(self, reference): + return self.session.query(model.Batch).filter_by(reference=reference).one() + + def list(self): + return self.session.query(model.Batch).all() diff --git a/src/allocation/config.py b/src/allocation/config.py new file mode 100644 index 00000000..f3b55cc9 --- /dev/null +++ b/src/allocation/config.py @@ -0,0 +1,15 @@ +import os + + +def get_postgres_uri(): + host = os.environ.get("DB_HOST", "localhost") + port = 54321 if host == "localhost" else 5432 + password = os.environ.get("DB_PASSWORD", "abc123") + user, db_name = "allocation", "allocation" + return f"postgresql://{user}:{password}@{host}:{port}/{db_name}" + + +def get_api_url(): + host = os.environ.get("API_HOST", "localhost") + port = 5005 if host == "localhost" else 80 + return f"http://{host}:{port}" diff --git a/src/allocation/domain/__init__.py b/src/allocation/domain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model.py b/src/allocation/domain/model.py similarity index 100% rename from model.py rename to src/allocation/domain/model.py diff --git a/src/allocation/entrypoints/__init__.py b/src/allocation/entrypoints/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/allocation/entrypoints/flask_app.py b/src/allocation/entrypoints/flask_app.py new file mode 100644 index 00000000..602a09c1 --- /dev/null +++ b/src/allocation/entrypoints/flask_app.py @@ -0,0 +1,41 @@ +from datetime import datetime +from flask import Flask, request +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from allocation.domain import model +from allocation.adapters import orm +from allocation.service_layer import services, unit_of_work + +app = Flask(__name__) +orm.start_mappers() + + +@app.route("/add_batch", methods=["POST"]) +def add_batch(): + eta = request.json["eta"] + if eta is not None: + eta = datetime.fromisoformat(eta).date() + services.add_batch( + request.json["ref"], + request.json["sku"], + request.json["qty"], + eta, + unit_of_work.SqlAlchemyUnitOfWork(), + ) + return "OK", 201 + + +@app.route("/allocate", methods=["POST"]) +def allocate_endpoint(): + try: + batchref = services.allocate( + request.json["orderid"], + request.json["sku"], + request.json["qty"], + unit_of_work.SqlAlchemyUnitOfWork(), + ) + except (model.OutOfStock, services.InvalidSku) as e: + return {"message": str(e)}, 400 + + return {"batchref": batchref}, 201 diff --git a/src/allocation/service_layer/__init__.py b/src/allocation/service_layer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/allocation/service_layer/services.py b/src/allocation/service_layer/services.py new file mode 100644 index 00000000..c5d46772 --- /dev/null +++ b/src/allocation/service_layer/services.py @@ -0,0 +1,38 @@ +from __future__ import annotations +from typing import Optional +from datetime import date + +from allocation.domain import model +from allocation.domain.model import OrderLine +from allocation.service_layer import unit_of_work + + +class InvalidSku(Exception): + pass + + +def is_valid_sku(sku, batches): + return sku in {b.sku for b in batches} + + +def add_batch( + ref: str, sku: str, qty: int, eta: Optional[date], + uow: unit_of_work.AbstractUnitOfWork, +): + with uow: + uow.batches.add(model.Batch(ref, sku, qty, eta)) + uow.commit() + + +def allocate( + orderid: str, sku: str, qty: int, + uow: unit_of_work.AbstractUnitOfWork, +) -> str: + line = OrderLine(orderid, sku, qty) + with uow: + batches = uow.batches.list() + if not is_valid_sku(line.sku, batches): + raise InvalidSku(f"Invalid sku {line.sku}") + batchref = model.allocate(line, batches) + uow.commit() + return batchref diff --git a/src/allocation/service_layer/unit_of_work.py b/src/allocation/service_layer/unit_of_work.py new file mode 100644 index 00000000..bf9196c8 --- /dev/null +++ b/src/allocation/service_layer/unit_of_work.py @@ -0,0 +1,54 @@ +# pylint: disable=attribute-defined-outside-init +from __future__ import annotations +import abc +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.session import Session + +from allocation import config +from allocation.adapters import repository + + +class AbstractUnitOfWork(abc.ABC): + batches: repository.AbstractRepository + + def __enter__(self) -> AbstractUnitOfWork: + return self + + def __exit__(self, *args): + self.rollback() + + @abc.abstractmethod + def commit(self): + raise NotImplementedError + + @abc.abstractmethod + def rollback(self): + raise NotImplementedError + + +DEFAULT_SESSION_FACTORY = sessionmaker( + bind=create_engine( + config.get_postgres_uri(), + ) +) + + +class SqlAlchemyUnitOfWork(AbstractUnitOfWork): + def __init__(self, session_factory=DEFAULT_SESSION_FACTORY): + self.session_factory = session_factory + + def __enter__(self): + self.session = self.session_factory() # type: Session + self.batches = repository.SqlAlchemyRepository(self.session) + return super().__enter__() + + def __exit__(self, *args): + super().__exit__(*args) + self.session.close() + + def commit(self): + self.session.commit() + + def rollback(self): + self.session.rollback() diff --git a/src/setup.py b/src/setup.py new file mode 100644 index 00000000..b2b0839a --- /dev/null +++ b/src/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup + +setup( + name="allocation", + version="0.1", + packages=["allocation"], +) diff --git a/test_orm.py b/test_orm.py deleted file mode 100644 index 9dc98719..00000000 --- a/test_orm.py +++ /dev/null @@ -1,39 +0,0 @@ -import model -from datetime import date - - -def test_orderline_mapper_can_load_lines(session): - session.execute( - "INSERT INTO order_lines (orderid, sku, qty) VALUES " - '("order1", "RED-CHAIR", 12),' - '("order1", "RED-TABLE", 13),' - '("order2", "BLUE-LIPSTICK", 14)' - ) - expected = [ - model.OrderLine("order1", "RED-CHAIR", 12), - model.OrderLine("order1", "RED-TABLE", 13), - model.OrderLine("order2", "BLUE-LIPSTICK", 14), - ] - assert session.query(model.OrderLine).all() == expected - - -def test_orderline_mapper_can_save_lines(session): - new_line = model.OrderLine("order1", "DECORATIVE-WIDGET", 12) - session.add(new_line) - session.commit() - - rows = list(session.execute('SELECT orderid, sku, qty FROM "order_lines"')) - assert rows == [("order1", "DECORATIVE-WIDGET", 12)] - - -def test_batches(session): - session.execute('INSERT INTO "batches" VALUES ("batch1", "sku1", 100, null)') - session.execute( - 'INSERT INTO "batches" VALUES ("batch2", "sku2", 200, "2011-04-11")' - ) - expected = [ - model.Batch("batch1", "sku1", 100, eta=None), - model.Batch("batch2", "sku2", 200, eta=date(2011, 4, 11)), - ] - - assert session.query(model.Batch).all() == expected diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..3dff0b83 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,75 @@ +# pylint: disable=redefined-outer-name +import time +from pathlib import Path + +import pytest +import requests +from requests.exceptions import ConnectionError +from sqlalchemy.exc import OperationalError +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, clear_mappers + +from allocation.adapters.orm import metadata, start_mappers +from allocation import config + + +@pytest.fixture +def in_memory_db(): + engine = create_engine("sqlite:///:memory:") + metadata.create_all(engine) + return engine + + +@pytest.fixture +def session_factory(in_memory_db): + start_mappers() + yield sessionmaker(bind=in_memory_db) + clear_mappers() + + +@pytest.fixture +def session(session_factory): + return session_factory() + + +def wait_for_postgres_to_come_up(engine): + deadline = time.time() + 10 + while time.time() < deadline: + try: + return engine.connect() + except OperationalError: + time.sleep(0.5) + pytest.fail("Postgres never came up") + + +def wait_for_webapp_to_come_up(): + deadline = time.time() + 10 + url = config.get_api_url() + while time.time() < deadline: + try: + return requests.get(url) + except ConnectionError: + time.sleep(0.5) + pytest.fail("API never came up") + + +@pytest.fixture(scope="session") +def postgres_db(): + engine = create_engine(config.get_postgres_uri()) + wait_for_postgres_to_come_up(engine) + metadata.create_all(engine) + return engine + + +@pytest.fixture +def postgres_session(postgres_db): + start_mappers() + yield sessionmaker(bind=postgres_db)() + clear_mappers() + + +@pytest.fixture +def restart_api(): + (Path(__file__).parent / "../src/allocation/entrypoints/flask_app.py").touch() + time.sleep(0.5) + wait_for_webapp_to_come_up() diff --git a/tests/e2e/test_api.py b/tests/e2e/test_api.py new file mode 100644 index 00000000..29b85761 --- /dev/null +++ b/tests/e2e/test_api.py @@ -0,0 +1,59 @@ +import uuid +import pytest +import requests + +from allocation import config + + +def random_suffix(): + return uuid.uuid4().hex[:6] + + +def random_sku(name=""): + return f"sku-{name}-{random_suffix()}" + + +def random_batchref(name=""): + return f"batch-{name}-{random_suffix()}" + + +def random_orderid(name=""): + return f"order-{name}-{random_suffix()}" + + +def post_to_add_batch(ref, sku, qty, eta): + url = config.get_api_url() + r = requests.post( + f"{url}/add_batch", json={"ref": ref, "sku": sku, "qty": qty, "eta": eta} + ) + assert r.status_code == 201 + + +@pytest.mark.usefixtures("postgres_db") +@pytest.mark.usefixtures("restart_api") +def test_happy_path_returns_201_and_allocated_batch(): + sku, othersku = random_sku(), random_sku("other") + earlybatch = random_batchref(1) + laterbatch = random_batchref(2) + otherbatch = random_batchref(3) + post_to_add_batch(laterbatch, sku, 100, "2011-01-02") + post_to_add_batch(earlybatch, sku, 100, "2011-01-01") + post_to_add_batch(otherbatch, othersku, 100, None) + data = {"orderid": random_orderid(), "sku": sku, "qty": 3} + + url = config.get_api_url() + r = requests.post(f"{url}/allocate", json=data) + + assert r.status_code == 201 + assert r.json()["batchref"] == earlybatch + + +@pytest.mark.usefixtures("postgres_db") +@pytest.mark.usefixtures("restart_api") +def test_unhappy_path_returns_400_and_error_message(): + unknown_sku, orderid = random_sku(), random_orderid() + data = {"orderid": orderid, "sku": unknown_sku, "qty": 20} + url = config.get_api_url() + r = requests.post(f"{url}/allocate", json=data) + assert r.status_code == 400 + assert r.json()["message"] == f"Invalid sku {unknown_sku}" diff --git a/tests/integration/test_orm.py b/tests/integration/test_orm.py new file mode 100644 index 00000000..db3a7a68 --- /dev/null +++ b/tests/integration/test_orm.py @@ -0,0 +1,89 @@ +from allocation.domain import model +from datetime import date + + +def test_orderline_mapper_can_load_lines(session): + session.execute( + "INSERT INTO order_lines (orderid, sku, qty) VALUES " + '("order1", "RED-CHAIR", 12),' + '("order1", "RED-TABLE", 13),' + '("order2", "BLUE-LIPSTICK", 14)' + ) + expected = [ + model.OrderLine("order1", "RED-CHAIR", 12), + model.OrderLine("order1", "RED-TABLE", 13), + model.OrderLine("order2", "BLUE-LIPSTICK", 14), + ] + assert session.query(model.OrderLine).all() == expected + + +def test_orderline_mapper_can_save_lines(session): + new_line = model.OrderLine("order1", "DECORATIVE-WIDGET", 12) + session.add(new_line) + session.commit() + + rows = list(session.execute('SELECT orderid, sku, qty FROM "order_lines"')) + assert rows == [("order1", "DECORATIVE-WIDGET", 12)] + + +def test_retrieving_batches(session): + session.execute( + "INSERT INTO batches (reference, sku, _purchased_quantity, eta)" + ' VALUES ("batch1", "sku1", 100, null)' + ) + session.execute( + "INSERT INTO batches (reference, sku, _purchased_quantity, eta)" + ' VALUES ("batch2", "sku2", 200, "2011-04-11")' + ) + expected = [ + model.Batch("batch1", "sku1", 100, eta=None), + model.Batch("batch2", "sku2", 200, eta=date(2011, 4, 11)), + ] + + assert session.query(model.Batch).all() == expected + + +def test_saving_batches(session): + batch = model.Batch("batch1", "sku1", 100, eta=None) + session.add(batch) + session.commit() + rows = session.execute( + 'SELECT reference, sku, _purchased_quantity, eta FROM "batches"' + ) + assert list(rows) == [("batch1", "sku1", 100, None)] + + +def test_saving_allocations(session): + batch = model.Batch("batch1", "sku1", 100, eta=None) + line = model.OrderLine("order1", "sku1", 10) + batch.allocate(line) + session.add(batch) + session.commit() + rows = list(session.execute('SELECT orderline_id, batch_id FROM "allocations"')) + assert rows == [(batch.id, line.id)] + + +def test_retrieving_allocations(session): + session.execute( + 'INSERT INTO order_lines (orderid, sku, qty) VALUES ("order1", "sku1", 12)' + ) + [[olid]] = session.execute( + "SELECT id FROM order_lines WHERE orderid=:orderid AND sku=:sku", + dict(orderid="order1", sku="sku1"), + ) + session.execute( + "INSERT INTO batches (reference, sku, _purchased_quantity, eta)" + ' VALUES ("batch1", "sku1", 100, null)' + ) + [[bid]] = session.execute( + "SELECT id FROM batches WHERE reference=:ref AND sku=:sku", + dict(ref="batch1", sku="sku1"), + ) + session.execute( + "INSERT INTO allocations (orderline_id, batch_id) VALUES (:olid, :bid)", + dict(olid=olid, bid=bid), + ) + + batch = session.query(model.Batch).one() + + assert batch._allocations == {model.OrderLine("order1", "sku1", 12)} diff --git a/tests/integration/test_repository.py b/tests/integration/test_repository.py new file mode 100644 index 00000000..0fa69ab3 --- /dev/null +++ b/tests/integration/test_repository.py @@ -0,0 +1,67 @@ +# pylint: disable=protected-access +from allocation.domain import model +from allocation.adapters import repository + + +def test_repository_can_save_a_batch(session): + batch = model.Batch("batch1", "RUSTY-SOAPDISH", 100, eta=None) + + repo = repository.SqlAlchemyRepository(session) + repo.add(batch) + session.commit() + + rows = session.execute( + 'SELECT reference, sku, _purchased_quantity, eta FROM "batches"' + ) + assert list(rows) == [("batch1", "RUSTY-SOAPDISH", 100, None)] + + +def insert_order_line(session): + session.execute( + "INSERT INTO order_lines (orderid, sku, qty)" + ' VALUES ("order1", "GENERIC-SOFA", 12)' + ) + [[orderline_id]] = session.execute( + "SELECT id FROM order_lines WHERE orderid=:orderid AND sku=:sku", + dict(orderid="order1", sku="GENERIC-SOFA"), + ) + return orderline_id + + +def insert_batch(session, batch_id): + session.execute( + "INSERT INTO batches (reference, sku, _purchased_quantity, eta)" + ' VALUES (:batch_id, "GENERIC-SOFA", 100, null)', + dict(batch_id=batch_id), + ) + [[batch_id]] = session.execute( + 'SELECT id FROM batches WHERE reference=:batch_id AND sku="GENERIC-SOFA"', + dict(batch_id=batch_id), + ) + return batch_id + + +def insert_allocation(session, orderline_id, batch_id): + session.execute( + "INSERT INTO allocations (orderline_id, batch_id)" + " VALUES (:orderline_id, :batch_id)", + dict(orderline_id=orderline_id, batch_id=batch_id), + ) + + +def test_repository_can_retrieve_a_batch_with_allocations(session): + orderline_id = insert_order_line(session) + batch1_id = insert_batch(session, "batch1") + insert_batch(session, "batch2") + insert_allocation(session, orderline_id, batch1_id) + + repo = repository.SqlAlchemyRepository(session) + retrieved = repo.get("batch1") + + expected = model.Batch("batch1", "GENERIC-SOFA", 100, eta=None) + assert retrieved == expected # Batch.__eq__ only compares reference + assert retrieved.sku == expected.sku + assert retrieved._purchased_quantity == expected._purchased_quantity + assert retrieved._allocations == { + model.OrderLine("order1", "GENERIC-SOFA", 12), + } diff --git a/tests/integration/test_uow.py b/tests/integration/test_uow.py new file mode 100644 index 00000000..3887e3ca --- /dev/null +++ b/tests/integration/test_uow.py @@ -0,0 +1,65 @@ +import pytest +from allocation.domain import model +from allocation.service_layer import unit_of_work + + +def insert_batch(session, ref, sku, qty, eta): + session.execute( + "INSERT INTO batches (reference, sku, _purchased_quantity, eta)" + " VALUES (:ref, :sku, :qty, :eta)", + dict(ref=ref, sku=sku, qty=qty, eta=eta), + ) + + +def get_allocated_batch_ref(session, orderid, sku): + [[orderlineid]] = session.execute( + "SELECT id FROM order_lines WHERE orderid=:orderid AND sku=:sku", + dict(orderid=orderid, sku=sku), + ) + [[batchref]] = session.execute( + "SELECT b.reference FROM allocations JOIN batches AS b ON batch_id = b.id" + " WHERE orderline_id=:orderlineid", + dict(orderlineid=orderlineid), + ) + return batchref + + +def test_uow_can_retrieve_a_batch_and_allocate_to_it(session_factory): + session = session_factory() + insert_batch(session, "batch1", "HIPSTER-WORKBENCH", 100, None) + session.commit() + + uow = unit_of_work.SqlAlchemyUnitOfWork(session_factory) + with uow: + batch = uow.batches.get(reference="batch1") + line = model.OrderLine("o1", "HIPSTER-WORKBENCH", 10) + batch.allocate(line) + uow.commit() + + batchref = get_allocated_batch_ref(session, "o1", "HIPSTER-WORKBENCH") + assert batchref == "batch1" + + +def test_rolls_back_uncommitted_work_by_default(session_factory): + uow = unit_of_work.SqlAlchemyUnitOfWork(session_factory) + with uow: + insert_batch(uow.session, "batch1", "MEDIUM-PLINTH", 100, None) + + new_session = session_factory() + rows = list(new_session.execute('SELECT * FROM "batches"')) + assert rows == [] + + +def test_rolls_back_on_error(session_factory): + class MyException(Exception): + pass + + uow = unit_of_work.SqlAlchemyUnitOfWork(session_factory) + with pytest.raises(MyException): + with uow: + insert_batch(uow.session, "batch1", "LARGE-FORK", 100, None) + raise MyException() + + new_session = session_factory() + rows = list(new_session.execute('SELECT * FROM "batches"')) + assert rows == [] diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..bbd083ac --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --tb=short diff --git a/test_allocate.py b/tests/unit/test_allocate.py similarity index 95% rename from test_allocate.py rename to tests/unit/test_allocate.py index 7e189307..48dcfe5c 100644 --- a/test_allocate.py +++ b/tests/unit/test_allocate.py @@ -1,6 +1,6 @@ from datetime import date, timedelta import pytest -from model import allocate, OrderLine, Batch, OutOfStock +from allocation.domain.model import allocate, OrderLine, Batch, OutOfStock today = date.today() tomorrow = today + timedelta(days=1) diff --git a/test_batches.py b/tests/unit/test_batches.py similarity index 97% rename from test_batches.py rename to tests/unit/test_batches.py index 62288330..8735f41e 100644 --- a/test_batches.py +++ b/tests/unit/test_batches.py @@ -1,5 +1,5 @@ from datetime import date -from model import Batch, OrderLine +from allocation.domain.model import Batch, OrderLine def test_allocating_to_a_batch_reduces_the_available_quantity(): diff --git a/tests/unit/test_services.py b/tests/unit/test_services.py new file mode 100644 index 00000000..091dbb2c --- /dev/null +++ b/tests/unit/test_services.py @@ -0,0 +1,58 @@ +import pytest +from allocation.adapters import repository +from allocation.service_layer import services, unit_of_work + + +class FakeRepository(repository.AbstractRepository): + def __init__(self, batches): + self._batches = set(batches) + + def add(self, batch): + self._batches.add(batch) + + def get(self, reference): + return next(b for b in self._batches if b.reference == reference) + + def list(self): + return list(self._batches) + + +class FakeUnitOfWork(unit_of_work.AbstractUnitOfWork): + def __init__(self): + self.batches = FakeRepository([]) + self.committed = False + + def commit(self): + self.committed = True + + def rollback(self): + pass + + +def test_add_batch(): + uow = FakeUnitOfWork() + services.add_batch("b1", "CRUNCHY-ARMCHAIR", 100, None, uow) + assert uow.batches.get("b1") is not None + assert uow.committed + + +def test_allocate_returns_allocation(): + uow = FakeUnitOfWork() + services.add_batch("batch1", "COMPLICATED-LAMP", 100, None, uow) + result = services.allocate("o1", "COMPLICATED-LAMP", 10, uow) + assert result == "batch1" + + +def test_allocate_errors_for_invalid_sku(): + uow = FakeUnitOfWork() + services.add_batch("b1", "AREALSKU", 100, None, uow) + + with pytest.raises(services.InvalidSku, match="Invalid sku NONEXISTENTSKU"): + services.allocate("o1", "NONEXISTENTSKU", 10, uow) + + +def test_allocate_commits(): + uow = FakeUnitOfWork() + services.add_batch("b1", "OMINOUS-MIRROR", 100, None, uow) + services.allocate("o1", "OMINOUS-MIRROR", 10, uow) + assert uow.committed