From 1879788705fcf20c0d0c5b60c624052dfbbd88d8 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 17 Aug 2023 01:59:37 +0300 Subject: [PATCH 01/35] typos --- src/petals/utils/packaging.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/petals/utils/packaging.py b/src/petals/utils/packaging.py index c6d9faa3d..f7d2e9bcc 100644 --- a/src/petals/utils/packaging.py +++ b/src/petals/utils/packaging.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Tuple, Sequence import torch from hivemind import nested_flatten, nested_pack @@ -18,7 +18,7 @@ def _get_tensor_index(item: bytes) -> int: return int(item[3:]) -def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]: +def pack_args_kwargs(*args, **kwargs) -> Tuple[Sequence[torch.Tensor], Any]: """ Check the function's arguments and pack all tensors into different flattened lists. :returns: a flattened list of tensors and args and kwargs, where tensors were masked @@ -35,7 +35,7 @@ def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]: return flat_tensors, nested_pack(masked_flat_values, (args, kwargs)) -def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any): +def unpack_args_kwargs(flat_tensors: Sequence[torch.Tensor], args_structure: Any): """ Restore arguments after `pack_args_kwargs` function. :returns: list of args and dict of kwargs From f313730767ce389539f51f98cec8e422e65ec982 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 17 Aug 2023 03:09:45 +0300 Subject: [PATCH 02/35] WIP --- src/petals/server/backend.py | 50 ++++++++++----- src/petals/server/block_functions.py | 95 ++++++++++++++++++---------- src/petals/server/handler.py | 10 +-- src/petals/server/task_pool.py | 15 ++--- 4 files changed, 105 insertions(+), 65 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 8b788b0d5..1d171a808 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple, Union import torch -from hivemind import BatchTensorDescriptor, TensorDescriptor +from hivemind import BatchTensorDescriptor, TensorDescriptor, nested_flatten, nested_map from hivemind.moe.expert_uid import ExpertUID from hivemind.moe.server.module_backend import ModuleBackend from hivemind.utils import get_logger @@ -96,21 +96,26 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S cache_tensors.extend((keys, values)) return cache_tensors - def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: - *inputs, active_adapter = inputs - with self._peft_module.using_adapter(active_adapter): - return super().forward(*inputs) + def forward(self, *args: torch.Tensor, active_adapter: Optional[str], **kwargs) -> Tuple[torch.Tensor, ...]: + with self._peft_module.using_adapter(active_adapter), torch.no_grad(): + return self.module(*args, **kwargs) - def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: - *inputs, active_adapter = inputs - with self._peft_module.using_adapter(active_adapter): - return super().backward(*inputs) + def backward( + self, grad_outputs: torch.Tensor, *args, active_adapter: Optional[str], **kwargs + ) -> Tuple[torch.Tensor, ...]: + assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor)) + with self._peft_module.using_adapter(active_adapter), torch.enable_grad(): + (outputs,) = self.module(*args, **kwargs) + assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape + torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False) + return nested_map(lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None) @torch.inference_mode() def inference_step( self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, + kwargs: Dict[str, torch.Tensor], inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" @@ -129,8 +134,9 @@ def inference_step( layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) for offset in range(0, seq_len, max_chunk_length): hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :] + kwargs_chunk = self._select_kwargs_chunk(kwargs, seq_len, offset, max_chunk_length) output_hidden_states_chunk, new_kvs = self.module.forward( - hidden_states_chunk, layer_past=layer_past, use_cache=True + hidden_states_chunk, layer_past=layer_past, use_cache=True, **kwargs_chunk ) if seq_len > max_chunk_length: output_hidden_states[:, offset : offset + max_chunk_length] = output_hidden_states_chunk @@ -178,6 +184,17 @@ def _update_cache_inplace( new_value = new_value.view(*cache_value.shape[:2], new_length, head_dim) cache_value[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :] + @staticmethod + def _select_kwargs_chunk(kwargs: Dict[str, Any], seq_len: int, offset: int, max_chunk_length: int): + if offset == 0 and max_chunk_length >= seq_len: + return kwargs + kwargs_chunk = {} + for key, value in kwargs.items(): + if isinstance(value, torch.Tensor) and value.ndim >= 2 and value.shape[-2] == seq_len: + value = value[:, offset : offset + max_chunk_length] + kwargs_chunk[key] = value + return kwargs_chunk + def get_pools(self) -> Sequence[PrioritizedTaskPool]: return self.forward_pool, self.backward_pool, self.inference_pool @@ -220,14 +237,17 @@ def __call__( self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, + backend_kwargs: Sequence[Dict[str, torch.Tensor]], inference_infos: Sequence[InferenceMetadata], *optional_prompts: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, ...]: - assert len(inference_infos) == len( - optional_prompts - ), f"found {len(inference_infos)} blocks but {len(optional_prompts)} prompts" - for inference_info, optional_prompt in zip(inference_infos, optional_prompts): + assert ( + len(inference_infos) == len(optional_prompts) == len(backend_kwargs) + ), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(backend_kwargs)} kwargs" + for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, backend_kwargs): if optional_prompt is not None: hidden_states[:, : optional_prompt.shape[1]] += optional_prompt - (hidden_states,) = self.backends[inference_info.uid].inference_step(hidden_states, hypo_ids, inference_info) + (hidden_states,) = self.backends[inference_info.uid].inference_step( + hidden_states, hypo_ids, kwargs, inference_info + ) return (hidden_states,) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index c682663fa..d55d99045 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -35,7 +35,7 @@ async def run_rpc_forward( active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, - args_structure: Any = None, + structure: Any, ) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream @@ -45,22 +45,19 @@ async def run_rpc_forward( :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ - if args_structure is not None: - # TODO: kwargs currently is unused, it can be used later for peft-like adaptation - flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) - hidden_states, prompts, *_ = flat_tensors - + (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure) dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) assert hidden_states.ndim == 3 + num_tokens = hidden_states.shape[0] * hidden_states.shape[1] if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a chain of requested backends - for backend, prompt in zip(requested_backends, prompts): + for backend, prompt, kwargs in zip(requested_backends, prompts, backend_kwargs): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt @@ -70,8 +67,10 @@ async def run_rpc_forward( ) (hidden_states,) = await backend.forward_pool.submit_task( hidden_states, - active_adapter, + active_adapter=active_adapter, + **kwargs, priority=priority, + size=num_tokens, ) assert isinstance(hidden_states, torch.Tensor) assert ( @@ -87,15 +86,13 @@ async def run_rpc_backward( active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, - args_structure: Any = None, + structure: Any, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - if args_structure is not None: - # TODO: kwargs currently is unused, it can be used later for peft-like adaptation - flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) - inputs, grad_outputs, prompts, *_ = flat_tensors - + (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure) # Cast inputs & grad outputs to backend dtype - inputs = inputs.to(requested_backends[0].dtype) + assert hidden_states.ndim == 3 + num_tokens = hidden_states.shape[0] * hidden_states.shape[1] + hidden_states = hidden_states.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) if prompts is None or is_dummy(prompts): @@ -106,32 +103,36 @@ async def run_rpc_backward( # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output inter_inputs = [] - for backend, prompt in zip(requested_backends[:-1], prompts[:-1]): - assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" + for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], backend_kwargs): + assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" if not is_dummy(prompt): - inputs[:, : prompt.shape[1]] += prompt - inter_inputs.append(inputs) + hidden_states[:, : prompt.shape[1]] += prompt + inter_inputs.append(hidden_states) assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( - inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward" + ) + (hidden_states,) = await backend.forward_pool.submit_task( + hidden_states, active_adapter, **kwargs, priority=priority, size=num_tokens ) - (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) - assert isinstance(inputs, torch.Tensor) + assert isinstance(hidden_states, torch.Tensor) if not is_dummy(prompts[-1]): - inputs[:, : prompts[-1].shape[1]] += prompts[-1] - inter_inputs.append(inputs) + hidden_states[:, : prompts[-1].shape[1]] += prompts[-1] + inter_inputs.append(hidden_states) assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" grad_prompts_reversed = [] # Run a chain of requested backends - for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))): + for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) + (grad_outputs,) = await backend.backward_pool.submit_task( + inp, grad_outputs, active_adapter, **kwargs, priority=priority, size=num_tokens + ) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): @@ -152,7 +153,7 @@ async def iterate_rpc_inference( prioritizer: TaskPrioritizerBase, points: int, quant_type: QuantType, - args_structure: Any = None, + structure: Any = None, ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: assert len(cache_handles) == len(requested_backends) @@ -161,12 +162,9 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) - if args_structure is not None: - # TODO: kwargs currently is unused, it can be used later for peft-like adaptation - flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) - - hidden_states, prompts, hypo_ids, *_ = flat_tensors + (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure) batch_size, length_increment, _ = hidden_states.shape + num_tokens = batch_size * length_increment # Cast inputs to backend dtype hidden_states = hidden_states.to(requested_backends[0].dtype) @@ -209,13 +207,21 @@ async def iterate_rpc_inference( for uid, handles in zip(requested_uids, cache_handles) ) (hidden_states,) = await requested_backends[0].inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, *prompts, priority=priority + hidden_states, + hypo_ids, + inference_infos, + *prompts, + backend_kwargs, + priority=priority, + size=num_tokens, ) else: - for backend, uid, handles, prompt in zip(requested_backends, requested_uids, cache_handles, prompts): + for backend, uid, handles, prompt, kwargs in zip( + requested_backends, requested_uids, cache_handles, prompts, backend_kwargs + ): inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),) (hidden_states,) = await backend.inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, prompt, priority=priority + hidden_states, hypo_ids, inference_infos, prompt, **kwargs, priority=priority, size=num_tokens ) # serialize and send last layer outputs @@ -228,3 +234,22 @@ async def iterate_rpc_inference( # prepare for next step prefix_length += length_increment + + +def _check_inputs( + requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], structure: Any +): + if structure is not None: + args, *backend_kwargs = unpack_args_kwargs(flat_tensors, structure) + else: + args, *backend_kwargs = flat_tensors, {} # backward compatibility + + if len(backend_kwargs) not in (1, len(requested_backends)): + raise RuntimeError( + f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts " + f"(one for each block). Found {len(backend_kwargs)} instead." + ) + if len(backend_kwargs) == 1: + backend_kwargs = (backend_kwargs,) * len(requested_backends) + assert len(backend_kwargs) == len(requested_backends) + return args, backend_kwargs diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 0dd63bde8..ab27a3305 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -180,7 +180,7 @@ async def rpc_inference( prioritizer=self._prioritizer, points=points, quant_type=self.quant_type, - args_structure=args_structure, + structure=args_structure, ): if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) @@ -368,7 +368,7 @@ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PCont prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - args_structure=args_structure, + structure=args_structure, ) return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) @@ -397,7 +397,7 @@ async def rpc_forward_stream( prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - args_structure=args_structure, + structure=args_structure, ) # Split the serialized_output for streaming and respond to client @@ -450,7 +450,7 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - args_structure=args_structure, + structure=args_structure, ) return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) @@ -477,7 +477,7 @@ async def rpc_backward_stream( prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - args_structure=args_structure, + structure=args_structure, ) # Split the serialized_grad_inputs for streaming and respond for tensor in self._serialize_grads(grads, requested_backends, metadata): diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index e027d52f5..7374c9fb6 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -20,7 +20,8 @@ class Task: priority: float time_submitted: float future: MPFuture = field(compare=False) - args: Sequence[torch.Tensor] = field(compare=False) + args: Sequence[Union[torch.Tensor, Any]] = field(compare=False) + size: int = 1 @property def uid(self) -> int: @@ -104,15 +105,15 @@ def shutdown(self, timeout: float = 3): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") self.terminate() - def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: + def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1) -> MPFuture: """Add task to this pool's queue, return Future for its output""" future = MPFuture() # Remove shmem from MPFuture. This disables the .cancel() feature but # saves the server from "could not unlink the shared memory file" crashes during rebalancing future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8) - task = Task(priority, time.monotonic(), future, args) - if self.get_task_size(task) > self.max_batch_size: + task = Task(priority, time.monotonic(), future, args, size=size) + if task.size > self.max_batch_size: exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed") task.future.set_exception(exc) else: @@ -122,12 +123,6 @@ def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture: self.priority = (task.priority, task.time_submitted) return task.future - def get_task_size(self, task: Task) -> int: - """compute task processing complexity; defaults to the total number of tokens""" - if task.args and task.args[0].ndim >= 2: - return task.args[0].shape[0] * task.args[0].shape[1] - return 1 - def load_batch_to_runtime( self, timeout: Optional[float] = None, device: Optional[torch.device] = None ) -> Tuple[Any, List[torch.Tensor]]: From ed8d7f41b84b09b7706d1ebcadd8c3b6a296295f Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 17 Aug 2023 04:18:37 +0300 Subject: [PATCH 03/35] mwp --- src/petals/server/backend.py | 36 +++++++++++++++++----------- src/petals/server/block_functions.py | 21 +++++++++++----- src/petals/server/task_pool.py | 25 ++++++++++--------- src/petals/utils/packaging.py | 2 +- 4 files changed, 52 insertions(+), 32 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 1d171a808..2e225d146 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -53,13 +53,22 @@ def __init__( max_batch_size = self.forward_pool.max_batch_size device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( - self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" + lambda args, kwargs: self.inference_step(*args, **kwargs), + max_batch_size=max_batch_size, + device=device, + name=f"{self.name}_inference", ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( - self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" + lambda args, kwargs: self.forward(*args, **kwargs), + max_batch_size=max_batch_size, + device=device, + name=f"{self.name}_forward", ) self.backward_pool = PrioritizedTaskPool( - self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" + lambda args, kwargs: self.backward(*args, **kwargs), + max_batch_size=max_batch_size, + device=device, + name=f"{self.name}_backward", ) self.dtype = backend_dtype @@ -96,27 +105,25 @@ def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> S cache_tensors.extend((keys, values)) return cache_tensors - def forward(self, *args: torch.Tensor, active_adapter: Optional[str], **kwargs) -> Tuple[torch.Tensor, ...]: + def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]: with self._peft_module.using_adapter(active_adapter), torch.no_grad(): return self.module(*args, **kwargs) def backward( - self, grad_outputs: torch.Tensor, *args, active_adapter: Optional[str], **kwargs + self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs ) -> Tuple[torch.Tensor, ...]: assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor)) with self._peft_module.using_adapter(active_adapter), torch.enable_grad(): (outputs,) = self.module(*args, **kwargs) assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False) - return nested_map(lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None) + return nested_map( + lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs) + ) @torch.inference_mode() def inference_step( - self, - hidden_states: torch.Tensor, - hypo_ids: torch.LongTensor, - kwargs: Dict[str, torch.Tensor], - inference_info: InferenceMetadata, + self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, inference_info: InferenceMetadata, **kwargs ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" seq_len = hidden_states.shape[1] @@ -217,8 +224,9 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]) """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" assert len(backends) != 0 and all(isinstance(b, TransformerBackend) for b in backends.values()) first_pool = next(iter(backends.values())).inference_pool + merged_inference_func = _MergedInferenceStep(backends) merged_pool = PrioritizedTaskPool( - _MergedInferenceStep(backends), + lambda args, kwargs: merged_inference_func(*args, **kwargs), max_batch_size=first_pool.max_batch_size, device=first_pool.device, name=f"merged_inference", @@ -237,9 +245,9 @@ def __call__( self, hidden_states: torch.Tensor, hypo_ids: torch.LongTensor, - backend_kwargs: Sequence[Dict[str, torch.Tensor]], inference_infos: Sequence[InferenceMetadata], *optional_prompts: Optional[torch.Tensor], + backend_kwargs: Sequence[Dict[str, torch.Tensor]], ) -> Tuple[torch.Tensor, ...]: assert ( len(inference_infos) == len(optional_prompts) == len(backend_kwargs) @@ -248,6 +256,6 @@ def __call__( if optional_prompt is not None: hidden_states[:, : optional_prompt.shape[1]] += optional_prompt (hidden_states,) = self.backends[inference_info.uid].inference_step( - hidden_states, hypo_ids, kwargs, inference_info + hidden_states, hypo_ids, inference_info, **kwargs ) return (hidden_states,) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index d55d99045..d1d8d69dc 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -66,8 +66,8 @@ async def run_rpc_forward( hidden_states, points=points / len(requested_backends), backend=backend, type="forward" ) (hidden_states,) = await backend.forward_pool.submit_task( + active_adapter, hidden_states, - active_adapter=active_adapter, **kwargs, priority=priority, size=num_tokens, @@ -113,7 +113,7 @@ async def run_rpc_backward( hidden_states, points=points / len(requested_backends), backend=backend, type="forward_in_backward" ) (hidden_states,) = await backend.forward_pool.submit_task( - hidden_states, active_adapter, **kwargs, priority=priority, size=num_tokens + active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens ) assert isinstance(hidden_states, torch.Tensor) @@ -131,7 +131,7 @@ async def run_rpc_backward( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) (grad_outputs,) = await backend.backward_pool.submit_task( - inp, grad_outputs, active_adapter, **kwargs, priority=priority, size=num_tokens + active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens ) assert isinstance(grad_outputs, torch.Tensor) @@ -211,7 +211,7 @@ async def iterate_rpc_inference( hypo_ids, inference_infos, *prompts, - backend_kwargs, + backend_kwargs=backend_kwargs, priority=priority, size=num_tokens, ) @@ -221,7 +221,13 @@ async def iterate_rpc_inference( ): inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),) (hidden_states,) = await backend.inference_pool.submit_task( - hidden_states, hypo_ids, inference_infos, prompt, **kwargs, priority=priority, size=num_tokens + hidden_states, + hypo_ids, + inference_infos, + prompt, + backend_kwargs=(kwargs,), + priority=priority, + size=num_tokens, ) # serialize and send last layer outputs @@ -250,6 +256,9 @@ def _check_inputs( f"(one for each block). Found {len(backend_kwargs)} instead." ) if len(backend_kwargs) == 1: - backend_kwargs = (backend_kwargs,) * len(requested_backends) + backend_kwargs = backend_kwargs * len(requested_backends) assert len(backend_kwargs) == len(requested_backends) + for i, kwargs in enumerate(backend_kwargs): + if not isinstance(kwargs, dict): + raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}") return args, backend_kwargs diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index 7374c9fb6..a173d3211 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -4,14 +4,17 @@ import time from concurrent.futures._base import PENDING from dataclasses import dataclass, field +from functools import partial from queue import PriorityQueue from typing import Any, List, Optional, Sequence, Tuple, Union import torch -from hivemind import get_logger +from hivemind import get_logger, nested_map from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.utils.mpfuture import ALL_STATES, MPFuture +from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs + logger = get_logger(__name__) @@ -19,9 +22,10 @@ class Task: priority: float time_submitted: float + size: int future: MPFuture = field(compare=False) - args: Sequence[Union[torch.Tensor, Any]] = field(compare=False) - size: int = 1 + flat_tensors: Sequence[torch.Tensor] = field(compare=False) + structure: Any @property def uid(self) -> int: @@ -105,14 +109,13 @@ def shutdown(self, timeout: float = 3): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") self.terminate() - def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1) -> MPFuture: + def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1, **kwargs: Any) -> MPFuture: """Add task to this pool's queue, return Future for its output""" future = MPFuture() # Remove shmem from MPFuture. This disables the .cancel() feature but # saves the server from "could not unlink the shared memory file" crashes during rebalancing future._shared_state_code = torch.tensor([ALL_STATES.index(PENDING)], dtype=torch.uint8) - - task = Task(priority, time.monotonic(), future, args, size=size) + task = Task(priority, time.monotonic(), size, future, *pack_args_kwargs(*args, **kwargs)) if task.size > self.max_batch_size: exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed") task.future.set_exception(exc) @@ -125,25 +128,25 @@ def submit_task(self, *args: Any, priority: float = 0.0, size: int = 1) -> MPFut def load_batch_to_runtime( self, timeout: Optional[float] = None, device: Optional[torch.device] = None - ) -> Tuple[Any, List[torch.Tensor]]: + ) -> Tuple[int, Any]: """receive next batch of arrays""" device = device if device is not None else self.device task = self._ordered_tasks.get(block=True, timeout=timeout) - batch_inputs = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.args] + device_flat_tensors = [_move_to_device_if_tensor(arg, device, share_memory=False) for arg in task.flat_tensors] self._dispatched_tasks[task.uid] = task self.batch_receiver.recv() # reduce the number of active batches if not self._ordered_tasks.empty(): first_remaining_task: Task = self._ordered_tasks.queue[0] self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted) - return task.uid, batch_inputs + return task.uid, unpack_args_kwargs(device_flat_tensors, task.structure) def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]): """send results for a processed batch, previously loaded through load_batch_to_runtime""" - batch_outputs = [_move_to_device_if_tensor(output, device="cpu", share_memory=True) for output in batch_outputs] + batch_outputs = nested_map(partial(_move_to_device_if_tensor, device="cpu", share_memory=True), batch_outputs) task = self._dispatched_tasks.pop(uid, None) if task is None: logger.error( - f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result" + f"Internal error: task task with index {uid} is missing from the dictionary; Could not set result" ) else: task.future.set_result(batch_outputs) diff --git a/src/petals/utils/packaging.py b/src/petals/utils/packaging.py index f7d2e9bcc..a50e8b25b 100644 --- a/src/petals/utils/packaging.py +++ b/src/petals/utils/packaging.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Sequence +from typing import Any, Sequence, Tuple import torch from hivemind import nested_flatten, nested_pack From fb9b21132c4ae0fa8f985208537b270347c83058 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 17 Aug 2023 04:31:56 +0300 Subject: [PATCH 04/35] black-isort --- src/petals/server/backend.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 2e225d146..3020ac495 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -17,6 +17,7 @@ from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy +from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) @@ -112,14 +113,17 @@ def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) def backward( self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs ) -> Tuple[torch.Tensor, ...]: + args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args] + # TODO remove this WITHIN PR#467; make sure args are passed properly and retain requires_grad assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor)) with self._peft_module.using_adapter(active_adapter), torch.enable_grad(): (outputs,) = self.module(*args, **kwargs) assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False) - return nested_map( - lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs) - ) + # flat_tensors, structure = pack_args_kwargs(nested_map( + # lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs)) + # ) + return (args[0].grad,) # TODO pass additional kwarg-grads back to the user WITHIN #467 @torch.inference_mode() def inference_step( From 355c1509e16e3a555862d1c0ed1e87942b02b836 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 17 Aug 2023 04:32:03 +0300 Subject: [PATCH 05/35] black-isort --- src/petals/server/backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 3020ac495..87d16fcb4 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -17,7 +17,6 @@ from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool from petals.utils.misc import is_dummy -from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) From 084d565845bb93296a93beb07fd6823fc4399b14 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 17 Aug 2023 04:41:18 +0300 Subject: [PATCH 06/35] priority pool --- tests/test_priority_pool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_priority_pool.py b/tests/test_priority_pool.py index 2623bb196..d9ec0a8c7 100644 --- a/tests/test_priority_pool.py +++ b/tests/test_priority_pool.py @@ -13,7 +13,8 @@ def test_priority_pools(): outputs_queue = mp.SimpleQueue() results_valid = mp.Event() - def dummy_pool_func(x): + def dummy_pool_func(args, kwargs): + (x,) = args # TODO modify the PriorityPool code such that dummy_pool_func can accept x directly time.sleep(0.1) y = x**2 outputs_queue.put((x, y)) From 65e87395bcabe138690bb5ecdee4b32d1c43e485 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Aug 2023 15:20:30 +0300 Subject: [PATCH 07/35] wip (again) --- src/petals/server/backend.py | 16 +++++---- src/petals/server/block_functions.py | 20 ++++++----- src/petals/server/handler.py | 51 +++++++++++++++------------- tests/test_server_stats.py | 1 + 4 files changed, 50 insertions(+), 38 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 87d16fcb4..8626d0426 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -111,18 +111,22 @@ def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) def backward( self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs - ) -> Tuple[torch.Tensor, ...]: + ) -> Tuple[Union[torch.Tensor, Any], ...]: args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args] - # TODO remove this WITHIN PR#467; make sure args are passed properly and retain requires_grad + # ^-- TODO remove this AFTER PR#467; make sure args are passed properly and retain requires_grad assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor)) with self._peft_module.using_adapter(active_adapter), torch.enable_grad(): (outputs,) = self.module(*args, **kwargs) assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False) - # flat_tensors, structure = pack_args_kwargs(nested_map( - # lambda x: x.grad if isinstance(x.grad, torch.Tensor) and x.requires_grad else None, (args, kwargs)) - # ) - return (args[0].grad,) # TODO pass additional kwarg-grads back to the user WITHIN #467 + return nested_map(self._get_grad_if_required, (args, kwargs)) + + @staticmethod + def _get_grad_if_required(input: Any) -> Optional[torch.Tensor]: + """Get grad w.r.t. input if input is a tensor that requires grad; otherwise return None""" + if isinstance(input, torch.Tensor) and input.requires_grad: + return input.grad if input.grad is not None else torch.zeros_like(input) + return None @torch.inference_mode() def inference_step( diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index d1d8d69dc..076129e5f 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -18,7 +18,7 @@ from petals.server.task_prioritizer import TaskPrioritizerBase from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy -from petals.utils.packaging import unpack_args_kwargs +from petals.utils.packaging import unpack_args_kwargs, pack_args_kwargs # We prioritize short inference requests and make them use a *merged* inference pool, # so they are processed without interruptions and extra overheads @@ -86,9 +86,9 @@ async def run_rpc_backward( active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, - structure: Any, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure) + args_structure: Any, +) -> Tuple[Sequence[torch.Tensor], Any]: + (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) # Cast inputs & grad outputs to backend dtype assert hidden_states.ndim == 3 num_tokens = hidden_states.shape[0] * hidden_states.shape[1] @@ -124,22 +124,26 @@ async def run_rpc_backward( assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" grad_prompts_reversed = [] + grad_backend_kwargs_reversed = [] + # Run a chain of requested backends for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" priority = prioritizer.prioritize( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) - (grad_outputs,) = await backend.backward_pool.submit_task( + (grad_outputs,), grad_kwargs = await backend.backward_pool.submit_task( active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens ) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) + grad_backend_kwargs_reversed.append(grad_kwargs) grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY - return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape + grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] + return pack_args_kwargs((grad_args, reversed(grad_backend_kwargs_reversed))) async def iterate_rpc_inference( @@ -153,7 +157,7 @@ async def iterate_rpc_inference( prioritizer: TaskPrioritizerBase, points: int, quant_type: QuantType, - structure: Any = None, + args_structure: Any = None, ) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]: assert len(cache_handles) == len(requested_backends) @@ -162,7 +166,7 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) - (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure) + (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) batch_size, length_increment, _ = hidden_states.shape num_tokens = batch_size * length_increment diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index ab27a3305..0cb173c9b 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -180,7 +180,7 @@ async def rpc_inference( prioritizer=self._prioritizer, points=points, quant_type=self.quant_type, - structure=args_structure, + args_structure=args_structure, ): if can_push: task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata)) @@ -444,16 +444,18 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" - grads = await run_rpc_backward( + flat_grads, grads_structure = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - structure=args_structure, + args_structure=args_structure, ) - return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) + serialized_flat_grads = self._serialize_grads(flat_grads, flat_tensors, metadata) + serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grads_structure)) + return runtime_pb2.ExpertResponse(tensors=serialized_flat_grads, metadata=serialized_output_metadata) async def rpc_backward_stream( self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext @@ -471,18 +473,20 @@ async def rpc_backward_stream( points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" - grads = await run_rpc_backward( + flat_grads, grad_structure = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - structure=args_structure, + args_structure=args_structure, ) # Split the serialized_grad_inputs for streaming and respond - for tensor in self._serialize_grads(grads, requested_backends, metadata): + serialized_output_metadata = MSGPackSerializer.dumps(output_metadata) + for tensor in self._serialize_grads(flat_grads, requested_backends, metadata): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): - yield runtime_pb2.ExpertResponse(tensors=[part]) + yield runtime_pb2.ExpertResponse(tensors=[part], metadata=serialized_output_metadata) + serialized_output_metadata = None # attach metadata to the first response only def _get_active_adapter(self, metadata: dict) -> str: active_adapter = metadata.get("active_adapter", "") @@ -492,28 +496,27 @@ def _get_active_adapter(self, metadata: dict) -> str: def _serialize_grads( self, - grads: Sequence[torch.Tensor], - requested_backends: Sequence[TransformerBackend], - metadata: Dict[str, Any], + flat_grads: Sequence[torch.Tensor], + flat_inputs: Sequence[runtime_pb2.Tensor], + input_metadata: Dict[str, Any], ) -> Sequence[runtime_pb2.Tensor]: """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema""" + inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad) + assert len(flat_grads) == len(inputs_with_grad), f"user provides {len(inputs_with_grad)} inputs with grad, " \ + f"but backward produced {len(flat_grads)} gradients" # Modify grad_inputs_schema to support grad_prompts - assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize - flat_grads_schema = tuple( - nested_flatten((requested_backends[0].args_schema * len(grads), requested_backends[0].kwargs_schema)) - ) # TODO generalize - - if metadata.get("output_compression") is not None: - assert isinstance(metadata["output_compression"], (list, tuple)), "output_compression must be a tuple/list" - output_compression = tuple(metadata["output_compression"]) + if input_metadata.get("output_compression") is not None: + output_compression = input_metadata["output_compression"] + assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list" assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" - assert len(output_compression) == len(grads), f"output_compression should have {len(grads)} elements" + assert len(output_compression) == len(flat_grads), f"output_compression should have {len(flat_grads)} " \ + f"elements, one for every tensor thar requires grad" else: - output_compression = tuple(tensor.compression for tensor in flat_grads_schema) - + output_compression = tuple(runtime_pb2.NONE for _ in flat_grads) + output_compression = tuple(output_compression) return [ - serialize_torch_tensor(result.to(proto.dtype), compression, allow_inplace=True) - for result, proto, compression in zip(grads, flat_grads_schema, output_compression) + serialize_torch_tensor(result.to(input.dtype), compression, allow_inplace=True) + for result, input, compression in zip(flat_grads, inputs_with_grad, output_compression) ] def _check_uids(self, uids: str) -> Tuple[ModuleUID, ...]: diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index c8e6ab6e5..7d17d9928 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -9,6 +9,7 @@ from test_utils import * +@pytest.mark.skip @pytest.mark.forked def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) From 13c13d347acf17f3ffe66b2dc7eddb0805f58d88 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Aug 2023 15:21:17 +0300 Subject: [PATCH 08/35] wip (again) --- src/petals/server/block_functions.py | 10 +++++----- src/petals/server/handler.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 076129e5f..cd398fb0d 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -35,7 +35,7 @@ async def run_rpc_forward( active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, - structure: Any, + args_structure: Any, ) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream @@ -45,7 +45,7 @@ async def run_rpc_forward( :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ - (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, structure) + (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) @@ -247,10 +247,10 @@ async def iterate_rpc_inference( def _check_inputs( - requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], structure: Any + requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any ): - if structure is not None: - args, *backend_kwargs = unpack_args_kwargs(flat_tensors, structure) + if args_structure is not None: + args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure) else: args, *backend_kwargs = flat_tensors, {} # backward compatibility diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 0cb173c9b..8b3a7e848 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -368,7 +368,7 @@ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PCont prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - structure=args_structure, + args_structure=args_structure, ) return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) @@ -397,7 +397,7 @@ async def rpc_forward_stream( prioritizer=self._prioritizer, active_adapter=active_adapter, points=points, - structure=args_structure, + args_structure=args_structure, ) # Split the serialized_output for streaming and respond to client From 4529471f3fd9fdb7e1b78e5c2514b17bc811bbd3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Aug 2023 15:21:49 +0300 Subject: [PATCH 09/35] black, isort --- src/petals/server/block_functions.py | 10 +++++++--- src/petals/server/handler.py | 12 ++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index cd398fb0d..2eb2f3cae 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -18,7 +18,7 @@ from petals.server.task_prioritizer import TaskPrioritizerBase from petals.utils.convert_block import QuantType from petals.utils.misc import DUMMY, is_dummy -from petals.utils.packaging import unpack_args_kwargs, pack_args_kwargs +from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs # We prioritize short inference requests and make them use a *merged* inference pool, # so they are processed without interruptions and extra overheads @@ -88,7 +88,9 @@ async def run_rpc_backward( points: int = 0, args_structure: Any, ) -> Tuple[Sequence[torch.Tensor], Any]: - (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) + (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs( + requested_backends, flat_tensors, args_structure + ) # Cast inputs & grad outputs to backend dtype assert hidden_states.ndim == 3 num_tokens = hidden_states.shape[0] * hidden_states.shape[1] @@ -166,7 +168,9 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) - (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) + (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs( + requested_backends, flat_tensors, args_structure + ) batch_size, length_increment, _ = hidden_states.shape num_tokens = batch_size * length_increment diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 8b3a7e848..530f6935f 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -502,15 +502,19 @@ def _serialize_grads( ) -> Sequence[runtime_pb2.Tensor]: """Serialize backward gradients w.r.t. inputs using either default schema or custom user-specified schema""" inputs_with_grad = tuple(input for input in flat_inputs if input.requires_grad) - assert len(flat_grads) == len(inputs_with_grad), f"user provides {len(inputs_with_grad)} inputs with grad, " \ - f"but backward produced {len(flat_grads)} gradients" + assert len(flat_grads) == len(inputs_with_grad), ( + f"user provides {len(inputs_with_grad)} inputs with grad, " + f"but backward produced {len(flat_grads)} gradients" + ) # Modify grad_inputs_schema to support grad_prompts if input_metadata.get("output_compression") is not None: output_compression = input_metadata["output_compression"] assert isinstance(output_compression, (list, tuple)), "output_compression must be a tuple/list" assert all(isinstance(c, int) for c in output_compression), "output_compression must contain integers" - assert len(output_compression) == len(flat_grads), f"output_compression should have {len(flat_grads)} " \ - f"elements, one for every tensor thar requires grad" + assert len(output_compression) == len(flat_grads), ( + f"output_compression should have {len(flat_grads)} " + f"elements, one for every tensor thar requires grad" + ) else: output_compression = tuple(runtime_pb2.NONE for _ in flat_grads) output_compression = tuple(output_compression) From d51c08ef208a7788a50591c574c85926bf836791 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 22 Aug 2023 15:22:27 +0300 Subject: [PATCH 10/35] undo debug change --- tests/test_server_stats.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index 7d17d9928..c8e6ab6e5 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -9,7 +9,6 @@ from test_utils import * -@pytest.mark.skip @pytest.mark.forked def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) From 09e9da6eb1483884bd3340cfb1e5435037d19359 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 25 Aug 2023 17:36:47 +0300 Subject: [PATCH 11/35] serialize outputs structure --- src/petals/server/backend.py | 2 +- src/petals/server/block_functions.py | 3 +-- src/petals/server/handler.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 8626d0426..7c66eb539 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -119,7 +119,7 @@ def backward( (outputs,) = self.module(*args, **kwargs) assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape torch.autograd.backward((outputs,), grad_tensors=(grad_outputs,), create_graph=False, retain_graph=False) - return nested_map(self._get_grad_if_required, (args, kwargs)) + return nested_map(self._get_grad_if_required, (*args, kwargs)) @staticmethod def _get_grad_if_required(input: Any) -> Optional[torch.Tensor]: diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 47e3e8ae7..f05d77884 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -134,10 +134,9 @@ async def run_rpc_backward( priority = prioritizer.prioritize( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) - (grad_outputs,), grad_kwargs = await backend.backward_pool.submit_task( + (*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task( active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens ) - assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 530f6935f..0da3b6eb3 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -482,7 +482,7 @@ async def rpc_backward_stream( args_structure=args_structure, ) # Split the serialized_grad_inputs for streaming and respond - serialized_output_metadata = MSGPackSerializer.dumps(output_metadata) + serialized_output_metadata = MSGPackSerializer.dumps(dict(structure=grad_structure)) for tensor in self._serialize_grads(flat_grads, requested_backends, metadata): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part], metadata=serialized_output_metadata) From 84ebd5710509dacbd7efe262723a1bed18346e71 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 28 Aug 2023 06:03:33 +0300 Subject: [PATCH 12/35] WIP, switching to another PR --- src/petals/server/backend.py | 3 --- src/petals/server/block_functions.py | 2 +- tests/test_server_stats.py | 2 +- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 7c66eb539..e6b380885 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -112,9 +112,6 @@ def forward(self, active_adapter: Optional[str], *args: torch.Tensor, **kwargs) def backward( self, active_adapter: Optional[str], grad_outputs: torch.Tensor, *args, **kwargs ) -> Tuple[Union[torch.Tensor, Any], ...]: - args = [x.detach().requires_grad_(True) if x.is_floating_point() else x.detach() for x in args] - # ^-- TODO remove this AFTER PR#467; make sure args are passed properly and retain requires_grad - assert any(x.requires_grad for x in nested_flatten((args, kwargs)) if isinstance(x, torch.Tensor)) with self._peft_module.using_adapter(active_adapter), torch.enable_grad(): (outputs,) = self.module(*args, **kwargs) assert isinstance(outputs, torch.Tensor) and outputs.shape == grad_outputs.shape diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index f05d77884..df300f2cf 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -92,7 +92,7 @@ async def run_rpc_backward( requested_backends, flat_tensors, args_structure ) # Cast inputs & grad outputs to backend dtype - assert hidden_states.ndim == 3 + assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad" num_tokens = hidden_states.shape[0] * hidden_states.shape[1] hidden_states = hidden_states.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index c8e6ab6e5..c8aacd570 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -9,7 +9,7 @@ from test_utils import * -@pytest.mark.forked +@pytest.mark.skip def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id From 49ff759d30730fa91161f5fc8e7690b7826f101f Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 28 Aug 2023 06:03:59 +0300 Subject: [PATCH 13/35] undo --- tests/test_server_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_server_stats.py b/tests/test_server_stats.py index c8aacd570..c8e6ab6e5 100644 --- a/tests/test_server_stats.py +++ b/tests/test_server_stats.py @@ -9,7 +9,7 @@ from test_utils import * -@pytest.mark.skip +@pytest.mark.forked def test_server_info(block_from: int = 2, block_to: int = 5, max_length: int = 100, max_length2: int = 50): config = AutoDistributedConfig.from_pretrained(MODEL_NAME) config.allowed_servers = ["QmNV5G3hq2UmAck2htEgsqrmPFBff5goFZAdmKDcZLBZLX"] # PeerID from server2.id From 6c7f7623799aaf527c5ac125b1b7c5dbb09e603a Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 05:21:12 +0300 Subject: [PATCH 14/35] rollback: only generic kwarg --- src/petals/server/backend.py | 15 +++------------ src/petals/server/server.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 3d95bcca5..9ae672b5e 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -53,22 +53,13 @@ def __init__( max_batch_size = self.forward_pool.max_batch_size device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( - lambda args, kwargs: self.inference_step(*args, **kwargs), - max_batch_size=max_batch_size, - device=device, - name=f"{self.name}_inference", + self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( - lambda args, kwargs: self.forward(*args, **kwargs), - max_batch_size=max_batch_size, - device=device, - name=f"{self.name}_forward", + self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" ) self.backward_pool = PrioritizedTaskPool( - lambda args, kwargs: self.backward(*args, **kwargs), - max_batch_size=max_batch_size, - device=device, - name=f"{self.name}_backward", + self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" ) self.dtype = backend_dtype diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fd9f76660..c85108ab9 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -770,3 +770,15 @@ class RuntimeWithDeduplicatedPools(Runtime): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.pools = tuple(set(self.pools)) + + def process_batch( + self, pool: TaskPoolBase, batch_index: int, args: Sequence[Any], kwargs: Dict[str, Any] + ) -> Tuple[Any, int]: + """process one batch of tasks from a given pool, return a batch of results and total batch size""" + outputs = pool.process_func(*args, **kwargs) + batch_size = 1 + for arg in args: + if isintance(arg, torch.Tensor) and arg.ndim > 2: + batch_size = arg.shape[0] * arg.shape[1] + break + return outputs, batch_size From cc4fe17a9906034e9c59ba6a33cdde3af226b6b5 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 14:05:41 +0300 Subject: [PATCH 15/35] minimize diff --- src/petals/__init__.py | 3 +++ src/petals/server/backend.py | 2 +- src/petals/server/server.py | 5 +++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 4e4a9d04a..27076ba92 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -24,6 +24,9 @@ assert ( version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0") ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0" + assert version.parse("1.1.10") <= version.parse( + hivemind.__version__ + ), "Please install a proper hivemind version: pip install hivemind>=1.1.10" def _override_bfloat16_mode_default(): diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 9ae672b5e..db4faac4a 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -221,7 +221,7 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]) first_pool = next(iter(backends.values())).inference_pool merged_inference_func = _MergedInferenceStep(backends) merged_pool = PrioritizedTaskPool( - lambda args, kwargs: merged_inference_func(*args, **kwargs), + merged_inference_func, max_batch_size=first_pool.max_batch_size, device=first_pool.device, name=f"merged_inference", diff --git a/src/petals/server/server.py b/src/petals/server/server.py index c85108ab9..e8e3d5931 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -8,7 +8,7 @@ import sys import threading import time -from typing import Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import hivemind import psutil @@ -17,6 +17,7 @@ from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime +from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig @@ -778,7 +779,7 @@ def process_batch( outputs = pool.process_func(*args, **kwargs) batch_size = 1 for arg in args: - if isintance(arg, torch.Tensor) and arg.ndim > 2: + if isinstance(arg, torch.Tensor) and arg.ndim > 2: batch_size = arg.shape[0] * arg.shape[1] break return outputs, batch_size From 2e760319ab86e0a47511964d9b9fad9eed76cd89 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 14:28:48 +0300 Subject: [PATCH 16/35] add docstr --- src/petals/client/inference_session.py | 2 +- src/petals/client/remote_forward_backward.py | 32 ++++++-------------- src/petals/client/sequential_autograd.py | 16 ++++++++-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 28d363248..4d2ae159d 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -93,7 +93,7 @@ def step( ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs - :prompts: optional DEEP prompts, added to a prefix of each layer's outputs, + :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ if self.closed: diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 44abe2686..e7f02184d 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -19,30 +19,30 @@ async def _forward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig ) -> List[torch.Tensor]: outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( - runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)), timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in outputs.tensors] async def _backward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig ) -> List[torch.Tensor]: grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( - runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)), timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] async def _forward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig ) -> List[torch.Tensor]: parts = ( - runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) + runtime_pb2.ExpertRequest(uid=uid, tensors=[part]) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) @@ -52,10 +52,10 @@ async def _forward_stream( async def _backward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig ) -> List[torch.Tensor]: parts = ( - runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) + runtime_pb2.ExpertRequest(uid=uid, tensors=[part]) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) @@ -68,31 +68,19 @@ async def run_remote_forward( uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, - *inputs: torch.Tensor, + *forward_inputs: torch.Tensor, config: ClientConfig, metadata: Optional[bytes] = None, - **kwargs, ) -> Tuple[torch.Tensor, ...]: """ Serializes input tensors and calls "rpc_forward" on a remote server. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ - - # Note: *inputs are flattened input tensors that follow the expert's info['input_schema'] - # detach to avoid pickling the computation graph - assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}" - kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]} - - # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors - forward_inputs = tuple(nested_flatten((inputs, kwargs))) args_schema, kwargs_schema = rpc_info["forward_schema"] compression = args_schema[0].compression forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs) inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs) - # TODO: create more explicit way to check servers schema and client's structure - assert len(inputs) >= len(args_schema) + 1, "Inputs and prompt tensors are necessary for a forward step" - # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( @@ -106,7 +94,7 @@ async def run_remote_forward( size = sum(t.element_size() * t.nelement() for t in inputs) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) + deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata) return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 9d965d2a5..c459711f5 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -4,7 +4,7 @@ import asyncio import itertools from collections import deque -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence, Tuple, Dict, Any import torch from hivemind import MSGPackSerializer @@ -29,14 +29,25 @@ async def sequential_forward( sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None, + block_kwargs: Sequence[Dict[str, Any]] = (), ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]: """ Constructs a routing path from to . Performs chained forward for each subsequence of blocks on the path. If some subsequence fails, reconstructs the remaining path and tries to finish the forward. + + :param inputs: initial hidden states of shape [batch_size, sequence length, hidden_size] + :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, + if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] + :param sequence_manager: a running SequenceManager used to select remote servers and handle failures + :param start_index: run remote blocks starting from this index + :param end_index: run remote blocks up to (but not including) this index + :param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block """ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" + assert len(block_kwargs) in (0, 1, end_index - start_index), \ + f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" inputs_device = inputs.device inputs_dtype = inputs.dtype @@ -68,7 +79,8 @@ async def sequential_forward( span = sequences.popleft() stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) - flat_tensors, args_structure = pack_args_kwargs(inputs, prompts[span.start : span.end]) + flat_tensors, args_structure = pack_args_kwargs( + inputs, prompts[span.start : span.end], *block_kwargs[span.start: span.end]) span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) metadata = sequence_manager.get_request_metadata( From e5c2d8eca472c10172cb9915f80d6f81cdceac97 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 15:32:47 +0300 Subject: [PATCH 17/35] WIP BEFORE MEETING NEED BACKWARD UPDATE --- src/petals/client/remote_forward_backward.py | 76 ++++++++++--------- src/petals/client/routing/sequence_manager.py | 24 ++++-- src/petals/client/sequential_autograd.py | 44 ++++------- 3 files changed, 71 insertions(+), 73 deletions(-) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index e7f02184d..9151c1801 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -5,7 +5,7 @@ from typing import Iterable, List, Optional, Sequence, Tuple import torch -from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor +from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor, PeerID from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor from hivemind.p2p import StubBase from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE @@ -14,8 +14,11 @@ from hivemind.utils.streaming import split_for_streaming from hivemind.utils.tensor_descr import BatchTensorDescriptor +from petals import RemoteSequenceManager from petals.client.config import ClientConfig -from petals.data_structures import ModuleUID, RPCInfo +from petals.data_structures import ModuleUID, RPCInfo, CHAIN_DELIMITER +from petals.server.handler import TransformerConnectionHandler +from petals.utils.packaging import pack_args_kwargs async def _forward_unary( @@ -65,73 +68,74 @@ async def _backward_stream( async def run_remote_forward( - uid: ModuleUID, - stub: StubBase, - rpc_info: RPCInfo, - *forward_inputs: torch.Tensor, - config: ClientConfig, - metadata: Optional[bytes] = None, + sequence_manager: RemoteSequenceManager, + peer_id: PeerID, + span_uids: Sequence[ModuleUID], + *args: torch.Tensor, + **kwargs: torch.Tensor, ) -> Tuple[torch.Tensor, ...]: """ Serializes input tensors and calls "rpc_forward" on a remote server. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ - args_schema, kwargs_schema = rpc_info["forward_schema"] - compression = args_schema[0].compression - forward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in forward_inputs) - inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs) + merged_uid = CHAIN_DELIMITER.join(span_uids) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id) + flat_inputs, args_structure = pack_args_kwargs(*args, **kwargs) + metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs) + compressions = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs) + if compressions is None: + compressions = [runtime_pb2.CompressionType.NONE] * len(flat_inputs) + compressions = list(nested_flatten(compressions)) + assert len(compressions) == len(flat_inputs), f"got {len(flat_inputs)} tensors but {len(compressions)} codecs" + inputs = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_inputs) + # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( - loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs, forward_schema) + loop.run_in_executor(None, serialize_torch_tensor, tensor, compression) + for tensor, compression in zip(inputs, compressions) ) ) # call RPC on remote server size = sum(t.element_size() * t.nelement() for t in inputs) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary - # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_outputs = await forward_fn(uid, serialized_tensors, stub, config, metadata=metadata) - return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"]) + # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR + return await forward_fn(merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata) async def run_remote_backward( - uid: ModuleUID, + sequence_manager: RemoteSequenceManager, + span_uids: Sequence[ModuleUID], stub: StubBase, - rpc_info: RPCInfo, - *inputs_and_grad_outputs: torch.Tensor, - config: ClientConfig, - metadata: Optional[bytes] = None, - **kwargs, + grad_outputs: Sequence[torch.Tensor], + *args: torch.Tensor, + **kwargs: torch.Tensor, ) -> Sequence[torch.Tensor]: """ Serializes grad outputs and calls "rpc_backward" on a remote server. Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ - args_schema, kwargs_schema = rpc_info["forward_schema"] - outputs_schema = rpc_info["outputs_schema"] - compression = args_schema[0].compression - backward_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in inputs_and_grad_outputs) - # TODO: create more explicit way to check servers schema and client's structure - assert ( - len(inputs_and_grad_outputs) >= len(args_schema) + len(outputs_schema) + 1 - ), "Inputs, grad_outputs and prompt tensors are necessary for a backward step" + flat_tensors, args_structure = pack_args_kwargs( + [grad.cpu() for grad in grad_outputs], args, kwargs + ) + metadata = sequence_manager.get_request_metadata( + "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id + ) # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( - loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(inputs_and_grad_outputs, backward_schema) + loop.run_in_executor(None, serialize_torch_tensor, compression) + for tensor, proto in zip(flat_inputs_and_grad_outputs, backward_schema) ) ) - size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs) + size = sum(t.element_size() * t.nelement() for t in flat_inputs_and_grad_outputs) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - deserialized_grad_inputs = await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata, **kwargs) - return deserialized_grad_inputs + return await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 3e239b4b5..560162109 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -474,22 +474,30 @@ def get_retry_delay(self, attempt_no: int) -> float: return 0 return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff) - def get_request_metadata( - self, protocol: str, args_structure: Any = None, *args, **kwargs - ) -> Optional[Dict[str, Any]]: + def get_request_metadata(self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Dict[str, Any]]: """ + :param peer_id: remote server's PeerID :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" - :param args_structure: the structure of flattened tensors from pack_args_kwargs in petals.utils.packaging - :param args: request-specific inputs, typically block uids and input tensors - :param kwargs: additional request context, such as remote peer ID - :returns: msgpack-serialized metadata dict that will be passed alongside a given request + :param args: request-specific input tensors + :param kwargs: additional request keyword arguments + :returns: metadata dict that will be passed alongside a given request """ return dict( points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter, - args_structure=args_structure, ) + def get_compression_codecs( + self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]: + """ + :param peer_id: remote server's PeerID + :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" + :param args: request-specific input tensors + :param kwargs: additional request keyword arguments + :returns: compressions for each input tensor; contains as many elements as there are tensors in (args, kwargs) + """ + return None + def shutdown(self): self._thread.shutdown() diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index c459711f5..e8ba8fa65 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -46,8 +46,11 @@ async def sequential_forward( """ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" - assert len(block_kwargs) in (0, 1, end_index - start_index), \ - f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" + assert len(block_kwargs) in ( + 0, + 1, + end_index - start_index, + ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" inputs_device = inputs.device inputs_dtype = inputs.dtype @@ -78,27 +81,19 @@ async def sequential_forward( span = sequences.popleft() - stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) - flat_tensors, args_structure = pack_args_kwargs( - inputs, prompts[span.start : span.end], *block_kwargs[span.start: span.end]) - - span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) - metadata = sequence_manager.get_request_metadata( - "rpc_forward", args_structure, span_uids, *flat_tensors - ) (outputs,) = await run_remote_forward( - span_uids, - stub, - sequence_manager.rpc_info, - *flat_tensors, - config=sequence_manager.config, - metadata=MSGPackSerializer.dumps(metadata), + sequence_manager, + span.peer_id, + sequence_manager.block_uids[span.start : span.end], + inputs, + prompts[span.start : span.end], + *block_kwargs[span.start : span.end] ) assert isinstance(outputs, torch.Tensor) assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}" - # Save intermediate inputs and subsequences if the forward is already done for them + # Save intermediate inputs and subsequ_peerences if the forward is already done for them intermediate_inputs.append(inputs) done_sequences.append(span) @@ -164,23 +159,14 @@ async def sequential_backward( inputs = intermediate_inputs.pop() span = forward_sequences.pop() - grad_outputs_cpu = [grad.cpu() for grad in grad_outputs] - flat_tensors, args_structure = pack_args_kwargs( - inputs, *grad_outputs_cpu, prompts[span.start : span.end] - ) span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) - metadata = sequence_manager.get_request_metadata( - "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id - ) grad_outputs, *span_grad_prompts = await run_remote_backward( + sequence_manager, + sequence_manager.block_uids[span.start: span.end], span_uids, - stub, - sequence_manager.rpc_info, - *flat_tensors, - config=sequence_manager.config, - metadata=MSGPackSerializer.dumps(metadata), + grad_outputs, inputs, ) grad_outputs = [grad_outputs] grad_prompts_reversed.extend(span_grad_prompts) From 49474e5477bda0ca7e7bc2cb87fb6ac381aff5ca Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 22:07:48 +0300 Subject: [PATCH 18/35] wip some more --- src/petals/client/inference_session.py | 20 +++-- src/petals/client/remote_forward_backward.py | 82 +++++++++++-------- src/petals/client/routing/sequence_manager.py | 7 +- src/petals/client/sequential_autograd.py | 47 +++++------ 4 files changed, 83 insertions(+), 73 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 4d2ae159d..b6a7ad531 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -4,7 +4,7 @@ import itertools import time import uuid -from typing import AsyncIterator, List, Optional, Tuple +from typing import AsyncIterator, List, Optional, Tuple, Sequence import torch from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor @@ -34,7 +34,7 @@ def __init__( self, config: ClientConfig, span: RemoteSpanInfo, - uid: ModuleUID, + span_uids: Sequence[ModuleUID], rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator, @@ -43,8 +43,8 @@ def __init__( **metadata, ): self.config = config - self.span, self.uid, self.rpc_info = span, uid, rpc_info - self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 + self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info + self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self.session_id = str(uuid.uuid4()) @@ -62,18 +62,19 @@ async def create( config: ClientConfig, p2p: P2P, span: RemoteSpanInfo, - uid: ModuleUID, + span_uids: Sequence[RemoteSpanInfo], rpc_info: RPCInfo, **metadata, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" + # TODO YOZH you don't need rpc info here stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), config.connect_timeout, ) - return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata) + return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: @@ -142,6 +143,7 @@ def step( request_metadata["args_structure"] = args_structure + # TODO YOZH FIX THIS BEFORE THE END OF THIS PR # TODO: make possible to use different compression method for different tensors server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"] compression = server_side_inference_schema[0].compression @@ -155,7 +157,7 @@ def step( outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( - uid=self.uid, + uid=CHAIN_DELIMITER.join(self.span_uids), tensors=[ serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) for tensor, proto in zip(input_tensors, inference_schema) @@ -244,8 +246,8 @@ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_Se server_sessions = [] try: for span in chosen_spans: - span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end]) - metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id) + span_uids = self._sequence_manager.block_uids[span.start : span.end] + metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids) session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( self._sequence_manager.config, diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 9151c1801..d1e3452e6 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -2,50 +2,49 @@ Utility functions that call RPC forward or backward on a single remote server """ import asyncio -from typing import Iterable, List, Optional, Sequence, Tuple +from typing import Iterable, List, Sequence, Tuple import torch -from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor, PeerID +from hivemind import PeerID, nested_flatten, serialize_torch_tensor from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor -from hivemind.p2p import StubBase from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter +from hivemind.utils.serializer import MSGPackSerializer from hivemind.utils.streaming import split_for_streaming -from hivemind.utils.tensor_descr import BatchTensorDescriptor -from petals import RemoteSequenceManager from petals.client.config import ClientConfig -from petals.data_structures import ModuleUID, RPCInfo, CHAIN_DELIMITER +from petals.client.routing import RemoteSequenceManager +from petals.data_structures import CHAIN_DELIMITER, ModuleUID from petals.server.handler import TransformerConnectionHandler from petals.utils.packaging import pack_args_kwargs async def _forward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward( - runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)), + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in outputs.tensors] async def _backward_unary( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward( - runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors)), + runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs), timeout=config.request_timeout, ) return [deserialize_torch_tensor(t) for t in grad_inputs.tensors] async def _forward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: parts = ( - runtime_pb2.ExpertRequest(uid=uid, tensors=[part]) + runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) @@ -55,10 +54,10 @@ async def _forward_stream( async def _backward_stream( - uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig + uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, config: ClientConfig, **kwargs ) -> List[torch.Tensor]: parts = ( - runtime_pb2.ExpertRequest(uid=uid, tensors=[part]) + runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs) for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) ) @@ -81,35 +80,39 @@ async def run_remote_forward( """ merged_uid = CHAIN_DELIMITER.join(span_uids) stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id) - flat_inputs, args_structure = pack_args_kwargs(*args, **kwargs) metadata = sequence_manager.get_request_metadata(peer_id, "rpc_forward", span_uids, *args, **kwargs) - compressions = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs) - if compressions is None: - compressions = [runtime_pb2.CompressionType.NONE] * len(flat_inputs) - compressions = list(nested_flatten(compressions)) - assert len(compressions) == len(flat_inputs), f"got {len(flat_inputs)} tensors but {len(compressions)} codecs" - inputs = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_inputs) + codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_forward", span_uids, *args, **kwargs) + flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs) + flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors) + args_structure = metadata.setdefault("args_structure", args_structure) + if codecs is None: + codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors) + else: + codecs = list(nested_flatten(codecs)) + assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs" # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( loop.run_in_executor(None, serialize_torch_tensor, tensor, compression) - for tensor, compression in zip(inputs, compressions) + for tensor, compression in zip(flat_tensors, codecs) ) ) # call RPC on remote server - size = sum(t.element_size() * t.nelement() for t in inputs) + size = sum(t.element_size() * t.nelement() for t in flat_tensors) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR - return await forward_fn(merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata) + return await forward_fn( + merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata) + ) async def run_remote_backward( sequence_manager: RemoteSequenceManager, + peer_id: PeerID, span_uids: Sequence[ModuleUID], - stub: StubBase, grad_outputs: Sequence[torch.Tensor], *args: torch.Tensor, **kwargs: torch.Tensor, @@ -119,23 +122,32 @@ async def run_remote_backward( Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221 but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here. """ - flat_tensors, args_structure = pack_args_kwargs( - [grad.cpu() for grad in grad_outputs], args, kwargs - ) - metadata = sequence_manager.get_request_metadata( - "rpc_backward", args_structure, span_uids, *flat_tensors, peer_id=span.peer_id - ) + merged_uid = CHAIN_DELIMITER.join(span_uids) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id) + metadata = sequence_manager.get_request_metadata(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs) + codecs = sequence_manager.get_compression_codecs(peer_id, "rpc_backward", span_uids, grad_outputs, *args, **kwargs) + flat_tensors, args_structure = pack_args_kwargs(grad_outputs, *args, **kwargs) + flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors) + args_structure = metadata.setdefault("args_structure", args_structure) + + if codecs is None: + codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors) + else: + codecs = list(nested_flatten(codecs)) + assert len(codecs) == len(flat_tensors), f"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs" # Asynchronous serialization loop = asyncio.get_running_loop() serialized_tensors = await asyncio.gather( *( - loop.run_in_executor(None, serialize_torch_tensor, compression) - for tensor, proto in zip(flat_inputs_and_grad_outputs, backward_schema) + loop.run_in_executor(None, serialize_torch_tensor, tensor, compression) + for tensor, compression in zip(flat_tensors, codecs) ) ) - size = sum(t.element_size() * t.nelement() for t in flat_inputs_and_grad_outputs) + size = sum(t.element_size() * t.nelement() for t in flat_tensors) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - return await backward_fn(uid, serialized_tensors, stub, config, metadata=metadata) + return await backward_fn( + merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata) + ) diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 560162109..c0b9fbf6d 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -474,7 +474,9 @@ def get_retry_delay(self, attempt_no: int) -> float: return 0 return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff) - def get_request_metadata(self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Dict[str, Any]]: + def get_request_metadata( + self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs + ) -> Optional[Dict[str, Any]]: """ :param peer_id: remote server's PeerID :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" @@ -488,7 +490,8 @@ def get_request_metadata(self, peer_id: PeerID, protocol: str, uids: Sequence[st ) def get_compression_codecs( - self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]: + self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs + ) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]: """ :param peer_id: remote server's PeerID :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index e8ba8fa65..5e26d21a3 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -4,19 +4,16 @@ import asyncio import itertools from collections import deque -from typing import List, Optional, Sequence, Tuple, Dict, Any +from typing import Any, Dict, List, Optional, Sequence, Tuple import torch -from hivemind import MSGPackSerializer from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.utils.logging import get_logger from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward from petals.client.routing import RemoteSequenceManager, maybe_log_traceback -from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo -from petals.server.handler import TransformerConnectionHandler +from petals.data_structures import RemoteSpanInfo from petals.utils.misc import DUMMY, is_dummy -from petals.utils.packaging import pack_args_kwargs logger = get_logger(__name__) @@ -24,12 +21,12 @@ async def sequential_forward( + sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor, - sequence_manager: RemoteSequenceManager, start_index: int = 0, end_index: Optional[int] = None, - block_kwargs: Sequence[Dict[str, Any]] = (), + *block_kwargs: Dict[str, Any], ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]: """ Constructs a routing path from to . @@ -45,13 +42,6 @@ async def sequential_forward( :param block_kwargs: optional per-block keyword arguments. Must be a sequence with one dictionary for each block """ - assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" - assert len(block_kwargs) in ( - 0, - 1, - end_index - start_index, - ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" - inputs_device = inputs.device inputs_dtype = inputs.dtype inputs = inputs.cpu() @@ -59,6 +49,9 @@ async def sequential_forward( end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) + assert len(block_kwargs) in (0, 1, end_index - start_index), \ + f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" + assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" assert is_dummy(prompts) or len(prompts) == len( sequence_manager.block_uids ) # should be n_layers - 1 but add extra prompts for convenience @@ -87,13 +80,13 @@ async def sequential_forward( sequence_manager.block_uids[span.start : span.end], inputs, prompts[span.start : span.end], - *block_kwargs[span.start : span.end] + *block_kwargs[span.start : span.end], ) assert isinstance(outputs, torch.Tensor) assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}" - # Save intermediate inputs and subsequ_peerences if the forward is already done for them + # Save intermediate inputs and subsequences if the forward is already done for them intermediate_inputs.append(inputs) done_sequences.append(span) @@ -118,11 +111,12 @@ async def sequential_forward( async def sequential_backward( + sequence_manager: RemoteSequenceManager, + forward_sequences: List[RemoteSpanInfo], grad_outputs: Sequence[torch.Tensor], intermediate_inputs: List[torch.Tensor], prompts: torch.Tensor, - forward_sequences: List[RemoteSpanInfo], - sequence_manager: RemoteSequenceManager, + *block_kwargs: Dict[str, Any], ) -> Tuple[Sequence[torch.Tensor], torch.Tensor]: """ Performs chained backward for each forward subsequence. @@ -148,7 +142,7 @@ async def sequential_backward( try: if attempt_no >= 1: _, backup_inputs, backup_sequences = await sequential_forward( - inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end + sequence_manager, inputs, prompts, start_index=span.start, end_index=span.end ) assert len(backup_inputs) == len(backup_sequences) assert backup_sequences[0].start == span.start @@ -159,14 +153,13 @@ async def sequential_backward( inputs = intermediate_inputs.pop() span = forward_sequences.pop() - - span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end]) - stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) grad_outputs, *span_grad_prompts = await run_remote_backward( sequence_manager, - sequence_manager.block_uids[span.start: span.end], - span_uids, - grad_outputs, inputs, + span.peer_id, + sequence_manager.block_uids[span.start : span.end], + grad_outputs, + *inputs, + *block_kwargs[span.start : span.end], ) grad_outputs = [grad_outputs] grad_prompts_reversed.extend(span_grad_prompts) @@ -198,7 +191,7 @@ async def _gather_forward(input_batches, prompt_batches, sequence_manager): """Wrapper for asyncio.gather to perform parallel sequential forwards""" return await asyncio.gather( *[ - sequential_forward(input_batch, prompt_batch, sequence_manager) + sequential_forward(sequence_manager, input_batch, prompt_batch) for input_batch, prompt_batch in zip(input_batches, prompt_batches) ] ) @@ -210,7 +203,7 @@ async def _gather_backward( """Wrapper for asyncio.gather to perform parallel sequential backwards""" return await asyncio.gather( *[ - sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager) + sequential_backward(sequence_manager, spans, (grad_output,), input_batch, prompt_batch) for grad_output, input_batch, prompt_batch, spans in zip( grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences ) From 4393d99e7839bf22bbe60f6f647a4c35254a1598 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 22:14:38 +0300 Subject: [PATCH 19/35] 1isort --- src/petals/client/inference_session.py | 2 +- src/petals/client/sequential_autograd.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index b6a7ad531..31fa5e7db 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -4,7 +4,7 @@ import itertools import time import uuid -from typing import AsyncIterator, List, Optional, Tuple, Sequence +from typing import AsyncIterator, List, Optional, Sequence, Tuple import torch from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 5e26d21a3..b4e189412 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -49,8 +49,11 @@ async def sequential_forward( end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) - assert len(block_kwargs) in (0, 1, end_index - start_index), \ - f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" + assert len(block_kwargs) in ( + 0, + 1, + end_index - start_index, + ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" assert is_dummy(prompts) or len(prompts) == len( sequence_manager.block_uids From 465fd93147ec7adf1f605572b8ba2aa1406d73d7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 5 Sep 2023 23:29:16 +0300 Subject: [PATCH 20/35] more WIP --- src/petals/client/remote_forward_backward.py | 2 ++ src/petals/client/sequential_autograd.py | 7 +++++-- src/petals/server/block_functions.py | 21 ++++++++++++-------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index d1e3452e6..f3eb91657 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -144,6 +144,8 @@ async def run_remote_backward( for tensor, compression in zip(flat_tensors, codecs) ) ) + for tensor, serialized_tensor in zip(flat_tensors, serialized_tensors): + serialized_tensor.requires_grad = tensor.requires_grad size = sum(t.element_size() * t.nelement() for t in flat_tensors) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index b4e189412..4db53be8f 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -161,7 +161,8 @@ async def sequential_backward( span.peer_id, sequence_manager.block_uids[span.start : span.end], grad_outputs, - *inputs, + inputs, + prompts[span.start: span.end], *block_kwargs[span.start : span.end], ) grad_outputs = [grad_outputs] @@ -224,12 +225,14 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) + input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches) if prompts is None or is_dummy(prompts): prompt_batches = [DUMMY] * len(input_batches) else: prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) + prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches) - sequence_manager.rpc_info # lazy init + sequence_manager.rpc_info # lazy init #TODO no longer needed outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager)) assert len(outputs) == len(input_batches) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index df300f2cf..4aa0cfc54 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -88,11 +88,11 @@ async def run_rpc_backward( points: int = 0, args_structure: Any, ) -> Tuple[Sequence[torch.Tensor], Any]: - (hidden_states, grad_outputs, prompts), backend_kwargs = _check_inputs( + assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad" + ((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs( requested_backends, flat_tensors, args_structure ) # Cast inputs & grad outputs to backend dtype - assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad" num_tokens = hidden_states.shape[0] * hidden_states.shape[1] hidden_states = hidden_states.to(requested_backends[0].dtype) grad_outputs = grad_outputs.to(requested_backends[-1].dtype) @@ -117,8 +117,7 @@ async def run_rpc_backward( (hidden_states,) = await backend.forward_pool.submit_task( active_adapter, hidden_states, **kwargs, priority=priority, size=num_tokens ) - - assert isinstance(hidden_states, torch.Tensor) + assert isinstance(hidden_states, torch.Tensor), "intermediate hidden states is not a tensor" if not is_dummy(prompts[-1]): hidden_states[:, : prompts[-1].shape[1]] += prompts[-1] @@ -129,13 +128,15 @@ async def run_rpc_backward( grad_backend_kwargs_reversed = [] # Run a chain of requested backends - for inp, prompt, backend, kwargs in reversed(list(zip(inter_inputs, prompts, requested_backends, backend_kwargs))): + for hidden_states, prompt, backend, kwargs in reversed(list(zip( + inter_inputs, prompts, requested_backends, backend_kwargs))): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" + hidden_states = hidden_states.detach().requires_grad_(True) priority = prioritizer.prioritize( - inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" + hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) (*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task( - active_adapter, grad_outputs, inp, **kwargs, priority=priority, size=num_tokens + active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens ) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): @@ -252,10 +253,14 @@ async def iterate_rpc_inference( def _check_inputs( requested_backends: Sequence[TransformerBackend], flat_tensors: Sequence[torch.Tensor], args_structure: Any ): + if len(flat_tensors) == 3: # backward compatibility for rpc_backward, remove after 2.3 + if flat_tensors[0].requires_grad and not flat_tensors[1].requires_grad: + hidden_states, grad_outputs, prompts = flat_tensors + flat_tensors = grad_outputs, hidden_states, prompts if args_structure is not None: args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure) else: - args, *backend_kwargs = flat_tensors, {} # backward compatibility + args, *backend_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2 if len(backend_kwargs) not in (1, len(requested_backends)): raise RuntimeError( From f2049658b6bc231afa5fdf102599af25a681d028 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 01:28:01 +0300 Subject: [PATCH 21/35] make it work for fwd, bwd --- src/petals/server/block_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 4aa0cfc54..81cca5561 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -135,7 +135,7 @@ async def run_rpc_backward( priority = prioritizer.prioritize( hidden_states, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) - (*grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task( + (grad_outputs, grad_kwargs) = await backend.backward_pool.submit_task( active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens ) assert isinstance(grad_outputs, torch.Tensor) @@ -145,7 +145,7 @@ async def run_rpc_backward( grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] - return pack_args_kwargs((grad_args, reversed(grad_backend_kwargs_reversed))) + return pack_args_kwargs((grad_args, list(reversed(grad_backend_kwargs_reversed)))) async def iterate_rpc_inference( From b7bd4770d7613f983f0e35c763ef8ac2bfd39c06 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 01:28:12 +0300 Subject: [PATCH 22/35] black-isort --- src/petals/client/sequential_autograd.py | 2 +- src/petals/server/block_functions.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 4db53be8f..d4226400e 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -162,7 +162,7 @@ async def sequential_backward( sequence_manager.block_uids[span.start : span.end], grad_outputs, inputs, - prompts[span.start: span.end], + prompts[span.start : span.end], *block_kwargs[span.start : span.end], ) grad_outputs = [grad_outputs] diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 81cca5561..e0bdd4642 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -128,8 +128,9 @@ async def run_rpc_backward( grad_backend_kwargs_reversed = [] # Run a chain of requested backends - for hidden_states, prompt, backend, kwargs in reversed(list(zip( - inter_inputs, prompts, requested_backends, backend_kwargs))): + for hidden_states, prompt, backend, kwargs in reversed( + list(zip(inter_inputs, prompts, requested_backends, backend_kwargs)) + ): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" hidden_states = hidden_states.detach().requires_grad_(True) priority = prioritizer.prioritize( From 9e29140bb02b022e7e06294f6a9f8c217b9360aa Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 01:37:48 +0300 Subject: [PATCH 23/35] mention reference issue --- src/petals/client/remote_forward_backward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index f3eb91657..15e2d9ef6 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -144,8 +144,8 @@ async def run_remote_backward( for tensor, compression in zip(flat_tensors, codecs) ) ) - for tensor, serialized_tensor in zip(flat_tensors, serialized_tensors): - serialized_tensor.requires_grad = tensor.requires_grad + for tensor, serialized in zip(flat_tensors, serialized_tensors): + serialized.requires_grad = tensor.requires_grad # see https://github.com/learning-at-home/hivemind/pull/591 size = sum(t.element_size() * t.nelement() for t in flat_tensors) backward_fn = _backward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _backward_unary From 17d278e88a6a23f8566f63410cdd709e31c0d083 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 02:38:14 +0300 Subject: [PATCH 24/35] black-isort-clarify --- src/petals/client/remote_forward_backward.py | 6 ++++- src/petals/client/routing/sequence_manager.py | 1 + src/petals/server/block_functions.py | 24 +++++++++++++------ src/petals/server/handler.py | 15 ++++++------ tests/test_remote_sequential.py | 4 ++-- 5 files changed, 33 insertions(+), 17 deletions(-) diff --git a/src/petals/client/remote_forward_backward.py b/src/petals/client/remote_forward_backward.py index 15e2d9ef6..45a30c150 100644 --- a/src/petals/client/remote_forward_backward.py +++ b/src/petals/client/remote_forward_backward.py @@ -104,9 +104,13 @@ async def run_remote_forward( size = sum(t.element_size() * t.nelement() for t in flat_tensors) forward_fn = _forward_stream if size > MAX_UNARY_PAYLOAD_SIZE // 2 else _forward_unary # Hotfix: we use "// 2" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR - return await forward_fn( + output_tensors = await forward_fn( merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=MSGPackSerializer.dumps(metadata) ) + # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591 + requires_grad = any(tensor.requires_grad for tensor in flat_tensors) + output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_tensors] + return output_tensors async def run_remote_backward( diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index c0b9fbf6d..884a37bc4 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -493,6 +493,7 @@ def get_compression_codecs( self, peer_id: PeerID, protocol: str, uids: Sequence[str], *args, **kwargs ) -> Optional[Sequence[runtime_pb2.CompressionType.ValueType]]: """ + return a sequence of compression codecs for client-side compression (applied to tensors sent to remote server) :param peer_id: remote server's PeerID :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference" :param args: request-specific input tensors diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index e0bdd4642..f27e276f3 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -31,20 +31,27 @@ async def run_rpc_forward( *flat_tensors: torch.Tensor, + args_structure: Any, requested_backends: Sequence[TransformerBackend], active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, - args_structure: Any, ) -> torch.Tensor: """ Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors - :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy) + :param args_structure: a schema that defines which of flat_tensors corresponds to which arg / kwarg + :note: see pack_args_kwargs function for the definition of args_structure :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass + :param active_adapter: the name of LoRA adapter to use; defaults to no adapter + :param prioritizer: assigns priorities to each sub-request based on the number of points + :param points: client-specified number of points, used to assign priorities + :param args_structure: :returns: hidden states after the last layer [batch_size, seq_length, hid_size] """ + requires_grad = any(tensor.requires_grad for tensor in flat_tensors) + flat_tensors = tuple(tensor.detach() for tensor in flat_tensors) (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes @@ -77,7 +84,7 @@ async def run_rpc_forward( hidden_states.ndim == 3 ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" - return hidden_states + return hidden_states.requires_grad_(requires_grad) async def run_rpc_backward( @@ -88,19 +95,22 @@ async def run_rpc_backward( points: int = 0, args_structure: Any, ) -> Tuple[Sequence[torch.Tensor], Any]: + """A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests""" assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad" ((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs( requested_backends, flat_tensors, args_structure ) + input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad + # Cast inputs & grad outputs to backend dtype num_tokens = hidden_states.shape[0] * hidden_states.shape[1] - hidden_states = hidden_states.to(requested_backends[0].dtype) - grad_outputs = grad_outputs.to(requested_backends[-1].dtype) + hidden_states = hidden_states.detach().to(requested_backends[0].dtype) + grad_outputs = grad_outputs.detach().to(requested_backends[-1].dtype) if prompts is None or is_dummy(prompts): prompts = [DUMMY] * len(requested_backends) else: - prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] + prompts = [p.squeeze(0).detach() for p in prompts.detach().to(requested_backends[0].dtype).split(1, dim=0)] # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output @@ -140,7 +150,7 @@ async def run_rpc_backward( active_adapter, grad_outputs, hidden_states, **kwargs, priority=priority, size=num_tokens ) assert isinstance(grad_outputs, torch.Tensor) - if not is_dummy(prompt): + if not is_dummy(prompt) and prompts_requires_grad: grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) grad_backend_kwargs_reversed.append(grad_kwargs) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 8108faa97..8e4d84846 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -361,18 +361,19 @@ async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PCont active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) args_structure = metadata.get("args_structure") + assert isinstance( points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" - hidden_states = await run_rpc_forward( *flat_inputs, + args_structure=args_structure, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, - args_structure=args_structure, ) + return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) ) @@ -396,11 +397,11 @@ async def rpc_forward_stream( hidden_states = await run_rpc_forward( *flat_inputs, + args_structure=args_structure, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, - args_structure=args_structure, ) # Split the serialized_output for streaming and respond to client @@ -450,8 +451,8 @@ async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PCon flat_grads, grads_structure = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, args_structure=args_structure, ) @@ -479,8 +480,8 @@ async def rpc_backward_stream( flat_grads, grad_structure = await run_rpc_backward( *flat_tensors, requested_backends=requested_backends, - prioritizer=self._prioritizer, active_adapter=active_adapter, + prioritizer=self._prioritizer, points=points, args_structure=args_structure, ) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 20c6011ef..90c403332 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -73,8 +73,8 @@ def rpc_info(self): rpc_info["forward_schema"] = (compressed_input_schema,), dict() # (args, kwargs) return rpc_info - def get_request_metadata(self, protocol: str, *args, **kwargs): - metadata = super().get_request_metadata(protocol, *args, **kwargs) + def get_request_metadata(self, peer_id, protocol, block_uids, *args, **kwargs): + metadata = super().get_request_metadata(peer_id, protocol, block_uids, *args, **kwargs) if protocol == "rpc_forward": metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) elif protocol == "rpc_backward": From 62e780c0542dc1d424f55caebaafe8520047641c Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:13:08 +0300 Subject: [PATCH 25/35] check num block kwargs --- src/petals/client/sequential_autograd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index d4226400e..92cdecc5f 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -49,11 +49,11 @@ async def sequential_forward( end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) - assert len(block_kwargs) in ( - 0, - 1, - end_index - start_index, + assert ( + len(block_kwargs) in (0, 1, end_index - start_index) ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * (end_index - start_index) assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" assert is_dummy(prompts) or len(prompts) == len( sequence_manager.block_uids From aacd8b2f9d1a60b109a84a43dd2308ec96b131a1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:16:03 +0300 Subject: [PATCH 26/35] pass args/kwargs via forward --- src/petals/client/remote_sequential.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index c6d2833d1..a6359bcfb 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -49,13 +49,13 @@ def __init__( self._active_session = ContextVar("active_session", default=None) - def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *args, **kwargs) -> torch.Tensor: assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" if self.active_session is None: assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}" - return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager) + return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs) else: - return self.active_session.step(inputs, prompts, **kwargs) + return self.active_session.step(inputs, prompts, *args, **kwargs) @property def active_session(self) -> Optional[InferenceSession]: From 056cd77f117fa20ab55d04e8b51b83a206941f3a Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:16:15 +0300 Subject: [PATCH 27/35] standardize checking block_kwargs --- src/petals/client/sequential_autograd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 92cdecc5f..a86df3187 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -49,11 +49,11 @@ async def sequential_forward( end_index = end_index if end_index is not None else len(sequence_manager.block_uids) assert start_index >= 0 and end_index <= len(sequence_manager.block_uids) - assert ( - len(block_kwargs) in (0, 1, end_index - start_index) - ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" if len(block_kwargs) == 1: block_kwargs = block_kwargs * (end_index - start_index) + assert ( + len(block_kwargs) in (0, end_index - start_index) + ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" assert is_dummy(prompts) or len(prompts) == len( sequence_manager.block_uids From a23bd73f3b0cfc0972b3b2f854f7c33314bd0b56 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:30:26 +0300 Subject: [PATCH 28/35] probably break everyting --- src/petals/client/inference_session.py | 32 ++++++++++++++++++------ src/petals/client/sequential_autograd.py | 8 +++--- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 31fa5e7db..a3f1130c3 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -4,7 +4,7 @@ import itertools import time import uuid -from typing import AsyncIterator, List, Optional, Sequence, Tuple +from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple import torch from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor @@ -43,7 +43,7 @@ def __init__( **metadata, ): self.config = config - self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info + self.span, self.span_uids = span, span_uids self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter @@ -67,7 +67,6 @@ async def create( **metadata, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" - # TODO YOZH you don't need rpc info here stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( @@ -89,7 +88,7 @@ def step( inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None, - *, + *block_kwargs: Dict[str, Any], step_id: str, ) -> torch.Tensor: """ @@ -97,6 +96,7 @@ def step( :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ + # TODO record previous kwargs in case of server failure!!! if self.closed: raise Exception("Session is closed, cannot perform step") @@ -115,6 +115,7 @@ def step( else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + assert len(block_kwargs) in (0, self.span.length) if prompts is None or is_dummy(prompts): prompts = DUMMY else: @@ -131,7 +132,7 @@ def step( assert hypo_ids.dtype == torch.int64 # serialize inputs and put them into the queue - input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids) + input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs) request_metadata = dict(session_id=self.session_id, step_id=step_id) if not self.stepped: @@ -141,7 +142,7 @@ def step( if next_servers: request_metadata["next_servers"] = next_servers - request_metadata["args_structure"] = args_structure + args_structure = request_metadata.setdefault("args_structure", args_structure) # TODO YOZH FIX THIS BEFORE THE END OF THIS PR # TODO: make possible to use different compression method for different tensors @@ -277,11 +278,22 @@ def __enter__(self) -> "InferenceSession": assert not self._closed and not self._server_sessions return self - def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + def step( + self, + inputs: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + *block_kwargs: Sequence[Dict[str, torch.Tensor]], + **kwargs, + ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") + num_blocks = len(self._sequence_manager) + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * num_blocks + assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}" + if prompts is None or is_dummy(prompts): prompts = DUMMY else: @@ -312,7 +324,11 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs + inputs, + prompts[server_session.span.start : server_session.span.end], + *block_kwargs[server_session.span.start : server_session.span.end], + step_id=step_id, + **kwargs, ) server_idx += 1 diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index a86df3187..174849038 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -52,7 +52,7 @@ async def sequential_forward( if len(block_kwargs) == 1: block_kwargs = block_kwargs * (end_index - start_index) assert ( - len(block_kwargs) in (0, end_index - start_index) + not block_kwargs or len(block_kwargs) == end_index - start_index ), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs" assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}" assert is_dummy(prompts) or len(prompts) == len( @@ -222,7 +222,8 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager): + def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompts: torch.Tensor): + # TODO add kwargs here; figure out a way to split kwargs across servers batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1) input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size) input_batches = tuple(batch.requires_grad_(inputs.requires_grad) for batch in input_batches) @@ -271,4 +272,5 @@ def backward(ctx, grad_outputs: torch.Tensor): grad_inputs = torch.cat(grad_input_batches, dim=0) dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches] grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None - return (grad_inputs, grad_prompts, None) + # TODO return grads w.r.t. kwargs here + return (None, grad_inputs, grad_prompts) From 68b8cea2465eaa110ea4a93d1ee359ca28367fc3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:35:39 +0300 Subject: [PATCH 29/35] note --- src/petals/client/inference_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index a3f1130c3..383e1ac01 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -255,7 +255,7 @@ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_Se self._sequence_manager.state.p2p, span, span_uids, - rpc_info=self._sequence_manager.rpc_info, + rpc_info=self._sequence_manager.rpc_info, #TODO not actually needed max_length=self._max_length, **metadata, ) From 8eb1722f1e95389a7ac8f3d742876eaeb04bf94a Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:37:45 +0300 Subject: [PATCH 30/35] standardize: s/backend_kwargs/block_kwargs/g everywhere --- src/petals/server/backend.py | 8 +++--- src/petals/server/block_functions.py | 42 ++++++++++++++-------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index db4faac4a..2f9684ff2 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -242,12 +242,12 @@ def __call__( hypo_ids: torch.LongTensor, inference_infos: Sequence[InferenceMetadata], *optional_prompts: Optional[torch.Tensor], - backend_kwargs: Sequence[Dict[str, torch.Tensor]], + block_kwargs: Sequence[Dict[str, torch.Tensor]], ) -> Tuple[torch.Tensor, ...]: assert ( - len(inference_infos) == len(optional_prompts) == len(backend_kwargs) - ), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(backend_kwargs)} kwargs" - for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, backend_kwargs): + len(inference_infos) == len(optional_prompts) == len(block_kwargs) + ), f"mismatch: got {len(inference_infos)} infos, {len(optional_prompts)} prompts, {len(block_kwargs)} kwargs" + for inference_info, optional_prompt, kwargs in zip(inference_infos, optional_prompts, block_kwargs): if optional_prompt is not None: hidden_states[:, : optional_prompt.shape[1]] += optional_prompt (hidden_states,) = self.backends[inference_info.uid].inference_step( diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index f27e276f3..d4898c98a 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -52,7 +52,7 @@ async def run_rpc_forward( """ requires_grad = any(tensor.requires_grad for tensor in flat_tensors) flat_tensors = tuple(tensor.detach() for tensor in flat_tensors) - (hidden_states, prompts), backend_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) + (hidden_states, prompts), block_kwargs = _check_inputs(requested_backends, flat_tensors, args_structure) dtype = requested_backends[0].dtype # check parse input tensors and cast dtypes hidden_states = hidden_states.to(dtype) @@ -64,7 +64,7 @@ async def run_rpc_forward( prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)] # Run a chain of requested backends - for backend, prompt, kwargs in zip(requested_backends, prompts, backend_kwargs): + for backend, prompt, kwargs in zip(requested_backends, prompts, block_kwargs): if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt @@ -97,7 +97,7 @@ async def run_rpc_backward( ) -> Tuple[Sequence[torch.Tensor], Any]: """A custom backward pass used by the server to service rpc_backward and rpc_backward_stream requests""" assert any(x.requires_grad for x in flat_tensors), "cannot backward: none of the input tensors requires_grad" - ((grad_outputs,), hidden_states, prompts), backend_kwargs = _check_inputs( + ((grad_outputs,), hidden_states, prompts), block_kwargs = _check_inputs( requested_backends, flat_tensors, args_structure ) input_requires_grad, prompts_requires_grad = hidden_states.requires_grad, prompts.requires_grad @@ -115,7 +115,7 @@ async def run_rpc_backward( # Run a forward chain to collect intermediate inputs # Note that we do not forward for the last module since we do not need its output inter_inputs = [] - for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], backend_kwargs): + for backend, prompt, kwargs in zip(requested_backends[:-1], prompts[:-1], block_kwargs): assert hidden_states.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" if not is_dummy(prompt): hidden_states[:, : prompt.shape[1]] += prompt @@ -135,11 +135,11 @@ async def run_rpc_backward( assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward" grad_prompts_reversed = [] - grad_backend_kwargs_reversed = [] + grad_block_kwargs_reversed = [] # Run a chain of requested backends for hidden_states, prompt, backend, kwargs in reversed( - list(zip(inter_inputs, prompts, requested_backends, backend_kwargs)) + list(zip(inter_inputs, prompts, requested_backends, block_kwargs)) ): assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools" hidden_states = hidden_states.detach().requires_grad_(True) @@ -152,11 +152,11 @@ async def run_rpc_backward( assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt) and prompts_requires_grad: grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0)) - grad_backend_kwargs_reversed.append(grad_kwargs) + grad_block_kwargs_reversed.append(grad_kwargs) grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY grad_args = [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] - return pack_args_kwargs((grad_args, list(reversed(grad_backend_kwargs_reversed)))) + return pack_args_kwargs((grad_args, list(reversed(grad_block_kwargs_reversed)))) async def iterate_rpc_inference( @@ -179,7 +179,7 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors) - (hidden_states, prompts, hypo_ids), backend_kwargs = _check_inputs( + (hidden_states, prompts, hypo_ids), block_kwargs = _check_inputs( requested_backends, flat_tensors, args_structure ) batch_size, length_increment, _ = hidden_states.shape @@ -230,13 +230,13 @@ async def iterate_rpc_inference( hypo_ids, inference_infos, *prompts, - backend_kwargs=backend_kwargs, + block_kwargs=block_kwargs, priority=priority, size=num_tokens, ) else: for backend, uid, handles, prompt, kwargs in zip( - requested_backends, requested_uids, cache_handles, prompts, backend_kwargs + requested_backends, requested_uids, cache_handles, prompts, block_kwargs ): inference_infos = (InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter),) (hidden_states,) = await backend.inference_pool.submit_task( @@ -244,7 +244,7 @@ async def iterate_rpc_inference( hypo_ids, inference_infos, prompt, - backend_kwargs=(kwargs,), + block_kwargs=(kwargs,), priority=priority, size=num_tokens, ) @@ -269,19 +269,19 @@ def _check_inputs( hidden_states, grad_outputs, prompts = flat_tensors flat_tensors = grad_outputs, hidden_states, prompts if args_structure is not None: - args, *backend_kwargs = unpack_args_kwargs(flat_tensors, args_structure) + args, *block_kwargs = unpack_args_kwargs(flat_tensors, args_structure) else: - args, *backend_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2 + args, *block_kwargs = flat_tensors, {} # backward compatibility for grad structure, remove at 2.2 - if len(backend_kwargs) not in (1, len(requested_backends)): + if len(block_kwargs) not in (1, len(requested_backends)): raise RuntimeError( f"Server expected either one dict of keyword arguments or {len(requested_backends)} dicts " - f"(one for each block). Found {len(backend_kwargs)} instead." + f"(one for each block). Found {len(block_kwargs)} instead." ) - if len(backend_kwargs) == 1: - backend_kwargs = backend_kwargs * len(requested_backends) - assert len(backend_kwargs) == len(requested_backends) - for i, kwargs in enumerate(backend_kwargs): + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * len(requested_backends) + assert len(block_kwargs) == len(requested_backends) + for i, kwargs in enumerate(block_kwargs): if not isinstance(kwargs, dict): raise RuntimeError(f"Expected kwargs for block {i} to be a dictionary, got {type(kwargs)}") - return args, backend_kwargs + return args, block_kwargs From 3bffcde0feea99c45898f1612a0afebfcf7d7626 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:37:59 +0300 Subject: [PATCH 31/35] black+isort --- src/petals/client/inference_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 383e1ac01..b1e36998c 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -255,7 +255,7 @@ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_Se self._sequence_manager.state.p2p, span, span_uids, - rpc_info=self._sequence_manager.rpc_info, #TODO not actually needed + rpc_info=self._sequence_manager.rpc_info, # TODO not actually needed max_length=self._max_length, **metadata, ) From 721f7d2db34f4ed5772a30bda7f3b12ce6755d48 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 03:52:40 +0300 Subject: [PATCH 32/35] unbreak everything --- src/petals/client/inference_session.py | 4 +--- src/petals/client/remote_sequential.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index b1e36998c..c625b08be 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -43,7 +43,7 @@ def __init__( **metadata, ): self.config = config - self.span, self.span_uids = span, span_uids + self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter @@ -283,7 +283,6 @@ def step( inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *block_kwargs: Sequence[Dict[str, torch.Tensor]], - **kwargs, ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): @@ -328,7 +327,6 @@ def step( prompts[server_session.span.start : server_session.span.end], *block_kwargs[server_session.span.start : server_session.span.end], step_id=step_id, - **kwargs, ) server_idx += 1 diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index a6359bcfb..4d43f310a 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -53,7 +53,7 @@ def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]" if self.active_session is None: assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}" - return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args, **kwargs) + return _RemoteSequentialAutogradFunction.apply(self.sequence_manager, inputs, prompts, *args) else: return self.active_session.step(inputs, prompts, *args, **kwargs) From 3048c3b3ad823dcdc51e500402361280129bff7b Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 04:29:28 +0300 Subject: [PATCH 33/35] rollback --- src/petals/client/inference_session.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index c625b08be..c1f4e2f38 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -87,8 +87,8 @@ def step( self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, - hypo_ids: Optional[torch.Tensor] = None, *block_kwargs: Dict[str, Any], + hypo_ids: Optional[torch.Tensor] = None, step_id: str, ) -> torch.Tensor: """ @@ -283,6 +283,7 @@ def step( inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, *block_kwargs: Sequence[Dict[str, torch.Tensor]], + hypo_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert not self._closed if torch.is_grad_enabled(): @@ -327,6 +328,7 @@ def step( prompts[server_session.span.start : server_session.span.end], *block_kwargs[server_session.span.start : server_session.span.end], step_id=step_id, + hypo_ids=hypo_ids, ) server_idx += 1 From 3f06b53b1dd8b990369624805997cd7893afac1d Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 05:53:07 +0300 Subject: [PATCH 34/35] temporary rollback: allow kwargs only at first inference step --- src/petals/client/inference_session.py | 133 +++++++++++++------------ 1 file changed, 69 insertions(+), 64 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index c1f4e2f38..6d173886b 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -7,15 +7,13 @@ from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple import torch -from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor +from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker -from hivemind.p2p import P2P from hivemind.proto import runtime_pb2 -from hivemind.utils.tensor_descr import BatchTensorDescriptor +from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten -from petals.client.config import ClientConfig from petals.client.routing import RemoteSequenceManager, maybe_log_traceback -from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo +from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo from petals.server.handler import TransformerConnectionHandler from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy from petals.utils.packaging import pack_args_kwargs @@ -32,23 +30,21 @@ class _ServerInferenceSession: def __init__( self, - config: ClientConfig, + sequence_manager: RemoteSequenceManager, span: RemoteSpanInfo, span_uids: Sequence[ModuleUID], - rpc_info: RPCInfo, inputs_queue: asyncio.Queue, - outputs_aiter: AsyncIterator, - *, + outputs_stream: AsyncIterator, + *block_kwargs, max_length: int, - **metadata, ): - self.config = config - self.span, self.span_uids, self.rpc_info = span, span_uids, rpc_info + self.sequence_manager = sequence_manager + self.span, self.span_uids = span, span_uids self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue - self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter + self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream self.session_id = str(uuid.uuid4()) - self.session_metadata = dict(max_length=max_length, **metadata) + self.max_length = max_length self.stepped = False self.closed = False @@ -56,24 +52,26 @@ def __init__( self.history = None # Used in case of server failures to regenerate attention caches on new servers self.next_session = None + self.block_kwargs = block_kwargs + assert len(self.block_kwargs) in (0, self.num_blocks) + @classmethod async def create( cls, - config: ClientConfig, - p2p: P2P, + sequence_manager: RemoteSequenceManager, span: RemoteSpanInfo, - span_uids: Sequence[RemoteSpanInfo], - rpc_info: RPCInfo, - **metadata, + span_uids: Sequence[ModuleUID], + *block_kwargs: Dict[str, Any], + **kwargs, ) -> _ServerInferenceSession: """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" - stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id) + stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id) inputs_queue = asyncio.Queue() outputs_stream = await asyncio.wait_for( stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)), - config.connect_timeout, + sequence_manager.config.connect_timeout, ) - return cls(config, span, span_uids, rpc_info, inputs_queue, outputs_stream, **metadata) + return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs) @staticmethod async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator: @@ -87,7 +85,7 @@ def step( self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, - *block_kwargs: Dict[str, Any], + *, hypo_ids: Optional[torch.Tensor] = None, step_id: str, ) -> torch.Tensor: @@ -96,7 +94,6 @@ def step( :param prompts: optional DEEP prompts, added to a prefix of each layer's outputs, if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size] """ - # TODO record previous kwargs in case of server failure!!! if self.closed: raise Exception("Session is closed, cannot perform step") @@ -112,10 +109,11 @@ def step( if not self.stepped: inputs = self.history # Pass full inputs including prefix + block_kwargs = self.block_kwargs else: inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further + block_kwargs = [] - assert len(block_kwargs) in (0, self.span.length) if prompts is None or is_dummy(prompts): prompts = DUMMY else: @@ -131,39 +129,50 @@ def step( assert len(hypo_ids) == len(inputs) assert hypo_ids.dtype == torch.int64 - # serialize inputs and put them into the queue - input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs) - - request_metadata = dict(session_id=self.session_id, step_id=step_id) - if not self.stepped: - request_metadata.update(self.session_metadata) - elif self.config.use_server_to_server: + metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length) + metadata.update( + self.sequence_manager.get_request_metadata( + self.span.peer_id, + "rpc_inference", + self.span_uids, + inputs, + prompts, + *block_kwargs, + max_length=self.max_length, + session_id=self.session_id, + step_id=step_id, + ) + ) + if self.stepped and self.sequence_manager.config.use_server_to_server: next_servers = self._collect_next_servers() if next_servers: - request_metadata["next_servers"] = next_servers + metadata["next_servers"] = next_servers - args_structure = request_metadata.setdefault("args_structure", args_structure) + codecs = self.sequence_manager.get_compression_codecs( + self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs + ) - # TODO YOZH FIX THIS BEFORE THE END OF THIS PR - # TODO: make possible to use different compression method for different tensors - server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"] - compression = server_side_inference_schema[0].compression - inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors) + # serialize inputs and put them into the queue + input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs) + args_structure = metadata.setdefault("args_structure", args_structure) - # TODO: create more explicit way to check servers schema and client's structure - assert len(input_tensors) >= len( - server_side_inference_schema - ), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step" + if codecs is None: + codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors) + else: + codecs = list(nested_flatten(codecs)) + assert len(codecs) == len( + input_tensors + ), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs" outputs_serialized = RemoteExpertWorker.run_coroutine( self._step( runtime_pb2.ExpertRequest( uid=CHAIN_DELIMITER.join(self.span_uids), tensors=[ - serialize_torch_tensor(tensor.to(proto.dtype), proto.compression) - for tensor, proto in zip(input_tensors, inference_schema) + serialize_torch_tensor(tensor, compression) + for tensor, compression in zip(input_tensors, codecs) ], - metadata=MSGPackSerializer.dumps(request_metadata), + metadata=MSGPackSerializer.dumps(metadata), ) ) ) @@ -190,7 +199,7 @@ async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_p """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" await self._inputs_queue.put(inputs_serialized) self.stepped = True - return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout) + return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout) def close(self): """Finish a given inference session, close the underlying connection""" @@ -227,7 +236,7 @@ class InferenceSession: An interface to a multi-step *inference* session for a sequence of remote transformer blocks """ - def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): + def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]): self._sequence_manager = sequence_manager self._closed = False self._server_sessions = [] @@ -235,6 +244,12 @@ def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int): self._max_length = max_length self.output_ids = None + num_blocks = len(self._sequence_manager) + if len(block_kwargs) == 1: + block_kwargs = block_kwargs * num_blocks + assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}" + self.block_kwargs = block_kwargs + @property def num_blocks(self) -> int: return len(self._sequence_manager) @@ -247,17 +262,13 @@ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_Se server_sessions = [] try: for span in chosen_spans: - span_uids = self._sequence_manager.block_uids[span.start : span.end] - metadata = self._sequence_manager.get_request_metadata(span.peer_id, "rpc_inference", span_uids) session = RemoteExpertWorker.run_coroutine( _ServerInferenceSession.create( - self._sequence_manager.config, - self._sequence_manager.state.p2p, + self._sequence_manager, span, - span_uids, - rpc_info=self._sequence_manager.rpc_info, # TODO not actually needed + self._sequence_manager.block_uids[span.start : span.end], + *self.block_kwargs[span.start : span.end], max_length=self._max_length, - **metadata, ) ) server_sessions.append(session) @@ -282,18 +293,13 @@ def step( self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, - *block_kwargs: Sequence[Dict[str, torch.Tensor]], hypo_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert not self._closed if torch.is_grad_enabled(): logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") - num_blocks = len(self._sequence_manager) - if len(block_kwargs) == 1: - block_kwargs = block_kwargs * num_blocks - assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}" - if prompts is None or is_dummy(prompts): prompts = DUMMY else: @@ -326,9 +332,8 @@ def step( inputs = server_session.step( inputs, prompts[server_session.span.start : server_session.span.end], - *block_kwargs[server_session.span.start : server_session.span.end], - step_id=step_id, hypo_ids=hypo_ids, + step_id=step_id, ) server_idx += 1 @@ -354,7 +359,7 @@ def step( outputs = outputs.to(device=inputs_device, dtype=inputs_dtype) return outputs - def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int: + def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int): # If there is a failed server session, this code closes it self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1]) From c665c42cf2049990d7d9783466fb58ff4eab3b99 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 6 Sep 2023 06:24:59 +0300 Subject: [PATCH 35/35] reduce diff --- src/petals/client/inference_session.py | 4 ++-- src/petals/client/sequential_autograd.py | 1 - tests/test_priority_pool.py | 9 +++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index 6d173886b..f6195d84a 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -34,7 +34,7 @@ def __init__( span: RemoteSpanInfo, span_uids: Sequence[ModuleUID], inputs_queue: asyncio.Queue, - outputs_stream: AsyncIterator, + outputs_aiter: AsyncIterator, *block_kwargs, max_length: int, ): @@ -42,7 +42,7 @@ def __init__( self.span, self.span_uids = span, span_uids self.num_blocks = len(span_uids) self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue - self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_stream + self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self.session_id = str(uuid.uuid4()) self.max_length = max_length self.stepped = False diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 174849038..6d450e370 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -233,7 +233,6 @@ def forward(ctx, sequence_manager: RemoteSequenceManager, inputs: torch.Tensor, prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1) prompt_batches = tuple(batch.requires_grad_(prompts.requires_grad) for batch in prompt_batches) - sequence_manager.rpc_info # lazy init #TODO no longer needed outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager)) assert len(outputs) == len(input_batches) diff --git a/tests/test_priority_pool.py b/tests/test_priority_pool.py index 8bf673cf4..15c6de5f4 100644 --- a/tests/test_priority_pool.py +++ b/tests/test_priority_pool.py @@ -4,8 +4,8 @@ import pytest import torch -from hivemind.moe.server.runtime import Runtime +from petals.server.server import RuntimeWithDeduplicatedPools from petals.server.task_pool import PrioritizedTaskPool @@ -35,8 +35,7 @@ def test_priority_pools(): runtime_ready = mp.Event() results_valid = mp.Event() - def dummy_pool_func(args, kwargs): - (x,) = args # TODO modify the PriorityPool code such that dummy_pool_func can accept x directly + def dummy_pool_func(x): time.sleep(0.1) y = x**2 outputs_queue.put((x, y)) @@ -58,7 +57,9 @@ def get_pools(self): proc = mp.context.ForkProcess(target=_submit_tasks, args=(runtime_ready, pools, results_valid)) proc.start() - runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0) + runtime = RuntimeWithDeduplicatedPools( + {str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0 + ) runtime.ready = runtime_ready runtime.start()