Include optional subclass flags, such as inbounds, nsw, etc., in the
[oota-llvm.git] / lib / VMCore / Constants.cpp
index 37efafc9b208902ba9f1624cdf26ff1c54ea3852..a5b4f289688b417473d2c2337c39c694f7372cd4 100644 (file)
@@ -632,21 +632,13 @@ Constant* ConstantVector::get(Constant* const* Vals, unsigned NumVals) {
 }
 
 Constant* ConstantExpr::getNSWAdd(Constant* C1, Constant* C2) {
-  Constant *C = getAdd(C1, C2);
-  // Set nsw attribute, assuming constant folding didn't eliminate the
-  // Add.
-  if (AddOperator *Add = dyn_cast<AddOperator>(C))
-    Add->setHasNoSignedWrap(true);
-  return C;
+  return getTy(C1->getType(), Instruction::Add, C1, C2,
+               OverflowingBinaryOperator::NoSignedWrap);
 }
 
 Constant* ConstantExpr::getExactSDiv(Constant* C1, Constant* C2) {
-  Constant *C = getSDiv(C1, C2);
-  // Set exact attribute, assuming constant folding didn't eliminate the
-  // SDiv.
-  if (SDivOperator *SDiv = dyn_cast<SDivOperator>(C))
-    SDiv->setIsExact(true);
-  return C;
+  return getTy(C1->getType(), Instruction::SDiv, C1, C2,
+               SDivOperator::IsExact);
 }
 
 // Utility function for determining if a ConstantExpr is a CastOp or not. This
@@ -729,15 +721,19 @@ ConstantExpr::getWithOperandReplaced(unsigned OpNo, Constant *Op) const {
     for (unsigned i = 1, e = getNumOperands(); i != e; ++i)
       Ops[i-1] = getOperand(i);
     if (OpNo == 0)
-      return ConstantExpr::getGetElementPtr(Op, &Ops[0], Ops.size());
+      return cast<GEPOperator>(this)->isInBounds() ?
+        ConstantExpr::getInBoundsGetElementPtr(Op, &Ops[0], Ops.size()) :
+        ConstantExpr::getGetElementPtr(Op, &Ops[0], Ops.size());
     Ops[OpNo-1] = Op;
-    return ConstantExpr::getGetElementPtr(getOperand(0), &Ops[0], Ops.size());
+    return cast<GEPOperator>(this)->isInBounds() ?
+      ConstantExpr::getInBoundsGetElementPtr(getOperand(0), &Ops[0], Ops.size()) :
+      ConstantExpr::getGetElementPtr(getOperand(0), &Ops[0], Ops.size());
   }
   default:
     assert(getNumOperands() == 2 && "Must be binary operator?");
     Op0 = (OpNo == 0) ? Op : getOperand(0);
     Op1 = (OpNo == 1) ? Op : getOperand(1);
-    return ConstantExpr::get(getOpcode(), Op0, Op1);
+    return ConstantExpr::get(getOpcode(), Op0, Op1, SubclassData);
   }
 }
 
@@ -779,13 +775,15 @@ getWithOperands(Constant* const *Ops, unsigned NumOps) const {
   case Instruction::ShuffleVector:
     return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]);
   case Instruction::GetElementPtr:
-    return ConstantExpr::getGetElementPtr(Ops[0], &Ops[1], NumOps-1);
+    return cast<GEPOperator>(this)->isInBounds() ?
+      ConstantExpr::getInBoundsGetElementPtr(Ops[0], &Ops[1], NumOps-1) :
+      ConstantExpr::getGetElementPtr(Ops[0], &Ops[1], NumOps-1);
   case Instruction::ICmp:
   case Instruction::FCmp:
     return ConstantExpr::getCompare(getPredicate(), Ops[0], Ops[1]);
   default:
     assert(getNumOperands() == 2 && "Must be binary operator?");
-    return ConstantExpr::get(getOpcode(), Ops[0], Ops[1]);
+    return ConstantExpr::get(getOpcode(), Ops[0], Ops[1], SubclassData);
   }
 }
 
@@ -1031,8 +1029,9 @@ static ExprMapKeyType getValType(ConstantExpr *CE) {
   Operands.reserve(CE->getNumOperands());
   for (unsigned i = 0, e = CE->getNumOperands(); i != e; ++i)
     Operands.push_back(cast<Constant>(CE->getOperand(i)));
-  return ExprMapKeyType(CE->getOpcode(), Operands, 
+  return ExprMapKeyType(CE->getOpcode(), Operands,
       CE->isCompare() ? CE->getPredicate() : 0,
+      CE->getRawSubclassOptionalData(),
       CE->hasIndices() ?
         CE->getIndices() : SmallVector<unsigned, 4>());
 }
@@ -1280,7 +1279,8 @@ Constant *ConstantExpr::getBitCast(Constant *C, const Type *DstTy) {
 }
 
 Constant *ConstantExpr::getTy(const Type *ReqTy, unsigned Opcode,
-                              Constant *C1, Constant *C2) {
+                              Constant *C1, Constant *C2,
+                              unsigned Flags) {
   // Check the operands for consistency first
   assert(Opcode >= Instruction::BinaryOpsBegin &&
          Opcode <  Instruction::BinaryOpsEnd   &&
@@ -1294,7 +1294,7 @@ Constant *ConstantExpr::getTy(const Type *ReqTy, unsigned Opcode,
       return FC;          // Fold a few common cases...
 
   std::vector<Constant*> argVec(1, C1); argVec.push_back(C2);
-  ExprMapKeyType Key(Opcode, argVec);
+  ExprMapKeyType Key(Opcode, argVec, 0, Flags);
   
   LLVMContextImpl *pImpl = ReqTy->getContext().pImpl;
   
@@ -1322,7 +1322,8 @@ Constant *ConstantExpr::getCompareTy(unsigned short predicate,
   }
 }
 
-Constant *ConstantExpr::get(unsigned Opcode, Constant *C1, Constant *C2) {
+Constant *ConstantExpr::get(unsigned Opcode, Constant *C1, Constant *C2,
+                            unsigned Flags) {
   // API compatibility: Adjust integer opcodes to floating-point opcodes.
   if (C1->getType()->isFPOrFPVector()) {
     if (Opcode == Instruction::Add) Opcode = Instruction::FAdd;
@@ -1387,7 +1388,7 @@ Constant *ConstantExpr::get(unsigned Opcode, Constant *C1, Constant *C2) {
   }
 #endif
 
-  return getTy(C1->getType(), Opcode, C1, C2);
+  return getTy(C1->getType(), Opcode, C1, C2, Flags);
 }
 
 Constant* ConstantExpr::getSizeOf(const Type* Ty) {
@@ -1481,6 +1482,36 @@ Constant *ConstantExpr::getGetElementPtrTy(const Type *ReqTy, Constant *C,
   return pImpl->ExprConstants.getOrCreate(ReqTy, Key);
 }
 
+Constant *ConstantExpr::getInBoundsGetElementPtrTy(const Type *ReqTy,
+                                                   Constant *C,
+                                                   Value* const *Idxs,
+                                                   unsigned NumIdx) {
+  assert(GetElementPtrInst::getIndexedType(C->getType(), Idxs,
+                                           Idxs+NumIdx) ==
+         cast<PointerType>(ReqTy)->getElementType() &&
+         "GEP indices invalid!");
+
+  if (Constant *FC = ConstantFoldGetElementPtr(
+                              ReqTy->getContext(), C, (Constant**)Idxs, NumIdx))
+    return FC;          // Fold a few common cases...
+
+  assert(isa<PointerType>(C->getType()) &&
+         "Non-pointer type for constant GetElementPtr expression");
+  // Look up the constant in the table first to ensure uniqueness
+  std::vector<Constant*> ArgVec;
+  ArgVec.reserve(NumIdx+1);
+  ArgVec.push_back(C);
+  for (unsigned i = 0; i != NumIdx; ++i)
+    ArgVec.push_back(cast<Constant>(Idxs[i]));
+  const ExprMapKeyType Key(Instruction::GetElementPtr, ArgVec, 0,
+                           GEPOperator::IsInBounds);
+
+  LLVMContextImpl *pImpl = ReqTy->getContext().pImpl;
+
+  // Implicitly locked.
+  return pImpl->ExprConstants.getOrCreate(ReqTy, Key);
+}
+
 Constant *ConstantExpr::getGetElementPtr(Constant *C, Value* const *Idxs,
                                          unsigned NumIdx) {
   // Get the result type of the getelementptr!
@@ -1494,12 +1525,12 @@ Constant *ConstantExpr::getGetElementPtr(Constant *C, Value* const *Idxs,
 Constant *ConstantExpr::getInBoundsGetElementPtr(Constant *C,
                                                  Value* const *Idxs,
                                                  unsigned NumIdx) {
-  Constant *Result = getGetElementPtr(C, Idxs, NumIdx);
-  // Set in bounds attribute, assuming constant folding didn't eliminate the
-  // GEP.
-  if (GEPOperator *GEP = dyn_cast<GEPOperator>(Result))
-    GEP->setIsInBounds(true);
-  return Result;
+  // Get the result type of the getelementptr!
+  const Type *Ty = 
+    GetElementPtrInst::getIndexedType(C->getType(), Idxs, Idxs+NumIdx);
+  assert(Ty && "GEP indices invalid!");
+  unsigned As = cast<PointerType>(C->getType())->getAddressSpace();
+  return getInBoundsGetElementPtrTy(PointerType::get(Ty, As), C, Idxs, NumIdx);
 }
 
 Constant *ConstantExpr::getGetElementPtr(Constant *C, Constant* const *Idxs,
@@ -2104,7 +2135,7 @@ void ConstantExpr::replaceUsesOfWithOnConstant(Value *From, Value *ToV,
     Constant *C2 = getOperand(1);
     if (C1 == From) C1 = To;
     if (C2 == From) C2 = To;
-    Replacement = ConstantExpr::get(getOpcode(), C1, C2);
+    Replacement = ConstantExpr::get(getOpcode(), C1, C2, SubclassData);
   } else {
     llvm_unreachable("Unknown ConstantExpr type!");
     return;