{-# LANGUAGE FlexibleContexts #-}

module Codegen.ExprGen.Cast where

import qualified Ast.Types as AT
import qualified Codegen.Errors as CC
import {-# SOURCE #-} qualified Codegen.ExprGen.ExprGen as EG
import qualified Codegen.ExprGen.Types as ET
import qualified Codegen.State as CS
import qualified Control.Monad.Except as E
import qualified LLVM.AST as AST
import qualified LLVM.AST.Constant as C
import qualified LLVM.AST.IntegerPredicate as IP
import qualified LLVM.AST.Type as T
import qualified LLVM.AST.Typed as TD
import qualified LLVM.IRBuilder.Instruction as I
import qualified Shared.Utils as SU

-- | Generate LLVM code for type casts.
generateCast :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m AST.Operand
generateCast :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m Operand
generateCast (AT.Cast SrcLoc
_ Type
typ Expr
expr) = do
  Operand
operand <- Expr -> m Operand
forall a (m :: * -> *).
(ExprGen a, MonadCodegen m) =>
a -> m Operand
forall (m :: * -> *). MonadCodegen m => Expr -> m Operand
EG.generateExpr Expr
expr
  SrcLoc -> Operand -> Type -> Type -> m Operand
forall (m :: * -> *).
MonadCodegen m =>
SrcLoc -> Operand -> Type -> Type -> m Operand
llvmCast (Expr -> SrcLoc
SU.getLoc Expr
expr) Operand
operand (Operand -> Type
forall a. Typed a => a -> Type
TD.typeOf Operand
operand) (Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM Type
typ)
generateCast Expr
expr =
  CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError (Expr -> SrcLoc
SU.getLoc Expr
expr) (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CC.UnsupportedDefinition Expr
expr

-- | Cast an operand to match a desired LLVM type.
llvmCast :: (CS.MonadCodegen m) => AT.SrcLoc -> AST.Operand -> T.Type -> T.Type -> m AST.Operand
llvmCast :: forall (m :: * -> *).
MonadCodegen m =>
SrcLoc -> Operand -> Type -> Type -> m Operand
llvmCast SrcLoc
loc Operand
operand Type
fromType Type
toType = case (Type
fromType, Type
toType) of
  (T.IntegerType Word32
fromBits, T.IntegerType Word32
toBits)
    | Word32
fromBits Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
< Word32
toBits -> Operand -> Type -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Type -> m Operand
I.zext Operand
operand Type
toType
    | Word32
fromBits Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
> Word32
toBits -> Operand -> Type -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Type -> m Operand
I.trunc Operand
operand Type
toType
    | Bool
otherwise -> Operand -> m Operand
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operand
operand
  (T.FloatingPointType FloatingPointType
fromFP, T.FloatingPointType FloatingPointType
toFP)
    | FloatingPointType -> FloatingPointType -> Bool
isLargerFP FloatingPointType
fromFP FloatingPointType
toFP -> Operand -> Type -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Type -> m Operand
I.fpext Operand
operand Type
toType
    | FloatingPointType -> FloatingPointType -> Bool
isSmallerFP FloatingPointType
fromFP FloatingPointType
toFP -> Operand -> Type -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Type -> m Operand
I.fptrunc Operand
operand Type
toType
    | Bool
otherwise -> Operand -> m Operand
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operand
operand
  (T.IntegerType Word32
_, T.FloatingPointType FloatingPointType
_) -> Operand -> Type -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Type -> m Operand
I.sitofp Operand
operand Type
toType
  (T.FloatingPointType FloatingPointType
_, T.IntegerType Word32
_) -> Operand -> Type -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Type -> m Operand
I.fptosi Operand
operand Type
toType
  (Type
x, Type
y) | Type -> Type -> Bool
isBitcastable Type
x Type
y -> Operand -> Type -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Operand -> Type -> m Operand
I.bitcast Operand
operand Type
toType
  (Type, Type)
_ -> CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError SrcLoc
loc (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Type -> Type -> CodegenErrorType
CC.UnsupportedConversion Type
fromType Type
toType
  where
    isLargerFP :: FloatingPointType -> FloatingPointType -> Bool
isLargerFP FloatingPointType
T.FloatFP FloatingPointType
T.DoubleFP = Bool
True
    isLargerFP FloatingPointType
_ FloatingPointType
_ = Bool
False

    isSmallerFP :: FloatingPointType -> FloatingPointType -> Bool
isSmallerFP FloatingPointType
T.DoubleFP FloatingPointType
T.FloatFP = Bool
True
    isSmallerFP FloatingPointType
_ FloatingPointType
_ = Bool
False

    isBitcastable :: Type -> Type -> Bool
isBitcastable (T.PointerType Type
_ AddrSpace
_) (T.PointerType Type
_ AddrSpace
_) = Bool
True
    isBitcastable (T.ArrayType Word64
_ Type
_) (T.PointerType Type
_ AddrSpace
_) = Bool
True
    isBitcastable (T.ArrayType Word64
_ Type
_) (T.ArrayType Word64
_ Type
_) = Bool
True
    isBitcastable (T.PointerType Type
_ AddrSpace
_) (T.IntegerType Word32
_) = Bool
True
    isBitcastable (T.IntegerType Word32
_) (T.PointerType Type
_ AddrSpace
_) = Bool
True
    isBitcastable Type
_ Type
_ = Bool
False

-- | Convert an operand to match a desired LLVM type if needed.
ensureMatchingType :: (CS.MonadCodegen m) => AT.SrcLoc -> AST.Operand -> T.Type -> m AST.Operand
ensureMatchingType :: forall (m :: * -> *).
MonadCodegen m =>
SrcLoc -> Operand -> Type -> m Operand
ensureMatchingType SrcLoc
loc Operand
val Type
targetTy
  | Operand -> Type
forall a. Typed a => a -> Type
TD.typeOf Operand
val Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
targetTy = Operand -> m Operand
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operand
val
  | Bool
otherwise = SrcLoc -> Operand -> Type -> Type -> m Operand
forall (m :: * -> *).
MonadCodegen m =>
SrcLoc -> Operand -> Type -> Type -> m Operand
llvmCast SrcLoc
loc Operand
val (Operand -> Type
forall a. Typed a => a -> Type
TD.typeOf Operand
val) Type
targetTy

-- | Convert an operand to a boolean value.
toBool :: (CS.MonadCodegen m) => AT.SrcLoc -> AST.Operand -> m AST.Operand
toBool :: forall (m :: * -> *).
MonadCodegen m =>
SrcLoc -> Operand -> m Operand
toBool SrcLoc
loc Operand
val = do
  let ty :: Type
ty = Operand -> Type
forall a. Typed a => a -> Type
TD.typeOf Operand
val
  case Type
ty of
    T.IntegerType Word32
1 -> Operand -> m Operand
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operand
val
    T.IntegerType Word32
8 -> do
      let zero8 :: Operand
zero8 = Constant -> Operand
AST.ConstantOperand (Word32 -> Integer -> Constant
C.Int Word32
8 Integer
0)
      IntegerPredicate -> Operand -> Operand -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
IntegerPredicate -> Operand -> Operand -> m Operand
I.icmp IntegerPredicate
IP.NE Operand
val Operand
zero8
    T.IntegerType Word32
32 -> do
      let zero32 :: Operand
zero32 = Constant -> Operand
AST.ConstantOperand (Word32 -> Integer -> Constant
C.Int Word32
32 Integer
0)
      IntegerPredicate -> Operand -> Operand -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
IntegerPredicate -> Operand -> Operand -> m Operand
I.icmp IntegerPredicate
IP.NE Operand
val Operand
zero32
    Type
_ ->
      CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError SrcLoc
loc (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Type -> CodegenErrorType
CC.UnsupportedType Type
AT.TVoid