Skip to content

Commit

Permalink
Migration Flask to FastAPI
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdae committed Nov 30, 2023
1 parent 5b42810 commit f6e5276
Show file tree
Hide file tree
Showing 18 changed files with 690 additions and 719 deletions.
668 changes: 279 additions & 389 deletions poetry.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@ packages = [{include = "world_boss"}]

[tool.poetry.dependencies]
python = "^3.10"
flask = {extras = ["async"], version = "^2.2.2"}
mypy = "^0.991"
flask-sqlalchemy = "2.5.1"
flask-migrate = "^4.0.0"
psycopg2 = "^2.9.5"
redis = "^4.3.5"
gunicorn = "^20.1.0"
types-flask-migrate = "^4.0.0.0"
types-redis = "^4.3.21.6"
boto3 = "^1.26.22"
ethereum-kms-signer = "^0.1.6"
Expand All @@ -29,12 +25,16 @@ celery-types = "^0.14.0"
pydantic = {extras = ["dotenv"], version = "^1.10.2"}
bencodex = "^1.0.1"
gql = {version = "^3.5.0b0", allow-prereleases = true}
sentry-sdk = {extras = ["flask"], version = "^1.33.1"}
sentry-sdk = {extras = ["fastapi"], version = "^1.35.0"}
fastapi = "^0.104.1"
sqlalchemy = "^2.0.23"
alembic = "^1.12.1"
uvicorn = "^0.24.0.post1"
python-multipart = "^0.0.6"


[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
pytest-flask-sqlalchemy = "^1.1.0"
pytest-postgresql = "^4.1.1"
pytest-dotenv = "^0.5.2"
psycopg-binary = "^3.1.4"
Expand Down
5 changes: 3 additions & 2 deletions tests/api_mock_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
from unittest.mock import MagicMock

Expand All @@ -10,7 +11,7 @@

@pytest.fixture()
def non_mocked_hosts() -> list:
return ["9c-main-full-state.nine-chronicles.com"]
return ["9c-main-full-state.nine-chronicles.com", "testserver"]


@pytest.mark.parametrize("has_header", [True, False])
Expand Down Expand Up @@ -95,7 +96,7 @@ def test_prepare_transfer_assets(
},
)
assert req.status_code == 200
task_id = req.json["task_id"]
task_id = json.loads(req.json())["task_id"]
task: AsyncResult = AsyncResult(task_id)
task.get(timeout=30)
assert task.state == "SUCCESS"
Expand Down
57 changes: 31 additions & 26 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import pytest
from celery.result import AsyncResult
from flask.testing import FlaskClient
from pytest_httpx import HTTPXMock
from starlette.testclient import TestClient

from world_boss.app.cache import cache_exists, set_to_cache
from world_boss.app.data_provider import DATA_PROVIDER_URLS
Expand All @@ -17,6 +17,11 @@
from world_boss.app.models import Transaction, WorldBossReward, WorldBossRewardAmount


@pytest.fixture()
def non_mocked_hosts() -> list:
return ["testserver"]


def test_raid_rewards_404(fx_test_client, redisdb, fx_session):
req = fx_test_client.get("/raid/1/test/rewards")
assert req.status_code == 404
Expand Down Expand Up @@ -58,7 +63,7 @@ def test_raid_rewards(fx_test_client, fx_session, redis_proc, caching: bool):
set_to_cache(cache_key, json.dumps(reward.as_dict()), timedelta(seconds=1))
req = fx_test_client.get(f"/raid/{raid_id}/{avatar_address}/rewards")
assert req.status_code == 200
assert req.json == reward.as_dict()
assert json.loads(req.json()) == reward.as_dict()
if caching:
time.sleep(2)
assert not cache_exists(cache_key)
Expand All @@ -81,7 +86,7 @@ def test_count_total_users(
"/raid/list/count", data={"text": 1, "channel_id": "channel_id"}
)
assert req.status_code == 200
task_id = req.json["task_id"]
task_id = json.loads(req.json())["task_id"]
task: AsyncResult = AsyncResult(task_id)
task.get(timeout=30)
assert task.state == "SUCCESS"
Expand Down Expand Up @@ -131,7 +136,7 @@ def test_generate_ranking_rewards_csv(
"/raid/rewards/list", data={"text": "1 1 1", "channel_id": "channel_id"}
)
assert req.status_code == 200
task_id = req.json["task_id"]
task_id = json.loads(req.json())["task_id"]
task: AsyncResult = AsyncResult(task_id)
task.get(timeout=30)
assert task.state == "SUCCESS"
Expand All @@ -154,15 +159,15 @@ def test_next_tx_nonce(
tx.signer = "0xCFCd6565287314FF70e4C4CF309dB701C43eA5bD"
tx.payload = "payload"
fx_session.add(tx)
fx_session.flush()
fx_session.commit()
with unittest.mock.patch(
"world_boss.app.api.client.chat_postMessage"
) as m, unittest.mock.patch(
"world_boss.app.slack.verifier.is_valid_request", return_value=True
):
req = fx_test_client.post("/nonce", data={"channel_id": "channel_id"})
assert req.status_code == 200
assert req.json == 200
assert req.json() == 200
m.assert_called_once_with(channel="channel_id", text="next tx nonce: 2")


Expand Down Expand Up @@ -205,7 +210,7 @@ def test_prepare_reward_assets(fx_test_client, celery_session_worker, fx_session
"/prepare-reward-assets", data={"channel_id": "channel_id", "text": "3"}
)
assert req.status_code == 200
task_id = req.json["task_id"]
task_id = json.loads(req.json())["task_id"]
task = AsyncResult(task_id)
task.get(timeout=30)
assert task.state == "SUCCESS"
Expand Down Expand Up @@ -244,7 +249,7 @@ def test_stage_transactions(
"/stage-transaction", data={"channel_id": "channel_id", "text": text}
)
assert req.status_code == 200
task_id = req.json["task_id"]
task_id = json.loads(req.json())["task_id"]
task: AsyncResult = AsyncResult(task_id)
task.get(timeout=30)
assert m.call_count == len(HEADLESS_URLS[network_type]) * len(fx_transactions)
Expand Down Expand Up @@ -285,7 +290,7 @@ def test_transaction_result(
"/transaction-result", data={"channel_id": "channel_id", "text": text}
)
assert req.status_code == 200
task_id = req.json["task_id"]
task_id = json.loads(req.json())["task_id"]
task: AsyncResult = AsyncResult(task_id)
task.get(timeout=30)
assert task.state == "SUCCESS"
Expand All @@ -296,7 +301,7 @@ def test_transaction_result(
assert kwargs["title"] == "world_boss_tx_result"
assert "world_boss_tx_result" in kwargs["filename"]
for tx in fx_session.query(Transaction):
assert tx.tx_result == "SUCCESS"
assert tx.tx_result == "INVALID"


def test_check_balance(fx_session, fx_test_client, celery_session_worker):
Expand Down Expand Up @@ -331,7 +336,7 @@ def test_check_balance(fx_session, fx_test_client, celery_session_worker):
):
req = fx_test_client.post("/balance", data={"channel_id": "channel_id"})
assert req.status_code == 200
task_id = req.json["task_id"]
task_id = json.loads(req.json())["task_id"]
task: AsyncResult = AsyncResult(task_id)
task.get(timeout=30)
assert m.call_count == 2
Expand All @@ -340,35 +345,35 @@ def test_check_balance(fx_session, fx_test_client, celery_session_worker):


@pytest.mark.parametrize(
"url",
"url, text",
[
"/raid/list/count",
"/raid/rewards/list",
"/raid/prepare",
"/nonce",
"/prepare-reward-assets",
"/stage-transaction",
"/transaction-result",
"/balance",
("/raid/list/count", "1"),
("/raid/rewards/list", "1 1 1"),
("/raid/prepare", "1 1"),
("/nonce", None),
("/prepare-reward-assets", "1"),
("/stage-transaction", "main"),
("/transaction-result", "main"),
("/balance", None),
],
)
def test_slack_auth(fx_test_client, url: str):
req = fx_test_client.post(url)
def test_slack_auth(fx_test_client, url: str, text: str):
req = fx_test_client.post(url, data={"channel_id": "channel_id", "text": text})
assert req.status_code == 403


def test_ping(fx_test_client: FlaskClient):
def test_ping(fx_test_client: TestClient):
req = fx_test_client.get("/ping")
assert req.status_code == 200
assert req.json == {"message": "pong"}
assert json.loads(req.json()) == {"message": "pong"}

mocked_session = MagicMock()
mocked_session.side_effect = TimeoutError()

with unittest.mock.patch(
"world_boss.app.api.db.session.execute", side_effect=mocked_session
"world_boss.app.api.text", side_effect=mocked_session
) as m:
req = fx_test_client.get("/ping")
m.assert_called_once()
assert req.status_code == 503
assert req.json == {"message": "database connection failed"}
assert json.loads(req.json()) == {"message": "database connection failed"}
40 changes: 19 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import pytest
import sqlalchemy as sa
from flask import Flask
from flask.testing import FlaskClient
from fastapi import FastAPI
from pytest_postgresql.janitor import DatabaseJanitor
from pytest_redis import factories # type: ignore
from starlette.testclient import TestClient

from world_boss.app.config import config
from world_boss.app.models import Transaction
from world_boss.app.orm import Base, SessionLocal, engine
from world_boss.app.stubs import RewardDictionary
from world_boss.app.tasks import celery
from world_boss.wsgi import create_app
Expand Down Expand Up @@ -36,49 +37,46 @@ def database():


@pytest.fixture(scope="session")
def fx_app(database) -> Flask:
def fx_app(database) -> FastAPI:
fx_app = create_app()
ctx = fx_app.app_context()
ctx.push()
fx_app.dependency_overrides["get_db"] = fx_session
return fx_app


@pytest.fixture
def fx_session(fx_app):
def fx_session(fx_app) -> typing.Generator:
"""
Provide the transactional fixtures with access to the database via a Flask-SQLAlchemy
database connection.
"""
fx_db = fx_app.extensions["sqlalchemy"].db
fx_db.session.rollback()
fx_db.drop_all()
fx_db.session.commit()
fx_db.create_all()
return fx_db.session
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
db = SessionLocal()
try:
yield db
finally:
db.close()


@pytest.fixture()
def fx_test_client(fx_app: Flask) -> FlaskClient:
fx_app.testing = True
return fx_app.test_client()
def fx_test_client(fx_app: FastAPI) -> TestClient:
return TestClient(fx_app)


@pytest.fixture(scope="session")
def celery_config(fx_app: Flask, redis_proc):
def celery_config(redis_proc):
conf = {
"broker_url": fx_app.config["CELERY_BROKER_URL"],
"result_backend": fx_app.config["CELERY_RESULT_BACKEND"],
"broker_url": config.celery_broker_url,
"result_backend": config.celery_result_backend,
}
conf.update(fx_app.config)
return conf


@pytest.fixture(scope="session")
def celery_parameters(fx_app):
class TestTask(celery.Task):
def __call__(self, *args, **kwargs):
with fx_app.app_context():
return self.run(*args, **kwargs)
return self.run(*args, **kwargs)

return {"task_cls": TestTask}

Expand Down
13 changes: 7 additions & 6 deletions tests/kms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_transfer_assets(fx_session) -> None:
[recipient],
"test",
"http://9c-internal-rpc-1.nine-chronicles.com/graphql",
fx_session,
)
transaction = fx_session.query(Transaction).first()
assert result == transaction
Expand All @@ -53,18 +54,18 @@ async def test_stage_transactions_async(fx_session, fx_mainnet_transactions):
fx_session.add_all(fx_mainnet_transactions)
fx_session.flush()
with pytest.raises(TransportQueryError):
await signer.stage_transactions_async(NetworkType.INTERNAL)
await signer.stage_transactions_async(NetworkType.INTERNAL, fx_session)


@pytest.mark.asyncio
async def test_check_transaction_status_async(fx_session, fx_mainnet_transactions):
assert fx_session.query(Transaction).count() == 0
fx_session.add_all(fx_mainnet_transactions)
fx_session.flush()
await signer.check_transaction_status_async(NetworkType.MAIN)
await signer.check_transaction_status_async(NetworkType.MAIN, fx_session)
transactions = fx_session.query(Transaction)
for transaction in transactions:
assert transaction.tx_result == "SUCCESS"
assert transaction.tx_result == "INVALID"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -105,9 +106,9 @@ def test_query_transaction_result(fx_session, fx_mainnet_transactions):
fx_session.add(tx)
fx_session.flush()
url = MINER_URLS[NetworkType.MAIN]
signer.query_transaction_result(url, tx.tx_id)
signer.query_transaction_result(url, tx.tx_id, fx_session)
transaction = fx_session.query(Transaction).one()
assert transaction.tx_result == "SUCCESS"
assert transaction.tx_result == "INVALID"


def test_query_balance(fx_session):
Expand All @@ -133,7 +134,7 @@ def test_query_balance(fx_session):
fx_session.add(reward_amount)
i += 1
fx_session.commit()
currencies = get_currencies()
currencies = get_currencies(fx_session)
for currency in currencies:
balance = signer.query_balance(url, currency)
assert balance == f'0 {currency["ticker"]}'
8 changes: 4 additions & 4 deletions tests/raid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_get_next_tx_nonce(fx_session, nonce_list: List[int], expected: int):
tx.payload = "payload"
fx_session.add(tx)
fx_session.flush()
assert get_next_tx_nonce() == expected
assert get_next_tx_nonce(fx_session) == expected


@pytest.mark.parametrize("tx_exist", [True, False])
Expand All @@ -144,7 +144,7 @@ def test_get_next_tx_nonce_tx_empty(fx_session, tx_exist: bool):
tx.payload = "payload"
fx_session.add(tx)
fx_session.flush()
assert get_next_tx_nonce() == 1
assert get_next_tx_nonce(fx_session) == 1


def test_get_assets(fx_session) -> None:
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_get_assets(fx_session) -> None:
fx_session.commit()
for i, asset in enumerate(assets):
raid_id = i + 1
assert get_assets(raid_id) == [assets[i]]
assert get_assets(raid_id, fx_session) == [assets[i]]


def test_write_tx_result_csv(tmp_path):
Expand Down Expand Up @@ -215,4 +215,4 @@ def test_list_tx_nonce(fx_session, nonce_list: List[int]):
tx.payload = "payload"
fx_session.add(tx)
fx_session.flush()
assert list_tx_nonce() == nonce_list
assert list_tx_nonce(fx_session) == nonce_list
Loading

0 comments on commit f6e5276

Please sign in to comment.