{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.ByteArray.Parse
    ( Parser
    , Result(..)
    
    , parse
    , parseFeed
    
    , hasMore
    , byte
    , anyByte
    , bytes
    , take
    , takeWhile
    , takeAll
    , skip
    , skipWhile
    , skipAll
    , takeStorable
    ) where
import           Control.Monad
import           Foreign.Storable              (Storable, peek, sizeOf)
import           Data.Word
import           Data.Memory.Internal.Imports
import           Data.Memory.Internal.Compat
import           Data.ByteArray.Types          (ByteArrayAccess, ByteArray)
import qualified Data.ByteArray.Types     as B
import qualified Data.ByteArray.Methods   as B
import           Prelude hiding (take, takeWhile)
data Result byteArray a =
      ParseFail String
    | ParseMore (Maybe byteArray -> Result byteArray a)
    | ParseOK   byteArray a
instance (Show ba, Show a) => Show (Result ba a) where
    show (ParseFail err) = "ParseFailure: " ++ err
    show (ParseMore _)   = "ParseMore _"
    show (ParseOK b a)   = "ParseOK " ++ show a ++ " " ++ show b
type Failure byteArray r = byteArray -> String -> Result byteArray r
type Success byteArray a r = byteArray -> a -> Result byteArray r
newtype Parser byteArray a = Parser
    { runParser :: forall r . byteArray
                           -> Failure byteArray r
                           -> Success byteArray a r
                           -> Result byteArray r }
instance Functor (Parser byteArray) where
    fmap f p = Parser $ \buf err ok ->
        runParser p buf err (\b a -> ok b (f a))
instance Applicative (Parser byteArray) where
    pure      = return
    (<*>) d e = d >>= \b -> e >>= \a -> return (b a)
instance Monad (Parser byteArray) where
    fail errorMsg = Parser $ \buf err _ -> err buf ("Parser failed: " ++ errorMsg)
    return v      = Parser $ \buf _ ok -> ok buf v
    m >>= k       = Parser $ \buf err ok ->
         runParser m buf err (\buf' a -> runParser (k a) buf' err ok)
instance MonadPlus (Parser byteArray) where
    mzero = fail "MonadPlus.mzero"
    mplus f g = Parser $ \buf err ok ->
        
        runParser f buf (\_ _ -> runParser g buf err ok) ok
instance Alternative (Parser byteArray) where
    empty = fail "Alternative.empty"
    (<|>) = mplus
parseFeed :: (ByteArrayAccess byteArray, Monad m)
          => m (Maybe byteArray)
          -> Parser byteArray a
          -> byteArray
          -> m (Result byteArray a)
parseFeed feeder p initial = loop $ parse p initial
  where loop (ParseMore k) = feeder >>= (loop . k)
        loop r             = return r
parse :: ByteArrayAccess byteArray
      => Parser byteArray a -> byteArray -> Result byteArray a
parse p s = runParser p s (\_ msg -> ParseFail msg) (\b a -> ParseOK b a)
getMore :: ByteArray byteArray => Parser byteArray ()
getMore = Parser $ \buf err ok -> ParseMore $ \nextChunk ->
    case nextChunk of
        Nothing -> err buf "EOL: need more data"
        Just nc
            | B.null nc -> runParser getMore buf err ok
            | otherwise -> ok (B.append buf nc) ()
getAll :: ByteArray byteArray => Parser byteArray ()
getAll = Parser $ \buf err ok -> ParseMore $ \nextChunk ->
    case nextChunk of
        Nothing -> ok buf ()
        Just nc -> runParser getAll (B.append buf nc) err ok
flushAll :: ByteArray byteArray => Parser byteArray ()
flushAll = Parser $ \buf err ok -> ParseMore $ \nextChunk ->
    case nextChunk of
        Nothing -> ok buf ()
        Just _  -> runParser flushAll B.empty err ok
hasMore :: ByteArray byteArray => Parser byteArray Bool
hasMore = Parser $ \buf err ok ->
    if B.null buf
        then ParseMore $ \nextChunk ->
                case nextChunk of
                    Nothing -> ok buf False
                    Just nc -> runParser hasMore nc err ok
        else ok buf True
anyByte :: ByteArray byteArray => Parser byteArray Word8
anyByte = Parser $ \buf err ok ->
    case B.uncons buf of
        Nothing      -> runParser (getMore >> anyByte) buf err ok
        Just (c1,b2) -> ok b2 c1
byte :: ByteArray byteArray => Word8 -> Parser byteArray ()
byte w = Parser $ \buf err ok ->
    case B.uncons buf of
        Nothing      -> runParser (getMore >> byte w) buf err ok
        Just (c1,b2) | c1 == w   -> ok b2 ()
                     | otherwise -> err buf ("byte " ++ show w ++ " : failed : got " ++ show c1)
bytes :: (Show ba, Eq ba, ByteArray ba) => ba -> Parser ba ()
bytes allExpected = consumeEq allExpected
  where errMsg = "bytes " ++ show allExpected ++ " : failed"
        
        consumeEq expected = Parser $ \actual err ok ->
            let eLen = B.length expected in
            if B.length actual >= eLen
                then    
                        let (aMatch,aRem) = B.splitAt eLen actual
                         in if aMatch == expected
                                then ok aRem ()
                                else err actual errMsg
                else    
                        let (eMatch, eRem) = B.splitAt (B.length actual) expected
                         in if actual == eMatch
                                then runParser (getMore >> consumeEq eRem) B.empty err ok
                                else err actual errMsg
takeStorable :: (ByteArray byteArray, Storable d)
             => Parser byteArray d
takeStorable = anyStorable undefined
  where
    anyStorable :: ByteArray byteArray => Storable d => d -> Parser byteArray d
    anyStorable a = do
        buf <- take (sizeOf a)
        return $ unsafeDoIO $ B.withByteArray buf $ \ptr -> peek ptr
take :: ByteArray byteArray => Int -> Parser byteArray byteArray
take n = Parser $ \buf err ok ->
    if B.length buf >= n
        then let (b1,b2) = B.splitAt n buf in ok b2 b1
        else runParser (getMore >> take n) buf err ok
takeWhile :: ByteArray byteArray => (Word8 -> Bool) -> Parser byteArray byteArray
takeWhile predicate = Parser $ \buf err ok ->
    let (b1, b2) = B.span predicate buf
     in if B.null b2
            then runParser (getMore >> takeWhile predicate) buf err ok
            else ok b2 b1
takeAll :: ByteArray byteArray => Parser byteArray byteArray
takeAll = Parser $ \buf err ok ->
    runParser (getAll >> returnBuffer) buf err ok
  where
    returnBuffer = Parser $ \buf _ ok -> ok B.empty buf
skip :: ByteArray byteArray => Int -> Parser byteArray ()
skip n = Parser $ \buf err ok ->
    if B.length buf >= n
        then ok (B.drop n buf) ()
        else runParser (getMore >> skip (n - B.length buf)) B.empty err ok
skipWhile :: ByteArray byteArray => (Word8 -> Bool) -> Parser byteArray ()
skipWhile p = Parser $ \buf err ok ->
    let (_, b2) = B.span p buf
     in if B.null b2
            then runParser (getMore >> skipWhile p) B.empty err ok
            else ok b2 ()
skipAll :: ByteArray byteArray => Parser byteArray ()
skipAll = Parser $ \buf err ok -> runParser flushAll buf err ok