From 5fbd53c6167f058582f5834443304aeb16e89ee9 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Mon, 9 Dec 2024 19:08:03 -0800 Subject: [PATCH] delay TTL check until after all requests are complete (#25) - fixes #25 where requests that last longer than the TTL will cause the process to be unloaded before the next request. - new behavior, TTL waits until all requests are complete before checking timeout --- proxy/process.go | 17 +++++++++-------- proxy/process_test.go | 21 +++++++++++++++++---- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/proxy/process.go b/proxy/process.go index 808917b..c20d624 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -122,16 +122,15 @@ func (p *Process) start() error { // start a goroutine to check every second if // the process should be stopped go func() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() maxDuration := time.Duration(p.config.UnloadAfter) * time.Second - for { - <-ticker.C + for range time.Tick(time.Second) { + // wait for all inflight requests to complete and ticker + p.inFlightRequests.Wait() + if time.Since(p.lastRequestHandled) > maxDuration { fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter) p.Stop() - return } } }() @@ -275,7 +274,11 @@ func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { p.inFlightRequests.Add(1) - defer p.inFlightRequests.Done() + + defer func() { + p.lastRequestHandled = time.Now() + p.inFlightRequests.Done() + }() if p.CurrentState() != StateReady { if err := p.start(); err != nil { @@ -285,8 +288,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { } } - p.lastRequestHandled = time.Now() - proxyTo := p.config.Proxy client := &http.Client{} req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) diff --git a/proxy/process_test.go b/proxy/process_test.go index fbf629b..1367ff0 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -82,18 +82,31 @@ func TestProcess_UnloadAfterTTL(t *testing.T) { process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard)) defer process.Stop() - req := httptest.NewRequest("GET", "/test", nil) + // this should take 4 seconds + req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil) + req2 := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() - // Proxy the request (auto start) - process.ProxyRequest(w, req) + // Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter + process.ProxyRequest(w, req1) + t.Log("sending slow first request (4 seconds)") assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), expectedMessage) + assert.Contains(t, w.Body.String(), "1234") + assert.Equal(t, StateReady, process.CurrentState()) + // ensure the TTL timeout does not race slow requests (see issue #25) + t.Log("sending second request (1 second)") + time.Sleep(time.Second) + w = httptest.NewRecorder() + process.ProxyRequest(w, req2) + assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), expectedMessage) assert.Equal(t, StateReady, process.CurrentState()) // wait 5 seconds + t.Log("sleep 5 seconds and check if unloaded") time.Sleep(5 * time.Second) assert.Equal(t, StateStopped, process.CurrentState()) }