Skip to content

Commit

Permalink
[bp][spark] Make xgboost spark support large model size (dmlc#10984)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 authored and trivialfis committed Nov 19, 2024
1 parent f199039 commit 60fe694
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
34 changes: 22 additions & 12 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ def _get_unwrapped_vec_cols(feature_col: Column) -> List[Column]:
)


_MODEL_CHUNK_SIZE = 4096 * 1024


class _SparkXGBEstimator(Estimator, _SparkXGBParams, MLReadable, MLWritable):
_input_kwargs: Dict[str, Any]

Expand Down Expand Up @@ -1091,25 +1094,27 @@ def _train_booster(
context.barrier()

if context.partitionId() == 0:
yield pd.DataFrame(
data={
"config": [booster.save_config()],
"booster": [booster.save_raw("json").decode("utf-8")],
}
)
config = booster.save_config()
yield pd.DataFrame({"data": [config]})
booster_json = booster.save_raw("json").decode("utf-8")

for offset in range(0, len(booster_json), _MODEL_CHUNK_SIZE):
booster_chunk = booster_json[offset : offset + _MODEL_CHUNK_SIZE]
yield pd.DataFrame({"data": [booster_chunk]})

def _run_job() -> Tuple[str, str]:
rdd = (
dataset.mapInPandas(
_train_booster, # type: ignore
schema="config string, booster string",
schema="data string",
)
.rdd.barrier()
.mapPartitions(lambda x: x)
)
rdd_with_resource = self._try_stage_level_scheduling(rdd)
ret = rdd_with_resource.collect()[0]
return ret[0], ret[1]
ret = rdd_with_resource.collect()
data = [v[0] for v in ret]
return data[0], "".join(data[1:])

get_logger(_LOG_TAG).info(
"Running xgboost-%s on %s workers with"
Expand Down Expand Up @@ -1690,7 +1695,12 @@ def saveImpl(self, path: str) -> None:
_SparkXGBSharedReadWrite.saveMetadata(self.instance, path, self.sc, self.logger)
model_save_path = os.path.join(path, "model")
booster = xgb_model.get_booster().save_raw("json").decode("utf-8")
_get_spark_session().sparkContext.parallelize([booster], 1).saveAsTextFile(
booster_chunks = []

for offset in range(0, len(booster), _MODEL_CHUNK_SIZE):
booster_chunks.append(booster[offset : offset + _MODEL_CHUNK_SIZE])

_get_spark_session().sparkContext.parallelize(booster_chunks, 1).saveAsTextFile(
model_save_path
)

Expand Down Expand Up @@ -1721,8 +1731,8 @@ def load(self, path: str) -> "_SparkXGBModel":
)
model_load_path = os.path.join(path, "model")

ser_xgb_model = (
_get_spark_session().sparkContext.textFile(model_load_path).collect()[0]
ser_xgb_model = "".join(
_get_spark_session().sparkContext.textFile(model_load_path).collect()
)

def create_xgb_model() -> "XGBModel":
Expand Down
20 changes: 20 additions & 0 deletions tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,26 @@ def test_regressor_model_pipeline_save_load(self, reg_data: RegData) -> None:
)
assert_model_compatible(model.stages[0], tmpdir)

def test_with_small_model_chunk_size(self, reg_data: RegData, monkeypatch) -> None:
import xgboost.spark.core

monkeypatch.setattr(xgboost.spark.core, "_MODEL_CHUNK_SIZE", 4)
with tempfile.TemporaryDirectory() as tmpdir:
path = "file:" + tmpdir
regressor = SparkXGBRegressor(**reg_data.reg_params)
model = regressor.fit(reg_data.reg_df_train)
model.save(path)
loaded_model = SparkXGBRegressorModel.load(path)
assert model.uid == loaded_model.uid
for k, v in reg_data.reg_params.items():
assert loaded_model.getOrDefault(k) == v

pred_result = loaded_model.transform(reg_data.reg_df_test).collect()
for row in pred_result:
assert np.isclose(
row.prediction, row.expected_prediction_with_params, atol=1e-3
)

def test_device_param(self, reg_data: RegData, clf_data: ClfData) -> None:
clf = SparkXGBClassifier(device="cuda", tree_method="exact")
with pytest.raises(ValueError, match="not supported for distributed"):
Expand Down

0 comments on commit 60fe694

Please sign in to comment.