{-# LANGUAGE UndecidableInstances, ScopedTypeVariables #-}
module Algebra.Monad.Logic where

import Algebra.Monad.Base
import Algebra.Monad.Writer

newtype LogicT m a = LogicT { runLogicT :: forall r. (a -> m r -> m r) -> m r -> m r }

instance Functor (LogicT m) where
  map f (LogicT l) = LogicT (\k -> l (\a -> k (f a)))
instance Unit (LogicT m) where
  pure a = LogicT ($a)
instance SemiApplicative (LogicT m)
instance Applicative (LogicT m)
instance Monad (LogicT m) where
  join (LogicT l) = LogicT (\k -> l (\(LogicT l') -> l' k))
instance MonadFix m => MonadFix (LogicT m) where
  mfix f = map pure (mfix (map head . (^..listLogic) . f))^.listLogic
instance MonadTrans LogicT where
  lift ma = LogicT (\k mr -> ma >>= \a -> k a mr)
instance (Monad m,Foldable m) => Foldable (LogicT m) where
  fold (LogicT l) = fold $ l (\a m -> map (a+) m) (pure zero) 
instance (Monad m,Traversable m) => Traversable (LogicT m) where
  sequence l = traverse sequence (l^..listLogic) <&> (^.listLogic)

instance Semigroup (LogicT m a) where
  LogicT l + LogicT l' = LogicT (\k -> l k . l' k)
instance Monoid (LogicT m a) where
  zero = LogicT (pure id)
instance Semigroup a => Semiring (LogicT m a) where
  (*) = plusA
instance Monoid a => Ring (LogicT m a) where
  one = zeroA

instance MonadState s m => MonadState s (LogicT m) where
  get = lift get
  modify f = lift (modify f)
instance MonadReader r m => MonadReader r (LogicT m) where
  ask = lift ask
  local f (LogicT l) = LogicT (\k mr -> local f (l k mr))
instance MonadWriter w m => MonadWriter w (LogicT m) where
  tell = lift . tell
  listen l = induce (listen (deduce l) <&> \(w,ml) -> map ((w,) <#> listen) ml)
  censor l = induce (censor (deduce l <&> \ml -> case ml of
                                Just ((a,f),l') -> (Just (a,censor l'),f)
                                Nothing -> (Nothing,id)))
instance Monad m => MonadError Void (LogicT m) where
  throw _ = zero
  catch z (LogicT l) = LogicT l'
    where l' k mr = l k' (Left<$>mr) >>= \x -> case x of
            Left r -> runLogicT (z zero) k (pure r)
            Right r -> return r
            where k' a m = Right <$> k a (map (id<|>id) m)

instance Monad m => MonadLogic m (LogicT m) where
  deduce l = runLogicT l (\a m -> pure (pure (a,induce m))) (pure zero)
  induce mm = LogicT (\k m -> mm >>= maybe m (\(a,l) -> k a (runLogicT l k m)))

listLogic :: (MonadLogic m l,MonadLogic n l') => Iso (l a) (l' b) (m [a]) (n [b])
listLogic = iso alts deduceAll
  where alts m = induce (m <&> \l -> case l of
          [] -> Nothing
          (a:t) -> Just (a,alts (pure t)))
deduction :: (MonadLogic m l,MonadLogic m' l') => Iso (m (Maybe (a,l a))) (m' (Maybe (b,l' b))) (l a) (l' b) 
deduction = iso deduce induce

deduceMany :: MonadLogic m l => Int -> l a -> m [a]
deduceMany 0 _ = pure []
deduceMany n l = deduce l >>= maybe (pure []) (\(a,t) -> (a:)<$>deduceMany (n-1) t)
deduceAll :: MonadLogic m l => l a -> m [a]
deduceAll l = deduce l >>= maybe (pure []) (\(a,t) -> (a:)<$>deduceAll t)

logicChoose :: MonadLogic m l => [a] -> l a
logicChoose l = pure l^.listLogic

cut :: MonadLogic m l => l a -> l a
cut = deduction %~ map (traverse.l'2.deduction %- pure Nothing)