diff --git a/src/Network/WebSockets/Connection.hs b/src/Network/WebSockets/Connection.hs index 9e87951..6ad00f7 100644 --- a/src/Network/WebSockets/Connection.hs +++ b/src/Network/WebSockets/Connection.hs @@ -12,6 +12,7 @@ module Network.WebSockets.Connection , RejectRequest(..) , defaultRejectRequest , rejectRequestWith + , rejectRequestAndCloseWith , Connection (..) @@ -245,6 +246,16 @@ rejectRequest rejectRequest pc body = rejectRequestWith pc defaultRejectRequest {rejectBody = body} +-------------------------------------------------------------------------------- + +-- | Send a rejection message to the client and close the underlying connection +rejectRequestAndCloseWith + :: PendingConnection -- ^ Connection to reject and close + -> RejectRequest -- ^ Params on how to reject the request + -> IO () +rejectRequestAndCloseWith pc reject = do + rejectRequestWith pc reject + Stream.close $ pendingStream pc -------------------------------------------------------------------------------- data Connection = Connection diff --git a/tests/haskell/Network/WebSockets/Handshake/Tests.hs b/tests/haskell/Network/WebSockets/Handshake/Tests.hs index dbf6921..43797e4 100644 --- a/tests/haskell/Network/WebSockets/Handshake/Tests.hs +++ b/tests/haskell/Network/WebSockets/Handshake/Tests.hs @@ -1,4 +1,5 @@ -------------------------------------------------------------------------------- +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE OverloadedStrings #-} module Network.WebSockets.Handshake.Tests ( tests @@ -7,14 +8,14 @@ module Network.WebSockets.Handshake.Tests -------------------------------------------------------------------------------- import Control.Concurrent (forkIO) -import Control.Exception (handle) +import Control.Exception (catch, handle, throwIO) import Data.ByteString.Char8 () import Data.IORef (newIORef, readIORef, writeIORef) import Data.Maybe (fromJust) import Test.Framework (Test, testGroup) import Test.Framework.Providers.HUnit (testCase) -import Test.HUnit (Assertion, assert, (@?=)) +import Test.HUnit (Assertion, assert, assertFailure, (@?=)) -------------------------------------------------------------------------------- @@ -22,6 +23,7 @@ import Network.WebSockets import Network.WebSockets.Connection import Network.WebSockets.Http import qualified Network.WebSockets.Stream as Stream +import Network.WebSockets.Types (ConnectionException(ConnectionClosed)) -------------------------------------------------------------------------------- @@ -33,6 +35,7 @@ tests = testGroup "Network.WebSockets.Handshake.Test" , testCase "handshake Hybi13 with subprotocols and headers" testHandshakeHybi13WithProtoAndHeaders , testCase "handshake reject" testHandshakeReject , testCase "handshake reject with custom code" testHandshakeRejectWithCode + , testCase "handshake reject and close connection" testHandshakeRejectAndClose , testCase "handshake Hybi9000" testHandshakeHybi9000 ] @@ -157,6 +160,31 @@ testHandshakeRejectWithCode = do code @?= 401 +-------------------------------------------------------------------------------- +testHandshakeRejectAndClose :: Assertion +testHandshakeRejectAndClose = do + ResponseHead code _ _ <- test' rq13 $ \pc -> do + rejectRequestAndCloseWith pc defaultRejectRequest + (do + Stream.write (pendingStream pc) "Stream should be closed" + assertFailure "Stream should be closed" + ) `catch` (\(e :: ConnectionException) -> + case e of + ConnectionClosed -> pure () + _ -> throwIO e + ) + + code @?= 400 + where + test' rq app = do + echo <- Stream.makeEchoStream + _ <- forkIO $ do + _ <- app (PendingConnection defaultConnectionOptions rq (const $ return ()) echo) + return () + mbRh <- Stream.parse echo decodeResponseHead + case mbRh of + Nothing -> fail "testHandshake: No response" + Just rh -> return rh -------------------------------------------------------------------------------- -- I don't believe this one is supported yet