[AsmPrinter] Fix crash in handleIndirectSymViaGOTPCRel
[oota-llvm.git] / lib / CodeGen / AsmPrinter / AsmPrinter.cpp
index 229082f1a5bf8e561a090714bf193a7582f8e0c9..28f5bc49dcabd2cf914680904b6fbafaddd852ec 100644 (file)
@@ -40,7 +40,7 @@
 #include "llvm/MC/MCInst.h"
 #include "llvm/MC/MCSection.h"
 #include "llvm/MC/MCStreamer.h"
-#include "llvm/MC/MCSymbol.h"
+#include "llvm/MC/MCSymbolELF.h"
 #include "llvm/MC/MCValue.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/Format.h"
@@ -151,7 +151,7 @@ void AsmPrinter::EmitToStreamer(MCStreamer &S, const MCInst &Inst) {
 }
 
 StringRef AsmPrinter::getTargetTriple() const {
-  return TM.getTargetTriple();
+  return TM.getTargetTriple().str();
 }
 
 /// getCurrentSection() - Return the current section we are emitting to.
@@ -172,7 +172,6 @@ void AsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
 
 bool AsmPrinter::doInitialization(Module &M) {
   MMI = getAnalysisIfAvailable<MachineModuleInfo>();
-  MMI->AnalyzeModule(M);
 
   // Initialize TargetLoweringObjectFile.
   const_cast<TargetLoweringObjectFile&>(getObjFileLowering())
@@ -180,7 +179,7 @@ bool AsmPrinter::doInitialization(Module &M) {
 
   OutStreamer->InitSections(false);
 
-  Mang = new Mangler(TM.getDataLayout());
+  Mang = new Mangler();
 
   // Emit the version-min deplyment target directive if needed.
   //
@@ -222,7 +221,8 @@ bool AsmPrinter::doInitialization(Module &M) {
     // We're at the module level. Construct MCSubtarget from the default CPU
     // and target triple.
     std::unique_ptr<MCSubtargetInfo> STI(TM.getTarget().createMCSubtargetInfo(
-        TM.getTargetTriple(), TM.getTargetCPU(), TM.getTargetFeatureString()));
+        TM.getTargetTriple().str(), TM.getTargetCPU(),
+        TM.getTargetFeatureString()));
     OutStreamer->AddComment("Start of file scope inline assembly");
     OutStreamer->AddBlankLine();
     EmitInlineAsm(M.getModuleInlineAsm()+"\n", *STI, TM.Options.MCOptions);
@@ -232,7 +232,7 @@ bool AsmPrinter::doInitialization(Module &M) {
 
   if (MAI->doesSupportDebugInformation()) {
     bool skip_dwarf = false;
-    if (Triple(TM.getTargetTriple()).isKnownWindowsMSVCEnvironment()) {
+    if (TM.getTargetTriple().isKnownWindowsMSVCEnvironment()) {
       Handlers.push_back(HandlerInfo(new WinCodeViewLineTables(this),
                                      DbgTimerName,
                                      CodeViewLineTablesGroupName));
@@ -512,7 +512,8 @@ void AsmPrinter::EmitGlobalVariable(const GlobalVariable *GV) {
 
   if (MAI->hasDotTypeDotSizeDirective())
     // .size foo, 42
-    OutStreamer->EmitELFSize(GVSym, MCConstantExpr::Create(Size, OutContext));
+    OutStreamer->emitELFSize(cast<MCSymbolELF>(GVSym),
+                             MCConstantExpr::create(Size, OutContext));
 
   OutStreamer->AddBlankLine();
 }
@@ -566,7 +567,7 @@ void AsmPrinter::EmitFunctionHeader() {
       MCSymbol *CurPos = OutContext.createTempSymbol();
       OutStreamer->EmitLabel(CurPos);
       OutStreamer->EmitAssignment(CurrentFnBegin,
-                                 MCSymbolRefExpr::Create(CurPos, OutContext));
+                                 MCSymbolRefExpr::create(CurPos, OutContext));
     } else {
       OutStreamer->EmitLabel(CurrentFnBegin);
     }
@@ -776,7 +777,7 @@ void AsmPrinter::emitFrameAlloc(const MachineInstr &MI) {
 
   // Emit a symbol assignment.
   OutStreamer->EmitAssignment(FrameAllocSym,
-                             MCConstantExpr::Create(FrameOffset, OutContext));
+                             MCConstantExpr::create(FrameOffset, OutContext));
 }
 
 /// EmitFunctionBody - This method emits the body and trailer for a
@@ -899,12 +900,11 @@ void AsmPrinter::EmitFunctionBody() {
   if (MAI->hasDotTypeDotSizeDirective()) {
     // We can get the size as difference between the function label and the
     // temp label.
-    const MCExpr *SizeExp =
-      MCBinaryExpr::CreateSub(MCSymbolRefExpr::Create(CurrentFnEnd, OutContext),
-                              MCSymbolRefExpr::Create(CurrentFnSymForSize,
-                                                      OutContext),
-                              OutContext);
-    OutStreamer->EmitELFSize(CurrentFnSym, SizeExp);
+    const MCExpr *SizeExp = MCBinaryExpr::createSub(
+        MCSymbolRefExpr::create(CurrentFnEnd, OutContext),
+        MCSymbolRefExpr::create(CurrentFnSymForSize, OutContext), OutContext);
+    if (auto Sym = dyn_cast<MCSymbolELF>(CurrentFnSym))
+      OutStreamer->emitELFSize(Sym, SizeExp);
   }
 
   for (const HandlerInfo &HI : Handlers) {
@@ -1042,8 +1042,7 @@ bool AsmPrinter::doFinalization(Module &M) {
   if (!ModuleFlags.empty())
     TLOF.emitModuleFlags(*OutStreamer, ModuleFlags, *Mang, TM);
 
-  Triple TT(TM.getTargetTriple());
-  if (TT.isOSBinFormatELF()) {
+  if (TM.getTargetTriple().isOSBinFormatELF()) {
     MachineModuleInfoELF &MMIELF = MMI->getObjFileInfo<MachineModuleInfoELF>();
 
     // Output stubs for external and common global variables.
@@ -1326,9 +1325,9 @@ void AsmPrinter::EmitJumpTableInfo() {
 
         // .set LJTSet, LBB32-base
         const MCExpr *LHS =
-          MCSymbolRefExpr::Create(MBB->getSymbol(), OutContext);
+          MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
         OutStreamer->EmitAssignment(GetJTSetSymbol(JTI, MBB->getNumber()),
-                                    MCBinaryExpr::CreateSub(LHS, Base,
+                                    MCBinaryExpr::createSub(LHS, Base,
                                                             OutContext));
       }
     }
@@ -1368,14 +1367,14 @@ void AsmPrinter::EmitJumpTableEntry(const MachineJumpTableInfo *MJTI,
   case MachineJumpTableInfo::EK_BlockAddress:
     // EK_BlockAddress - Each entry is a plain address of block, e.g.:
     //     .word LBB123
-    Value = MCSymbolRefExpr::Create(MBB->getSymbol(), OutContext);
+    Value = MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
     break;
   case MachineJumpTableInfo::EK_GPRel32BlockAddress: {
     // EK_GPRel32BlockAddress - Each entry is an address of block, encoded
     // with a relocation as gp-relative, e.g.:
     //     .gprel32 LBB123
     MCSymbol *MBBSym = MBB->getSymbol();
-    OutStreamer->EmitGPRel32Value(MCSymbolRefExpr::Create(MBBSym, OutContext));
+    OutStreamer->EmitGPRel32Value(MCSymbolRefExpr::create(MBBSym, OutContext));
     return;
   }
 
@@ -1384,7 +1383,7 @@ void AsmPrinter::EmitJumpTableEntry(const MachineJumpTableInfo *MJTI,
     // with a relocation as gp-relative, e.g.:
     //     .gpdword LBB123
     MCSymbol *MBBSym = MBB->getSymbol();
-    OutStreamer->EmitGPRel64Value(MCSymbolRefExpr::Create(MBBSym, OutContext));
+    OutStreamer->EmitGPRel64Value(MCSymbolRefExpr::create(MBBSym, OutContext));
     return;
   }
 
@@ -1397,14 +1396,14 @@ void AsmPrinter::EmitJumpTableEntry(const MachineJumpTableInfo *MJTI,
     //      .set L4_5_set_123, LBB123 - LJTI1_2
     //      .word L4_5_set_123
     if (MAI->doesSetDirectiveSuppressesReloc()) {
-      Value = MCSymbolRefExpr::Create(GetJTSetSymbol(UID, MBB->getNumber()),
+      Value = MCSymbolRefExpr::create(GetJTSetSymbol(UID, MBB->getNumber()),
                                       OutContext);
       break;
     }
-    Value = MCSymbolRefExpr::Create(MBB->getSymbol(), OutContext);
+    Value = MCSymbolRefExpr::create(MBB->getSymbol(), OutContext);
     const TargetLowering *TLI = MF->getSubtarget().getTargetLowering();
     const MCExpr *Base = TLI->getPICJumpTableRelocBaseExpr(MF, UID, OutContext);
-    Value = MCBinaryExpr::CreateSub(Value, Base, OutContext);
+    Value = MCBinaryExpr::createSub(Value, Base, OutContext);
     break;
   }
   }
@@ -1590,25 +1589,7 @@ void AsmPrinter::EmitInt32(int Value) const {
 /// .set if it avoids relocations.
 void AsmPrinter::EmitLabelDifference(const MCSymbol *Hi, const MCSymbol *Lo,
                                      unsigned Size) const {
-  if (!MAI->doesDwarfUseRelocationsAcrossSections())
-    if (OutStreamer->emitAbsoluteSymbolDiff(Hi, Lo, Size))
-      return;
-
-  // Get the Hi-Lo expression.
-  const MCExpr *Diff =
-    MCBinaryExpr::CreateSub(MCSymbolRefExpr::Create(Hi, OutContext),
-                            MCSymbolRefExpr::Create(Lo, OutContext),
-                            OutContext);
-
-  if (!MAI->doesSetDirectiveSuppressesReloc()) {
-    OutStreamer->EmitValue(Diff, Size);
-    return;
-  }
-
-  // Otherwise, emit with .set (aka assignment).
-  MCSymbol *SetLabel = createTempSymbol("set");
-  OutStreamer->EmitAssignment(SetLabel, Diff);
-  OutStreamer->EmitSymbolValue(SetLabel, Size);
+  OutStreamer->emitAbsoluteSymbolDiff(Hi, Lo, Size);
 }
 
 /// EmitLabelPlusOffset - Emit something like ".long Label+Offset"
@@ -1623,10 +1604,10 @@ void AsmPrinter::EmitLabelPlusOffset(const MCSymbol *Label, uint64_t Offset,
   }
 
   // Emit Label+Offset (or just Label if Offset is zero)
-  const MCExpr *Expr = MCSymbolRefExpr::Create(Label, OutContext);
+  const MCExpr *Expr = MCSymbolRefExpr::create(Label, OutContext);
   if (Offset)
-    Expr = MCBinaryExpr::CreateAdd(
-        Expr, MCConstantExpr::Create(Offset, OutContext), OutContext);
+    Expr = MCBinaryExpr::createAdd(
+        Expr, MCConstantExpr::create(Offset, OutContext), OutContext);
 
   OutStreamer->EmitValue(Expr, Size);
 }
@@ -1663,16 +1644,16 @@ const MCExpr *AsmPrinter::lowerConstant(const Constant *CV) {
   MCContext &Ctx = OutContext;
 
   if (CV->isNullValue() || isa<UndefValue>(CV))
-    return MCConstantExpr::Create(0, Ctx);
+    return MCConstantExpr::create(0, Ctx);
 
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
-    return MCConstantExpr::Create(CI->getZExtValue(), Ctx);
+    return MCConstantExpr::create(CI->getZExtValue(), Ctx);
 
   if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV))
-    return MCSymbolRefExpr::Create(getSymbol(GV), Ctx);
+    return MCSymbolRefExpr::create(getSymbol(GV), Ctx);
 
   if (const BlockAddress *BA = dyn_cast<BlockAddress>(CV))
-    return MCSymbolRefExpr::Create(GetBlockAddressSymbol(BA), Ctx);
+    return MCSymbolRefExpr::create(GetBlockAddressSymbol(BA), Ctx);
 
   const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
   if (!CE) {
@@ -1713,7 +1694,7 @@ const MCExpr *AsmPrinter::lowerConstant(const Constant *CV) {
       return Base;
 
     int64_t Offset = OffsetAI.getSExtValue();
-    return MCBinaryExpr::CreateAdd(Base, MCConstantExpr::Create(Offset, Ctx),
+    return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
                                    Ctx);
   }
 
@@ -1756,8 +1737,8 @@ const MCExpr *AsmPrinter::lowerConstant(const Constant *CV) {
     // the high bits so we are sure to get a proper truncation if the input is
     // a constant expr.
     unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
-    const MCExpr *MaskExpr = MCConstantExpr::Create(~0ULL >> (64-InBits), Ctx);
-    return MCBinaryExpr::CreateAnd(OpExpr, MaskExpr, Ctx);
+    const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
+    return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
   }
 
   // The MC library also has a right-shift operator, but it isn't consistently
@@ -1775,15 +1756,15 @@ const MCExpr *AsmPrinter::lowerConstant(const Constant *CV) {
     const MCExpr *RHS = lowerConstant(CE->getOperand(1));
     switch (CE->getOpcode()) {
     default: llvm_unreachable("Unknown binary operator constant cast expr");
-    case Instruction::Add: return MCBinaryExpr::CreateAdd(LHS, RHS, Ctx);
-    case Instruction::Sub: return MCBinaryExpr::CreateSub(LHS, RHS, Ctx);
-    case Instruction::Mul: return MCBinaryExpr::CreateMul(LHS, RHS, Ctx);
-    case Instruction::SDiv: return MCBinaryExpr::CreateDiv(LHS, RHS, Ctx);
-    case Instruction::SRem: return MCBinaryExpr::CreateMod(LHS, RHS, Ctx);
-    case Instruction::Shl: return MCBinaryExpr::CreateShl(LHS, RHS, Ctx);
-    case Instruction::And: return MCBinaryExpr::CreateAnd(LHS, RHS, Ctx);
-    case Instruction::Or:  return MCBinaryExpr::CreateOr (LHS, RHS, Ctx);
-    case Instruction::Xor: return MCBinaryExpr::CreateXor(LHS, RHS, Ctx);
+    case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
+    case Instruction::Sub: return MCBinaryExpr::createSub(LHS, RHS, Ctx);
+    case Instruction::Mul: return MCBinaryExpr::createMul(LHS, RHS, Ctx);
+    case Instruction::SDiv: return MCBinaryExpr::createDiv(LHS, RHS, Ctx);
+    case Instruction::SRem: return MCBinaryExpr::createMod(LHS, RHS, Ctx);
+    case Instruction::Shl: return MCBinaryExpr::createShl(LHS, RHS, Ctx);
+    case Instruction::And: return MCBinaryExpr::createAnd(LHS, RHS, Ctx);
+    case Instruction::Or:  return MCBinaryExpr::createOr (LHS, RHS, Ctx);
+    case Instruction::Xor: return MCBinaryExpr::createXor(LHS, RHS, Ctx);
     }
   }
   }
@@ -1810,40 +1791,30 @@ static int isRepeatedByteSequence(const ConstantDataSequential *V) {
 /// composed of a repeated sequence of identical bytes and return the
 /// byte value.  If it is not a repeated sequence, return -1.
 static int isRepeatedByteSequence(const Value *V, TargetMachine &TM) {
-
   if (const ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
-    if (CI->getBitWidth() > 64) return -1;
-
-    uint64_t Size =
-        TM.getDataLayout()->getTypeAllocSize(V->getType());
-    uint64_t Value = CI->getZExtValue();
+    uint64_t Size = TM.getDataLayout()->getTypeAllocSizeInBits(V->getType());
+    assert(Size % 8 == 0);
 
-    // Make sure the constant is at least 8 bits long and has a power
-    // of 2 bit width.  This guarantees the constant bit width is
-    // always a multiple of 8 bits, avoiding issues with padding out
-    // to Size and other such corner cases.
-    if (CI->getBitWidth() < 8 || !isPowerOf2_64(CI->getBitWidth())) return -1;
+    // Extend the element to take zero padding into account.
+    APInt Value = CI->getValue().zextOrSelf(Size);
+    if (!Value.isSplat(8))
+      return -1;
 
-    uint8_t Byte = static_cast<uint8_t>(Value);
-
-    for (unsigned i = 1; i < Size; ++i) {
-      Value >>= 8;
-      if (static_cast<uint8_t>(Value) != Byte) return -1;
-    }
-    return Byte;
+    return Value.zextOrTrunc(8).getZExtValue();
   }
   if (const ConstantArray *CA = dyn_cast<ConstantArray>(V)) {
     // Make sure all array elements are sequences of the same repeated
     // byte.
     assert(CA->getNumOperands() != 0 && "Should be a CAZ");
-    int Byte = isRepeatedByteSequence(CA->getOperand(0), TM);
-    if (Byte == -1) return -1;
-
-    for (unsigned i = 1, e = CA->getNumOperands(); i != e; ++i) {
-      int ThisByte = isRepeatedByteSequence(CA->getOperand(i), TM);
-      if (ThisByte == -1) return -1;
-      if (Byte != ThisByte) return -1;
-    }
+    Constant *Op0 = CA->getOperand(0);
+    int Byte = isRepeatedByteSequence(Op0, TM);
+    if (Byte == -1)
+      return -1;
+
+    // All array elements must be equal.
+    for (unsigned i = 1, e = CA->getNumOperands(); i != e; ++i)
+      if (CA->getOperand(i) != Op0)
+        return -1;
     return Byte;
   }
 
@@ -2107,16 +2078,20 @@ static void handleIndirectSymViaGOTPCRel(AsmPrinter &AP, const MCExpr **ME,
   //    cstexpr := <gotequiv> - "." + <cst>
   //    cstexpr := <gotequiv> - (<foo> - <offset from @foo base>) + <cst>
   //
-  // After canonicalization by EvaluateAsRelocatable `ME` turns into:
+  // After canonicalization by evaluateAsRelocatable `ME` turns into:
   //
   //  cstexpr := <gotequiv> - <foo> + gotpcrelcst, where
   //    gotpcrelcst := <offset from @foo base> + <cst>
   //
   MCValue MV;
-  if (!(*ME)->EvaluateAsRelocatable(MV, nullptr, nullptr) || MV.isAbsolute())
+  if (!(*ME)->evaluateAsRelocatable(MV, nullptr, nullptr) || MV.isAbsolute())
+    return;
+  const MCSymbolRefExpr *SymA = MV.getSymA();
+  if (!SymA)
     return;
 
-  const MCSymbol *GOTEquivSym = &MV.getSymA()->getSymbol();
+  // Check that GOT equivalent symbol is cached.
+  const MCSymbol *GOTEquivSym = &SymA->getSymbol();
   if (!AP.GlobalGOTEquivs.count(GOTEquivSym))
     return;
 
@@ -2124,8 +2099,11 @@ static void handleIndirectSymViaGOTPCRel(AsmPrinter &AP, const MCExpr **ME,
   if (!BaseGV)
     return;
 
+  // Check for a valid base symbol
   const MCSymbol *BaseSym = AP.getSymbol(BaseGV);
-  if (BaseSym != &MV.getSymB()->getSymbol())
+  const MCSymbolRefExpr *SymB = MV.getSymB();
+
+  if (!SymB || BaseSym != &SymB->getSymbol())
     return;
 
   // Make sure to match:
@@ -2321,11 +2299,10 @@ MCSymbol *AsmPrinter::getSymbolWithGlobalValueBase(const GlobalValue *GV,
                                                            TM);
 }
 
-/// GetExternalSymbolSymbol - Return the MCSymbol for the specified
-/// ExternalSymbol.
+/// Return the MCSymbol for the specified ExternalSymbol.
 MCSymbol *AsmPrinter::GetExternalSymbolSymbol(StringRef Sym) const {
   SmallString<60> NameStr;
-  Mang->getNameWithPrefix(NameStr, Sym);
+  Mangler::getNameWithPrefix(NameStr, Sym, *TM.getDataLayout());
   return OutContext.getOrCreateSymbol(NameStr);
 }