Teach instsimplify how to constant fold pointer differences.
authorChandler Carruth <chandlerc@gmail.com>
Mon, 12 Mar 2012 11:19:31 +0000 (11:19 +0000)
committerChandler Carruth <chandlerc@gmail.com>
Mon, 12 Mar 2012 11:19:31 +0000 (11:19 +0000)
Typically instcombine has handled this, but pointer differences show up
in several contexts where we would like to get constant folding, and
cannot afford to run instcombine. Specifically, I'm working on improving
the constant folding of arguments used in inline cost analysis with
instsimplify.

Doing this in instsimplify implies some algorithm changes. We have to
handle multiple layers of all-constant GEPs because instsimplify cannot
fold them into a single GEP the way instcombine can. Also, we're only
interested in all-constant GEPs. The result is that this doesn't really
replace the instcombine logic, it's just complimentary and focused on
constant folding.

Reviewed on IRC by Benjamin Kramer.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@152555 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Analysis/InstructionSimplify.cpp
test/Transforms/InstSimplify/ptr_diff.ll [new file with mode: 0644]

index 0dd0d6ed3a1e0bca8ba95919bd55e03c469de938..cd97dfb197d965aab2f8a4a19df898174893a1a7 100644 (file)
@@ -18,6 +18,7 @@
 //===----------------------------------------------------------------------===//
 
 #define DEBUG_TYPE "instsimplify"
+#include "llvm/GlobalAlias.h"
 #include "llvm/Operator.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/InstructionSimplify.h"
@@ -26,6 +27,7 @@
 #include "llvm/Analysis/Dominators.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Support/ConstantRange.h"
+#include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/PatternMatch.h"
 #include "llvm/Support/ValueHandle.h"
 #include "llvm/Target/TargetData.h"
@@ -665,6 +667,112 @@ Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
   return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, TD, TLI, DT, RecursionLimit);
 }
 
+/// \brief Compute the constant integer offset a GEP represents.
+///
+/// Given a getelementptr instruction/constantexpr, form a constant expression
+/// which computes the offset from the base pointer (without adding in the base
+/// pointer).
+static Constant *computeGEPOffset(const TargetData &TD, GEPOperator *GEP) {
+  Type *IntPtrTy = TD.getIntPtrType(GEP->getContext());
+  Constant *Result = Constant::getNullValue(IntPtrTy);
+
+  // If the GEP is inbounds, we know that none of the addressing operations will
+  // overflow in an unsigned sense.
+  bool IsInBounds = GEP->isInBounds();
+
+  // Build a mask for high order bits.
+  unsigned IntPtrWidth = TD.getPointerSizeInBits();
+  uint64_t PtrSizeMask = ~0ULL >> (64-IntPtrWidth);
+
+  gep_type_iterator GTI = gep_type_begin(GEP);
+  for (User::op_iterator I = GEP->op_begin() + 1, E = GEP->op_end(); I != E;
+       ++I, ++GTI) {
+    ConstantInt *OpC = dyn_cast<ConstantInt>(*I);
+    if (!OpC) return 0;
+    if (OpC->isZero()) continue;
+
+    uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()) & PtrSizeMask;
+
+    // Handle a struct index, which adds its field offset to the pointer.
+    if (StructType *STy = dyn_cast<StructType>(*GTI)) {
+      Size = TD.getStructLayout(STy)->getElementOffset(OpC->getZExtValue());
+
+      if (Size)
+        Result = ConstantExpr::getAdd(Result, ConstantInt::get(IntPtrTy, Size));
+      continue;
+    }
+
+    Constant *Scale = ConstantInt::get(IntPtrTy, Size);
+    Constant *OC = ConstantExpr::getIntegerCast(OpC, IntPtrTy, true /*SExt*/);
+    Scale = ConstantExpr::getMul(OC, Scale, IsInBounds/*NUW*/);
+    Result = ConstantExpr::getAdd(Result, Scale);
+  }
+  return Result;
+}
+
+/// \brief Compute the base pointer and cumulative constant offsets for V.
+///
+/// This strips all constant offsets off of V, leaving it the base pointer, and
+/// accumulates the total constant offset applied in the returned constant. It
+/// returns 0 if V is not a pointer, and returns the constant '0' if there are
+/// no constant offsets applied.
+static Constant *stripAndComputeConstantOffsets(const TargetData &TD,
+                                                Value *&V) {
+  if (!V->getType()->isPointerTy())
+    return 0;
+
+  Type *IntPtrTy = TD.getIntPtrType(V->getContext());
+  Constant *Result = Constant::getNullValue(IntPtrTy);
+
+  // Even though we don't look through PHI nodes, we could be called on an
+  // instruction in an unreachable block, which may be on a cycle.
+  SmallPtrSet<Value *, 4> Visited;
+  Visited.insert(V);
+  do {
+    if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
+      Constant *Offset = computeGEPOffset(TD, GEP);
+      if (!Offset)
+        break;
+      Result = ConstantExpr::getAdd(Result, Offset);
+      V = GEP->getPointerOperand();
+    } else if (Operator::getOpcode(V) == Instruction::BitCast) {
+      V = cast<Operator>(V)->getOperand(0);
+    } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
+      if (GA->mayBeOverridden())
+        break;
+      V = GA->getAliasee();
+    } else {
+      break;
+    }
+    assert(V->getType()->isPointerTy() && "Unexpected operand type!");
+  } while (Visited.insert(V));
+
+  return Result;
+}
+
+/// \brief Compute the constant difference between two pointer values.
+/// If the difference is not a constant, returns zero.
+static Constant *computePointerDifference(const TargetData &TD,
+                                          Value *LHS, Value *RHS) {
+  Constant *LHSOffset = stripAndComputeConstantOffsets(TD, LHS);
+  if (!LHSOffset)
+    return 0;
+  Constant *RHSOffset = stripAndComputeConstantOffsets(TD, RHS);
+  if (!RHSOffset)
+    return 0;
+
+  // If LHS and RHS are not related via constant offsets to the same base
+  // value, there is nothing we can do here.
+  if (LHS != RHS)
+    return 0;
+
+  // Otherwise, the difference of LHS - RHS can be computed as:
+  //    LHS - RHS
+  //  = (LHSOffset + Base) - (RHSOffset + Base)
+  //  = LHSOffset - RHSOffset
+  return ConstantExpr::getSub(LHSOffset, RHSOffset);
+}
+
 /// SimplifySubInst - Given operands for a Sub, see if we can
 /// fold the result.  If not, this returns null.
 static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
@@ -699,6 +807,20 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
       match(Op0, m_Shl(m_Specific(Op1), m_One())))
     return Op1;
 
+  if (TD) {
+    Value *LHSOp, *RHSOp;
+    if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
+        match(Op1, m_PtrToInt(m_Value(RHSOp))))
+      if (Constant *Result = computePointerDifference(*TD, LHSOp, RHSOp))
+        return ConstantExpr::getIntegerCast(Result, Op0->getType(), true);
+
+    // trunc(p)-trunc(q) -> trunc(p-q)
+    if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
+        match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
+      if (Constant *Result = computePointerDifference(*TD, LHSOp, RHSOp))
+        return ConstantExpr::getIntegerCast(Result, Op0->getType(), true);
+  }
+
   // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies.
   // For example, (X + Y) - Y -> X; (Y + X) - Y -> X
   Value *Y = 0, *Z = Op1;
diff --git a/test/Transforms/InstSimplify/ptr_diff.ll b/test/Transforms/InstSimplify/ptr_diff.ll
new file mode 100644 (file)
index 0000000..013964c
--- /dev/null
@@ -0,0 +1,33 @@
+; RUN: opt < %s -instsimplify -S | FileCheck %s
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+define i64 @ptrdiff1(i8* %ptr) {
+; CHECK: @ptrdiff1
+; CHECK-NEXT: ret i64 42
+
+  %first = getelementptr i8* %ptr, i32 0
+  %last = getelementptr i8* %ptr, i32 42
+  %first.int = ptrtoint i8* %first to i64
+  %last.int = ptrtoint i8* %last to i64
+  %diff = sub i64 %last.int, %first.int
+  ret i64 %diff
+}
+
+define i64 @ptrdiff2(i8* %ptr) {
+; CHECK: @ptrdiff2
+; CHECK-NEXT: ret i64 42
+
+  %first1 = getelementptr i8* %ptr, i32 0
+  %first2 = getelementptr i8* %first1, i32 1
+  %first3 = getelementptr i8* %first2, i32 2
+  %first4 = getelementptr i8* %first3, i32 4
+  %last1 = getelementptr i8* %first2, i32 48
+  %last2 = getelementptr i8* %last1, i32 8
+  %last3 = getelementptr i8* %last2, i32 -4
+  %last4 = getelementptr i8* %last3, i32 -4
+  %first.int = ptrtoint i8* %first4 to i64
+  %last.int = ptrtoint i8* %last4 to i64
+  %diff = sub i64 %last.int, %first.int
+  ret i64 %diff
+}