Add support for simplifying a load from a computed value to a load from a global...
[oota-llvm.git] / lib / Transforms / Scalar / CorrelatedValuePropagation.cpp
1 //===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the Correlated Value Propagation pass.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #define DEBUG_TYPE "correlated-value-propagation"
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Function.h"
17 #include "llvm/Instructions.h"
18 #include "llvm/Pass.h"
19 #include "llvm/Analysis/LazyValueInfo.h"
20 #include "llvm/Transforms/Utils/Local.h"
21 #include "llvm/ADT/Statistic.h"
22 using namespace llvm;
23
24 STATISTIC(NumPhis,      "Number of phis propagated");
25 STATISTIC(NumSelects,   "Number of selects propagated");
26 STATISTIC(NumMemAccess, "Number of memory access targets propagated");
27
28 namespace {
29   class CorrelatedValuePropagation : public FunctionPass {
30     LazyValueInfo *LVI;
31     
32     bool processSelect(SelectInst *SI);
33     bool processPHI(PHINode *P);
34     bool processMemAccess(Instruction *I);
35     
36   public:
37     static char ID;
38     CorrelatedValuePropagation(): FunctionPass(ID) { }
39     
40     bool runOnFunction(Function &F);
41     
42     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
43       AU.addRequired<LazyValueInfo>();
44     }
45   };
46 }
47
48 char CorrelatedValuePropagation::ID = 0;
49 INITIALIZE_PASS(CorrelatedValuePropagation, "correlated-propagation",
50                 "Value Propagation", false, false);
51
52 // Public interface to the Value Propagation pass
53 Pass *llvm::createCorrelatedValuePropagationPass() {
54   return new CorrelatedValuePropagation();
55 }
56
57 bool CorrelatedValuePropagation::processSelect(SelectInst *S) {
58   if (S->getType()->isVectorTy()) return false;
59   if (isa<Constant>(S->getOperand(0))) return false;
60   
61   Constant *C = LVI->getConstant(S->getOperand(0), S->getParent());
62   if (!C) return false;
63   
64   ConstantInt *CI = dyn_cast<ConstantInt>(C);
65   if (!CI) return false;
66   
67   S->replaceAllUsesWith(S->getOperand(CI->isOne() ? 1 : 2));
68   S->eraseFromParent();
69
70   ++NumSelects;
71   
72   return true;
73 }
74
75 bool CorrelatedValuePropagation::processPHI(PHINode *P) {
76   bool Changed = false;
77   
78   BasicBlock *BB = P->getParent();
79   for (unsigned i = 0, e = P->getNumIncomingValues(); i < e; ++i) {
80     Value *Incoming = P->getIncomingValue(i);
81     if (isa<Constant>(Incoming)) continue;
82     
83     Constant *C = LVI->getConstantOnEdge(P->getIncomingValue(i),
84                                          P->getIncomingBlock(i),
85                                          BB);
86     if (!C) continue;
87     
88     P->setIncomingValue(i, C);
89     Changed = true;
90   }
91   
92   if (Value *ConstVal = P->hasConstantValue()) {
93     P->replaceAllUsesWith(ConstVal);
94     P->eraseFromParent();
95     Changed = true;
96   }
97   
98   ++NumPhis;
99   
100   return Changed;
101 }
102
103 bool CorrelatedValuePropagation::processMemAccess(Instruction *I) {
104   Value *Pointer = 0;
105   if (LoadInst *L = dyn_cast<LoadInst>(I))
106     Pointer = L->getPointerOperand();
107   else
108     Pointer = cast<StoreInst>(I)->getPointerOperand();
109   
110   if (isa<Constant>(Pointer)) return false;
111   
112   Constant *C = LVI->getConstant(Pointer, I->getParent());
113   if (!C) return false;
114   
115   ++NumMemAccess;
116   I->replaceUsesOfWith(Pointer, C);
117   return true;
118 }
119
120 bool CorrelatedValuePropagation::runOnFunction(Function &F) {
121   LVI = &getAnalysis<LazyValueInfo>();
122   
123   bool FnChanged = false;
124   
125   for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI) {
126     bool BBChanged = false;
127     for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) {
128       Instruction *II = BI++;
129       switch (II->getOpcode()) {
130       case Instruction::Select:
131         BBChanged |= processSelect(cast<SelectInst>(II));
132         break;
133       case Instruction::PHI:
134         BBChanged |= processPHI(cast<PHINode>(II));
135         break;
136       case Instruction::Load:
137       case Instruction::Store:
138         BBChanged |= processMemAccess(II);
139         break;
140       }
141     }
142     
143     // Propagating correlated values might leave cruft around.
144     // Try to clean it up before we continue.
145     if (BBChanged)
146       SimplifyInstructionsInBlock(FI);
147     
148     FnChanged |= BBChanged;
149   }
150   
151   return FnChanged;
152 }