[C++11] More 'nullptr' conversion. In some cases just using a boolean check instead...
[oota-llvm.git] / lib / Analysis / CostModel.cpp
1 //===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the cost model analysis. It provides a very basic cost
11 // estimation for LLVM-IR. This analysis uses the services of the codegen
12 // to approximate the cost of any IR instruction when lowered to machine
13 // instructions. The cost results are unit-less and the cost number represents
14 // the throughput of the machine assuming that all loads hit the cache, all
15 // branches are predicted, etc. The cost numbers can be added in order to
16 // compare two or more transformation alternatives.
17 //
18 //===----------------------------------------------------------------------===//
19
20 #define CM_NAME "cost-model"
21 #define DEBUG_TYPE CM_NAME
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Analysis/Passes.h"
24 #include "llvm/Analysis/TargetTransformInfo.h"
25 #include "llvm/IR/Function.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Value.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 using namespace llvm;
34
35 static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false),
36                                      cl::Hidden,
37                                      cl::desc("Recognize reduction patterns."));
38
39 namespace {
40   class CostModelAnalysis : public FunctionPass {
41
42   public:
43     static char ID; // Class identification, replacement for typeinfo
44     CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) {
45       initializeCostModelAnalysisPass(
46         *PassRegistry::getPassRegistry());
47     }
48
49     /// Returns the expected cost of the instruction.
50     /// Returns -1 if the cost is unknown.
51     /// Note, this method does not cache the cost calculation and it
52     /// can be expensive in some cases.
53     unsigned getInstructionCost(const Instruction *I) const;
54
55   private:
56     void getAnalysisUsage(AnalysisUsage &AU) const override;
57     bool runOnFunction(Function &F) override;
58     void print(raw_ostream &OS, const Module*) const override;
59
60     /// The function that we analyze.
61     Function *F;
62     /// Target information.
63     const TargetTransformInfo *TTI;
64   };
65 }  // End of anonymous namespace
66
67 // Register this pass.
68 char CostModelAnalysis::ID = 0;
69 static const char cm_name[] = "Cost Model Analysis";
70 INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
71 INITIALIZE_PASS_END  (CostModelAnalysis, CM_NAME, cm_name, false, true)
72
73 FunctionPass *llvm::createCostModelAnalysisPass() {
74   return new CostModelAnalysis();
75 }
76
77 void
78 CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
79   AU.setPreservesAll();
80 }
81
82 bool
83 CostModelAnalysis::runOnFunction(Function &F) {
84  this->F = &F;
85  TTI = getAnalysisIfAvailable<TargetTransformInfo>();
86
87  return false;
88 }
89
90 static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) {
91   for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i)
92     if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i))
93       return false;
94   return true;
95 }
96
97 static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) {
98   TargetTransformInfo::OperandValueKind OpInfo =
99     TargetTransformInfo::OK_AnyValue;
100
101   // Check for a splat of a constant or for a non uniform vector of constants.
102   if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) {
103     OpInfo = TargetTransformInfo::OK_NonUniformConstantValue;
104     if (cast<Constant>(V)->getSplatValue() != nullptr)
105       OpInfo = TargetTransformInfo::OK_UniformConstantValue;
106   }
107
108   return OpInfo;
109 }
110
111 static bool matchMask(SmallVectorImpl<int> &M1, SmallVectorImpl<int> &M2) {
112   if (M1.size() != M2.size())
113     return false;
114
115   for (unsigned i = 0, e = M1.size(); i != e; ++i)
116     if (M1[i] != M2[i])
117       return false;
118
119   return true;
120 }
121
122 static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft,
123                                      unsigned Level) {
124   // We don't need a shuffle if we just want to have element 0 in position 0 of
125   // the vector.
126   if (!SI && Level == 0 && IsLeft)
127     return true;
128   else if (!SI)
129     return false;
130
131   SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1);
132
133   // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether
134   // we look at the left or right side.
135   for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2)
136     Mask[i] = val;
137
138   SmallVector<int, 16> ActualMask = SI->getShuffleMask();
139   if (!matchMask(Mask, ActualMask))
140     return false;
141
142   return true;
143 }
144
145 static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp,
146                                           unsigned Level, unsigned NumLevels) {
147   // Match one level of pairwise operations.
148   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
149   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
150   // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
151   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
152   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
153   if (BinOp == nullptr)
154     return false;
155
156   assert(BinOp->getType()->isVectorTy() && "Expecting a vector type");
157
158   unsigned Opcode = BinOp->getOpcode();
159   Value *L = BinOp->getOperand(0);
160   Value *R = BinOp->getOperand(1);
161
162   ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L);
163   if (!LS && Level)
164     return false;
165   ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R);
166   if (!RS && Level)
167     return false;
168
169   // On level 0 we can omit one shufflevector instruction.
170   if (!Level && !RS && !LS)
171     return false;
172
173   // Shuffle inputs must match.
174   Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr;
175   Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr;
176   Value *NextLevelOp = nullptr;
177   if (NextLevelOpR && NextLevelOpL) {
178     // If we have two shuffles their operands must match.
179     if (NextLevelOpL != NextLevelOpR)
180       return false;
181
182     NextLevelOp = NextLevelOpL;
183   } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) {
184     // On the first level we can omit the shufflevector <0, undef,...>. So the
185     // input to the other shufflevector <1, undef> must match with one of the
186     // inputs to the current binary operation.
187     // Example:
188     //  %NextLevelOpL = shufflevector %R, <1, undef ...>
189     //  %BinOp        = fadd          %NextLevelOpL, %R
190     if (NextLevelOpL && NextLevelOpL != R)
191       return false;
192     else if (NextLevelOpR && NextLevelOpR != L)
193       return false;
194
195     NextLevelOp = NextLevelOpL ? R : L;
196   } else
197     return false;
198
199   // Check that the next levels binary operation exists and matches with the
200   // current one.
201   BinaryOperator *NextLevelBinOp = nullptr;
202   if (Level + 1 != NumLevels) {
203     if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp)))
204       return false;
205     else if (NextLevelBinOp->getOpcode() != Opcode)
206       return false;
207   }
208
209   // Shuffle mask for pairwise operation must match.
210   if (matchPairwiseShuffleMask(LS, true, Level)) {
211     if (!matchPairwiseShuffleMask(RS, false, Level))
212       return false;
213   } else if (matchPairwiseShuffleMask(RS, true, Level)) {
214     if (!matchPairwiseShuffleMask(LS, false, Level))
215       return false;
216   } else
217     return false;
218
219   if (++Level == NumLevels)
220     return true;
221
222   // Match next level.
223   return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels);
224 }
225
226 static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot,
227                                    unsigned &Opcode, Type *&Ty) {
228   if (!EnableReduxCost)
229     return false;
230
231   // Need to extract the first element.
232   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
233   unsigned Idx = ~0u;
234   if (CI)
235     Idx = CI->getZExtValue();
236   if (Idx != 0)
237     return false;
238
239   BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
240   if (!RdxStart)
241     return false;
242
243   Type *VecTy = ReduxRoot->getOperand(0)->getType();
244   unsigned NumVecElems = VecTy->getVectorNumElements();
245   if (!isPowerOf2_32(NumVecElems))
246     return false;
247
248   // We look for a sequence of shuffle,shuffle,add triples like the following
249   // that builds a pairwise reduction tree.
250   //
251   //  (X0, X1, X2, X3)
252   //   (X0 + X1, X2 + X3, undef, undef)
253   //    ((X0 + X1) + (X2 + X3), undef, undef, undef)
254   //
255   // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef,
256   //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef>
257   // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef,
258   //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef>
259   // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1
260   // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
261   //       <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef>
262   // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef,
263   //       <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
264   // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1
265   // %r = extractelement <4 x float> %bin.rdx8, i32 0
266   if (!matchPairwiseReductionAtLevel(RdxStart, 0,  Log2_32(NumVecElems)))
267     return false;
268
269   Opcode = RdxStart->getOpcode();
270   Ty = VecTy;
271
272   return true;
273 }
274
275 static std::pair<Value *, ShuffleVectorInst *>
276 getShuffleAndOtherOprd(BinaryOperator *B) {
277
278   Value *L = B->getOperand(0);
279   Value *R = B->getOperand(1);
280   ShuffleVectorInst *S = nullptr;
281
282   if ((S = dyn_cast<ShuffleVectorInst>(L)))
283     return std::make_pair(R, S);
284
285   S = dyn_cast<ShuffleVectorInst>(R);
286   return std::make_pair(L, S);
287 }
288
289 static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot,
290                                           unsigned &Opcode, Type *&Ty) {
291   if (!EnableReduxCost)
292     return false;
293
294   // Need to extract the first element.
295   ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1));
296   unsigned Idx = ~0u;
297   if (CI)
298     Idx = CI->getZExtValue();
299   if (Idx != 0)
300     return false;
301
302   BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0));
303   if (!RdxStart)
304     return false;
305   unsigned RdxOpcode = RdxStart->getOpcode();
306
307   Type *VecTy = ReduxRoot->getOperand(0)->getType();
308   unsigned NumVecElems = VecTy->getVectorNumElements();
309   if (!isPowerOf2_32(NumVecElems))
310     return false;
311
312   // We look for a sequence of shuffles and adds like the following matching one
313   // fadd, shuffle vector pair at a time.
314   //
315   // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef,
316   //                           <4 x i32> <i32 2, i32 3, i32 undef, i32 undef>
317   // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf
318   // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef,
319   //                          <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
320   // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7
321   // %r = extractelement <4 x float> %bin.rdx8, i32 0
322
323   unsigned MaskStart = 1;
324   Value *RdxOp = RdxStart;
325   SmallVector<int, 32> ShuffleMask(NumVecElems, 0);
326   unsigned NumVecElemsRemain = NumVecElems;
327   while (NumVecElemsRemain - 1) {
328     // Check for the right reduction operation.
329     BinaryOperator *BinOp;
330     if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp)))
331       return false;
332     if (BinOp->getOpcode() != RdxOpcode)
333       return false;
334
335     Value *NextRdxOp;
336     ShuffleVectorInst *Shuffle;
337     std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp);
338
339     // Check the current reduction operation and the shuffle use the same value.
340     if (Shuffle == nullptr)
341       return false;
342     if (Shuffle->getOperand(0) != NextRdxOp)
343       return false;
344
345     // Check that shuffle masks matches.
346     for (unsigned j = 0; j != MaskStart; ++j)
347       ShuffleMask[j] = MaskStart + j;
348     // Fill the rest of the mask with -1 for undef.
349     std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1);
350
351     SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
352     if (!matchMask(ShuffleMask, Mask))
353       return false;
354
355     RdxOp = NextRdxOp;
356     NumVecElemsRemain /= 2;
357     MaskStart *= 2;
358   }
359
360   Opcode = RdxOpcode;
361   Ty = VecTy;
362   return true;
363 }
364
365 unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
366   if (!TTI)
367     return -1;
368
369   switch (I->getOpcode()) {
370   case Instruction::GetElementPtr:{
371     Type *ValTy = I->getOperand(0)->getType()->getPointerElementType();
372     return TTI->getAddressComputationCost(ValTy);
373   }
374
375   case Instruction::Ret:
376   case Instruction::PHI:
377   case Instruction::Br: {
378     return TTI->getCFInstrCost(I->getOpcode());
379   }
380   case Instruction::Add:
381   case Instruction::FAdd:
382   case Instruction::Sub:
383   case Instruction::FSub:
384   case Instruction::Mul:
385   case Instruction::FMul:
386   case Instruction::UDiv:
387   case Instruction::SDiv:
388   case Instruction::FDiv:
389   case Instruction::URem:
390   case Instruction::SRem:
391   case Instruction::FRem:
392   case Instruction::Shl:
393   case Instruction::LShr:
394   case Instruction::AShr:
395   case Instruction::And:
396   case Instruction::Or:
397   case Instruction::Xor: {
398     TargetTransformInfo::OperandValueKind Op1VK =
399       getOperandInfo(I->getOperand(0));
400     TargetTransformInfo::OperandValueKind Op2VK =
401       getOperandInfo(I->getOperand(1));
402     return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK,
403                                        Op2VK);
404   }
405   case Instruction::Select: {
406     const SelectInst *SI = cast<SelectInst>(I);
407     Type *CondTy = SI->getCondition()->getType();
408     return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy);
409   }
410   case Instruction::ICmp:
411   case Instruction::FCmp: {
412     Type *ValTy = I->getOperand(0)->getType();
413     return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy);
414   }
415   case Instruction::Store: {
416     const StoreInst *SI = cast<StoreInst>(I);
417     Type *ValTy = SI->getValueOperand()->getType();
418     return TTI->getMemoryOpCost(I->getOpcode(), ValTy,
419                                  SI->getAlignment(),
420                                  SI->getPointerAddressSpace());
421   }
422   case Instruction::Load: {
423     const LoadInst *LI = cast<LoadInst>(I);
424     return TTI->getMemoryOpCost(I->getOpcode(), I->getType(),
425                                  LI->getAlignment(),
426                                  LI->getPointerAddressSpace());
427   }
428   case Instruction::ZExt:
429   case Instruction::SExt:
430   case Instruction::FPToUI:
431   case Instruction::FPToSI:
432   case Instruction::FPExt:
433   case Instruction::PtrToInt:
434   case Instruction::IntToPtr:
435   case Instruction::SIToFP:
436   case Instruction::UIToFP:
437   case Instruction::Trunc:
438   case Instruction::FPTrunc:
439   case Instruction::BitCast:
440   case Instruction::AddrSpaceCast: {
441     Type *SrcTy = I->getOperand(0)->getType();
442     return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy);
443   }
444   case Instruction::ExtractElement: {
445     const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
446     ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
447     unsigned Idx = -1;
448     if (CI)
449       Idx = CI->getZExtValue();
450
451     // Try to match a reduction sequence (series of shufflevector and vector
452     // adds followed by a extractelement).
453     unsigned ReduxOpCode;
454     Type *ReduxType;
455
456     if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType))
457       return TTI->getReductionCost(ReduxOpCode, ReduxType, false);
458     else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType))
459       return TTI->getReductionCost(ReduxOpCode, ReduxType, true);
460
461     return TTI->getVectorInstrCost(I->getOpcode(),
462                                    EEI->getOperand(0)->getType(), Idx);
463   }
464   case Instruction::InsertElement: {
465     const InsertElementInst * IE = cast<InsertElementInst>(I);
466     ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
467     unsigned Idx = -1;
468     if (CI)
469       Idx = CI->getZExtValue();
470     return TTI->getVectorInstrCost(I->getOpcode(),
471                                    IE->getType(), Idx);
472   }
473   case Instruction::ShuffleVector: {
474     const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
475     Type *VecTypOp0 = Shuffle->getOperand(0)->getType();
476     unsigned NumVecElems = VecTypOp0->getVectorNumElements();
477     SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
478
479     if (NumVecElems == Mask.size() && isReverseVectorMask(Mask))
480       return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0, 0,
481                                  nullptr);
482     return -1;
483   }
484   case Instruction::Call:
485     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
486       SmallVector<Type*, 4> Tys;
487       for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J)
488         Tys.push_back(II->getArgOperand(J)->getType());
489
490       return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(),
491                                         Tys);
492     }
493     return -1;
494   default:
495     // We don't have any information on this instruction.
496     return -1;
497   }
498 }
499
500 void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
501   if (!F)
502     return;
503
504   for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) {
505     for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) {
506       Instruction *Inst = it;
507       unsigned Cost = getInstructionCost(Inst);
508       if (Cost != (unsigned)-1)
509         OS << "Cost Model: Found an estimated cost of " << Cost;
510       else
511         OS << "Cost Model: Unknown cost";
512
513       OS << " for instruction: "<< *Inst << "\n";
514     }
515   }
516 }