Load & StoreInst no longer derive from MemAccessInst, so we don't have
[oota-llvm.git] / lib / Transforms / Scalar / DecomposeMultiDimRefs.cpp
1 //===- llvm/Transforms/DecomposeMultiDimRefs.cpp - Lower array refs to 1D -===//
2 //
3 // DecomposeMultiDimRefs - Convert multi-dimensional references consisting of
4 // any combination of 2 or more array and structure indices into a sequence of
5 // instructions (using getelementpr and cast) so that each instruction has at
6 // most one index (except structure references, which need an extra leading
7 // index of [0]).
8 //
9 //===----------------------------------------------------------------------===//
10
11 #include "llvm/Transforms/Scalar.h"
12 #include "llvm/DerivedTypes.h"
13 #include "llvm/Constants.h"
14 #include "llvm/Constant.h"
15 #include "llvm/iMemory.h"
16 #include "llvm/iOther.h"
17 #include "llvm/BasicBlock.h"
18 #include "llvm/Pass.h"
19 #include "Support/StatisticReporter.h"
20
21 static Statistic<> NumAdded("lowerrefs\t\t- New instructions added");
22
23 namespace {
24   struct DecomposePass : public BasicBlockPass {
25     virtual bool runOnBasicBlock(BasicBlock &BB);
26
27   private:
28     static bool decomposeArrayRef(BasicBlock::iterator &BBI);
29   };
30
31   RegisterOpt<DecomposePass> X("lowerrefs", "Decompose multi-dimensional "
32                                "structure/array references");
33 }
34
35 Pass
36 *createDecomposeMultiDimRefsPass()
37 {
38   return new DecomposePass();
39 }
40
41
42 // runOnBasicBlock - Entry point for array or structure references with multiple
43 // indices.
44 //
45 bool
46 DecomposePass::runOnBasicBlock(BasicBlock &BB)
47 {
48   bool Changed = false;
49   for (BasicBlock::iterator II = BB.begin(); II != BB.end(); ) {
50     if (MemAccessInst *MAI = dyn_cast<MemAccessInst>(&*II))
51       if (MAI->getNumIndices() >= 2) {
52         Changed |= decomposeArrayRef(II); // always modifies II
53         continue;
54       }
55     ++II;
56   }
57   return Changed;
58 }
59
60 // Check for a constant (uint) 0.
61 inline bool
62 IsZero(Value* idx)
63 {
64   return (isa<ConstantInt>(idx) && cast<ConstantInt>(idx)->isNullValue());
65 }
66
67 // For any MemAccessInst with 2 or more array and structure indices:
68 // 
69 //      opCode CompositeType* P, [uint|ubyte] idx1, ..., [uint|ubyte] idxN
70 // 
71 // this function generates the foll sequence:
72 // 
73 //      ptr1   = getElementPtr P,         idx1
74 //      ptr2   = getElementPtr ptr1,   0, idx2
75 //      ...
76 //      ptrN-1 = getElementPtr ptrN-2, 0, idxN-1
77 //      opCode                 ptrN-1, 0, idxN  // New-MAI
78 // 
79 // Then it replaces the original instruction with this sequence,
80 // and replaces all uses of the original instruction with New-MAI.
81 // If idx1 is 0, we simply omit the first getElementPtr instruction.
82 // 
83 // On return: BBI points to the instruction after the current one
84 //            (whether or not *BBI was replaced).
85 // 
86 // Return value: true if the instruction was replaced; false otherwise.
87 // 
88 bool
89 DecomposePass::decomposeArrayRef(BasicBlock::iterator &BBI)
90 {
91   MemAccessInst &MAI = cast<MemAccessInst>(*BBI);
92   BasicBlock *BB = MAI.getParent();
93   Value *LastPtr = MAI.getPointerOperand();
94
95   // Remove the instruction from the stream
96   BB->getInstList().remove(BBI);
97
98   // The vector of new instructions to be created
99   std::vector<Instruction*> NewInsts;
100
101   // Process each index except the last one.
102   User::const_op_iterator OI = MAI.idx_begin(), OE = MAI.idx_end();
103   for (; OI+1 != OE; ++OI) {
104     std::vector<Value*> Indices;
105     
106     // If this is the first index and is 0, skip it and move on!
107     if (OI == MAI.idx_begin()) {
108       if (IsZero(*OI)) continue;
109     } else
110       // Not the first index: include initial [0] to deref the last ptr
111       Indices.push_back(Constant::getNullValue(Type::UIntTy));
112
113     Indices.push_back(*OI);
114
115     // New Instruction: nextPtr1 = GetElementPtr LastPtr, Indices
116     LastPtr = new GetElementPtrInst(LastPtr, Indices, "ptr1");
117     NewInsts.push_back(cast<Instruction>(LastPtr));
118     ++NumAdded;
119   }
120
121   // Now create a new instruction to replace the original one
122   //
123   const PointerType *PtrTy = cast<PointerType>(LastPtr->getType());
124
125   // Get the final index vector, including an initial [0] as before.
126   std::vector<Value*> Indices;
127   Indices.push_back(Constant::getNullValue(Type::UIntTy));
128   Indices.push_back(*OI);
129
130   Instruction *NewI = 0;
131   switch(MAI.getOpcode()) {
132   case Instruction::GetElementPtr:
133     NewI = new GetElementPtrInst(LastPtr, Indices, MAI.getName());
134     break;
135   default:
136     assert(0 && "Unrecognized memory access instruction");
137   }
138   NewInsts.push_back(NewI);
139
140   // Replace all uses of the old instruction with the new
141   MAI.replaceAllUsesWith(NewI);
142
143   // Now delete the old instruction...
144   delete &MAI;
145
146   // Insert all of the new instructions...
147   BB->getInstList().insert(BBI, NewInsts.begin(), NewInsts.end());
148
149   // Advance the iterator to the instruction following the one just inserted...
150   BBI = NewInsts.back();
151   ++BBI;
152   return true;
153 }