-- | Internal module exposing the guts of the package.  Use at
-- your own risk.  No API stability guarantees apply.
module Web.ServerSession.Backend.Redis.Internal
  ( RedisStorage(..)
  , RedisStorageException(..)

  , transaction
  , unwrap
  , rSessionKey
  , rAuthKey

  , RedisSession(..)
  , parseSession
  , printSession
  , parseUTCTime
  , printUTCTime
  , timeFormat

  , getSessionImpl
  , deleteSessionImpl
  , removeSessionFromAuthId
  , insertSessionForAuthId
  , deleteAllSessionsOfAuthIdImpl
  , insertSessionImpl
  , replaceSessionImpl
  , throwRS
  ) where

import Control.Applicative as A
import Control.Arrow (first)
import Control.Monad (void, when)
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.List (partition)
import Data.Maybe (fromMaybe, catMaybes)
import Data.Proxy (Proxy(..))
import Data.Typeable (Typeable)
import Web.PathPieces (toPathPiece)
import Web.ServerSession.Core

import qualified Control.Exception as E
import qualified Database.Redis as R
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import qualified Data.HashMap.Strict as HM
import qualified Data.Text.Encoding as TE
import qualified Data.Time.Clock as TI
import qualified Data.Time.Clock.POSIX as TP
import qualified Data.Time.Format as TI

#if MIN_VERSION_time(1,5,0)
import Data.Time.Format (defaultTimeLocale)
#else
import System.Locale (defaultTimeLocale)
#endif

----------------------------------------------------------------------


-- | Session storage backend using Redis via the @hedis@ package.
data RedisStorage sess =
  RedisStorage
    { forall sess. RedisStorage sess -> Connection
connPool :: R.Connection
      -- ^ Connection pool to the Redis server.
    , forall sess. RedisStorage sess -> Maybe NominalDiffTime
idleTimeout :: Maybe TI.NominalDiffTime
    -- ^ How long should a session live after last access
    , forall sess. RedisStorage sess -> Maybe NominalDiffTime
absoluteTimeout :: Maybe TI.NominalDiffTime
    -- ^ How long should a session live after creation
    } deriving (Typeable)


-- | We do not provide any ACID guarantees for different actions
-- running inside the same @TransactionM RedisStorage@.
instance RedisSession sess => Storage (RedisStorage sess) where
  type SessionData  (RedisStorage sess) = sess
  type TransactionM (RedisStorage sess) = R.Redis
  runTransactionM :: forall a.
RedisStorage sess -> TransactionM (RedisStorage sess) a -> IO a
runTransactionM = Connection -> Redis a -> IO a
forall a. Connection -> Redis a -> IO a
R.runRedis (Connection -> Redis a -> IO a)
-> (RedisStorage sess -> Connection)
-> RedisStorage sess
-> Redis a
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RedisStorage sess -> Connection
forall sess. RedisStorage sess -> Connection
connPool
  getSession :: RedisStorage sess
-> SessionId (SessionData (RedisStorage sess))
-> TransactionM
     (RedisStorage sess)
     (Maybe (Session (SessionData (RedisStorage sess))))
getSession                RedisStorage sess
_ = SessionId sess -> Redis (Maybe (Session sess))
SessionId (SessionData (RedisStorage sess))
-> TransactionM
     (RedisStorage sess)
     (Maybe (Session (SessionData (RedisStorage sess))))
forall sess.
RedisSession sess =>
SessionId sess -> Redis (Maybe (Session sess))
getSessionImpl
  deleteSession :: RedisStorage sess
-> SessionId (SessionData (RedisStorage sess))
-> TransactionM (RedisStorage sess) ()
deleteSession             RedisStorage sess
_ = SessionId sess -> Redis ()
SessionId (SessionData (RedisStorage sess))
-> TransactionM (RedisStorage sess) ()
forall sess. RedisSession sess => SessionId sess -> Redis ()
deleteSessionImpl
  deleteAllSessionsOfAuthId :: RedisStorage sess
-> ByteString -> TransactionM (RedisStorage sess) ()
deleteAllSessionsOfAuthId RedisStorage sess
_ = ByteString -> Redis ()
ByteString -> TransactionM (RedisStorage sess) ()
deleteAllSessionsOfAuthIdImpl
  insertSession :: RedisStorage sess
-> Session (SessionData (RedisStorage sess))
-> TransactionM (RedisStorage sess) ()
insertSession               = RedisStorage sess -> Session sess -> Redis ()
RedisStorage sess
-> Session (SessionData (RedisStorage sess))
-> TransactionM (RedisStorage sess) ()
forall sess.
RedisSession sess =>
RedisStorage sess -> Session sess -> Redis ()
insertSessionImpl
  replaceSession :: RedisStorage sess
-> Session (SessionData (RedisStorage sess))
-> TransactionM (RedisStorage sess) ()
replaceSession              = RedisStorage sess -> Session sess -> Redis ()
RedisStorage sess
-> Session (SessionData (RedisStorage sess))
-> TransactionM (RedisStorage sess) ()
forall sess.
RedisSession sess =>
RedisStorage sess -> Session sess -> Redis ()
replaceSessionImpl


-- | An exception thrown by the @serversession-backend-redis@
-- package.
data RedisStorageException =
    ExpectedTxSuccess (R.TxResult ())
    -- ^ We expected 'TxSuccess' but got something else.
  | ExpectedRight R.Reply
    -- ^ We expected 'Right' from an @Either 'R.Reply' a@ but got
    -- 'Left'.
    deriving (Int -> RedisStorageException -> ShowS
[RedisStorageException] -> ShowS
RedisStorageException -> String
(Int -> RedisStorageException -> ShowS)
-> (RedisStorageException -> String)
-> ([RedisStorageException] -> ShowS)
-> Show RedisStorageException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RedisStorageException -> ShowS
showsPrec :: Int -> RedisStorageException -> ShowS
$cshow :: RedisStorageException -> String
show :: RedisStorageException -> String
$cshowList :: [RedisStorageException] -> ShowS
showList :: [RedisStorageException] -> ShowS
Show, Typeable)

instance E.Exception RedisStorageException


----------------------------------------------------------------------


-- | Run the given Redis transaction and force its result.
-- Throws a 'RedisStorageException' if the result is not
-- 'TxSuccess'.
transaction :: R.RedisTx (R.Queued ()) -> R.Redis ()
transaction :: RedisTx (Queued ()) -> Redis ()
transaction RedisTx (Queued ())
tx = do
  ret <- RedisTx (Queued ()) -> Redis (TxResult ())
forall a. RedisTx (Queued a) -> Redis (TxResult a)
R.multiExec RedisTx (Queued ())
tx
  case ret of
   R.TxSuccess () -> () -> Redis ()
forall a. a -> Redis a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
   TxResult ()
_              -> IO () -> Redis ()
forall a. IO a -> Redis a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Redis ()) -> IO () -> Redis ()
forall a b. (a -> b) -> a -> b
$ RedisStorageException -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO (RedisStorageException -> IO ()) -> RedisStorageException -> IO ()
forall a b. (a -> b) -> a -> b
$ TxResult () -> RedisStorageException
ExpectedTxSuccess TxResult ()
ret


-- | Unwraps an @Either 'R.Reply' a@ by throwing an exception if
-- not @Right@.
unwrap :: R.Redis (Either R.Reply a) -> R.Redis a
unwrap :: forall a. Redis (Either Reply a) -> Redis a
unwrap Redis (Either Reply a)
act = Redis (Either Reply a)
act Redis (Either Reply a) -> (Either Reply a -> Redis a) -> Redis a
forall a b. Redis a -> (a -> Redis b) -> Redis b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Reply -> Redis a) -> (a -> Redis a) -> Either Reply a -> Redis a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (IO a -> Redis a
forall a. IO a -> Redis a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> Redis a) -> (Reply -> IO a) -> Reply -> Redis a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RedisStorageException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO (RedisStorageException -> IO a)
-> (Reply -> RedisStorageException) -> Reply -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reply -> RedisStorageException
ExpectedRight) a -> Redis a
forall a. a -> Redis a
forall (m :: * -> *) a. Monad m => a -> m a
return


-- | Redis key for the given session ID.
rSessionKey :: SessionId sess -> ByteString
rSessionKey :: forall sess. SessionId sess -> ByteString
rSessionKey = ByteString -> ByteString -> ByteString
B.append ByteString
"ssr:session:" (ByteString -> ByteString)
-> (SessionId sess -> ByteString) -> SessionId sess -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
TE.encodeUtf8 (Text -> ByteString)
-> (SessionId sess -> Text) -> SessionId sess -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionId sess -> Text
forall s. PathPiece s => s -> Text
toPathPiece


-- | Redis key for the given auth ID.
rAuthKey :: AuthId -> ByteString
rAuthKey :: ByteString -> ByteString
rAuthKey = ByteString -> ByteString -> ByteString
B.append ByteString
"ssr:authid:"


----------------------------------------------------------------------


-- | Class for data types that can be used as session data for
-- the Redis backend.
--
-- It should hold that
--
-- @
-- fromHash p . perm . toHash p  ===  id
-- @
--
-- for all list permutations @perm :: [a] -> [a]@,
-- where @p :: Proxy sess@.
class IsSessionData sess => RedisSession sess where
  -- | Transform a decomposed session into a Redis hash.  Keys
  -- will be prepended with @\"data:\"@ before being stored.
  toHash   :: Proxy sess -> Decomposed sess -> [(ByteString, ByteString)]

  -- | Parse back a Redis hash into session data.
  fromHash :: Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess


-- | Assumes that keys are UTF-8 encoded when parsing (which is
-- true if keys are always generated via @toHash@).
instance RedisSession SessionMap where
  toHash :: Proxy SessionMap
-> Decomposed SessionMap -> [(ByteString, ByteString)]
toHash   Proxy SessionMap
_ = ((Text, ByteString) -> (ByteString, ByteString))
-> [(Text, ByteString)] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ((Text -> ByteString)
-> (Text, ByteString) -> (ByteString, ByteString)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Text -> ByteString
TE.encodeUtf8) ([(Text, ByteString)] -> [(ByteString, ByteString)])
-> (SessionMap -> [(Text, ByteString)])
-> SessionMap
-> [(ByteString, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap Text ByteString -> [(Text, ByteString)]
forall k v. HashMap k v -> [(k, v)]
HM.toList (HashMap Text ByteString -> [(Text, ByteString)])
-> (SessionMap -> HashMap Text ByteString)
-> SessionMap
-> [(Text, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SessionMap -> HashMap Text ByteString
unSessionMap
  fromHash :: Proxy SessionMap
-> [(ByteString, ByteString)] -> Decomposed SessionMap
fromHash Proxy SessionMap
_ = HashMap Text ByteString -> SessionMap
SessionMap (HashMap Text ByteString -> SessionMap)
-> ([(ByteString, ByteString)] -> HashMap Text ByteString)
-> [(ByteString, ByteString)]
-> SessionMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Text, ByteString)] -> HashMap Text ByteString
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList ([(Text, ByteString)] -> HashMap Text ByteString)
-> ([(ByteString, ByteString)] -> [(Text, ByteString)])
-> [(ByteString, ByteString)]
-> HashMap Text ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, ByteString) -> (Text, ByteString))
-> [(ByteString, ByteString)] -> [(Text, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> Text)
-> (ByteString, ByteString) -> (Text, ByteString)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ByteString -> Text
TE.decodeUtf8)


-- | Parse a 'Session' from a Redis hash.
parseSession
  :: forall sess. RedisSession sess
  => SessionId sess
  -> [(ByteString, ByteString)]
  -> Maybe (Session sess)
parseSession :: forall sess.
RedisSession sess =>
SessionId sess
-> [(ByteString, ByteString)] -> Maybe (Session sess)
parseSession SessionId sess
_   []  = Maybe (Session sess)
forall a. Maybe a
Nothing
parseSession SessionId sess
sid [(ByteString, ByteString)]
bss =
  let ([(ByteString, ByteString)]
externalList, [(ByteString, ByteString)]
internalList) = ((ByteString, ByteString) -> Bool)
-> [(ByteString, ByteString)]
-> ([(ByteString, ByteString)], [(ByteString, ByteString)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (ByteString -> ByteString -> Bool
B8.isPrefixOf ByteString
"data:" (ByteString -> Bool)
-> ((ByteString, ByteString) -> ByteString)
-> (ByteString, ByteString)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, ByteString) -> ByteString
forall a b. (a, b) -> a
fst) [(ByteString, ByteString)]
bss
      authId :: Maybe ByteString
authId     = ByteString -> [(ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"internal:authId" [(ByteString, ByteString)]
internalList
      createdAt :: UTCTime
createdAt  = ByteString -> UTCTime
parseUTCTime (ByteString -> UTCTime) -> ByteString -> UTCTime
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
lookup' ByteString
"internal:createdAt"
      accessedAt :: UTCTime
accessedAt = ByteString -> UTCTime
parseUTCTime (ByteString -> UTCTime) -> ByteString -> UTCTime
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
lookup' ByteString
"internal:accessedAt"
      lookup' :: ByteString -> ByteString
lookup' ByteString
k = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe (String -> ByteString
forall a. HasCallStack => String -> a
error String
err) (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> [(ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
k [(ByteString, ByteString)]
internalList
        where err :: String
err = String
"serversession-backend-redis/parseSession: missing key " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
forall a. Show a => a -> String
show ByteString
k
      data_ :: Decomposed sess
data_ = Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess
forall sess.
RedisSession sess =>
Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess
fromHash Proxy sess
p ([(ByteString, ByteString)] -> Decomposed sess)
-> [(ByteString, ByteString)] -> Decomposed sess
forall a b. (a -> b) -> a -> b
$ ((ByteString, ByteString) -> (ByteString, ByteString))
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> ByteString)
-> (ByteString, ByteString) -> (ByteString, ByteString)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ByteString -> ByteString
removePrefix) [(ByteString, ByteString)]
externalList
        where removePrefix :: ByteString -> ByteString
removePrefix ByteString
bs = let (ByteString
"data:", ByteString
key) = Int -> ByteString -> (ByteString, ByteString)
B8.splitAt Int
5 ByteString
bs in ByteString
key
              p :: Proxy sess
p = Proxy sess
forall {k} (t :: k). Proxy t
Proxy :: Proxy sess
  in Session sess -> Maybe (Session sess)
forall a. a -> Maybe a
Just Session
       { sessionKey :: SessionId sess
sessionKey        = SessionId sess
sid
       , sessionAuthId :: Maybe ByteString
sessionAuthId     = Maybe ByteString
authId
       , sessionData :: Decomposed sess
sessionData       = Decomposed sess
data_
       , sessionCreatedAt :: UTCTime
sessionCreatedAt  = UTCTime
createdAt
       , sessionAccessedAt :: UTCTime
sessionAccessedAt = UTCTime
accessedAt
       }


-- | Convert a 'Session' into a Redis hash.
printSession :: forall sess. RedisSession sess => Session sess -> [(ByteString, ByteString)]
printSession :: forall sess.
RedisSession sess =>
Session sess -> [(ByteString, ByteString)]
printSession Session {Maybe ByteString
UTCTime
Decomposed sess
SessionId sess
sessionKey :: forall sess. Session sess -> SessionId sess
sessionAuthId :: forall sess. Session sess -> Maybe ByteString
sessionData :: forall sess. Session sess -> Decomposed sess
sessionCreatedAt :: forall sess. Session sess -> UTCTime
sessionAccessedAt :: forall sess. Session sess -> UTCTime
sessionKey :: SessionId sess
sessionAuthId :: Maybe ByteString
sessionData :: Decomposed sess
sessionCreatedAt :: UTCTime
sessionAccessedAt :: UTCTime
..} =
  ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> (ByteString
    -> [(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> Maybe ByteString
-> [(ByteString, ByteString)]
-> [(ByteString, ByteString)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. a -> a
id ((:) ((ByteString, ByteString)
 -> [(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> (ByteString -> (ByteString, ByteString))
-> ByteString
-> [(ByteString, ByteString)]
-> [(ByteString, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,) ByteString
"internal:authId") Maybe ByteString
sessionAuthId ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$
  (:) (ByteString
"internal:createdAt",  UTCTime -> ByteString
printUTCTime UTCTime
sessionCreatedAt) ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$
  (:) (ByteString
"internal:accessedAt", UTCTime -> ByteString
printUTCTime UTCTime
sessionAccessedAt) ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$
  ((ByteString, ByteString) -> (ByteString, ByteString))
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> ByteString)
-> (ByteString, ByteString) -> (ByteString, ByteString)
forall b c d. (b -> c) -> (b, d) -> (c, d)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first ((ByteString -> ByteString)
 -> (ByteString, ByteString) -> (ByteString, ByteString))
-> (ByteString -> ByteString)
-> (ByteString, ByteString)
-> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> ByteString
B8.append ByteString
"data:") ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a b. (a -> b) -> a -> b
$
  Proxy sess -> Decomposed sess -> [(ByteString, ByteString)]
forall sess.
RedisSession sess =>
Proxy sess -> Decomposed sess -> [(ByteString, ByteString)]
toHash (Proxy sess
forall {k} (t :: k). Proxy t
Proxy :: Proxy sess) Decomposed sess
sessionData


-- | Parse 'UTCTime' from a 'ByteString' stored on Redis.  Uses
-- 'error' on parse error.
parseUTCTime :: ByteString -> TI.UTCTime
#if MIN_VERSION_time(1,5,0)
parseUTCTime :: ByteString -> UTCTime
parseUTCTime = Bool -> TimeLocale -> String -> String -> UTCTime
forall t.
ParseTime t =>
Bool -> TimeLocale -> String -> String -> t
TI.parseTimeOrError Bool
True TimeLocale
defaultTimeLocale String
timeFormat (String -> UTCTime)
-> (ByteString -> String) -> ByteString -> UTCTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
B8.unpack
#else
parseUTCTime =
  fromMaybe (error "Web.ServerSession.Backend.Redis.Internal.parseUTCTime") .
  TI.parseTime defaultTimeLocale timeFormat . B8.unpack
#endif


-- | Convert a 'UTCTime' into a 'ByteString' to be stored on
-- Redis.
printUTCTime :: TI.UTCTime -> ByteString
printUTCTime :: UTCTime -> ByteString
printUTCTime = String -> ByteString
B8.pack (String -> ByteString)
-> (UTCTime -> String) -> UTCTime -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeLocale -> String -> UTCTime -> String
forall t. FormatTime t => TimeLocale -> String -> t -> String
TI.formatTime TimeLocale
defaultTimeLocale String
timeFormat


-- | Time format used when storing 'UTCTime'.
timeFormat :: String
timeFormat :: String
timeFormat = String
"%Y-%m-%dT%H:%M:%S%Q"


----------------------------------------------------------------------


-- | Run the given Redis command in batches of @511*1024@ items.
-- This is used for @HMSET@ because there's a hard Redis limit of
-- @1024*1024@ arguments to a command.  The last result is returned.
batched :: Monad m => ([a] -> m b) -> [a] -> m b
batched :: forall (m :: * -> *) a b. Monad m => ([a] -> m b) -> [a] -> m b
batched [a] -> m b
f [a]
xs =
  let ([a]
this, [a]
rest) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
511Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
1024) [a]
xs
      continue :: b -> m b
continue | [a] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [a]
rest = b -> m b
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
               | Bool
otherwise = m b -> b -> m b
forall a b. a -> b -> a
const (([a] -> m b) -> [a] -> m b
forall (m :: * -> *) a b. Monad m => ([a] -> m b) -> [a] -> m b
batched [a] -> m b
f [a]
rest)
  in [a] -> m b
f [a]
this m b -> (b -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= b -> m b
continue


-- | Get the session for the given session ID.
getSessionImpl :: RedisSession sess => SessionId sess -> R.Redis (Maybe (Session sess))
getSessionImpl :: forall sess.
RedisSession sess =>
SessionId sess -> Redis (Maybe (Session sess))
getSessionImpl SessionId sess
sid = SessionId sess
-> [(ByteString, ByteString)] -> Maybe (Session sess)
forall sess.
RedisSession sess =>
SessionId sess
-> [(ByteString, ByteString)] -> Maybe (Session sess)
parseSession SessionId sess
sid ([(ByteString, ByteString)] -> Maybe (Session sess))
-> Redis [(ByteString, ByteString)] -> Redis (Maybe (Session sess))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
A.<$> Redis (Either Reply [(ByteString, ByteString)])
-> Redis [(ByteString, ByteString)]
forall a. Redis (Either Reply a) -> Redis a
unwrap (ByteString -> Redis (Either Reply [(ByteString, ByteString)])
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> m (f [(ByteString, ByteString)])
R.hgetall (ByteString -> Redis (Either Reply [(ByteString, ByteString)]))
-> ByteString -> Redis (Either Reply [(ByteString, ByteString)])
forall a b. (a -> b) -> a -> b
$ SessionId sess -> ByteString
forall sess. SessionId sess -> ByteString
rSessionKey SessionId sess
sid)


-- | Delete the session with given session ID.
deleteSessionImpl :: RedisSession sess => SessionId sess -> R.Redis ()
deleteSessionImpl :: forall sess. RedisSession sess => SessionId sess -> Redis ()
deleteSessionImpl SessionId sess
sid = do
  msession <- SessionId sess -> Redis (Maybe (Session sess))
forall sess.
RedisSession sess =>
SessionId sess -> Redis (Maybe (Session sess))
getSessionImpl SessionId sess
sid
  case msession of
    Maybe (Session sess)
Nothing -> () -> Redis ()
forall a. a -> Redis a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Just Session sess
session ->
      RedisTx (Queued ()) -> Redis ()
transaction (RedisTx (Queued ()) -> Redis ())
-> RedisTx (Queued ()) -> Redis ()
forall a b. (a -> b) -> a -> b
$ do
        r <- [ByteString] -> RedisTx (Queued Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
[ByteString] -> m (f Integer)
R.del [SessionId sess -> ByteString
forall sess. SessionId sess -> ByteString
rSessionKey SessionId sess
sid]
        removeSessionFromAuthId sid (sessionAuthId session)
        return (() <$ r)


-- | Remove the given 'SessionId' from the set of sessions of the
-- given 'AuthId'.  Does not do anything if @Nothing@.
removeSessionFromAuthId :: (R.RedisCtx m f, Functor m) => SessionId sess -> Maybe AuthId -> m ()
removeSessionFromAuthId :: forall (m :: * -> *) (f :: * -> *) sess.
(RedisCtx m f, Functor m) =>
SessionId sess -> Maybe ByteString -> m ()
removeSessionFromAuthId = (ByteString -> [ByteString] -> m (f Integer))
-> SessionId sess -> Maybe ByteString -> m ()
forall (m :: * -> *) (f :: * -> *) sess.
(RedisCtx m f, Functor m) =>
(ByteString -> [ByteString] -> m (f Integer))
-> SessionId sess -> Maybe ByteString -> m ()
fooSessionBarAuthId ByteString -> [ByteString] -> m (f Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> [ByteString] -> m (f Integer)
R.srem

-- | Insert the given 'SessionId' into the set of sessions of the
-- given 'AuthId'.  Does not do anything if @Nothing@.
insertSessionForAuthId :: (R.RedisCtx m f, Functor m) => SessionId sess -> Maybe AuthId -> m ()
insertSessionForAuthId :: forall (m :: * -> *) (f :: * -> *) sess.
(RedisCtx m f, Functor m) =>
SessionId sess -> Maybe ByteString -> m ()
insertSessionForAuthId = (ByteString -> [ByteString] -> m (f Integer))
-> SessionId sess -> Maybe ByteString -> m ()
forall (m :: * -> *) (f :: * -> *) sess.
(RedisCtx m f, Functor m) =>
(ByteString -> [ByteString] -> m (f Integer))
-> SessionId sess -> Maybe ByteString -> m ()
fooSessionBarAuthId ByteString -> [ByteString] -> m (f Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> [ByteString] -> m (f Integer)
R.sadd


-- | (Internal) Helper for 'removeSessionFromAuthId' and 'insertSessionForAuthId'
fooSessionBarAuthId
  :: (R.RedisCtx m f, Functor m)
  => (ByteString -> [ByteString] -> m (f Integer))
  -> SessionId sess
  -> Maybe AuthId
  -> m ()
fooSessionBarAuthId :: forall (m :: * -> *) (f :: * -> *) sess.
(RedisCtx m f, Functor m) =>
(ByteString -> [ByteString] -> m (f Integer))
-> SessionId sess -> Maybe ByteString -> m ()
fooSessionBarAuthId ByteString -> [ByteString] -> m (f Integer)
_   SessionId sess
_   Maybe ByteString
Nothing       = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
fooSessionBarAuthId ByteString -> [ByteString] -> m (f Integer)
fun SessionId sess
sid (Just ByteString
authId) = m (f Integer) -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m (f Integer) -> m ()) -> m (f Integer) -> m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> [ByteString] -> m (f Integer)
fun (ByteString -> ByteString
rAuthKey ByteString
authId) [SessionId sess -> ByteString
forall sess. SessionId sess -> ByteString
rSessionKey SessionId sess
sid]


-- | Delete all sessions of the given auth ID.
deleteAllSessionsOfAuthIdImpl :: AuthId -> R.Redis ()
deleteAllSessionsOfAuthIdImpl :: ByteString -> Redis ()
deleteAllSessionsOfAuthIdImpl ByteString
authId = do
  sessionRefs <- Redis (Either Reply [ByteString]) -> Redis [ByteString]
forall a. Redis (Either Reply a) -> Redis a
unwrap (Redis (Either Reply [ByteString]) -> Redis [ByteString])
-> Redis (Either Reply [ByteString]) -> Redis [ByteString]
forall a b. (a -> b) -> a -> b
$ ByteString -> Redis (Either Reply [ByteString])
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> m (f [ByteString])
R.smembers (ByteString -> ByteString
rAuthKey ByteString
authId)
  void $ unwrap $ R.del $ rAuthKey authId : sessionRefs


-- | Insert a new session.
insertSessionImpl :: RedisSession sess => RedisStorage sess -> Session sess -> R.Redis ()
insertSessionImpl :: forall sess.
RedisSession sess =>
RedisStorage sess -> Session sess -> Redis ()
insertSessionImpl RedisStorage sess
sto Session sess
session = do
  -- Check that no old session exists.
  let sid :: SessionId sess
sid = Session sess -> SessionId sess
forall sess. Session sess -> SessionId sess
sessionKey Session sess
session
  moldSession <- SessionId sess -> Redis (Maybe (Session sess))
forall sess.
RedisSession sess =>
SessionId sess -> Redis (Maybe (Session sess))
getSessionImpl SessionId sess
sid
  case moldSession of
    Just Session sess
oldSession -> StorageException (RedisStorage sess) -> Redis ()
forall sess a.
Storage (RedisStorage sess) =>
StorageException (RedisStorage sess) -> Redis a
throwRS (StorageException (RedisStorage sess) -> Redis ())
-> StorageException (RedisStorage sess) -> Redis ()
forall a b. (a -> b) -> a -> b
$ Session (SessionData (RedisStorage sess))
-> Session (SessionData (RedisStorage sess))
-> StorageException (RedisStorage sess)
forall sto.
Session (SessionData sto)
-> Session (SessionData sto) -> StorageException sto
SessionAlreadyExists Session sess
Session (SessionData (RedisStorage sess))
oldSession Session sess
Session (SessionData (RedisStorage sess))
session
    Maybe (Session sess)
Nothing -> do
      RedisTx (Queued ()) -> Redis ()
transaction (RedisTx (Queued ()) -> Redis ())
-> RedisTx (Queued ()) -> Redis ()
forall a b. (a -> b) -> a -> b
$ do
        let sk :: ByteString
sk = SessionId sess -> ByteString
forall sess. SessionId sess -> ByteString
rSessionKey SessionId sess
sid
        r <- ([(ByteString, ByteString)] -> RedisTx (Queued Status))
-> [(ByteString, ByteString)] -> RedisTx (Queued Status)
forall (m :: * -> *) a b. Monad m => ([a] -> m b) -> [a] -> m b
batched (ByteString -> [(ByteString, ByteString)] -> RedisTx (Queued Status)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> [(ByteString, ByteString)] -> m (f Status)
R.hmset ByteString
sk) (Session sess -> [(ByteString, ByteString)]
forall sess.
RedisSession sess =>
Session sess -> [(ByteString, ByteString)]
printSession Session sess
session)
        expireSession session sto
        insertSessionForAuthId (sessionKey session) (sessionAuthId session)
        return (() <$ r)


-- | Replace the contents of a session.
replaceSessionImpl :: RedisSession sess => RedisStorage sess -> Session sess -> R.Redis ()
replaceSessionImpl :: forall sess.
RedisSession sess =>
RedisStorage sess -> Session sess -> Redis ()
replaceSessionImpl RedisStorage sess
sto Session sess
session = do
  -- Check that the old session exists.
  let sid :: SessionId sess
sid = Session sess -> SessionId sess
forall sess. Session sess -> SessionId sess
sessionKey Session sess
session
  moldSession <- SessionId sess -> Redis (Maybe (Session sess))
forall sess.
RedisSession sess =>
SessionId sess -> Redis (Maybe (Session sess))
getSessionImpl SessionId sess
sid
  case moldSession of
    Maybe (Session sess)
Nothing -> StorageException (RedisStorage sess) -> Redis ()
forall sess a.
Storage (RedisStorage sess) =>
StorageException (RedisStorage sess) -> Redis a
throwRS (StorageException (RedisStorage sess) -> Redis ())
-> StorageException (RedisStorage sess) -> Redis ()
forall a b. (a -> b) -> a -> b
$ Session (SessionData (RedisStorage sess))
-> StorageException (RedisStorage sess)
forall sto. Session (SessionData sto) -> StorageException sto
SessionDoesNotExist Session sess
Session (SessionData (RedisStorage sess))
session
    Just Session sess
oldSession -> do
      RedisTx (Queued ()) -> Redis ()
transaction (RedisTx (Queued ()) -> Redis ())
-> RedisTx (Queued ()) -> Redis ()
forall a b. (a -> b) -> a -> b
$ do
        -- Delete the old session and set the new one.
        let sk :: ByteString
sk = SessionId sess -> ByteString
forall sess. SessionId sess -> ByteString
rSessionKey SessionId sess
sid
        _ <- [ByteString] -> RedisTx (Queued Integer)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
[ByteString] -> m (f Integer)
R.del [ByteString
sk]
        r <- batched (R.hmset sk) (printSession session)
        expireSession session sto

        -- Remove the old auth ID from the map if it has changed.
        let oldAuthId = Session sess -> Maybe ByteString
forall sess. Session sess -> Maybe ByteString
sessionAuthId Session sess
oldSession
            newAuthId = Session sess -> Maybe ByteString
forall sess. Session sess -> Maybe ByteString
sessionAuthId Session sess
session
        when (oldAuthId /= newAuthId) $ do
          removeSessionFromAuthId sid oldAuthId
          insertSessionForAuthId sid newAuthId

        return (() <$ r)


-- | Specialization of 'E.throwIO' for 'RedisStorage'.
throwRS
  :: Storage (RedisStorage sess)
  => StorageException (RedisStorage sess)
  -> R.Redis a
throwRS :: forall sess a.
Storage (RedisStorage sess) =>
StorageException (RedisStorage sess) -> Redis a
throwRS = IO a -> Redis a
forall a. IO a -> Redis a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> Redis a)
-> (StorageException (RedisStorage sess) -> IO a)
-> StorageException (RedisStorage sess)
-> Redis a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StorageException (RedisStorage sess) -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO


-- | Given a session, finds the next time the session will time out,
-- either by idle or absolute timeout and schedule the key in redis to
-- expire at that time. This is meant to be used on every write to a
-- session so that it is constantly setting the appropriate timeout.
expireSession :: Session sess -> RedisStorage sess -> R.RedisTx ()
expireSession :: forall sess. Session sess -> RedisStorage sess -> RedisTx ()
expireSession Session {Maybe ByteString
UTCTime
Decomposed sess
SessionId sess
sessionKey :: forall sess. Session sess -> SessionId sess
sessionAuthId :: forall sess. Session sess -> Maybe ByteString
sessionData :: forall sess. Session sess -> Decomposed sess
sessionCreatedAt :: forall sess. Session sess -> UTCTime
sessionAccessedAt :: forall sess. Session sess -> UTCTime
sessionKey :: SessionId sess
sessionAuthId :: Maybe ByteString
sessionData :: Decomposed sess
sessionCreatedAt :: UTCTime
sessionAccessedAt :: UTCTime
..} RedisStorage {Maybe NominalDiffTime
Connection
connPool :: forall sess. RedisStorage sess -> Connection
idleTimeout :: forall sess. RedisStorage sess -> Maybe NominalDiffTime
absoluteTimeout :: forall sess. RedisStorage sess -> Maybe NominalDiffTime
connPool :: Connection
idleTimeout :: Maybe NominalDiffTime
absoluteTimeout :: Maybe NominalDiffTime
..} =
  case [UTCTime] -> Maybe UTCTime
forall {a}. Ord a => [a] -> Maybe a
minimum' ([Maybe UTCTime] -> [UTCTime]
forall a. [Maybe a] -> [a]
catMaybes [Maybe UTCTime
viaIdle, Maybe UTCTime
viaAbsolute]) of
    Maybe UTCTime
Nothing -> () -> RedisTx ()
forall a. a -> RedisTx a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Just UTCTime
t -> let ts :: Integer
ts = NominalDiffTime -> Integer
forall b. Integral b => NominalDiffTime -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (UTCTime -> NominalDiffTime
TP.utcTimeToPOSIXSeconds UTCTime
t)
              in RedisTx (Queued Bool) -> RedisTx ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ByteString -> Integer -> RedisTx (Queued Bool)
forall (m :: * -> *) (f :: * -> *).
RedisCtx m f =>
ByteString -> Integer -> m (f Bool)
R.expireat ByteString
sk Integer
ts)
  where
    sk :: ByteString
sk = SessionId sess -> ByteString
forall sess. SessionId sess -> ByteString
rSessionKey SessionId sess
sessionKey
    minimum' :: [a] -> Maybe a
minimum' [] = Maybe a
forall a. Maybe a
Nothing
    minimum' [a]
xs = a -> Maybe a
forall a. a -> Maybe a
Just ([a] -> a
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [a]
xs)
    viaIdle :: Maybe UTCTime
viaIdle = (NominalDiffTime -> UTCTime -> UTCTime)
-> UTCTime -> NominalDiffTime -> UTCTime
forall a b c. (a -> b -> c) -> b -> a -> c
flip NominalDiffTime -> UTCTime -> UTCTime
TI.addUTCTime UTCTime
sessionAccessedAt (NominalDiffTime -> UTCTime)
-> Maybe NominalDiffTime -> Maybe UTCTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe NominalDiffTime
idleTimeout
    viaAbsolute :: Maybe UTCTime
viaAbsolute = (NominalDiffTime -> UTCTime -> UTCTime)
-> UTCTime -> NominalDiffTime -> UTCTime
forall a b c. (a -> b -> c) -> b -> a -> c
flip NominalDiffTime -> UTCTime -> UTCTime
TI.addUTCTime UTCTime
sessionCreatedAt  (NominalDiffTime -> UTCTime)
-> Maybe NominalDiffTime -> Maybe UTCTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe NominalDiffTime
absoluteTimeout