s/Method/Function
[oota-llvm.git] / lib / Transforms / Scalar / DecomposeMultiDimRefs.cpp
1 //===- llvm/Transforms/DecomposeMultiDimRefs.cpp - Lower array refs to 1D -----=//
2 //
3 // DecomposeMultiDimRefs - 
4 // Convert multi-dimensional references consisting of any combination
5 // of 2 or more array and structure indices into a sequence of
6 // instructions (using getelementpr and cast) so that each instruction
7 // has at most one index (except structure references,
8 // which need an extra leading index of [0]).
9 //
10 //===---------------------------------------------------------------------===//
11
12 #include "llvm/Transforms/Scalar/DecomposeMultiDimRefs.h"
13 #include "llvm/ConstantVals.h"
14 #include "llvm/iMemory.h"
15 #include "llvm/iOther.h"
16 #include "llvm/BasicBlock.h"
17 #include "llvm/Function.h"
18 #include "llvm/Pass.h"
19
20
21 // 
22 // For any combination of 2 or more array and structure indices,
23 // this function repeats the foll. until we have a one-dim. reference: {
24 //      ptr1 = getElementPtr [CompositeType-N] * lastPtr, uint firstIndex
25 //      ptr2 = cast [CompositeType-N] * ptr1 to [CompositeType-N] *
26 // }
27 // Then it replaces the original instruction with an equivalent one that
28 // uses the last ptr2 generated in the loop and a single index.
29 // If any index is (uint) 0, we omit the getElementPtr instruction.
30 // 
31 static BasicBlock::iterator
32 decomposeArrayRef(BasicBlock::iterator& BBI)
33 {
34   MemAccessInst *memI = cast<MemAccessInst>(*BBI);
35   BasicBlock* BB = memI->getParent();
36   Value* lastPtr = memI->getPointerOperand();
37   vector<Instruction*> newIvec;
38   
39   // Process each index except the last one.
40   // 
41   MemAccessInst::const_op_iterator OI = memI->idx_begin();
42   MemAccessInst::const_op_iterator OE = memI->idx_end();
43   for ( ; OI != OE; ++OI)
44     {
45       assert(isa<PointerType>(lastPtr->getType()));
46       
47       if (OI+1 == OE)                   // stop before the last operand
48         break;
49       
50       // Check for a zero index.  This will need a cast instead of
51       // a getElementPtr, or it may need neither.
52       bool indexIsZero = bool(isa<ConstantUInt>(*OI) && 
53                               cast<ConstantUInt>(*OI)->getValue() == 0);
54       
55       // Extract the first index.  If the ptr is a pointer to a structure
56       // and the next index is a structure offset (i.e., not an array offset), 
57       // we need to include an initial [0] to index into the pointer.
58       vector<Value*> idxVec(1, *OI);
59       PointerType* ptrType = cast<PointerType>(lastPtr->getType());
60       if (isa<StructType>(ptrType->getElementType())
61           && ! ptrType->indexValid(*OI))
62         idxVec.insert(idxVec.begin(), ConstantUInt::get(Type::UIntTy, 0));
63       
64       // Get the type obtained by applying the first index.
65       // It must be a structure or array.
66       const Type* nextType = MemAccessInst::getIndexedType(lastPtr->getType(),
67                                                            idxVec, true);
68       assert(isa<StructType>(nextType) || isa<ArrayType>(nextType));
69       
70       // Get a pointer to the structure or to the elements of the array.
71       const Type* nextPtrType =
72         PointerType::get(isa<StructType>(nextType)? nextType
73                          : cast<ArrayType>(nextType)->getElementType());
74       
75       // Instruction 1: nextPtr1 = GetElementPtr lastPtr, idxVec
76       // This is not needed if the index is zero.
77       Value* gepValue;
78       if (indexIsZero)
79         gepValue = lastPtr;
80       else
81         {
82           gepValue = new GetElementPtrInst(lastPtr, idxVec,"ptr1");
83           newIvec.push_back(cast<Instruction>(gepValue));
84         }
85       
86       // Instruction 2: nextPtr2 = cast nextPtr1 to nextPtrType
87       // This is not needed if the two types are identical.
88       Value* castInst;
89       if (gepValue->getType() == nextPtrType)
90         castInst = gepValue;
91       else
92         {
93           castInst = new CastInst(gepValue, nextPtrType, "ptr2");
94           newIvec.push_back(cast<Instruction>(castInst));
95         }
96       
97       lastPtr = castInst;
98     }
99   
100   // 
101   // Now create a new instruction to replace the original one
102   //
103   PointerType* ptrType = cast<PointerType>(lastPtr->getType());
104   assert(ptrType);
105
106   // First, get the final index vector.  As above, we may need an initial [0].
107   vector<Value*> idxVec(1, *OI);
108   if (isa<StructType>(ptrType->getElementType())
109       && ! ptrType->indexValid(*OI))
110     idxVec.insert(idxVec.begin(), ConstantUInt::get(Type::UIntTy, 0));
111   
112   const std::string newInstName = memI->hasName()? memI->getName()
113                                                  : string("finalRef");
114   Instruction* newInst = NULL;
115   
116   switch(memI->getOpcode())
117     {
118     case Instruction::Load:
119       newInst = new LoadInst(lastPtr, idxVec /*, newInstName */); break;
120     case Instruction::Store:
121       newInst = new StoreInst(memI->getOperand(0),
122                               lastPtr, idxVec /*, newInstName */); break;
123       break;
124     case Instruction::GetElementPtr:
125       newInst = new GetElementPtrInst(lastPtr, idxVec /*, newInstName */); break;
126     default:
127       assert(0 && "Unrecognized memory access instruction"); break;
128     }
129   
130   newIvec.push_back(newInst);
131   
132   // Replace all uses of the old instruction with the new
133   memI->replaceAllUsesWith(newInst);
134   
135   BasicBlock::iterator newI = BBI;;
136   for (int i = newIvec.size()-1; i >= 0; i--)
137     newI = BB->getInstList().insert(newI, newIvec[i]);
138   
139   // Now delete the old instruction and return a pointer to the last new one
140   BB->getInstList().remove(memI);
141   delete memI;
142   
143   return newI + newIvec.size() - 1;           // pointer to last new instr
144 }
145
146
147 //---------------------------------------------------------------------------
148 // Entry point for array or  structure references with multiple indices.
149 //---------------------------------------------------------------------------
150
151 static bool
152 doDecomposeMultiDimRefs(Function *F)
153 {
154   bool changed = false;
155   
156   for (Function::iterator BI = F->begin(), BE = F->end(); BI != BE; ++BI)
157     for (BasicBlock::iterator newI, II = (*BI)->begin();
158          II != (*BI)->end(); II = ++newI)
159       {
160         newI = II;
161         if (MemAccessInst *memI = dyn_cast<MemAccessInst>(*II))
162           if (memI->getNumOperands() > 1 + memI->getFirstIndexOperandNumber())
163             {
164               newI = decomposeArrayRef(II);
165               changed = true;
166             }
167       }
168   
169   return changed;
170 }
171
172
173 namespace {
174   struct DecomposeMultiDimRefsPass : public MethodPass {
175     virtual bool runOnMethod(Function *F) { return doDecomposeMultiDimRefs(F); }
176   };
177 }
178
179 Pass *createDecomposeMultiDimRefsPass() { return new DecomposeMultiDimRefsPass(); }