Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[Serve] Refactor opt demo website to use alpa.serve controller (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Oct 1, 2022
1 parent 375729a commit 56fb3e6
Show file tree
Hide file tree
Showing 26 changed files with 684 additions and 807 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ examples/llm_serving/dataset/*.so
examples/llm_serving/dataset/*.c
examples/llm_serving/dataset/*.cpp
examples/llm_serving/weblogs
examples/llm_serving/keys_file.json
examples/llm_serving/benchmark/tmp*
examples/llm_serving/tmp*
examples/opt_finetune/output/
Expand Down
22 changes: 19 additions & 3 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
time.
"""
from abc import ABC, abstractmethod
import asyncio
from collections import defaultdict, namedtuple
from collections.abc import Iterable
import logging
Expand Down Expand Up @@ -1104,7 +1105,7 @@ def get_remote_buffers(
ary_refs: Union[List["RemoteArrayRef"], "RemoteArrayRef"],
host_local_ids: Sequence[Sequence[Tuple[int, int]]] = None,
batching=False,
asynchronize=False):
return_ray_ref=False):
"""
Get values of remote buffers.
Expand Down Expand Up @@ -1163,7 +1164,7 @@ def get_remote_buffers(
self.workers[host_id].get_buffers.remote(
ary_ref.uuid, local_id))
obj_refs.append(ary_obj_refs)
if asynchronize:
if return_ray_ref:
ret = obj_refs
else:
ret = [ray.get(refs) for refs in obj_refs]
Expand Down Expand Up @@ -1459,7 +1460,7 @@ def __init__(self,
def size(self):
return np.prod(self.shape)

def get_remote_buffers_async(self):
def prefetch(self):
# TODO (yinmin): Move this function out of DistributedArray
# and batch different requests. Also need to add another
# function to `ray.wait` for the remote references.
Expand All @@ -1478,6 +1479,21 @@ def delete(self):
def flush(self):
self._npy_value = None

async def to_np_async(self):
if self._npy_value is None:
npy_value = np.empty(self.aval.shape, self.aval.dtype)
if not self._fetched_np_buffers:
if not self._fetched_np_buffers_ref:
self.prefetch()
fetched_np_buffers = await asyncio.gather(
*self._fetched_np_buffers_ref)
else:
fetched_np_buffers = self._fetched_np_buffers
for ct, i in enumerate(self.one_replica_buffer_ids):
npy_value[self.indices[i]] = fetched_np_buffers[ct]
self._npy_value = npy_value
return self._npy_value

##### distributed save/load #####
def save(self, ckpt_dir: str, local_cache_dir: Union[str, None] = None):
"""
Expand Down
2 changes: 2 additions & 0 deletions alpa/serve/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Alpa serving backend"""
from alpa.serve.controller import CONTROLLER_NAME, run_controller
157 changes: 105 additions & 52 deletions alpa/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,19 @@
import os
import pickle
import socket
from typing import Callable, List, Dict, Optional, Tuple, Any
import time
from typing import Callable, List, Dict, Optional, Tuple, Any, Union

from fastapi.middleware.cors import CORSMiddleware
import ray
from ray.actor import ActorHandle
import uvicorn

from alpa.api import init
from alpa.serve.http_util import (
HTTPRequestWrapper,
receive_http_body,
Response,
set_socket_reuse_port,
ASGIHandler,
build_starlette_request,
new_port,
)
from alpa.serve.http_util import (HTTPRequestWrapper, receive_http_body,
Response, set_socket_reuse_port, ASGIHandler,
build_starlette_request, new_port,
RelayException, make_error_response)

logger = logging.getLogger(__file__)

Expand All @@ -48,7 +45,7 @@ class ModelInfo:
@ray.remote(num_cpus=1)
class DeviceMeshGroupManager:

def __init__(self, virtual_mesh_shape):
def __init__(self, virtual_mesh_shape: Optional[Tuple[int]] = None):
if virtual_mesh_shape:
init(cluster="ray",
num_nodes=virtual_mesh_shape[0],
Expand All @@ -60,28 +57,46 @@ def __init__(self, virtual_mesh_shape):
self.replicas = {}

def create_replica(self, name: str, create_info: CreateInfo):
assert name not in self.replicas

model_def, args, kwargs = (create_info.model_def, create_info.init_args,
create_info.init_kwargs)
args = args or []
kwargs = kwargs or {}
self.replicas[name] = model_def(*args, **kwargs)

def delete_replica(self, name: str):
assert name in self.replicas
del self.replicas[name]

async def handle_request(self, name: str, request_wrapper: bytes):
request_wrapper = pickle.loads(request_wrapper)
request = build_starlette_request(request_wrapper)
response = await self.replicas[name].handle_request(request)
return response
request.tstamp = request_wrapper.scope["tstamp"]
try:
response = await self.replicas[name].handle_request(request)
return response
except Exception as e: # pylint: disable=broad-except
return RelayException(e)


@ray.remote(num_cpus=0)
class Controller:

def __init__(self, host: str, port: int, root_path: str):
def __init__(self,
host: str,
port: int,
root_path: str,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[Union[str, os.PathLike]] = None):
self.host = host
self.port = port
self.root_path = root_path
self.ssl_keyfile = ssl_keyfile
self.ssl_certfile = ssl_certfile

# Dict[str -> ModelInfo]
self.manager_lock = asyncio.Lock()
self.model_info = {}
self.mesh_group_managers = {}

Expand All @@ -90,34 +105,49 @@ def __init__(self, host: str, port: int, root_path: str):
self.http_server_task = asyncio.get_event_loop().create_task(
self.run_http_server())

def launch_mesh_group_manager(self,
group_id: int,
virtual_mesh_shape: Tuple[int] = None):
self.mesh_group_managers[group_id] = (
DeviceMeshGroupManager.remote(virtual_mesh_shape))

def register_model(self,
name: str,
model_def: Callable,
init_args: Optional[List] = None,
init_kwargs: Optional[Dict] = None):
assert name not in self.model_info, (
f"Model {name} is already registered")
self.model_info[name] = ModelInfo([],
CreateInfo(model_def, init_args,
init_kwargs))
def launch_mesh_group_manager(
self,
group_id: int,
virtual_mesh_shape: Optional[Tuple[int]] = None):
assert group_id not in self.mesh_group_managers, (
f"Mesh group {group_id} is already launched")
self.mesh_group_managers[group_id] = (DeviceMeshGroupManager.options(
name=f"mesh_group_manager_{group_id}").remote(virtual_mesh_shape))

async def register_model(self,
name: str,
model_def: Callable,
init_args: Optional[List] = None,
init_kwargs: Optional[Dict] = None,
override: bool = False):
async with self.manager_lock:
if name in self.model_info:
if override:
for manager in self.model_info[name].managers:
await manager.delete_replica.remote(name)
else:
raise ValueError(f"Model {name} is already registered")

self.model_info[name] = ModelInfo([],
CreateInfo(
model_def, init_args,
init_kwargs))

async def create_replica(self, name: str, mesh_group_id: int):
assert mesh_group_id in self.mesh_group_managers
model_info = self.model_info[name]
manager = self.mesh_group_managers[mesh_group_id]
assert manager not in model_info.managers
async with self.manager_lock:
assert mesh_group_id in self.mesh_group_managers
model_info = self.model_info[name]
manager = self.mesh_group_managers[mesh_group_id]
assert manager not in model_info.managers

await manager.create_replica.remote(name, model_info.create_info)
model_info.managers.append(manager)
logger.info(
f"Create replica of model={name} on mesh={mesh_group_id}")
await manager.create_replica.remote(name, model_info.create_info)
model_info.managers.append(manager)

async def handle_asgi(self, scope, receive, send):
assert scope["type"] == "http"
scope["tstamp"] = time.time()

# Receive request
http_body_bytes = await receive_http_body(scope, receive, send)
Expand All @@ -126,23 +156,30 @@ async def handle_asgi(self, scope, receive, send):
request_wrapper = pickle.dumps(request_wrapper)

# Route
obj = await request.json()
name = obj["model"]
if name not in self.model_info:
await Response(f"Model {name} is not registered",
status_code=404).send(scope, receive, send)
return
try:
obj = await request.json()

if not self.model_info[name].managers:
await Response(f"No replica of model {name} is created",
status_code=404).send(scope, receive, send)
return
assert "model" in obj, "Model name is not specified in the request."
name = obj["model"]

manager = self.model_info[name].managers[0]
assert name in self.model_info, (
f"Model '{name}' is not registered.")
assert self.model_info[name].managers, (
f"No replica of model '{name}' is created.")
manager = self.model_info[name].managers[0]

# Process request
response = await manager.handle_request.remote(name, request_wrapper)
await Response(response).send(scope, receive, send)
response = await manager.handle_request.remote(
name, request_wrapper)
if isinstance(response, Exception):
raise response

status_code = 200
except Exception as e: # pylint: disable=broad-except
response = make_error_response(e)
status_code = 400

await Response(response,
status_code=status_code).send(scope, receive, send)

def get_info(self):
return {
Expand Down Expand Up @@ -187,13 +224,23 @@ async def run_http_server(self):
# Note(simon): we have to use lower level uvicorn Config and Server
# class because we want to run the server as a coroutine. The only
# alternative is to call uvicorn.run which is blocking.
app = ASGIHandler(self)
app = CORSMiddleware(
app,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)

config = uvicorn.Config(
ASGIHandler(self),
app,
host=self.host,
port=self.port,
root_path=self.root_path,
lifespan="off",
access_log=False,
ssl_keyfile=self.ssl_keyfile,
ssl_certfile=self.ssl_certfile,
)
server = uvicorn.Server(config=config)

Expand All @@ -206,11 +253,17 @@ async def run_http_server(self):
await server.serve(sockets=[sock])


def run_controller(host, port=None, root_path="/"):
def run_controller(host,
port=None,
root_path="/",
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[Union[str, os.PathLike]] = None):
controller = Controller.options(name=CONTROLLER_NAME).remote(
host=host,
port=port or new_port(),
root_path=root_path,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
)
ray.get(controller.ready.remote())
return controller
20 changes: 20 additions & 0 deletions alpa/serve/http_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import random
import socket
import traceback
from typing import Any, Dict, Type

from fastapi.encoders import jsonable_encoder
Expand Down Expand Up @@ -358,3 +359,22 @@ async def __call__(self, scope, receive, send):
https://asgi.readthedocs.io/en/latest/specs/index.html.
"""
await self.controller.handle_asgi(scope, receive, send)


class RelayException(Exception):

def __init__(self, e):
self.e = e
self.stacktrace = "".join(traceback.format_tb(e.__traceback__))


def make_error_response(e):
if isinstance(e, RelayException):
msg = str(e.e)
stacktrace = "".join(traceback.format_tb(
e.__traceback__)) + e.stacktrace
else:
msg = str(e)
stacktrace = "".join(traceback.format_tb(e.__traceback__))

return {"type": "error", "message": msg, "stacktrace": stacktrace}
Loading

0 comments on commit 56fb3e6

Please sign in to comment.