Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] #907 #908 #911

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skyplane/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def check_stderr(tup):
logger.fs.debug(f"{self.uuid()} gateway_api_url = {self.gateway_api_url}")

# wait for gateways to start (check status API)
http_pool = urllib3.PoolManager()
http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=10))

def is_api_ready():
try:
Expand Down
39 changes: 22 additions & 17 deletions skyplane/gateway/operators/gateway_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,23 @@ def make_socket(self, dst_host):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return sock

def send_data(self, dst_host, header, data):
# contact server to set up socket connection
if self.destination_ports.get(dst_host) is None:
self.destination_sockets[dst_host] = self.make_socket(dst_host)
sock = self.destination_sockets[dst_host]

try:
header.to_socket(sock)
sock.sendall(data)
except socket.error as e:
print(e)
del self.destination_ports[dst_host]
return False
# if successful, return True
return True


# send chunks to other instances
def process(self, chunk_req: ChunkRequest, dst_host: str):
"""Send list of chunks to gateway server, pipelining small chunks together into a single socket stream."""
Expand Down Expand Up @@ -307,15 +324,6 @@ def process(self, chunk_req: ChunkRequest, dst_host: str):
print(f"[{self.handle}:{self.worker_id}] Error registering chunks {chunk_ids} to {dst_host}: {e}")
raise e

# contact server to set up socket connection
if self.destination_ports.get(dst_host) is None:
print(f"[sender-{self.worker_id}]:{chunk_ids} creating new socket")
self.destination_sockets[dst_host] = retry_backoff(
partial(self.make_socket, dst_host), max_retries=3, exception_class=socket.timeout
)
print(f"[sender-{self.worker_id}]:{chunk_ids} created new socket")
sock = self.destination_sockets[dst_host]

# TODO: cleanup so this isn't a loop
for idx, chunk_req in enumerate(chunk_reqs):
# self.chunk_store.state_start_upload(chunk_id, f"sender:{self.worker_id}")
Expand Down Expand Up @@ -347,16 +355,13 @@ def process(self, chunk_req: ChunkRequest, dst_host: str):
raw_wire_length=raw_wire_length,
is_compressed=(compressed_length is not None),
)
# print(f"[sender-{self.worker_id}]:{chunk_id} sending chunk header {header}")
header.to_socket(sock)
# print(f"[sender-{self.worker_id}]:{chunk_id} sent chunk header")

# send chunk data
assert chunk_file_path.exists(), f"chunk file {chunk_file_path} does not exist"
# file_size = os.path.getsize(chunk_file_path)

with Timer() as t:
sock.sendall(data)

while True:
with Timer() as t:
is_suc = self.send_data(dst_host=dst_host, header=header, data=data)
if is_suc: break

# logger.debug(f"[sender:{self.worker_id}]:{chunk_id} sent at {chunk.chunk_length_bytes * 8 / t.elapsed / MB:.2f}Mbps")
print(f"[sender:{self.worker_id}]:{chunk_id} sent at {wire_length * 8 / t.elapsed / MB:.2f}Mbps")
Expand Down
102 changes: 54 additions & 48 deletions skyplane/gateway/operators/gateway_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,31 +145,31 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
init_space = self.chunk_store.remaining_bytes()
print("Init space", init_space)
while True:
# receive header and write data to file
logger.debug(f"[receiver:{server_port}] Blocking for next header")
chunk_header = WireProtocolHeader.from_socket(conn)
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} Got chunk header {chunk_header}")

# TODO: this wont work
# chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id)

should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region
should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region

# wait for space
# while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks:
# print(
# f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
# )
# time.sleep(0.1)

# get data
# self.chunk_store.state_queue_download(chunk_header.chunk_id)
# self.chunk_store.state_start_download(chunk_header.chunk_id, f"receiver:{self.worker_id}")
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} wire header length {chunk_header.data_len}")
with Timer() as t:
fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id)
with fpath.open("wb") as f:
try:
# receive header and write data to file
logger.debug(f"[receiver:{server_port}] Blocking for next header")
chunk_header = WireProtocolHeader.from_socket(conn)
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} Got chunk header {chunk_header}")

# TODO: this wont work
# chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id)

should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region
should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region

# wait for space
# while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks:
# print(
# f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
# )
# time.sleep(0.1)

# get data
# self.chunk_store.state_queue_download(chunk_header.chunk_id)
# self.chunk_store.state_start_download(chunk_header.chunk_id, f"receiver:{self.worker_id}")
logger.debug(f"[receiver:{server_port}]:{chunk_header.chunk_id} wire header length {chunk_header.data_len}")
with Timer() as t:

socket_data_len = chunk_header.data_len
chunk_received_size, chunk_received_size_decompressed = 0, 0
to_write = bytearray(socket_data_len)
Expand Down Expand Up @@ -199,29 +199,35 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]):
print(
f"[receiver:{server_port}]:{chunk_header.chunk_id} Decompressing {len(to_write)} bytes to {chunk_received_size_decompressed} bytes"
)

# try to write data until successful
while True:
try:
f.seek(0, 0)
f.write(to_write)
f.flush()

# check write succeeds
assert os.path.exists(fpath)

# check size
file_size = os.path.getsize(fpath)
if file_size == chunk_header.raw_data_len:
break
elif file_size >= chunk_header.raw_data_len:
raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.raw_data_len}")
except Exception as e:
print(e)
print(
f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
)
time.sleep(1)
except socket.error as e:
print(e)
# This may have pipeline broken error, if happened then restart receiver.
continue

fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id)
with fpath.open("wb") as f:
# try to write data until successful
while True:
try:
f.seek(0, 0)
f.write(to_write)
f.flush()

# check write succeeds
assert os.path.exists(fpath)

# check size
file_size = os.path.getsize(fpath)
if file_size == chunk_header.raw_data_len:
break
elif file_size >= chunk_header.raw_data_len:
raise ValueError(f"[Gateway] File size {file_size} greater than chunk size {chunk_header.raw_data_len}")
except Exception as e:
print(e)
print(
f"[receiver:{server_port}]: No remaining space with bytes {self.chunk_store.remaining_bytes()} data len {chunk_header.data_len} max pending {self.max_pending_chunks}, total space {init_space}"
)
time.sleep(1)
assert (
socket_data_len == 0 and chunk_received_size == chunk_header.data_len
), f"Size mismatch: got {chunk_received_size} expected {chunk_header.data_len} and had {socket_data_len} bytes remaining"
Expand Down
Loading