X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=include%2Fllvm%2FSupport%2FPatternMatch.h;h=97739b08694d7ad6cc58b058f22eca8c0df423ce;hb=bef2236283c333f17613b2ea4904878228fedb6e;hp=1cc59952727aca2fefd13f4fc6269bcbd2f122f2;hpb=b9f08a00af689eb54d25f4cec9a71899d1984f56;p=oota-llvm.git diff --git a/include/llvm/Support/PatternMatch.h b/include/llvm/Support/PatternMatch.h index 1cc59952727..97739b08694 100644 --- a/include/llvm/Support/PatternMatch.h +++ b/include/llvm/Support/PatternMatch.h @@ -29,8 +29,11 @@ #ifndef LLVM_SUPPORT_PATTERNMATCH_H #define LLVM_SUPPORT_PATTERNMATCH_H -#include "llvm/Constants.h" -#include "llvm/Instructions.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/CallSite.h" namespace llvm { namespace PatternMatch { @@ -40,25 +43,154 @@ bool match(Val *V, const Pattern &P) { return const_cast(P).match(V); } + +template +struct OneUse_match { + SubPattern_t SubPattern; + + OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} + + template + bool match(OpTy *V) { + return V->hasOneUse() && SubPattern.match(V); + } +}; + +template +inline OneUse_match m_OneUse(const T &SubPattern) { return SubPattern; } + + template -struct leaf_ty { +struct class_match { template bool match(ITy *V) { return isa(V); } }; /// m_Value() - Match an arbitrary value and ignore it. -inline leaf_ty m_Value() { return leaf_ty(); } +inline class_match m_Value() { return class_match(); } /// m_ConstantInt() - Match an arbitrary ConstantInt and ignore it. -inline leaf_ty m_ConstantInt() { return leaf_ty(); } +inline class_match m_ConstantInt() { + return class_match(); +} +/// m_Undef() - Match an arbitrary undef constant. +inline class_match m_Undef() { return class_match(); } + +inline class_match m_Constant() { return class_match(); } + +/// Matching combinators +template +struct match_combine_or { + LTy L; + RTy R; + + match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) { } + + template + bool match(ITy *V) { + if (L.match(V)) + return true; + if (R.match(V)) + return true; + return false; + } +}; + +template +struct match_combine_and { + LTy L; + RTy R; + + match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) { } + + template + bool match(ITy *V) { + if (L.match(V)) + if (R.match(V)) + return true; + return false; + } +}; + +/// Combine two pattern matchers matching L || R +template +inline match_combine_or m_CombineOr(const LTy &L, const RTy &R) { + return match_combine_or(L, R); +} + +/// Combine two pattern matchers matching L && R +template +inline match_combine_and m_CombineAnd(const LTy &L, const RTy &R) { + return match_combine_and(L, R); +} + +struct match_zero { + template + bool match(ITy *V) { + if (const Constant *C = dyn_cast(V)) + return C->isNullValue(); + return false; + } +}; + +/// m_Zero() - Match an arbitrary zero/null constant. This includes +/// zero_initializer for vectors and ConstantPointerNull for pointers. +inline match_zero m_Zero() { return match_zero(); } + +struct match_neg_zero { + template + bool match(ITy *V) { + if (const Constant *C = dyn_cast(V)) + return C->isNegativeZeroValue(); + return false; + } +}; + +/// m_NegZero() - Match an arbitrary zero/null constant. This includes +/// zero_initializer for vectors and ConstantPointerNull for pointers. For +/// floating point constants, this will match negative zero but not positive +/// zero +inline match_neg_zero m_NegZero() { return match_neg_zero(); } + +/// m_AnyZero() - Match an arbitrary zero/null constant. This includes +/// zero_initializer for vectors and ConstantPointerNull for pointers. For +/// floating point constants, this will match negative zero and positive zero +inline match_combine_or m_AnyZero() { + return m_CombineOr(m_Zero(), m_NegZero()); +} + +struct apint_match { + const APInt *&Res; + apint_match(const APInt *&R) : Res(R) {} + template + bool match(ITy *V) { + if (ConstantInt *CI = dyn_cast(V)) { + Res = &CI->getValue(); + return true; + } + if (V->getType()->isVectorTy()) + if (const Constant *C = dyn_cast(V)) + if (ConstantInt *CI = + dyn_cast_or_null(C->getSplatValue())) { + Res = &CI->getValue(); + return true; + } + return false; + } +}; + +/// m_APInt - Match a ConstantInt or splatted ConstantVector, binding the +/// specified pointer to the contained APInt. +inline apint_match m_APInt(const APInt *&Res) { return Res; } + template -struct constantint_ty { +struct constantint_match { template bool match(ITy *V) { if (const ConstantInt *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) - return CIV == Val; + return CIV == static_cast(Val); // If Val is negative, and CI is shorter than it, truncate to the right // number of bits. If it is larger, then we have to sign extend. Just // compare their negated values. @@ -68,26 +200,87 @@ struct constantint_ty { } }; -/// m_ConstantInt(int64_t) - Match a ConstantInt with a specific value -/// and ignore it. +/// m_ConstantInt - Match a ConstantInt with a specific value. template -inline constantint_ty m_ConstantInt() { - return constantint_ty(); +inline constantint_match m_ConstantInt() { + return constantint_match(); } -struct zero_ty { +/// cst_pred_ty - This helper class is used to match scalar and vector constants +/// that satisfy a specified predicate. +template +struct cst_pred_ty : public Predicate { template bool match(ITy *V) { - if (const Constant *C = dyn_cast(V)) - return C->isNullValue(); + if (const ConstantInt *CI = dyn_cast(V)) + return this->isValue(CI->getValue()); + if (V->getType()->isVectorTy()) + if (const Constant *C = dyn_cast(V)) + if (const ConstantInt *CI = + dyn_cast_or_null(C->getSplatValue())) + return this->isValue(CI->getValue()); return false; } }; -/// m_Zero() - Match an arbitrary zero/null constant. -inline zero_ty m_Zero() { return zero_ty(); } +/// api_pred_ty - This helper class is used to match scalar and vector constants +/// that satisfy a specified predicate, and bind them to an APInt. +template +struct api_pred_ty : public Predicate { + const APInt *&Res; + api_pred_ty(const APInt *&R) : Res(R) {} + template + bool match(ITy *V) { + if (const ConstantInt *CI = dyn_cast(V)) + if (this->isValue(CI->getValue())) { + Res = &CI->getValue(); + return true; + } + if (V->getType()->isVectorTy()) + if (const Constant *C = dyn_cast(V)) + if (ConstantInt *CI = dyn_cast_or_null(C->getSplatValue())) + if (this->isValue(CI->getValue())) { + Res = &CI->getValue(); + return true; + } + + return false; + } +}; +struct is_one { + bool isValue(const APInt &C) { return C == 1; } +}; + +/// m_One() - Match an integer 1 or a vector with all elements equal to 1. +inline cst_pred_ty m_One() { return cst_pred_ty(); } +inline api_pred_ty m_One(const APInt *&V) { return V; } + +struct is_all_ones { + bool isValue(const APInt &C) { return C.isAllOnesValue(); } +}; + +/// m_AllOnes() - Match an integer or vector with all bits set to true. +inline cst_pred_ty m_AllOnes() {return cst_pred_ty();} +inline api_pred_ty m_AllOnes(const APInt *&V) { return V; } + +struct is_sign_bit { + bool isValue(const APInt &C) { return C.isSignBit(); } +}; + +/// m_SignBit() - Match an integer or vector with only the sign bit(s) set. +inline cst_pred_ty m_SignBit() {return cst_pred_ty();} +inline api_pred_ty m_SignBit(const APInt *&V) { return V; } + +struct is_power2 { + bool isValue(const APInt &C) { return C.isPowerOf2(); } +}; + +/// m_Power2() - Match an integer or vector power of 2. +inline cst_pred_ty m_Power2() { return cst_pred_ty(); } +inline api_pred_ty m_Power2(const APInt *&V) { return V; } + template struct bind_ty { Class *&VR; @@ -108,28 +301,76 @@ inline bind_ty m_Value(Value *&V) { return V; } /// m_ConstantInt - Match a ConstantInt, capturing the value if we match. inline bind_ty m_ConstantInt(ConstantInt *&CI) { return CI; } - + +/// m_Constant - Match a Constant, capturing the value if we match. +inline bind_ty m_Constant(Constant *&C) { return C; } + +/// m_ConstantFP - Match a ConstantFP, capturing the value if we match. +inline bind_ty m_ConstantFP(ConstantFP *&C) { return C; } + /// specificval_ty - Match a specified Value*. struct specificval_ty { const Value *Val; specificval_ty(const Value *V) : Val(V) {} - + template bool match(ITy *V) { return V == Val; } }; - + /// m_Specific - Match if we have a specific specified value. inline specificval_ty m_Specific(const Value *V) { return V; } - + +/// Match a specified floating point value or vector of all elements of that +/// value. +struct specific_fpval { + double Val; + specific_fpval(double V) : Val(V) {} + + template + bool match(ITy *V) { + if (const ConstantFP *CFP = dyn_cast(V)) + return CFP->isExactlyValue(Val); + if (V->getType()->isVectorTy()) + if (const Constant *C = dyn_cast(V)) + if (ConstantFP *CFP = dyn_cast_or_null(C->getSplatValue())) + return CFP->isExactlyValue(Val); + return false; + } +}; + +/// Match a specific floating point value or vector with all elements equal to +/// the value. +inline specific_fpval m_SpecificFP(double V) { return specific_fpval(V); } + +/// Match a float 1.0 or vector with all elements equal to 1.0. +inline specific_fpval m_FPOne() { return m_SpecificFP(1.0); } + +struct bind_const_intval_ty { + uint64_t &VR; + bind_const_intval_ty(uint64_t &V) : VR(V) {} + + template + bool match(ITy *V) { + if (ConstantInt *CV = dyn_cast(V)) + if (CV->getBitWidth() <= 64) { + VR = CV->getZExtValue(); + return true; + } + return false; + } +}; + +/// m_ConstantInt - Match a ConstantInt and bind to its value. This does not +/// match ConstantInts wider than 64-bits. +inline bind_const_intval_ty m_ConstantInt(uint64_t &V) { return V; } //===----------------------------------------------------------------------===// // Matchers for specific binary operators. // -template +template struct BinaryOp_match { LHS_t L; RHS_t R; @@ -139,9 +380,8 @@ struct BinaryOp_match { template bool match(OpTy *V) { if (V->getValueID() == Value::InstructionVal + Opcode) { - ConcreteTy *I = cast(V); - return I->getOpcode() == Opcode && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + BinaryOperator *I = cast(V); + return L.match(I->getOperand(0)) && R.match(I->getOperand(1)); } if (ConstantExpr *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && L.match(CE->getOperand(0)) && @@ -151,176 +391,268 @@ struct BinaryOp_match { }; template -inline BinaryOp_match m_Add(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_Add(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_Sub(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_FAdd(const LHS &L, const RHS &R) { + return BinaryOp_match(L, R); +} + +template +inline BinaryOp_match +m_Sub(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_Mul(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_FSub(const LHS &L, const RHS &R) { + return BinaryOp_match(L, R); +} + +template +inline BinaryOp_match +m_Mul(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_UDiv(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_FMul(const LHS &L, const RHS &R) { + return BinaryOp_match(L, R); +} + +template +inline BinaryOp_match +m_UDiv(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_SDiv(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_SDiv(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_FDiv(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_FDiv(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_URem(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_URem(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_SRem(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_SRem(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_FRem(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_FRem(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_And(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_And(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_Or(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_Or(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_Xor(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_Xor(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_Shl(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_Shl(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_LShr(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_LShr(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } template -inline BinaryOp_match m_AShr(const LHS &L, - const RHS &R) { +inline BinaryOp_match +m_AShr(const LHS &L, const RHS &R) { return BinaryOp_match(L, R); } -//===----------------------------------------------------------------------===// -// Matchers for either AShr or LShr .. for convenience -// -template -struct Shr_match { +template +struct OverflowingBinaryOp_match { LHS_t L; RHS_t R; - Shr_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} + OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Instruction::LShr || - V->getValueID() == Value::InstructionVal + Instruction::AShr) { - ConcreteTy *I = cast(V); - return (I->getOpcode() == Instruction::AShr || - I->getOpcode() == Instruction::LShr) && - L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + if (OverflowingBinaryOperator *Op = dyn_cast(V)) { + if (Op->getOpcode() != Opcode) + return false; + if (WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap && + !Op->hasNoUnsignedWrap()) + return false; + if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && + !Op->hasNoSignedWrap()) + return false; + return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); } - if (ConstantExpr *CE = dyn_cast(V)) - return (CE->getOpcode() == Instruction::LShr || - CE->getOpcode() == Instruction::AShr) && - L.match(CE->getOperand(0)) && - R.match(CE->getOperand(1)); return false; } }; -template -inline Shr_match m_Shr(const LHS &L, const RHS &R) { - return Shr_match(L, R); +template +inline OverflowingBinaryOp_match +m_NSWAdd(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); +} +template +inline OverflowingBinaryOp_match +m_NSWSub(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); +} +template +inline OverflowingBinaryOp_match +m_NSWMul(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); +} +template +inline OverflowingBinaryOp_match +m_NSWShl(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); +} + +template +inline OverflowingBinaryOp_match +m_NUWAdd(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); +} +template +inline OverflowingBinaryOp_match +m_NUWSub(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); +} +template +inline OverflowingBinaryOp_match +m_NUWMul(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); +} +template +inline OverflowingBinaryOp_match +m_NUWShl(const LHS &L, const RHS &R) { + return OverflowingBinaryOp_match( + L, R); } //===----------------------------------------------------------------------===// -// Matchers for binary classes +// Class that matches two different binary ops. // - -template -struct BinaryOpClass_match { - OpcType *Opcode; +template +struct BinOp2_match { LHS_t L; RHS_t R; - BinaryOpClass_match(OpcType &Op, const LHS_t &LHS, - const RHS_t &RHS) - : Opcode(&Op), L(LHS), R(RHS) {} - BinaryOpClass_match(const LHS_t &LHS, const RHS_t &RHS) - : Opcode(0), L(LHS), R(RHS) {} + BinOp2_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (Class *I = dyn_cast(V)) - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { - if (Opcode) - *Opcode = I->getOpcode(); - return true; - } -#if 0 // Doesn't handle constantexprs yet! + if (V->getValueID() == Value::InstructionVal + Opc1 || + V->getValueID() == Value::InstructionVal + Opc2) { + BinaryOperator *I = cast(V); + return L.match(I->getOperand(0)) && R.match(I->getOperand(1)); + } if (ConstantExpr *CE = dyn_cast(V)) - return CE->getOpcode() == Opcode && L.match(CE->getOperand(0)) && - R.match(CE->getOperand(1)); -#endif + return (CE->getOpcode() == Opc1 || CE->getOpcode() == Opc2) && + L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); return false; } }; +/// m_Shr - Matches LShr or AShr. +template +inline BinOp2_match +m_Shr(const LHS &L, const RHS &R) { + return BinOp2_match(L, R); +} + +/// m_LogicalShift - Matches LShr or Shl. template -inline BinaryOpClass_match -m_Shift(Instruction::BinaryOps &Op, const LHS &L, const RHS &R) { - return BinaryOpClass_match(Op, L, R); +inline BinOp2_match +m_LogicalShift(const LHS &L, const RHS &R) { + return BinOp2_match(L, R); } +/// m_IDiv - Matches UDiv and SDiv. template -inline BinaryOpClass_match -m_Shift(const LHS &L, const RHS &R) { - return BinaryOpClass_match(L, R); +inline BinOp2_match +m_IDiv(const LHS &L, const RHS &R) { + return BinOp2_match(L, R); } +//===----------------------------------------------------------------------===// +// Class that matches exact binary ops. +// +template +struct Exact_match { + SubPattern_t SubPattern; + + Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} + + template + bool match(OpTy *V) { + if (PossiblyExactOperator *PEO = dyn_cast(V)) + return PEO->isExact() && SubPattern.match(V); + return false; + } +}; + +template +inline Exact_match m_Exact(const T &SubPattern) { return SubPattern; } + //===----------------------------------------------------------------------===// // Matchers for CmpInst classes // @@ -331,8 +663,7 @@ struct CmpClass_match { LHS_t L; RHS_t R; - CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, - const RHS_t &RHS) + CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} template @@ -385,19 +716,17 @@ struct SelectClass_match { }; template -inline SelectClass_match +inline SelectClass_match m_Select(const Cond &C, const LHS &L, const RHS &R) { return SelectClass_match(C, L, R); } /// m_SelectCst - This matches a select of two constants, e.g.: -/// m_SelectCst(m_Value(V), -1, 0) +/// m_SelectCst<-1, 0>(m_Value(V)) template -inline SelectClass_match, constantint_ty > +inline SelectClass_match, constantint_match > m_SelectCst(const Cond &C) { - return SelectClass_match, - constantint_ty >(C, m_ConstantInt(), - m_ConstantInt()); + return m_Select(C, m_ConstantInt(), m_ConstantInt()); } @@ -405,26 +734,69 @@ m_SelectCst(const Cond &C) { // Matchers for CastInst classes // -template +template struct CastClass_match { Op_t Op; - + CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - + template bool match(OpTy *V) { - if (Class *I = dyn_cast(V)) - return Op.match(I->getOperand(0)); + if (Operator *O = dyn_cast(V)) + return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); return false; } }; -template -inline CastClass_match m_Cast(const OpTy &Op) { - return CastClass_match(Op); +/// m_BitCast +template +inline CastClass_match +m_BitCast(const OpTy &Op) { + return CastClass_match(Op); +} + +/// m_PtrToInt +template +inline CastClass_match +m_PtrToInt(const OpTy &Op) { + return CastClass_match(Op); +} + +/// m_Trunc +template +inline CastClass_match +m_Trunc(const OpTy &Op) { + return CastClass_match(Op); +} + +/// m_SExt +template +inline CastClass_match +m_SExt(const OpTy &Op) { + return CastClass_match(Op); +} + +/// m_ZExt +template +inline CastClass_match +m_ZExt(const OpTy &Op) { + return CastClass_match(Op); +} + +/// m_UIToFP +template +inline CastClass_match +m_UIToFP(const OpTy &Op) { + return CastClass_match(Op); +} + +/// m_SIToFP +template +inline CastClass_match +m_SIToFP(const OpTy &Op) { + return CastClass_match(Op); } - //===----------------------------------------------------------------------===// // Matchers for unary operators // @@ -437,27 +809,18 @@ struct not_match { template bool match(OpTy *V) { - if (Instruction *I = dyn_cast(V)) - if (I->getOpcode() == Instruction::Xor) - return matchIfNot(I->getOperand(0), I->getOperand(1)); - if (ConstantExpr *CE = dyn_cast(V)) - if (CE->getOpcode() == Instruction::Xor) - return matchIfNot(CE->getOperand(0), CE->getOperand(1)); - if (ConstantInt *CI = dyn_cast(V)) - return L.match(ConstantExpr::getNot(CI)); + if (Operator *O = dyn_cast(V)) + if (O->getOpcode() == Instruction::Xor) + return matchIfNot(O->getOperand(0), O->getOperand(1)); return false; } private: bool matchIfNot(Value *LHS, Value *RHS) { - if (ConstantInt *CI = dyn_cast(RHS)) - return CI->isAllOnesValue() && L.match(LHS); - if (ConstantInt *CI = dyn_cast(LHS)) - return CI->isAllOnesValue() && L.match(RHS); - if (ConstantVector *CV = dyn_cast(RHS)) - return CV->isAllOnesValue() && L.match(LHS); - if (ConstantVector *CV = dyn_cast(LHS)) - return CV->isAllOnesValue() && L.match(RHS); - return false; + return (isa(RHS) || isa(RHS) || + // FIXME: Remove CV. + isa(RHS)) && + cast(RHS)->isAllOnesValue() && + L.match(LHS); } }; @@ -468,36 +831,78 @@ inline not_match m_Not(const LHS &L) { return L; } template struct neg_match { LHS_t L; - + neg_match(const LHS_t &LHS) : L(LHS) {} - + template bool match(OpTy *V) { - if (Instruction *I = dyn_cast(V)) - if (I->getOpcode() == Instruction::Sub) - return matchIfNeg(I->getOperand(0), I->getOperand(1)); - if (ConstantExpr *CE = dyn_cast(V)) - if (CE->getOpcode() == Instruction::Sub) - return matchIfNeg(CE->getOperand(0), CE->getOperand(1)); - if (ConstantInt *CI = dyn_cast(V)) - return L.match(ConstantExpr::getNeg(CI)); + if (Operator *O = dyn_cast(V)) + if (O->getOpcode() == Instruction::Sub) + return matchIfNeg(O->getOperand(0), O->getOperand(1)); return false; } private: bool matchIfNeg(Value *LHS, Value *RHS) { - return LHS == ConstantExpr::getZeroValueForNegationExpr(LHS->getType()) && + return ((isa(LHS) && cast(LHS)->isZero()) || + isa(LHS)) && L.match(RHS); } }; +/// m_Neg - Match an integer negate. template inline neg_match m_Neg(const LHS &L) { return L; } +template +struct fneg_match { + LHS_t L; + + fneg_match(const LHS_t &LHS) : L(LHS) {} + + template + bool match(OpTy *V) { + if (Operator *O = dyn_cast(V)) + if (O->getOpcode() == Instruction::FSub) + return matchIfFNeg(O->getOperand(0), O->getOperand(1)); + return false; + } +private: + bool matchIfFNeg(Value *LHS, Value *RHS) { + if (ConstantFP *C = dyn_cast(LHS)) + return C->isNegativeZeroValue() && L.match(RHS); + return false; + } +}; + +/// m_FNeg - Match a floating point negate. +template +inline fneg_match m_FNeg(const LHS &L) { return L; } + + //===----------------------------------------------------------------------===// -// Matchers for control flow +// Matchers for control flow. // +struct br_match { + BasicBlock *&Succ; + br_match(BasicBlock *&Succ) + : Succ(Succ) { + } + + template + bool match(OpTy *V) { + if (BranchInst *BI = dyn_cast(V)) + if (BI->isUnconditional()) { + Succ = BI->getSuccessor(0); + return true; + } + return false; + } +}; + +inline br_match m_UnconditionalBr(BasicBlock *&Succ) { return br_match(Succ); } + template struct brc_match { Cond_t Cond; @@ -509,12 +914,10 @@ struct brc_match { template bool match(OpTy *V) { if (BranchInst *BI = dyn_cast(V)) - if (BI->isConditional()) { - if (Cond.match(BI->getCondition())) { - T = BI->getSuccessor(0); - F = BI->getSuccessor(1); - return true; - } + if (BI->isConditional() && Cond.match(BI->getCondition())) { + T = BI->getSuccessor(0); + F = BI->getSuccessor(1); + return true; } return false; } @@ -525,6 +928,283 @@ inline brc_match m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F) { return brc_match(C, T, F); } + +//===----------------------------------------------------------------------===// +// Matchers for max/min idioms, eg: "select (sgt x, y), x, y" -> smax(x,y). +// + +template +struct MaxMin_match { + LHS_t L; + RHS_t R; + + MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) + : L(LHS), R(RHS) {} + + template + bool match(OpTy *V) { + // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". + SelectInst *SI = dyn_cast(V); + if (!SI) + return false; + CmpInst_t *Cmp = dyn_cast(SI->getCondition()); + if (!Cmp) + return false; + // At this point we have a select conditioned on a comparison. Check that + // it is the values returned by the select that are being compared. + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + Value *LHS = Cmp->getOperand(0); + Value *RHS = Cmp->getOperand(1); + if ((TrueVal != LHS || FalseVal != RHS) && + (TrueVal != RHS || FalseVal != LHS)) + return false; + typename CmpInst_t::Predicate Pred = LHS == TrueVal ? + Cmp->getPredicate() : Cmp->getSwappedPredicate(); + // Does "(x pred y) ? x : y" represent the desired max/min operation? + if (!Pred_t::match(Pred)) + return false; + // It does! Bind the operands. + return L.match(LHS) && R.match(RHS); + } +}; + +/// smax_pred_ty - Helper class for identifying signed max predicates. +struct smax_pred_ty { + static bool match(ICmpInst::Predicate Pred) { + return Pred == CmpInst::ICMP_SGT || Pred == CmpInst::ICMP_SGE; + } +}; + +/// smin_pred_ty - Helper class for identifying signed min predicates. +struct smin_pred_ty { + static bool match(ICmpInst::Predicate Pred) { + return Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_SLE; + } +}; + +/// umax_pred_ty - Helper class for identifying unsigned max predicates. +struct umax_pred_ty { + static bool match(ICmpInst::Predicate Pred) { + return Pred == CmpInst::ICMP_UGT || Pred == CmpInst::ICMP_UGE; + } +}; + +/// umin_pred_ty - Helper class for identifying unsigned min predicates. +struct umin_pred_ty { + static bool match(ICmpInst::Predicate Pred) { + return Pred == CmpInst::ICMP_ULT || Pred == CmpInst::ICMP_ULE; + } +}; + +/// ofmax_pred_ty - Helper class for identifying ordered max predicates. +struct ofmax_pred_ty { + static bool match(FCmpInst::Predicate Pred) { + return Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_OGE; + } +}; + +/// ofmin_pred_ty - Helper class for identifying ordered min predicates. +struct ofmin_pred_ty { + static bool match(FCmpInst::Predicate Pred) { + return Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_OLE; + } +}; + +/// ufmax_pred_ty - Helper class for identifying unordered max predicates. +struct ufmax_pred_ty { + static bool match(FCmpInst::Predicate Pred) { + return Pred == CmpInst::FCMP_UGT || Pred == CmpInst::FCMP_UGE; + } +}; + +/// ufmin_pred_ty - Helper class for identifying unordered min predicates. +struct ufmin_pred_ty { + static bool match(FCmpInst::Predicate Pred) { + return Pred == CmpInst::FCMP_ULT || Pred == CmpInst::FCMP_ULE; + } +}; + +template +inline MaxMin_match +m_SMax(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +template +inline MaxMin_match +m_SMin(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +template +inline MaxMin_match +m_UMax(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +template +inline MaxMin_match +m_UMin(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +/// \brief Match an 'ordered' floating point maximum function. +/// Floating point has one special value 'NaN'. Therefore, there is no total +/// order. However, if we can ignore the 'NaN' value (for example, because of a +/// 'no-nans-float-math' flag) a combination of a fcmp and select has 'maximum' +/// semantics. In the presence of 'NaN' we have to preserve the original +/// select(fcmp(ogt/ge, L, R), L, R) semantics matched by this predicate. +/// +/// max(L, R) iff L and R are not NaN +/// m_OrdFMax(L, R) = R iff L or R are NaN +template +inline MaxMin_match +m_OrdFMax(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +/// \brief Match an 'ordered' floating point minimum function. +/// Floating point has one special value 'NaN'. Therefore, there is no total +/// order. However, if we can ignore the 'NaN' value (for example, because of a +/// 'no-nans-float-math' flag) a combination of a fcmp and select has 'minimum' +/// semantics. In the presence of 'NaN' we have to preserve the original +/// select(fcmp(olt/le, L, R), L, R) semantics matched by this predicate. +/// +/// max(L, R) iff L and R are not NaN +/// m_OrdFMin(L, R) = R iff L or R are NaN +template +inline MaxMin_match +m_OrdFMin(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +/// \brief Match an 'unordered' floating point maximum function. +/// Floating point has one special value 'NaN'. Therefore, there is no total +/// order. However, if we can ignore the 'NaN' value (for example, because of a +/// 'no-nans-float-math' flag) a combination of a fcmp and select has 'maximum' +/// semantics. In the presence of 'NaN' we have to preserve the original +/// select(fcmp(ugt/ge, L, R), L, R) semantics matched by this predicate. +/// +/// max(L, R) iff L and R are not NaN +/// m_UnordFMin(L, R) = L iff L or R are NaN +template +inline MaxMin_match +m_UnordFMax(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +/// \brief Match an 'unordered' floating point minimum function. +/// Floating point has one special value 'NaN'. Therefore, there is no total +/// order. However, if we can ignore the 'NaN' value (for example, because of a +/// 'no-nans-float-math' flag) a combination of a fcmp and select has 'minimum' +/// semantics. In the presence of 'NaN' we have to preserve the original +/// select(fcmp(ult/le, L, R), L, R) semantics matched by this predicate. +/// +/// max(L, R) iff L and R are not NaN +/// m_UnordFMin(L, R) = L iff L or R are NaN +template +inline MaxMin_match +m_UnordFMin(const LHS &L, const RHS &R) { + return MaxMin_match(L, R); +} + +template +struct Argument_match { + unsigned OpI; + Opnd_t Val; + Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) { } + + template + bool match(OpTy *V) { + CallSite CS(V); + return CS.isCall() && Val.match(CS.getArgument(OpI)); + } +}; + +/// Match an argument +template +inline Argument_match m_Argument(const Opnd_t &Op) { + return Argument_match(OpI, Op); +} + +/// Intrinsic matchers. +struct IntrinsicID_match { + unsigned ID; + IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) { } + + template + bool match(OpTy *V) { + IntrinsicInst *II = dyn_cast(V); + return II && II->getIntrinsicID() == ID; + } +}; + +/// Intrinsic matches are combinations of ID matchers, and argument +/// matchers. Higher arity matcher are defined recursively in terms of and-ing +/// them with lower arity matchers. Here's some convenient typedefs for up to +/// several arguments, and more can be added as needed +template struct m_Intrinsic_Ty; +template +struct m_Intrinsic_Ty { + typedef match_combine_and > Ty; +}; +template +struct m_Intrinsic_Ty { + typedef match_combine_and::Ty, + Argument_match > Ty; +}; +template +struct m_Intrinsic_Ty { + typedef match_combine_and::Ty, + Argument_match > Ty; +}; +template +struct m_Intrinsic_Ty { + typedef match_combine_and::Ty, + Argument_match > Ty; +}; + +/// Match intrinsic calls like this: +/// m_Intrinsic(m_Value(X)) +template +inline IntrinsicID_match +m_Intrinsic() { return IntrinsicID_match(IntrID); } + +template +inline typename m_Intrinsic_Ty::Ty +m_Intrinsic(const T0 &Op0) { + return m_CombineAnd(m_Intrinsic(), m_Argument<0>(Op0)); +} + +template +inline typename m_Intrinsic_Ty::Ty +m_Intrinsic(const T0 &Op0, const T1 &Op1) { + return m_CombineAnd(m_Intrinsic(Op0), m_Argument<1>(Op1)); +} + +template +inline typename m_Intrinsic_Ty::Ty +m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2) { + return m_CombineAnd(m_Intrinsic(Op0, Op1), m_Argument<2>(Op2)); +} + +template +inline typename m_Intrinsic_Ty::Ty +m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) { + return m_CombineAnd(m_Intrinsic(Op0, Op1, Op2), m_Argument<3>(Op3)); +} + +// Helper intrinsic matching specializations +template +inline typename m_Intrinsic_Ty::Ty +m_BSwap(const Opnd0 &Op0) { + return m_Intrinsic(Op0); +} + } // end namespace PatternMatch } // end namespace llvm