Fix PR23045.
[oota-llvm.git] / lib / IR / Constants.cpp
index 21dbaccd6097c6bddc494c06995ef24f5e270cf7..e51a396a0bf14bce979fcd2d4354c42bdbd0bc8c 100644 (file)
@@ -257,11 +257,11 @@ Constant *Constant::getAggregateElement(unsigned Elt) const {
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
     return Elt < CV->getNumOperands() ? CV->getOperand(Elt) : nullptr;
 
-  if (const ConstantAggregateZero *CAZ =dyn_cast<ConstantAggregateZero>(this))
-    return CAZ->getElementValue(Elt);
+  if (const ConstantAggregateZero *CAZ = dyn_cast<ConstantAggregateZero>(this))
+    return Elt < CAZ->getNumElements() ? CAZ->getElementValue(Elt) : nullptr;
 
   if (const UndefValue *UV = dyn_cast<UndefValue>(this))
-    return UV->getElementValue(Elt);
+    return Elt < UV->getNumElements() ? UV->getElementValue(Elt) : nullptr;
 
   if (const ConstantDataSequential *CDS =dyn_cast<ConstantDataSequential>(this))
     return Elt < CDS->getNumElements() ? CDS->getElementAsConstant(Elt)
@@ -316,7 +316,7 @@ static bool canTrapImpl(const Constant *C,
   // ConstantExpr traps if any operands can trap.
   for (unsigned i = 0, e = C->getNumOperands(); i != e; ++i) {
     if (ConstantExpr *Op = dyn_cast<ConstantExpr>(CE->getOperand(i))) {
-      if (NonTrappingOps.insert(Op) && canTrapImpl(Op, NonTrappingOps))
+      if (NonTrappingOps.insert(Op).second && canTrapImpl(Op, NonTrappingOps))
         return true;
     }
   }
@@ -363,7 +363,7 @@ ConstHasGlobalValuePredicate(const Constant *C,
       const Constant *ConstOp = dyn_cast<Constant>(Op);
       if (!ConstOp)
         continue;
-      if (Visited.insert(ConstOp))
+      if (Visited.insert(ConstOp).second)
         WorkList.push_back(ConstOp);
     }
   }
@@ -554,19 +554,17 @@ Constant *ConstantInt::getFalse(Type *Ty) {
                                   ConstantInt::getFalse(Ty->getContext()));
 }
 
-
-// Get a ConstantInt from an APInt. Note that the value stored in the DenseMap 
-// as the key, is a DenseMapAPIntKeyInfo::KeyTy which has provided the
-// operator== and operator!= to ensure that the DenseMap doesn't attempt to
-// compare APInt's of different widths, which would violate an APInt class
-// invariant which generates an assertion.
+// Get a ConstantInt from an APInt.
 ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
-  // Get the corresponding integer type for the bit width of the value.
-  IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
   // get an existing value or the insertion position
   LLVMContextImpl *pImpl = Context.pImpl;
-  ConstantInt *&Slot = pImpl->IntConstants[DenseMapAPIntKeyInfo::KeyTy(V, ITy)];
-  if (!Slot) Slot = new ConstantInt(ITy, V);
+  ConstantInt *&Slot = pImpl->IntConstants[V];
+  if (!Slot) {
+    // Get the corresponding integer type for the bit width of the value.
+    IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
+    Slot = new ConstantInt(ITy, V);
+  }
+  assert(Slot->getType() == IntegerType::get(Context, V.getBitWidth()));
   return Slot;
 }
 
@@ -689,7 +687,7 @@ Constant *ConstantFP::getZeroValueForNegation(Type *Ty) {
 ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
   LLVMContextImpl* pImpl = Context.pImpl;
 
-  ConstantFP *&Slot = pImpl->FPConstants[DenseMapAPFloatKeyInfo::KeyTy(V)];
+  ConstantFP *&Slot = pImpl->FPConstants[V];
 
   if (!Slot) {
     Type *Ty;
@@ -766,6 +764,14 @@ Constant *ConstantAggregateZero::getElementValue(unsigned Idx) const {
   return getStructElement(Idx);
 }
 
+unsigned ConstantAggregateZero::getNumElements() const {
+  const Type *Ty = getType();
+  if (const auto *AT = dyn_cast<ArrayType>(Ty))
+    return AT->getNumElements();
+  if (const auto *VT = dyn_cast<VectorType>(Ty))
+    return VT->getNumElements();
+  return Ty->getStructNumElements();
+}
 
 //===----------------------------------------------------------------------===//
 //                         UndefValue Implementation
@@ -799,7 +805,14 @@ UndefValue *UndefValue::getElementValue(unsigned Idx) const {
   return getStructElement(Idx);
 }
 
-
+unsigned UndefValue::getNumElements() const {
+  const Type *Ty = getType();
+  if (const auto *AT = dyn_cast<ArrayType>(Ty))
+    return AT->getNumElements();
+  if (const auto *VT = dyn_cast<VectorType>(Ty))
+    return VT->getNumElements();
+  return Ty->getStructNumElements();
+}
 
 //===----------------------------------------------------------------------===//
 //                            ConstantXXX Classes
@@ -898,23 +911,25 @@ Constant *ConstantArray::getImpl(ArrayType *Ty, ArrayRef<Constant*> V) {
 
     if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       if (CFP->getType()->isFloatTy()) {
-        SmallVector<float, 16> Elts;
+        SmallVector<uint32_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToFloat());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
+          return ConstantDataArray::getFP(C->getContext(), Elts);
       } else if (CFP->getType()->isDoubleTy()) {
-        SmallVector<double, 16> Elts;
+        SmallVector<uint64_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToDouble());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
+          return ConstantDataArray::getFP(C->getContext(), Elts);
       }
     }
   }
@@ -1084,23 +1099,25 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
 
     if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       if (CFP->getType()->isFloatTy()) {
-        SmallVector<float, 16> Elts;
+        SmallVector<uint32_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToFloat());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
+          return ConstantDataVector::getFP(C->getContext(), Elts);
       } else if (CFP->getType()->isDoubleTy()) {
-        SmallVector<double, 16> Elts;
+        SmallVector<uint64_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToDouble());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
+          return ConstantDataVector::getFP(C->getContext(), Elts);
       }
     }
   }
@@ -1198,11 +1215,9 @@ ConstantExpr::getWithOperandReplaced(unsigned OpNo, Constant *Op) const {
 Constant *ConstantExpr::getWithOperands(ArrayRef<Constant *> Ops, Type *Ty,
                                         bool OnlyIfReduced) const {
   assert(Ops.size() == getNumOperands() && "Operand count mismatch!");
-  bool AnyChange = Ty != getType();
-  for (unsigned i = 0; i != Ops.size(); ++i)
-    AnyChange |= Ops[i] != getOperand(i);
 
-  if (!AnyChange)  // No operands changed, return self.
+  // If no operands changed return self.
+  if (Ty == getType() && std::equal(Ops.begin(), Ops.end(), op_begin()))
     return const_cast<ConstantExpr*>(this);
 
   Type *OnlyIfReducedTy = OnlyIfReduced ? Ty : nullptr;
@@ -1919,7 +1934,7 @@ Constant *ConstantExpr::getAlignOf(Type* Ty) {
   // alignof is implemented as: (i64) gep ({i1,Ty}*)null, 0, 1
   // Note that a non-inbounds gep is used, as null isn't within any object.
   Type *AligningTy = 
-    StructType::get(Type::getInt1Ty(Ty->getContext()), Ty, NULL);
+    StructType::get(Type::getInt1Ty(Ty->getContext()), Ty, nullptr);
   Constant *NullPtr = Constant::getNullValue(AligningTy->getPointerTo(0));
   Constant *Zero = ConstantInt::get(Type::getInt64Ty(Ty->getContext()), 0);
   Constant *One = ConstantInt::get(Type::getInt32Ty(Ty->getContext()), 1);
@@ -2436,14 +2451,16 @@ Constant *ConstantDataSequential::getImpl(StringRef Elements, Type *Ty) {
     return ConstantAggregateZero::get(Ty);
 
   // Do a lookup to see if we have already formed one of these.
-  StringMap<ConstantDataSequential*>::MapEntryTy &Slot =
-    Ty->getContext().pImpl->CDSConstants.GetOrCreateValue(Elements);
+  auto &Slot =
+      *Ty->getContext()
+           .pImpl->CDSConstants.insert(std::make_pair(Elements, nullptr))
+           .first;
 
   // The bucket can point to a linked list of different CDS's that have the same
   // body but different types.  For example, 0,0,0,1 could be a 4 element array
   // of i8, or a 1-element array of i32.  They'll both end up in the same
   /// StringMap bucket, linked up by their Next pointers.  Walk the list.
-  ConstantDataSequential **Entry = &Slot.getValue();
+  ConstantDataSequential **Entry = &Slot.second;
   for (ConstantDataSequential *Node = *Entry; Node;
        Entry = &Node->Next, Node = *Entry)
     if (Node->getType() == Ty)
@@ -2452,10 +2469,10 @@ Constant *ConstantDataSequential::getImpl(StringRef Elements, Type *Ty) {
   // Okay, we didn't get a hit.  Create a node of the right class, link it in,
   // and return it.
   if (isa<ArrayType>(Ty))
-    return *Entry = new ConstantDataArray(Ty, Slot.getKeyData());
+    return *Entry = new ConstantDataArray(Ty, Slot.first().data());
 
   assert(isa<VectorType>(Ty));
-  return *Entry = new ConstantDataVector(Ty, Slot.getKeyData());
+  return *Entry = new ConstantDataVector(Ty, Slot.first().data());
 }
 
 void ConstantDataSequential::destroyConstant() {
@@ -2529,7 +2546,31 @@ Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<float> Elts) {
 Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
-  return getImpl(StringRef(const_cast<char *>(Data), Elts.size()*8), Ty);
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
+}
+
+/// getFP() constructors - Return a constant with array type with an element
+/// count and element type of float with precision matching the number of
+/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
+/// double for 64bits) Note that this can return a ConstantAggregateZero
+/// object.
+Constant *ConstantDataArray::getFP(LLVMContext &Context,
+                                   ArrayRef<uint16_t> Elts) {
+  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 2), Ty);
+}
+Constant *ConstantDataArray::getFP(LLVMContext &Context,
+                                   ArrayRef<uint32_t> Elts) {
+  Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 4), Ty);
+}
+Constant *ConstantDataArray::getFP(LLVMContext &Context,
+                                   ArrayRef<uint64_t> Elts) {
+  Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
 }
 
 /// getString - This method constructs a CDS and initializes it with a text
@@ -2582,7 +2623,31 @@ Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<float> Elts) {
 Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
-  return getImpl(StringRef(const_cast<char *>(Data), Elts.size()*8), Ty);
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
+}
+
+/// getFP() constructors - Return a constant with vector type with an element
+/// count and element type of float with the precision matching the number of
+/// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
+/// double for 64bits) Note that this can return a ConstantAggregateZero
+/// object.
+Constant *ConstantDataVector::getFP(LLVMContext &Context,
+                                    ArrayRef<uint16_t> Elts) {
+  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 2), Ty);
+}
+Constant *ConstantDataVector::getFP(LLVMContext &Context,
+                                    ArrayRef<uint32_t> Elts) {
+  Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 4), Ty);
+}
+Constant *ConstantDataVector::getFP(LLVMContext &Context,
+                                    ArrayRef<uint64_t> Elts) {
+  Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
 }
 
 Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
@@ -2608,13 +2673,14 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
 
   if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
     if (CFP->getType()->isFloatTy()) {
-      SmallVector<float, 16> Elts(NumElts, CFP->getValueAPF().convertToFloat());
-      return get(V->getContext(), Elts);
+      SmallVector<uint32_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getContext(), Elts);
     }
     if (CFP->getType()->isDoubleTy()) {
-      SmallVector<double, 16> Elts(NumElts,
-                                   CFP->getValueAPF().convertToDouble());
-      return get(V->getContext(), Elts);
+      SmallVector<uint64_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getContext(), Elts);
     }
   }
   return ConstantVector::getSplat(NumElts, V);
@@ -2652,13 +2718,13 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
   default:
     llvm_unreachable("Accessor can only be used when element is float/double!");
   case Type::FloatTyID: {
-      const float *FloatPrt = reinterpret_cast<const float *>(EltPtr);
-      return APFloat(*const_cast<float *>(FloatPrt));
-    }
+    auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
+    return APFloat(APFloat::IEEEsingle, APInt(32, EltVal));
+  }
   case Type::DoubleTyID: {
-      const double *DoublePtr = reinterpret_cast<const double *>(EltPtr);
-      return APFloat(*const_cast<double *>(DoublePtr));
-    }
+    auto EltVal = *reinterpret_cast<const uint64_t *>(EltPtr);
+    return APFloat(APFloat::IEEEdouble, APInt(64, EltVal));
+  }
   }
 }
 
@@ -2712,7 +2778,7 @@ bool ConstantDataSequential::isCString() const {
 }
 
 /// getSplatValue - If this is a splat constant, meaning that all of the
-/// elements have the same value, return that value. Otherwise return NULL.
+/// elements have the same value, return that value. Otherwise return nullptr.
 Constant *ConstantDataVector::getSplatValue() const {
   const char *Base = getRawDataValues().data();
 
@@ -2903,10 +2969,7 @@ void ConstantExpr::replaceUsesOfWithOnConstant(Value *From, Value *ToV,
 }
 
 Instruction *ConstantExpr::getAsInstruction() {
-  SmallVector<Value*,4> ValueOperands;
-  for (op_iterator I = op_begin(), E = op_end(); I != E; ++I)
-    ValueOperands.push_back(cast<Value>(I));
-
+  SmallVector<Value *, 4> ValueOperands(op_begin(), op_end());
   ArrayRef<Value*> Ops(ValueOperands);
 
   switch (getOpcode()) {
@@ -2938,12 +3001,14 @@ Instruction *ConstantExpr::getAsInstruction() {
   case Instruction::ShuffleVector:
     return new ShuffleVectorInst(Ops[0], Ops[1], Ops[2]);
 
-  case Instruction::GetElementPtr:
-    if (cast<GEPOperator>(this)->isInBounds())
-      return GetElementPtrInst::CreateInBounds(Ops[0], Ops.slice(1));
-    else
-      return GetElementPtrInst::Create(Ops[0], Ops.slice(1));
-
+  case Instruction::GetElementPtr: {
+    const auto *GO = cast<GEPOperator>(this);
+    if (GO->isInBounds())
+      return GetElementPtrInst::CreateInBounds(GO->getSourceElementType(),
+                                               Ops[0], Ops.slice(1));
+    return GetElementPtrInst::Create(GO->getSourceElementType(), Ops[0],
+                                     Ops.slice(1));
+  }
   case Instruction::ICmp:
   case Instruction::FCmp:
     return CmpInst::Create((Instruction::OtherOps)getOpcode(),