Fix a case where vector comparison constant folding would cause an
authorChris Lattner <sabre@nondot.org>
Thu, 10 Jul 2008 00:29:28 +0000 (00:29 +0000)
committerChris Lattner <sabre@nondot.org>
Thu, 10 Jul 2008 00:29:28 +0000 (00:29 +0000)
infinite recursion.  part of PR2529

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@53383 91177308-0d34-0410-b5e6-96231b3b80d8

lib/VMCore/ConstantFold.cpp
test/Transforms/ConstProp/2008-07-07-VectorCompare.ll

index 57b6a4e05d17f96b06b2b9dfb6e9259541b1ad3a..5167c12d49ba8b5eb8a819ced596f014bebf7a89 100644 (file)
@@ -1348,43 +1348,43 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       return ConstantInt::get(Type::Int1Ty, R==APFloat::cmpGreaterThan ||
                                             R==APFloat::cmpEqual);
     }
-  } else if (const ConstantVector *CP1 = dyn_cast<ConstantVector>(C1)) {
-    if (const ConstantVector *CP2 = dyn_cast<ConstantVector>(C2)) {
-      // If we can constant fold the comparison of each element, constant fold
-      // the whole vector comparison.
-      SmallVector<Constant*, 4> Elts;
-      const Type *InEltTy = CP1->getOperand(0)->getType();
-      bool isFP = InEltTy->isFloatingPoint();
-      const Type *ResEltTy = InEltTy;
+  } else if (isa<VectorType>(C1->getType())) {
+    SmallVector<Constant*, 16> C1Elts, C2Elts;
+    C1->getVectorElements(C1Elts);
+    C2->getVectorElements(C2Elts);
+    
+    // If we can constant fold the comparison of each element, constant fold
+    // the whole vector comparison.
+    SmallVector<Constant*, 4> ResElts;
+    const Type *InEltTy = C1Elts[0]->getType();
+    bool isFP = InEltTy->isFloatingPoint();
+    const Type *ResEltTy = InEltTy;
+    if (isFP)
+      ResEltTy = IntegerType::get(InEltTy->getPrimitiveSizeInBits());
+    
+    for (unsigned i = 0, e = C1Elts.size(); i != e; ++i) {
+      // Compare the elements, producing an i1 result or constant expr.
+      Constant *C;
       if (isFP)
-        ResEltTy = IntegerType::get(InEltTy->getPrimitiveSizeInBits());
-      
-      for (unsigned i = 0, e = CP1->getNumOperands(); i != e; ++i) {
-        // Compare the elements, producing an i1 result or constant expr.
-        Constant *C;
-        if (isFP)
-          C = ConstantExpr::getFCmp(pred, CP1->getOperand(i),
-                                    CP2->getOperand(i));
-        else
-          C = ConstantExpr::getICmp(pred, CP1->getOperand(i),
-                                    CP2->getOperand(i));
+        C = ConstantExpr::getFCmp(pred, C1Elts[i], C2Elts[i]);
+      else
+        C = ConstantExpr::getICmp(pred, C1Elts[i], C2Elts[i]);
 
-        // If it is a bool or undef result, convert to the dest type.
-        if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
-          if (CI->isZero())
-            Elts.push_back(Constant::getNullValue(ResEltTy));
-          else
-            Elts.push_back(Constant::getAllOnesValue(ResEltTy));
-        } else if (isa<UndefValue>(C)) {
-          Elts.push_back(UndefValue::get(ResEltTy));
-        } else {
-          break;
-        }
+      // If it is a bool or undef result, convert to the dest type.
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
+        if (CI->isZero())
+          ResElts.push_back(Constant::getNullValue(ResEltTy));
+        else
+          ResElts.push_back(Constant::getAllOnesValue(ResEltTy));
+      } else if (isa<UndefValue>(C)) {
+        ResElts.push_back(UndefValue::get(ResEltTy));
+      } else {
+        break;
       }
-      
-      if (Elts.size() == CP1->getNumOperands())
-        return ConstantVector::get(&Elts[0], Elts.size());
     }
+    
+    if (ResElts.size() == C1Elts.size())
+      return ConstantVector::get(&ResElts[0], ResElts.size());
   }
 
   if (C1->getType()->isFloatingPoint()) {
index b42b0248496172ce8b360b506ea2c74ab86f1054..4c71463204850118cfb0cb27667641f7db55206f 100644 (file)
@@ -20,3 +20,9 @@ undef>, <float 1.0, float 1.0, float 1.0, float undef>
        ret <4 x i32> %foo
 }
 
+define <4 x i32> @test4() {
+   %foo = vfcmp ueq <4 x float> <float 0.0, float 0.0, float 0.0, float 0.0>, <float 1.0, float 1.0, float 1.0, float 0.0>
+
+       ret <4 x i32> %foo
+}
+