IC: (X & C1) | C2 --> (X | C2) & (C1|C2)
[oota-llvm.git] / lib / Transforms / Scalar / InstructionCombining.cpp
index 505313ba9c3e26b2ae2eb3ce9a08f992b7e54ec6..df0b3566bee646d3b1029f3d75904d82609c1ccd 100644 (file)
@@ -22,6 +22,7 @@
 #include "llvm/Constants.h"
 #include "llvm/ConstantHandling.h"
 #include "llvm/DerivedTypes.h"
+#include "llvm/GlobalVariable.h"
 #include "llvm/Support/InstIterator.h"
 #include "llvm/Support/InstVisitor.h"
 #include "llvm/Support/CallSite.h"
@@ -79,6 +80,7 @@ namespace {
     Instruction *visitPHINode(PHINode &PN);
     Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP);
     Instruction *visitAllocationInst(AllocationInst &AI);
+    Instruction *visitLoadInst(LoadInst &LI);
     Instruction *visitBranchInst(BranchInst &BI);
 
     // visitInstruction - Specify what to return for unhandled instructions...
@@ -289,6 +291,13 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
   return Changed ? &I : 0;
 }
 
+// isSignBit - Return true if the value represented by the constant only has the
+// highest order bit set.
+static bool isSignBit(ConstantInt *CI) {
+  unsigned NumBits = CI->getType()->getPrimitiveSize()*8;
+  return (CI->getRawValue() & ~(-1LL << NumBits)) == (1ULL << (NumBits-1));
+}
+
 Instruction *InstCombiner::visitSub(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
@@ -361,10 +370,10 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
   if (Constant *Op1 = dyn_cast<Constant>(I.getOperand(1))) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
       const Type *Ty = CI->getType();
-      uint64_t Val = Ty->isSigned() ?
-                          (uint64_t)cast<ConstantSInt>(CI)->getValue() : 
-                                    cast<ConstantUInt>(CI)->getValue();
+      int64_t Val = (int64_t)cast<ConstantInt>(CI)->getRawValue();
       switch (Val) {
+      case -1:                               // X * -1 -> -X
+        return BinaryOperator::createNeg(Op0, I.getName());
       case 0:
         return ReplaceInstUsesWith(I, Op1);  // Eliminate 'mul double %X, 0'
       case 1:
@@ -484,19 +493,31 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
     return ReplaceInstUsesWith(I, Op1);
 
   // and X, -1 == X
-  if (ConstantIntegral *RHS = dyn_cast<ConstantIntegral>(Op1))
+  if (ConstantIntegral *RHS = dyn_cast<ConstantIntegral>(Op1)) {
     if (RHS->isAllOnesValue())
       return ReplaceInstUsesWith(I, Op0);
 
+    // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2)
+    if (Instruction *Op0I = dyn_cast<Instruction>(Op0))
+      if (Op0I->getOpcode() == Instruction::Xor && isOnlyUse(Op0))
+        if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
+          std::string Op0Name = Op0I->getName(); Op0I->setName("");
+          Instruction *And = BinaryOperator::create(Instruction::And,
+                                                    Op0I->getOperand(0), RHS,
+                                                   Op0Name);
+          InsertNewInstBefore(And, I);
+          return BinaryOperator::create(Instruction::Xor, And, *RHS & *Op0CI);
+        }
+  }
+
   Value *Op0NotVal = dyn_castNotVal(Op0);
   Value *Op1NotVal = dyn_castNotVal(Op1);
 
   // (~A & ~B) == (~(A | B)) - Demorgan's Law
   if (Op0NotVal && Op1NotVal && isOnlyUse(Op0) && isOnlyUse(Op1)) {
     Instruction *Or = BinaryOperator::create(Instruction::Or, Op0NotVal,
-                                             Op1NotVal,I.getName()+".demorgan",
-                                             &I);
-    WorkList.push_back(Or);
+                                             Op1NotVal,I.getName()+".demorgan");
+    InsertNewInstBefore(Or, I);
     return BinaryOperator::createNot(Or);
   }
 
@@ -517,10 +538,35 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
     return ReplaceInstUsesWith(I, Op0);
 
   // or X, -1 == -1
-  if (ConstantIntegral *RHS = dyn_cast<ConstantIntegral>(Op1))
+  if (ConstantIntegral *RHS = dyn_cast<ConstantIntegral>(Op1)) {
     if (RHS->isAllOnesValue())
       return ReplaceInstUsesWith(I, Op1);
 
+    if (Instruction *Op0I = dyn_cast<Instruction>(Op0)) {
+      // (X & C1) | C2 --> (X | C2) & (C1|C2)
+      if (Op0I->getOpcode() == Instruction::And && isOnlyUse(Op0))
+        if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
+          std::string Op0Name = Op0I->getName(); Op0I->setName("");
+          Instruction *Or = BinaryOperator::create(Instruction::Or,
+                                                   Op0I->getOperand(0), RHS,
+                                                   Op0Name);
+          InsertNewInstBefore(Or, I);
+          return BinaryOperator::create(Instruction::And, Or, *RHS | *Op0CI);
+        }
+
+      // (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2)
+      if (Op0I->getOpcode() == Instruction::Xor && isOnlyUse(Op0))
+        if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
+          std::string Op0Name = Op0I->getName(); Op0I->setName("");
+          Instruction *Or = BinaryOperator::create(Instruction::Or,
+                                                   Op0I->getOperand(0), RHS,
+                                                   Op0Name);
+          InsertNewInstBefore(Or, I);
+          return BinaryOperator::create(Instruction::Xor, Or, *Op0CI & *~*RHS);
+        }
+    }
+  }
+
   Value *Op0NotVal = dyn_castNotVal(Op0);
   Value *Op1NotVal = dyn_castNotVal(Op1);
 
@@ -688,15 +734,40 @@ Instruction *InstCombiner::visitSetCondInst(BinaryOperator &I) {
   // integers at the end of their ranges...
   //
   if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
-    if (CI->isNullValue()) {
-      if (I.getOpcode() == Instruction::SetNE)
-        return new CastInst(Op0, Type::BoolTy, I.getName());
-      else if (I.getOpcode() == Instruction::SetEQ) {
+    // Simplify seteq and setne instructions...
+    if (I.getOpcode() == Instruction::SetEQ ||
+        I.getOpcode() == Instruction::SetNE) {
+      bool isSetNE = I.getOpcode() == Instruction::SetNE;
+
+      if (CI->isNullValue()) {   // Simplify [seteq|setne] X, 0
+        CastInst *Val = new CastInst(Op0, Type::BoolTy, I.getName()+".not");
+        if (isSetNE) return Val;
+
         // seteq X, 0 -> not (cast X to bool)
-        Instruction *Val = new CastInst(Op0, Type::BoolTy, I.getName()+".not");
         InsertNewInstBefore(Val, I);
         return BinaryOperator::createNot(Val, I.getName());
       }
+
+      // If the first operand is (and|or|xor) with a constant, and the second
+      // operand is a constant, simplify a bit.
+      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0))
+        if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1)))
+          if (BO->getOpcode() == Instruction::Or) {
+            // If bits are being or'd in that are not present in the constant we
+            // are comparing against, then the comparison could never succeed!
+            if (!(*BOC & *~*CI)->isNullValue())
+              return ReplaceInstUsesWith(I, ConstantBool::get(isSetNE));
+          } else if (BO->getOpcode() == Instruction::And) {
+            // If bits are being compared against that are and'd out, then the
+            // comparison can never succeed!
+            if (!(*CI & *~*BOC)->isNullValue())
+              return ReplaceInstUsesWith(I, ConstantBool::get(isSetNE));
+          } else if (BO->getOpcode() == Instruction::Xor) {
+            // For the xor case, we can always just xor the two constants
+            // together, potentially eliminating the explicit xor.
+            return BinaryOperator::create(I.getOpcode(), BO->getOperand(0),
+                                          *CI ^ *BOC);
+          }
     }
 
     // Check to see if we are comparing against the minimum or maximum value...
@@ -956,13 +1027,15 @@ Instruction *InstCombiner::visitCastInst(CastInst &CI) {
     if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Src)) {
       Value *Op0 = BO->getOperand(0), *Op1 = BO->getOperand(1);
 
-      // Replace (cast (sub A, B) to bool) with (setne A, B)
-      if (BO->getOpcode() == Instruction::Sub)
+      switch (BO->getOpcode()) {
+      case Instruction::Sub:
+      case Instruction::Xor:
+        // Replace (cast ([sub|xor] A, B) to bool) with (setne A, B)
         return new SetCondInst(Instruction::SetNE, Op0, Op1);
 
       // Replace (cast (add A, B) to bool) with (setne A, -B) if B is
       // efficiently invertible, or if the add has just this one use.
-      if (BO->getOpcode() == Instruction::Add)
+      case Instruction::Add:
         if (Value *NegVal = dyn_castNegVal(Op1))
           return new SetCondInst(Instruction::SetNE, Op0, NegVal);
         else if (Value *NegVal = dyn_castNegVal(Op0))
@@ -973,6 +1046,36 @@ Instruction *InstCombiner::visitCastInst(CastInst &CI) {
           InsertNewInstBefore(Neg, CI);
           return new SetCondInst(Instruction::SetNE, Op0, Neg);
         }
+        break;
+
+      case Instruction::And:
+        // Replace (cast (and X, (1 << size(X)-1)) to bool) with x < 0,
+        // converting X to be a signed value as appropriate.  Don't worry about
+        // bool values, as they will be optimized other ways if they occur in
+        // this configuration.
+        if (ConstantInt *CInt = dyn_cast<ConstantInt>(Op1))
+          if (isSignBit(CInt)) {
+            // If 'X' is not signed, insert a cast now...
+            if (!CInt->getType()->isSigned()) {
+              const Type *DestTy;
+              switch (CInt->getType()->getPrimitiveID()) {
+              case Type::UByteTyID:  DestTy = Type::SByteTy; break;
+              case Type::UShortTyID: DestTy = Type::ShortTy; break;
+              case Type::UIntTyID:   DestTy = Type::IntTy;   break;
+              case Type::ULongTyID:  DestTy = Type::LongTy;  break;
+              default: assert(0 && "Invalid unsigned integer type!"); abort();
+              }
+              CastInst *NewCI = new CastInst(Op0, DestTy,
+                                             Op0->getName()+".signed");
+              InsertNewInstBefore(NewCI, CI);
+              Op0 = NewCI;
+            }
+            return new SetCondInst(Instruction::SetLT, Op0,
+                                   Constant::getNullValue(Op0->getType()));
+          }
+        break;
+      default: break;
+      }
     }
   }
 
@@ -1260,6 +1363,52 @@ Instruction *InstCombiner::visitAllocationInst(AllocationInst &AI) {
   return 0;
 }
 
+/// GetGEPGlobalInitializer - Given a constant, and a getelementptr
+/// constantexpr, return the constant value being addressed by the constant
+/// expression, or null if something is funny.
+///
+static Constant *GetGEPGlobalInitializer(Constant *C, ConstantExpr *CE) {
+  if (CE->getOperand(1) != Constant::getNullValue(Type::LongTy))
+    return 0;  // Do not allow stepping over the value!
+
+  // Loop over all of the operands, tracking down which value we are
+  // addressing...
+  for (unsigned i = 2, e = CE->getNumOperands(); i != e; ++i)
+    if (ConstantUInt *CU = dyn_cast<ConstantUInt>(CE->getOperand(i))) {
+      ConstantStruct *CS = cast<ConstantStruct>(C);
+      if (CU->getValue() >= CS->getValues().size()) return 0;
+      C = cast<Constant>(CS->getValues()[CU->getValue()]);
+    } else if (ConstantSInt *CS = dyn_cast<ConstantSInt>(CE->getOperand(i))) {
+      ConstantArray *CA = cast<ConstantArray>(C);
+      if ((uint64_t)CS->getValue() >= CA->getValues().size()) return 0;
+      C = cast<Constant>(CA->getValues()[CS->getValue()]);
+    } else 
+      return 0;
+  return C;
+}
+
+Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
+  Value *Op = LI.getOperand(0);
+  if (ConstantPointerRef *CPR = dyn_cast<ConstantPointerRef>(Op))
+    Op = CPR->getValue();
+
+  // Instcombine load (constant global) into the value loaded...
+  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Op))
+    if (GV->isConstant() && !GV->isExternal())
+      return ReplaceInstUsesWith(LI, GV->getInitializer());
+
+  // Instcombine load (constantexpr_GEP global, 0, ...) into the value loaded...
+  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op))
+    if (CE->getOpcode() == Instruction::GetElementPtr)
+      if (ConstantPointerRef *G=dyn_cast<ConstantPointerRef>(CE->getOperand(0)))
+        if (GlobalVariable *GV = dyn_cast<GlobalVariable>(G->getValue()))
+          if (GV->isConstant() && !GV->isExternal())
+            if (Constant *V = GetGEPGlobalInitializer(GV->getInitializer(), CE))
+              return ReplaceInstUsesWith(LI, V);
+  return 0;
+}
+
+
 Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
   // Change br (not X), label True, label False to: br X, label False, True
   if (BI.isConditional() && !isa<Constant>(BI.getCondition()))