InstCombine: Fold ((A | B) & C1) ^ (B & C2) -> (A & C1) ^ B if C1^C2=-1
[oota-llvm.git] / lib / IR / Constants.cpp
index 45a71dc623011049f2bdcd21802d28d06cfe737f..4cc1e96605fc080361a44b8fa0e15ca7ad34278c 100644 (file)
@@ -803,6 +803,11 @@ ConstantArray::ConstantArray(ArrayType *T, ArrayRef<Constant *> V)
 }
 
 Constant *ConstantArray::get(ArrayType *Ty, ArrayRef<Constant*> V) {
+  if (Constant *C = getImpl(Ty, V))
+    return C;
+  return Ty->getContext().pImpl->ArrayConstants.getOrCreate(Ty, V);
+}
+Constant *ConstantArray::getImpl(ArrayType *Ty, ArrayRef<Constant*> V) {
   // Empty arrays are canonicalized to ConstantAggregateZero.
   if (V.empty())
     return ConstantAggregateZero::get(Ty);
@@ -811,7 +816,6 @@ Constant *ConstantArray::get(ArrayType *Ty, ArrayRef<Constant*> V) {
     assert(V[i]->getType() == Ty->getElementType() &&
            "Wrong type in array element initializer");
   }
-  LLVMContextImpl *pImpl = Ty->getContext().pImpl;
 
   // If this is an all-zero array, return a ConstantAggregateZero object.  If
   // all undef, return an UndefValue, if "all simple", then return a
@@ -893,7 +897,7 @@ Constant *ConstantArray::get(ArrayType *Ty, ArrayRef<Constant*> V) {
   }
 
   // Otherwise, we really do want to create a ConstantArray.
-  return pImpl->ArrayConstants.getOrCreate(Ty, V);
+  return nullptr;
 }
 
 /// getTypeForElements - Return an anonymous struct type to use for a constant
@@ -981,9 +985,14 @@ ConstantVector::ConstantVector(VectorType *T, ArrayRef<Constant *> V)
 
 // ConstantVector accessors.
 Constant *ConstantVector::get(ArrayRef<Constant*> V) {
+  if (Constant *C = getImpl(V))
+    return C;
+  VectorType *Ty = VectorType::get(V.front()->getType(), V.size());
+  return Ty->getContext().pImpl->VectorConstants.getOrCreate(Ty, V);
+}
+Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
   assert(!V.empty() && "Vectors can't be empty");
   VectorType *T = VectorType::get(V.front()->getType(), V.size());
-  LLVMContextImpl *pImpl = T->getContext().pImpl;
 
   // If this is an all-undef or all-zero vector, return a
   // ConstantAggregateZero or UndefValue.
@@ -1075,7 +1084,7 @@ Constant *ConstantVector::get(ArrayRef<Constant*> V) {
 
   // Otherwise, the element type isn't compatible with ConstantDataVector, or
   // the operand list constants a ConstantExpr or something else strange.
-  return pImpl->VectorConstants.getOrCreate(T, V);
+  return nullptr;
 }
 
 Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) {
@@ -1163,8 +1172,8 @@ ConstantExpr::getWithOperandReplaced(unsigned OpNo, Constant *Op) const {
 /// getWithOperands - This returns the current constant expression with the
 /// operands replaced with the specified values.  The specified array must
 /// have the same number of operands as our current one.
-Constant *ConstantExpr::
-getWithOperands(ArrayRef<Constant*> Ops, Type *Ty) const {
+Constant *ConstantExpr::getWithOperands(ArrayRef<Constant *> Ops, Type *Ty,
+                                        bool OnlyIfReduced) const {
   assert(Ops.size() == getNumOperands() && "Operand count mismatch!");
   bool AnyChange = Ty != getType();
   for (unsigned i = 0; i != Ops.size(); ++i)
@@ -1173,6 +1182,7 @@ getWithOperands(ArrayRef<Constant*> Ops, Type *Ty) const {
   if (!AnyChange)  // No operands changed, return self.
     return const_cast<ConstantExpr*>(this);
 
+  Type *OnlyIfReducedTy = OnlyIfReduced ? Ty : nullptr;
   switch (getOpcode()) {
   case Instruction::Trunc:
   case Instruction::ZExt:
@@ -1187,28 +1197,34 @@ getWithOperands(ArrayRef<Constant*> Ops, Type *Ty) const {
   case Instruction::IntToPtr:
   case Instruction::BitCast:
   case Instruction::AddrSpaceCast:
-    return ConstantExpr::getCast(getOpcode(), Ops[0], Ty);
+    return ConstantExpr::getCast(getOpcode(), Ops[0], Ty, OnlyIfReduced);
   case Instruction::Select:
-    return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]);
+    return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2], OnlyIfReducedTy);
   case Instruction::InsertElement:
-    return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]);
+    return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2],
+                                          OnlyIfReducedTy);
   case Instruction::ExtractElement:
-    return ConstantExpr::getExtractElement(Ops[0], Ops[1]);
+    return ConstantExpr::getExtractElement(Ops[0], Ops[1], OnlyIfReducedTy);
   case Instruction::InsertValue:
-    return ConstantExpr::getInsertValue(Ops[0], Ops[1], getIndices());
+    return ConstantExpr::getInsertValue(Ops[0], Ops[1], getIndices(),
+                                        OnlyIfReducedTy);
   case Instruction::ExtractValue:
-    return ConstantExpr::getExtractValue(Ops[0], getIndices());
+    return ConstantExpr::getExtractValue(Ops[0], getIndices(), OnlyIfReducedTy);
   case Instruction::ShuffleVector:
-    return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]);
+    return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2],
+                                          OnlyIfReducedTy);
   case Instruction::GetElementPtr:
     return ConstantExpr::getGetElementPtr(Ops[0], Ops.slice(1),
-                                      cast<GEPOperator>(this)->isInBounds());
+                                          cast<GEPOperator>(this)->isInBounds(),
+                                          OnlyIfReducedTy);
   case Instruction::ICmp:
   case Instruction::FCmp:
-    return ConstantExpr::getCompare(getPredicate(), Ops[0], Ops[1]);
+    return ConstantExpr::getCompare(getPredicate(), Ops[0], Ops[1],
+                                    OnlyIfReducedTy);
   default:
     assert(getNumOperands() == 2 && "Must be binary operator?");
-    return ConstantExpr::get(getOpcode(), Ops[0], Ops[1], SubclassOptionalData);
+    return ConstantExpr::get(getOpcode(), Ops[0], Ops[1], SubclassOptionalData,
+                             OnlyIfReducedTy);
   }
 }
 
@@ -1469,27 +1485,21 @@ void BlockAddress::replaceUsesOfWithOnConstant(Value *From, Value *To, Use *U) {
   // and return early.
   BlockAddress *&NewBA =
     getContext().pImpl->BlockAddresses[std::make_pair(NewF, NewBB)];
-  if (!NewBA) {
-    getBasicBlock()->AdjustBlockAddressRefCount(-1);
-
-    // Remove the old entry, this can't cause the map to rehash (just a
-    // tombstone will get added).
-    getContext().pImpl->BlockAddresses.erase(std::make_pair(getFunction(),
-                                                            getBasicBlock()));
-    NewBA = this;
-    setOperand(0, NewF);
-    setOperand(1, NewBB);
-    getBasicBlock()->AdjustBlockAddressRefCount(1);
+  if (NewBA) {
+    replaceUsesOfWithOnConstantImpl(NewBA);
     return;
   }
 
-  // Otherwise, I do need to replace this with an existing value.
-  assert(NewBA != this && "I didn't contain From!");
-
-  // Everyone using this now uses the replacement.
-  replaceAllUsesWith(NewBA);
+  getBasicBlock()->AdjustBlockAddressRefCount(-1);
 
-  destroyConstant();
+  // Remove the old entry, this can't cause the map to rehash (just a
+  // tombstone will get added).
+  getContext().pImpl->BlockAddresses.erase(std::make_pair(getFunction(),
+                                                          getBasicBlock()));
+  NewBA = this;
+  setOperand(0, NewF);
+  setOperand(1, NewBB);
+  getBasicBlock()->AdjustBlockAddressRefCount(1);
 }
 
 //---- ConstantExpr::get() implementations.
@@ -1497,22 +1507,26 @@ void BlockAddress::replaceUsesOfWithOnConstant(Value *From, Value *To, Use *U) {
 
 /// This is a utility function to handle folding of casts and lookup of the
 /// cast in the ExprConstants map. It is used by the various get* methods below.
-static inline Constant *getFoldedCast(
-  Instruction::CastOps opc, Constant *C, Type *Ty) {
+static Constant *getFoldedCast(Instruction::CastOps opc, Constant *C, Type *Ty,
+                               bool OnlyIfReduced = false) {
   assert(Ty->isFirstClassType() && "Cannot cast to an aggregate type!");
   // Fold a few common cases
   if (Constant *FC = ConstantFoldCastInstruction(opc, C, Ty))
     return FC;
 
+  if (OnlyIfReduced)
+    return nullptr;
+
   LLVMContextImpl *pImpl = Ty->getContext().pImpl;
 
   // Look up the constant in the table first to ensure uniqueness.
-  ExprMapKeyType Key(opc, C);
+  ConstantExprKeyType Key(opc, C);
 
   return pImpl->ExprConstants.getOrCreate(Ty, Key);
 }
 
-Constant *ConstantExpr::getCast(unsigned oc, Constant *C, Type *Ty) {
+Constant *ConstantExpr::getCast(unsigned oc, Constant *C, Type *Ty,
+                                bool OnlyIfReduced) {
   Instruction::CastOps opc = Instruction::CastOps(oc);
   assert(Instruction::isCast(opc) && "opcode out of range");
   assert(C && Ty && "Null arguments to getCast");
@@ -1521,19 +1535,32 @@ Constant *ConstantExpr::getCast(unsigned oc, Constant *C, Type *Ty) {
   switch (opc) {
   default:
     llvm_unreachable("Invalid cast opcode");
-  case Instruction::Trunc:    return getTrunc(C, Ty);
-  case Instruction::ZExt:     return getZExt(C, Ty);
-  case Instruction::SExt:     return getSExt(C, Ty);
-  case Instruction::FPTrunc:  return getFPTrunc(C, Ty);
-  case Instruction::FPExt:    return getFPExtend(C, Ty);
-  case Instruction::UIToFP:   return getUIToFP(C, Ty);
-  case Instruction::SIToFP:   return getSIToFP(C, Ty);
-  case Instruction::FPToUI:   return getFPToUI(C, Ty);
-  case Instruction::FPToSI:   return getFPToSI(C, Ty);
-  case Instruction::PtrToInt: return getPtrToInt(C, Ty);
-  case Instruction::IntToPtr: return getIntToPtr(C, Ty);
-  case Instruction::BitCast:  return getBitCast(C, Ty);
-  case Instruction::AddrSpaceCast:  return getAddrSpaceCast(C, Ty);
+  case Instruction::Trunc:
+    return getTrunc(C, Ty, OnlyIfReduced);
+  case Instruction::ZExt:
+    return getZExt(C, Ty, OnlyIfReduced);
+  case Instruction::SExt:
+    return getSExt(C, Ty, OnlyIfReduced);
+  case Instruction::FPTrunc:
+    return getFPTrunc(C, Ty, OnlyIfReduced);
+  case Instruction::FPExt:
+    return getFPExtend(C, Ty, OnlyIfReduced);
+  case Instruction::UIToFP:
+    return getUIToFP(C, Ty, OnlyIfReduced);
+  case Instruction::SIToFP:
+    return getSIToFP(C, Ty, OnlyIfReduced);
+  case Instruction::FPToUI:
+    return getFPToUI(C, Ty, OnlyIfReduced);
+  case Instruction::FPToSI:
+    return getFPToSI(C, Ty, OnlyIfReduced);
+  case Instruction::PtrToInt:
+    return getPtrToInt(C, Ty, OnlyIfReduced);
+  case Instruction::IntToPtr:
+    return getIntToPtr(C, Ty, OnlyIfReduced);
+  case Instruction::BitCast:
+    return getBitCast(C, Ty, OnlyIfReduced);
+  case Instruction::AddrSpaceCast:
+    return getAddrSpaceCast(C, Ty, OnlyIfReduced);
   }
 }
 
@@ -1606,7 +1633,7 @@ Constant *ConstantExpr::getFPCast(Constant *C, Type *Ty) {
   return getCast(opcode, C, Ty);
 }
 
-Constant *ConstantExpr::getTrunc(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1617,10 +1644,10 @@ Constant *ConstantExpr::getTrunc(Constant *C, Type *Ty) {
   assert(C->getType()->getScalarSizeInBits() > Ty->getScalarSizeInBits()&&
          "SrcTy must be larger than DestTy for Trunc!");
 
-  return getFoldedCast(Instruction::Trunc, C, Ty);
+  return getFoldedCast(Instruction::Trunc, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getSExt(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getSExt(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1631,10 +1658,10 @@ Constant *ConstantExpr::getSExt(Constant *C, Type *Ty) {
   assert(C->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits()&&
          "SrcTy must be smaller than DestTy for SExt!");
 
-  return getFoldedCast(Instruction::SExt, C, Ty);
+  return getFoldedCast(Instruction::SExt, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getZExt(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getZExt(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1645,10 +1672,10 @@ Constant *ConstantExpr::getZExt(Constant *C, Type *Ty) {
   assert(C->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits()&&
          "SrcTy must be smaller than DestTy for ZExt!");
 
-  return getFoldedCast(Instruction::ZExt, C, Ty);
+  return getFoldedCast(Instruction::ZExt, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getFPTrunc(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getFPTrunc(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1657,10 +1684,10 @@ Constant *ConstantExpr::getFPTrunc(Constant *C, Type *Ty) {
   assert(C->getType()->isFPOrFPVectorTy() && Ty->isFPOrFPVectorTy() &&
          C->getType()->getScalarSizeInBits() > Ty->getScalarSizeInBits()&&
          "This is an illegal floating point truncation!");
-  return getFoldedCast(Instruction::FPTrunc, C, Ty);
+  return getFoldedCast(Instruction::FPTrunc, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getFPExtend(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getFPExtend(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1669,10 +1696,10 @@ Constant *ConstantExpr::getFPExtend(Constant *C, Type *Ty) {
   assert(C->getType()->isFPOrFPVectorTy() && Ty->isFPOrFPVectorTy() &&
          C->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits()&&
          "This is an illegal floating point extension!");
-  return getFoldedCast(Instruction::FPExt, C, Ty);
+  return getFoldedCast(Instruction::FPExt, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getUIToFP(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getUIToFP(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1680,10 +1707,10 @@ Constant *ConstantExpr::getUIToFP(Constant *C, Type *Ty) {
   assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
   assert(C->getType()->isIntOrIntVectorTy() && Ty->isFPOrFPVectorTy() &&
          "This is an illegal uint to floating point cast!");
-  return getFoldedCast(Instruction::UIToFP, C, Ty);
+  return getFoldedCast(Instruction::UIToFP, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getSIToFP(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getSIToFP(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1691,10 +1718,10 @@ Constant *ConstantExpr::getSIToFP(Constant *C, Type *Ty) {
   assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
   assert(C->getType()->isIntOrIntVectorTy() && Ty->isFPOrFPVectorTy() &&
          "This is an illegal sint to floating point cast!");
-  return getFoldedCast(Instruction::SIToFP, C, Ty);
+  return getFoldedCast(Instruction::SIToFP, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getFPToUI(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getFPToUI(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1702,10 +1729,10 @@ Constant *ConstantExpr::getFPToUI(Constant *C, Type *Ty) {
   assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
   assert(C->getType()->isFPOrFPVectorTy() && Ty->isIntOrIntVectorTy() &&
          "This is an illegal floating point to uint cast!");
-  return getFoldedCast(Instruction::FPToUI, C, Ty);
+  return getFoldedCast(Instruction::FPToUI, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getFPToSI(Constant *C, Type *Ty) {
+Constant *ConstantExpr::getFPToSI(Constant *C, Type *Ty, bool OnlyIfReduced) {
 #ifndef NDEBUG
   bool fromVec = C->getType()->getTypeID() == Type::VectorTyID;
   bool toVec = Ty->getTypeID() == Type::VectorTyID;
@@ -1713,10 +1740,11 @@ Constant *ConstantExpr::getFPToSI(Constant *C, Type *Ty) {
   assert((fromVec == toVec) && "Cannot convert from scalar to/from vector");
   assert(C->getType()->isFPOrFPVectorTy() && Ty->isIntOrIntVectorTy() &&
          "This is an illegal floating point to sint cast!");
-  return getFoldedCast(Instruction::FPToSI, C, Ty);
+  return getFoldedCast(Instruction::FPToSI, C, Ty, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy) {
+Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy,
+                                    bool OnlyIfReduced) {
   assert(C->getType()->getScalarType()->isPointerTy() &&
          "PtrToInt source must be pointer or pointer vector");
   assert(DstTy->getScalarType()->isIntegerTy() && 
@@ -1725,10 +1753,11 @@ Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy) {
   if (isa<VectorType>(C->getType()))
     assert(C->getType()->getVectorNumElements()==DstTy->getVectorNumElements()&&
            "Invalid cast between a different number of vector elements");
-  return getFoldedCast(Instruction::PtrToInt, C, DstTy);
+  return getFoldedCast(Instruction::PtrToInt, C, DstTy, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getIntToPtr(Constant *C, Type *DstTy) {
+Constant *ConstantExpr::getIntToPtr(Constant *C, Type *DstTy,
+                                    bool OnlyIfReduced) {
   assert(C->getType()->getScalarType()->isIntegerTy() &&
          "IntToPtr source must be integer or integer vector");
   assert(DstTy->getScalarType()->isPointerTy() &&
@@ -1737,10 +1766,11 @@ Constant *ConstantExpr::getIntToPtr(Constant *C, Type *DstTy) {
   if (isa<VectorType>(C->getType()))
     assert(C->getType()->getVectorNumElements()==DstTy->getVectorNumElements()&&
            "Invalid cast between a different number of vector elements");
-  return getFoldedCast(Instruction::IntToPtr, C, DstTy);
+  return getFoldedCast(Instruction::IntToPtr, C, DstTy, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getBitCast(Constant *C, Type *DstTy) {
+Constant *ConstantExpr::getBitCast(Constant *C, Type *DstTy,
+                                   bool OnlyIfReduced) {
   assert(CastInst::castIsValid(Instruction::BitCast, C, DstTy) &&
          "Invalid constantexpr bitcast!");
 
@@ -1748,10 +1778,11 @@ Constant *ConstantExpr::getBitCast(Constant *C, Type *DstTy) {
   // speedily.
   if (C->getType() == DstTy) return C;
 
-  return getFoldedCast(Instruction::BitCast, C, DstTy);
+  return getFoldedCast(Instruction::BitCast, C, DstTy, OnlyIfReduced);
 }
 
-Constant *ConstantExpr::getAddrSpaceCast(Constant *C, Type *DstTy) {
+Constant *ConstantExpr::getAddrSpaceCast(Constant *C, Type *DstTy,
+                                         bool OnlyIfReduced) {
   assert(CastInst::castIsValid(Instruction::AddrSpaceCast, C, DstTy) &&
          "Invalid constantexpr addrspacecast!");
 
@@ -1768,11 +1799,11 @@ Constant *ConstantExpr::getAddrSpaceCast(Constant *C, Type *DstTy) {
     }
     C = getBitCast(C, MidTy);
   }
-  return getFoldedCast(Instruction::AddrSpaceCast, C, DstTy);
+  return getFoldedCast(Instruction::AddrSpaceCast, C, DstTy, OnlyIfReduced);
 }
 
 Constant *ConstantExpr::get(unsigned Opcode, Constant *C1, Constant *C2,
-                            unsigned Flags) {
+                            unsigned Flags, Type *OnlyIfReducedTy) {
   // Check the operands for consistency first.
   assert(Opcode >= Instruction::BinaryOpsBegin &&
          Opcode <  Instruction::BinaryOpsEnd   &&
@@ -1841,8 +1872,11 @@ Constant *ConstantExpr::get(unsigned Opcode, Constant *C1, Constant *C2,
   if (Constant *FC = ConstantFoldBinaryInstruction(Opcode, C1, C2))
     return FC;          // Fold a few common cases.
 
+  if (OnlyIfReducedTy == C1->getType())
+    return nullptr;
+
   Constant *ArgVec[] = { C1, C2 };
-  ExprMapKeyType Key(Opcode, ArgVec, 0, Flags);
+  ConstantExprKeyType Key(Opcode, ArgVec, 0, Flags);
 
   LLVMContextImpl *pImpl = C1->getContext().pImpl;
   return pImpl->ExprConstants.getOrCreate(C1->getType(), Key);
@@ -1890,8 +1924,8 @@ Constant *ConstantExpr::getOffsetOf(Type* Ty, Constant *FieldNo) {
                      Type::getInt64Ty(Ty->getContext()));
 }
 
-Constant *ConstantExpr::getCompare(unsigned short Predicate, 
-                                   Constant *C1, Constant *C2) {
+Constant *ConstantExpr::getCompare(unsigned short Predicate, Constant *C1,
+                                   Constant *C2, bool OnlyIfReduced) {
   assert(C1->getType() == C2->getType() && "Op types should be identical!");
 
   switch (Predicate) {
@@ -1902,31 +1936,35 @@ Constant *ConstantExpr::getCompare(unsigned short Predicate,
   case CmpInst::FCMP_UEQ:   case CmpInst::FCMP_UGT: case CmpInst::FCMP_UGE:
   case CmpInst::FCMP_ULT:   case CmpInst::FCMP_ULE: case CmpInst::FCMP_UNE:
   case CmpInst::FCMP_TRUE:
-    return getFCmp(Predicate, C1, C2);
+    return getFCmp(Predicate, C1, C2, OnlyIfReduced);
 
   case CmpInst::ICMP_EQ:  case CmpInst::ICMP_NE:  case CmpInst::ICMP_UGT:
   case CmpInst::ICMP_UGE: case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE:
   case CmpInst::ICMP_SGT: case CmpInst::ICMP_SGE: case CmpInst::ICMP_SLT:
   case CmpInst::ICMP_SLE:
-    return getICmp(Predicate, C1, C2);
+    return getICmp(Predicate, C1, C2, OnlyIfReduced);
   }
 }
 
-Constant *ConstantExpr::getSelect(Constant *C, Constant *V1, Constant *V2) {
+Constant *ConstantExpr::getSelect(Constant *C, Constant *V1, Constant *V2,
+                                  Type *OnlyIfReducedTy) {
   assert(!SelectInst::areInvalidOperands(C, V1, V2)&&"Invalid select operands");
 
   if (Constant *SC = ConstantFoldSelectInstruction(C, V1, V2))
     return SC;        // Fold common cases
 
+  if (OnlyIfReducedTy == V1->getType())
+    return nullptr;
+
   Constant *ArgVec[] = { C, V1, V2 };
-  ExprMapKeyType Key(Instruction::Select, ArgVec);
+  ConstantExprKeyType Key(Instruction::Select, ArgVec);
 
   LLVMContextImpl *pImpl = C->getContext().pImpl;
   return pImpl->ExprConstants.getOrCreate(V1->getType(), Key);
 }
 
 Constant *ConstantExpr::getGetElementPtr(Constant *C, ArrayRef<Value *> Idxs,
-                                         bool InBounds) {
+                                         bool InBounds, Type *OnlyIfReducedTy) {
   assert(C->getType()->isPtrOrPtrVectorTy() &&
          "Non-pointer type for constant GetElementPtr expression");
 
@@ -1941,6 +1979,9 @@ Constant *ConstantExpr::getGetElementPtr(Constant *C, ArrayRef<Value *> Idxs,
   if (VectorType *VecTy = dyn_cast<VectorType>(C->getType()))
     ReqTy = VectorType::get(ReqTy, VecTy->getNumElements());
 
+  if (OnlyIfReducedTy == ReqTy)
+    return nullptr;
+
   // Look up the constant in the table first to ensure uniqueness
   std::vector<Constant*> ArgVec;
   ArgVec.reserve(1 + Idxs.size());
@@ -1954,15 +1995,15 @@ Constant *ConstantExpr::getGetElementPtr(Constant *C, ArrayRef<Value *> Idxs,
            "getelementptr index type missmatch");
     ArgVec.push_back(cast<Constant>(Idxs[i]));
   }
-  const ExprMapKeyType Key(Instruction::GetElementPtr, ArgVec, 0,
-                           InBounds ? GEPOperator::IsInBounds : 0);
+  const ConstantExprKeyType Key(Instruction::GetElementPtr, ArgVec, 0,
+                                InBounds ? GEPOperator::IsInBounds : 0);
 
   LLVMContextImpl *pImpl = C->getContext().pImpl;
   return pImpl->ExprConstants.getOrCreate(ReqTy, Key);
 }
 
-Constant *
-ConstantExpr::getICmp(unsigned short pred, Constant *LHS, Constant *RHS) {
+Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS,
+                                Constant *RHS, bool OnlyIfReduced) {
   assert(LHS->getType() == RHS->getType());
   assert(pred >= ICmpInst::FIRST_ICMP_PREDICATE && 
          pred <= ICmpInst::LAST_ICMP_PREDICATE && "Invalid ICmp Predicate");
@@ -1970,10 +2011,13 @@ ConstantExpr::getICmp(unsigned short pred, Constant *LHS, Constant *RHS) {
   if (Constant *FC = ConstantFoldCompareInstruction(pred, LHS, RHS))
     return FC;          // Fold a few common cases...
 
+  if (OnlyIfReduced)
+    return nullptr;
+
   // Look up the constant in the table first to ensure uniqueness
   Constant *ArgVec[] = { LHS, RHS };
   // Get the key type with both the opcode and predicate
-  const ExprMapKeyType Key(Instruction::ICmp, ArgVec, pred);
+  const ConstantExprKeyType Key(Instruction::ICmp, ArgVec, pred);
 
   Type *ResultTy = Type::getInt1Ty(LHS->getContext());
   if (VectorType *VT = dyn_cast<VectorType>(LHS->getType()))
@@ -1983,18 +2027,21 @@ ConstantExpr::getICmp(unsigned short pred, Constant *LHS, Constant *RHS) {
   return pImpl->ExprConstants.getOrCreate(ResultTy, Key);
 }
 
-Constant *
-ConstantExpr::getFCmp(unsigned short pred, Constant *LHS, Constant *RHS) {
+Constant *ConstantExpr::getFCmp(unsigned short pred, Constant *LHS,
+                                Constant *RHS, bool OnlyIfReduced) {
   assert(LHS->getType() == RHS->getType());
   assert(pred <= FCmpInst::LAST_FCMP_PREDICATE && "Invalid FCmp Predicate");
 
   if (Constant *FC = ConstantFoldCompareInstruction(pred, LHS, RHS))
     return FC;          // Fold a few common cases...
 
+  if (OnlyIfReduced)
+    return nullptr;
+
   // Look up the constant in the table first to ensure uniqueness
   Constant *ArgVec[] = { LHS, RHS };
   // Get the key type with both the opcode and predicate
-  const ExprMapKeyType Key(Instruction::FCmp, ArgVec, pred);
+  const ConstantExprKeyType Key(Instruction::FCmp, ArgVec, pred);
 
   Type *ResultTy = Type::getInt1Ty(LHS->getContext());
   if (VectorType *VT = dyn_cast<VectorType>(LHS->getType()))
@@ -2004,7 +2051,8 @@ ConstantExpr::getFCmp(unsigned short pred, Constant *LHS, Constant *RHS) {
   return pImpl->ExprConstants.getOrCreate(ResultTy, Key);
 }
 
-Constant *ConstantExpr::getExtractElement(Constant *Val, Constant *Idx) {
+Constant *ConstantExpr::getExtractElement(Constant *Val, Constant *Idx,
+                                          Type *OnlyIfReducedTy) {
   assert(Val->getType()->isVectorTy() &&
          "Tried to create extractelement operation on non-vector type!");
   assert(Idx->getType()->isIntegerTy() &&
@@ -2013,17 +2061,20 @@ Constant *ConstantExpr::getExtractElement(Constant *Val, Constant *Idx) {
   if (Constant *FC = ConstantFoldExtractElementInstruction(Val, Idx))
     return FC;          // Fold a few common cases.
 
+  Type *ReqTy = Val->getType()->getVectorElementType();
+  if (OnlyIfReducedTy == ReqTy)
+    return nullptr;
+
   // Look up the constant in the table first to ensure uniqueness
   Constant *ArgVec[] = { Val, Idx };
-  const ExprMapKeyType Key(Instruction::ExtractElement, ArgVec);
+  const ConstantExprKeyType Key(Instruction::ExtractElement, ArgVec);
 
   LLVMContextImpl *pImpl = Val->getContext().pImpl;
-  Type *ReqTy = Val->getType()->getVectorElementType();
   return pImpl->ExprConstants.getOrCreate(ReqTy, Key);
 }
 
-Constant *ConstantExpr::getInsertElement(Constant *Val, Constant *Elt, 
-                                         Constant *Idx) {
+Constant *ConstantExpr::getInsertElement(Constant *Val, Constant *Elt,
+                                         Constant *Idx, Type *OnlyIfReducedTy) {
   assert(Val->getType()->isVectorTy() &&
          "Tried to create insertelement operation on non-vector type!");
   assert(Elt->getType() == Val->getType()->getVectorElementType() &&
@@ -2033,16 +2084,20 @@ Constant *ConstantExpr::getInsertElement(Constant *Val, Constant *Elt,
 
   if (Constant *FC = ConstantFoldInsertElementInstruction(Val, Elt, Idx))
     return FC;          // Fold a few common cases.
+
+  if (OnlyIfReducedTy == Val->getType())
+    return nullptr;
+
   // Look up the constant in the table first to ensure uniqueness
   Constant *ArgVec[] = { Val, Elt, Idx };
-  const ExprMapKeyType Key(Instruction::InsertElement, ArgVec);
+  const ConstantExprKeyType Key(Instruction::InsertElement, ArgVec);
 
   LLVMContextImpl *pImpl = Val->getContext().pImpl;
   return pImpl->ExprConstants.getOrCreate(Val->getType(), Key);
 }
 
-Constant *ConstantExpr::getShuffleVector(Constant *V1, Constant *V2, 
-                                         Constant *Mask) {
+Constant *ConstantExpr::getShuffleVector(Constant *V1, Constant *V2,
+                                         Constant *Mask, Type *OnlyIfReducedTy) {
   assert(ShuffleVectorInst::isValidOperands(V1, V2, Mask) &&
          "Invalid shuffle vector constant expr operands!");
 
@@ -2053,16 +2108,20 @@ Constant *ConstantExpr::getShuffleVector(Constant *V1, Constant *V2,
   Type *EltTy = V1->getType()->getVectorElementType();
   Type *ShufTy = VectorType::get(EltTy, NElts);
 
+  if (OnlyIfReducedTy == ShufTy)
+    return nullptr;
+
   // Look up the constant in the table first to ensure uniqueness
   Constant *ArgVec[] = { V1, V2, Mask };
-  const ExprMapKeyType Key(Instruction::ShuffleVector, ArgVec);
+  const ConstantExprKeyType Key(Instruction::ShuffleVector, ArgVec);
 
   LLVMContextImpl *pImpl = ShufTy->getContext().pImpl;
   return pImpl->ExprConstants.getOrCreate(ShufTy, Key);
 }
 
 Constant *ConstantExpr::getInsertValue(Constant *Agg, Constant *Val,
-                                       ArrayRef<unsigned> Idxs) {
+                                       ArrayRef<unsigned> Idxs,
+                                       Type *OnlyIfReducedTy) {
   assert(Agg->getType()->isFirstClassType() &&
          "Non-first-class type for constant insertvalue expression");
 
@@ -2074,15 +2133,18 @@ Constant *ConstantExpr::getInsertValue(Constant *Agg, Constant *Val,
   if (Constant *FC = ConstantFoldInsertValueInstruction(Agg, Val, Idxs))
     return FC;
 
+  if (OnlyIfReducedTy == ReqTy)
+    return nullptr;
+
   Constant *ArgVec[] = { Agg, Val };
-  const ExprMapKeyType Key(Instruction::InsertValue, ArgVec, 0, 0, Idxs);
+  const ConstantExprKeyType Key(Instruction::InsertValue, ArgVec, 0, 0, Idxs);
 
   LLVMContextImpl *pImpl = Agg->getContext().pImpl;
   return pImpl->ExprConstants.getOrCreate(ReqTy, Key);
 }
 
-Constant *ConstantExpr::getExtractValue(Constant *Agg,
-                                        ArrayRef<unsigned> Idxs) {
+Constant *ConstantExpr::getExtractValue(Constant *Agg, ArrayRef<unsigned> Idxs,
+                                        Type *OnlyIfReducedTy) {
   assert(Agg->getType()->isFirstClassType() &&
          "Tried to create extractelement operation on non-first-class type!");
 
@@ -2095,8 +2157,11 @@ Constant *ConstantExpr::getExtractValue(Constant *Agg,
   if (Constant *FC = ConstantFoldExtractValueInstruction(Agg, Idxs))
     return FC;
 
+  if (OnlyIfReducedTy == ReqTy)
+    return nullptr;
+
   Constant *ArgVec[] = { Agg };
-  const ExprMapKeyType Key(Instruction::ExtractValue, ArgVec, 0, 0, Idxs);
+  const ConstantExprKeyType Key(Instruction::ExtractValue, ArgVec, 0, 0, Idxs);
 
   LLVMContextImpl *pImpl = Agg->getContext().pImpl;
   return pImpl->ExprConstants.getOrCreate(ReqTy, Key);
@@ -2652,13 +2717,22 @@ Constant *ConstantDataVector::getSplatValue() const {
 /// work, but would be really slow because it would have to unique each updated
 /// array instance.
 ///
+void Constant::replaceUsesOfWithOnConstantImpl(Constant *Replacement) {
+  // I do need to replace this with an existing value.
+  assert(Replacement != this && "I didn't contain From!");
+
+  // Everyone using this now uses the replacement.
+  replaceAllUsesWith(Replacement);
+
+  // Delete the old constant!
+  destroyConstant();
+}
+
 void ConstantArray::replaceUsesOfWithOnConstant(Value *From, Value *To,
                                                 Use *U) {
   assert(isa<Constant>(To) && "Cannot make Constant refer to non-constant!");
   Constant *ToC = cast<Constant>(To);
 
-  LLVMContextImpl *pImpl = getType()->getContext().pImpl;
-
   SmallVector<Constant*, 8> Values;
   Values.reserve(getNumOperands());  // Build replacement array.
 
@@ -2678,52 +2752,25 @@ void ConstantArray::replaceUsesOfWithOnConstant(Value *From, Value *To,
     AllSame &= Val == ToC;
   }
 
-  Constant *Replacement = nullptr;
   if (AllSame && ToC->isNullValue()) {
-    Replacement = ConstantAggregateZero::get(getType());
-  } else if (AllSame && isa<UndefValue>(ToC)) {
-    Replacement = UndefValue::get(getType());
-  } else {
-    // Check to see if we have this array type already.
-    LLVMContextImpl::ArrayConstantsTy::LookupKey Lookup(
-        cast<ArrayType>(getType()), makeArrayRef(Values));
-    LLVMContextImpl::ArrayConstantsTy::MapTy::iterator I =
-      pImpl->ArrayConstants.find(Lookup);
-
-    if (I != pImpl->ArrayConstants.map_end()) {
-      Replacement = I->first;
-    } else {
-      // Okay, the new shape doesn't exist in the system yet.  Instead of
-      // creating a new constant array, inserting it, replaceallusesof'ing the
-      // old with the new, then deleting the old... just update the current one
-      // in place!
-      pImpl->ArrayConstants.remove(this);
-
-      // Update to the new value.  Optimize for the case when we have a single
-      // operand that we're changing, but handle bulk updates efficiently.
-      if (NumUpdated == 1) {
-        unsigned OperandToUpdate = U - OperandList;
-        assert(getOperand(OperandToUpdate) == From &&
-               "ReplaceAllUsesWith broken!");
-        setOperand(OperandToUpdate, ToC);
-      } else {
-        for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
-          if (getOperand(i) == From)
-            setOperand(i, ToC);
-      }
-      pImpl->ArrayConstants.insert(this);
-      return;
-    }
+    replaceUsesOfWithOnConstantImpl(ConstantAggregateZero::get(getType()));
+    return;
+  }
+  if (AllSame && isa<UndefValue>(ToC)) {
+    replaceUsesOfWithOnConstantImpl(UndefValue::get(getType()));
+    return;
   }
 
-  // Otherwise, I do need to replace this with an existing value.
-  assert(Replacement != this && "I didn't contain From!");
-
-  // Everyone using this now uses the replacement.
-  replaceAllUsesWith(Replacement);
+  // Check for any other type of constant-folding.
+  if (Constant *C = getImpl(getType(), Values)) {
+    replaceUsesOfWithOnConstantImpl(C);
+    return;
+  }
 
-  // Delete the old constant!
-  destroyConstant();
+  // Update to the new value.
+  if (Constant *C = getContext().pImpl->ArrayConstants.replaceOperandsInPlace(
+          Values, this, From, ToC, NumUpdated, U - OperandList))
+    replaceUsesOfWithOnConstantImpl(C);
 }
 
 void ConstantStruct::replaceUsesOfWithOnConstant(Value *From, Value *To,
@@ -2761,65 +2808,47 @@ void ConstantStruct::replaceUsesOfWithOnConstant(Value *From, Value *To,
   }
   Values[OperandToUpdate] = ToC;
 
-  LLVMContextImpl *pImpl = getContext().pImpl;
-
-  Constant *Replacement = nullptr;
   if (isAllZeros) {
-    Replacement = ConstantAggregateZero::get(getType());
-  } else if (isAllUndef) {
-    Replacement = UndefValue::get(getType());
-  } else {
-    // Check to see if we have this struct type already.
-    LLVMContextImpl::StructConstantsTy::LookupKey Lookup(
-        cast<StructType>(getType()), makeArrayRef(Values));
-    LLVMContextImpl::StructConstantsTy::MapTy::iterator I =
-      pImpl->StructConstants.find(Lookup);
-
-    if (I != pImpl->StructConstants.map_end()) {
-      Replacement = I->first;
-    } else {
-      // Okay, the new shape doesn't exist in the system yet.  Instead of
-      // creating a new constant struct, inserting it, replaceallusesof'ing the
-      // old with the new, then deleting the old... just update the current one
-      // in place!
-      pImpl->StructConstants.remove(this);
-
-      // Update to the new value.
-      setOperand(OperandToUpdate, ToC);
-      pImpl->StructConstants.insert(this);
-      return;
-    }
+    replaceUsesOfWithOnConstantImpl(ConstantAggregateZero::get(getType()));
+    return;
+  }
+  if (isAllUndef) {
+    replaceUsesOfWithOnConstantImpl(UndefValue::get(getType()));
+    return;
   }
 
-  assert(Replacement != this && "I didn't contain From!");
-
-  // Everyone using this now uses the replacement.
-  replaceAllUsesWith(Replacement);
-
-  // Delete the old constant!
-  destroyConstant();
+  // Update to the new value.
+  if (Constant *C = getContext().pImpl->StructConstants.replaceOperandsInPlace(
+          Values, this, From, ToC))
+    replaceUsesOfWithOnConstantImpl(C);
 }
 
 void ConstantVector::replaceUsesOfWithOnConstant(Value *From, Value *To,
                                                  Use *U) {
   assert(isa<Constant>(To) && "Cannot make Constant refer to non-constant!");
+  Constant *ToC = cast<Constant>(To);
 
   SmallVector<Constant*, 8> Values;
   Values.reserve(getNumOperands());  // Build replacement array...
+  unsigned NumUpdated = 0;
   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
     Constant *Val = getOperand(i);
-    if (Val == From) Val = cast<Constant>(To);
+    if (Val == From) {
+      ++NumUpdated;
+      Val = ToC;
+    }
     Values.push_back(Val);
   }
 
-  Constant *Replacement = get(Values);
-  assert(Replacement != this && "I didn't contain From!");
-
-  // Everyone using this now uses the replacement.
-  replaceAllUsesWith(Replacement);
+  if (Constant *C = getImpl(Values)) {
+    replaceUsesOfWithOnConstantImpl(C);
+    return;
+  }
 
-  // Delete the old constant!
-  destroyConstant();
+  // Update to the new value.
+  if (Constant *C = getContext().pImpl->VectorConstants.replaceOperandsInPlace(
+          Values, this, From, ToC, NumUpdated, U - OperandList))
+    replaceUsesOfWithOnConstantImpl(C);
 }
 
 void ConstantExpr::replaceUsesOfWithOnConstant(Value *From, Value *ToV,
@@ -2828,19 +2857,26 @@ void ConstantExpr::replaceUsesOfWithOnConstant(Value *From, Value *ToV,
   Constant *To = cast<Constant>(ToV);
 
   SmallVector<Constant*, 8> NewOps;
+  unsigned NumUpdated = 0;
   for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
     Constant *Op = getOperand(i);
-    NewOps.push_back(Op == From ? To : Op);
+    if (Op == From) {
+      ++NumUpdated;
+      Op = To;
+    }
+    NewOps.push_back(Op);
   }
+  assert(NumUpdated && "I didn't contain From!");
 
-  Constant *Replacement = getWithOperands(NewOps);
-  assert(Replacement != this && "I didn't contain From!");
-
-  // Everyone using this now uses the replacement.
-  replaceAllUsesWith(Replacement);
+  if (Constant *C = getWithOperands(NewOps, getType(), true)) {
+    replaceUsesOfWithOnConstantImpl(C);
+    return;
+  }
 
-  // Delete the old constant!
-  destroyConstant();
+  // Update to the new value.
+  if (Constant *C = getContext().pImpl->ExprConstants.replaceOperandsInPlace(
+          NewOps, this, From, To, NumUpdated, U - OperandList))
+    replaceUsesOfWithOnConstantImpl(C);
 }
 
 Instruction *ConstantExpr::getAsInstruction() {