Fix UBSan bootstrap: replace shift of negative value with multiplication.
[oota-llvm.git] / lib / Analysis / InstructionSimplify.cpp
index 45290fdcc008e63e012367ba3a73ae1993ff5bae..56e7a094976e0a1b935b2f5115ed569b1456e3b5 100644 (file)
@@ -1011,6 +1011,10 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   if (match(Op1, m_Undef()))
     return Op1;
 
+  // X / 0 -> undef, we don't need to preserve faults!
+  if (match(Op1, m_Zero()))
+    return UndefValue::get(Op1->getType());
+
   // undef / X -> 0
   if (match(Op0, m_Undef()))
     return Constant::getNullValue(Op0->getType());
@@ -1334,6 +1338,11 @@ static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1,
   if (Op0 == Op1)
     return Constant::getNullValue(Op0->getType());
 
+  // undef >> X -> 0
+  // undef >> X -> undef (if it's exact)
+  if (match(Op0, m_Undef()))
+    return isExact ? Op0 : Constant::getNullValue(Op0->getType());
+
   // The low bit cannot be shifted out of an exact shift if it is set.
   if (isExact) {
     unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
@@ -1356,8 +1365,9 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
     return V;
 
   // undef << X -> 0
+  // undef << X -> undef if (if it's NSW/NUW)
   if (match(Op0, m_Undef()))
-    return Constant::getNullValue(Op0->getType());
+    return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType());
 
   // (X >> A) << A -> X
   Value *X;
@@ -1382,10 +1392,6 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
                                     MaxRecurse))
       return V;
 
-  // undef >>l X -> 0
-  if (match(Op0, m_Undef()))
-    return Constant::getNullValue(Op0->getType());
-
   // (X << A) >> A -> X
   Value *X;
   if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
@@ -1416,10 +1422,6 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
   if (match(Op0, m_AllOnes()))
     return Op0;
 
-  // undef >>a X -> all ones
-  if (match(Op0, m_Undef()))
-    return Constant::getAllOnesValue(Op0->getType());
-
   // (X << A) >> A -> X
   Value *X;
   if (match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
@@ -1443,12 +1445,57 @@ Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
                             RecursionLimit);
 }
 
+static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp,
+                                         ICmpInst *UnsignedICmp, bool IsAnd) {
+  Value *X, *Y;
+
+  ICmpInst::Predicate EqPred;
+  if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) ||
+      !ICmpInst::isEquality(EqPred))
+    return nullptr;
+
+  ICmpInst::Predicate UnsignedPred;
+  if (match(UnsignedICmp, m_ICmp(UnsignedPred, m_Value(X), m_Specific(Y))) &&
+      ICmpInst::isUnsigned(UnsignedPred))
+    ;
+  else if (match(UnsignedICmp,
+                 m_ICmp(UnsignedPred, m_Value(Y), m_Specific(X))) &&
+           ICmpInst::isUnsigned(UnsignedPred))
+    UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);
+  else
+    return nullptr;
+
+  // X < Y && Y != 0  -->  X < Y
+  // X < Y || Y != 0  -->  Y != 0
+  if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE)
+    return IsAnd ? UnsignedICmp : ZeroICmp;
+
+  // X >= Y || Y != 0  -->  true
+  // X >= Y || Y == 0  -->  X >= Y
+  if (UnsignedPred == ICmpInst::ICMP_UGE && !IsAnd) {
+    if (EqPred == ICmpInst::ICMP_NE)
+      return getTrue(UnsignedICmp->getType());
+    return UnsignedICmp;
+  }
+
+  // X < Y && Y == 0  -->  false
+  if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_EQ &&
+      IsAnd)
+    return getFalse(UnsignedICmp->getType());
+
+  return nullptr;
+}
+
 // Simplify (and (icmp ...) (icmp ...)) to true when we can tell that the range
 // of possible values cannot be satisfied.
 static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
   ICmpInst::Predicate Pred0, Pred1;
   ConstantInt *CI1, *CI2;
   Value *V;
+
+  if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true))
+    return X;
+
   if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)),
                          m_ConstantInt(CI2))))
    return nullptr;
@@ -1602,6 +1649,10 @@ static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
   ICmpInst::Predicate Pred0, Pred1;
   ConstantInt *CI1, *CI2;
   Value *V;
+
+  if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false))
+    return X;
+
   if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)),
                          m_ConstantInt(CI2))))
    return nullptr;
@@ -3101,36 +3152,57 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal,
   if (isa<UndefValue>(FalseVal))   // select C, X, undef -> X
     return TrueVal;
 
-  if (const auto *ICI = dyn_cast<ICmpInst>(CondVal)) {
+  const auto *ICI = dyn_cast<ICmpInst>(CondVal);
+  unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits();
+  if (ICI && BitWidth) {
+    ICmpInst::Predicate Pred = ICI->getPredicate();
+    APInt MinSignedValue = APInt::getSignBit(BitWidth);
     Value *X;
     const APInt *Y;
-    if (ICI->isEquality() &&
+    bool TrueWhenUnset;
+    bool IsBitTest = false;
+    if (ICmpInst::isEquality(Pred) &&
         match(ICI->getOperand(0), m_And(m_Value(X), m_APInt(Y))) &&
         match(ICI->getOperand(1), m_Zero())) {
-      ICmpInst::Predicate Pred = ICI->getPredicate();
+      IsBitTest = true;
+      TrueWhenUnset = Pred == ICmpInst::ICMP_EQ;
+    } else if (Pred == ICmpInst::ICMP_SLT &&
+               match(ICI->getOperand(1), m_Zero())) {
+      X = ICI->getOperand(0);
+      Y = &MinSignedValue;
+      IsBitTest = true;
+      TrueWhenUnset = false;
+    } else if (Pred == ICmpInst::ICMP_SGT &&
+               match(ICI->getOperand(1), m_AllOnes())) {
+      X = ICI->getOperand(0);
+      Y = &MinSignedValue;
+      IsBitTest = true;
+      TrueWhenUnset = true;
+    }
+    if (IsBitTest) {
       const APInt *C;
       // (X & Y) == 0 ? X & ~Y : X  --> X
       // (X & Y) != 0 ? X & ~Y : X  --> X & ~Y
       if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) &&
           *Y == ~*C)
-        return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal;
+        return TrueWhenUnset ? FalseVal : TrueVal;
       // (X & Y) == 0 ? X : X & ~Y  --> X & ~Y
       // (X & Y) != 0 ? X : X & ~Y  --> X
       if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) &&
           *Y == ~*C)
-        return Pred == ICmpInst::ICMP_EQ ? FalseVal : TrueVal;
+        return TrueWhenUnset ? FalseVal : TrueVal;
 
       if (Y->isPowerOf2()) {
         // (X & Y) == 0 ? X | Y : X  --> X | Y
         // (X & Y) != 0 ? X | Y : X  --> X
         if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) &&
             *Y == *C)
-          return Pred == ICmpInst::ICMP_EQ ? TrueVal : FalseVal;
+          return TrueWhenUnset ? TrueVal : FalseVal;
         // (X & Y) == 0 ? X : X | Y  --> X
         // (X & Y) != 0 ? X : X | Y  --> X | Y
         if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) &&
             *Y == *C)
-          return Pred == ICmpInst::ICMP_EQ ? TrueVal : FalseVal;
+          return TrueWhenUnset ? TrueVal : FalseVal;
       }
     }
   }