Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 4 additions & 76 deletions api/gym_walk_env/gym_walk_env_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,11 @@ async def reset(idx: str, reset_ops: RestEnvRequestModel) -> JSONResponse:
detail={"message": f"Environment {ENV_NAME} is not initialized."
" Have you called make()?"})

# global envs
# if cidx in envs:
# env = envs[cidx]
#
# if env is not None:
# observation, info = envs[cidx].reset(seed=seed)
#
# step = TimeStep(observation=observation,
# reward=0.0,
# step_type=TimeStepType.FIRST,
# info=info,
# discount=1.0)
# logger.info(f'Reset environment {ENV_NAME} and index {cidx}')
# return JSONResponse(status_code=status.HTTP_202_ACCEPTED,
# content={"time_step": step.model_dump()})
#
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
# detail={"message": f"Environment {ENV_NAME} is not initialized."
# " Have you called make()?"})


@gym_walk_env_router.post("/{idx}/step", status_code=status.HTTP_202_ACCEPTED,
response_model=TimeStepResponse)
async def step(idx: str, action: DiscreteAction,
api_config: Annotated[Config, Depends(get_api_config)]) -> JSONResponse:

if idx not in manager:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "NOT_ALIVE/NOT_CREATED. Call make/reset"})
Expand All @@ -154,39 +133,16 @@ async def step(idx: str, action: DiscreteAction,

observation = step_result.observation
step_ = TimeStep(observation=observation,
reward=step_result.reward,
step_type=step_type,
info=step_result.info,
discount=1.0)
reward=step_result.reward,
step_type=step_type,
info=step_result.info,
discount=1.0)

if api_config.LOG_INFO:
logger.info(f'Step in environment {ENV_NAME} and index {idx}')
return JSONResponse(status_code=status.HTTP_202_ACCEPTED,
content={"time_step": step_.model_dump()})

# global envs
# if cidx in envs:
# env = envs[cidx]
#
# if env is not None:
# observation, reward, terminated, truncated, info = envs[cidx].step(action)
#
# step_type = TimeStepType.MID
# if terminated or truncated:
# step_type = TimeStepType.LAST
#
# step = TimeStep(observation=observation,
# reward=reward,
# step_type=step_type,
# info=info,
# discount=1.0)
# logger.info(f'Step in environment {ENV_NAME} and index {cidx}')
# return JSONResponse(status_code=status.HTTP_202_ACCEPTED,
# content={"time_step": step.model_dump()})
#
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"Environment {ENV_NAME} is not initialized. Have you called make()?")


@gym_walk_env_router.get("/{idx}/dynamics", response_model=GetEnvDynmicsResponseModel)
async def get_dynamics(idx: str, dyn_req: Annotated[GetEnvDynmicsRequestModel, Query()],
Expand All @@ -212,31 +168,3 @@ async def get_dynamics(idx: str, dyn_req: Annotated[GetEnvDynmicsRequestModel, Q
logger.info(f'Get dynamics for state={dyn_req.state_id}/action={dyn_req.action_id}')
return JSONResponse(status_code=status.HTTP_200_OK,
content={"dynamics": dynamics})

# global envs
#
# env = None
# if cidx in envs:
# env = envs[cidx]
#
# if env is None:
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"Environment {ENV_NAME} does not exposes dynamics.")
#
# if state >= env.nS:
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"Action {state} should be in [0, {env.nS})")
#
# if action is not None:
#
# if action not in ACTIONS_SPACE:
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"Action {action} not in {list(ACTIONS_SPACE.keys())}")
#
# p = env.P[state][action]
# return JSONResponse(status_code=status.HTTP_200_OK,
# content={"p": p})
#
# p = env.P[state]
# return JSONResponse(status_code=status.HTTP_200_OK,
# content={"p": p})
Empty file added api/restaurant_env/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions api/restaurant_env/restaurant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
from typing import Optional


class RestaurantEnv:
def __init__(self, *, hours_left: Optional[int] = 4,
large_party_prob: Optional[float] = 0.3,
small_party_revenue: Optional[float] = 50.0,
large_party_revenue: Optional[float] = 120.0):
self.hours_left = hours_left
self._hours_left_copy = hours_left
self.large_party_prob = large_party_prob
self.small_party_revenue = small_party_revenue
self.large_party_revenue = large_party_revenue

def reset(self):
self.hours_left = self._hours_left_copy
return self.hours_left

def step(self, action: int):
if action == 1:
return 0, self.small_party_revenue, True

self.hours_left -= 1
if self.hours_left == 0:
return 0, 0, True

if np.random.rand() < self.large_party_prob:
return 0, self.large_party_revenue, True
else:
return self.hours_left, 0, False

def close(self) -> None:
pass
145 changes: 145 additions & 0 deletions api/restaurant_env/restaurant_env_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import sys
from loguru import logger
from typing import Annotated
from fastapi import APIRouter, status, Depends, Query
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from api.utils.time_step_response import TimeStep, TimeStepType, TimeStepResponse

from api.utils.make_env_request_model import MakeEnvRequestModel
from api.utils.make_env_response_model import MakeEnvResponseModel
from api.utils.reset_request_model import RestEnvRequestModel
from api.utils.spaces.actions import DiscreteAction

from api.api_config import get_api_config, Config
from .restaurant_manager import RestaurantEnvManager

restaurant_env_router = APIRouter(prefix="/gdrl/gym-walk-env", tags=["gym-walk-env"])

ENV_NAME = "RestaurantEnv"

# actions that the environment accepts
ACTIONS_SPACE = {1: "Choose small party", 0: "EAST"}

DEFAULT_OPTIONS = {'hours_left': 4,
'large_party_prob': 0.3,
'small_party_revenue': 50.0,
'large_party_revenue': 120.0}
DEFAULT_VERSION = "v1"

manager = RestaurantEnvManager(verbose=True)


@restaurant_env_router.get("/copies")
async def get_n_copies() -> JSONResponse:
return JSONResponse(status_code=status.HTTP_200_OK,
content={"copies": len(manager)})


@restaurant_env_router.get("/{idx}/is-alive")
async def get_is_alive(idx: str) -> JSONResponse:
is_alive_ = manager.is_alive(idx=idx)

return JSONResponse(status_code=status.HTTP_200_OK,
content={"result": is_alive_})


@restaurant_env_router.post("/{idx}/close")
async def close(idx: str) -> JSONResponse:
closed = await manager.close(idx=idx)

if closed:
return JSONResponse(status_code=status.HTTP_202_ACCEPTED,
content={"message": "OK"})

return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST,
content={"message": "FAILED"})


@restaurant_env_router.post("/make", status_code=status.HTTP_201_CREATED,
response_model=MakeEnvResponseModel)
async def make(request: MakeEnvRequestModel,
api_config: Annotated[Config, Depends(get_api_config)]
) -> JSONResponse:
version = request.version or DEFAULT_VERSION

# merge defaults with user overrides
options = DEFAULT_OPTIONS | (request.options or {})

env_type = f"{ENV_NAME}-{version}"

if api_config.LOG_INFO:
logger.info(f'Creating environment {env_type}')

idx = await manager.make(env_name=env_type, **options)

if api_config.LOG_INFO:
logger.info(f'Created environment {ENV_NAME} and index {idx}')
return JSONResponse(status_code=status.HTTP_201_CREATED,
content={"message": "OK", "idx": idx})


@restaurant_env_router.post("/{idx}/reset", status_code=status.HTTP_202_ACCEPTED,
response_model=TimeStepResponse)
async def reset(idx: str, reset_ops: RestEnvRequestModel) -> JSONResponse:
"""Reset the environment

:return:
"""

if idx not in manager:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "NOT_ALIVE/NOT_CREATED"})

try:
reset_step = await manager.reset(idx=idx, seed=reset_ops.seed)

observation = reset_step.observation
step_ = TimeStep(observation=observation,
reward=0.0,
step_type=TimeStepType.FIRST,
info=reset_step.info,
discount=1.0)
return JSONResponse(status_code=status.HTTP_202_ACCEPTED,
content={"time_step": step_.model_dump()})
except Exception as e:
exception = sys.exc_info()
logger.opt(exception=exception).info("Logging exception traceback")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": f"Environment {ENV_NAME} is not initialized."
" Have you called make()?"})


@restaurant_env_router.post("/{idx}/step", status_code=status.HTTP_202_ACCEPTED,
response_model=TimeStepResponse)
async def step(idx: str, action: DiscreteAction,
api_config: Annotated[Config, Depends(get_api_config)]) -> JSONResponse:
if idx not in manager:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail={"message": "NOT_ALIVE/NOT_CREATED. Call make/reset"})

if action.action not in ACTIONS_SPACE:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Action {action} not in {list(ACTIONS_SPACE.keys())}")

step_result = await manager.step(idx=idx, action=action.action)

step_type = TimeStepType.MID
if step_result.terminated:
step_type = TimeStepType.LAST

info = step_result.info
if info is not None:
info['truncated'] = step_result.truncated

observation = step_result.observation
step_ = TimeStep(observation=observation,
reward=step_result.reward,
step_type=step_type,
info=step_result.info,
discount=1.0)

if api_config.LOG_INFO:
logger.info(f'Step in environment {ENV_NAME} and index {idx}')
return JSONResponse(status_code=status.HTTP_202_ACCEPTED,
content={"time_step": step_.model_dump()})
80 changes: 80 additions & 0 deletions api/restaurant_env/restaurant_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Optional
from collections import namedtuple
import uuid
import gymnasium as gym
import asyncio
from loguru import logger


from .restaurant import RestaurantEnv

GymEnvResetResult = namedtuple(typename='GymEnvResetResult', field_names=['observation',
'info'])
GymEnvStepResult = namedtuple(typename="GymEnvStepResult", field_names=["observation",
"reward",
"terminated",
"truncated",
"info"])


class RestaurantEnvManager:
"""
Thread-safe async environment manager for Gymnasium environments.
Supports multiple independent environments (for distributed RL setups).
"""

def __init__(self, verbose: bool = True):
self.verbose = verbose
self.envs: dict[str, RestaurantEnv] = {}
self.locks: dict[str, asyncio.Lock] = {}

def __len__(self) -> int:
return len(self.envs)

def __contains__(self, idx: str) -> bool:
"""Allow `if idx in manager:` syntax."""
return self.is_alive(idx)

def get_lock(self, idx: str) -> asyncio.Lock:
if idx not in self.locks:
raise ValueError(f"idx not in locks")
return self.locks[idx]

async def make(self, env_name: str, **kwargs) -> str:

idx = uuid.uuid4().hex
self.envs[idx] = RestaurantEnv(**kwargs)
self.locks[idx] = asyncio.Lock()
return idx

async def close(self, idx: str) -> bool:
async with self.get_lock(idx):
if idx in self.envs:
self.envs[idx].close()
del self.envs[idx]
return True
return False

async def step(self, idx: str, action: int) -> GymEnvStepResult:
async with self.get_lock(idx):
env = self.envs.get(idx)
if env is None:
raise ValueError("Env not found.")
observation, reward, terminated, info = env.step(action)
return GymEnvStepResult(observation=observation, reward=reward,
terminated=terminated, truncated=None, info=info)

async def reset(self, idx: str, seed: Optional[int] = None, **kwargs) -> GymEnvResetResult:
"""Reset the environment and return (observation, info)."""
async with self.get_lock(idx):
env = self.envs.get(idx)
if env is None:
raise ValueError(f"Environment {idx} not found. Have you called make()?")

obs = env.reset()
logger.info(f"Reset environment {idx}")
return GymEnvResetResult(observation=obs, info={})

def is_alive(self, idx: str) -> bool:
"""Check if an environment exists and is active."""
return idx in self.envs and self.envs[idx] is not None