Teach InstCombine's ComputeMaskedBits to handle pointer expressions
authorDan Gohman <gohman@apple.com>
Thu, 10 Apr 2008 18:43:06 +0000 (18:43 +0000)
committerDan Gohman <gohman@apple.com>
Thu, 10 Apr 2008 18:43:06 +0000 (18:43 +0000)
in addition to integer expressions. Rewrite GetOrEnforceKnownAlignment
as a ComputeMaskedBits problem, moving all of its special alignment
knowledge to ComputeMaskedBits as low-zero-bits knowledge.

Also, teach ComputeMaskedBits a few basic things about Mul and PHI
instructions.

This improves ComputeMaskedBits-based simplifications in a few cases,
but more noticeably it significantly improves instcombine's alignment
detection for loads, stores, and memory intrinsics.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@49492 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/InstructionCombining.cpp
test/Transforms/InstCombine/align-2d-gep.ll [new file with mode: 0644]
test/Transforms/InstCombine/align-addr.ll [new file with mode: 0644]

index beeee4911a5f043748455077c3f07c435a5849af..7b974731b1b05a6d013aa20024f15d1d63d66b29 100644 (file)
@@ -372,6 +372,15 @@ namespace {
 
 
     Value *EvaluateInDifferentType(Value *V, const Type *Ty, bool isSigned);
+
+    void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero, 
+                           APInt& KnownOne, unsigned Depth = 0);
+    bool MaskedValueIsZero(Value *V, const APInt& Mask, unsigned Depth = 0);
+    bool CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
+                                    unsigned CastOpc,
+                                    int &NumCastsRemoved);
+    unsigned GetOrEnforceKnownAlignment(Value *V,
+                                        unsigned PrefAlign = 0);
   };
 
   char InstCombiner::ID = 0;
@@ -580,6 +589,17 @@ static User *dyn_castGetElementPtr(Value *V) {
   return false;
 }
 
+/// getOpcode - If this is an Instruction or a ConstantExpr, return the
+/// opcode value. Otherwise return UserOp1.
+static unsigned getOpcode(User *U) {
+  if (Instruction *I = dyn_cast<Instruction>(U))
+    return I->getOpcode();
+  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U))
+    return CE->getOpcode();
+  // Use UserOp1 to mean there's no opcode.
+  return Instruction::UserOp1;
+}
+
 /// AddOne - Add one to a ConstantInt
 static ConstantInt *AddOne(ConstantInt *C) {
   APInt Val(C->getValue());
@@ -639,12 +659,17 @@ static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) {
 /// optimized based on the contradictory assumption that it is non-zero.
 /// Because instcombine aggressively folds operations with undef args anyway,
 /// this won't lose us code quality.
-static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero, 
-                              APInt& KnownOne, unsigned Depth = 0) {
+void InstCombiner::ComputeMaskedBits(Value *V, const APInt &Mask,
+                                     APInt& KnownZero, APInt& KnownOne,
+                                     unsigned Depth) {
   assert(V && "No Value?");
   assert(Depth <= 6 && "Limit Search Depth");
   uint32_t BitWidth = Mask.getBitWidth();
-  assert(cast<IntegerType>(V->getType())->getBitWidth() == BitWidth &&
+  assert((V->getType()->isInteger() || isa<PointerType>(V->getType())) &&
+         "Not integer or pointer type!");
+  assert((!TD || TD->getTypeSizeInBits(V->getType()) == BitWidth) &&
+         (!isa<IntegerType>(V->getType()) ||
+          V->getType()->getPrimitiveSizeInBits() == BitWidth) &&
          KnownZero.getBitWidth() == BitWidth && 
          KnownOne.getBitWidth() == BitWidth &&
          "V, Mask, KnownOne and KnownZero should have same BitWidth");
@@ -654,17 +679,37 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
     KnownZero = ~KnownOne & Mask;
     return;
   }
+  // Null is all-zeros.
+  if (isa<ConstantPointerNull>(V)) {
+    KnownOne.clear();
+    KnownZero = Mask;
+    return;
+  }
+  // The address of an aligned GlobalValue has trailing zeros.
+  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
+    unsigned Align = GV->getAlignment();
+    if (Align == 0 && TD && GV->getType()->getElementType()->isSized()) 
+      Align = TD->getPrefTypeAlignment(GV->getType()->getElementType());
+    if (Align > 0)
+      KnownZero = Mask & APInt::getLowBitsSet(BitWidth,
+                                              CountTrailingZeros_32(Align));
+    else
+      KnownZero.clear();
+    KnownOne.clear();
+    return;
+  }
 
   if (Depth == 6 || Mask == 0)
     return;  // Limit search depth.
 
-  Instruction *I = dyn_cast<Instruction>(V);
+  User *I = dyn_cast<User>(V);
   if (!I) return;
 
   KnownZero.clear(); KnownOne.clear();   // Don't know anything.
   APInt KnownZero2(KnownZero), KnownOne2(KnownOne);
   
-  switch (I->getOpcode()) {
+  switch (getOpcode(I)) {
+  default: break;
   case Instruction::And: {
     // If either the LHS or the RHS are Zero, the result is zero.
     ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
@@ -705,6 +750,24 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
     KnownZero = KnownZeroOut;
     return;
   }
+  case Instruction::Mul: {
+    APInt Mask2 = APInt::getAllOnesValue(BitWidth);
+    ComputeMaskedBits(I->getOperand(1), Mask2, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero2, KnownOne2, Depth+1);
+    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
+    assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); 
+    
+    // If low bits are zero in either operand, output low known-0 bits.
+    // More trickiness is possible, but this is sufficient for the
+    // interesting case of alignment computation.
+    KnownOne.clear();
+    unsigned TrailZ = KnownZero.countTrailingOnes() +
+                      KnownZero2.countTrailingOnes();
+    TrailZ = std::min(TrailZ, BitWidth);
+    KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ);
+    KnownZero &= Mask;
+    return;
+  }
   case Instruction::Select:
     ComputeMaskedBits(I->getOperand(2), Mask, KnownZero, KnownOne, Depth+1);
     ComputeMaskedBits(I->getOperand(1), Mask, KnownZero2, KnownOne2, Depth+1);
@@ -720,48 +783,40 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
   case Instruction::FPToUI:
   case Instruction::FPToSI:
   case Instruction::SIToFP:
-  case Instruction::PtrToInt:
   case Instruction::UIToFP:
+    return; // Can't work with floating point.
+  case Instruction::PtrToInt:
   case Instruction::IntToPtr:
-    return; // Can't work with floating point or pointers
+    // We can't handle these if we don't know the pointer size.
+    if (!TD) return;
+    // Fall through and handle them the same as zext/trunc.
+  case Instruction::ZExt:
   case Instruction::Trunc: {
     // All these have integer operands
-    uint32_t SrcBitWidth = 
-      cast<IntegerType>(I->getOperand(0)->getType())->getBitWidth();
+    const Type *SrcTy = I->getOperand(0)->getType();
+    uint32_t SrcBitWidth = TD ?
+      TD->getTypeSizeInBits(SrcTy) :
+      SrcTy->getPrimitiveSizeInBits();
     APInt MaskIn(Mask);
-    MaskIn.zext(SrcBitWidth);
-    KnownZero.zext(SrcBitWidth);
-    KnownOne.zext(SrcBitWidth);
+    MaskIn.zextOrTrunc(SrcBitWidth);
+    KnownZero.zextOrTrunc(SrcBitWidth);
+    KnownOne.zextOrTrunc(SrcBitWidth);
     ComputeMaskedBits(I->getOperand(0), MaskIn, KnownZero, KnownOne, Depth+1);
-    KnownZero.trunc(BitWidth);
-    KnownOne.trunc(BitWidth);
+    KnownZero.zextOrTrunc(BitWidth);
+    KnownOne.zextOrTrunc(BitWidth);
+    // Any top bits are known to be zero.
+    if (BitWidth > SrcBitWidth)
+      KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth);
     return;
   }
   case Instruction::BitCast: {
     const Type *SrcTy = I->getOperand(0)->getType();
-    if (SrcTy->isInteger()) {
+    if (SrcTy->isInteger() || isa<PointerType>(SrcTy)) {
       ComputeMaskedBits(I->getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
       return;
     }
     break;
   }
-  case Instruction::ZExt:  {
-    // Compute the bits in the result that are not present in the input.
-    const IntegerType *SrcTy = cast<IntegerType>(I->getOperand(0)->getType());
-    uint32_t SrcBitWidth = SrcTy->getBitWidth();
-      
-    APInt MaskIn(Mask);
-    MaskIn.trunc(SrcBitWidth);
-    KnownZero.trunc(SrcBitWidth);
-    KnownOne.trunc(SrcBitWidth);
-    ComputeMaskedBits(I->getOperand(0), MaskIn, KnownZero, KnownOne, Depth+1);
-    assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
-    // The top bits are known to be zero.
-    KnownZero.zext(BitWidth);
-    KnownOne.zext(BitWidth);
-    KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth);
-    return;
-  }
   case Instruction::SExt: {
     // Compute the bits in the result that are not present in the input.
     const IntegerType *SrcTy = cast<IntegerType>(I->getOperand(0)->getType());
@@ -835,6 +890,32 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
       return;
     }
     break;
+  case Instruction::Sub: {
+    if (ConstantInt *CLHS = dyn_cast<ConstantInt>(I->getOperand(0))) {
+      // We know that the top bits of C-X are clear if X contains less bits
+      // than C (i.e. no wrap-around can happen).  For example, 20-X is
+      // positive if we can prove that X is >= 0 and < 16.
+      if (!CLHS->getValue().isNegative()) {
+        unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros();
+        // NLZ can't be BitWidth with no sign bit
+        APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1);
+        ComputeMaskedBits(I->getOperand(1), MaskV, KnownZero, KnownOne, Depth+1);
+    
+        // If all of the MaskV bits are known to be zero, then we know the output
+        // top bits are zero, because we now know that the output is from [0-C].
+        if ((KnownZero & MaskV) == MaskV) {
+          unsigned NLZ2 = CLHS->getValue().countLeadingZeros();
+          // Top bits known zero.
+          KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2) & Mask;
+          KnownOne = APInt(BitWidth, 0);   // No one bits known.
+        } else {
+          KnownZero = KnownOne = APInt(BitWidth, 0);  // Otherwise, nothing known.
+        }
+        return;
+      }        
+    }
+  }
+  // fall through
   case Instruction::Add: {
     // If either the LHS or the RHS are Zero, the result is zero.
     ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
@@ -852,33 +933,6 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
     KnownOne = APInt(BitWidth, 0);
     return;
   }
-  case Instruction::Sub: {
-    ConstantInt *CLHS = dyn_cast<ConstantInt>(I->getOperand(0));
-    if (!CLHS) break;
-        
-    // We know that the top bits of C-X are clear if X contains less bits
-    // than C (i.e. no wrap-around can happen).  For example, 20-X is
-    // positive if we can prove that X is >= 0 and < 16.
-    if (CLHS->getValue().isNegative())
-      break;
-    
-    unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros();
-    // NLZ can't be BitWidth with no sign bit
-    APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1);
-    ComputeMaskedBits(I->getOperand(1), MaskV, KnownZero, KnownOne, Depth+1);
-    
-    // If all of the MaskV bits are known to be zero, then we know the output
-    // top bits are zero, because we now know that the output is from [0-C].
-    if ((KnownZero & MaskV) == MaskV) {
-      unsigned NLZ2 = CLHS->getValue().countLeadingZeros();
-      // Top bits known zero.
-      KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2) & Mask;
-      KnownOne = APInt(BitWidth, 0);   // No one bits known.
-    } else {
-      KnownZero = KnownOne = APInt(BitWidth, 0);  // Otherwise, nothing known.
-    }
-    return;
-  }        
   case Instruction::SRem:
     if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
       APInt RA = Rem->getValue();
@@ -923,13 +977,124 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
       assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
     }
     break;
+
+  case Instruction::Alloca:
+  case Instruction::Malloc: {
+    AllocationInst *AI = cast<AllocationInst>(V);
+    unsigned Align = AI->getAlignment();
+    if (Align == 0 && TD) {
+      if (isa<AllocaInst>(AI))
+        Align = TD->getPrefTypeAlignment(AI->getType()->getElementType());
+      else if (isa<MallocInst>(AI)) {
+        // Malloc returns maximally aligned memory.
+        Align = TD->getABITypeAlignment(AI->getType()->getElementType());
+        Align =
+          std::max(Align,
+                   (unsigned)TD->getABITypeAlignment(Type::DoubleTy));
+        Align =
+          std::max(Align,
+                   (unsigned)TD->getABITypeAlignment(Type::Int64Ty));
+      }
+    }
+    
+    if (Align > 0)
+      KnownZero = Mask & APInt::getLowBitsSet(BitWidth,
+                                              CountTrailingZeros_32(Align));
+    break;
+  }
+  case Instruction::GetElementPtr: {
+    // Analyze all of the subscripts of this getelementptr instruction
+    // to determine if we can prove known low zero bits.
+    APInt LocalMask = APInt::getAllOnesValue(BitWidth);
+    APInt LocalKnownZero(BitWidth, 0), LocalKnownOne(BitWidth, 0);
+    ComputeMaskedBits(I->getOperand(0), LocalMask,
+                      LocalKnownZero, LocalKnownOne, Depth+1);
+    unsigned TrailZ = LocalKnownZero.countTrailingOnes();
+
+    gep_type_iterator GTI = gep_type_begin(I);
+    for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
+      Value *Index = I->getOperand(i);
+      if (const StructType *STy = dyn_cast<StructType>(*GTI)) {
+        // Handle struct member offset arithmetic.
+        if (!TD) return;
+        const StructLayout *SL = TD->getStructLayout(STy);
+        unsigned Idx = cast<ConstantInt>(Index)->getZExtValue();
+        uint64_t Offset = SL->getElementOffset(Idx);
+        TrailZ = std::min(TrailZ,
+                          CountTrailingZeros_64(Offset));
+      } else {
+        // Handle array index arithmetic.
+        const Type *IndexedTy = GTI.getIndexedType();
+        if (!IndexedTy->isSized()) return;
+        unsigned GEPOpiBits = Index->getType()->getPrimitiveSizeInBits();
+        uint64_t TypeSize = TD ? TD->getABITypeSize(IndexedTy) : 1;
+        LocalMask = APInt::getAllOnesValue(GEPOpiBits);
+        LocalKnownZero = LocalKnownOne = APInt(GEPOpiBits, 0);
+        ComputeMaskedBits(Index, LocalMask,
+                          LocalKnownZero, LocalKnownOne, Depth+1);
+        TrailZ = std::min(TrailZ,
+                          CountTrailingZeros_64(TypeSize) +
+                            LocalKnownZero.countTrailingOnes());
+      }
+    }
+    
+    KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ) & Mask;
+    break;
+  }
+  case Instruction::PHI: {
+    PHINode *P = cast<PHINode>(I);
+    // Handle the case of a simple two-predecessor recurrence PHI.
+    // There's a lot more that could theoretically be done here, but
+    // this is sufficient to catch some interesting cases.
+    if (P->getNumIncomingValues() == 2) {
+      for (unsigned i = 0; i != 2; ++i) {
+        Value *L = P->getIncomingValue(i);
+        Value *R = P->getIncomingValue(!i);
+        User *LU = dyn_cast<User>(L);
+        unsigned Opcode = LU ? getOpcode(LU) : (unsigned)Instruction::UserOp1;
+        // Check for operations that have the property that if
+        // both their operands have low zero bits, the result
+        // will have low zero bits.
+        if (Opcode == Instruction::Add ||
+            Opcode == Instruction::Sub ||
+            Opcode == Instruction::And ||
+            Opcode == Instruction::Or ||
+            Opcode == Instruction::Mul) {
+          Value *LL = LU->getOperand(0);
+          Value *LR = LU->getOperand(1);
+          // Find a recurrence.
+          if (LL == I)
+            L = LR;
+          else if (LR == I)
+            L = LL;
+          else
+            break;
+          // Ok, we have a PHI of the form L op= R. Check for low
+          // zero bits.
+          APInt Mask2 = APInt::getAllOnesValue(BitWidth);
+          ComputeMaskedBits(R, Mask2, KnownZero2, KnownOne2, Depth+1);
+          Mask2 = APInt::getLowBitsSet(BitWidth,
+                                       KnownZero2.countTrailingOnes());
+          KnownOne2.clear();
+          KnownZero2.clear();
+          ComputeMaskedBits(L, Mask2, KnownZero2, KnownOne2, Depth+1);
+          KnownZero = Mask &
+                      APInt::getLowBitsSet(BitWidth,
+                                           KnownZero2.countTrailingOnes());
+          break;
+        }
+      }
+    }
+    break;
+  }
   }
 }
 
 /// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero.  We use
 /// this predicate to simplify operations downstream.  Mask is known to be zero
 /// for bits that V cannot have.
-static bool MaskedValueIsZero(Value *V, const APInt& Mask, unsigned Depth = 0) {
+bool InstCombiner::MaskedValueIsZero(Value *V, const APInt& Mask,
+                                     unsigned Depth) {
   APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0);
   ComputeMaskedBits(V, Mask, KnownZero, KnownOne, Depth);
   assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
@@ -6695,8 +6860,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
 ///
 /// This is a truncation operation if Ty is smaller than V->getType(), or an
 /// extension operation if Ty is larger.
-static bool CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
-                                       unsigned CastOpc, int &NumCastsRemoved) {
+bool InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
+                                              unsigned CastOpc,
+                                              int &NumCastsRemoved) {
   // We can always evaluate constants in another type.
   if (isa<ConstantInt>(V))
     return true;
@@ -8062,94 +8228,83 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
   return 0;
 }
 
-/// GetOrEnforceKnownAlignment - If the specified pointer has an alignment that
-/// we can determine, return it, otherwise return 0.  If PrefAlign is specified,
-/// and it is more than the alignment of the ultimate object, see if we can
-/// increase the alignment of the ultimate object, making this check succeed.
-static unsigned GetOrEnforceKnownAlignment(Value *V, TargetData *TD,
-                                           unsigned PrefAlign = 0) {
-  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
-    unsigned Align = GV->getAlignment();
-    if (Align == 0 && TD && GV->getType()->getElementType()->isSized()) 
-      Align = TD->getPrefTypeAlignment(GV->getType()->getElementType());
+/// EnforceKnownAlignment - If the specified pointer points to an object that
+/// we control, modify the object's alignment to PrefAlign. This isn't
+/// often possible though. If alignment is important, a more reliable approach
+/// is to simply align all global variables and allocation instructions to
+/// their preferred alignment from the beginning.
+///
+static unsigned EnforceKnownAlignment(Value *V,
+                                      unsigned Align, unsigned PrefAlign) {
+
+  User *U = dyn_cast<User>(V);
+  if (!U) return Align;
+
+  switch (getOpcode(U)) {
+  default: break;
+  case Instruction::BitCast:
+    return EnforceKnownAlignment(U->getOperand(0), Align, PrefAlign);
+  case Instruction::GetElementPtr: {
+    // If all indexes are zero, it is just the alignment of the base pointer.
+    bool AllZeroOperands = true;
+    for (unsigned i = 1, e = U->getNumOperands(); i != e; ++i)
+      if (!isa<Constant>(U->getOperand(i)) ||
+          !cast<Constant>(U->getOperand(i))->isNullValue()) {
+        AllZeroOperands = false;
+        break;
+      }
+
+    if (AllZeroOperands) {
+      // Treat this like a bitcast.
+      return EnforceKnownAlignment(U->getOperand(0), Align, PrefAlign);
+    }
+    break;
+  }
+  }
 
+  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
     // If there is a large requested alignment and we can, bump up the alignment
     // of the global.
-    if (PrefAlign > Align && GV->hasInitializer()) {
+    if (!GV->isDeclaration()) {
       GV->setAlignment(PrefAlign);
       Align = PrefAlign;
     }
-    return Align;
   } else if (AllocationInst *AI = dyn_cast<AllocationInst>(V)) {
-    unsigned Align = AI->getAlignment();
-    if (Align == 0 && TD) {
-      if (isa<AllocaInst>(AI))
-        Align = TD->getPrefTypeAlignment(AI->getType()->getElementType());
-      else if (isa<MallocInst>(AI)) {
-        // Malloc returns maximally aligned memory.
-        Align = TD->getABITypeAlignment(AI->getType()->getElementType());
-        Align =
-          std::max(Align,
-                   (unsigned)TD->getABITypeAlignment(Type::DoubleTy));
-        Align =
-          std::max(Align,
-                   (unsigned)TD->getABITypeAlignment(Type::Int64Ty));
-      }
-    }
-    
     // If there is a requested alignment and if this is an alloca, round up.  We
     // don't do this for malloc, because some systems can't respect the request.
-    if (PrefAlign > Align && isa<AllocaInst>(AI)) {
+    if (isa<AllocaInst>(AI)) {
       AI->setAlignment(PrefAlign);
       Align = PrefAlign;
     }
-    return Align;
-  } else if (isa<BitCastInst>(V) ||
-             (isa<ConstantExpr>(V) && 
-              cast<ConstantExpr>(V)->getOpcode() == Instruction::BitCast)) {
-    return GetOrEnforceKnownAlignment(cast<User>(V)->getOperand(0),
-                                      TD, PrefAlign);
-  } else if (User *GEPI = dyn_castGetElementPtr(V)) {
-    // If all indexes are zero, it is just the alignment of the base pointer.
-    bool AllZeroOperands = true;
-    for (unsigned i = 1, e = GEPI->getNumOperands(); i != e; ++i)
-      if (!isa<Constant>(GEPI->getOperand(i)) ||
-          !cast<Constant>(GEPI->getOperand(i))->isNullValue()) {
-        AllZeroOperands = false;
-        break;
-      }
-
-    if (AllZeroOperands) {
-      // Treat this like a bitcast.
-      return GetOrEnforceKnownAlignment(GEPI->getOperand(0), TD, PrefAlign);
-    }
+  }
 
-    unsigned BaseAlignment = GetOrEnforceKnownAlignment(GEPI->getOperand(0),TD);
-    if (BaseAlignment == 0) return 0;
+  return Align;
+}
 
-    // Otherwise, if the base alignment is >= the alignment we expect for the
-    // base pointer type, then we know that the resultant pointer is aligned at
-    // least as much as its type requires.
-    if (!TD) return 0;
+/// GetOrEnforceKnownAlignment - If the specified pointer has an alignment that
+/// we can determine, return it, otherwise return 0.  If PrefAlign is specified,
+/// and it is more than the alignment of the ultimate object, see if we can
+/// increase the alignment of the ultimate object, making this check succeed.
+unsigned InstCombiner::GetOrEnforceKnownAlignment(Value *V,
+                                                  unsigned PrefAlign) {
+  unsigned BitWidth = TD ? TD->getTypeSizeInBits(V->getType()) :
+                      sizeof(PrefAlign) * CHAR_BIT;
+  APInt Mask = APInt::getAllOnesValue(BitWidth);
+  APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+  ComputeMaskedBits(V, Mask, KnownZero, KnownOne);
+  unsigned TrailZ = KnownZero.countTrailingOnes();
+  unsigned Align = 1u << std::min(BitWidth - 1, TrailZ);
 
-    const Type *BasePtrTy = GEPI->getOperand(0)->getType();
-    const PointerType *PtrTy = cast<PointerType>(BasePtrTy);
-    unsigned Align = TD->getABITypeAlignment(PtrTy->getElementType());
-    if (Align <= BaseAlignment) {
-      const Type *GEPTy = GEPI->getType();
-      const PointerType *GEPPtrTy = cast<PointerType>(GEPTy);
-      Align = std::min(Align, (unsigned)
-                       TD->getABITypeAlignment(GEPPtrTy->getElementType()));
-      return Align;
-    }
-    return 0;
-  }
-  return 0;
+  if (PrefAlign > Align)
+    Align = EnforceKnownAlignment(V, Align, PrefAlign);
+  
+    // We don't need to make any adjustment.
+  return Align;
 }
 
 Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
-  unsigned DstAlign = GetOrEnforceKnownAlignment(MI->getOperand(1), TD);
-  unsigned SrcAlign = GetOrEnforceKnownAlignment(MI->getOperand(2), TD);
+  unsigned DstAlign = GetOrEnforceKnownAlignment(MI->getOperand(1));
+  unsigned SrcAlign = GetOrEnforceKnownAlignment(MI->getOperand(2));
   unsigned MinAlign = std::min(DstAlign, SrcAlign);
   unsigned CopyAlign = MI->getAlignment()->getZExtValue();
 
@@ -8270,7 +8425,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
       if (Instruction *I = SimplifyMemTransfer(MI))
         return I;
     } else if (isa<MemSetInst>(MI)) {
-      unsigned Alignment = GetOrEnforceKnownAlignment(MI->getDest(), TD);
+      unsigned Alignment = GetOrEnforceKnownAlignment(MI->getDest());
       if (MI->getAlignment()->getZExtValue() < Alignment) {
         MI->setAlignment(ConstantInt::get(Type::Int32Ty, Alignment));
         Changed = true;
@@ -8288,7 +8443,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
     case Intrinsic::x86_sse2_loadu_dq:
       // Turn PPC lvx     -> load if the pointer is known aligned.
       // Turn X86 loadups -> load if the pointer is known aligned.
-      if (GetOrEnforceKnownAlignment(II->getOperand(1), TD, 16) >= 16) {
+      if (GetOrEnforceKnownAlignment(II->getOperand(1), 16) >= 16) {
         Value *Ptr = InsertBitCastBefore(II->getOperand(1),
                                          PointerType::getUnqual(II->getType()),
                                          CI);
@@ -8298,7 +8453,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
     case Intrinsic::ppc_altivec_stvx:
     case Intrinsic::ppc_altivec_stvxl:
       // Turn stvx -> store if the pointer is known aligned.
-      if (GetOrEnforceKnownAlignment(II->getOperand(2), TD, 16) >= 16) {
+      if (GetOrEnforceKnownAlignment(II->getOperand(2), 16) >= 16) {
         const Type *OpPtrTy = 
           PointerType::getUnqual(II->getOperand(1)->getType());
         Value *Ptr = InsertBitCastBefore(II->getOperand(2), OpPtrTy, CI);
@@ -8310,7 +8465,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
     case Intrinsic::x86_sse2_storeu_dq:
     case Intrinsic::x86_sse2_storel_dq:
       // Turn X86 storeu -> store if the pointer is known aligned.
-      if (GetOrEnforceKnownAlignment(II->getOperand(1), TD, 16) >= 16) {
+      if (GetOrEnforceKnownAlignment(II->getOperand(1), 16) >= 16) {
         const Type *OpPtrTy = 
           PointerType::getUnqual(II->getOperand(2)->getType());
         Value *Ptr = InsertBitCastBefore(II->getOperand(1), OpPtrTy, CI);
@@ -9762,8 +9917,10 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
   Value *Op = LI.getOperand(0);
 
   // Attempt to improve the alignment.
-  unsigned KnownAlign = GetOrEnforceKnownAlignment(Op, TD);
-  if (KnownAlign > LI.getAlignment())
+  unsigned KnownAlign = GetOrEnforceKnownAlignment(Op);
+  if (KnownAlign >
+      (LI.getAlignment() == 0 ? TD->getABITypeAlignment(LI.getType()) :
+                                LI.getAlignment()))
     LI.setAlignment(KnownAlign);
 
   // load (cast X) --> cast (load X) iff safe
@@ -9980,8 +10137,10 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
   }
 
   // Attempt to improve the alignment.
-  unsigned KnownAlign = GetOrEnforceKnownAlignment(Ptr, TD);
-  if (KnownAlign > SI.getAlignment())
+  unsigned KnownAlign = GetOrEnforceKnownAlignment(Ptr);
+  if (KnownAlign >
+      (SI.getAlignment() == 0 ? TD->getABITypeAlignment(Val->getType()) :
+                                SI.getAlignment()))
     SI.setAlignment(KnownAlign);
 
   // Do really simple DSE, to catch cases where there are several consequtive
diff --git a/test/Transforms/InstCombine/align-2d-gep.ll b/test/Transforms/InstCombine/align-2d-gep.ll
new file mode 100644 (file)
index 0000000..c826e31
--- /dev/null
@@ -0,0 +1,43 @@
+; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep {align 16} | count 1
+
+; A multi-dimensional array in a nested loop doing vector stores that
+; aren't yet aligned. Instcombine can understand the addressing in the
+; Nice case to prove 16 byte alignment. In the Awkward case, the inner
+; array dimension is not even, so the stores to it won't always be
+; aligned. Instcombine should prove alignment in exactly one of the two
+; stores.
+
+@Nice    = global [1001 x [20000 x double]] zeroinitializer, align 32
+@Awkward = global [1001 x [20001 x double]] zeroinitializer, align 32
+
+define void @foo() nounwind  {
+entry:
+  br label %bb7.outer
+
+bb7.outer:
+  %i = phi i64 [ 0, %entry ], [ %indvar.next26, %bb11 ]
+  br label %bb1
+
+bb1:
+  %j = phi i64 [ 0, %bb7.outer ], [ %indvar.next, %bb1 ]
+
+  %t4 = getelementptr [1001 x [20000 x double]]* @Nice, i64 0, i64 %i, i64 %j
+  %q = bitcast double* %t4 to <2 x double>*
+  store <2 x double><double 0.0, double 0.0>, <2 x double>* %q, align 8
+
+  %s4 = getelementptr [1001 x [20001 x double]]* @Awkward, i64 0, i64 %i, i64 %j
+  %r = bitcast double* %s4 to <2 x double>*
+  store <2 x double><double 0.0, double 0.0>, <2 x double>* %r, align 8
+
+  %indvar.next = add i64 %j, 2
+  %exitcond = icmp eq i64 %indvar.next, 557
+  br i1 %exitcond, label %bb11, label %bb1
+
+bb11:
+  %indvar.next26 = add i64 %i, 1
+  %exitcond27 = icmp eq i64 %indvar.next26, 991
+  br i1 %exitcond27, label %return.split, label %bb7.outer
+
+return.split:
+  ret void
+}
diff --git a/test/Transforms/InstCombine/align-addr.ll b/test/Transforms/InstCombine/align-addr.ll
new file mode 100644 (file)
index 0000000..a05c513
--- /dev/null
@@ -0,0 +1,30 @@
+; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep {align 16} | count 1
+
+; Instcombine should be able to prove vector alignment in the
+; presence of a few mild address computation tricks.
+
+define void @foo(i8* %b, i64 %n, i64 %u, i64 %y) nounwind  {
+entry:
+  %c = ptrtoint i8* %b to i64
+  %d = and i64 %c, -16
+  %e = inttoptr i64 %d to double*
+  %v = mul i64 %u, 2
+  %z = and i64 %y, -2
+  %t1421 = icmp eq i64 %n, 0
+  br i1 %t1421, label %return, label %bb
+
+bb:
+  %i = phi i64 [ %indvar.next, %bb ], [ 20, %entry ]
+  %j = mul i64 %i, %v
+  %h = add i64 %j, %z
+  %t8 = getelementptr double* %e, i64 %h
+  %p = bitcast double* %t8 to <2 x double>*
+  store <2 x double><double 0.0, double 0.0>, <2 x double>* %p, align 8
+  %indvar.next = add i64 %i, 1
+  %exitcond = icmp eq i64 %indvar.next, %n
+  br i1 %exitcond, label %return, label %bb
+
+return:
+  ret void
+}
+