{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}

module Codegen.ExprGen.Function where

import qualified Ast.Types as AT
import qualified Codegen.Errors as CE
import {-# SOURCE #-} qualified Codegen.ExprGen.ExprGen as EG
import qualified Codegen.ExprGen.Types as ET
import qualified Codegen.State as CS
import qualified Codegen.Utils as U
import qualified Control.Monad as CM
import qualified Control.Monad.Except as E
import qualified Control.Monad.State as S
import qualified Data.Foldable as FD
import qualified Data.Maybe as M
import qualified LLVM.AST as AST
import qualified LLVM.AST.Constant as C
import qualified LLVM.AST.Type as T
import qualified LLVM.IRBuilder.Instruction as I
import qualified LLVM.IRBuilder.Module as M
import qualified Shared.Utils as SU

-- | Generate LLVM code for function definitions.
generateFunction :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m AST.Operand
generateFunction :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m Operand
generateFunction (AT.Function SrcLoc
_ String
name (AT.TFunction Type
ret [Type]
params Bool
var) [String]
paramNames Expr
body) = do
  let funcName :: Name
funcName = ShortByteString -> Name
AST.Name (ShortByteString -> Name) -> ShortByteString -> Name
forall a b. (a -> b) -> a -> b
$ String -> ShortByteString
U.stringToByteString String
name
      paramTypes :: [(Type, ParameterName)]
paramTypes = (Type -> String -> (Type, ParameterName))
-> [Type] -> [String] -> [(Type, ParameterName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> String -> (Type, ParameterName)
mkParam [Type]
params [String]
paramNames
      funcType :: Type
funcType = Type -> Type
T.ptr (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> [Type] -> Bool -> Type
T.FunctionType (Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
ret) (((Type, ParameterName) -> Type)
-> [(Type, ParameterName)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type, ParameterName) -> Type
forall a b. (a, b) -> a
fst [(Type, ParameterName)]
paramTypes) Bool
var
  String -> Operand -> m ()
forall (m :: * -> *). VarBinding m => String -> Operand -> m ()
CS.addGlobalVar String
name (Operand -> m ()) -> Operand -> m ()
forall a b. (a -> b) -> a -> b
$ Constant -> Operand
AST.ConstantOperand (Constant -> Operand) -> Constant -> Operand
forall a b. (a -> b) -> a -> b
$ Type -> Name -> Constant
C.GlobalReference Type
funcType Name
funcName
  Name
-> [(Type, ParameterName)]
-> Type
-> ([Operand] -> IRBuilderT m ())
-> m Operand
forall (m :: * -> *).
MonadModuleBuilder m =>
Name
-> [(Type, ParameterName)]
-> Type
-> ([Operand] -> IRBuilderT m ())
-> m Operand
M.function Name
funcName [(Type, ParameterName)]
paramTypes (Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
ret) (([Operand] -> IRBuilderT m ()) -> m Operand)
-> ([Operand] -> IRBuilderT m ()) -> m Operand
forall a b. (a -> b) -> a -> b
$ \[Operand]
ops -> do
    (CodegenState -> CodegenState) -> IRBuilderT m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
S.modify (\CodegenState
s -> CodegenState
s {localState :: LocalState
CS.localState = []})
    (String -> Operand -> IRBuilderT m ())
-> [String] -> [Operand] -> IRBuilderT m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
S.zipWithM_ String -> Operand -> IRBuilderT m ()
forall (m :: * -> *). VarBinding m => String -> Operand -> m ()
CS.addVar [String]
paramNames [Operand]
ops
    LocalState
oldAllocatedVars <- (CodegenState -> LocalState) -> IRBuilderT m LocalState
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
S.gets CodegenState -> LocalState
CS.allocatedVars
    Expr -> IRBuilderT m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
body
    Operand
result <- Expr -> IRBuilderT m Operand
forall a (m :: * -> *).
(ExprGen a, MonadCodegen m) =>
a -> m Operand
forall (m :: * -> *). MonadCodegen m => Expr -> m Operand
EG.generateExpr Expr
body
    case Type
ret of
      Type
AT.TVoid -> IRBuilderT m ()
forall (m :: * -> *). MonadIRBuilder m => m ()
I.retVoid
      Type
_ -> Operand -> IRBuilderT m ()
forall (m :: * -> *). MonadIRBuilder m => Operand -> m ()
I.ret Operand
result
    (CodegenState -> CodegenState) -> IRBuilderT m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
S.modify (\CodegenState
s -> CodegenState
s {allocatedVars :: LocalState
CS.allocatedVars = LocalState
oldAllocatedVars})
  where
    mkParam :: Type -> String -> (Type, ParameterName)
mkParam (AT.TFunction Type
retType [Type]
paramTypes Bool
isVar) String
n =
      ( Type -> Type
T.ptr (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$
          Type -> [Type] -> Bool -> Type
T.FunctionType
            (Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
retType)
            ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM [Type]
paramTypes)
            Bool
isVar,
        ShortByteString -> ParameterName
M.ParameterName (ShortByteString -> ParameterName)
-> ShortByteString -> ParameterName
forall a b. (a -> b) -> a -> b
$ String -> ShortByteString
U.stringToByteString String
n
      )
    mkParam Type
t String
n = (Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
t, ShortByteString -> ParameterName
M.ParameterName (ShortByteString -> ParameterName)
-> ShortByteString -> ParameterName
forall a b. (a -> b) -> a -> b
$ String -> ShortByteString
U.stringToByteString String
n)
generateFunction Expr
expr =
  CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CE.CodegenError (Expr -> SrcLoc
SU.getLoc Expr
expr) (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CE.UnsupportedDefinition Expr
expr

-- | Pre-allocate variables before generating code.
preAllocateVars :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m ()
preAllocateVars :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars (AT.Assignment SrcLoc
_ (AT.Var SrcLoc
_ String
name Type
typ) Expr
_) = do
  let llvmType :: Type
llvmType = Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
typ
  Maybe Operand
existingVar <- String -> m (Maybe Operand)
forall (m :: * -> *). VarBinding m => String -> m (Maybe Operand)
CS.getVar String
name
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
CM.when (Maybe Operand -> Bool
forall a. Maybe a -> Bool
M.isNothing Maybe Operand
existingVar) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Operand
ptr <- Type -> Maybe Operand -> Word32 -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Type -> Maybe Operand -> Word32 -> m Operand
I.alloca Type
llvmType Maybe Operand
forall a. Maybe a
Nothing Word32
0
    (CodegenState -> CodegenState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
S.modify (\CodegenState
s -> CodegenState
s {allocatedVars :: LocalState
CS.allocatedVars = (String
name, Operand
ptr) (String, Operand) -> LocalState -> LocalState
forall a. a -> [a] -> [a]
: CodegenState -> LocalState
CS.allocatedVars CodegenState
s})
preAllocateVars (AT.Declaration SrcLoc
_ String
name Type
typ Maybe Expr
init') = do
  let llvmType :: Type
llvmType = Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
typ
  Maybe Operand
existingVar <- String -> m (Maybe Operand)
forall (m :: * -> *). VarBinding m => String -> m (Maybe Operand)
CS.getVar String
name
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
CM.when (Maybe Operand -> Bool
forall a. Maybe a -> Bool
M.isNothing Maybe Operand
existingVar) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Operand
ptr <- Type -> Maybe Operand -> Word32 -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Type -> Maybe Operand -> Word32 -> m Operand
I.alloca Type
llvmType Maybe Operand
forall a. Maybe a
Nothing Word32
0
    (CodegenState -> CodegenState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
S.modify (\CodegenState
s -> CodegenState
s {allocatedVars :: LocalState
CS.allocatedVars = (String
name, Operand
ptr) (String, Operand) -> LocalState -> LocalState
forall a. a -> [a] -> [a]
: CodegenState -> LocalState
CS.allocatedVars CodegenState
s})
    Maybe Expr -> (Expr -> m ()) -> m ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
FD.for_ Maybe Expr
init' Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars
preAllocateVars (AT.Block [Expr]
exprs) = (Expr -> m ()) -> [Expr] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars [Expr]
exprs
preAllocateVars (AT.If SrcLoc
_ Expr
cond Expr
thenExpr Maybe Expr
elseExpr) = do
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
cond
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
thenExpr
  m () -> (Expr -> m ()) -> Maybe Expr -> m ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Maybe Expr
elseExpr
preAllocateVars (AT.From SrcLoc
_ Expr
startExpr Expr
endExpr Expr
stepExpr Expr
varExpr Expr
bodyExpr) = do
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
startExpr
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
endExpr
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
stepExpr
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
varExpr
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
bodyExpr
preAllocateVars (AT.While SrcLoc
_ Expr
condExpr Expr
bodyExpr) = do
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
condExpr
  Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
bodyExpr
preAllocateVars (AT.Break SrcLoc
_) = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
preAllocateVars (AT.Continue SrcLoc
_) = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
preAllocateVars (AT.Assignment SrcLoc
_ Expr
_ Expr
valueExpr) = Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
valueExpr
preAllocateVars (AT.Function SrcLoc
_ String
_ Type
_ [String]
_ Expr
bodyExpr) = Expr -> m ()
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m ()
preAllocateVars Expr
bodyExpr
preAllocateVars Expr
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Generate LLVM code for foreign function definitions.
generateForeignFunction :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m AST.Operand
generateForeignFunction :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m Operand
generateForeignFunction (AT.ForeignFunction SrcLoc
_ String
name (AT.TFunction Type
ret [Type]
params Bool
var)) = do
  let funcType :: Type
funcType = Type -> Type
T.ptr (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> [Type] -> Bool -> Type
T.FunctionType (Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
ret) ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM [Type]
params) Bool
var
      funcName :: Name
funcName = ShortByteString -> Name
AST.Name (ShortByteString -> Name) -> ShortByteString -> Name
forall a b. (a -> b) -> a -> b
$ String -> ShortByteString
U.stringToByteString String
name

  Operand
_ <-
    (if Bool
var then Name -> [Type] -> Type -> m Operand
forall (m :: * -> *).
MonadModuleBuilder m =>
Name -> [Type] -> Type -> m Operand
M.externVarArgs else Name -> [Type] -> Type -> m Operand
forall (m :: * -> *).
MonadModuleBuilder m =>
Name -> [Type] -> Type -> m Operand
M.extern)
      Name
funcName
      ((Type -> Bool) -> [Type] -> [Type]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
T.void) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM [Type]
params)
      (Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
ret)

  String -> Operand -> m ()
forall (m :: * -> *). VarBinding m => String -> Operand -> m ()
CS.addGlobalVar String
name (Operand -> m ()) -> Operand -> m ()
forall a b. (a -> b) -> a -> b
$ Constant -> Operand
AST.ConstantOperand (Constant -> Operand) -> Constant -> Operand
forall a b. (a -> b) -> a -> b
$ Type -> Name -> Constant
C.GlobalReference Type
funcType Name
funcName

  Operand -> m Operand
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Operand -> m Operand) -> Operand -> m Operand
forall a b. (a -> b) -> a -> b
$ Constant -> Operand
AST.ConstantOperand (Constant -> Operand) -> Constant -> Operand
forall a b. (a -> b) -> a -> b
$ Type -> Name -> Constant
C.GlobalReference Type
funcType Name
funcName
generateForeignFunction Expr
expr =
  CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CE.CodegenError (Expr -> SrcLoc
SU.getLoc Expr
expr) (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CE.UnsupportedDefinition Expr
expr

-- | Generate LLVM code for function calls.
generateFunctionCall :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m AST.Operand
generateFunctionCall :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m Operand
generateFunctionCall (AT.Call SrcLoc
loc (AT.Var SrcLoc
_ String
funcName Type
_) [Expr]
args) = do
  Maybe Operand
maybeFunc <- String -> m (Maybe Operand)
forall (m :: * -> *). VarBinding m => String -> m (Maybe Operand)
CS.getVar String
funcName
  case Maybe Operand
maybeFunc of
    Just Operand
funcOperand -> do
      [Operand]
argOperands <- (Expr -> m Operand) -> [Expr] -> m [Operand]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Expr -> m Operand
forall a (m :: * -> *).
(ExprGen a, MonadCodegen m) =>
a -> m Operand
forall (m :: * -> *). MonadCodegen m => Expr -> m Operand
EG.generateExpr [Expr]
args
      Operand -> [(Operand, [ParameterAttribute])] -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> [(Operand, [ParameterAttribute])] -> m Operand
I.call Operand
funcOperand ((Operand -> (Operand, [ParameterAttribute]))
-> [Operand] -> [(Operand, [ParameterAttribute])]
forall a b. (a -> b) -> [a] -> [b]
map (,[]) [Operand]
argOperands)
    Maybe Operand
Nothing ->
      CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CE.CodegenError SrcLoc
loc (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ String -> CodegenErrorType
CE.UnsupportedFunctionCall String
funcName
generateFunctionCall Expr
expr =
  CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CE.CodegenError (Expr -> SrcLoc
SU.getLoc Expr
expr) (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CE.UnsupportedDefinition Expr
expr