From a92a5aac159df6146356225725be9ab2612561c8 Mon Sep 17 00:00:00 2001 From: Edsko de Vries Date: Fri, 25 Aug 2023 11:47:49 +0200 Subject: [PATCH] Close all streams on termination Importantly, this means we properly close all streams when a client disconnects. --- Network/HTTP2/Arch/Receiver.hs | 22 ++++++++++++++-------- Network/HTTP2/Arch/Stream.hs | 19 +++++++++++++++++++ Network/HTTP2/Arch/Types.hs | 2 +- Network/HTTP2/Client/Run.hs | 1 + Network/HTTP2/Server/Run.hs | 1 + 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/Network/HTTP2/Arch/Receiver.hs b/Network/HTTP2/Arch/Receiver.hs index 730b30c1..51bfd744 100644 --- a/Network/HTTP2/Arch/Receiver.hs +++ b/Network/HTTP2/Arch/Receiver.hs @@ -376,7 +376,7 @@ stream FrameHeaders header@FrameHeader{flags,streamId} bs ctx (Open (Body q _ _ if endOfStream then do tbl <- hpackDecodeTrailer frag streamId ctx writeIORef tlr (Just tbl) - atomically $ writeTQueue q "" + atomically $ writeTQueue q $ Right "" return HalfClosedRemote else -- we don't support continuation here. @@ -412,13 +412,13 @@ stream FrameData E.throwIO $ ConnectionErrorIsSent ProtocolError streamId "too many empty data" else do writeIORef bodyLength len - atomically $ writeTQueue q body + atomically $ writeTQueue q $ Right body if endOfStream then do case mcl of Nothing -> return () Just cl -> when (cl /= len) $ E.throwIO $ StreamErrorIsSent ProtocolError streamId "actual body length is not the same as content-length" -- no trailers - atomically $ writeTQueue q "" + atomically $ writeTQueue q $ Right "" return HalfClosedRemote else return s @@ -498,11 +498,11 @@ stream _ FrameHeader{streamId} _ _ _ _ = E.throwIO $ StreamErrorIsSent ProtocolE -- | Type for input streaming. data Source = Source (Int -> IO ()) - (TQueue ByteString) + (TQueue (Either E.SomeException ByteString)) (IORef ByteString) (IORef Bool) -mkSource :: TQueue ByteString -> (Int -> IO ()) -> IO Source +mkSource :: TQueue (Either E.SomeException ByteString) -> (Int -> IO ()) -> IO Source mkSource q inform = Source inform q <$> newIORef "" <*> newIORef False readSource :: Source -> IO ByteString @@ -516,12 +516,18 @@ readSource (Source inform q refBS refEOF) = do inform len return bs where + readBS :: IO ByteString readBS = do bs0 <- readIORef refBS if bs0 == "" then do - bs <- atomically $ readTQueue q - when (bs == "") $ writeIORef refEOF True - return bs + mBS <- atomically $ readTQueue q + case mBS of + Left err -> do + writeIORef refEOF True + E.throwIO err + Right bs -> do + when (bs == "") $ writeIORef refEOF True + return bs else do writeIORef refBS "" return bs0 diff --git a/Network/HTTP2/Arch/Stream.hs b/Network/HTTP2/Arch/Stream.hs index f835b9d3..23bf0b64 100644 --- a/Network/HTTP2/Arch/Stream.hs +++ b/Network/HTTP2/Arch/Stream.hs @@ -2,6 +2,7 @@ module Network.HTTP2.Arch.Stream where +import Control.Exception import Data.IORef import qualified Data.IntMap.Strict as M import UnliftIO.Concurrent @@ -75,3 +76,21 @@ updateAllStreamWindow :: (WindowSize -> WindowSize) -> StreamTable -> IO () updateAllStreamWindow adst (StreamTable ref) = do strms <- M.elems <$> readIORef ref forM_ strms $ \strm -> atomically $ modifyTVar (streamWindow strm) adst + +closeAllStreams :: StreamTable -> Maybe SomeException -> IO () +closeAllStreams (StreamTable ref) mErr' = do + strms <- atomicModifyIORef' ref $ \m -> (M.empty, m) + forM_ strms $ \strm -> do + st <- readStreamState strm + case st of + Open (Body q _ _ _) -> + atomically $ writeTQueue q $ maybe (Right mempty) Left mErr + _otherwise -> + return () + where + mErr :: Maybe SomeException + mErr = case mErr of + Just err | Just ConnectionIsClosed <- fromException err -> + Nothing + _otherwise -> + mErr' diff --git a/Network/HTTP2/Arch/Types.hs b/Network/HTTP2/Arch/Types.hs index 24baed2a..c6d2eb9e 100644 --- a/Network/HTTP2/Arch/Types.hs +++ b/Network/HTTP2/Arch/Types.hs @@ -204,7 +204,7 @@ data OpenState = Bool -- End of stream | NoBody HeaderTable | HasBody HeaderTable - | Body (TQueue ByteString) + | Body (TQueue (Either SomeException ByteString)) (Maybe Int) -- received Content-Length -- compared the body length for error checking (IORef Int) -- actual body length diff --git a/Network/HTTP2/Client/Run.hs b/Network/HTTP2/Client/Run.hs index 63589848..798696eb 100644 --- a/Network/HTTP2/Client/Run.hs +++ b/Network/HTTP2/Client/Run.hs @@ -40,6 +40,7 @@ run ClientConfig{..} conf@Config{..} client = do enqueueControl (controlQ ctx) $ CFrames Nothing [frame] return x stopAfter mgr (race runBackgroundThreads runClient) $ \res -> do + closeAllStreams (streamTable ctx) $ either Just (const Nothing) res case res of Left err -> throwIO err diff --git a/Network/HTTP2/Server/Run.hs b/Network/HTTP2/Server/Run.hs index 5dbe53d3..41180aeb 100644 --- a/Network/HTTP2/Server/Run.hs +++ b/Network/HTTP2/Server/Run.hs @@ -34,6 +34,7 @@ run conf@Config{..} server = do let runReceiver = frameReceiver ctx conf runSender = frameSender ctx conf mgr stopAfter mgr (concurrently_ runReceiver runSender) $ \res -> do + closeAllStreams (streamTable ctx) $ either Just (const Nothing) res case res of Left err -> throwIO err