diff --git a/src/Database/Redis/Cluster.hs b/src/Database/Redis/Cluster.hs index 4a1367d4..09ebb1bc 100644 --- a/src/Database/Redis/Cluster.hs +++ b/src/Database/Redis/Cluster.hs @@ -443,24 +443,24 @@ allMasterNodes (Connection nodeConns _ _ _ _) (ShardMap shardMap) = requestNode :: NodeConnection -> [[B.ByteString]] -> IO [Reply] requestNode (NodeConnection pool lastRecvRef nid) requests = do envTimeout <- round . (\x -> (x :: Time.NominalDiffTime) * 1000000) . realToFrac . fromMaybe (0.5 :: Double) . (>>= readMaybe) <$> lookupEnv "REDIS_REQUEST_NODE_TIMEOUT" - eresp <- race requestNodeImpl (threadDelay envTimeout) + eresp <- race (withResource pool requestNodeImpl) (threadDelay envTimeout) case eresp of Left e -> return e Right _ -> putStrLn ("timeout happened" ++ show nid) *> throwIO NoNodeException where - requestNodeImpl :: IO [Reply] - requestNodeImpl = do - mapM_ (sendNode . renderRequest) requests - _ <- withResource pool CC.flush - replicateM (length requests) recvNode - sendNode :: B.ByteString -> IO () - sendNode reqs = withResource pool (`CC.send` reqs) - recvNode :: IO Reply - recvNode = do + requestNodeImpl :: CC.ConnectionContext -> IO [Reply] + requestNodeImpl ctx = do + mapM_ (sendNode ctx . renderRequest) requests + _ <- CC.flush ctx + replicateM (length requests) $ recvNode ctx + sendNode :: CC.ConnectionContext -> B.ByteString -> IO () + sendNode = CC.send + recvNode :: CC.ConnectionContext -> IO Reply + recvNode ctx = do maybeLastRecv <- IOR.readIORef lastRecvRef scanResult <- case maybeLastRecv of - Just lastRecv -> Scanner.scanWith (withResource pool CC.recv) reply lastRecv - Nothing -> Scanner.scanWith (withResource pool CC.recv) reply B.empty + Just lastRecv -> Scanner.scanWith (CC.recv ctx) reply lastRecv + Nothing -> Scanner.scanWith (CC.recv ctx) reply B.empty case scanResult of Scanner.Fail{} -> CC.errConnClosed