Allow min/max detection to see through casts.
authorJames Molloy <james.molloy@arm.com>
Fri, 15 May 2015 16:04:50 +0000 (16:04 +0000)
committerJames Molloy <james.molloy@arm.com>
Fri, 15 May 2015 16:04:50 +0000 (16:04 +0000)
This teaches the min/max idiom detector in ValueTracking to see through
casts such as SExt/ZExt/Trunc. SCEV can already do this, so we're bringing
non-SCEV analyses up to the same level.

The returned LHS/RHS will not match the type of the original SelectInst
any more, so a CastOp is returned too to inform the caller how to
convert to the SelectInst's type.

No in-tree users yet; this will be used by InstCombine in a followup.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@237452 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ValueTracking.h
lib/Analysis/ValueTracking.cpp

index 32b3805b229b8a65665de5ab0f20adf72c0df2ac..204d5f06d4a2a7e6fb401953af8066020a147ff7 100644 (file)
@@ -16,6 +16,7 @@
 #define LLVM_ANALYSIS_VALUETRACKING_H
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/Support/DataTypes.h"
 
 namespace llvm {
@@ -280,7 +281,21 @@ namespace llvm {
   };
   /// Pattern match integer [SU]MIN, [SU]MAX and ABS idioms, returning the kind
   /// and providing the out parameter results if we successfully match.
-  SelectPatternFlavor matchSelectPattern(Value *V, Value *&LHS, Value *&RHS);
+  ///
+  /// If CastOp is not nullptr, also match MIN/MAX idioms where the type does
+  /// not match that of the original select. If this is the case, the cast
+  /// operation (one of Trunc,SExt,Zext) that must be done to transform the
+  /// type of LHS and RHS into the type of V is returned in CastOp.
+  ///
+  /// For example:
+  ///   %1 = icmp slt i32 %a, i32 4
+  ///   %2 = sext i32 %a to i64
+  ///   %3 = select i1 %1, i64 %2, i64 4
+  ///
+  /// -> LHS = %a, RHS = i32 4, *CastOp = Instruction::SExt
+  ///
+  SelectPatternFlavor matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
+                                         Instruction::CastOps *CastOp = nullptr);
 
 } // end namespace llvm
 
index 21d9e2e7f69148183323f418631ec862faaa4f20..c70a8087d868419e7c8a2e15e65fe721bf26d8e3 100644 (file)
@@ -3220,20 +3220,10 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(Value *LHS, Value *RHS,
   return OverflowResult::MayOverflow;
 }
 
-SelectPatternFlavor llvm::matchSelectPattern(Value *V,
-                                             Value *&LHS, Value *&RHS) {
-  SelectInst *SI = dyn_cast<SelectInst>(V);
-  if (!SI) return SPF_UNKNOWN;
-
-  ICmpInst *ICI = dyn_cast<ICmpInst>(SI->getCondition());
-  if (!ICI) return SPF_UNKNOWN;
-
-  ICmpInst::Predicate Pred = ICI->getPredicate();
-  Value *CmpLHS = ICI->getOperand(0);
-  Value *CmpRHS = ICI->getOperand(1);
-  Value *TrueVal = SI->getTrueValue();
-  Value *FalseVal = SI->getFalseValue();
-
+static SelectPatternFlavor matchSelectPattern(ICmpInst::Predicate Pred,
+                                              Value *CmpLHS, Value *CmpRHS,
+                                              Value *TrueVal, Value *FalseVal,
+                                              Value *&LHS, Value *&RHS) {
   LHS = CmpLHS;
   RHS = CmpRHS;
 
@@ -3300,3 +3290,55 @@ SelectPatternFlavor llvm::matchSelectPattern(Value *V,
 
   return SPF_UNKNOWN;
 }
+
+static Constant *lookThroughCast(ICmpInst *CmpI, Value *V1, Value *V2,
+                                 Instruction::CastOps *CastOp) {
+  CastInst *CI = dyn_cast<CastInst>(V1);
+  Constant *C = dyn_cast<Constant>(V2);
+  if (!CI || !C)
+    return nullptr;
+  *CastOp = CI->getOpcode();
+
+  if ((isa<SExtInst>(CI) && CmpI->isSigned()) ||
+      (isa<ZExtInst>(CI) && CmpI->isUnsigned()))
+    return ConstantExpr::getTrunc(C, CI->getSrcTy());
+
+  if (isa<TruncInst>(CI))
+    return ConstantExpr::getIntegerCast(C, CI->getSrcTy(), CmpI->isSigned());
+
+  return nullptr;
+}
+
+SelectPatternFlavor llvm::matchSelectPattern(Value *V,
+                                             Value *&LHS, Value *&RHS,
+                                             Instruction::CastOps *CastOp) {
+  SelectInst *SI = dyn_cast<SelectInst>(V);
+  if (!SI) return SPF_UNKNOWN;
+
+  ICmpInst *CmpI = dyn_cast<ICmpInst>(SI->getCondition());
+  if (!CmpI) return SPF_UNKNOWN;
+
+  ICmpInst::Predicate Pred = CmpI->getPredicate();
+  Value *CmpLHS = CmpI->getOperand(0);
+  Value *CmpRHS = CmpI->getOperand(1);
+  Value *TrueVal = SI->getTrueValue();
+  Value *FalseVal = SI->getFalseValue();
+
+  // Bail out early.
+  if (CmpI->isEquality())
+    return SPF_UNKNOWN;
+
+  // Deal with type mismatches.
+  if (CastOp && CmpLHS->getType() != TrueVal->getType()) {
+    if (Constant *C = lookThroughCast(CmpI, TrueVal, FalseVal, CastOp))
+      return ::matchSelectPattern(Pred, CmpLHS, CmpRHS,
+                                  cast<CastInst>(TrueVal)->getOperand(0), C,
+                                  LHS, RHS);
+    if (Constant *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp))
+      return ::matchSelectPattern(Pred, CmpLHS, CmpRHS,
+                                  C, cast<CastInst>(FalseVal)->getOperand(0),
+                                  LHS, RHS);
+  }
+  return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal,
+                              LHS, RHS);
+}