Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 85 additions & 47 deletions pyzoo/zoo/orca/learn/mpi/mpi_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,13 @@ def fit(self, data, epochs=1, batch_size=32, validation_data=None, validate_batc
assert feature_cols is not None and label_cols is not None, \
"feature_cols and label_cols must be provided if data is a Spark DataFrame"
data = data.rdd.map(convert_row(feature_cols, label_cols))
# TODO: make object store memory configurable?
object_store_address = self.mpi_runner.launch_plasma(object_store_memory="100g")
# partition_id, subpartition_id, subpartition_size, object_id, node_ip
plasma_meta = data.mapPartitionsWithIndex(
put_to_plasma(object_store_address)).collect()
# partition_id, subpartition_id, subpartition_size, file_name, node_ip
train_save_path = "/mnt/disk1/dlrm/saved/train/"
file_meta = data.mapPartitionsWithIndex(
save_file(train_save_path, feature_cols, label_cols)).collect()
# The following is mainly for debugging and confirmation purpose.
train_size_map = {}
for partition_id, subpartition_id, subpartition_size, object_id, ip in plasma_meta:
for partition_id, subpartition_id, subpartition_size, object_id, ip in file_meta:
if ip not in train_size_map:
train_size_map[ip] = {}
if partition_id not in train_size_map[ip]:
Expand All @@ -83,17 +82,18 @@ def fit(self, data, epochs=1, batch_size=32, validation_data=None, validate_batc
print("Node {} has {} subpartitions and {} train records".format(node, count, size))
size = 0
count = 0
data_creator = plasma_data_creator(plasma_meta, object_store_address,
self.mpi_runner.processes_per_node, batch_size)
data_creator = file_data_creator(file_meta, train_save_path, feature_cols, label_cols,
self.mpi_runner.processes_per_node, batch_size)
data.unpersist()
if validation_data:
assert isinstance(validation_data, DataFrame)
validation_data = validation_data.rdd.map(convert_row(feature_cols, label_cols))
validate_plasma_meta = validation_data.mapPartitionsWithIndex(
put_to_plasma(object_store_address)).collect()
validate_save_path = "/mnt/disk1/dlrm/saved/test/"
validate_file_meta = validation_data.mapPartitionsWithIndex(
save_file(validate_save_path, feature_cols, label_cols)).collect()
validate_size_map = {}
for partition_id, subpartition_id, subpartition_size, object_id, ip \
in validate_plasma_meta:
in validate_file_meta:
if ip not in validate_size_map:
validate_size_map[ip] = {}
if partition_id not in validate_size_map[ip]:
Expand All @@ -109,8 +109,8 @@ def fit(self, data, epochs=1, batch_size=32, validation_data=None, validate_batc
.format(node, count, size))
size = 0
count = 0
validation_data_creator = plasma_data_creator(
validate_plasma_meta, object_store_address,
validation_data_creator = file_data_creator(
validate_file_meta, validate_save_path, feature_cols, label_cols,
self.mpi_runner.processes_per_node, validate_batch_size)
validation_data.unpersist()
else:
Expand All @@ -135,10 +135,9 @@ def fit(self, data, epochs=1, batch_size=32, validation_data=None, validate_batc
validate_batches, validate_steps), f)
self.mpi_runner.scp_file("mpi_train_data.pkl", self.dir)
self.mpi_runner.run("{}/mpi_train.py".format(self.dir), pkl_path=self.dir)
self.mpi_runner.shutdown_plasma()

def shutdown(self):
self.mpi_runner.shutdown_plasma()
pass


def convert_row(feature_cols, label_cols):
Expand All @@ -160,42 +159,68 @@ def transform(row):
return transform


def put_to_plasma(address):
def save_file(folder_path, feature_cols, label_cols):

def f(index, iterator):
import pyarrow.plasma as plasma
client = plasma.connect(address)
part_size = 1000000 # TODO: Make subpartition size configurable?
buffer = []
sub_index = 0
for record in iterator:
if len(buffer) == part_size:
res_buffer = process_records(buffer)
object_id = client.put(res_buffer)
saved_tmp = {}
if len(feature_cols) > 1:
for i in range(len(feature_cols)):
saved_tmp[feature_cols[i]] = res_buffer["x"][i]
else:
saved_tmp[feature_cols[0]] = res_buffer["x"]
if len(label_cols) > 1:
for i in range(len(label_cols)):
saved_tmp[label_cols[i]] = res_buffer["y"][i]
else:
saved_tmp[label_cols[0]] = res_buffer["y"]
file_name = "partition{}_{}_{}.npz".format(index, sub_index, part_size)
np.savez_compressed(
folder_path + file_name,
size=part_size,
**saved_tmp
)
buffer = [record]
yield index, sub_index, part_size, object_id, get_node_ip()
yield index, sub_index, part_size, file_name, get_node_ip()
sub_index += 1
else:
buffer.append(record)
remain_size = len(buffer)
if remain_size > 0:
res_buffer = process_records(buffer)
object_id = client.put(res_buffer)
saved_tmp = {}
if len(feature_cols) > 1:
for i in range(len(feature_cols)):
saved_tmp[feature_cols[i]] = res_buffer["x"][i]
else:
saved_tmp[feature_cols[0]] = res_buffer["x"]
if len(label_cols) > 1:
for i in range(len(label_cols)):
saved_tmp[label_cols[i]] = res_buffer["y"][i]
else:
saved_tmp[label_cols[0]] = res_buffer["y"]
file_name = "partition{}_{}_{}.npz".format(index, sub_index, remain_size)
np.savez_compressed(
folder_path + file_name,
size=remain_size,
**saved_tmp
)
buffer = []
client.disconnect()
yield index, sub_index, remain_size, object_id, get_node_ip()
else:
client.disconnect()
yield index, sub_index, remain_size, file_name, get_node_ip()

return f


class PlasmaNDArrayDataset(Dataset):
def __init__(self, meta_data, object_store_address, workers_per_node=1, batch_size=1):
import pyarrow.plasma as plasma
self.client = plasma.connect(object_store_address)
print("Connected to plasma")

class FileNDArrayDataset(Dataset):
def __init__(self, meta_data, file_path, feature_cols, label_cols, workers_per_node=1, batch_size=1):
self.file_path = file_path
self.feature_cols = feature_cols
self.label_cols = label_cols
# All the subpartitions on this node
all_data = [subpartition for subpartition in meta_data if subpartition[4] == get_node_ip()]
rank = int(os.environ.get("PMI_RANK", 0))
Expand All @@ -209,7 +234,7 @@ def __init__(self, meta_data, object_store_address, workers_per_node=1, batch_si
remain_data = data_splits[-1]
if local_rank < len(remain_data):
worker_data += [remain_data[local_rank]]
self.object_ids = [subpartition[3] for subpartition in worker_data]
self.files = [subpartition[3] for subpartition in worker_data]
self.sizes = [subpartition[2] for subpartition in worker_data]
print("Data size for worker: ", sum(self.sizes))
self.batch_size = batch_size
Expand All @@ -221,17 +246,30 @@ def __init__(self, meta_data, object_store_address, workers_per_node=1, batch_si
offsets.append(offsets[-1] + i)
self.offsets = offsets
self.current_index = 0 # Current index for object_id; data loaded
self.load_from_plasma(self.current_index)
self.load_file(self.current_index)

def reset(self):
self.current_index = 0
self.load_from_plasma(self.current_index)

def load_from_plasma(self, index):
print("Loading {} of size {}".format(self.object_ids[index], self.sizes[index]))
current_data = self.client.get(self.object_ids[index], timeout_ms=0)
self.current_x = current_data["x"]
self.current_y = current_data["y"]
self.load_file(self.current_index)

def load_file(self, index):
file = self.files[index]
size = self.sizes[index]
print("Loading {} of size {}".format(file, size))
with np.load(self.file_path + file) as data:
assert data["size"] == size
if len(self.feature_cols) > 1:
self.current_x = []
for i in range(len(self.feature_cols)):
self.current_x.append(data[self.feature_cols[i]])
else:
self.current_x = data[self.feature_cols[0]]
if len(self.label_cols) > 1:
self.current_y = []
for i in range(len(self.label_cols)):
self.current_y.append(data[self.label_cols[i]])
else:
self.current_y = data[self.label_cols[0]]
self.current_offset = self.offsets[index]

def __len__(self):
Expand All @@ -253,7 +291,7 @@ def __getitem__(self, i): # Directly get a batch
remain_size = self.batch_size - current_available_size
while True:
self.current_index += 1
self.load_from_plasma(self.current_index)
self.load_file(self.current_index)
if self.sizes[self.current_index] >= remain_size:
x_list.append(index(self.current_x, end=remain_size))
y_list.append(index(self.current_y, end=remain_size))
Expand Down Expand Up @@ -285,12 +323,12 @@ def __getitem__(self, i): # Directly get a batch
return x_np, y_np


def plasma_data_creator(meta_data, object_store_address,
workers_per_node=1, batch_size=1):
def file_data_creator(meta_data, file_path, feature_cols, label_cols,
workers_per_node=1, batch_size=1):

def create_plasma_dataloader(config):
dataset = PlasmaNDArrayDataset(meta_data, object_store_address,
workers_per_node, batch_size)
def create_file_dataloader(config):
dataset = FileNDArrayDataset(meta_data, file_path, feature_cols, label_cols,
workers_per_node, batch_size)
# TODO: support more options
loader = DataLoader(
dataset,
Expand All @@ -300,7 +338,7 @@ def create_plasma_dataloader(config):
)
return loader

return create_plasma_dataloader
return create_file_dataloader


def train_epoch(config, model, train_ld, train_batches, optimizer, loss, scheduler,
Expand Down