Fix PR23045.
[oota-llvm.git] / lib / IR / Constants.cpp
index e0cb835c2c675eb5940501102946dad066dd4627..e51a396a0bf14bce979fcd2d4354c42bdbd0bc8c 100644 (file)
@@ -257,11 +257,11 @@ Constant *Constant::getAggregateElement(unsigned Elt) const {
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
     return Elt < CV->getNumOperands() ? CV->getOperand(Elt) : nullptr;
 
-  if (const ConstantAggregateZero *CAZ =dyn_cast<ConstantAggregateZero>(this))
-    return CAZ->getElementValue(Elt);
+  if (const ConstantAggregateZero *CAZ = dyn_cast<ConstantAggregateZero>(this))
+    return Elt < CAZ->getNumElements() ? CAZ->getElementValue(Elt) : nullptr;
 
   if (const UndefValue *UV = dyn_cast<UndefValue>(this))
-    return UV->getElementValue(Elt);
+    return Elt < UV->getNumElements() ? UV->getElementValue(Elt) : nullptr;
 
   if (const ConstantDataSequential *CDS =dyn_cast<ConstantDataSequential>(this))
     return Elt < CDS->getNumElements() ? CDS->getElementAsConstant(Elt)
@@ -554,19 +554,17 @@ Constant *ConstantInt::getFalse(Type *Ty) {
                                   ConstantInt::getFalse(Ty->getContext()));
 }
 
-
-// Get a ConstantInt from an APInt. Note that the value stored in the DenseMap 
-// as the key, is a DenseMapAPIntKeyInfo::KeyTy which has provided the
-// operator== and operator!= to ensure that the DenseMap doesn't attempt to
-// compare APInt's of different widths, which would violate an APInt class
-// invariant which generates an assertion.
+// Get a ConstantInt from an APInt.
 ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
-  // Get the corresponding integer type for the bit width of the value.
-  IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
   // get an existing value or the insertion position
   LLVMContextImpl *pImpl = Context.pImpl;
-  ConstantInt *&Slot = pImpl->IntConstants[DenseMapAPIntKeyInfo::KeyTy(V, ITy)];
-  if (!Slot) Slot = new ConstantInt(ITy, V);
+  ConstantInt *&Slot = pImpl->IntConstants[V];
+  if (!Slot) {
+    // Get the corresponding integer type for the bit width of the value.
+    IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
+    Slot = new ConstantInt(ITy, V);
+  }
+  assert(Slot->getType() == IntegerType::get(Context, V.getBitWidth()));
   return Slot;
 }
 
@@ -689,7 +687,7 @@ Constant *ConstantFP::getZeroValueForNegation(Type *Ty) {
 ConstantFP* ConstantFP::get(LLVMContext &Context, const APFloat& V) {
   LLVMContextImpl* pImpl = Context.pImpl;
 
-  ConstantFP *&Slot = pImpl->FPConstants[DenseMapAPFloatKeyInfo::KeyTy(V)];
+  ConstantFP *&Slot = pImpl->FPConstants[V];
 
   if (!Slot) {
     Type *Ty;
@@ -766,6 +764,14 @@ Constant *ConstantAggregateZero::getElementValue(unsigned Idx) const {
   return getStructElement(Idx);
 }
 
+unsigned ConstantAggregateZero::getNumElements() const {
+  const Type *Ty = getType();
+  if (const auto *AT = dyn_cast<ArrayType>(Ty))
+    return AT->getNumElements();
+  if (const auto *VT = dyn_cast<VectorType>(Ty))
+    return VT->getNumElements();
+  return Ty->getStructNumElements();
+}
 
 //===----------------------------------------------------------------------===//
 //                         UndefValue Implementation
@@ -799,7 +805,14 @@ UndefValue *UndefValue::getElementValue(unsigned Idx) const {
   return getStructElement(Idx);
 }
 
-
+unsigned UndefValue::getNumElements() const {
+  const Type *Ty = getType();
+  if (const auto *AT = dyn_cast<ArrayType>(Ty))
+    return AT->getNumElements();
+  if (const auto *VT = dyn_cast<VectorType>(Ty))
+    return VT->getNumElements();
+  return Ty->getStructNumElements();
+}
 
 //===----------------------------------------------------------------------===//
 //                            ConstantXXX Classes
@@ -898,23 +911,25 @@ Constant *ConstantArray::getImpl(ArrayType *Ty, ArrayRef<Constant*> V) {
 
     if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       if (CFP->getType()->isFloatTy()) {
-        SmallVector<float, 16> Elts;
+        SmallVector<uint32_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToFloat());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
+          return ConstantDataArray::getFP(C->getContext(), Elts);
       } else if (CFP->getType()->isDoubleTy()) {
-        SmallVector<double, 16> Elts;
+        SmallVector<uint64_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToDouble());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
+          return ConstantDataArray::getFP(C->getContext(), Elts);
       }
     }
   }
@@ -1084,23 +1099,25 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
 
     if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       if (CFP->getType()->isFloatTy()) {
-        SmallVector<float, 16> Elts;
+        SmallVector<uint32_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToFloat());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
+          return ConstantDataVector::getFP(C->getContext(), Elts);
       } else if (CFP->getType()->isDoubleTy()) {
-        SmallVector<double, 16> Elts;
+        SmallVector<uint64_t, 16> Elts;
         for (unsigned i = 0, e = V.size(); i != e; ++i)
           if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(CFP->getValueAPF().convertToDouble());
+            Elts.push_back(
+                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
           else
             break;
         if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
+          return ConstantDataVector::getFP(C->getContext(), Elts);
       }
     }
   }
@@ -1198,11 +1215,9 @@ ConstantExpr::getWithOperandReplaced(unsigned OpNo, Constant *Op) const {
 Constant *ConstantExpr::getWithOperands(ArrayRef<Constant *> Ops, Type *Ty,
                                         bool OnlyIfReduced) const {
   assert(Ops.size() == getNumOperands() && "Operand count mismatch!");
-  bool AnyChange = Ty != getType();
-  for (unsigned i = 0; i != Ops.size(); ++i)
-    AnyChange |= Ops[i] != getOperand(i);
 
-  if (!AnyChange)  // No operands changed, return self.
+  // If no operands changed return self.
+  if (Ty == getType() && std::equal(Ops.begin(), Ops.end(), op_begin()))
     return const_cast<ConstantExpr*>(this);
 
   Type *OnlyIfReducedTy = OnlyIfReduced ? Ty : nullptr;
@@ -2531,7 +2546,31 @@ Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<float> Elts) {
 Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
-  return getImpl(StringRef(const_cast<char *>(Data), Elts.size()*8), Ty);
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
+}
+
+/// getFP() constructors - Return a constant with array type with an element
+/// count and element type of float with precision matching the number of
+/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
+/// double for 64bits) Note that this can return a ConstantAggregateZero
+/// object.
+Constant *ConstantDataArray::getFP(LLVMContext &Context,
+                                   ArrayRef<uint16_t> Elts) {
+  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 2), Ty);
+}
+Constant *ConstantDataArray::getFP(LLVMContext &Context,
+                                   ArrayRef<uint32_t> Elts) {
+  Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 4), Ty);
+}
+Constant *ConstantDataArray::getFP(LLVMContext &Context,
+                                   ArrayRef<uint64_t> Elts) {
+  Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
 }
 
 /// getString - This method constructs a CDS and initializes it with a text
@@ -2584,7 +2623,31 @@ Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<float> Elts) {
 Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
-  return getImpl(StringRef(const_cast<char *>(Data), Elts.size()*8), Ty);
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
+}
+
+/// getFP() constructors - Return a constant with vector type with an element
+/// count and element type of float with the precision matching the number of
+/// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
+/// double for 64bits) Note that this can return a ConstantAggregateZero
+/// object.
+Constant *ConstantDataVector::getFP(LLVMContext &Context,
+                                    ArrayRef<uint16_t> Elts) {
+  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 2), Ty);
+}
+Constant *ConstantDataVector::getFP(LLVMContext &Context,
+                                    ArrayRef<uint32_t> Elts) {
+  Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 4), Ty);
+}
+Constant *ConstantDataVector::getFP(LLVMContext &Context,
+                                    ArrayRef<uint64_t> Elts) {
+  Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+  const char *Data = reinterpret_cast<const char *>(Elts.data());
+  return getImpl(StringRef(const_cast<char *>(Data), Elts.size() * 8), Ty);
 }
 
 Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
@@ -2610,13 +2673,14 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
 
   if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
     if (CFP->getType()->isFloatTy()) {
-      SmallVector<float, 16> Elts(NumElts, CFP->getValueAPF().convertToFloat());
-      return get(V->getContext(), Elts);
+      SmallVector<uint32_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getContext(), Elts);
     }
     if (CFP->getType()->isDoubleTy()) {
-      SmallVector<double, 16> Elts(NumElts,
-                                   CFP->getValueAPF().convertToDouble());
-      return get(V->getContext(), Elts);
+      SmallVector<uint64_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getContext(), Elts);
     }
   }
   return ConstantVector::getSplat(NumElts, V);
@@ -2654,13 +2718,13 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
   default:
     llvm_unreachable("Accessor can only be used when element is float/double!");
   case Type::FloatTyID: {
-      const float *FloatPrt = reinterpret_cast<const float *>(EltPtr);
-      return APFloat(*const_cast<float *>(FloatPrt));
-    }
+    auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
+    return APFloat(APFloat::IEEEsingle, APInt(32, EltVal));
+  }
   case Type::DoubleTyID: {
-      const double *DoublePtr = reinterpret_cast<const double *>(EltPtr);
-      return APFloat(*const_cast<double *>(DoublePtr));
-    }
+    auto EltVal = *reinterpret_cast<const uint64_t *>(EltPtr);
+    return APFloat(APFloat::IEEEdouble, APInt(64, EltVal));
+  }
   }
 }
 
@@ -2905,10 +2969,7 @@ void ConstantExpr::replaceUsesOfWithOnConstant(Value *From, Value *ToV,
 }
 
 Instruction *ConstantExpr::getAsInstruction() {
-  SmallVector<Value*,4> ValueOperands;
-  for (op_iterator I = op_begin(), E = op_end(); I != E; ++I)
-    ValueOperands.push_back(cast<Value>(I));
-
+  SmallVector<Value *, 4> ValueOperands(op_begin(), op_end());
   ArrayRef<Value*> Ops(ValueOperands);
 
   switch (getOpcode()) {
@@ -2940,12 +3001,14 @@ Instruction *ConstantExpr::getAsInstruction() {
   case Instruction::ShuffleVector:
     return new ShuffleVectorInst(Ops[0], Ops[1], Ops[2]);
 
-  case Instruction::GetElementPtr:
-    if (cast<GEPOperator>(this)->isInBounds())
-      return GetElementPtrInst::CreateInBounds(Ops[0], Ops.slice(1));
-    else
-      return GetElementPtrInst::Create(Ops[0], Ops.slice(1));
-
+  case Instruction::GetElementPtr: {
+    const auto *GO = cast<GEPOperator>(this);
+    if (GO->isInBounds())
+      return GetElementPtrInst::CreateInBounds(GO->getSourceElementType(),
+                                               Ops[0], Ops.slice(1));
+    return GetElementPtrInst::Create(GO->getSourceElementType(), Ops[0],
+                                     Ops.slice(1));
+  }
   case Instruction::ICmp:
   case Instruction::FCmp:
     return CmpInst::Create((Instruction::OtherOps)getOpcode(),