ConstantInt has some getters which return ConstantInt's or ConstantVector's of
authorNick Lewycky <nicholas@mxc.ca>
Sun, 6 Mar 2011 03:36:19 +0000 (03:36 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Sun, 6 Mar 2011 03:36:19 +0000 (03:36 +0000)
the value splatted into every element. Extend this to getTrue and getFalse which
by providing new overloads that take Types that are either i1 or <N x i1>. Use
it in InstCombine to add vector support to some code, fixing PR8469!

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@127116 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Constants.h
lib/Transforms/InstCombine/InstCombineCompares.cpp
lib/VMCore/Constants.cpp
test/Transforms/InstCombine/icmp.ll

index 7da8e23a442f11062a8da948feea90e83d18c5f4..c12b33fae71f783fc037574b319730b1674c4dab 100644 (file)
@@ -57,6 +57,8 @@ protected:
 public:
   static ConstantInt *getTrue(LLVMContext &Context);
   static ConstantInt *getFalse(LLVMContext &Context);
+  static Constant *getTrue(const Type *Ty);
+  static Constant *getFalse(const Type *Ty);
   
   /// If Ty is a vector type, return a Constant with a splat of the given
   /// value. Otherwise return a ConstantInt for the given value.
index bdef3dbc53668eb8b786d4209d1676d17161962f..108806a82b7668fcd8e429fbdf6fe027a6c69fb9 100644 (file)
@@ -1915,10 +1915,10 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     // that code below can assume that Min != Max.
     if (!isa<Constant>(Op0) && Op0Min == Op0Max)
       return new ICmpInst(I.getPredicate(),
-                          ConstantInt::get(I.getContext(), Op0Min), Op1);
+                          ConstantInt::get(Op0->getType(), Op0Min), Op1);
     if (!isa<Constant>(Op1) && Op1Min == Op1Max)
       return new ICmpInst(I.getPredicate(), Op0,
-                          ConstantInt::get(I.getContext(), Op1Min));
+                          ConstantInt::get(Op1->getType(), Op1Min));
 
     // Based on the range information we know about the LHS, see if we can
     // simplify this comparison.  For example, (x&4) < 8 is always true.
@@ -1926,7 +1926,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     default: llvm_unreachable("Unknown icmp opcode!");
     case ICmpInst::ICMP_EQ: {
       if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
         
       // If all bits are known zero except for one, then we know at most one
       // bit is set.   If the comparison is against zero, then this is a check
@@ -1963,7 +1963,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     }
     case ICmpInst::ICMP_NE: {
       if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       
       // If all bits are known zero except for one, then we know at most one
       // bit is set.   If the comparison is against zero, then this is a check
@@ -2000,9 +2000,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     }
     case ICmpInst::ICMP_ULT:
       if (Op0Max.ult(Op1Min))          // A <u B -> true if max(A) < min(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Min.uge(Op1Max))          // A <u B -> false if min(A) >= max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
       if (Op1Min == Op0Max)            // A <u B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
       if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
@@ -2018,9 +2018,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       break;
     case ICmpInst::ICMP_UGT:
       if (Op0Min.ugt(Op1Max))          // A >u B -> true if min(A) > max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Max.ule(Op1Min))          // A >u B -> false if max(A) <= max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
 
       if (Op1Max == Op0Min)            // A >u B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
@@ -2037,9 +2037,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       break;
     case ICmpInst::ICMP_SLT:
       if (Op0Max.slt(Op1Min))          // A <s B -> true if max(A) < min(C)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Min.sge(Op1Max))          // A <s B -> false if min(A) >= max(C)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
       if (Op1Min == Op0Max)            // A <s B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
       if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
@@ -2050,9 +2050,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       break;
     case ICmpInst::ICMP_SGT:
       if (Op0Min.sgt(Op1Max))          // A >s B -> true if min(A) > max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Max.sle(Op1Min))          // A >s B -> false if max(A) <= min(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
 
       if (Op1Max == Op0Min)            // A >s B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
@@ -2065,30 +2065,30 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     case ICmpInst::ICMP_SGE:
       assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
       if (Op0Min.sge(Op1Max))          // A >=s B -> true if min(A) >= max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Max.slt(Op1Min))          // A >=s B -> false if max(A) < min(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
       break;
     case ICmpInst::ICMP_SLE:
       assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
       if (Op0Max.sle(Op1Min))          // A <=s B -> true if max(A) <= min(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Min.sgt(Op1Max))          // A <=s B -> false if min(A) > max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
       break;
     case ICmpInst::ICMP_UGE:
       assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
       if (Op0Min.uge(Op1Max))          // A >=u B -> true if min(A) >= max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Max.ult(Op1Min))          // A >=u B -> false if max(A) < min(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
       break;
     case ICmpInst::ICMP_ULE:
       assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
       if (Op0Max.ule(Op1Min))          // A <=u B -> true if max(A) <= min(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
       if (Op0Min.ugt(Op1Max))          // A <=u B -> false if min(A) > max(B)
-        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
       break;
     }
 
@@ -2329,9 +2329,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) {
         default: break;
         case ICmpInst::ICMP_EQ:
-          return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+          return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
         case ICmpInst::ICMP_NE:
-          return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+          return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
         case ICmpInst::ICMP_SGT:
         case ICmpInst::ICMP_SGE:
           return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1),
index 26c248df5d5fc5fa6de2368e64f2580ecd7807ca..7a4dcf92bb2f1c51ae97b8b4c58ec636ae4dc6dd 100644 (file)
@@ -326,27 +326,53 @@ ConstantInt::ConstantInt(const IntegerType *Ty, const APInt& V)
   assert(V.getBitWidth() == Ty->getBitWidth() && "Invalid constant for type");
 }
 
-ConstantIntConstantInt::getTrue(LLVMContext &Context) {
+ConstantInt *ConstantInt::getTrue(LLVMContext &Context) {
   LLVMContextImpl *pImpl = Context.pImpl;
   if (!pImpl->TheTrueVal)
     pImpl->TheTrueVal = ConstantInt::get(Type::getInt1Ty(Context), 1);
   return pImpl->TheTrueVal;
 }
 
-ConstantIntConstantInt::getFalse(LLVMContext &Context) {
+ConstantInt *ConstantInt::getFalse(LLVMContext &Context) {
   LLVMContextImpl *pImpl = Context.pImpl;
   if (!pImpl->TheFalseVal)
     pImpl->TheFalseVal = ConstantInt::get(Type::getInt1Ty(Context), 0);
   return pImpl->TheFalseVal;
 }
 
+Constant *ConstantInt::getTrue(const Type *Ty) {
+  const VectorType *VTy = dyn_cast<VectorType>(Ty);
+  if (!VTy) {
+    assert(Ty->isIntegerTy(1) && "True must be i1 or vector of i1.");
+    return ConstantInt::getTrue(Ty->getContext());
+  }
+  assert(VTy->getElementType()->isIntegerTy(1) &&
+         "True must be vector of i1 or i1.");
+  SmallVector<Constant*, 16> Splat(VTy->getNumElements(),
+                                   ConstantInt::getTrue(Ty->getContext()));
+  return ConstantVector::get(Splat);
+}
+
+Constant *ConstantInt::getFalse(const Type *Ty) {
+  const VectorType *VTy = dyn_cast<VectorType>(Ty);
+  if (!VTy) {
+    assert(Ty->isIntegerTy(1) && "False must be i1 or vector of i1.");
+    return ConstantInt::getFalse(Ty->getContext());
+  }
+  assert(VTy->getElementType()->isIntegerTy(1) &&
+         "False must be vector of i1 or i1.");
+  SmallVector<Constant*, 16> Splat(VTy->getNumElements(),
+                                   ConstantInt::getFalse(Ty->getContext()));
+  return ConstantVector::get(Splat);
+}
+
 
 // Get a ConstantInt from an APInt. Note that the value stored in the DenseMap 
 // as the key, is a DenseMapAPIntKeyInfo::KeyTy which has provided the
 // operator== and operator!= to ensure that the DenseMap doesn't attempt to
 // compare APInt's of different widths, which would violate an APInt class
 // invariant which generates an assertion.
-ConstantInt *ConstantInt::get(LLVMContext &Context, const APIntV) {
+ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt &V) {
   // Get the corresponding integer type for the bit width of the value.
   const IntegerType *ITy = IntegerType::get(Context, V.getBitWidth());
   // get an existing value or the insertion position
@@ -356,9 +382,8 @@ ConstantInt *ConstantInt::get(LLVMContext &Context, const APInt& V) {
   return Slot;
 }
 
-Constant *ConstantInt::get(const Type* Ty, uint64_t V, bool isSigned) {
-  Constant *C = get(cast<IntegerType>(Ty->getScalarType()),
-                               V, isSigned);
+Constant *ConstantInt::get(const Type *Ty, uint64_t V, bool isSigned) {
+  Constant *C = get(cast<IntegerType>(Ty->getScalarType()), V, isSigned);
 
   // For vectors, broadcast the value.
   if (const VectorType *VTy = dyn_cast<VectorType>(Ty))
index 0230506697b7617eb16ef15c4da8ac2ad1ef9c2e..63f76313f8ec5e9508d54d1ff9bacd28bcae1845 100644 (file)
@@ -465,3 +465,13 @@ define i1 @test48(i32 %X, i32 %Y, i32 %Z) {
   %C = icmp eq i32 %A, %B
   ret i1 %C
 }
+
+; PR8469
+; CHECK: @test49
+; CHECK: ret <2 x i1> <i1 true, i1 true>
+define <2 x i1> @test49(<2 x i32> %tmp3) {
+entry:
+  %tmp11 = and <2 x i32> %tmp3, <i32 3, i32 3>
+  %cmp = icmp ult <2 x i32> %tmp11, <i32 4, i32 4>
+  ret <2 x i1> %cmp  
+}