module Ast.Parser.Type where

import qualified Ast.Parser.State as PS
import qualified Ast.Parser.Utils as PU
import qualified Ast.Types as AT
import qualified Control.Monad.State as S
import qualified Data.Maybe as M
import qualified Text.Megaparsec as M
import qualified Text.Megaparsec.Char as MC
import qualified Text.Megaparsec.Char.Lexer as ML

-- | Parse a general type. This function combines multiple specific type parsers.
-- It tries to match typedefs, structs, unions, functions, mutable types, pointers, and base types.
parseType :: PU.Parser AT.Type
parseType :: Parser Type
parseType = [Parser Type] -> Parser Type
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, Alternative m) =>
f (m a) -> m a
M.choice [Parser Type -> Parser Type
forall a.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
M.try Parser Type
functionType, Parser Type
parseTermType]

-- | Parses a terminal type, which can be a base type, custom integer size, mutable type, array, pointer, or custom type.
-- Returns the parsed `AT.Type`.
parseTermType :: PU.Parser AT.Type
parseTermType :: Parser Type
parseTermType =
  [Parser Type] -> Parser Type
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, Alternative m) =>
f (m a) -> m a
M.choice
    [ Parser Type -> Parser Type
forall a.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall e s (m :: * -> *) a. MonadParsec e s m => m a -> m a
M.try Parser Type
customIntType,
      Parser Type
baseType,
      Parser Type
mutableType,
      Parser Type
arrayType,
      Parser Type
pointerType,
      Parser Type
customType
    ]

-- | A list of predefined base types along with their associated keywords.
-- These include basic types such as int, float, double, char, bool, and void.
baseTypes :: [(String, AT.Type)]
baseTypes :: [(String, Type)]
baseTypes =
  [ (String
"int", Int -> Type
AT.TInt Int
32),
    (String
"float", Type
AT.TFloat),
    (String
"double", Type
AT.TDouble),
    (String
"char", Type
AT.TChar),
    (String
"bool", Type
AT.TBoolean),
    (String
"never", Type
AT.TVoid),
    (String
"byte", Int -> Type
AT.TInt Int
8)
  ]

-- | Parses a user-defined integer size.
-- Example: "int128" would result in AT.TInt 128.
customIntType :: PU.Parser AT.Type
customIntType :: Parser Type
customIntType = do
  String
_ <- String -> Parser String
PU.symbol String
"int"
  Int -> Type
AT.TInt (Int -> Type)
-> ParsecT ParseErrorCustom String (StateT ParserState IO) Int
-> Parser Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ParsecT ParseErrorCustom String (StateT ParserState IO) Int
-> ParsecT ParseErrorCustom String (StateT ParserState IO) Int
forall a.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
PU.lexeme ParsecT ParseErrorCustom String (StateT ParserState IO) Int
forall e s (m :: * -> *) a.
(MonadParsec e s m, Token s ~ Char, Num a) =>
m a
ML.decimal

-- | Parses a base type by matching one of the predefined base type keywords.
-- Example: "int" or "bool".
baseType :: PU.Parser AT.Type
baseType :: Parser Type
baseType = [Parser Type] -> Parser Type
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, Alternative m) =>
f (m a) -> m a
M.choice ([Parser Type] -> Parser Type) -> [Parser Type] -> Parser Type
forall a b. (a -> b) -> a -> b
$ (\(String
kw, Type
ty) -> Type
ty Type -> Parser String -> Parser Type
forall a b.
a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) b
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ String -> Parser String
PU.symbol String
kw) ((String, Type) -> Parser Type)
-> [(String, Type)] -> [Parser Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(String, Type)]
baseTypes

-- | Parses a pointer type.
-- A pointer type is denoted by a '*' followed by another type.
-- Example: "*int" results in a pointer to an integer.
pointerType :: PU.Parser AT.Type
pointerType :: Parser Type
pointerType = do
  Char
_ <- Token String
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Token String)
forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
Token s -> m (Token s)
MC.char Char
Token String
'*'
  Type -> Type
AT.TPointer (Type -> Type) -> Parser Type -> Parser Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser Type
parseTermType

-- | Parses a mutable type.
-- A mutable type is prefixed by the keyword "mut" followed by the type.
-- Example: "mut int" indicates a mutable integer type.
mutableType :: PU.Parser AT.Type
mutableType :: Parser Type
mutableType = Type -> Type
AT.TMutable (Type -> Type) -> Parser Type -> Parser Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> Parser String
PU.symbol String
"mut" Parser String -> Parser Type -> Parser Type
forall a b.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) b
-> ParsecT ParseErrorCustom String (StateT ParserState IO) b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Type
parseType)

-- | Parses an array type.
-- An array type is denoted by square brackets "[]" followed by the type.
-- Example: "[]int" results in an array of integers.
arrayType :: PU.Parser AT.Type
arrayType :: Parser Type
arrayType = do
  Maybe Int
size <- ParsecT ParseErrorCustom String (StateT ParserState IO) Char
-> ParsecT ParseErrorCustom String (StateT ParserState IO) Char
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Maybe Int)
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Maybe Int)
forall (m :: * -> *) open close a.
Applicative m =>
m open -> m close -> m a -> m a
M.between (Token String
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Token String)
forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
Token s -> m (Token s)
MC.char Char
Token String
'[') (Token String
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Token String)
forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
Token s -> m (Token s)
MC.char Char
Token String
']') (ParsecT
   ParseErrorCustom String (StateT ParserState IO) (Maybe Int)
 -> ParsecT
      ParseErrorCustom String (StateT ParserState IO) (Maybe Int))
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Maybe Int)
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Maybe Int)
forall a b. (a -> b) -> a -> b
$ ParsecT ParseErrorCustom String (StateT ParserState IO) Int
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Maybe Int)
forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
M.optional ParsecT ParseErrorCustom String (StateT ParserState IO) Int
forall e s (m :: * -> *) a.
(MonadParsec e s m, Token s ~ Char, Num a) =>
m a
ML.decimal
  Type
elemType <- Parser Type
parseType
  Type -> Parser Type
forall a.
a -> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Parser Type) -> Type -> Parser Type
forall a b. (a -> b) -> a -> b
$ Type -> Maybe Int -> Type
AT.TArray Type
elemType Maybe Int
size

-- | Parses a function type.
-- A function type is defined by its parameter types, followed by "->", and the return type also enclosed in parentheses.
-- Example: "int -> float" or "int int -> void".
-- TODO: find a way to do it without the parenthesis and avoid the infinite loop of parseType
functionType :: PU.Parser AT.Type
functionType :: Parser Type
functionType = do
  [Type]
paramTypes <- Parser Type
-> ParsecT ParseErrorCustom String (StateT ParserState IO) [Type]
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
M.some (Parser Type
 -> ParsecT ParseErrorCustom String (StateT ParserState IO) [Type])
-> Parser Type
-> ParsecT ParseErrorCustom String (StateT ParserState IO) [Type]
forall a b. (a -> b) -> a -> b
$ Parser Type -> Parser Type
forall a.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
PU.lexeme (Parser Type -> Parser Type) -> Parser Type -> Parser Type
forall a b. (a -> b) -> a -> b
$ Parser Type
functionParser Parser Type -> Parser Type -> Parser Type
forall a.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
M.<|> Parser Type
parseTermType
  Bool
variadic <- Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
M.fromMaybe Bool
False (Maybe Bool -> Bool)
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Maybe Bool)
-> ParsecT ParseErrorCustom String (StateT ParserState IO) Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ParsecT ParseErrorCustom String (StateT ParserState IO) Bool
-> ParsecT
     ParseErrorCustom String (StateT ParserState IO) (Maybe Bool)
forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
M.optional (Bool
True Bool
-> Parser String
-> ParsecT ParseErrorCustom String (StateT ParserState IO) Bool
forall a b.
a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) b
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ String -> Parser String
PU.symbol String
"...")
  Type
returnType <- String -> Parser String
PU.symbol String
"->" Parser String -> Parser Type -> Parser Type
forall a b.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) b
-> ParsecT ParseErrorCustom String (StateT ParserState IO) b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Type -> Parser Type
forall a.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
PU.lexeme (Parser Type
functionParser Parser Type -> Parser Type -> Parser Type
forall a.
ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
-> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
M.<|> Parser Type
parseTermType)
  Type -> Parser Type
forall a.
a -> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Parser Type) -> Type -> Parser Type
forall a b. (a -> b) -> a -> b
$ AT.TFunction {returnType :: Type
AT.returnType = Type
returnType, paramTypes :: [Type]
AT.paramTypes = [Type]
paramTypes, isVariadic :: Bool
AT.isVariadic = Bool
variadic}
  where
    functionParser :: Parser Type
functionParser = Parser String -> Parser String -> Parser Type -> Parser Type
forall (m :: * -> *) open close a.
Applicative m =>
m open -> m close -> m a -> m a
M.between (String -> Parser String
PU.symbol String
"(") (String -> Parser String
PU.symbol String
")") Parser Type
functionType

-- | Parses a custom user-defined type by its name.
-- If the type is not found in the environment, an `UnknownType` error is raised.
customType :: PU.Parser AT.Type
customType :: Parser Type
customType = do
  String
name <- Parser String
PU.identifier
  ParserState
env <- ParsecT ParseErrorCustom String (StateT ParserState IO) ParserState
forall s (m :: * -> *). MonadState s m => m s
S.get
  case String -> ParserState -> Maybe Type
PS.lookupType String
name ParserState
env of
    Just Type
ty -> Type -> Parser Type
forall a.
a -> ParsecT ParseErrorCustom String (StateT ParserState IO) a
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty
    Maybe Type
Nothing -> ParseErrorCustom -> Parser Type
forall e s (m :: * -> *) a. MonadParsec e s m => e -> m a
M.customFailure (ParseErrorCustom -> Parser Type)
-> ParseErrorCustom -> Parser Type
forall a b. (a -> b) -> a -> b
$ String -> ParseErrorCustom
PU.UnknownType String
name