Use the new script to sort the includes of every file under lib.
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
index 5f925ff87b938de911867ef415360c24b7d902c3..d785fd2a478b86fb0546adb95177a5dfa0e0acce 100644 (file)
 //===----------------------------------------------------------------------===//
 
 
-#include "NVPTX.h"
 #include "NVPTXISelLowering.h"
+#include "NVPTX.h"
 #include "NVPTXTargetMachine.h"
 #include "NVPTXTargetObjectFile.h"
 #include "NVPTXUtilities.h"
-#include "llvm/Intrinsics.h"
-#include "llvm/IntrinsicInst.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/DerivedTypes.h"
-#include "llvm/GlobalValue.h"
-#include "llvm/Module.h"
-#include "llvm/Function.h"
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/Function.h"
+#include "llvm/GlobalValue.h"
+#include "llvm/IntrinsicInst.h"
+#include "llvm/Intrinsics.h"
+#include "llvm/MC/MCSectionELF.h"
+#include "llvm/Module.h"
 #include "llvm/Support/CallSite.h"
-#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
-#include "llvm/MC/MCSectionELF.h"
 #include <sstream>
 
 #undef DEBUG_TYPE
@@ -174,10 +174,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
 
   // PTX does not support load / store predicate registers
-  setOperationAction(ISD::LOAD, MVT::i1, Expand);
+  setOperationAction(ISD::LOAD, MVT::i1, Custom);
+  setOperationAction(ISD::STORE, MVT::i1, Custom);
+
   setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
   setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
-  setOperationAction(ISD::STORE, MVT::i1, Expand);
   setTruncStoreAction(MVT::i64, MVT::i1, Expand);
   setTruncStoreAction(MVT::i32, MVT::i1, Expand);
   setTruncStoreAction(MVT::i16, MVT::i1, Expand);
@@ -270,6 +271,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
   }
 }
 
+bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
+  return VT == MVT::i1;
+}
 
 SDValue
 NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
@@ -402,7 +406,7 @@ std::string NVPTXTargetLowering::getPrototype(Type *retTy,
 
     if (isABI) {
       unsigned align = Outs[i].Flags.getByValAlign();
-      unsigned sz = getTargetData()->getTypeAllocSize(ETy);
+      unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
       O << ".param .align " << align
           << " .b8 ";
       O << "_";
@@ -438,17 +442,21 @@ std::string NVPTXTargetLowering::getPrototype(Type *retTy,
 }
 
 
-#if 0
 SDValue
-NVPTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
-                               CallingConv::ID CallConv, bool isVarArg,
-                               bool doesNotRet, bool &isTailCall,
-                               const SmallVectorImpl<ISD::OutputArg> &Outs,
-                               const SmallVectorImpl<SDValue> &OutVals,
-                               const SmallVectorImpl<ISD::InputArg> &Ins,
-                               DebugLoc dl, SelectionDAG &DAG,
-                               SmallVectorImpl<SDValue> &InVals, Type *retTy,
-                               const ArgListTy &Args) const {
+NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
+                               SmallVectorImpl<SDValue> &InVals) const {
+  SelectionDAG &DAG                     = CLI.DAG;
+  DebugLoc &dl                          = CLI.DL;
+  SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
+  SmallVector<SDValue, 32> &OutVals     = CLI.OutVals;
+  SmallVector<ISD::InputArg, 32> &Ins   = CLI.Ins;
+  SDValue Chain                         = CLI.Chain;
+  SDValue Callee                        = CLI.Callee;
+  bool &isTailCall                      = CLI.IsTailCall;
+  ArgListTy &Args                       = CLI.Args;
+  Type *retTy                           = CLI.RetTy;
+  ImmutableCallSite *CS                 = CLI.CS;
+
   bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
 
   SDValue tempChain = Chain;
@@ -649,20 +657,14 @@ NVPTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
         InFlag = Chain.getValue(1);
       }
       else {
-        // @TODO: Re-enable getAlign calls.  We do not have the
-        // ImmutableCallSite object here anymore.
-        //if (Func) { // direct call
-        //if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
-        //retAlignment = TD->getABITypeAlignment(retTy);
-        //}
-        //else { // indirect call
-        //const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
-        //if (!llvm::getAlign(*CallI, 0, retAlignment))
-        //retAlignment = TD->getABITypeAlignment(retTy);
-        //}
-        // @TODO: Remove this hack!
-        // Functions with explicit alignment metadata will be broken, for now.
-        retAlignment = 16;
+        if (Func) { // direct call
+          if (!llvm::getAlign(*(CS->getCalledFunction()), 0, retAlignment))
+            retAlignment = getDataLayout()->getABITypeAlignment(retTy);
+        } else { // indirect call
+          const CallInst *CallI = dyn_cast<CallInst>(CS->getInstruction());
+          if (!llvm::getAlign(*CallI, 0, retAlignment))
+            retAlignment = getDataLayout()->getABITypeAlignment(retTy);
+        }
         SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
         SDValue DeclareRetOps[] = { Chain, DAG.getConstant(retAlignment,
                                                            MVT::i32),
@@ -823,7 +825,6 @@ NVPTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
   isTailCall = false;
   return Chain;
 }
-#endif
 
 // By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
 // (see LegalizeDAG.cpp). This is slow and uses local memory.
@@ -859,11 +860,64 @@ LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::EXTRACT_SUBVECTOR:
     return Op;
   case ISD::CONCAT_VECTORS: return LowerCONCAT_VECTORS(Op, DAG);
+  case ISD::STORE: return LowerSTORE(Op, DAG);
+  case ISD::LOAD: return LowerLOAD(Op, DAG);
   default:
-    assert(0 && "Custom lowering not defined for operation");
+    llvm_unreachable("Custom lowering not defined for operation");
   }
 }
 
+
+// v = ld i1* addr
+//   =>
+// v1 = ld i8* addr
+// v = trunc v1 to i1
+SDValue NVPTXTargetLowering::
+LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
+  SDNode *Node = Op.getNode();
+  LoadSDNode *LD = cast<LoadSDNode>(Node);
+  DebugLoc dl = Node->getDebugLoc();
+  assert(LD->getExtensionType() == ISD::NON_EXTLOAD) ;
+  assert(Node->getValueType(0) == MVT::i1 &&
+         "Custom lowering for i1 load only");
+  SDValue newLD = DAG.getLoad(MVT::i8, dl, LD->getChain(), LD->getBasePtr(),
+                              LD->getPointerInfo(),
+                              LD->isVolatile(), LD->isNonTemporal(),
+                              LD->isInvariant(),
+                              LD->getAlignment());
+  SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
+  // The legalizer (the caller) is expecting two values from the legalized
+  // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
+  // in LegalizeDAG.cpp which also uses MergeValues.
+  SDValue Ops[] = {result, LD->getChain()};
+  return DAG.getMergeValues(Ops, 2, dl);
+}
+
+// st i1 v, addr
+//    =>
+// v1 = zxt v to i8
+// st i8, addr
+SDValue NVPTXTargetLowering::
+LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
+  SDNode *Node = Op.getNode();
+  DebugLoc dl = Node->getDebugLoc();
+  StoreSDNode *ST = cast<StoreSDNode>(Node);
+  SDValue Tmp1 = ST->getChain();
+  SDValue Tmp2 = ST->getBasePtr();
+  SDValue Tmp3 = ST->getValue();
+  assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
+  unsigned Alignment = ST->getAlignment();
+  bool isVolatile = ST->isVolatile();
+  bool isNonTemporal = ST->isNonTemporal();
+  Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl,
+                     MVT::i8, Tmp3);
+  SDValue Result = DAG.getStore(Tmp1, dl, Tmp3, Tmp2,
+                                ST->getPointerInfo(), isVolatile,
+                                isNonTemporal, Alignment);
+  return Result;
+}
+
+
 SDValue
 NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname, int idx,
                                 EVT v) const {
@@ -887,10 +941,10 @@ NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
 // Check to see if the kernel argument is image*_t or sampler_t
 
 bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
-  const char *specialTypes[] = {
-                                "struct._image2d_t",
-                                "struct._image3d_t",
-                                "struct._sampler_t"
+  static const char *const specialTypes[] = {
+                                             "struct._image2d_t",
+                                             "struct._image3d_t",
+                                             "struct._sampler_t"
   };
 
   const Type *Ty = arg->getType();
@@ -905,7 +959,7 @@ bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
   const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
   const std::string TypeName = STy ? STy->getName() : "";
 
-  for (int i=0, e=sizeof(specialTypes)/sizeof(specialTypes[0]); i!=e; ++i)
+  for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
     if (TypeName == specialTypes[i])
       return true;
 
@@ -919,7 +973,7 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
                                           DebugLoc dl, SelectionDAG &DAG,
                                        SmallVectorImpl<SDValue> &InVals) const {
   MachineFunction &MF = DAG.getMachineFunction();
-  const TargetData *TD = getTargetData();
+  const DataLayout *TD = getDataLayout();
 
   const Function *F = MF.getFunction();
   const AttrListPtr &PAL = F->getAttributes();
@@ -968,7 +1022,7 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
     // to newly created nodes. The SDNOdes for params have to
     // appear in the same order as their order of appearance
     // in the original function. "idx+1" holds that order.
-    if (PAL.paramHasAttr(i+1, Attribute::ByVal) == false) {
+    if (PAL.getParamAttributes(i+1).hasAttribute(Attributes::ByVal) == false) {
       // A plain scalar.
       if (isABI || isKernel) {
         // If ABI, load from the param symbol