Skip to content

Commit

Permalink
♻️ Refactor rebalance API endpoints (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
mingi3314 authored Mar 17, 2024
1 parent af45c30 commit c817832
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 18 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ repos:
types-PyYAML===6.0.12.12,
types-toml===0.10.8.7,
fastapi==0.109.2,
freezegun==1.4.0,
]
47 changes: 32 additions & 15 deletions pyrb/controllers/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pyrb.enums import AssetAllocationStrategyEnum, BrokerageType
from pyrb.exceptions import InitializationError
from pyrb.models.account import Account, AccountFactory
from pyrb.models.order import OrderPlacementResult
from pyrb.models.order import Order, OrderPlacementResult
from pyrb.models.position import Position
from pyrb.services.rebalance import Rebalancer
from pyrb.services.strategy.asset_allocate import AssetAllocationStrategyFactory
Expand Down Expand Up @@ -52,12 +52,16 @@ class PortfolioResponse(BaseModel):
positions: list[Position]


class RebalanceRequest(BaseModel):
investment_amount: float | None
class OrdersPrepareResponse(BaseModel):
orders: list[Order]


class RebalanceResponse(BaseModel):
rebalanced_at: AwareDatetime
class OrdersPlaceRequest(BaseModel):
orders: list[Order]


class OrdersPlaceResponse(BaseModel):
placed_at: AwareDatetime
placed_orders: list[OrderPlacementResult]


Expand Down Expand Up @@ -94,19 +98,32 @@ async def get_portfolio(context: RebalanceContextDep) -> PortfolioResponse:
)


# TODO: Swagger에서 StrEnum이 제대로 표시되지 않는 문제 원인 파악 후 수정
@app.post("/strategies/{strategy_type}/rebalance", response_model=RebalanceResponse)
async def rebalance(
context: RebalanceContextDep, strategy_type: AssetAllocationStrategyEnum, body: RebalanceRequest
) -> RebalanceResponse:
@app.get("/strategies/{strategy_type}/orders", response_model=OrdersPrepareResponse)
async def prepare_orders(
context: RebalanceContextDep,
strategy_type: AssetAllocationStrategyEnum,
) -> OrdersPrepareResponse:
strategy = AssetAllocationStrategyFactory.create(strategy_type)
rebalancer = Rebalancer(context, strategy)
rebalancer = Rebalancer(context)

orders = rebalancer.prepare_orders(
strategy=strategy, investment_amount=context.portfolio.total_value * 0.99
)

return OrdersPrepareResponse(
orders=orders,
)

orders = rebalancer.prepare_orders(investment_amount=body.investment_amount)
placed_orders = rebalancer.place_orders(orders)

return RebalanceResponse(
rebalanced_at=datetime.datetime.now(ZoneInfo("Asia/Seoul")),
@app.post("/strategies/{strategy_type}/orders", response_model=OrdersPlaceResponse)
async def place_orders(
context: RebalanceContextDep,
body: OrdersPlaceRequest,
) -> OrdersPlaceResponse:
rebalancer = Rebalancer(context)
placed_orders = rebalancer.place_orders(body.orders)
return OrdersPlaceResponse(
placed_at=datetime.datetime.now(ZoneInfo("Asia/Seoul")),
placed_orders=placed_orders,
)

Expand Down
141 changes: 138 additions & 3 deletions tests/controllers/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from fastapi.testclient import TestClient
from freezegun import freeze_time

from pyrb.controllers.api.deps import account_repo_dep, context_dep
from pyrb.controllers.api.main import AccountCreateResponse, app
Expand Down Expand Up @@ -66,19 +67,153 @@ def test_get_default_account_without_created_account() -> None:
assert response.json() == {"detail": "No accounts registered"}


def test_rebalance(fake_rebalance_context: RebalanceContext) -> None:
def test_prepare_orders(fake_rebalance_context: RebalanceContext) -> None:
# Given
create_account()
app.dependency_overrides[context_dep] = lambda: fake_rebalance_context

# When
response = client.get(
"/strategies/all-weather-kr/orders",
)

# Then
assert response.status_code == 200
assert response.json() == {
"orders": [
{
"symbol": "379800",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
{
"symbol": "361580",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
{
"symbol": "411060",
"price": 100,
"quantity": 148,
"side": "BUY",
"order_type": "MARKET",
},
{
"symbol": "365780",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
{
"symbol": "308620",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
{
"symbol": "272580",
"price": 100,
"quantity": 148,
"side": "BUY",
"order_type": "MARKET",
},
]
}
app.dependency_overrides.clear()


@freeze_time("2024-01-03T00:00:00+09:00")
def test_place_orders(fake_rebalance_context: RebalanceContext) -> None:
# Given
create_account()
app.dependency_overrides[context_dep] = lambda: fake_rebalance_context
orders = client.get("/strategies/all-weather-kr/orders").json()["orders"]

# When
response = client.post(
"strategies/all-weather-kr/rebalance",
json={"investment_amount": 100000},
"/strategies/all-weather-kr/orders",
json={"orders": orders},
)

# Then
assert response.status_code == 200
assert response.json() == {
"placed_at": "2024-01-03T00:00:00+09:00",
"placed_orders": [
{
"order": {
"symbol": "379800",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
"success": True,
"message": None,
},
{
"order": {
"symbol": "361580",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
"success": True,
"message": None,
},
{
"order": {
"symbol": "411060",
"price": 100,
"quantity": 148,
"side": "BUY",
"order_type": "MARKET",
},
"success": True,
"message": None,
},
{
"order": {
"symbol": "365780",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
"success": True,
"message": None,
},
{
"order": {
"symbol": "308620",
"price": 100,
"quantity": 173,
"side": "BUY",
"order_type": "MARKET",
},
"success": True,
"message": None,
},
{
"order": {
"symbol": "272580",
"price": 100,
"quantity": 148,
"side": "BUY",
"order_type": "MARKET",
},
"success": True,
"message": None,
},
],
}
app.dependency_overrides.clear()


Expand Down
1 change: 1 addition & 0 deletions tests/controllers/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def test_sut_stops_rebalancing_with_insufficient_funds(

# then
assert result.exit_code == 1
assert result.exc_info is not None
assert result.exc_info[0] == InsufficientFundsException


Expand Down

0 comments on commit c817832

Please sign in to comment.