hbs2/hbs2-core/lib/HBS2/Net/Proto/Service.hs

298 lines
11 KiB
Haskell

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
module HBS2.Net.Proto.Service
( module HBS2.Net.Proto.Service
, module HBS2.Net.Proto.Types
) where
import HBS2.Actors.Peer
import HBS2.Net.Messaging.Unix
import HBS2.Net.Proto.Types
import HBS2.Prelude.Plated
import Codec.Serialise
import Control.Monad
import Control.Monad.Reader
import Control.Monad.Trans.Resource
import Data.ByteString.Lazy (ByteString)
import Data.Kind
import Data.List qualified as List
import GHC.TypeLits
-- import Lens.Micro.Platform
import UnliftIO.Async
import UnliftIO qualified as UIO
import UnliftIO (TVar,TQueue,atomically)
import System.Random (randomIO)
import Data.Word
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as HashMap
import Control.Exception (bracket_)
type family Input a :: Type
type family Output a :: Type
-- FIXME: wrap-those-instances
type instance Input () = ()
type instance Output () = ()
class (Monad m, Serialise (Output a), Serialise (Input a)) => HandleMethod m a where
handleMethod :: Input a -> m (Output a)
type family AllHandlers m (xs :: [Type]) :: Constraint where
AllHandlers m '[] = ()
AllHandlers m (x ': xs) = (HandleMethod m x, AllHandlers m xs)
data SomeHandler m = forall a . HandleMethod m a => SomeHandler ( Input a -> m (Output a) )
class Monad m => EnumAll (xs :: [Type]) t m where
enumMethods :: [t]
instance (Monad m, HandleMethod m ()) => EnumAll '[] (Int, SomeHandler m) m where
enumMethods = [(0, SomeHandler @m @() (\_ -> pure ()))]
instance (Monad m, EnumAll xs (Int, SomeHandler m) m, HandleMethod m x) => EnumAll (x ': xs) (Int, SomeHandler m) m where
enumMethods = (0, wtf) : shift (enumMethods @xs @(Int, SomeHandler m) @m)
where
wtf = SomeHandler @m @x (handleMethod @m @x)
shift = map (\(i, h) -> (i + 1, h))
instance Monad m => HandleMethod m () where
handleMethod _ = pure ()
data ServiceError =
ErrorMethodNotFound
| ErrorInvalidRequest
| ErrorInvalidResponse
deriving stock (Eq,Ord,Generic,Show,Typeable)
instance Serialise ServiceError
data ServiceProto api e =
ServiceRequest { reqNum :: Word64, reqData :: ByteString }
| ServiceResponse { reqNum :: Word64, reqResp :: Either ServiceError ByteString }
deriving stock (Generic,Show)
instance Serialise (ServiceProto api e)
dispatch :: forall api e m . ( EnumAll api (Int, SomeHandler m) m
, MonadIO m
, HasProtocol e (ServiceProto api e)
)
=> ServiceProto api e
-> m (ServiceProto api e)
dispatch (ServiceResponse n _) = do
pure $ ServiceResponse n (Left ErrorInvalidRequest)
dispatch (ServiceRequest rn lbs) = do
let ha = enumMethods @api @(Int, SomeHandler m) @m
let (n, bss) = deserialise @(Int, ByteString) lbs
maybe1 (List.lookup n ha) methodNotFound $ \(SomeHandler fn) -> do
case deserialiseOrFail bss of
Left{} -> pure $ ServiceResponse rn (Left ErrorInvalidRequest)
Right v -> ServiceResponse rn . Right . serialise <$> fn v
where
methodNotFound = pure (ServiceResponse rn (Left ErrorMethodNotFound))
type family FindMethodIndex (n :: Nat) (x :: Type) (xs :: [Type]) :: Maybe Nat where
FindMethodIndex _ x '[] = 'Nothing
FindMethodIndex n x (x ': xs) = 'Just n
FindMethodIndex n x (y ': xs) = FindMethodIndex (n + 1) x xs
type family FromJust (a :: Maybe k) :: k where
FromJust ('Just a) = a
FromJust 'Nothing = TypeError ('Text "Element not found")
findMethodIndex :: forall x xs. KnownNat (FromJust (FindMethodIndex 0 x xs)) => Integer
findMethodIndex = natVal (Proxy :: Proxy (FromJust (FindMethodIndex 0 x xs)))
makeRequest :: forall api method e . ( KnownNat (FromJust (FindMethodIndex 0 method api))
, Serialise (Input method)
)
=> Word64 -> Input method -> ServiceProto api e
makeRequest rnum input = ServiceRequest rnum (serialise (fromIntegral idx :: Int, serialise input))
where
idx = findMethodIndex @method @api
makeRequestR :: forall api method e m . ( KnownNat (FromJust (FindMethodIndex 0 method api))
, Serialise (Input method)
, MonadIO m
)
=> Input method -> m (ServiceProto api e)
makeRequestR input = do
rnum <- liftIO $ randomIO
pure $ ServiceRequest rnum (serialise (fromIntegral idx :: Int, serialise input))
where
idx = findMethodIndex @method @api
runWithContext :: r -> ReaderT r m a -> m a
runWithContext co m = runReaderT m co
makeServer :: forall api e m proto . ( MonadIO m
, EnumAll api (Int, SomeHandler m) m
, Response e (ServiceProto api e) m
, HasProtocol e proto
, HasDeferred proto e m
, Pretty (Peer e)
, proto ~ ServiceProto api e
)
=> ServiceProto api e
-> m ()
makeServer msg = do
deferred @proto $ dispatch @api @e msg >>= response
data ServiceCaller api e =
ServiceCaller
{ callPeer :: Peer e
, callInQ :: TQueue (ServiceProto api e)
, callWaiters :: TVar (HashMap Word64 (TQueue (ServiceProto api e)))
}
makeServiceCaller :: forall api e m . MonadIO m => Peer e -> m (ServiceCaller api e)
makeServiceCaller p = ServiceCaller p <$> UIO.newTQueueIO
<*> UIO.newTVarIO mempty
runServiceClient :: forall api e m . ( MonadIO m
, MonadUnliftIO m
, HasProtocol e (ServiceProto api e)
-- FIXME: remove-this-debug-shit
, Show (Peer e)
, Pretty (Peer e)
, PeerMessaging e
, HasOwnPeer e m
, HasFabriq e m
, HasTimeLimits e (ServiceProto api e) m
)
=> ServiceCaller api e
-> m ()
runServiceClient caller = do
proto <- async $ runProto @e [ makeResponse (makeClient @api caller) ]
link proto
forever do
req <- getRequest caller
request @e (callPeer caller) req
wait proto
data Endpoint e m = forall (api :: [Type]) . ( HasProtocol e (ServiceProto api e)
, HasTimeLimits e (ServiceProto api e) m
, PeerMessaging e
, Pretty (Peer e)
)
=> Endpoint (ServiceCaller api e)
runServiceClientMulti :: forall e m . ( MonadIO m
, MonadUnliftIO m
-- FIXME: remove-this-debug-shit
, Show (Peer e)
, Pretty (Peer e)
, PeerMessaging e
, HasOwnPeer e m
, HasFabriq e m
)
=> [ Endpoint e m ]
-> m ()
runServiceClientMulti endpoints = do
proto <- async $ runProto @e [ makeResponse @e (makeClient x) | (Endpoint x) <- endpoints ]
waiters <- forM endpoints $ \(Endpoint caller) -> async $ forever do
req <- getRequest caller
request @e (callPeer caller) req
r <- UIO.waitAnyCatchCancel $ proto : waiters
either UIO.throwIO (const $ pure ()) (snd r)
notifyServiceCaller :: forall api e m . MonadIO m
=> ServiceCaller api e
-> ServiceProto api e
-> m ()
notifyServiceCaller caller msg = do
waiter <- UIO.readTVarIO (callWaiters caller) <&> HashMap.lookup (reqNum msg)
maybe1 waiter none $ \q -> atomically $ UIO.writeTQueue q msg
getRequest :: forall api e m . MonadIO m
=> ServiceCaller api e
-> m (ServiceProto api e)
getRequest caller = atomically $ UIO.readTQueue (callInQ caller)
callService :: forall method api e m . ( MonadIO m
, HasProtocol e (ServiceProto api e)
, KnownNat (FromJust (FindMethodIndex 0 method api))
, Serialise (Input method)
, Serialise (Output method)
)
=> ServiceCaller api e
-> Input method
-> m (Either ServiceError (Output method))
callService caller input = do
req <- makeRequestR @api @method @e @m input
resp <- UIO.newTQueueIO
let addWaiter = atomically $ do
UIO.modifyTVar (callWaiters caller) (HashMap.insert (reqNum req) resp)
UIO.writeTQueue (callInQ caller) req
let removeWaiter = atomically $
UIO.modifyTVar (callWaiters caller) (HashMap.delete (reqNum req))
liftIO $ bracket_ addWaiter removeWaiter $ do
msg <- atomically $ UIO.readTQueue resp
case msg of
ServiceResponse _ (Right bs) ->
case deserialiseOrFail @(Output method) bs of
Left _ -> pure (Left ErrorInvalidResponse)
Right x -> pure (Right x)
ServiceResponse _ (Left wtf) -> pure (Left wtf)
_ -> pure (Left ErrorInvalidResponse)
callRpcWaitMay :: forall method (api :: [Type]) m e proto t . ( MonadUnliftIO m
, KnownNat (FromJust (FindMethodIndex 0 method api))
, HasProtocol e (ServiceProto api e)
, Serialise (Input method)
, Serialise (Output method)
, IsTimeout t
, proto ~ ServiceProto api e
)
=> Timeout t
-> ServiceCaller api e
-> Input method
-> m (Maybe (Output method))
callRpcWaitMay t caller args = do
race (pause t) (callService @method @api @e @m caller args)
>>= \case
Right (Right x) -> pure (Just x)
_ -> pure Nothing
makeClient :: forall api e m . ( MonadIO m
, HasProtocol e (ServiceProto api e)
, Pretty (Peer e)
)
=> ServiceCaller api e
-> ServiceProto api e
-> m ()
makeClient = notifyServiceCaller
instance (HasProtocol e (ServiceProto api e)) => HasTimeLimits e (ServiceProto api e) IO where
tryLockForPeriod _ _ = pure True