{-# LANGUAGE FlexibleContexts #-}

module Codegen.ExprGen.DataValue where

import qualified Ast.Types as AT
import qualified Codegen.Errors as CC
import {-# SOURCE #-} qualified Codegen.ExprGen.ExprGen as EG
import qualified Codegen.State as CS
import qualified Control.Monad.Except as E
import qualified Data.List as L
import qualified LLVM.AST as AST
import qualified LLVM.IRBuilder.Constant as IC
import qualified LLVM.IRBuilder.Instruction as I
import qualified Shared.Utils as SU

-- | Generate LLVM code for array access.
generateArrayAccess :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m AST.Operand
generateArrayAccess :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m Operand
generateArrayAccess (AT.ArrayAccess SrcLoc
_ Expr
arrayExpr Expr
indexExpr) = do
  Operand
arrayOperand <- Expr -> m Operand
forall a (m :: * -> *).
(ExprGen a, MonadCodegen m) =>
a -> m Operand
forall (m :: * -> *). MonadCodegen m => Expr -> m Operand
EG.generateExpr Expr
arrayExpr
  Operand
indexOperand <- Expr -> m Operand
forall a (m :: * -> *).
(ExprGen a, MonadCodegen m) =>
a -> m Operand
forall (m :: * -> *). MonadCodegen m => Expr -> m Operand
EG.generateExpr Expr
indexExpr
  Operand
ptr <- Operand -> [Operand] -> m Operand
forall (m :: * -> *).
(MonadIRBuilder m, MonadModuleBuilder m) =>
Operand -> [Operand] -> m Operand
I.gep Operand
arrayOperand [Operand
indexOperand]
  Operand -> Word32 -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Word32 -> m Operand
I.load Operand
ptr Word32
0
generateArrayAccess Expr
expr =
  CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError (Expr -> SrcLoc
SU.getLoc Expr
expr) (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CC.UnsupportedDefinition Expr
expr

-- | Generate LLVM code for struct access, recursively traversing all levels.
generateStructAccess :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m AST.Operand
generateStructAccess :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m Operand
generateStructAccess Expr
expr = do
  (Operand
ptr, Type
_) <- Expr -> m (Operand, Type)
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m (Operand, Type)
getStructFieldPointer Expr
expr
  Operand -> Word32 -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Word32 -> m Operand
I.load Operand
ptr Word32
0

-- | Get a pointer to a struct field.
getStructFieldPointer :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m (AST.Operand, AT.Type)
getStructFieldPointer :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m (Operand, Type)
getStructFieldPointer (AT.StructAccess SrcLoc
structLoc Expr
structExpr (AT.Var SrcLoc
_ String
fieldName Type
_)) = do
  (Operand
parentPtr, Type
parentType) <- Expr -> m (Operand, Type)
forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m (Operand, Type)
getStructFieldPointer Expr
structExpr
  case Type
parentType of
    AT.TStruct String
_ [(String, Type)]
structFields -> do
      Integer
fieldIndex <- case ((String, Type) -> Bool) -> [(String, Type)] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex ((String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
fieldName) (String -> Bool)
-> ((String, Type) -> String) -> (String, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String, Type) -> String
forall a b. (a, b) -> a
fst) [(String, Type)]
structFields of
        Just Int
index -> Integer -> m Integer
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> m Integer) -> Integer -> m Integer
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index
        Maybe Int
Nothing -> CodegenError -> m Integer
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Integer) -> CodegenError -> m Integer
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError SrcLoc
structLoc (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ String -> CodegenErrorType
CC.StructureFieldNotFound String
fieldName
      Operand
fieldPtr <- Operand -> [Operand] -> m Operand
forall (m :: * -> *).
(MonadIRBuilder m, MonadModuleBuilder m) =>
Operand -> [Operand] -> m Operand
I.gep Operand
parentPtr [Integer -> Operand
IC.int32 Integer
0, Integer -> Operand
IC.int32 Integer
fieldIndex]
      let fieldType :: Type
fieldType = (String, Type) -> Type
forall a b. (a, b) -> b
snd ((String, Type) -> Type) -> (String, Type) -> Type
forall a b. (a -> b) -> a -> b
$ [(String, Type)]
structFields [(String, Type)] -> Int -> (String, Type)
forall a. HasCallStack => [a] -> Int -> a
!! Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
fieldIndex
      (Operand, Type) -> m (Operand, Type)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand
fieldPtr, Type
fieldType)
    Type
_ -> CodegenError -> m (Operand, Type)
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m (Operand, Type))
-> CodegenError -> m (Operand, Type)
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError SrcLoc
structLoc (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CC.UnsupportedStructureAccess Expr
structExpr
getStructFieldPointer (AT.Var SrcLoc
structLoc String
structName Type
structType) = do
  Maybe Operand
maybeVar <- String -> m (Maybe Operand)
forall (m :: * -> *). VarBinding m => String -> m (Maybe Operand)
CS.getVar String
structName
  Operand
ptr <- case Maybe Operand
maybeVar of
    Just Operand
structPtr -> Operand -> m Operand
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Operand
structPtr
    Maybe Operand
Nothing -> CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError SrcLoc
structLoc (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ String -> CodegenErrorType
CC.VariableNotFound String
structName
  (Operand, Type) -> m (Operand, Type)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Operand
ptr, Type
structType)
getStructFieldPointer Expr
expr =
  CodegenError -> m (Operand, Type)
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m (Operand, Type))
-> CodegenError -> m (Operand, Type)
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError (Expr -> SrcLoc
SU.getLoc Expr
expr) (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CC.UnsupportedStructureAccess Expr
expr