Skip to content

Commit d15ac9e

Browse files
committed
add 8bit quant all reduce support
1 parent 2de80a3 commit d15ac9e

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

open_diloco/train_fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class HvConfig(BaseConfig):
9797
announce_maddrs: list[str] | None = None
9898
matchmaking_time: float | None = None
9999
averaging_timeout: float | None = None
100-
hivemind_compression: Literal["none", "fp16", "scaled-fp16"] = "none"
100+
hivemind_compression: Literal["none", "fp16", "scaled-fp16", "uniform8bit", "quantile8bit"] = "none"
101101
all_reduce_strategy: AllReduceStrategy = AllReduceStrategy.WAIT_FOR_ALL
102102
timeout_waiting_for_peers: float | None = None
103103
skip_load_from_peers: bool = False

open_diloco/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,16 @@ def get_compression_kwargs(hivemind_compression: str) -> dict:
108108

109109
ret_kwargs["grad_compression"] = NoCompression()
110110
ret_kwargs["state_averaging_compression"] = NoCompression()
111+
elif hivemind_compression == "uniform8bit":
112+
from hivemind import Uniform8BitQuantization
113+
114+
ret_kwargs["grad_compression"] = Uniform8BitQuantization()
115+
ret_kwargs["state_averaging_compression"] = Uniform8BitQuantization()
116+
elif hivemind_compression == "quantile8bit":
117+
from hivemind import Quantile8BitQuantization
118+
119+
ret_kwargs["grad_compression"] = Quantile8BitQuantization()
120+
ret_kwargs["state_averaging_compression"] = Quantile8BitQuantization()
111121
else:
112122
raise ValueError(
113123
f"Invalid hivemind_compression: {hivemind_compression}. Please choose 'none', 'fp16', or 'scaled-fp16'."

0 commit comments

Comments
 (0)