Fix a bunch of bugs handling vector compare constant expressions, fixing
[oota-llvm.git] / lib / VMCore / Constants.cpp
index 51e85478b36fd8a50d65c09084da34508c2873dc..48b0a8f31a5f31ccbb418cf07f1787aedacd7dab 100644 (file)
@@ -730,7 +730,8 @@ bool ConstantExpr::isCast() const {
 }
 
 bool ConstantExpr::isCompare() const {
-  return getOpcode() == Instruction::ICmp || getOpcode() == Instruction::FCmp;
+  return getOpcode() == Instruction::ICmp || getOpcode() == Instruction::FCmp ||
+         getOpcode() == Instruction::VICmp || getOpcode() == Instruction::VFCmp;
 }
 
 bool ConstantExpr::hasIndices() const {
@@ -2201,9 +2202,8 @@ ConstantExpr::getFCmp(unsigned short pred, Constant* LHS, Constant* RHS) {
 
 Constant *
 ConstantExpr::getVICmp(unsigned short pred, Constant* LHS, Constant* RHS) {
-  assert(isa<VectorType>(LHS->getType()) &&
+  assert(isa<VectorType>(LHS->getType()) && LHS->getType() == RHS->getType() &&
          "Tried to create vicmp operation on non-vector type!");
-  assert(LHS->getType() == RHS->getType());
   assert(pred >= ICmpInst::FIRST_ICMP_PREDICATE && 
          pred <= ICmpInst::LAST_ICMP_PREDICATE && "Invalid VICmp Predicate");
 
@@ -2211,23 +2211,30 @@ ConstantExpr::getVICmp(unsigned short pred, Constant* LHS, Constant* RHS) {
   const Type *EltTy = VTy->getElementType();
   unsigned NumElts = VTy->getNumElements();
 
-  SmallVector<Constant *, 16> Elts;
-  for (unsigned i = 0; i != NumElts; ++i) {
-    Constant *FC = ConstantFoldCompareInstruction(pred, LHS->getOperand(i),
-                                                  RHS->getOperand(i));
-    if (ConstantInt *FCI = dyn_cast_or_null<ConstantInt>(FC)) {
-      if (FCI->getZExtValue())
-        Elts.push_back(ConstantInt::getAllOnesValue(EltTy));
-      else
-        Elts.push_back(ConstantInt::get(EltTy, 0ULL));
-    } else if (FC && isa<UndefValue>(FC)) {
-      Elts.push_back(UndefValue::get(EltTy));
-    } else {
-      break;
+  // See if we can fold the element-wise comparison of the LHS and RHS.
+  SmallVector<Constant *, 16> LHSElts, RHSElts;
+  LHS->getVectorElements(LHSElts);
+  RHS->getVectorElements(RHSElts);
+                    
+  if (!LHSElts.empty() && !RHSElts.empty()) {
+    SmallVector<Constant *, 16> Elts;
+    for (unsigned i = 0; i != NumElts; ++i) {
+      Constant *FC = ConstantFoldCompareInstruction(pred, LHSElts[i],
+                                                    RHSElts[i]);
+      if (ConstantInt *FCI = dyn_cast_or_null<ConstantInt>(FC)) {
+        if (FCI->getZExtValue())
+          Elts.push_back(ConstantInt::getAllOnesValue(EltTy));
+        else
+          Elts.push_back(ConstantInt::get(EltTy, 0ULL));
+      } else if (FC && isa<UndefValue>(FC)) {
+        Elts.push_back(UndefValue::get(EltTy));
+      } else {
+        break;
+      }
     }
+    if (Elts.size() == NumElts)
+      return ConstantVector::get(&Elts[0], Elts.size());
   }
-  if (Elts.size() == NumElts)
-    return ConstantVector::get(&Elts[0], Elts.size());
 
   // Look up the constant in the table first to ensure uniqueness
   std::vector<Constant*> ArgVec;
@@ -2251,23 +2258,30 @@ ConstantExpr::getVFCmp(unsigned short pred, Constant* LHS, Constant* RHS) {
   const Type *REltTy = IntegerType::get(EltTy->getPrimitiveSizeInBits());
   const Type *ResultTy = VectorType::get(REltTy, NumElts);
 
-  SmallVector<Constant *, 16> Elts;
-  for (unsigned i = 0; i != NumElts; ++i) {
-    Constant *FC = ConstantFoldCompareInstruction(pred, LHS->getOperand(i),
-                                                        RHS->getOperand(i));
-    if (ConstantInt *FCI = dyn_cast_or_null<ConstantInt>(FC)) {
-      if (FCI->getZExtValue())
-        Elts.push_back(ConstantInt::getAllOnesValue(REltTy));
-      else
-        Elts.push_back(ConstantInt::get(REltTy, 0ULL));
-    } else if (FC && isa<UndefValue>(FC)) {
-      Elts.push_back(UndefValue::get(REltTy));
-    } else {
-      break;
+  // See if we can fold the element-wise comparison of the LHS and RHS.
+  SmallVector<Constant *, 16> LHSElts, RHSElts;
+  LHS->getVectorElements(LHSElts);
+  RHS->getVectorElements(RHSElts);
+  
+  if (!LHSElts.empty() && !RHSElts.empty()) {
+    SmallVector<Constant *, 16> Elts;
+    for (unsigned i = 0; i != NumElts; ++i) {
+      Constant *FC = ConstantFoldCompareInstruction(pred, LHSElts[i],
+                                                    RHSElts[i]);
+      if (ConstantInt *FCI = dyn_cast_or_null<ConstantInt>(FC)) {
+        if (FCI->getZExtValue())
+          Elts.push_back(ConstantInt::getAllOnesValue(REltTy));
+        else
+          Elts.push_back(ConstantInt::get(REltTy, 0ULL));
+      } else if (FC && isa<UndefValue>(FC)) {
+        Elts.push_back(UndefValue::get(REltTy));
+      } else {
+        break;
+      }
     }
+    if (Elts.size() == NumElts)
+      return ConstantVector::get(&Elts[0], Elts.size());
   }
-  if (Elts.size() == NumElts)
-    return ConstantVector::get(&Elts[0], Elts.size());
 
   // Look up the constant in the table first to ensure uniqueness
   std::vector<Constant*> ArgVec;
@@ -2683,8 +2697,14 @@ void ConstantExpr::replaceUsesOfWithOnConstant(Value *From, Value *ToV,
     if (C2 == From) C2 = To;
     if (getOpcode() == Instruction::ICmp)
       Replacement = ConstantExpr::getICmp(getPredicate(), C1, C2);
-    else
+    else if (getOpcode() == Instruction::FCmp)
       Replacement = ConstantExpr::getFCmp(getPredicate(), C1, C2);
+    else if (getOpcode() == Instruction::VICmp)
+      Replacement = ConstantExpr::getVICmp(getPredicate(), C1, C2);
+    else {
+      assert(getOpcode() == Instruction::VFCmp);
+      Replacement = ConstantExpr::getVFCmp(getPredicate(), C1, C2);
+    }
   } else if (getNumOperands() == 2) {
     Constant *C1 = getOperand(0);
     Constant *C2 = getOperand(1);