diff --git a/skyplane/compute/server.py b/skyplane/compute/server.py index 585773105..8529ec308 100644 --- a/skyplane/compute/server.py +++ b/skyplane/compute/server.py @@ -378,7 +378,7 @@ def check_stderr(tup): logger.fs.debug(f"{self.uuid()} gateway_api_url = {self.gateway_api_url}") # wait for gateways to start (check status API) - http_pool = urllib3.PoolManager() + http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=10)) def is_api_ready(): try: diff --git a/skyplane/gateway/operators/gateway_operator.py b/skyplane/gateway/operators/gateway_operator.py index c574caebc..68dcc636b 100644 --- a/skyplane/gateway/operators/gateway_operator.py +++ b/skyplane/gateway/operators/gateway_operator.py @@ -261,6 +261,23 @@ def make_socket(self, dst_host): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return sock + def send_data(self, dst_host, header, data): + # contact server to set up socket connection + if self.destination_ports.get(dst_host) is None: + self.destination_sockets[dst_host] = self.make_socket(dst_host) + sock = self.destination_sockets[dst_host] + + try: + header.to_socket(sock) + sock.sendall(data) + except socket.error as e: + print(e) + del self.destination_ports[dst_host] + return False + # if successful, return True + return True + + # send chunks to other instances def process(self, chunk_req: ChunkRequest, dst_host: str): """Send list of chunks to gateway server, pipelining small chunks together into a single socket stream.""" @@ -307,15 +324,6 @@ def process(self, chunk_req: ChunkRequest, dst_host: str): print(f"[{self.handle}:{self.worker_id}] Error registering chunks {chunk_ids} to {dst_host}: {e}") raise e - # contact server to set up socket connection - if self.destination_ports.get(dst_host) is None: - print(f"[sender-{self.worker_id}]:{chunk_ids} creating new socket") - self.destination_sockets[dst_host] = retry_backoff( - partial(self.make_socket, dst_host), max_retries=3, exception_class=socket.timeout - ) - print(f"[sender-{self.worker_id}]:{chunk_ids} created new socket") - sock = self.destination_sockets[dst_host] - # TODO: cleanup so this isn't a loop for idx, chunk_req in enumerate(chunk_reqs): # self.chunk_store.state_start_upload(chunk_id, f"sender:{self.worker_id}") @@ -347,16 +355,13 @@ def process(self, chunk_req: ChunkRequest, dst_host: str): raw_wire_length=raw_wire_length, is_compressed=(compressed_length is not None), ) - # print(f"[sender-{self.worker_id}]:{chunk_id} sending chunk header {header}") - header.to_socket(sock) - # print(f"[sender-{self.worker_id}]:{chunk_id} sent chunk header") - # send chunk data assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist" - # file_size = os.path.getsize(chunk_file_path) - - with Timer() as t: - sock.sendall(data) + + while True: + with Timer() as t: + is_suc = self.send_data(dst_host=dst_host, header=header, data=data) + if is_suc: break # logger.debug(f"[sender:{self.worker_id}]:{chunk_id} sent at {chunk.chunk_length_bytes * 8 / t.elapsed / MB:.2f}Mbps") print(f"[sender:{self.worker_id}]:{chunk_id} sent at {wire_length * 8 / t.elapsed / MB:.2f}Mbps") diff --git a/skyplane/gateway/operators/gateway_receiver.py b/skyplane/gateway/operators/gateway_receiver.py index aeed0e341..03d19d700 100644 --- a/skyplane/gateway/operators/gateway_receiver.py +++ b/skyplane/gateway/operators/gateway_receiver.py @@ -145,31 +145,31 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]): init_space = self.chunk_store.remaining_bytes() print("Init space", init_space) while True: - # receive header and write data to file - logger.debug(f"[receiver:{server_port}] Blocking for next header") - chunk_header = WireProtocolHeader.from_socket(conn) - logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} Got chunk header {chunk_header}") - - # TODO: this wont work - # chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id) - - should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region - should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region - - # wait for space - # while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks: - # print( - # f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}" - # ) - # time.sleep(0.1) - - # get data - # self.chunk_store.state_queue_download(chunk_header.chunk_id) - # self.chunk_store.state_start_download(chunk_header.chunk_id, f"receiver:{self.worker_id}") - logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} wire header length {chunk_header.data_len}") - with Timer() as t: - fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id) - with fpath.open("wb") as f: + try: + # receive header and write data to file + logger.debug(f"[receiver:{server_port}] Blocking for next header") + chunk_header = WireProtocolHeader.from_socket(conn) + logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} Got chunk header {chunk_header}") + + # TODO: this wont work + # chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id) + + should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region + should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region + + # wait for space + # while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks: + # print( + # f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}" + # ) + # time.sleep(0.1) + + # get data + # self.chunk_store.state_queue_download(chunk_header.chunk_id) + # self.chunk_store.state_start_download(chunk_header.chunk_id, f"receiver:{self.worker_id}") + logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} wire header length {chunk_header.data_len}") + with Timer() as t: + socket_data_len = chunk_header.data_len chunk_received_size, chunk_received_size_decompressed = 0, 0 to_write = bytearray(socket_data_len) @@ -199,29 +199,35 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]): print( f"[receiver:{server_port}]:{chunk_header.chunk_id} Decompressing {len(to_write)} bytes to {chunk_received_size_decompressed} bytes" ) - - # try to write data until successful - while True: - try: - f.seek(0, 0) - f.write(to_write) - f.flush() - - # check write succeeds - assert os.path.exists(fpath) - - # check size - file_size = os.path.getsize(fpath) - if file_size == chunk_header.raw_data_len: - break - elif file_size >= chunk_header.raw_data_len: - raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.raw_data_len}") - except Exception as e: - print(e) - print( - f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}" - ) - time.sleep(1) + except socket.error as e: + print(e) + # This may have pipeline broken error, if happened then restart receiver. + continue + + fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id) + with fpath.open("wb") as f: + # try to write data until successful + while True: + try: + f.seek(0, 0) + f.write(to_write) + f.flush() + + # check write succeeds + assert os.path.exists(fpath) + + # check size + file_size = os.path.getsize(fpath) + if file_size == chunk_header.raw_data_len: + break + elif file_size >= chunk_header.raw_data_len: + raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.raw_data_len}") + except Exception as e: + print(e) + print( + f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}" + ) + time.sleep(1) assert ( socket_data_len == 0 and chunk_received_size == chunk_header.data_len ), f"Size mismatch: got {chunk_received_size} expected {chunk_header.data_len} and had {socket_data_len} bytes remaining"