Cost model: Add check for reverse shuffles to CostModel analysis
[oota-llvm.git] / lib / Analysis / CostModel.cpp
1 //===- CostModel.cpp ------ Cost Model Analysis ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the cost model analysis. It provides a very basic cost
11 // estimation for LLVM-IR. This analysis uses the services of the codegen
12 // to approximate the cost of any IR instruction when lowered to machine
13 // instructions. The cost results are unit-less and the cost number represents
14 // the throughput of the machine assuming that all loads hit the cache, all
15 // branches are predicted, etc. The cost numbers can be added in order to
16 // compare two or more transformation alternatives.
17 //
18 //===----------------------------------------------------------------------===//
19
20 #define CM_NAME "cost-model"
21 #define DEBUG_TYPE CM_NAME
22 #include "llvm/Analysis/Passes.h"
23 #include "llvm/Analysis/TargetTransformInfo.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Instructions.h"
26 #include "llvm/IR/Value.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 using namespace llvm;
31
32 namespace {
33   class CostModelAnalysis : public FunctionPass {
34
35   public:
36     static char ID; // Class identification, replacement for typeinfo
37     CostModelAnalysis() : FunctionPass(ID), F(0), TTI(0) {
38       initializeCostModelAnalysisPass(
39         *PassRegistry::getPassRegistry());
40     }
41
42     /// Returns the expected cost of the instruction.
43     /// Returns -1 if the cost is unknown.
44     /// Note, this method does not cache the cost calculation and it
45     /// can be expensive in some cases.
46     unsigned getInstructionCost(const Instruction *I) const;
47
48   private:
49     virtual void getAnalysisUsage(AnalysisUsage &AU) const;
50     virtual bool runOnFunction(Function &F);
51     virtual void print(raw_ostream &OS, const Module*) const;
52
53     /// The function that we analyze.
54     Function *F;
55     /// Target information.
56     const TargetTransformInfo *TTI;
57   };
58 }  // End of anonymous namespace
59
60 // Register this pass.
61 char CostModelAnalysis::ID = 0;
62 static const char cm_name[] = "Cost Model Analysis";
63 INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true)
64 INITIALIZE_PASS_END  (CostModelAnalysis, CM_NAME, cm_name, false, true)
65
66 FunctionPass *llvm::createCostModelAnalysisPass() {
67   return new CostModelAnalysis();
68 }
69
70 void
71 CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
72   AU.setPreservesAll();
73 }
74
75 bool
76 CostModelAnalysis::runOnFunction(Function &F) {
77  this->F = &F;
78  TTI = getAnalysisIfAvailable<TargetTransformInfo>();
79
80  return false;
81 }
82
83 static bool isReverseVectorMask(SmallVector<int, 16> &Mask) {
84   for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i)
85     if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i))
86       return false;
87   return true;
88 }
89
90 unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const {
91   if (!TTI)
92     return -1;
93
94   switch (I->getOpcode()) {
95   case Instruction::GetElementPtr:{
96     Type *ValTy = I->getOperand(0)->getType()->getPointerElementType();
97     return TTI->getAddressComputationCost(ValTy);
98   }
99
100   case Instruction::Ret:
101   case Instruction::PHI:
102   case Instruction::Br: {
103     return TTI->getCFInstrCost(I->getOpcode());
104   }
105   case Instruction::Add:
106   case Instruction::FAdd:
107   case Instruction::Sub:
108   case Instruction::FSub:
109   case Instruction::Mul:
110   case Instruction::FMul:
111   case Instruction::UDiv:
112   case Instruction::SDiv:
113   case Instruction::FDiv:
114   case Instruction::URem:
115   case Instruction::SRem:
116   case Instruction::FRem:
117   case Instruction::Shl:
118   case Instruction::LShr:
119   case Instruction::AShr:
120   case Instruction::And:
121   case Instruction::Or:
122   case Instruction::Xor: {
123     return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType());
124   }
125   case Instruction::Select: {
126     const SelectInst *SI = cast<SelectInst>(I);
127     Type *CondTy = SI->getCondition()->getType();
128     return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy);
129   }
130   case Instruction::ICmp:
131   case Instruction::FCmp: {
132     Type *ValTy = I->getOperand(0)->getType();
133     return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy);
134   }
135   case Instruction::Store: {
136     const StoreInst *SI = cast<StoreInst>(I);
137     Type *ValTy = SI->getValueOperand()->getType();
138     return TTI->getMemoryOpCost(I->getOpcode(), ValTy,
139                                  SI->getAlignment(),
140                                  SI->getPointerAddressSpace());
141   }
142   case Instruction::Load: {
143     const LoadInst *LI = cast<LoadInst>(I);
144     return TTI->getMemoryOpCost(I->getOpcode(), I->getType(),
145                                  LI->getAlignment(),
146                                  LI->getPointerAddressSpace());
147   }
148   case Instruction::ZExt:
149   case Instruction::SExt:
150   case Instruction::FPToUI:
151   case Instruction::FPToSI:
152   case Instruction::FPExt:
153   case Instruction::PtrToInt:
154   case Instruction::IntToPtr:
155   case Instruction::SIToFP:
156   case Instruction::UIToFP:
157   case Instruction::Trunc:
158   case Instruction::FPTrunc:
159   case Instruction::BitCast: {
160     Type *SrcTy = I->getOperand(0)->getType();
161     return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy);
162   }
163   case Instruction::ExtractElement: {
164     const ExtractElementInst * EEI = cast<ExtractElementInst>(I);
165     ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1));
166     unsigned Idx = -1;
167     if (CI)
168       Idx = CI->getZExtValue();
169     return TTI->getVectorInstrCost(I->getOpcode(),
170                                    EEI->getOperand(0)->getType(), Idx);
171   }
172   case Instruction::InsertElement: {
173       const InsertElementInst * IE = cast<InsertElementInst>(I);
174       ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2));
175       unsigned Idx = -1;
176       if (CI)
177         Idx = CI->getZExtValue();
178       return TTI->getVectorInstrCost(I->getOpcode(),
179                                      IE->getType(), Idx);
180     }
181   case Instruction::ShuffleVector: {
182     const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
183     Type *VecTypOp0 = Shuffle->getOperand(0)->getType();
184     unsigned NumVecElems = VecTypOp0->getVectorNumElements();
185     SmallVector<int, 16> Mask = Shuffle->getShuffleMask();
186
187     if (NumVecElems == Mask.size() && isReverseVectorMask(Mask))
188       return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0, 0,
189                                  0);
190     return -1;
191   }
192   default:
193     // We don't have any information on this instruction.
194     return -1;
195   }
196 }
197
198 void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
199   if (!F)
200     return;
201
202   for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) {
203     for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) {
204       Instruction *Inst = it;
205       unsigned Cost = getInstructionCost(Inst);
206       if (Cost != (unsigned)-1)
207         OS << "Cost Model: Found an estimated cost of " << Cost;
208       else
209         OS << "Cost Model: Unknown cost";
210
211       OS << " for instruction: "<< *Inst << "\n";
212     }
213   }
214 }