{-# LANGUAGE UndecidableInstances #-}
module Algebra.Monad.State (
MonadState(..),
StateT,State,StateRes(..),
stateT,state,
(<~),(=~),(=-),(^>=),gets,getl,saving,swapWith,
Next,Prev,
mapAccum,mapAccum_,mapAccumR,mapAccumR_,scanl,scanr,push,pop,withPrev,withNext,
StateA(..),stateA,
runAtomic
) where
import Algebra.Monad.RWS
import Algebra.Monad.Base
import Data.IORef
instance MonadState (IO ()) IO where
get = return unit
put a = a
modify f = put (f unit)
newtype StateT s m a = StateT (RWST Void Void s m a)
deriving (Unit,Functor,SemiApplicative,Applicative,MonadFix,
MonadTrans,MonadInternal,
MonadCont,MonadState s,MonadList)
type State s a = StateT s Id a
instance Monad m => Monad (StateT s m) where join = coerceJoin StateT
instance MonadReader r m => MonadReader r (StateT s m) where
ask = ask_ ; local = local_
instance MonadWriter w m => MonadWriter w (StateT s m) where
tell = tell_ ; listen = listen_ ; censor = censor_
instance (MonadCounter w acc m) => MonadCounter w acc (StateT s m) where
getCounter = getCounter_; setCounter = setCounter_
deriving instance MonadError e m => MonadError e (StateT s m)
deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a)
deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a)
deriving instance Semiring (m (a,s,Void)) => Semiring (StateT s m a)
deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a)
deriving instance (Monad m,MonadFuture n m) => MonadFuture n (StateT s m)
deriving instance ConcreteMonad (StateT s)
instance (MonadLogic l m) => MonadLogic (StateT s l) (StateT s m) where
deduce = coerceDeduce StateT StateT
induce = coerceInduce StateT StateT
_StateT :: Iso (StateT s m a) (StateT t n b) (RWST Void Void s m a) (RWST Void Void t n b)
_StateT = iso StateT (\ ~(StateT s) -> s)
stateT :: (Functor m,Functor n) => Iso (StateT s m a) (StateT t n b) (s -> m (s,a)) (t -> n (t,b))
stateT = mapping (mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a)))
.promapping i'_.i'RWST._StateT
class StateRes t s a | t -> s a where
evalS :: t -> a
execS :: t -> s
instance StateRes (s,a) s a where evalS = snd ; execS = fst
instance (StateRes r a b) => StateRes (Id r) a b where evalS = evalS . getId ; execS = execS . getId
instance (StateRes r a b) => StateRes (s -> r) (s -> a) (s -> b) where evalS = map evalS ; execS = map execS
instance Functor m => StateRes (StateT s m a) (s -> m s) (s -> m a) where
execS x = map2 execS (x^..stateT) ; evalS x = map2 evalS (x^..stateT)
state :: Iso (State s a) (State t b) (s -> (s,a)) (t -> (t,b))
state = mapping i'Id.stateT
(=-) :: MonadState s m => FixFold' s s' -> s' -> m ()
infixl 1 =-,=~,<~
(<~) :: MonadState s m => Lens' s a -> (a -> (a,b)) -> m b
(<~) l st = getl l >>= \a -> let (a',b) = st a in b <$ modify (l %- a')
swapWith :: MonadState s m => Lens' s a -> (a -> a) -> m a
swapWith l f = l <~ \a' -> (f a',a')
l =- x = modify (set l x)
(=~) :: MonadState s m => FixFold' s a -> (a -> a) -> m ()
l =~ f = modify (warp l f)
(^>=) :: MonadState s m => LensLike m a a s s -> (a -> m ()) -> m ()
l ^>= k = get >>= \s -> forl_ l s k
gets :: MonadState s m => (s -> a) -> m a
gets = (get<&>)
getl :: MonadState s m => Lens' s a -> m a
getl l = gets (by l)
saving :: MonadState s m => Lens' s s' -> m a -> m a
saving l st = getl l >>= \s -> st <* (l =- s)
newtype StateA m s a = StateA (StateT s m a)
stateA :: Iso (StateA m s a) (StateA m' s' a') (StateT s m a) (StateT s' m' a')
stateA = iso StateA (\(StateA s) -> s)
instance Monad m => Deductive (StateA m) where
StateA sbc . StateA sab = StateA $ (^.stateT) $ \a ->
(sab^..stateT) a >>= \(a',b) -> (a',).snd <$> (sbc^..stateT) b
instance Monad m => Category (StateA m) where
id = StateA get
instance Monad m => Split (StateA m) where
StateA sac <#> StateA sbd = StateA $ (^.stateT)
$ map2 (\((a',c),(b',d)) -> ((a',b'),(c,d)))
$ (Kleisli (sac^..stateT) <#> Kleisli (sbd^..stateT)) ^.. i'Kleisli
instance Monad m => Choice (StateA m) where
StateA sac <|> StateA sbc = StateA $ (^.stateT) $
l Left (sac^..stateT)<|>l Right (sbc^..stateT)
where l = map2 . first
mapAccum :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccum f t = traverse (by state<$>f) t^..state
mapAccum_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccum_ = (map.map.map) snd mapAccum
mapAccumR :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccumR f t = traverse (by (state.i'Backwards)<$>f) t^..state.i'Backwards
mapAccumR_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccumR_ = (map.map.map) snd mapAccumR
scanl :: Traversable t => (a -> b -> a) -> a -> t b -> t a
scanl f a tb = mapAccum_ (\b a' -> let x = f a' b in (x,x)) tb a
scanr :: Traversable t => (a -> b -> a) -> a -> t b -> t a
scanr f a tb = mapAccumR_ (\b a' -> let x = f a' b in (x,x)) tb a
push :: Traversable t => t a -> a -> t a
push = mapAccum_ (,)
pop :: Traversable t => t a -> a -> t a
pop = mapAccumR_ (,)
type Next a = a
type Prev a = a
withPrev :: Traversable t => a -> t a -> t (Prev a,a)
withPrev = flip (mapAccum_ (\a p -> (a,(p,a))))
withNext :: Traversable t => t a -> a -> t (a,Next a)
withNext = mapAccumR_ (\a p -> (a,(a,p)))
runAtomic :: IORef s -> State s a -> IO a
runAtomic r st = atomicModifyIORef r (yb state st)