R600/SI: Fix src1_modifiers for class instructions
[oota-llvm.git] / lib / IR / ConstantFold.cpp
index cdfb41f7dcce4b81de3161d01a3f9d2bfae1abd4..a915d285b14a1a743493d6c4dd1d6c602be5a727 100644 (file)
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/MathExtras.h"
 #include <limits>
 using namespace llvm;
+using namespace llvm::PatternMatch;
 
 //===----------------------------------------------------------------------===//
 //                ConstantFold*Instruction Implementations
@@ -913,49 +915,70 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
         return C1;
       return Constant::getNullValue(C1->getType());   // undef & X -> 0
     case Instruction::Mul: {
-      ConstantInt *CI;
-      // X * undef -> undef   if X is odd or undef
-      if (((CI = dyn_cast<ConstantInt>(C1)) && CI->getValue()[0]) ||
-          ((CI = dyn_cast<ConstantInt>(C2)) && CI->getValue()[0]) ||
-          (isa<UndefValue>(C1) && isa<UndefValue>(C2)))
-        return UndefValue::get(C1->getType());
+      // undef * undef -> undef
+      if (isa<UndefValue>(C1) && isa<UndefValue>(C2))
+        return C1;
+      const APInt *CV;
+      // X * undef -> undef   if X is odd
+      if (match(C1, m_APInt(CV)) || match(C2, m_APInt(CV)))
+        if ((*CV)[0])
+          return UndefValue::get(C1->getType());
 
       // X * undef -> 0       otherwise
       return Constant::getNullValue(C1->getType());
     }
-    case Instruction::UDiv:
     case Instruction::SDiv:
+    case Instruction::UDiv:
+      // X / undef -> undef
+      if (match(C1, m_Zero()))
+        return C2;
+      // undef / 0 -> undef
       // undef / 1 -> undef
-      if (Opcode == Instruction::UDiv || Opcode == Instruction::SDiv)
-        if (ConstantInt *CI2 = dyn_cast<ConstantInt>(C2))
-          if (CI2->isOne())
-            return C1;
-      // FALL THROUGH
+      if (match(C2, m_Zero()) || match(C2, m_One()))
+        return C1;
+      // undef / X -> 0       otherwise
+      return Constant::getNullValue(C1->getType());
     case Instruction::URem:
     case Instruction::SRem:
-      if (!isa<UndefValue>(C2))                    // undef / X -> 0
-        return Constant::getNullValue(C1->getType());
-      return C2;                                   // X / undef -> undef
+      // X % undef -> undef
+      if (match(C2, m_Undef()))
+        return C2;
+      // undef % 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // undef % X -> 0       otherwise
+      return Constant::getNullValue(C1->getType());
     case Instruction::Or:                          // X | undef -> -1
       if (isa<UndefValue>(C1) && isa<UndefValue>(C2)) // undef | undef -> undef
         return C1;
       return Constant::getAllOnesValue(C1->getType()); // undef | X -> ~0
     case Instruction::LShr:
-      if (isa<UndefValue>(C2) && isa<UndefValue>(C1))
-        return C1;                                  // undef lshr undef -> undef
-      return Constant::getNullValue(C1->getType()); // X lshr undef -> 0
-                                                    // undef lshr X -> 0
+      // X >>l undef -> undef
+      if (isa<UndefValue>(C2))
+        return C2;
+      // undef >>l 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // undef >>l X -> 0
+      return Constant::getNullValue(C1->getType());
     case Instruction::AShr:
-      if (!isa<UndefValue>(C2))                     // undef ashr X --> all ones
-        return Constant::getAllOnesValue(C1->getType());
-      else if (isa<UndefValue>(C1)) 
-        return C1;                                  // undef ashr undef -> undef
-      else
-        return C1;                                  // X ashr undef --> X
+      // X >>a undef -> undef
+      if (isa<UndefValue>(C2))
+        return C2;
+      // undef >>a 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // TODO: undef >>a X -> undef if the shift is exact
+      // undef >>a X -> 0
+      return Constant::getNullValue(C1->getType());
     case Instruction::Shl:
-      if (isa<UndefValue>(C2) && isa<UndefValue>(C1))
-        return C1;                                  // undef shl undef -> undef
-      // undef << X -> 0   or   X << undef -> 0
+      // X << undef -> undef
+      if (isa<UndefValue>(C2))
+        return C2;
+      // undef << 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // undef << X -> 0
       return Constant::getNullValue(C1->getType());
     }
   }
@@ -1259,15 +1282,17 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) {
   if (!isa<ConstantInt>(C1) || !isa<ConstantInt>(C2))
     return -2; // don't know!
 
-  // Ok, we have two differing integer indices.  Sign extend them to be the same
-  // type.  Long is always big enough, so we use it.
-  if (!C1->getType()->isIntegerTy(64))
-    C1 = ConstantExpr::getSExt(C1, Type::getInt64Ty(C1->getContext()));
+  // We cannot compare the indices if they don't fit in an int64_t.
+  if (cast<ConstantInt>(C1)->getValue().getActiveBits() > 64 ||
+      cast<ConstantInt>(C2)->getValue().getActiveBits() > 64)
+    return -2; // don't know!
 
-  if (!C2->getType()->isIntegerTy(64))
-    C2 = ConstantExpr::getSExt(C2, Type::getInt64Ty(C1->getContext()));
+  // Ok, we have two differing integer indices.  Sign extend them to be the same
+  // type.
+  int64_t C1Val = cast<ConstantInt>(C1)->getSExtValue();
+  int64_t C2Val = cast<ConstantInt>(C2)->getSExtValue();
 
-  if (C1 == C2) return 0;  // They are equal
+  if (C1Val == C2Val) return 0;  // They are equal
 
   // If the type being indexed over is really just a zero sized type, there is
   // no pointer difference being made here.
@@ -1276,8 +1301,7 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) {
 
   // If they are really different, now that they are the same type, then we
   // found a difference!
-  if (cast<ConstantInt>(C1)->getSExtValue() < 
-      cast<ConstantInt>(C2)->getSExtValue())
+  if (C1Val < C2Val)
     return -1;
   else
     return 1;
@@ -1348,9 +1372,24 @@ static FCmpInst::Predicate evaluateFCmpRelation(Constant *V1, Constant *V2) {
 
 static ICmpInst::Predicate areGlobalsPotentiallyEqual(const GlobalValue *GV1,
                                                       const GlobalValue *GV2) {
+  auto isGlobalUnsafeForEquality = [](const GlobalValue *GV) {
+    if (GV->hasExternalWeakLinkage() || GV->hasWeakAnyLinkage())
+      return true;
+    if (const auto *GVar = dyn_cast<GlobalVariable>(GV)) {
+      Type *Ty = GVar->getType()->getPointerElementType();
+      // A global with opaque type might end up being zero sized.
+      if (!Ty->isSized())
+        return true;
+      // A global with an empty type might lie at the address of any other
+      // global.
+      if (Ty->isEmptyTy())
+        return true;
+    }
+    return false;
+  };
   // Don't try to decide equality of aliases.
   if (!isa<GlobalAlias>(GV1) && !isa<GlobalAlias>(GV2))
-    if (!GV1->hasExternalWeakLinkage() || !GV2->hasExternalWeakLinkage())
+    if (!isGlobalUnsafeForEquality(GV1) && !isGlobalUnsafeForEquality(GV2))
       return ICmpInst::ICMP_NE;
   return ICmpInst::BAD_ICMP_PREDICATE;
 }
@@ -2040,8 +2079,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
       if (PerformFold) {
         SmallVector<Value*, 16> NewIndices;
         NewIndices.reserve(Idxs.size() + CE->getNumOperands());
-        for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i)
-          NewIndices.push_back(CE->getOperand(i));
+        NewIndices.append(CE->op_begin() + 1, CE->op_end() - 1);
 
         // Add the last index of the source with the first index of the new GEP.
         // Make sure to handle the case when they are actually different types.
@@ -2050,9 +2088,15 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
         if (!Idx0->isNullValue()) {
           Type *IdxTy = Combined->getType();
           if (IdxTy != Idx0->getType()) {
-            Type *Int64Ty = Type::getInt64Ty(IdxTy->getContext());
-            Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, Int64Ty);
-            Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, Int64Ty);
+            unsigned CommonExtendedWidth =
+                std::max(IdxTy->getIntegerBitWidth(),
+                         Idx0->getType()->getIntegerBitWidth());
+            CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
+
+            Type *CommonTy =
+                Type::getIntNTy(IdxTy->getContext(), CommonExtendedWidth);
+            Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, CommonTy);
+            Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, CommonTy);
             Combined = ConstantExpr::get(Instruction::Add, C1, C2);
           } else {
             Combined =
@@ -2125,14 +2169,20 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
             Constant *PrevIdx = cast<Constant>(Idxs[i-1]);
             Constant *Div = ConstantExpr::getSDiv(CI, Factor);
 
+            unsigned CommonExtendedWidth =
+                std::max(PrevIdx->getType()->getIntegerBitWidth(),
+                         Div->getType()->getIntegerBitWidth());
+            CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
+
             // Before adding, extend both operands to i64 to avoid
             // overflow trouble.
-            if (!PrevIdx->getType()->isIntegerTy(64))
-              PrevIdx = ConstantExpr::getSExt(PrevIdx,
-                                           Type::getInt64Ty(Div->getContext()));
-            if (!Div->getType()->isIntegerTy(64))
-              Div = ConstantExpr::getSExt(Div,
-                                          Type::getInt64Ty(Div->getContext()));
+            if (!PrevIdx->getType()->isIntegerTy(CommonExtendedWidth))
+              PrevIdx = ConstantExpr::getSExt(
+                  PrevIdx,
+                  Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
+            if (!Div->getType()->isIntegerTy(CommonExtendedWidth))
+              Div = ConstantExpr::getSExt(
+                  Div, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
 
             NewIdxs[i-1] = ConstantExpr::getAdd(PrevIdx, Div);
           } else {