Distribute sext/zext to the operands of and/or/xor
authorJingyue Wu <jingyue@google.com>
Tue, 27 May 2014 18:00:00 +0000 (18:00 +0000)
committerJingyue Wu <jingyue@google.com>
Tue, 27 May 2014 18:00:00 +0000 (18:00 +0000)
This is an enhancement to SeparateConstOffsetFromGEP. With this patch, we can
extract a constant offset from "s/zext and/or/xor A, B".

Added a new test @ext_or to verify this enhancement.

Refactoring the code, I also extracted some common logic to function
Distributable.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@209670 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp
test/Transforms/SeparateConstOffsetFromGEP/NVPTX/split-gep.ll

index ac3e7c4..b8529e1 100644 (file)
@@ -165,6 +165,10 @@ class ConstantOffsetExtractor {
   void ComputeKnownBits(Value *V, APInt &KnownOne, APInt &KnownZero) const;
   /// Finds the first use of Used in U. Returns -1 if not found.
   static unsigned FindFirstUse(User *U, Value *Used);
+  /// Returns whether OPC (sext or zext) can be distributed to the operands of
+  /// BO. e.g., sext can be distributed to the operands of an "add nsw" because
+  /// sext (add nsw a, b) == add nsw (sext a), (sext b).
+  static bool Distributable(unsigned OPC, BinaryOperator *BO);
 
   /// The path from the constant offset to the old GEP index. e.g., if the GEP
   /// index is "a * b + (c + 5)". After running function find, UserChain[0] will
@@ -223,6 +227,25 @@ FunctionPass *llvm::createSeparateConstOffsetFromGEPPass() {
   return new SeparateConstOffsetFromGEP();
 }
 
+bool ConstantOffsetExtractor::Distributable(unsigned OPC, BinaryOperator *BO) {
+  assert(OPC == Instruction::SExt || OPC == Instruction::ZExt);
+
+  // sext (add/sub nsw A, B) == add/sub nsw (sext A), (sext B)
+  // zext (add/sub nuw A, B) == add/sub nuw (zext A), (zext B)
+  if (BO->getOpcode() == Instruction::Add ||
+      BO->getOpcode() == Instruction::Sub) {
+    return (OPC == Instruction::SExt && BO->hasNoSignedWrap()) ||
+           (OPC == Instruction::ZExt && BO->hasNoUnsignedWrap());
+  }
+
+  // sext/zext (and/or/xor A, B) == and/or/xor (sext/zext A), (sext/zext B)
+  // -instcombine also leverages this invariant to do the reverse
+  // transformation to reduce integer casts.
+  return BO->getOpcode() == Instruction::And ||
+         BO->getOpcode() == Instruction::Or ||
+         BO->getOpcode() == Instruction::Xor;
+}
+
 int64_t ConstantOffsetExtractor::findInEitherOperand(User *U, bool IsSub) {
   assert(U->getNumOperands() == 2);
   int64_t ConstantOffset = find(U->getOperand(0));
@@ -273,21 +296,14 @@ int64_t ConstantOffsetExtractor::find(Value *V) {
           ConstantOffset = findInEitherOperand(U, false);
         break;
       }
-      case Instruction::SExt: {
-        // For safety, we trace into sext only when its operand is marked
-        // "nsw" because xxx.nsw guarantees no signed wrap. e.g., we can safely
-        // transform "sext (add nsw a, 5)" into "add nsw (sext a), 5".
-        if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0))) {
-          if (BO->hasNoSignedWrap())
-            ConstantOffset = find(U->getOperand(0));
-        }
-        break;
-      }
+      case Instruction::SExt:
       case Instruction::ZExt: {
-        // Similarly, we trace into zext only when its operand is marked with
-        // "nuw" because zext (add nuw a, b) == add nuw (zext a), (zext b).
+        // We trace into sext/zext if the operator can be distributed to its
+        // operand. e.g., we can transform into "sext (add nsw a, 5)" and
+        // extract constant 5, because
+        //   sext (add nsw a, 5) == add nsw (sext a), 5
         if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0))) {
-          if (BO->hasNoUnsignedWrap())
+          if (Distributable(O->getOpcode(), BO))
             ConstantOffset = find(U->getOperand(0));
         }
         break;
index 320af5f..42136d2 100644 (file)
@@ -57,6 +57,25 @@ define float* @ext_add_no_overflow(i64 %a, i32 %b, i64 %c, i32 %d) {
 ; CHECK: [[BASE_PTR:%[0-9]+]] = getelementptr [32 x [32 x float]]* @float_2d_array, i64 0, i64 %{{[0-9]+}}, i64 %{{[0-9]+}}
 ; CHECK: getelementptr float* [[BASE_PTR]], i64 33
 
+; Similar to @ext_add_no_overflow, we should be able to trace into sext/zext if
+; its operand is an "or" instruction.
+define float* @ext_or(i64 %a, i32 %b) {
+entry:
+  %b1 = shl i32 %b, 2
+  %b2 = or i32 %b1, 1
+  %b3 = or i32 %b1, 2
+  %b2.ext = sext i32 %b2 to i64
+  %b3.ext = sext i32 %b3 to i64
+  %i = add i64 %a, %b2.ext
+  %j = add i64 %a, %b3.ext
+  %p = getelementptr inbounds [32 x [32 x float]]* @float_2d_array, i64 0, i64 %i, i64 %j
+  ret float* %p
+}
+; CHECK-LABEL: @ext_or
+; CHECK: [[BASE_PTR:%[0-9]+]] = getelementptr [32 x [32 x float]]* @float_2d_array, i64 0, i64 %{{[0-9]+}}, i64 %{{[0-9]+}}
+; CHECK: [[BASE_INT:%[0-9]+]] = ptrtoint float* [[BASE_PTR]] to i64
+; CHECK: add i64 [[BASE_INT]], 136
+
 ; We should treat "or" with no common bits (%k) as "add", and leave "or" with
 ; potentially common bits (%l) as is.
 define float* @or(i64 %i) {