[NVPTX] Fix emitting aggregate parameters
[oota-llvm.git] / lib / Target / NVPTX / NVPTXAsmPrinter.cpp
index 0dcc70d238cab73fb3bceaf6859000fa944e93bb..5fad27e47a90d73ef3d37ee41e60a3f3e4e96d14 100644 (file)
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "NVPTXAsmPrinter.h"
+#include "InstPrinter/NVPTXInstPrinter.h"
 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
 #include "NVPTX.h"
 #include "NVPTXInstrInfo.h"
 #include "NVPTXRegisterInfo.h"
 #include "NVPTXTargetMachine.h"
 #include "NVPTXUtilities.h"
-#include "InstPrinter/NVPTXInstPrinter.h"
 #include "cl_common_defines.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ConstantFolding.h"
-#include "llvm/Assembly/Writer.h"
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
@@ -33,6 +32,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Mangler.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/Support/Path.h"
 #include "llvm/Support/TargetRegistry.h"
 #include "llvm/Support/TimeValue.h"
-#include "llvm/Target/Mangler.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include <sstream>
 using namespace llvm;
 
-bool RegAllocNilUsed = true;
-
 #define DEPOTNAME "__local_depot"
 
 static cl::opt<bool>
@@ -57,12 +54,10 @@ EmitLineNumbers("nvptx-emit-line-numbers", cl::Hidden,
                 cl::desc("NVPTX Specific: Emit Line numbers even without -G"),
                 cl::init(true));
 
-namespace llvm { bool InterleaveSrcInPtx = false; }
-
-static cl::opt<bool, true>
+static cl::opt<bool>
 InterleaveSrc("nvptx-emit-src", cl::ZeroOrMore, cl::Hidden,
               cl::desc("NVPTX Specific: Emit source line in ptx file"),
-              cl::location(llvm::InterleaveSrcInPtx));
+              cl::init(false));
 
 namespace {
 /// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
@@ -130,7 +125,7 @@ const MCExpr *nvptx::LowerConstant(const Constant *CV, AsmPrinter &AP) {
     return MCConstantExpr::Create(CI->getZExtValue(), Ctx);
 
   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV))
-    return MCSymbolRefExpr::Create(AP.Mang->getSymbol(GV), Ctx);
+    return MCSymbolRefExpr::Create(AP.getSymbol(GV), Ctx);
 
   if (const BlockAddress *BA = dyn_cast<BlockAddress>(CV))
     return MCSymbolRefExpr::Create(AP.GetBlockAddressSymbol(BA), Ctx);
@@ -153,7 +148,7 @@ const MCExpr *nvptx::LowerConstant(const Constant *CV, AsmPrinter &AP) {
       std::string S;
       raw_string_ostream OS(S);
       OS << "Unsupported expression in static initializer: ";
-      WriteAsOperand(OS, CE, /*PrintType=*/ false,
+      CE->printAsOperand(OS, /*PrintType=*/ false,
                      !AP.MF ? 0 : AP.MF->getFunction()->getParent());
       report_fatal_error(OS.str());
     }
@@ -295,7 +290,7 @@ void NVPTXAsmPrinter::emitLineNumberAsDotLoc(const MachineInstr &MI) {
     return;
 
   // Emit the line from the source file.
-  if (llvm::InterleaveSrcInPtx)
+  if (InterleaveSrc)
     this->emitSrcInText(fileName.str(), curLoc.getLine());
 
   std::stringstream temp;
@@ -318,6 +313,14 @@ void NVPTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
 void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
   OutMI.setOpcode(MI->getOpcode());
 
+  // Special: Do not mangle symbol operand of CALL_PROTOTYPE
+  if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
+    const MachineOperand &MO = MI->getOperand(0);
+    OutMI.addOperand(GetSymbolRef(MO,
+      OutContext.GetOrCreateSymbol(Twine(MO.getSymbolName()))));
+    return;
+  }
+
   for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
     const MachineOperand &MO = MI->getOperand(i);
 
@@ -345,7 +348,7 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
     MCOp = GetSymbolRef(MO, GetExternalSymbolSymbol(MO.getSymbolName()));
     break;
   case MachineOperand::MO_GlobalAddress:
-    MCOp = GetSymbolRef(MO, Mang->getSymbol(MO.getGlobal()));
+    MCOp = GetSymbolRef(MO, getSymbol(MO.getGlobal()));
     break;
   case MachineOperand::MO_FPImmediate: {
     const ConstantFP *Cnt = MO.getFPImm();
@@ -426,7 +429,7 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
   O << " (";
 
   if (isABI) {
-    if (Ty->isPrimitiveType() || Ty->isIntegerTy()) {
+    if (Ty->isFloatingPointTy() || Ty->isIntegerTy()) {
       unsigned size = 0;
       if (const IntegerType *ITy = dyn_cast<IntegerType>(Ty)) {
         size = ITy->getBitWidth();
@@ -681,7 +684,7 @@ void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
   else
     O << ".func ";
   printReturnValStr(F, O);
-  O << *Mang->getSymbol(F) << "\n";
+  O << *getSymbol(F) << "\n";
   emitFunctionParamList(F, O);
   O << ";\n";
 }
@@ -891,7 +894,7 @@ bool NVPTXAsmPrinter::doInitialization(Module &M) {
   const_cast<TargetLoweringObjectFile &>(getObjFileLowering())
       .Initialize(OutContext, TM);
 
-  Mang = new Mangler(OutContext, &TM);
+  Mang = new Mangler(TM.getDataLayout());
 
   // Emit header before any dwarf directives are emitted below.
   emitHeader(M, OS1);
@@ -1203,7 +1206,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   else
     O << " .align " << GVar->getAlignment();
 
-  if (ETy->isPrimitiveType() || ETy->isIntegerTy() || isa<PointerType>(ETy)) {
+  if (ETy->isSingleValueType()) {
     O << " .";
     // Special case: ABI requires that we use .u8 for predicates
     if (ETy->isIntegerTy(1))
@@ -1211,7 +1214,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     else
       O << getPTXFundamentalTypeStr(ETy, false);
     O << " ";
-    O << *Mang->getSymbol(GVar);
+    O << *getSymbol(GVar);
 
     // Ptx allows variable initilization only for constant and global state
     // spaces.
@@ -1247,15 +1250,15 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           bufferAggregateConstant(Initializer, &aggBuffer);
           if (aggBuffer.numSymbols) {
             if (nvptxSubtarget.is64Bit()) {
-              O << " .u64 " << *Mang->getSymbol(GVar) << "[";
+              O << " .u64 " << *getSymbol(GVar) << "[";
               O << ElementSize / 8;
             } else {
-              O << " .u32 " << *Mang->getSymbol(GVar) << "[";
+              O << " .u32 " << *getSymbol(GVar) << "[";
               O << ElementSize / 4;
             }
             O << "]";
           } else {
-            O << " .b8 " << *Mang->getSymbol(GVar) << "[";
+            O << " .b8 " << *getSymbol(GVar) << "[";
             O << ElementSize;
             O << "]";
           }
@@ -1263,7 +1266,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           aggBuffer.print();
           O << "}";
         } else {
-          O << " .b8 " << *Mang->getSymbol(GVar);
+          O << " .b8 " << *getSymbol(GVar);
           if (ElementSize) {
             O << "[";
             O << ElementSize;
@@ -1271,7 +1274,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           }
         }
       } else {
-        O << " .b8 " << *Mang->getSymbol(GVar);
+        O << " .b8 " << *getSymbol(GVar);
         if (ElementSize) {
           O << "[";
           O << ElementSize;
@@ -1374,11 +1377,11 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   else
     O << " .align " << GVar->getAlignment();
 
-  if (ETy->isPrimitiveType() || ETy->isIntegerTy() || isa<PointerType>(ETy)) {
+  if (ETy->isSingleValueType()) {
     O << " .";
     O << getPTXFundamentalTypeStr(ETy);
     O << " ";
-    O << *Mang->getSymbol(GVar);
+    O << *getSymbol(GVar);
     return;
   }
 
@@ -1393,7 +1396,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   case Type::ArrayTyID:
   case Type::VectorTyID:
     ElementSize = TD->getTypeStoreSize(ETy);
-    O << " .b8 " << *Mang->getSymbol(GVar) << "[";
+    O << " .b8 " << *getSymbol(GVar) << "[";
     if (ElementSize) {
       O << itostr(ElementSize);
     }
@@ -1406,7 +1409,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
 }
 
 static unsigned int getOpenCLAlignment(const DataLayout *TD, Type *Ty) {
-  if (Ty->isPrimitiveType() || Ty->isIntegerTy() || isa<PointerType>(Ty))
+  if (Ty->isSingleValueType())
     return TD->getPrefTypeAlignment(Ty);
 
   const ArrayType *ATy = dyn_cast<ArrayType>(Ty);
@@ -1448,7 +1451,7 @@ void NVPTXAsmPrinter::printParamName(Function::const_arg_iterator I,
                                      int paramIndex, raw_ostream &O) {
   if ((nvptxSubtarget.getDrvInterface() == NVPTX::NVCL) ||
       (nvptxSubtarget.getDrvInterface() == NVPTX::CUDA))
-    O << *Mang->getSymbol(I->getParent()) << "_param_" << paramIndex;
+    O << *getSymbol(I->getParent()) << "_param_" << paramIndex;
   else {
     std::string argName = I->getName();
     const char *p = argName.c_str();
@@ -1507,20 +1510,20 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
       if (llvm::isImage(*I)) {
         std::string sname = I->getName();
         if (llvm::isImageWriteOnly(*I))
-          O << "\t.param .surfref " << *Mang->getSymbol(F) << "_param_"
+          O << "\t.param .surfref " << *getSymbol(F) << "_param_"
             << paramIndex;
         else // Default image is read_only
-          O << "\t.param .texref " << *Mang->getSymbol(F) << "_param_"
+          O << "\t.param .texref " << *getSymbol(F) << "_param_"
             << paramIndex;
       } else // Should be llvm::isSampler(*I)
-        O << "\t.param .samplerref " << *Mang->getSymbol(F) << "_param_"
+        O << "\t.param .samplerref " << *getSymbol(F) << "_param_"
           << paramIndex;
       continue;
     }
 
     if (PAL.hasAttribute(paramIndex + 1, Attribute::ByVal) == false) {
-      if (Ty->isVectorTy()) {
-        // Just print .param .b8 .align <a> .param[size];
+      if (Ty->isAggregateType() || Ty->isVectorTy()) {
+        // Just print .param .align <a> .b8 .param[size];
         // <a> = PAL.getparamalignment
         // size = typeallocsize of element type
         unsigned align = PAL.getParamAlignment(paramIndex + 1);
@@ -1576,7 +1579,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
         continue;
       }
       // Non-kernel function, just print .param .b<size> for ABI
-      // and .reg .b<size> for non ABY
+      // and .reg .b<size> for non-ABI
       unsigned sz = 0;
       if (isa<IntegerType>(Ty)) {
         sz = cast<IntegerType>(Ty)->getBitWidth();
@@ -1600,7 +1603,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
     Type *ETy = PTy->getElementType();
 
     if (isABI || isKernelFunc) {
-      // Just print .param .b8 .align <a> .param[size];
+      // Just print .param .align <a> .b8 .param[size];
       // <a> = PAL.getparamalignment
       // size = typeallocsize of element type
       unsigned align = PAL.getParamAlignment(paramIndex + 1);
@@ -1760,13 +1763,13 @@ void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
     return;
   }
   if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
-    O << *Mang->getSymbol(GVar);
+    O << *getSymbol(GVar);
     return;
   }
   if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
     const Value *v = Cexpr->stripPointerCasts();
     if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
-      O << *Mang->getSymbol(GVar);
+      O << *getSymbol(GVar);
       return;
     } else {
       O << *LowerConstant(CPV, *this);
@@ -2080,23 +2083,8 @@ void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
     break;
 
   case MachineOperand::MO_GlobalAddress:
-    O << *Mang->getSymbol(MO.getGlobal());
-    break;
-
-  case MachineOperand::MO_ExternalSymbol: {
-    const char *symbname = MO.getSymbolName();
-    if (strstr(symbname, ".PARAM") == symbname) {
-      unsigned index;
-      sscanf(symbname + 6, "%u[];", &index);
-      printParamName(index, O);
-    } else if (strstr(symbname, ".HLPPARAM") == symbname) {
-      unsigned index;
-      sscanf(symbname + 9, "%u[];", &index);
-      O << *CurrentFnSym << "_param_" << index << "_offset";
-    } else
-      O << symbname;
+    O << *getSymbol(MO.getGlobal());
     break;
-  }
 
   case MachineOperand::MO_MachineBasicBlock:
     O << *MO.getMBB()->getSymbol();