Loop Strength Reduce: Scaling factor cost.
authorQuentin Colombet <qcolombet@apple.com>
Fri, 31 May 2013 21:29:03 +0000 (21:29 +0000)
committerQuentin Colombet <qcolombet@apple.com>
Fri, 31 May 2013 21:29:03 +0000 (21:29 +0000)
Account for the cost of scaling factor in Loop Strength Reduce when rating the
formulae. This uses a target hook.

The default implementation of the hook is: if the addressing mode is legal, the
scaling factor is free.

<rdar://problem/13806271>

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@183045 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/TargetTransformInfo.h
include/llvm/Target/TargetLowering.h
lib/Analysis/TargetTransformInfo.cpp
lib/CodeGen/BasicTargetTransformInfo.cpp
lib/Transforms/Scalar/LoopStrengthReduce.cpp

index a9d6725d86b0999e42bb21cd920e592068463141..eb29e3483d831077cc56e1c4e650ee853df12b70 100644 (file)
@@ -225,6 +225,16 @@ public:
                                      int64_t BaseOffset, bool HasBaseReg,
                                      int64_t Scale) const;
 
+  /// \brief Return the cost of the scaling factor used in the addressing
+  /// mode represented by AM for this target, for a load/store
+  /// of the specified type.
+  /// If the AM is supported, the return value must be >= 0.
+  /// If the AM is not supported, it returns a negative value.
+  /// TODO: Handle pre/postinc as well.
+  virtual int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+                                   int64_t BaseOffset, bool HasBaseReg,
+                                   int64_t Scale) const;
+
   /// isTruncateFree - Return true if it's free to truncate a value of
   /// type Ty1 to type Ty2. e.g. On x86 it's free to truncate a i32 value in
   /// register EAX to i16 by referencing its sub-register AX.
index 41a4a2b838f328b10bda1ff641d0fc1d5651ce54..d67e55dc6655a0d7c7b3916f02b2d56e931e8826 100644 (file)
@@ -1139,6 +1139,18 @@ public:
   /// TODO: Handle pre/postinc as well.
   virtual bool isLegalAddressingMode(const AddrMode &AM, Type *Ty) const;
 
+  /// \brief Return the cost of the scaling factor used in the addressing
+  /// mode represented by AM for this target, for a load/store
+  /// of the specified type.
+  /// If the AM is supported, the return value must be >= 0.
+  /// If the AM is not supported, it returns a negative value.
+  /// TODO: Handle pre/postinc as well.
+  virtual int getScalingFactorCost(const AddrMode &AM, Type *Ty) const {
+    // Default: assume that any scaling factor used in a legal AM is free.
+    if (isLegalAddressingMode(AM, Ty)) return 0;
+    return -1;
+  }
+
   /// isLegalICmpImmediate - Return true if the specified immediate is legal
   /// icmp immediate, that is the target has icmp instructions which can compare
   /// a register against the immediate without having to materialize the
index 64f8e96884c716a8dc9cd1cb823b9da854b465a8..35ce794c7f12099c310e7c5614ef4baeabd331b9 100644 (file)
@@ -108,6 +108,14 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
                                         Scale);
 }
 
+int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+                                              int64_t BaseOffset,
+                                              bool HasBaseReg,
+                                              int64_t Scale) const {
+  return PrevTTI->getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg,
+                                       Scale);
+}
+
 bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const {
   return PrevTTI->isTruncateFree(Ty1, Ty2);
 }
@@ -457,6 +465,15 @@ struct NoTTI : ImmutablePass, TargetTransformInfo {
     return !BaseGV && BaseOffset == 0 && Scale <= 1;
   }
 
+  int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
+                           bool HasBaseReg, int64_t Scale) const {
+    // Guess that all legal addressing mode are free.
+    if(isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale))
+      return 0;
+    return -1;
+  }
+
+
   bool isTruncateFree(Type *Ty1, Type *Ty2) const {
     return false;
   }
index 4a99184f5eecfaf5c9e8dd23ef09735cd2d034bd..92a5bb70f432f3ea0dcec72e438eee020ed0ba66 100644 (file)
@@ -71,6 +71,9 @@ public:
   virtual bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
                                      int64_t BaseOffset, bool HasBaseReg,
                                      int64_t Scale) const;
+  virtual int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+                                   int64_t BaseOffset, bool HasBaseReg,
+                                   int64_t Scale) const;
   virtual bool isTruncateFree(Type *Ty1, Type *Ty2) const;
   virtual bool isTypeLegal(Type *Ty) const;
   virtual unsigned getJumpBufAlignment() const;
@@ -139,6 +142,17 @@ bool BasicTTI::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
   return TLI->isLegalAddressingMode(AM, Ty);
 }
 
+int BasicTTI::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+                                   int64_t BaseOffset, bool HasBaseReg,
+                                   int64_t Scale) const {
+  TargetLoweringBase::AddrMode AM;
+  AM.BaseGV = BaseGV;
+  AM.BaseOffs = BaseOffset;
+  AM.HasBaseReg = HasBaseReg;
+  AM.Scale = Scale;
+  return TLI->getScalingFactorCost(AM, Ty);
+}
+
 bool BasicTTI::isTruncateFree(Type *Ty1, Type *Ty2) const {
   return TLI->isTruncateFree(Ty1, Ty2);
 }
index ecc96ae0b22c9bb626098d2ced7fe65a27c2a29f..b107fef35a0fdd9380e11c7690983a5b2a88ca4f 100644 (file)
@@ -779,6 +779,9 @@ class LSRUse;
 // Check if it is legal to fold 2 base registers.
 static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU,
                              const Formula &F);
+// Get the cost of the scaling factor used in F for LU.
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+                                     const LSRUse &LU, const Formula &F);
 
 namespace {
 
@@ -792,11 +795,12 @@ class Cost {
   unsigned NumBaseAdds;
   unsigned ImmCost;
   unsigned SetupCost;
+  unsigned ScaleCost;
 
 public:
   Cost()
     : NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0),
-      SetupCost(0) {}
+      SetupCost(0), ScaleCost(0) {}
 
   bool operator<(const Cost &Other) const;
 
@@ -806,9 +810,9 @@ public:
   // Once any of the metrics loses, they must all remain losers.
   bool isValid() {
     return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds
-             | ImmCost | SetupCost) != ~0u)
+             | ImmCost | SetupCost | ScaleCost) != ~0u)
       || ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds
-           & ImmCost & SetupCost) == ~0u);
+           & ImmCost & SetupCost & ScaleCost) == ~0u);
   }
 #endif
 
@@ -947,6 +951,9 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,
     // allows to fold 2 registers.
     NumBaseAdds += NumBaseParts - (1 + isLegal2RegAMUse(TTI, LU, F));
 
+  // Accumulate non-free scaling amounts.
+  ScaleCost += getScalingFactorCost(TTI, LU, F);
+
   // Tally up the non-zero immediates.
   for (SmallVectorImpl<int64_t>::const_iterator I = Offsets.begin(),
        E = Offsets.end(); I != E; ++I) {
@@ -968,6 +975,7 @@ void Cost::Loose() {
   NumBaseAdds = ~0u;
   ImmCost = ~0u;
   SetupCost = ~0u;
+  ScaleCost = ~0u;
 }
 
 /// operator< - Choose the lower cost.
@@ -980,6 +988,8 @@ bool Cost::operator<(const Cost &Other) const {
     return NumIVMuls < Other.NumIVMuls;
   if (NumBaseAdds != Other.NumBaseAdds)
     return NumBaseAdds < Other.NumBaseAdds;
+  if (ScaleCost != Other.ScaleCost)
+    return ScaleCost < Other.ScaleCost;
   if (ImmCost != Other.ImmCost)
     return ImmCost < Other.ImmCost;
   if (SetupCost != Other.SetupCost)
@@ -996,6 +1006,8 @@ void Cost::print(raw_ostream &OS) const {
   if (NumBaseAdds != 0)
     OS << ", plus " << NumBaseAdds << " base add"
        << (NumBaseAdds == 1 ? "" : "s");
+  if (ScaleCost != 0)
+    OS << ", plus " << ScaleCost << " scale cost";
   if (ImmCost != 0)
     OS << ", plus " << ImmCost << " imm cost";
   if (SetupCost != 0)
@@ -1396,6 +1408,34 @@ static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU,
                     F.BaseGV, F.BaseOffset, F.HasBaseReg, 1);
  }
 
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+                                     const LSRUse &LU, const Formula &F) {
+  if (!F.Scale)
+    return 0;
+  assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind,
+                    LU.AccessTy, F) && "Illegal formula in use.");
+
+  switch (LU.Kind) {
+  case LSRUse::Address: {
+    int CurScaleCost = TTI.getScalingFactorCost(LU.AccessTy, F.BaseGV,
+                                                F.BaseOffset, F.HasBaseReg,
+                                                F.Scale);
+    assert(CurScaleCost >= 0 && "Legal addressing mode has an illegal cost!");
+    return CurScaleCost;
+  }
+  case LSRUse::ICmpZero:
+    // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg.
+    // Therefore, return 0 in case F.Scale == -1. 
+    return F.Scale != -1;
+
+  case LSRUse::Basic:
+  case LSRUse::Special:
+    return 0;
+  }
+
+  llvm_unreachable("Invalid LSRUse Kind!");
+}
+
 static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
                              LSRUse::KindType Kind, Type *AccessTy,
                              GlobalValue *BaseGV, int64_t BaseOffset,