{-# LANGUAGE UndecidableInstances, ScopedTypeVariables, DeriveGeneric #-} module Algebra.Monad.Free where import Algebra.Monad.Base import Unsafe.Coerce (unsafeCoerce) import qualified Control.DeepSeq as DSeq import GHC.Generics (Generic) data Free f a = Join (Forest f a) | Pure a deriving Generic type Forest f a = f (Free f a) deriving instance (Eq (f (Free f a)),Eq a) => Eq (Free f a) deriving instance (Ord (f (Free f a)),Ord a) => Ord (Free f a) deriving instance (Show (f (Free f a)),Show a) => Show (Free f a) instance (DSeq.NFData (Forest f a), DSeq.NFData a) => DSeq.NFData (Free f a) t'Join :: Traversal (f (Free f a)) (g (Free g a)) (Free f a) (Free g a) t'Join k (Join x) = Join<$>k x t'Join _ (Pure a) = pure (Pure a) t'Pure :: Traversal' (Free f a) a t'Pure k (Pure a) = Pure<$>k a t'Pure _ x = pure x instance Semigroup (f (Free f a)) => Semigroup (Free f a) where Join a + Join b = Join (a+b) Join a + _ = Join a a + _ = a instance Monoid (f (Free f a)) => Monoid (Free f a) where zero = Join zero instance Functor f => Functor (Free f) where map f (Join fa) = Join (map2 f fa) map f (Pure a) = Pure (f a) instance Unit (Free f) where pure = Pure instance Functor f => SemiApplicative (Free f) instance Functor f => Applicative (Free f) instance Functor f => Monad (Free f) where join (Join f) = Join (map join f) join (Pure f) = f instance Counit f => Counit (Free f) where extract (Join f) = extract (extract f) extract (Pure a) = a instance Comonad f => Comonad (Free f) where duplicate (Pure a) = Pure (Pure a) duplicate (Join f) = Join (f =>> liftF) instance MonadFix f => MonadFix (Free f) where mfix f = Join (Pure<$>mfix (\a -> perform (f a))) instance Foldable f => Foldable (Free f) where fold (Join f) = foldMap fold f fold (Pure a) = a instance Traversable f => Traversable (Free f) where sequence (Join f) = Join<$>(traverse sequence f) sequence (Pure a) = Pure<$>a instance Unit (Zip (Free f)) where pure = Zip . Pure instance (Functor f,SemiApplicative (Zip f)) => SemiApplicative (Zip (Free f)) where Zip (Join f) <*> Zip (Join x) = Zip (Join (zipWith zap f x)) Zip (Pure f) <*> Zip x = Zip (map f x) Zip f <*> Zip (Pure x) = Zip (map ($x) f) instance (Functor f,Applicative (Zip f)) => Applicative (Zip (Free f)) where instance MonadTrans Free where lift = liftF instance ConcreteMonad Free where generalize (Join f) = Join ((pure . generalize . getId) f) generalize (Pure a) = Pure a instance MonadState s m => MonadState s (Free m) where get = lift get put a = lift (put a) modify f = lift (modify f) instance MonadReader r m => MonadReader r (Free m) where ask = lift ask local f (Join m) = Join (local f m) local _ (Pure a) = Pure a instance MonadWriter w m => MonadWriter w (Free m) where tell w = lift (tell w) listen m = lift (listen (perform m)) censor m = lift (censor (perform m)) instance MonadCounter w a m => MonadCounter w a (Free m) where getCounter = lift getCounter ; setCounter c = lift (setCounter c) instance MonadIO m => MonadIO (Free m) where liftIO = lift . liftIO instance MonadList m => MonadList (Free m) where choose l = lift (choose l) instance MonadFuture m t => MonadFuture m (Free t) where future = lift . future instance MonadError e m => MonadError e (Free m) where throw e = lift (throw e) catch k m = lift (catch (map perform k) (perform m)) concrete :: Monad m => Free m a -> m (Free Id a) concrete = map Pure . perform unliftF :: Monad m => Free m a -> Free m (m a) unliftF = Pure . perform mapF :: (Functor f,Functor g) => (forall a. f a -> g a) -> Free f b -> Free g b mapF f (Join a) = Join (f (map (mapF f) a)) mapF _ (Pure a) = Pure a sequenceF :: (Traversable f,Monad g) => Free (g:.:f) a -> g (Free f a) sequenceF (Join (Compose gfa)) = map Join (gfa >>= \fa -> traverse sequenceF fa) sequenceF (Pure a) = pure (Pure a) traverseF :: (Functor f,Traversable f',Monad g) => (forall a. f a -> g (f' a)) -> Free f b -> g (Free f' b) traverseF f = sequenceF . mapF (\fa -> Compose (f fa)) class MonadFree m f | f -> m where step :: Monad m => f a -> m (f a) perform :: Monad m => f a -> m a liftF :: Functor m => m a -> f a instance MonadFree m (Free m) where step (Join j) = j step (Pure a) = pure (Pure a) perform (Join fa) = fa >>= perform perform (Pure a) = pure a liftF = Join . map Pure coerceStep :: forall m f g a. (Monad m,MonadFree m f) => (f a -> g a) -> (g a -> m (g a)) coerceStep _ = unsafeCoerce (step :: f a -> m (f a)) coercePerform :: forall m f g a. (Monad m,MonadFree m f) => (f a -> g a) -> (g a -> m a) coercePerform _ = unsafeCoerce (perform :: f a -> m a) coerceLiftF :: forall m f g a. (Functor m,MonadFree m f) => (f a -> g a) -> (m a -> g a) coerceLiftF _ = unsafeCoerce (liftF :: m a -> f a) data Cofree w a = Step a (Coforest w a) type Coforest w a = w (Cofree w a) deriving instance (Eq a,Eq (Coforest f a)) => Eq (Cofree f a) deriving instance (Ord a,Ord (Coforest f a)) => Ord (Cofree f a) deriving instance (Show a,Show (Coforest f a)) => Show (Cofree f a) instance Lens1 a a (Cofree f a) (Cofree f a) where l'1 k (Step x f) = k x <&> \x' -> Step x' f instance Lens2 (f (Cofree f a)) (f (Cofree f a)) (Cofree f a) (Cofree f a) where l'2 k (Step x f) = k f <&> Step x type Infinite a = Cofree Id a type Colist a = Cofree Maybe a instance Functor w => Functor (Cofree w) where map f (Step a wca) = Step (f a) (map2 f wca) instance Counit (Cofree w) where extract (Step a _) = a instance Functor w => Comonad (Cofree w) where duplicate d@(Step _ wca) = Step d (map duplicate wca) instance Foldable w => Foldable (Cofree w) where fold (Step a wca) = a + foldMap fold wca instance Traversable w => Traversable (Cofree w) where sequence (Step fa wcfa) = Step<$>fa<*>traverse sequence wcfa instance Unit m => Unit (Cofree m) where pure a = Step a (pure (pure a)) instance Applicative m => SemiApplicative (Cofree m) instance Applicative m => Applicative (Cofree m) instance Applicative m => Monad (Cofree m) where join (Step (Step a _) ww) = Step a (map join ww) type Bifree f a = Cofree (Free f) a newtype ContC k a b = ContC { runContC :: forall c. k b c -> k a c } contC :: (Category k,Category k') => Iso (ContC k a b) (ContC k' a' b') (k a b) (k' a' b') contC = iso (\x -> ContC (x >>>)) (($id) . runContC) instance Deductive (ContC k) where ContC cxbx . ContC bxax = ContC (\kcx -> bxax (cxbx kcx)) instance Category (ContC k) where id = ContC id