--------------------------------------------------------------------------------
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TupleSections     #-}
module Network.WebSockets.Extensions.PermessageDeflate
    ( defaultPermessageDeflate
    , PermessageDeflate(..)
    , negotiateDeflate

      -- * Considered internal
    , makeMessageInflater
    , makeMessageDeflater
    ) where


--------------------------------------------------------------------------------
import           Control.Applicative                       ((<$>))
import           Control.Exception                         (throwIO)
import           Control.Monad                             (foldM, unless)
import qualified Data.ByteString                           as B
import qualified Data.ByteString.Char8                     as B8
import qualified Data.ByteString.Lazy                      as BL
import qualified Data.ByteString.Lazy.Char8                as BL8
import qualified Data.ByteString.Lazy.Internal             as BL
import           Data.Int                                  (Int64)
import           Data.Monoid
import qualified Data.Streaming.Zlib                       as Zlib
import           Network.WebSockets.Connection.Options
import           Network.WebSockets.Extensions
import           Network.WebSockets.Extensions.Description
import           Network.WebSockets.Http
import           Network.WebSockets.Types
import           Prelude
import           Text.Read                                 (readMaybe)


--------------------------------------------------------------------------------
-- | Convert the parameters to an 'ExtensionDescription' that we can put in a
-- 'Sec-WebSocket-Extensions' header.
toExtensionDescription :: PermessageDeflate -> ExtensionDescription
toExtensionDescription :: PermessageDeflate -> ExtensionDescription
toExtensionDescription PermessageDeflate {..} = $WExtensionDescription :: ByteString -> [ExtensionParam] -> ExtensionDescription
ExtensionDescription
    { extName :: ByteString
extName   = "permessage-deflate"
    , extParams :: [ExtensionParam]
extParams =
         [("server_no_context_takeover", Maybe ByteString
forall a. Maybe a
Nothing) | Bool
serverNoContextTakeover] [ExtensionParam] -> [ExtensionParam] -> [ExtensionParam]
forall a. [a] -> [a] -> [a]
++
         [("client_no_context_takeover", Maybe ByteString
forall a. Maybe a
Nothing) | Bool
clientNoContextTakeover] [ExtensionParam] -> [ExtensionParam] -> [ExtensionParam]
forall a. [a] -> [a] -> [a]
++
         [("server_max_window_bits", Int -> Maybe ByteString
param Int
serverMaxWindowBits) | Int
serverMaxWindowBits Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 15] [ExtensionParam] -> [ExtensionParam] -> [ExtensionParam]
forall a. [a] -> [a] -> [a]
++
         [("client_max_window_bits", Int -> Maybe ByteString
param Int
clientMaxWindowBits) | Int
clientMaxWindowBits Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 15]
    }
  where
    param :: Int -> Maybe ByteString
param = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString)
-> (Int -> ByteString) -> Int -> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
B8.pack (String -> ByteString) -> (Int -> String) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show


--------------------------------------------------------------------------------
toHeaders :: PermessageDeflate -> Headers
toHeaders :: PermessageDeflate -> Headers
toHeaders pmd :: PermessageDeflate
pmd =
    [ ( "Sec-WebSocket-Extensions"
      , ExtensionDescriptions -> ByteString
encodeExtensionDescriptions [PermessageDeflate -> ExtensionDescription
toExtensionDescription PermessageDeflate
pmd]
      )
    ]


--------------------------------------------------------------------------------
negotiateDeflate
    :: SizeLimit -> Maybe PermessageDeflate -> NegotiateExtension
negotiateDeflate :: SizeLimit -> Maybe PermessageDeflate -> NegotiateExtension
negotiateDeflate messageLimit :: SizeLimit
messageLimit pmd0 :: Maybe PermessageDeflate
pmd0 exts0 :: ExtensionDescriptions
exts0 = do
    (headers :: Headers
headers, pmd1 :: Maybe PermessageDeflate
pmd1) <- ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts ExtensionDescriptions
exts0 Maybe PermessageDeflate
pmd0
    Extension -> Either String Extension
forall (m :: * -> *) a. Monad m => a -> m a
return Extension :: Headers
-> (IO (Maybe Message) -> IO (IO (Maybe Message)))
-> (([Message] -> IO ()) -> IO ([Message] -> IO ()))
-> Extension
Extension
        { extHeaders :: Headers
extHeaders = Headers
headers
        , extParse :: IO (Maybe Message) -> IO (IO (Maybe Message))
extParse   = \parseRaw :: IO (Maybe Message)
parseRaw -> do
            Message -> IO Message
inflate <- SizeLimit -> Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageInflater SizeLimit
messageLimit Maybe PermessageDeflate
pmd1
            IO (Maybe Message) -> IO (IO (Maybe Message))
forall (m :: * -> *) a. Monad m => a -> m a
return (IO (Maybe Message) -> IO (IO (Maybe Message)))
-> IO (Maybe Message) -> IO (IO (Maybe Message))
forall a b. (a -> b) -> a -> b
$ do
                Maybe Message
msg <- IO (Maybe Message)
parseRaw
                case Maybe Message
msg of
                    Nothing -> Maybe Message -> IO (Maybe Message)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Message
forall a. Maybe a
Nothing
                    Just m :: Message
m  -> (Message -> Maybe Message) -> IO Message -> IO (Maybe Message)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Message -> Maybe Message
forall a. a -> Maybe a
Just (Message -> IO Message
inflate Message
m)

        , extWrite :: ([Message] -> IO ()) -> IO ([Message] -> IO ())
extWrite   = \writeRaw :: [Message] -> IO ()
writeRaw -> do
            Message -> IO Message
deflate <- Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater Maybe PermessageDeflate
pmd1
            ([Message] -> IO ()) -> IO ([Message] -> IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return (([Message] -> IO ()) -> IO ([Message] -> IO ()))
-> ([Message] -> IO ()) -> IO ([Message] -> IO ())
forall a b. (a -> b) -> a -> b
$ \msgs :: [Message]
msgs ->
                (Message -> IO Message) -> [Message] -> IO [Message]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Message -> IO Message
deflate [Message]
msgs IO [Message] -> ([Message] -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Message] -> IO ()
writeRaw
        }
  where
    negotiateDeflateOpts
        :: ExtensionDescriptions
        -> Maybe PermessageDeflate
        -> Either String (Headers, Maybe PermessageDeflate)

    negotiateDeflateOpts :: ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts (ext :: ExtensionDescription
ext : _) (Just x :: PermessageDeflate
x)
        | ExtensionDescription -> ByteString
extName ExtensionDescription
ext ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== "x-webkit-deflate-frame" = (Headers, Maybe PermessageDeflate)
-> Either String (Headers, Maybe PermessageDeflate)
forall a b. b -> Either a b
Right
            ([("Sec-WebSocket-Extensions", "x-webkit-deflate-frame")], PermessageDeflate -> Maybe PermessageDeflate
forall a. a -> Maybe a
Just PermessageDeflate
x)

    negotiateDeflateOpts (ext :: ExtensionDescription
ext : _) (Just x :: PermessageDeflate
x)
        | ExtensionDescription -> ByteString
extName ExtensionDescription
ext ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== "permessage-deflate" = do
            PermessageDeflate
x' <- (PermessageDeflate
 -> ExtensionParam -> Either String PermessageDeflate)
-> PermessageDeflate
-> [ExtensionParam]
-> Either String PermessageDeflate
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM PermessageDeflate
-> ExtensionParam -> Either String PermessageDeflate
setParam PermessageDeflate
x (ExtensionDescription -> [ExtensionParam]
extParams ExtensionDescription
ext)
            (Headers, Maybe PermessageDeflate)
-> Either String (Headers, Maybe PermessageDeflate)
forall a b. b -> Either a b
Right (PermessageDeflate -> Headers
toHeaders PermessageDeflate
x', PermessageDeflate -> Maybe PermessageDeflate
forall a. a -> Maybe a
Just PermessageDeflate
x')

    negotiateDeflateOpts (_ : exts :: ExtensionDescriptions
exts) (Just x :: PermessageDeflate
x) =
        ExtensionDescriptions
-> Maybe PermessageDeflate
-> Either String (Headers, Maybe PermessageDeflate)
negotiateDeflateOpts ExtensionDescriptions
exts (PermessageDeflate -> Maybe PermessageDeflate
forall a. a -> Maybe a
Just PermessageDeflate
x)

    negotiateDeflateOpts _ _ = (Headers, Maybe PermessageDeflate)
-> Either String (Headers, Maybe PermessageDeflate)
forall a b. b -> Either a b
Right ([], Maybe PermessageDeflate
forall a. Maybe a
Nothing)


--------------------------------------------------------------------------------
setParam
    :: PermessageDeflate -> ExtensionParam -> Either String PermessageDeflate
setParam :: PermessageDeflate
-> ExtensionParam -> Either String PermessageDeflate
setParam pmd :: PermessageDeflate
pmd ("server_no_context_takeover", _) =
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {serverNoContextTakeover :: Bool
serverNoContextTakeover = Bool
True}

setParam pmd :: PermessageDeflate
pmd ("client_no_context_takeover", _) =
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {clientNoContextTakeover :: Bool
clientNoContextTakeover = Bool
True}

setParam pmd :: PermessageDeflate
pmd ("server_max_window_bits", Nothing) =
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {serverMaxWindowBits :: Int
serverMaxWindowBits = 15}

setParam pmd :: PermessageDeflate
pmd ("server_max_window_bits", Just param :: ByteString
param) = do
    Int
w <- ByteString -> Either String Int
parseWindow ByteString
param
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {serverMaxWindowBits :: Int
serverMaxWindowBits = Int
w}

setParam pmd :: PermessageDeflate
pmd ("client_max_window_bits", Nothing) = do
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {clientMaxWindowBits :: Int
clientMaxWindowBits = 15}

setParam pmd :: PermessageDeflate
pmd ("client_max_window_bits", Just param :: ByteString
param) = do
    Int
w <- ByteString -> Either String Int
parseWindow ByteString
param
    PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd {clientMaxWindowBits :: Int
clientMaxWindowBits = Int
w}

setParam pmd :: PermessageDeflate
pmd (_, _) = PermessageDeflate -> Either String PermessageDeflate
forall a b. b -> Either a b
Right PermessageDeflate
pmd


--------------------------------------------------------------------------------
parseWindow :: B.ByteString -> Either String Int
parseWindow :: ByteString -> Either String Int
parseWindow bs8 :: ByteString
bs8 = case String -> Maybe Int
forall a. Read a => String -> Maybe a
readMaybe (ByteString -> String
B8.unpack ByteString
bs8) of
    Just w :: Int
w
        | Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= 8 Bool -> Bool -> Bool
&& Int
w Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= 15 -> Int -> Either String Int
forall a b. b -> Either a b
Right Int
w
        | Bool
otherwise         -> String -> Either String Int
forall a b. a -> Either a b
Left (String -> Either String Int) -> String -> Either String Int
forall a b. (a -> b) -> a -> b
$ "Window out of bounds: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
w
    Nothing -> String -> Either String Int
forall a b. a -> Either a b
Left (String -> Either String Int) -> String -> Either String Int
forall a b. (a -> b) -> a -> b
$ "Can't parse window: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
bs8


--------------------------------------------------------------------------------
-- | If the window_bits parameter is set to 8, we must set it to 9 instead.
--
-- Related issues:
-- - https://github.com/haskell/zlib/issues/11
-- - https://github.com/madler/zlib/issues/94
--
-- Quote from zlib manual:
--
-- For the current implementation of deflate(), a windowBits value of 8 (a
-- window size of 256 bytes) is not supported. As a result, a request for 8 will
-- result in 9 (a 512-byte window). In that case, providing 8 to inflateInit2()
-- will result in an error when the zlib header with 9 is checked against the
-- initialization of inflate(). The remedy is to not use 8 with deflateInit2()
-- with this initialization, or at least in that case use 9 with inflateInit2().
fixWindowBits :: Int -> Int
fixWindowBits :: Int -> Int
fixWindowBits n :: Int
n
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 9     = 9
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 15    = 15
    | Bool
otherwise = Int
n


--------------------------------------------------------------------------------
appTailL :: BL.ByteString
appTailL :: ByteString
appTailL = [Word8] -> ByteString
BL.pack [0x00,0x00,0xff,0xff]


--------------------------------------------------------------------------------
maybeStrip :: BL.ByteString -> BL.ByteString
maybeStrip :: ByteString -> ByteString
maybeStrip x :: ByteString
x | ByteString
appTailL ByteString -> ByteString -> Bool
`BL.isSuffixOf` ByteString
x = Int64 -> ByteString -> ByteString
BL.take (ByteString -> Int64
BL.length ByteString
x Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- 4) ByteString
x
maybeStrip x :: ByteString
x = ByteString
x


--------------------------------------------------------------------------------
rejectExtensions :: Message -> IO Message
rejectExtensions :: Message -> IO Message
rejectExtensions (DataMessage rsv1 :: Bool
rsv1 rsv2 :: Bool
rsv2 rsv3 :: Bool
rsv3 _) | Bool
rsv1 Bool -> Bool -> Bool
|| Bool
rsv2 Bool -> Bool -> Bool
|| Bool
rsv3 =
    ConnectionException -> IO Message
forall e a. Exception e => e -> IO a
throwIO (ConnectionException -> IO Message)
-> ConnectionException -> IO Message
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest 1002 "Protocol Error"
rejectExtensions x :: Message
x = Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return Message
x


--------------------------------------------------------------------------------
makeMessageDeflater
    :: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater :: Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageDeflater Nothing = (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return Message -> IO Message
rejectExtensions
makeMessageDeflater (Just pmd :: PermessageDeflate
pmd)
    | PermessageDeflate -> Bool
serverNoContextTakeover PermessageDeflate
pmd = do
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \msg :: Message
msg -> do
            Deflate
ptr <- PermessageDeflate -> IO Deflate
initDeflate PermessageDeflate
pmd
            (ByteString -> IO ByteString) -> Message -> IO Message
deflateMessageWith (Deflate -> ByteString -> IO ByteString
deflateBody Deflate
ptr) Message
msg
    | Bool
otherwise = do
        Deflate
ptr <- PermessageDeflate -> IO Deflate
initDeflate PermessageDeflate
pmd
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \msg :: Message
msg ->
            (ByteString -> IO ByteString) -> Message -> IO Message
deflateMessageWith (Deflate -> ByteString -> IO ByteString
deflateBody Deflate
ptr) Message
msg
  where
    ----------------------------------------------------------------------------
    initDeflate :: PermessageDeflate -> IO Zlib.Deflate
    initDeflate :: PermessageDeflate -> IO Deflate
initDeflate PermessageDeflate {..} =
        Int -> WindowBits -> IO Deflate
Zlib.initDeflate
            Int
pdCompressionLevel
            (Int -> WindowBits
Zlib.WindowBits (- (Int -> Int
fixWindowBits Int
serverMaxWindowBits)))


    ----------------------------------------------------------------------------
    deflateMessageWith
        :: (BL.ByteString -> IO BL.ByteString)
        -> Message -> IO Message
    deflateMessageWith :: (ByteString -> IO ByteString) -> Message -> IO Message
deflateMessageWith deflater :: ByteString -> IO ByteString
deflater (DataMessage False False False (Text x :: ByteString
x _)) = do
        ByteString
x' <- ByteString -> IO ByteString
deflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
True Bool
False Bool
False (ByteString -> Maybe Text -> DataMessage
Text ByteString
x' Maybe Text
forall a. Maybe a
Nothing))
    deflateMessageWith deflater :: ByteString -> IO ByteString
deflater (DataMessage False False False (Binary x :: ByteString
x)) = do
        ByteString
x' <- ByteString -> IO ByteString
deflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
True Bool
False Bool
False (ByteString -> DataMessage
Binary ByteString
x'))
    deflateMessageWith _ x :: Message
x = Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return Message
x


    ----------------------------------------------------------------------------
    deflateBody :: Zlib.Deflate -> BL.ByteString -> IO BL.ByteString
    deflateBody :: Deflate -> ByteString -> IO ByteString
deflateBody ptr :: Deflate
ptr = (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
maybeStrip (IO ByteString -> IO ByteString)
-> (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> IO ByteString
go ([ByteString] -> IO ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BL.toChunks
      where
        go :: [ByteString] -> IO ByteString
go [] =
            Popper -> IO ByteString
dePopper (Deflate -> Popper
Zlib.flushDeflate Deflate
ptr)
        go (c :: ByteString
c : cs :: [ByteString]
cs) = do
            ByteString
chunk <- Deflate -> ByteString -> IO Popper
Zlib.feedDeflate Deflate
ptr ByteString
c IO Popper -> (Popper -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Popper -> IO ByteString
dePopper
            (ByteString
chunk ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString] -> IO ByteString
go [ByteString]
cs


--------------------------------------------------------------------------------
dePopper :: Zlib.Popper -> IO BL.ByteString
dePopper :: Popper -> IO ByteString
dePopper p :: Popper
p = Popper
p Popper -> (PopperRes -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \res :: PopperRes
res -> case PopperRes
res of
    Zlib.PRDone    -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BL.empty
    Zlib.PRNext c :: ByteString
c  -> ByteString -> ByteString -> ByteString
BL.chunk ByteString
c (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Popper -> IO ByteString
dePopper Popper
p
    Zlib.PRError x :: ZlibException
x -> ConnectionException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (ConnectionException -> IO ByteString)
-> ConnectionException -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Word16 -> ByteString -> ConnectionException
CloseRequest 1002 (String -> ByteString
BL8.pack (ZlibException -> String
forall a. Show a => a -> String
show ZlibException
x))


--------------------------------------------------------------------------------
makeMessageInflater
    :: SizeLimit -> Maybe PermessageDeflate
    -> IO (Message -> IO Message)
makeMessageInflater :: SizeLimit -> Maybe PermessageDeflate -> IO (Message -> IO Message)
makeMessageInflater _ Nothing = (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return Message -> IO Message
rejectExtensions
makeMessageInflater messageLimit :: SizeLimit
messageLimit (Just pmd :: PermessageDeflate
pmd)
    | PermessageDeflate -> Bool
clientNoContextTakeover PermessageDeflate
pmd =
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \msg :: Message
msg -> do
            Inflate
ptr <- PermessageDeflate -> IO Inflate
initInflate PermessageDeflate
pmd
            (ByteString -> IO ByteString) -> Message -> IO Message
inflateMessageWith (Inflate -> ByteString -> IO ByteString
inflateBody Inflate
ptr) Message
msg
    | Bool
otherwise = do
        Inflate
ptr <- PermessageDeflate -> IO Inflate
initInflate PermessageDeflate
pmd
        (Message -> IO Message) -> IO (Message -> IO Message)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Message -> IO Message) -> IO (Message -> IO Message))
-> (Message -> IO Message) -> IO (Message -> IO Message)
forall a b. (a -> b) -> a -> b
$ \msg :: Message
msg ->
            (ByteString -> IO ByteString) -> Message -> IO Message
inflateMessageWith (Inflate -> ByteString -> IO ByteString
inflateBody Inflate
ptr) Message
msg
  where
    --------------------------------------------------------------------------------
    initInflate :: PermessageDeflate -> IO Zlib.Inflate
    initInflate :: PermessageDeflate -> IO Inflate
initInflate PermessageDeflate {..} =
        WindowBits -> IO Inflate
Zlib.initInflate
            (Int -> WindowBits
Zlib.WindowBits (- (Int -> Int
fixWindowBits Int
clientMaxWindowBits)))


    ----------------------------------------------------------------------------
    inflateMessageWith
        :: (BL.ByteString -> IO BL.ByteString)
        -> Message -> IO Message
    inflateMessageWith :: (ByteString -> IO ByteString) -> Message -> IO Message
inflateMessageWith inflater :: ByteString -> IO ByteString
inflater (DataMessage True a :: Bool
a b :: Bool
b (Text x :: ByteString
x _)) = do
        ByteString
x' <- ByteString -> IO ByteString
inflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
False Bool
a Bool
b (ByteString -> Maybe Text -> DataMessage
Text ByteString
x' Maybe Text
forall a. Maybe a
Nothing))
    inflateMessageWith inflater :: ByteString -> IO ByteString
inflater (DataMessage True a :: Bool
a b :: Bool
b (Binary x :: ByteString
x)) = do
        ByteString
x' <- ByteString -> IO ByteString
inflater ByteString
x
        Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> Bool -> Bool -> DataMessage -> Message
DataMessage Bool
False Bool
a Bool
b (ByteString -> DataMessage
Binary ByteString
x'))
    inflateMessageWith _ x :: Message
x = Message -> IO Message
forall (m :: * -> *) a. Monad m => a -> m a
return Message
x


    ----------------------------------------------------------------------------
    inflateBody :: Zlib.Inflate -> BL.ByteString -> IO BL.ByteString
    inflateBody :: Inflate -> ByteString -> IO ByteString
inflateBody ptr :: Inflate
ptr =
        Int64 -> [ByteString] -> IO ByteString
go 0 ([ByteString] -> IO ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BL.toChunks (ByteString -> [ByteString])
-> (ByteString -> ByteString) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
appTailL)
      where
        go :: Int64 -> [B.ByteString] -> IO BL.ByteString
        go :: Int64 -> [ByteString] -> IO ByteString
go size0 :: Int64
size0 []       = do
            ByteString
chunk <- Inflate -> IO ByteString
Zlib.flushInflate Inflate
ptr
            Int64 -> IO ()
checkSize (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
chunk) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
size0)
            ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> ByteString
BL.fromStrict ByteString
chunk)
        go size0 :: Int64
size0 (c :: ByteString
c : cs :: [ByteString]
cs) = do
            ByteString
chunk <- Inflate -> ByteString -> IO Popper
Zlib.feedInflate Inflate
ptr ByteString
c IO Popper -> (Popper -> IO ByteString) -> IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Popper -> IO ByteString
dePopper
            let size1 :: Int64
size1 = Int64
size0 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ ByteString -> Int64
BL.length ByteString
chunk
            Int64 -> IO ()
checkSize Int64
size1
            (ByteString
chunk ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int64 -> [ByteString] -> IO ByteString
go Int64
size1 [ByteString]
cs


    ----------------------------------------------------------------------------
    checkSize :: Int64 -> IO ()
    checkSize :: Int64 -> IO ()
checkSize size :: Int64
size = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int64 -> SizeLimit -> Bool
atMostSizeLimit Int64
size SizeLimit
messageLimit) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ConnectionException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (ConnectionException -> IO ()) -> ConnectionException -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> ConnectionException
ParseException (String -> ConnectionException) -> String -> ConnectionException
forall a b. (a -> b) -> a -> b
$ "Message of size " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int64 -> String
forall a. Show a => a -> String
show Int64
size String -> String -> String
forall a. [a] -> [a] -> [a]
++ " exceeded limit"