Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/application/distributed.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func initDistributed(cfg *config.ApplicationConfig, authDB *gorm.DB, configLoade
cfg.Distributed.HealthCheckIntervalOrDefault(),
cfg.Distributed.StaleNodeThresholdOrDefault(),
routerAuthToken,
cfg.Distributed.PerModelHealthCheck,
!cfg.Distributed.DisablePerModelHealthCheck,
)

// Initialize job store
Expand Down
10 changes: 9 additions & 1 deletion core/config/distributed_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@ type DistributedConfig struct {
DrainTimeout time.Duration // Time to wait for in-flight requests during drain (default 30s)
HealthCheckInterval time.Duration // Health monitor check interval (default 15s)
StaleNodeThreshold time.Duration // Time before a node is considered stale (default 60s)
PerModelHealthCheck bool // Enable per-model backend health checking (default false)
// DisablePerModelHealthCheck turns off the health monitor's per-model
// gRPC probe. When enabled (the default), the monitor pings each model's
// gRPC address and removes stale node_models rows whose backend has
// crashed even though the worker's node-level heartbeat is still arriving.
// Without per-model probing, /embeddings and /completions can be dispatched
// to a backend that silently returns garbage (see also the cascading
// model-row cleanup on MarkUnhealthy / MarkDraining).
DisablePerModelHealthCheck bool

MCPCIJobTimeout time.Duration // MCP CI job execution timeout (default 10m)

MaxUploadSize int64 // Maximum upload body size in bytes (default 50 GB)
Expand Down
67 changes: 59 additions & 8 deletions core/services/nodes/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ import (
"gorm.io/gorm"
)

// perModelMissThreshold is the number of consecutive failed gRPC probes
// against a model's backend before the model is removed from the registry.
// A single failure can be transient (network blip, brief GC pause on the
// worker, a long-running request hogging the gRPC server thread); requiring
// N consecutive misses avoids deleting healthy rows over noise. At the
// default 15s tick this means a model has to be unreachable for ~45s before
// it gets reaped.
const perModelMissThreshold = 3

// modelKey identifies a specific (node, model, replica) tuple. We track miss
// counts per tuple because the same model name can be loaded on multiple
// replicas on the same node.
type modelKey struct {
NodeID string
ModelName string
ReplicaIndex int
}

// HealthMonitor periodically checks the health of registered backend nodes.
type HealthMonitor struct {
registry NodeHealthStore
Expand All @@ -21,6 +39,8 @@ type HealthMonitor struct {
autoOffline bool // mark stale nodes as offline (preserves approval status)
clientFactory BackendClientFactory // creates gRPC backend clients
perModelHealthCheck bool // check each model's backend process individually
missesMu sync.Mutex
misses map[modelKey]int // consecutive failed-probe counts; reset on success or model removal
cancel context.CancelFunc
cancelMu sync.Mutex
}
Expand All @@ -46,6 +66,7 @@ func NewHealthMonitor(registry NodeHealthStore, db *gorm.DB, checkInterval, stal
autoOffline: true,
clientFactory: factory,
perModelHealthCheck: perModelHealthCheck,
misses: make(map[modelKey]int),
}
}

Expand Down Expand Up @@ -152,9 +173,11 @@ func (hm *HealthMonitor) doCheckAll(ctx context.Context) {
}
}

// Per-model backend health check (opt-in): probe each model's gRPC address
// and remove stale model records. This does NOT affect the node's status —
// a crashed backend process is a model-level issue, not a node-level one.
// Per-model backend health check: probe each model's gRPC address and
// remove stale model records. This does NOT affect the node's status —
// a crashed backend process is a model-level issue, not a node-level
// one. A model is only removed after perModelMissThreshold consecutive
// failed probes so a single network/GC blip doesn't force a reload.
if hm.perModelHealthCheck {
models, _ := hm.registry.GetNodeModels(ctx, node.ID)
for _, m := range models {
Expand All @@ -163,15 +186,43 @@ func (hm *HealthMonitor) doCheckAll(ctx context.Context) {
}
mClient := hm.clientFactory.NewClient(m.Address, false)
mCheckCtx, mCancel := context.WithTimeout(ctx, 5*time.Second)
if ok, _ := mClient.HealthCheck(mCheckCtx); !ok {
xlog.Warn("Model backend unhealthy, removing from registry",
"node", node.ID, "model", m.ModelName, "replica", m.ReplicaIndex, "address", m.Address)
hm.registry.RemoveNodeModel(ctx, node.ID, m.ModelName, m.ReplicaIndex)
}
ok, _ := mClient.HealthCheck(mCheckCtx)
mCancel()
if closer, ok := mClient.(io.Closer); ok {
closer.Close()
}

key := modelKey{NodeID: node.ID, ModelName: m.ModelName, ReplicaIndex: m.ReplicaIndex}
hm.missesMu.Lock()
if ok {
// Probe succeeded — wipe any previous miss streak.
delete(hm.misses, key)
hm.missesMu.Unlock()
continue
}
hm.misses[key]++
misses := hm.misses[key]
hm.missesMu.Unlock()

if misses < perModelMissThreshold {
xlog.Debug("Model backend probe failed, awaiting threshold before removal",
"node", node.ID, "model", m.ModelName, "replica", m.ReplicaIndex,
"address", m.Address, "misses", misses, "threshold", perModelMissThreshold)
continue
}
xlog.Warn("Model backend unhealthy after consecutive misses, removing from registry",
"node", node.ID, "model", m.ModelName, "replica", m.ReplicaIndex,
"address", m.Address, "misses", misses)
if err := hm.registry.RemoveNodeModel(ctx, node.ID, m.ModelName, m.ReplicaIndex); err != nil {
xlog.Warn("Failed to remove unhealthy model from registry",
"node", node.ID, "model", m.ModelName, "replica", m.ReplicaIndex, "error", err)
// Leave the miss counter in place so the next tick retries
// the removal rather than starting the streak over.
continue
}
hm.missesMu.Lock()
delete(hm.misses, key)
hm.missesMu.Unlock()
}
}
}
Expand Down
1 change: 1 addition & 0 deletions core/services/nodes/health_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ func newTestHealthMonitor(store NodeHealthStore, factory BackendClientFactory, a
staleThreshold: staleThreshold,
autoOffline: autoOffline,
clientFactory: factory,
misses: make(map[modelKey]int),
}
}

Expand Down
42 changes: 41 additions & 1 deletion core/services/nodes/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ var _ = Describe("HealthMonitor (mock-based)", func() {
Expect(calls).NotTo(ContainElement(ContainSubstring("MarkUnhealthy")))
})

It("removes stale model via per-model health check without affecting node status", func() {
It("removes stale model via per-model health check after consecutive failures", func() {
store := newFakeNodeHealthStore()
factory := newFakeBackendClientFactory()
hm := newTestHealthMonitor(store, factory, true, staleThreshold)
Expand All @@ -268,12 +268,52 @@ var _ = Describe("HealthMonitor (mock-based)", func() {
// Model backend is dead
factory.setClient("10.0.0.10:50053", &fakeBackendClient{healthy: false, err: fmt.Errorf("connection refused")})

// First (perModelMissThreshold-1) probes must NOT remove the row —
// a single failure could be a transient blip.
for i := 0; i < perModelMissThreshold-1; i++ {
hm.doCheckAll(context.Background())
Expect(store.getCalls()).NotTo(ContainElement(ContainSubstring("RemoveNodeModel")),
"removed too early at miss %d", i+1)
}

// Threshold-th consecutive miss triggers removal.
hm.doCheckAll(context.Background())

// Node should remain healthy — only the specific replica record is removed.
Expect(store.getNode("node-model").Status).To(Equal(StatusHealthy))
Expect(store.getCalls()).To(ContainElement("RemoveNodeModel:node-model:piper-model:0"))
Expect(store.getCalls()).NotTo(ContainElement(ContainSubstring("MarkUnhealthy")))
})

It("preserves model row when an intermittent failure is followed by a success", func() {
store := newFakeNodeHealthStore()
factory := newFakeBackendClientFactory()
hm := newTestHealthMonitor(store, factory, true, staleThreshold)
hm.perModelHealthCheck = true

node := makeTestNode("node-flap", "flap-worker", "10.0.0.11:50051", StatusHealthy, freshTime())
store.addNode(node)
store.addNodeModel("node-flap", NodeModel{NodeID: "node-flap", ModelName: "piper-model", Address: "10.0.0.11:50053"})

deadClient := &fakeBackendClient{healthy: false, err: fmt.Errorf("connection refused")}
liveClient := &fakeBackendClient{healthy: true}

// Two failing probes then a recovery — should NOT remove the row,
// and should reset the miss counter so two more failures don't tip
// it over.
factory.setClient("10.0.0.11:50053", deadClient)
hm.doCheckAll(context.Background())
hm.doCheckAll(context.Background())
factory.setClient("10.0.0.11:50053", liveClient)
hm.doCheckAll(context.Background())

Expect(store.getCalls()).NotTo(ContainElement(ContainSubstring("RemoveNodeModel")))

// Counter is reset; two more failures must not be enough to remove.
factory.setClient("10.0.0.11:50053", deadClient)
hm.doCheckAll(context.Background())
hm.doCheckAll(context.Background())
Expect(store.getCalls()).NotTo(ContainElement(ContainSubstring("RemoveNodeModel")))
})
})
})
37 changes: 33 additions & 4 deletions core/services/nodes/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,13 @@ func (r *NodeRegistry) GetByName(ctx context.Context, name string) (*BackendNode
return &node, nil
}

// MarkUnhealthy sets a node status to unhealthy.
// MarkUnhealthy sets a node status to unhealthy. Deliberately status-only:
// callers fire this on transient triggers (a single nats.ErrNoResponders from
// managers_distributed / reconciler) where the next heartbeat is expected to
// flip the node back to healthy, and cascade-deleting node_models here would
// force a full model reload on every brief NATS hiccup. Stale rows are reaped
// by the per-model health probe (on by default; see HealthMonitor) and by
// MarkOffline when the heartbeat really has gone away.
func (r *NodeRegistry) MarkUnhealthy(ctx context.Context, nodeID string) error {
return r.setStatus(ctx, nodeID, StatusUnhealthy)
}
Expand All @@ -556,9 +562,23 @@ func (r *NodeRegistry) MarkHealthy(ctx context.Context, nodeID string) error {
return r.setStatus(ctx, nodeID, StatusHealthy)
}

// MarkDraining sets a node status to draining (no new requests).
// MarkDraining sets a node status to draining (no new requests) and clears its
// model records. Routing already filters out non-healthy nodes, so removing
// the rows on drain doesn't change new-request behavior — but it does stop the
// Models UI from showing the node's models as "running" while the box has been
// taken out of rotation, and it prevents stale rows from being selected if
// (re)scheduling logic gets relaxed elsewhere. In-flight requests already hold
// their gRPC client through Route() and will finish normally; the only
// observable effect is that the per-call IncrementInFlight bookkeeping logs a
// non-fatal warning, which is acceptable for a drain.
func (r *NodeRegistry) MarkDraining(ctx context.Context, nodeID string) error {
return r.setStatus(ctx, nodeID, StatusDraining)
if err := r.setStatus(ctx, nodeID, StatusDraining); err != nil {
return err
}
if err := r.db.WithContext(ctx).Where("node_id = ?", nodeID).Delete(&NodeModel{}).Error; err != nil {
xlog.Warn("Failed to clear model records on draining", "node", nodeID, "error", err)
}
return nil
}

// FindStaleNodes returns nodes that haven't sent a heartbeat within the given threshold.
Expand Down Expand Up @@ -673,9 +693,18 @@ func (r *NodeRegistry) FindAndLockNodeWithModel(ctx context.Context, modelName s
// to moderate concurrency where requests don't overlap) collapses to
// "biggest GPU wins every time" and one node ends up taking nearly all
// the load while replicas on other nodes sit idle.
// Filter on backend_nodes.status = healthy in the inner JOIN itself,
// not only in the later node-fetch step. The previous version picked
// a (node_id, replica) pair purely on node_models state, then bailed
// out when the second query couldn't find a healthy node row — but
// any concurrent reader of node_models could still pick the same
// stale row in the same window, and other helpers that mirror this
// JOIN need the same invariant. Belt-and-braces: status filter here
// AND the status-checked node fetch below.
q := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
Joins("JOIN backend_nodes ON backend_nodes.id = node_models.node_id").
Where("node_models.model_name = ? AND node_models.state = ?", modelName, "loaded")
Where("node_models.model_name = ? AND node_models.state = ? AND backend_nodes.status = ?",
modelName, "loaded", StatusHealthy)
if len(candidateNodeIDs) > 0 {
q = q.Where("node_models.node_id IN ?", candidateNodeIDs)
}
Expand Down
Loading