Skip to content

Commit 03efe58

Browse files
committed
Add way to disable caching for http_client
1 parent 9b470b3 commit 03efe58

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed

gateway/mw_streaming.go

+45-6
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ func (s *StreamingMiddleware) createStreamManager(r *http.Request) *StreamManage
210210
configJSON, _ := json.Marshal(streamsConfig)
211211
cacheKey := fmt.Sprintf("%x", sha256.Sum256(configJSON))
212212

213+
// Determine if caching should be disabled
214+
disableCache := s.shouldDisableCache(streamsConfig)
215+
213216
// Critical section starts here
214217
// This section is called by ProcessRequest method of the middleware implementation
215218
// Concurrent requests can call this method at the same time and those requests
@@ -221,9 +224,11 @@ func (s *StreamingMiddleware) createStreamManager(r *http.Request) *StreamManage
221224

222225
s.Logger().Debug("Attempting to load stream manager from cache")
223226
s.Logger().Debugf("Cache key: %s", cacheKey)
224-
if cachedManager, found := s.streamManagerCache.Load(cacheKey); found {
225-
s.Logger().Debug("Found cached stream manager")
226-
return cachedManager.(*StreamManager)
227+
if !disableCache {
228+
if cachedManager, found := s.streamManagerCache.Load(cacheKey); found {
229+
s.Logger().Debug("Found cached stream manager")
230+
return cachedManager.(*StreamManager)
231+
}
227232
}
228233

229234
newStreamManager := &StreamManager{
@@ -234,12 +239,35 @@ func (s *StreamingMiddleware) createStreamManager(r *http.Request) *StreamManage
234239
}
235240
newStreamManager.initStreams(r, streamsConfig)
236241

237-
if r != nil {
242+
if !disableCache && r != nil {
238243
s.streamManagerCache.Store(cacheKey, newStreamManager)
239244
}
240245
return newStreamManager
241246
}
242247

248+
func (s *StreamingMiddleware) shouldDisableCache(streamsConfig *StreamsConfig) bool {
249+
for _, stream := range streamsConfig.Streams {
250+
if streamMap, ok := stream.(map[string]interface{}); ok {
251+
inputType := s.getComponentType(streamMap, "input")
252+
outputType := s.getComponentType(streamMap, "output")
253+
if inputType == "http_client" && outputType == "http_server" {
254+
return true
255+
}
256+
}
257+
}
258+
return false
259+
}
260+
261+
// getComponentType returns the type of the input or output component from the stream configuration
262+
func (s *StreamingMiddleware) getComponentType(streamConfig map[string]interface{}, component string) string {
263+
if componentMap, ok := streamConfig[component].(map[string]interface{}); ok {
264+
if typeStr, ok := componentMap["type"].(string); ok {
265+
return typeStr
266+
}
267+
}
268+
return ""
269+
}
270+
243271
// Helper function to extract paths from an http_server configuration
244272
func extractPaths(httpConfig map[string]interface{}) map[string]string {
245273
paths := make(map[string]string)
@@ -547,7 +575,7 @@ func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter,
547575
defer h.sm.activityCounter.Add(-1)
548576

549577
hasInput := h.inputHandlers[path] != nil
550-
hasOutput := h.inputHandlers[path] != nil
578+
hasOutput := h.outputHandlers[path] != nil
551579

552580
if !hasInput || !hasOutput {
553581
h.logger.Debugf("Only output handler found for path: %s, executing directly", path)
@@ -576,6 +604,17 @@ func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter,
576604
}
577605

578606
handler(w, r)
607+
if handlerType == "output" {
608+
streamsConfig := &StreamsConfig{
609+
Streams: map[string]any{
610+
"stream": h.config,
611+
},
612+
}
613+
if h.mw.shouldDisableCache(streamsConfig) {
614+
h.logger.Debugf("Cache disabled, removing stream %s after output handler", h.streamID)
615+
h.sm.removeStream(h.streamID)
616+
}
617+
}
579618
case pathKey == "ws_path" && websocket.IsWebSocketUpgrade(r):
580619
h.handleWebSocket(f, w, r, path)
581620
default:
@@ -589,7 +628,7 @@ func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter,
589628

590629
func (h *handleFuncAdapter) handleWebSocket(f func(w http.ResponseWriter, r *http.Request), w http.ResponseWriter, r *http.Request, path string) {
591630
if h.inputHandlers[path] == nil || h.outputHandlers[path] == nil {
592-
h.logger.Debugf("Executing directly", path)
631+
h.logger.Debugf("Executing directly %s", path)
593632
f(w, r)
594633
return
595634
}

0 commit comments

Comments
 (0)