{-# LANGUAGE FlexibleContexts #-}

module Codegen.ExprGen.Assembly where

import qualified Ast.Types as AT
import qualified Codegen.Errors as CC
import {-# SOURCE #-} qualified Codegen.ExprGen.ExprGen as EG
import qualified Codegen.ExprGen.Types as ET
import qualified Codegen.State as CS
import qualified Control.Monad.Except as E
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Short as BS
import qualified Data.List as L
import qualified LLVM.AST as AST
import qualified LLVM.AST.Attribute as A
import qualified LLVM.AST.CallingConvention as ACC
import qualified LLVM.AST.Constant as C
import qualified LLVM.AST.InlineAssembly as IA
import qualified LLVM.AST.Type as T
import qualified LLVM.IRBuilder as IRB
import qualified LLVM.IRBuilder.Monad as IRM
import qualified Shared.Utils as SU

-- | Low level function to generate LLVM code for inline assembly.
-- LLVM's IRM module does not provide a function to generate inline assembly
-- so we have to use the IRBuilder directly.
callInlineAssembly ::
  (IRM.MonadIRBuilder m) =>
  IA.InlineAssembly ->
  AST.Type ->
  [(AST.Operand, [A.ParameterAttribute])] ->
  m AST.Operand
callInlineAssembly :: forall (m :: * -> *).
MonadIRBuilder m =>
InlineAssembly
-> Type -> [(Operand, [ParameterAttribute])] -> m Operand
callInlineAssembly InlineAssembly
asm Type
retType [(Operand, [ParameterAttribute])]
args' = do
  let callInstr :: Instruction
callInstr =
        AST.Call
          { tailCallKind :: Maybe TailCallKind
AST.tailCallKind = Maybe TailCallKind
forall a. Maybe a
Nothing,
            callingConvention :: CallingConvention
AST.callingConvention = CallingConvention
ACC.C,
            returnAttributes :: [ParameterAttribute]
AST.returnAttributes = [],
            function :: CallableOperand
AST.function = InlineAssembly -> CallableOperand
forall a b. a -> Either a b
Left InlineAssembly
asm,
            arguments :: [(Operand, [ParameterAttribute])]
AST.arguments = [(Operand, [ParameterAttribute])]
args',
            functionAttributes :: [Either GroupID FunctionAttribute]
AST.functionAttributes = [],
            metadata :: InstructionMetadata
AST.metadata = []
          }
  case Type
retType of
    Type
T.VoidType -> do
      Instruction -> m ()
forall (m :: * -> *). MonadIRBuilder m => Instruction -> m ()
IRB.emitInstrVoid Instruction
callInstr
      Operand -> m Operand
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constant -> Operand
AST.ConstantOperand (Type -> Constant
C.Undef Type
T.void))
    Type
_ -> Type -> Instruction -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
Type -> Instruction -> m Operand
IRB.emitInstr Type
retType Instruction
callInstr

-- | Generate LLVM code for assembly expressions.
generateAssembly :: (CS.MonadCodegen m, EG.ExprGen AT.Expr) => AT.Expr -> m AST.Operand
generateAssembly :: forall (m :: * -> *).
(MonadCodegen m, ExprGen Expr) =>
Expr -> m Operand
generateAssembly (AT.Assembly SrcLoc
_ AsmExpr
asmExpr) = do
  let llvmRetTy :: Type
llvmRetTy = Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ AsmExpr -> Type
AT.asmReturnType AsmExpr
asmExpr
      inlineType :: Type
inlineType = Type -> [Type] -> Bool -> Type
T.FunctionType Type
llvmRetTy ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall a. ToLLVM a => a -> Type
ET.toLLVM ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ AsmExpr -> [Type]
AT.asmParameters AsmExpr
asmExpr) Bool
False

      (String
output, [String]
inputs) =
        ( AsmConstraint -> String
AT.outputConstraint (AsmConstraint -> String) -> AsmConstraint -> String
forall a b. (a -> b) -> a -> b
$ AsmExpr -> AsmConstraint
AT.asmConstraints AsmExpr
asmExpr,
          AsmConstraint -> [String]
AT.inputConstraints (AsmConstraint -> [String]) -> AsmConstraint -> [String]
forall a b. (a -> b) -> a -> b
$ AsmExpr -> AsmConstraint
AT.asmConstraints AsmExpr
asmExpr
        )
      combinedConstraints :: String
combinedConstraints = case (String
output, [String]
inputs) of
        (String
"", []) -> String
""
        (String
"", [String]
is) -> String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
"," [String]
is
        (String
o, []) -> String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
o
        (String
o, [String]
is) -> String
"=" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
o String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"," String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
L.intercalate String
"," [String]
is

      dialect :: Dialect
dialect = case AsmExpr -> AsmDialect
AT.asmDialect AsmExpr
asmExpr of
        AsmDialect
AT.Intel -> Dialect
IA.IntelDialect
        AsmDialect
AT.ATT -> Dialect
IA.ATTDialect

      inlAsm :: InlineAssembly
inlAsm =
        IA.InlineAssembly
          { type' :: Type
IA.type' = Type
inlineType,
            assembly :: ByteString
IA.assembly = String -> ByteString
B.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ AsmExpr -> String
AT.asmCode AsmExpr
asmExpr,
            constraints :: ShortByteString
IA.constraints = ByteString -> ShortByteString
BS.toShort (ByteString -> ShortByteString) -> ByteString -> ShortByteString
forall a b. (a -> b) -> a -> b
$ String -> ByteString
B.pack String
combinedConstraints,
            hasSideEffects :: Bool
IA.hasSideEffects = AsmExpr -> Bool
AT.asmSideEffects AsmExpr
asmExpr,
            alignStack :: Bool
IA.alignStack = AsmExpr -> Bool
AT.asmAlignStack AsmExpr
asmExpr,
            dialect :: Dialect
IA.dialect = Dialect
dialect
          }

  [(Operand, [ParameterAttribute])]
asmOperands <-
    (Expr -> m (Operand, [ParameterAttribute]))
-> [Expr] -> m [(Operand, [ParameterAttribute])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM
      ( \Expr
argExpr -> do
          Operand
argOp <- Expr -> m Operand
forall a (m :: * -> *).
(ExprGen a, MonadCodegen m) =>
a -> m Operand
forall (m :: * -> *). MonadCodegen m => Expr -> m Operand
EG.generateExpr Expr
argExpr
          (Operand, [ParameterAttribute])
-> m (Operand, [ParameterAttribute])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Operand
argOp, [])
      )
      (AsmExpr -> [Expr]
AT.asmArgs AsmExpr
asmExpr)

  InlineAssembly
-> Type -> [(Operand, [ParameterAttribute])] -> m Operand
forall (m :: * -> *).
MonadIRBuilder m =>
InlineAssembly
-> Type -> [(Operand, [ParameterAttribute])] -> m Operand
callInlineAssembly InlineAssembly
inlAsm Type
llvmRetTy [(Operand, [ParameterAttribute])]
asmOperands
generateAssembly Expr
expr =
  CodegenError -> m Operand
forall a. CodegenError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError (CodegenError -> m Operand) -> CodegenError -> m Operand
forall a b. (a -> b) -> a -> b
$ SrcLoc -> CodegenErrorType -> CodegenError
CC.CodegenError (Expr -> SrcLoc
SU.getLoc Expr
expr) (CodegenErrorType -> CodegenError)
-> CodegenErrorType -> CodegenError
forall a b. (a -> b) -> a -> b
$ Expr -> CodegenErrorType
CC.UnsupportedDefinition Expr
expr