Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One pool per node #19

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 81 additions & 54 deletions src/Database/Redis/Cluster.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
module Database.Redis.Cluster
( Connection(..)
, NodeRole(..)
Expand All @@ -13,8 +15,11 @@ module Database.Redis.Cluster
, HashSlot
, Shard(..)
, TimeoutException(..)
, TcpInfo(..)
, Host
, NodeID
, connect
, disconnect
, destroyNodeResources
, requestPipelined
, requestMasterNodes
, nodes
Expand All @@ -28,6 +33,7 @@ import Data.List(nub, sortBy, find)
import Data.Map(fromListWith, assocs)
import Data.Function(on)
import Control.Exception(Exception, SomeException, throwIO, BlockedIndefinitelyOnMVar(..), catches, Handler(..), try, fromException)
import Data.Pool(Pool, createPool, withResource, destroyAllResources)
import Control.Concurrent.Async(race)
import Control.Concurrent(threadDelay)
import Control.Concurrent.MVar(MVar, newMVar, readMVar, modifyMVar, modifyMVar_)
Expand All @@ -45,6 +51,7 @@ import Text.Read (readMaybe)

import Database.Redis.Protocol(Reply(Error), renderRequest, reply)
import qualified Database.Redis.Cluster.Command as CMD
import Network.TLS (ClientParams)

-- This module implements a clustered connection whilst maintaining
-- compatibility with the original Hedis codebase. In particular it still
Expand All @@ -59,10 +66,10 @@ import qualified Database.Redis.Cluster.Command as CMD
-- | 'NodeConnection's, a 'Pipeline', and a 'ShardMap'
type IsReadOnly = Bool

data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap IsReadOnly
data Connection = Connection (MVar NodeConnectionMap) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap IsReadOnly TcpInfo

-- | A connection to a single node in the cluster, similar to 'ProtocolPipelining.Connection'
data NodeConnection = NodeConnection CC.ConnectionContext (IOR.IORef (Maybe B.ByteString)) NodeID
data NodeConnection = NodeConnection (Pool CC.ConnectionContext) (IOR.IORef (Maybe B.ByteString)) NodeID

instance Show NodeConnection where
show (NodeConnection _ _ id1) = "nodeId: " <> show id1
Expand Down Expand Up @@ -113,6 +120,17 @@ data Shard = Shard MasterNode [SlaveNode] deriving (Show, Eq, Ord)
-- A map from hashslot to shards
newtype ShardMap = ShardMap (IntMap.IntMap Shard) deriving (Show)

type NodeConnectionMap = HM.HashMap NodeID NodeConnection

-- Object for storing Tcp Connection Info which will be used when cluster is refreshed
data TcpInfo = TcpInfo
{ connectAuth :: Maybe B.ByteString
, connectTLSParams :: Maybe ClientParams
, idleTime :: Time.NominalDiffTime
, maxResources :: Int
, timeoutOpt :: Maybe Int
} deriving Show

newtype MissingNodeException = MissingNodeException [B.ByteString] deriving (Show, Typeable)
instance Exception MissingNodeException

Expand All @@ -128,8 +146,8 @@ instance Exception NoNodeException
data TimeoutException = TimeoutException String deriving (Show, Typeable)
instance Exception TimeoutException

connect :: (Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext) -> [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> Bool -> ([NodeConnection] -> IO ShardMap) -> IO Connection
connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap = do
connect :: (Host -> CC.PortID -> Maybe Int -> IO CC.ConnectionContext) -> [CMD.CommandInfo] -> MVar ShardMap -> Bool -> ([NodeConnection] -> IO ShardMap) -> TcpInfo -> IO Connection
connect withAuth commandInfos shardMapVar isReadOnly refreshShardMap (tcpInfo@TcpInfo{ timeoutOpt, maxResources, idleTime }) = do
shardMap <- readMVar shardMapVar
stateVar <- newMVar $ Pending []
pipelineVar <- newMVar $ Pipeline stateVar
Expand All @@ -148,7 +166,8 @@ connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap
throwIO NoNodeException
else
return eNodeConns
return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) isReadOnly where
nodeConnsVar <- newMVar nodeConns
return $ Connection nodeConnsVar pipelineVar shardMapVar (CMD.newInfoMap commandInfos) isReadOnly tcpInfo where
simpleNodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection)
simpleNodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ nodes shardMap)
nodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection, Bool)
Expand All @@ -161,34 +180,34 @@ connect withAuth commandInfos shardMapVar timeoutOpt isReadOnly refreshShardMap
) (mempty, False) info
connectNode :: Node -> IO (NodeID, NodeConnection)
connectNode (Node n _ host port) = do
ctx <- withAuth host (CC.PortNumber $ toEnum port) timeoutOpt
ctx <- createPool (withAuth host (CC.PortNumber $ toEnum port) timeoutOpt) CC.disconnect 1 idleTime maxResources
ref <- IOR.newIORef Nothing
return (n, NodeConnection ctx ref n)
refreshShardMapVar :: ShardMap -> IO ()
refreshShardMapVar shardMap = hasLocked $ modifyMVar_ shardMapVar (const (pure shardMap))

disconnect :: Connection -> IO ()
disconnect (Connection nodeConnMap _ _ _ _ ) = mapM_ disconnectNode (HM.elems nodeConnMap) where
disconnectNode (NodeConnection nodeCtx _ _) = CC.disconnect nodeCtx
destroyNodeResources :: Connection -> IO ()
destroyNodeResources (Connection nodeConnMapVar _ _ _ _ _) = readMVar nodeConnMapVar >>= (mapM_ disconnectNode . HM.elems) where
disconnectNode (NodeConnection nodePool _ _) = destroyAllResources nodePool

-- Add a request to the current pipeline for this connection. The pipeline will
-- be executed implicitly as soon as any result returned from this function is
-- evaluated.
requestPipelined :: IO ShardMap -> Connection -> [B.ByteString] -> IO Reply
requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
requestPipelined refreshShardmapAction conn@(Connection _ pipelineVar shardMapVar _ _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
(newStateVar, repliesIndex) <- hasLocked $ modifyMVar stateVar $ \case
Pending requests | isMulti nextRequest -> do
replies <- evaluatePipeline shardMapVar refreshAction conn requests
replies <- evaluatePipeline shardMapVar refreshShardmapAction conn requests
s' <- newMVar $ TransactionPending [nextRequest]
return (Executed replies, (s', 0))
Pending requests | length requests > 1000 -> do
replies <- evaluatePipeline shardMapVar refreshAction conn (nextRequest:requests)
replies <- evaluatePipeline shardMapVar refreshShardmapAction conn (nextRequest:requests)
return (Executed replies, (stateVar, length requests))
Pending requests ->
return (Pending (nextRequest:requests), (stateVar, length requests))
TransactionPending requests ->
if isExec nextRequest then do
replies <- evaluateTransactionPipeline shardMapVar refreshAction conn (nextRequest:requests)
replies <- evaluateTransactionPipeline shardMapVar refreshShardmapAction conn (nextRequest:requests)
return (Executed replies, (stateVar, length requests))
else
return (TransactionPending (nextRequest:requests), (stateVar, length requests))
Expand All @@ -204,10 +223,10 @@ requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) n
Executed replies ->
return (Executed replies, replies)
Pending requests-> do
replies <- evaluatePipeline shardMapVar refreshAction conn requests
replies <- evaluatePipeline shardMapVar refreshShardmapAction conn requests
return (Executed replies, replies)
TransactionPending requests-> do
replies <- evaluateTransactionPipeline shardMapVar refreshAction conn requests
replies <- evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests
return (Executed replies, replies)
return $ replies !! repliesIndex
return (Pipeline newStateVar, evaluateAction)
Expand Down Expand Up @@ -270,7 +289,7 @@ evaluatePipeline shardMapVar refreshShardmapAction conn requests = do
Left (err :: SomeException) ->
case fromException err of
Just (er :: TimeoutException) -> throwIO er
_ -> executeRequests (getRandomConnection cc conn) r
_ -> getRandomConnection cc conn >>= (`executeRequests` r)
) (zip eresps requestsByNode)
-- check for any moved in both responses and continue the flow.
when (any (moved . rawResponse) resps) refreshShardMapVar
Expand Down Expand Up @@ -305,14 +324,14 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
-- there is one.
case last replies of
(Error errString) | B.isPrefixOf "MOVED" errString -> do
let (Connection _ _ _ infoMap _) = conn
let (Connection _ _ _ infoMap _ _) = conn
keys <- mconcat <$> mapM (requestKeys infoMap) requests
hashSlot <- hashSlotForKeys (CrossSlotException requests) keys
nodeConn <- nodeConnForHashSlot shardMapVar conn (MissingNodeException (head requests)) hashSlot
requestNode nodeConn requests
(askingRedirection -> Just (host, port)) -> do
shardMap <- hasLocked $ readMVar shardMapVar
let maybeAskNode = nodeConnWithHostAndPort shardMap conn host port
maybeAskNode <- nodeConnWithHostAndPort shardMap conn host port
case maybeAskNode of
Just askNode -> tail <$> requestNode askNode (["ASKING"] : requests)
Nothing -> case retryCount of
Expand All @@ -327,7 +346,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
evaluateTransactionPipeline :: MVar ShardMap -> IO ShardMap -> Connection -> [[B.ByteString]] -> IO [Reply]
evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = do
let requests = reverse requests'
let (Connection _ _ _ infoMap _) = conn
let (Connection _ _ _ infoMap _ _) = conn
keys <- mconcat <$> mapM (requestKeys infoMap) requests
-- In cluster mode Redis expects commands in transactions to all work on the
-- same hashslot. We find that hashslot here.
Expand All @@ -345,7 +364,7 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d
resps <-
case eresps of
Right v -> return v
Left (_ :: SomeException) -> requestNode (getRandomConnection nodeConn conn) requests
Left (_ :: SomeException) -> getRandomConnection nodeConn conn >>= (`requestNode` requests)
-- The Redis documentation has the following to say on the effect of
-- resharding on multi-key operations:
--
Expand Down Expand Up @@ -379,8 +398,9 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d

nodeConnForHashSlot :: Exception e => MVar ShardMap -> Connection -> e -> HashSlot -> IO NodeConnection
nodeConnForHashSlot shardMapVar conn exception hashSlot = do
let (Connection nodeConns _ _ _ _) = conn
let (Connection nodeConnsVar _ _ _ _ _) = conn
(ShardMap shardMap) <- hasLocked $ readMVar shardMapVar
nodeConns <- readMVar nodeConnsVar
node <-
case IntMap.lookup (fromEnum hashSlot) shardMap of
Nothing -> throwIO exception
Expand Down Expand Up @@ -421,13 +441,16 @@ moved (Error errString) = case Char8.words errString of
moved _ = False


nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> Maybe NodeConnection
nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _ _) host port = do
node <- nodeWithHostAndPort shardMap host port
HM.lookup (nodeId node) nodeConns
nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> IO (Maybe NodeConnection)
nodeConnWithHostAndPort shardMap (Connection nodeConnsVar _ _ _ _ _) host port =
case nodeWithHostAndPort shardMap host port of
Nothing -> return Nothing
Just node -> do
nodeConns <- readMVar nodeConnsVar
return (HM.lookup (nodeId node) nodeConns)

nodeConnectionForCommand :: Connection -> ShardMap -> [B.ByteString] -> IO [NodeConnection]
nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap shardMap) request =
nodeConnectionForCommand conn@(Connection nodeConnsVar _ _ infoMap _ _) (ShardMap shardMap) request =
case request of
("FLUSHALL" : _) -> allNodes
("FLUSHDB" : _) -> allNodes
Expand All @@ -439,48 +462,50 @@ nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap sha
node <- case IntMap.lookup (fromEnum hashSlot) shardMap of
Nothing -> throwIO $ MissingNodeException request
Just (Shard master _) -> return master
nodeConns <- readMVar nodeConnsVar
maybe (throwIO $ MissingNodeException request) (return . return) (HM.lookup (nodeId node) nodeConns)
where
allNodes =
case allMasterNodes conn (ShardMap shardMap) of
allNodes = do
maybeNodes <- allMasterNodes conn (ShardMap shardMap)
case maybeNodes of
Nothing -> throwIO $ MissingNodeException request
Just allNodes' -> return allNodes'

allMasterNodes :: Connection -> ShardMap -> Maybe [NodeConnection]
allMasterNodes (Connection nodeConns _ _ _ _) (ShardMap shardMap) =
mapM (flip HM.lookup nodeConns . nodeId) onlyMasterNodes
allMasterNodes :: Connection -> ShardMap -> IO (Maybe [NodeConnection])
allMasterNodes (Connection nodeConnsVar _ _ _ _ _) (ShardMap shardMap) = do
nodeConns <- readMVar nodeConnsVar
return $ mapM (flip HM.lookup nodeConns . nodeId) onlyMasterNodes
where
onlyMasterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap)

requestNode :: NodeConnection -> [[B.ByteString]] -> IO [Reply]
requestNode (NodeConnection ctx lastRecvRef _) requests = do
requestNode (NodeConnection pool lastRecvRef _) requests = withResource pool $ \ctx -> do
envTimeout <- round . (\x -> (x :: Time.NominalDiffTime) * 1000000) . realToFrac . fromMaybe (0.5 :: Double) . (>>= readMaybe) <$> lookupEnv "REDIS_REQUEST_NODE_TIMEOUT"
eresp <- race requestNodeImpl (threadDelay envTimeout)
eresp <- race (requestNodeImpl ctx) (threadDelay envTimeout)
case eresp of
Left e -> return e
Right _ -> putStrLn "timeout happened" *> throwIO (TimeoutException "Request Timeout")

where
requestNodeImpl :: IO [Reply]
requestNodeImpl = do
mapM_ (sendNode . renderRequest) requests
requestNodeImpl :: CC.ConnectionContext -> IO [Reply]
requestNodeImpl ctx = do
mapM_ (sendNode ctx . renderRequest) requests
_ <- CC.flush ctx
replicateM (length requests) recvNode
sendNode :: B.ByteString -> IO ()
sendNode = CC.send ctx
recvNode :: IO Reply
recvNode = do
replicateM (length requests) $ recvNode ctx
sendNode :: CC.ConnectionContext -> B.ByteString -> IO ()
sendNode = CC.send
recvNode :: CC.ConnectionContext -> IO Reply
recvNode ctx = do
maybeLastRecv <- IOR.readIORef lastRecvRef
scanResult <- case maybeLastRecv of
Just lastRecv -> Scanner.scanWith (CC.recv ctx) reply lastRecv
Nothing -> Scanner.scanWith (CC.recv ctx) reply B.empty

case scanResult of
Scanner.Fail{} -> CC.errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
IOR.writeIORef lastRecvRef (Just rest')
return r
Scanner.Fail{} -> CC.errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
IOR.writeIORef lastRecvRef (Just rest')
return r

{-# INLINE nodes #-}
nodes :: ShardMap -> [Node]
Expand Down Expand Up @@ -508,14 +533,16 @@ requestMasterNodes conn req = do
concat <$> mapM (`requestNode` [req]) masterNodeConns

masterNodes :: Connection -> IO [NodeConnection]
masterNodes (Connection nodeConns _ shardMapVar _ _) = do
masterNodes (Connection nodeConnsVar _ shardMapVar _ _ _) = do
(ShardMap shardMap) <- readMVar shardMapVar
let masters = map ((\(Shard m _) -> m) . snd) $ IntMap.toList shardMap
let masterNodeIds = map nodeId masters
nodeConns <- readMVar nodeConnsVar
return $ mapMaybe (`HM.lookup` nodeConns) masterNodeIds

getRandomConnection :: NodeConnection -> Connection -> NodeConnection
getRandomConnection nc conn =
let (Connection hmn _ _ _ _) = conn
conns = HM.elems hmn
in fromMaybe (head conns) $ find (nc /= ) conns
getRandomConnection :: NodeConnection -> Connection -> IO NodeConnection
getRandomConnection nc conn = do
let (Connection hmnVar _ _ _ _ _) = conn
hmn <- readMVar hmnVar
let conns = HM.elems hmn
return $ fromMaybe (head conns) $ find (nc /= ) conns
Loading