diff --git a/api/gym_walk_env/gym_walk_env_api.py b/api/gym_walk_env/gym_walk_env_api.py index 4b8cac3..92d71d8 100644 --- a/api/gym_walk_env/gym_walk_env_api.py +++ b/api/gym_walk_env/gym_walk_env_api.py @@ -108,32 +108,11 @@ async def reset(idx: str, reset_ops: RestEnvRequestModel) -> JSONResponse: detail={"message": f"Environment {ENV_NAME} is not initialized." " Have you called make()?"}) - # global envs - # if cidx in envs: - # env = envs[cidx] - # - # if env is not None: - # observation, info = envs[cidx].reset(seed=seed) - # - # step = TimeStep(observation=observation, - # reward=0.0, - # step_type=TimeStepType.FIRST, - # info=info, - # discount=1.0) - # logger.info(f'Reset environment {ENV_NAME} and index {cidx}') - # return JSONResponse(status_code=status.HTTP_202_ACCEPTED, - # content={"time_step": step.model_dump()}) - # - # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - # detail={"message": f"Environment {ENV_NAME} is not initialized." - # " Have you called make()?"}) - @gym_walk_env_router.post("/{idx}/step", status_code=status.HTTP_202_ACCEPTED, response_model=TimeStepResponse) async def step(idx: str, action: DiscreteAction, api_config: Annotated[Config, Depends(get_api_config)]) -> JSONResponse: - if idx not in manager: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail={"message": "NOT_ALIVE/NOT_CREATED. Call make/reset"}) @@ -154,39 +133,16 @@ async def step(idx: str, action: DiscreteAction, observation = step_result.observation step_ = TimeStep(observation=observation, - reward=step_result.reward, - step_type=step_type, - info=step_result.info, - discount=1.0) + reward=step_result.reward, + step_type=step_type, + info=step_result.info, + discount=1.0) if api_config.LOG_INFO: logger.info(f'Step in environment {ENV_NAME} and index {idx}') return JSONResponse(status_code=status.HTTP_202_ACCEPTED, content={"time_step": step_.model_dump()}) - # global envs - # if cidx in envs: - # env = envs[cidx] - # - # if env is not None: - # observation, reward, terminated, truncated, info = envs[cidx].step(action) - # - # step_type = TimeStepType.MID - # if terminated or truncated: - # step_type = TimeStepType.LAST - # - # step = TimeStep(observation=observation, - # reward=reward, - # step_type=step_type, - # info=info, - # discount=1.0) - # logger.info(f'Step in environment {ENV_NAME} and index {cidx}') - # return JSONResponse(status_code=status.HTTP_202_ACCEPTED, - # content={"time_step": step.model_dump()}) - # - # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - # detail=f"Environment {ENV_NAME} is not initialized. Have you called make()?") - @gym_walk_env_router.get("/{idx}/dynamics", response_model=GetEnvDynmicsResponseModel) async def get_dynamics(idx: str, dyn_req: Annotated[GetEnvDynmicsRequestModel, Query()], @@ -212,31 +168,3 @@ async def get_dynamics(idx: str, dyn_req: Annotated[GetEnvDynmicsRequestModel, Q logger.info(f'Get dynamics for state={dyn_req.state_id}/action={dyn_req.action_id}') return JSONResponse(status_code=status.HTTP_200_OK, content={"dynamics": dynamics}) - - # global envs - # - # env = None - # if cidx in envs: - # env = envs[cidx] - # - # if env is None: - # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - # detail=f"Environment {ENV_NAME} does not exposes dynamics.") - # - # if state >= env.nS: - # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - # detail=f"Action {state} should be in [0, {env.nS})") - # - # if action is not None: - # - # if action not in ACTIONS_SPACE: - # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, - # detail=f"Action {action} not in {list(ACTIONS_SPACE.keys())}") - # - # p = env.P[state][action] - # return JSONResponse(status_code=status.HTTP_200_OK, - # content={"p": p}) - # - # p = env.P[state] - # return JSONResponse(status_code=status.HTTP_200_OK, - # content={"p": p}) diff --git a/api/restaurant_env/__init__.py b/api/restaurant_env/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/restaurant_env/restaurant.py b/api/restaurant_env/restaurant.py new file mode 100644 index 0000000..7be9c1d --- /dev/null +++ b/api/restaurant_env/restaurant.py @@ -0,0 +1,34 @@ +import numpy as np +from typing import Optional + + +class RestaurantEnv: + def __init__(self, *, hours_left: Optional[int] = 4, + large_party_prob: Optional[float] = 0.3, + small_party_revenue: Optional[float] = 50.0, + large_party_revenue: Optional[float] = 120.0): + self.hours_left = hours_left + self._hours_left_copy = hours_left + self.large_party_prob = large_party_prob + self.small_party_revenue = small_party_revenue + self.large_party_revenue = large_party_revenue + + def reset(self): + self.hours_left = self._hours_left_copy + return self.hours_left + + def step(self, action: int): + if action == 1: + return 0, self.small_party_revenue, True + + self.hours_left -= 1 + if self.hours_left == 0: + return 0, 0, True + + if np.random.rand() < self.large_party_prob: + return 0, self.large_party_revenue, True + else: + return self.hours_left, 0, False + + def close(self) -> None: + pass diff --git a/api/restaurant_env/restaurant_env_api.py b/api/restaurant_env/restaurant_env_api.py new file mode 100644 index 0000000..47019d6 --- /dev/null +++ b/api/restaurant_env/restaurant_env_api.py @@ -0,0 +1,145 @@ +import sys +from loguru import logger +from typing import Annotated +from fastapi import APIRouter, status, Depends, Query +from fastapi.responses import JSONResponse +from fastapi import HTTPException +from api.utils.time_step_response import TimeStep, TimeStepType, TimeStepResponse + +from api.utils.make_env_request_model import MakeEnvRequestModel +from api.utils.make_env_response_model import MakeEnvResponseModel +from api.utils.reset_request_model import RestEnvRequestModel +from api.utils.spaces.actions import DiscreteAction + +from api.api_config import get_api_config, Config +from .restaurant_manager import RestaurantEnvManager + +restaurant_env_router = APIRouter(prefix="/gdrl/gym-walk-env", tags=["gym-walk-env"]) + +ENV_NAME = "RestaurantEnv" + +# actions that the environment accepts +ACTIONS_SPACE = {1: "Choose small party", 0: "EAST"} + +DEFAULT_OPTIONS = {'hours_left': 4, + 'large_party_prob': 0.3, + 'small_party_revenue': 50.0, + 'large_party_revenue': 120.0} +DEFAULT_VERSION = "v1" + +manager = RestaurantEnvManager(verbose=True) + + +@restaurant_env_router.get("/copies") +async def get_n_copies() -> JSONResponse: + return JSONResponse(status_code=status.HTTP_200_OK, + content={"copies": len(manager)}) + + +@restaurant_env_router.get("/{idx}/is-alive") +async def get_is_alive(idx: str) -> JSONResponse: + is_alive_ = manager.is_alive(idx=idx) + + return JSONResponse(status_code=status.HTTP_200_OK, + content={"result": is_alive_}) + + +@restaurant_env_router.post("/{idx}/close") +async def close(idx: str) -> JSONResponse: + closed = await manager.close(idx=idx) + + if closed: + return JSONResponse(status_code=status.HTTP_202_ACCEPTED, + content={"message": "OK"}) + + return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, + content={"message": "FAILED"}) + + +@restaurant_env_router.post("/make", status_code=status.HTTP_201_CREATED, + response_model=MakeEnvResponseModel) +async def make(request: MakeEnvRequestModel, + api_config: Annotated[Config, Depends(get_api_config)] + ) -> JSONResponse: + version = request.version or DEFAULT_VERSION + + # merge defaults with user overrides + options = DEFAULT_OPTIONS | (request.options or {}) + + env_type = f"{ENV_NAME}-{version}" + + if api_config.LOG_INFO: + logger.info(f'Creating environment {env_type}') + + idx = await manager.make(env_name=env_type, **options) + + if api_config.LOG_INFO: + logger.info(f'Created environment {ENV_NAME} and index {idx}') + return JSONResponse(status_code=status.HTTP_201_CREATED, + content={"message": "OK", "idx": idx}) + + +@restaurant_env_router.post("/{idx}/reset", status_code=status.HTTP_202_ACCEPTED, + response_model=TimeStepResponse) +async def reset(idx: str, reset_ops: RestEnvRequestModel) -> JSONResponse: + """Reset the environment + + :return: + """ + + if idx not in manager: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail={"message": "NOT_ALIVE/NOT_CREATED"}) + + try: + reset_step = await manager.reset(idx=idx, seed=reset_ops.seed) + + observation = reset_step.observation + step_ = TimeStep(observation=observation, + reward=0.0, + step_type=TimeStepType.FIRST, + info=reset_step.info, + discount=1.0) + return JSONResponse(status_code=status.HTTP_202_ACCEPTED, + content={"time_step": step_.model_dump()}) + except Exception as e: + exception = sys.exc_info() + logger.opt(exception=exception).info("Logging exception traceback") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail={"message": f"Environment {ENV_NAME} is not initialized." + " Have you called make()?"}) + + +@restaurant_env_router.post("/{idx}/step", status_code=status.HTTP_202_ACCEPTED, + response_model=TimeStepResponse) +async def step(idx: str, action: DiscreteAction, + api_config: Annotated[Config, Depends(get_api_config)]) -> JSONResponse: + if idx not in manager: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail={"message": "NOT_ALIVE/NOT_CREATED. Call make/reset"}) + + if action.action not in ACTIONS_SPACE: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Action {action} not in {list(ACTIONS_SPACE.keys())}") + + step_result = await manager.step(idx=idx, action=action.action) + + step_type = TimeStepType.MID + if step_result.terminated: + step_type = TimeStepType.LAST + + info = step_result.info + if info is not None: + info['truncated'] = step_result.truncated + + observation = step_result.observation + step_ = TimeStep(observation=observation, + reward=step_result.reward, + step_type=step_type, + info=step_result.info, + discount=1.0) + + if api_config.LOG_INFO: + logger.info(f'Step in environment {ENV_NAME} and index {idx}') + return JSONResponse(status_code=status.HTTP_202_ACCEPTED, + content={"time_step": step_.model_dump()}) diff --git a/api/restaurant_env/restaurant_manager.py b/api/restaurant_env/restaurant_manager.py new file mode 100644 index 0000000..2a49286 --- /dev/null +++ b/api/restaurant_env/restaurant_manager.py @@ -0,0 +1,80 @@ +from typing import Any, Optional +from collections import namedtuple +import uuid +import gymnasium as gym +import asyncio +from loguru import logger + + +from .restaurant import RestaurantEnv + +GymEnvResetResult = namedtuple(typename='GymEnvResetResult', field_names=['observation', + 'info']) +GymEnvStepResult = namedtuple(typename="GymEnvStepResult", field_names=["observation", + "reward", + "terminated", + "truncated", + "info"]) + + +class RestaurantEnvManager: + """ + Thread-safe async environment manager for Gymnasium environments. + Supports multiple independent environments (for distributed RL setups). + """ + + def __init__(self, verbose: bool = True): + self.verbose = verbose + self.envs: dict[str, RestaurantEnv] = {} + self.locks: dict[str, asyncio.Lock] = {} + + def __len__(self) -> int: + return len(self.envs) + + def __contains__(self, idx: str) -> bool: + """Allow `if idx in manager:` syntax.""" + return self.is_alive(idx) + + def get_lock(self, idx: str) -> asyncio.Lock: + if idx not in self.locks: + raise ValueError(f"idx not in locks") + return self.locks[idx] + + async def make(self, env_name: str, **kwargs) -> str: + + idx = uuid.uuid4().hex + self.envs[idx] = RestaurantEnv(**kwargs) + self.locks[idx] = asyncio.Lock() + return idx + + async def close(self, idx: str) -> bool: + async with self.get_lock(idx): + if idx in self.envs: + self.envs[idx].close() + del self.envs[idx] + return True + return False + + async def step(self, idx: str, action: int) -> GymEnvStepResult: + async with self.get_lock(idx): + env = self.envs.get(idx) + if env is None: + raise ValueError("Env not found.") + observation, reward, terminated, info = env.step(action) + return GymEnvStepResult(observation=observation, reward=reward, + terminated=terminated, truncated=None, info=info) + + async def reset(self, idx: str, seed: Optional[int] = None, **kwargs) -> GymEnvResetResult: + """Reset the environment and return (observation, info).""" + async with self.get_lock(idx): + env = self.envs.get(idx) + if env is None: + raise ValueError(f"Environment {idx} not found. Have you called make()?") + + obs = env.reset() + logger.info(f"Reset environment {idx}") + return GymEnvResetResult(observation=obs, info={}) + + def is_alive(self, idx: str) -> bool: + """Check if an environment exists and is active.""" + return idx in self.envs and self.envs[idx] is not None