[NVPTX] Fix emitting aggregate parameters
[oota-llvm.git] / lib / Target / NVPTX / NVPTXAsmPrinter.cpp
index 7552fe7041151a4383400c359c522cbc133d6182..5fad27e47a90d73ef3d37ee41e60a3f3e4e96d14 100644 (file)
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "NVPTXAsmPrinter.h"
+#include "InstPrinter/NVPTXInstPrinter.h"
 #include "MCTargetDesc/NVPTXMCAsmInfo.h"
 #include "NVPTX.h"
 #include "NVPTXInstrInfo.h"
 #include "NVPTXRegisterInfo.h"
 #include "NVPTXTargetMachine.h"
 #include "NVPTXUtilities.h"
-#include "InstPrinter/NVPTXInstPrinter.h"
 #include "cl_common_defines.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ConstantFolding.h"
-#include "llvm/Assembly/Writer.h"
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
@@ -33,6 +32,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Mangler.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/MC/MCStreamer.h"
@@ -43,7 +43,6 @@
 #include "llvm/Support/Path.h"
 #include "llvm/Support/TargetRegistry.h"
 #include "llvm/Support/TimeValue.h"
-#include "llvm/Target/Mangler.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include <sstream>
 using namespace llvm;
@@ -149,7 +148,7 @@ const MCExpr *nvptx::LowerConstant(const Constant *CV, AsmPrinter &AP) {
       std::string S;
       raw_string_ostream OS(S);
       OS << "Unsupported expression in static initializer: ";
-      WriteAsOperand(OS, CE, /*PrintType=*/ false,
+      CE->printAsOperand(OS, /*PrintType=*/ false,
                      !AP.MF ? 0 : AP.MF->getFunction()->getParent());
       report_fatal_error(OS.str());
     }
@@ -430,7 +429,7 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
   O << " (";
 
   if (isABI) {
-    if (Ty->isPrimitiveType() || Ty->isIntegerTy()) {
+    if (Ty->isFloatingPointTy() || Ty->isIntegerTy()) {
       unsigned size = 0;
       if (const IntegerType *ITy = dyn_cast<IntegerType>(Ty)) {
         size = ITy->getBitWidth();
@@ -895,7 +894,7 @@ bool NVPTXAsmPrinter::doInitialization(Module &M) {
   const_cast<TargetLoweringObjectFile &>(getObjFileLowering())
       .Initialize(OutContext, TM);
 
-  Mang = new Mangler(&TM);
+  Mang = new Mangler(TM.getDataLayout());
 
   // Emit header before any dwarf directives are emitted below.
   emitHeader(M, OS1);
@@ -1207,7 +1206,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   else
     O << " .align " << GVar->getAlignment();
 
-  if (ETy->isPrimitiveType() || ETy->isIntegerTy() || isa<PointerType>(ETy)) {
+  if (ETy->isSingleValueType()) {
     O << " .";
     // Special case: ABI requires that we use .u8 for predicates
     if (ETy->isIntegerTy(1))
@@ -1378,7 +1377,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   else
     O << " .align " << GVar->getAlignment();
 
-  if (ETy->isPrimitiveType() || ETy->isIntegerTy() || isa<PointerType>(ETy)) {
+  if (ETy->isSingleValueType()) {
     O << " .";
     O << getPTXFundamentalTypeStr(ETy);
     O << " ";
@@ -1410,7 +1409,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
 }
 
 static unsigned int getOpenCLAlignment(const DataLayout *TD, Type *Ty) {
-  if (Ty->isPrimitiveType() || Ty->isIntegerTy() || isa<PointerType>(Ty))
+  if (Ty->isSingleValueType())
     return TD->getPrefTypeAlignment(Ty);
 
   const ArrayType *ATy = dyn_cast<ArrayType>(Ty);
@@ -1523,8 +1522,8 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
     }
 
     if (PAL.hasAttribute(paramIndex + 1, Attribute::ByVal) == false) {
-      if (Ty->isVectorTy()) {
-        // Just print .param .b8 .align <a> .param[size];
+      if (Ty->isAggregateType() || Ty->isVectorTy()) {
+        // Just print .param .align <a> .b8 .param[size];
         // <a> = PAL.getparamalignment
         // size = typeallocsize of element type
         unsigned align = PAL.getParamAlignment(paramIndex + 1);
@@ -1580,7 +1579,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
         continue;
       }
       // Non-kernel function, just print .param .b<size> for ABI
-      // and .reg .b<size> for non ABY
+      // and .reg .b<size> for non-ABI
       unsigned sz = 0;
       if (isa<IntegerType>(Ty)) {
         sz = cast<IntegerType>(Ty)->getBitWidth();
@@ -1604,7 +1603,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
     Type *ETy = PTy->getElementType();
 
     if (isABI || isKernelFunc) {
-      // Just print .param .b8 .align <a> .param[size];
+      // Just print .param .align <a> .b8 .param[size];
       // <a> = PAL.getparamalignment
       // size = typeallocsize of element type
       unsigned align = PAL.getParamAlignment(paramIndex + 1);
@@ -2087,21 +2086,6 @@ void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
     O << *getSymbol(MO.getGlobal());
     break;
 
-  case MachineOperand::MO_ExternalSymbol: {
-    const char *symbname = MO.getSymbolName();
-    if (strstr(symbname, ".PARAM") == symbname) {
-      unsigned index;
-      sscanf(symbname + 6, "%u[];", &index);
-      printParamName(index, O);
-    } else if (strstr(symbname, ".HLPPARAM") == symbname) {
-      unsigned index;
-      sscanf(symbname + 9, "%u[];", &index);
-      O << *CurrentFnSym << "_param_" << index << "_offset";
-    } else
-      O << symbname;
-    break;
-  }
-
   case MachineOperand::MO_MachineBasicBlock:
     O << *MO.getMBB()->getSymbol();
     return;