[NVPTX] Handle addrspacecast constant expressions in aggregate initializers
[oota-llvm.git] / lib / Target / NVPTX / NVPTXAsmPrinter.cpp
index 3beee65b7ef09da840373af2023c38f52577245c..85a4fa6c08748ade81061649c201cc81a3795a37 100644 (file)
@@ -1983,6 +1983,212 @@ bool NVPTXAsmPrinter::ignoreLoc(const MachineInstr &MI) {
   return false;
 }
 
+/// lowerConstantForGV - Return an MCExpr for the given Constant.  This is mostly
+/// a copy from AsmPrinter::lowerConstant, except customized to only handle
+/// expressions that are representable in PTX and create
+/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
+const MCExpr *
+NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
+  MCContext &Ctx = OutContext;
+
+  if (CV->isNullValue() || isa<UndefValue>(CV))
+    return MCConstantExpr::Create(0, Ctx);
+
+  if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
+    return MCConstantExpr::Create(CI->getZExtValue(), Ctx);
+
+  if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
+    const MCSymbolRefExpr *Expr =
+      MCSymbolRefExpr::Create(getSymbol(GV), Ctx);
+    if (ProcessingGeneric) {
+      return NVPTXGenericMCSymbolRefExpr::Create(Expr, Ctx);
+    } else {
+      return Expr;
+    }
+  }
+
+  const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
+  if (!CE) {
+    llvm_unreachable("Unknown constant value to lower!");
+  }
+
+  switch (CE->getOpcode()) {
+  default:
+    // If the code isn't optimized, there may be outstanding folding
+    // opportunities. Attempt to fold the expression using DataLayout as a
+    // last resort before giving up.
+    if (Constant *C = ConstantFoldConstantExpression(CE, *TM.getDataLayout()))
+      if (C != CE)
+        return lowerConstantForGV(C, ProcessingGeneric);
+
+    // Otherwise report the problem to the user.
+    {
+      std::string S;
+      raw_string_ostream OS(S);
+      OS << "Unsupported expression in static initializer: ";
+      CE->printAsOperand(OS, /*PrintType=*/false,
+                     !MF ? nullptr : MF->getFunction()->getParent());
+      report_fatal_error(OS.str());
+    }
+
+  case Instruction::AddrSpaceCast: {
+    // Strip the addrspacecast and pass along the operand
+    PointerType *DstTy = cast<PointerType>(CE->getType());
+    if (DstTy->getAddressSpace() == 0) {
+      return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
+    }
+    std::string S;
+    raw_string_ostream OS(S);
+    OS << "Unsupported expression in static initializer: ";
+    CE->printAsOperand(OS, /*PrintType=*/ false,
+                       !MF ? 0 : MF->getFunction()->getParent());
+    report_fatal_error(OS.str());
+  }
+
+  case Instruction::GetElementPtr: {
+    const DataLayout &DL = *TM.getDataLayout();
+
+    // Generate a symbolic expression for the byte address
+    APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
+    cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
+
+    const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
+                                            ProcessingGeneric);
+    if (!OffsetAI)
+      return Base;
+
+    int64_t Offset = OffsetAI.getSExtValue();
+    return MCBinaryExpr::CreateAdd(Base, MCConstantExpr::Create(Offset, Ctx),
+                                   Ctx);
+  }
+
+  case Instruction::Trunc:
+    // We emit the value and depend on the assembler to truncate the generated
+    // expression properly.  This is important for differences between
+    // blockaddress labels.  Since the two labels are in the same function, it
+    // is reasonable to treat their delta as a 32-bit value.
+    // FALL THROUGH.
+  case Instruction::BitCast:
+    return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
+
+  case Instruction::IntToPtr: {
+    const DataLayout &DL = *TM.getDataLayout();
+
+    // Handle casts to pointers by changing them into casts to the appropriate
+    // integer type.  This promotes constant folding and simplifies this code.
+    Constant *Op = CE->getOperand(0);
+    Op = ConstantExpr::getIntegerCast(Op, DL.getIntPtrType(CV->getType()),
+                                      false/*ZExt*/);
+    return lowerConstantForGV(Op, ProcessingGeneric);
+  }
+
+  case Instruction::PtrToInt: {
+    const DataLayout &DL = *TM.getDataLayout();
+
+    // Support only foldable casts to/from pointers that can be eliminated by
+    // changing the pointer to the appropriately sized integer type.
+    Constant *Op = CE->getOperand(0);
+    Type *Ty = CE->getType();
+
+    const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
+
+    // We can emit the pointer value into this slot if the slot is an
+    // integer slot equal to the size of the pointer.
+    if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
+      return OpExpr;
+
+    // Otherwise the pointer is smaller than the resultant integer, mask off
+    // the high bits so we are sure to get a proper truncation if the input is
+    // a constant expr.
+    unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
+    const MCExpr *MaskExpr = MCConstantExpr::Create(~0ULL >> (64-InBits), Ctx);
+    return MCBinaryExpr::CreateAnd(OpExpr, MaskExpr, Ctx);
+  }
+
+  // The MC library also has a right-shift operator, but it isn't consistently
+  // signed or unsigned between different targets.
+  case Instruction::Add: {
+    const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
+    const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
+    switch (CE->getOpcode()) {
+    default: llvm_unreachable("Unknown binary operator constant cast expr");
+    case Instruction::Add: return MCBinaryExpr::CreateAdd(LHS, RHS, Ctx);
+    }
+  }
+  }
+}
+
+// Copy of MCExpr::print customized for NVPTX
+void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
+  switch (Expr.getKind()) {
+  case MCExpr::Target:
+    return cast<MCTargetExpr>(&Expr)->PrintImpl(OS);
+  case MCExpr::Constant:
+    OS << cast<MCConstantExpr>(Expr).getValue();
+    return;
+
+  case MCExpr::SymbolRef: {
+    const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
+    const MCSymbol &Sym = SRE.getSymbol();
+    OS << Sym;
+    return;
+  }
+
+  case MCExpr::Unary: {
+    const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
+    switch (UE.getOpcode()) {
+    case MCUnaryExpr::LNot:  OS << '!'; break;
+    case MCUnaryExpr::Minus: OS << '-'; break;
+    case MCUnaryExpr::Not:   OS << '~'; break;
+    case MCUnaryExpr::Plus:  OS << '+'; break;
+    }
+    printMCExpr(*UE.getSubExpr(), OS);
+    return;
+  }
+
+  case MCExpr::Binary: {
+    const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
+
+    // Only print parens around the LHS if it is non-trivial.
+    if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
+        isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
+      printMCExpr(*BE.getLHS(), OS);
+    } else {
+      OS << '(';
+      printMCExpr(*BE.getLHS(), OS);
+      OS<< ')';
+    }
+
+    switch (BE.getOpcode()) {
+    case MCBinaryExpr::Add:
+      // Print "X-42" instead of "X+-42".
+      if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
+        if (RHSC->getValue() < 0) {
+          OS << RHSC->getValue();
+          return;
+        }
+      }
+
+      OS <<  '+';
+      break;
+    default: llvm_unreachable("Unhandled binary operator");
+    }
+
+    // Only print parens around the LHS if it is non-trivial.
+    if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
+      printMCExpr(*BE.getRHS(), OS);
+    } else {
+      OS << '(';
+      printMCExpr(*BE.getRHS(), OS);
+      OS << ')';
+    }
+    return;
+  }
+  }
+
+  llvm_unreachable("Invalid expression kind!");
+}
+
 /// PrintAsmOperand - Print out an operand for an inline asm expression.
 ///
 bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,