recognize an unsigned add with overflow idiom into uadd.
authorChris Lattner <sabre@nondot.org>
Sun, 19 Dec 2010 19:37:52 +0000 (19:37 +0000)
committerChris Lattner <sabre@nondot.org>
Sun, 19 Dec 2010 19:37:52 +0000 (19:37 +0000)
This resolves a README entry and technically resolves PR4916,
but we still get poor code for the testcase in that PR because
GVN isn't CSE'ing uadd with add, filed as PR8817.

Previously we got:

_test7:                                 ## @test7
addq %rsi, %rdi
cmpq %rdi, %rsi
movl $42, %eax
cmovaq %rsi, %rax
ret

Now we get:

_test7:                                 ## @test7
addq %rsi, %rdi
movl $42, %eax
cmovbq %rsi, %rax
ret

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@122182 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/README.txt
lib/Transforms/InstCombine/InstCombineCompares.cpp
test/Transforms/InstCombine/overflow.ll

index 5c97b8984c1d6e9e172445cca7099c3a5f57ba89..abb8ee6f4999138124365f5cf0a0a38bf9e0156e 100644 (file)
@@ -74,26 +74,7 @@ This has a number of uses:
 //===---------------------------------------------------------------------===//
 
 We should recognized various "overflow detection" idioms and translate them into
-llvm.uadd.with.overflow and similar intrinsics.  For example, we compile this:
-
-size_t add(size_t a,size_t b) {
- if (a+b<a)
-   exit(0);
- return a+b;
-}
-
-into:
-
-       addq    %rdi, %rbx
-       cmpq    %rdi, %rbx
-       jae     LBB0_2
-
-when it would be better to generate:
-
-       addq    %rdi, %rbx
-       jno     LBB0_2
-
-Apparently some version of GCC knows this.  Here is a multiply idiom:
+llvm.uadd.with.overflow and similar intrinsics.  Here is a multiply idiom:
 
 unsigned int mul(unsigned int a,unsigned int b) {
  if ((unsigned long long)a*b>0xffffffff)
index 86a0ed2652dbd856116d2c3ec2dd9c7da61c676f..6c6d26c81c70dcf74d0d902657d8a25a5ff02d5f 100644 (file)
@@ -1648,7 +1648,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
   
   // Put the new code above the original add, in case there are any uses of the
   // add between the add and the compare.
-  Builder->SetInsertPoint(OrigAdd->getParent(), BasicBlock::iterator(OrigAdd));
+  Builder->SetInsertPoint(OrigAdd);
   
   Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc");
   Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc");
@@ -1664,6 +1664,35 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
   return ExtractValueInst::Create(Call, 1, "sadd.overflow");
 }
 
+static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV,
+                                     InstCombiner &IC) {
+  // Don't bother doing this transformation for pointers, don't do it for
+  // vectors.
+  if (!isa<IntegerType>(OrigAddV->getType())) return 0;
+  
+  // If the add is a constant expr, then we don't bother transforming it.
+  Instruction *OrigAdd = dyn_cast<Instruction>(OrigAddV);
+  if (OrigAdd == 0) return 0;
+  
+  Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1);
+  
+  // Put the new code above the original add, in case there are any uses of the
+  // add between the add and the compare.
+  InstCombiner::BuilderTy *Builder = IC.Builder;
+  Builder->SetInsertPoint(OrigAdd);
+
+  Module *M = I.getParent()->getParent()->getParent();
+  const Type *Ty = LHS->getType();
+  Value *F = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, &Ty,1);
+  CallInst *Call = Builder->CreateCall2(F, LHS, RHS, "uadd");
+  Value *Add = Builder->CreateExtractValue(Call, 0);
+
+  IC.ReplaceInstUsesWith(*OrigAdd, Add);
+
+  // The original icmp gets replaced with the overflow value.
+  return ExtractValueInst::Create(Call, 1, "uadd.overflow");
+}
+
 
 Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
   bool Changed = false;
@@ -1726,11 +1755,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
   }
 
   unsigned BitWidth = 0;
-  if (TD)
-    BitWidth = TD->getTypeSizeInBits(Ty->getScalarType());
-  else if (Ty->isIntOrIntVectorTy())
+  if (Ty->isIntOrIntVectorTy())
     BitWidth = Ty->getScalarSizeInBits();
-
+  else if (TD)  // Pointers require TD info to get their size.
+    BitWidth = TD->getTypeSizeInBits(Ty->getScalarType());
+  
   bool isSignBit = false;
 
   // See if we are doing a comparison with a constant.
@@ -2225,6 +2254,22 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     if (match(Op0, m_Not(m_Value(A))) &&
         match(Op1, m_Not(m_Value(B))))
       return new ICmpInst(I.getPredicate(), B, A);
+
+    // (a+b) <u a  --> llvm.uadd.with.overflow.
+    // (a+b) <u b  --> llvm.uadd.with.overflow.
+    if (I.getPredicate() == ICmpInst::ICMP_ULT &&
+        match(Op0, m_Add(m_Value(A), m_Value(B))) && 
+        (Op1 == A || Op1 == B))
+      if (Instruction *R = ProcessUAddIdiom(I, Op0, *this))
+        return R;
+                                 
+    // a >u (a+b)  --> llvm.uadd.with.overflow.
+    // b >u (a+b)  --> llvm.uadd.with.overflow.
+    if (I.getPredicate() == ICmpInst::ICMP_UGT &&
+        match(Op1, m_Add(m_Value(A), m_Value(B))) &&
+        (Op0 == A || Op0 == B))
+      if (Instruction *R = ProcessUAddIdiom(I, Op1, *this))
+        return R;
   }
   
   if (I.isEquality()) {
index 6a53d27749a7e8fa57d1e90200b7ad4e97c710d8..9123283988de965afd2fe9fd73b9c894d2f23b39 100644 (file)
@@ -97,3 +97,37 @@ if.end:                                           ; preds = %entry
 ; CHECK: ret i8
 }
 
+; CHECK: @test5
+; CHECK: llvm.uadd.with.overflow
+; CHECK: ret i64
+define i64 @test5(i64 %a, i64 %b) nounwind ssp {
+entry:
+  %add = add i64 %b, %a
+  %cmp = icmp ult i64 %add, %a
+  %Q = select i1 %cmp, i64 %b, i64 42
+  ret i64 %Q
+}
+
+; CHECK: @test6
+; CHECK: llvm.uadd.with.overflow
+; CHECK: ret i64
+define i64 @test6(i64 %a, i64 %b) nounwind ssp {
+entry:
+  %add = add i64 %b, %a
+  %cmp = icmp ult i64 %add, %b
+  %Q = select i1 %cmp, i64 %b, i64 42
+  ret i64 %Q
+}
+
+; CHECK: @test7
+; CHECK: llvm.uadd.with.overflow
+; CHECK: ret i64
+define i64 @test7(i64 %a, i64 %b) nounwind ssp {
+entry:
+  %add = add i64 %b, %a
+  %cmp = icmp ugt i64 %b, %add
+  %Q = select i1 %cmp, i64 %b, i64 42
+  ret i64 %Q
+}
+
+