Skip to content
90 changes: 90 additions & 0 deletions modelopt/torch/speculative/plugins/hf_streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from typing import TypedDict

import httpx
import base64
import time
import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
from safetensors import SafetensorError, safe_open
Expand Down Expand Up @@ -410,7 +412,95 @@ def _next_url(self) -> str:
self._rr += 1
return url

# ---- RDMA transport (no-disk hidden states via NIXL) ----
def _rdma(self):
pid = os.getpid()
if getattr(self, "_nixl_pid", None) != pid:
from nixl._api import nixl_agent, nixl_agent_config
self._nixl = nixl_agent(f"hs-trainer-{pid}",
nixl_agent_config(backends=["UCX"]))
self._nixl_pid = pid
self._remote_by_host: dict = {}
self._recv = None
_wi = torch.utils.data.get_worker_info()
self._rr = int(os.environ.get("RANK", "0")) * (_wi.num_workers if _wi else 1) + (_wi.id if _wi else 0)
self._http_rdma = httpx.Client(
timeout=httpx.Timeout(self.config.request_timeout, connect=10.0))
return self._nixl

def _remote(self, host, port):
if host not in self._remote_by_host:
m = self._http_rdma.get(f"http://{host}:{port}/meta").json()["agent_metadata"]
self._remote_by_host[host] = self._nixl.add_remote_agent(base64.b64decode(m))
return self._remote_by_host[host]

def _fetch_rdma(self, sample: dict):
agent = self._rdma()
url = self._next_url()
host = url.split("://", 1)[-1].split(":")[0]
port = int(os.environ.get("HS_SIDECAR_PORT", "18999"))
r = self._http_rdma.post(
f"{url}/v1/completions",
json={"model": self.config.model, "prompt": sample["token_ids"],
"max_tokens": 1, "temperature": 0})
r.raise_for_status()
rid = (r.json().get("kv_transfer_params") or {}).get("hs_req_id")
if rid is None:
warn_rank_0(f"[streaming] no hs_req_id for {sample['cid']}")
return None
remote = self._remote(host, port)
desc = None
deadline = time.time() + self.config.request_timeout
while time.time() < deadline:
rr = self._http_rdma.get(f"http://{host}:{port}/desc", params={"req_id": rid})
if rr.status_code == 200 and rr.json().get("ready"):
desc = rr.json(); break
time.sleep(0.002)
if desc is None:
warn_rank_0(f"[streaming] rdma desc timeout for {sample['cid']}")
return None
shape = tuple(desc["hs_shape"]); dtype = getattr(torch, desc["hs_dtype"])
feat = shape[1:]
maxtok = self.config.max_seq_len
if self._recv is None or self._recv.dtype != dtype or tuple(self._recv.shape[1:]) != feat:
# plain (pageable) host tensor: NIXL/ibv_reg_mr pins the pages itself.
# Do NOT call .pin_memory() here — dataloader workers are forked and have no
# valid CUDA context (cudaHostAlloc -> CUDA initialization error).
self._recv = torch.empty((maxtok, *feat), dtype=dtype)
agent.register_memory([self._recv])
view = self._recv[:shape[0]]
ldescs = agent.get_xfer_descs([view])
rdescs = agent.deserialize_descs(base64.b64decode(desc["hs_descs"]))
h = agent.initialize_xfer("READ", ldescs, rdescs, remote)
agent.transfer(h)
while True:
st = agent.check_xfer_state(h)
if st == "DONE":
break
if st == "ERR":
agent.release_xfer_handle(h)
warn_rank_0(f"[streaming] rdma xfer ERR for {sample['cid']}")
return None
time.sleep(0.0002)
agent.release_xfer_handle(h)
try:
self._http_rdma.get(f"http://{host}:{port}/done", params={"req_id": rid})
except Exception:
pass
hidden_states = view.clone()
token_ids = torch.tensor(desc["token_ids"], dtype=torch.long)
client_ids = torch.as_tensor(sample["token_ids"], dtype=token_ids.dtype)
n = client_ids.shape[0]
if token_ids.shape[0] not in (n, n + 1) or not torch.equal(token_ids[:n], client_ids):
raise RuntimeError(
f"server token_ids drift for {sample['cid']}: client_len={n}, "
f"server_len={token_ids.shape[0]}")
loss_mask = self._align_loss_mask(sample["loss_mask"], token_ids.shape[0])
return {"token_ids": token_ids, "hidden_states": hidden_states, "loss_mask": loss_mask}

def _fetch(self, sample: dict) -> EagleFetchPayload | None:
if os.environ.get("HS_TRANSPORT") == "rdma":
return self._fetch_rdma(sample)
client = self._client()
url = self._next_url()
r = client.post(
Expand Down
Loading