fdb2c50a2c71a099625230d8c6a5eb5b5da9bdfc
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Brad Jones and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering Packed datatypes into more primitive
11 // Packed datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Argument.h"
16 #include "llvm/Constants.h"
17 #include "llvm/DerivedTypes.h"
18 #include "llvm/Function.h"
19 #include "llvm/Instructions.h"
20 #include "llvm/Pass.h"
21 #include "llvm/Support/InstVisitor.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include <algorithm>
24 #include <map>
25 #include <iostream>
26
27 using namespace llvm;
28
29 namespace {
30
31 /// This pass converts packed operators to an
32 /// equivalent operations on smaller packed data, to possibly
33 /// scalar operations.  Currently it supports lowering
34 /// to scalar operations.
35 ///
36 /// @brief Transforms packed instructions to simpler instructions.
37 ///
38 class LowerPacked : public FunctionPass, public InstVisitor<LowerPacked> {
39 public:
40    /// @brief Lowers packed operations to scalar operations. 
41    /// @param F The fuction to process
42    virtual bool runOnFunction(Function &F);
43
44    /// @brief Lowers packed load instructions.
45    /// @param LI the load instruction to convert
46    void visitLoadInst(LoadInst& LI);
47
48    /// @brief Lowers packed store instructions.
49    /// @param SI the store instruction to convert
50    void visitStoreInst(StoreInst& SI);
51
52    /// @brief Lowers packed binary operations.
53    /// @param BO the binary operator to convert
54    void visitBinaryOperator(BinaryOperator& BO);
55
56    /// @brief Lowers packed select instructions.
57    /// @param SELI the select operator to convert
58    void visitSelectInst(SelectInst& SELI);
59
60    /// This function asserts if the instruction is a PackedType but
61    /// is handled by another function.
62    /// 
63    /// @brief Asserts if PackedType instruction is not handled elsewhere.
64    /// @param I the unhandled instruction
65    void visitInstruction(Instruction &I)
66    {
67       if(isa<PackedType>(I.getType())) {
68          std::cerr << "Unhandled Instruction with Packed ReturnType: " << 
69                       I << '\n';
70       }
71    }
72 private:
73    /// @brief Retrieves lowered values for a packed value.
74    /// @param val the packed value
75    /// @return the lowered values
76    std::vector<Value*>& getValues(Value* val);
77
78    /// @brief Sets lowered values for a packed value.
79    /// @param val the packed value
80    /// @param values the corresponding lowered values
81    void setValues(Value* val,const std::vector<Value*>& values);
82
83    // Data Members
84    /// @brief whether we changed the function or not   
85    bool Changed;
86
87    /// @brief a map from old packed values to new smaller packed values
88    std::map<Value*,std::vector<Value*> > packedToScalarMap;
89
90    /// Instructions in the source program to get rid of
91    /// after we do a pass (the old packed instructions)
92    std::vector<Instruction*> instrsToRemove;
93 }; 
94
95 RegisterOpt<LowerPacked> 
96 X("lower-packed", 
97   "lowers packed operations to operations on smaller packed datatypes");
98
99 } // end namespace   
100
101 FunctionPass *createLowerPackedPass() { return new LowerPacked(); }
102
103
104 // This function sets lowered values for a corresponding
105 // packed value.  Note, in the case of a forward reference
106 // getValues(Value*) will have already been called for 
107 // the packed parameter.  This function will then replace 
108 // all references in the in the function of the "dummy" 
109 // value the previous getValues(Value*) call 
110 // returned with actual references.
111 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
112 {
113    std::map<Value*,std::vector<Value*> >::iterator it = 
114          packedToScalarMap.lower_bound(value);
115    if (it == packedToScalarMap.end() || it->first != value) {
116        // there was not a forward reference to this element
117        packedToScalarMap.insert(it,std::make_pair(value,values));
118    }
119    else {
120       // replace forward declarations with actual definitions
121       assert(it->second.size() == values.size() && 
122              "Error forward refences and actual definition differ in size");
123       for (unsigned i = 0, e = values.size(); i != e; ++i) {
124            // replace and get rid of old forward references
125            it->second[i]->replaceAllUsesWith(values[i]);
126            delete it->second[i];
127            it->second[i] = values[i];
128       }
129    }
130 }
131
132 // This function will examine the packed value parameter
133 // and if it is a packed constant or a forward reference
134 // properly create the lowered values needed.  Otherwise
135 // it will simply retreive values from a  
136 // setValues(Value*,const std::vector<Value*>&) 
137 // call.  Failing both of these cases, it will abort
138 // the program.
139 std::vector<Value*>& LowerPacked::getValues(Value* value)
140 {
141    assert(isa<PackedType>(value->getType()) &&
142           "Value must be PackedType");
143
144    // reject further processing if this one has
145    // already been handled
146    std::map<Value*,std::vector<Value*> >::iterator it = 
147       packedToScalarMap.lower_bound(value);
148    if (it != packedToScalarMap.end() && it->first == value) {
149        return it->second;
150    }
151
152    if (ConstantPacked* CP = dyn_cast<ConstantPacked>(value)) {
153        // non-zero constant case
154        std::vector<Value*> results;
155        results.reserve(CP->getNumOperands());
156        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
157           results.push_back(CP->getOperand(i));
158        }
159        return packedToScalarMap.insert(it,
160                                        std::make_pair(value,results))->second;
161    }
162    else if (ConstantAggregateZero* CAZ =
163             dyn_cast<ConstantAggregateZero>(value)) {
164        // zero constant 
165        const PackedType* PKT = cast<PackedType>(CAZ->getType());
166        std::vector<Value*> results;
167        results.reserve(PKT->getNumElements());
168    
169        Constant* C = Constant::getNullValue(PKT->getElementType());
170        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
171             results.push_back(C);
172        }
173        return packedToScalarMap.insert(it,
174                                        std::make_pair(value,results))->second;
175    }
176    else if (isa<Instruction>(value)) {
177        // foward reference
178        const PackedType* PKT = cast<PackedType>(value->getType());
179        std::vector<Value*> results;
180        results.reserve(PKT->getNumElements());
181    
182       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
183            results.push_back(new Argument(PKT->getElementType()));
184       }
185       return packedToScalarMap.insert(it,
186                                       std::make_pair(value,results))->second;
187    }
188    else {
189        // we don't know what it is, and we are trying to retrieve
190        // a value for it
191        assert(false && "Unhandled PackedType value");
192        abort();
193    }
194 }
195
196 void LowerPacked::visitLoadInst(LoadInst& LI)
197 {
198    // Make sure what we are dealing with is a packed type
199    if (const PackedType* PKT = dyn_cast<PackedType>(LI.getType())) {
200        // Initialization, Idx is needed for getelementptr needed later
201        std::vector<Value*> Idx(2);
202        Idx[0] = ConstantUInt::get(Type::UIntTy,0);
203
204        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
205                                       PKT->getNumElements());
206        PointerType* APT = PointerType::get(AT);
207
208        // Cast the packed type to an array
209        Value* array = new CastInst(LI.getPointerOperand(),
210                                    APT,
211                                    LI.getName() + ".a",
212                                    &LI);
213
214        // Convert this load into num elements number of loads
215        std::vector<Value*> values;
216        values.reserve(PKT->getNumElements());
217
218        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
219             // Calculate the second index we will need
220             Idx[1] = ConstantUInt::get(Type::UIntTy,i);
221
222             // Get the pointer
223             Value* val = new GetElementPtrInst(array, 
224                                                Idx,
225                                                LI.getName() + 
226                                                ".ge." + utostr(i),
227                                                &LI);
228
229             // generate the new load and save the result in packedToScalar map
230             values.push_back(new LoadInst(val, 
231                              LI.getName()+"."+utostr(i),
232                              LI.isVolatile(),
233                              &LI));
234        }
235                
236        setValues(&LI,values);
237        Changed = true;
238        instrsToRemove.push_back(&LI);
239    }
240 }
241
242 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
243 {
244    // Make sure both operands are PackedTypes
245    if (isa<PackedType>(BO.getOperand(0)->getType())) {
246        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
247        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
248        std::vector<Value*> result;
249        assert((op0Vals.size() == op1Vals.size()) &&
250               "The two packed operand to scalar maps must be equal in size.");
251
252        result.reserve(op0Vals.size());
253    
254        // generate the new binary op and save the result
255        for (unsigned i = 0; i != op0Vals.size(); ++i) {
256             result.push_back(BinaryOperator::create(BO.getOpcode(), 
257                                                     op0Vals[i], 
258                                                     op1Vals[i],
259                                                     BO.getName() + 
260                                                     "." + utostr(i),
261                                                     &BO));
262        }
263
264        setValues(&BO,result);
265        Changed = true;
266        instrsToRemove.push_back(&BO);
267    }
268 }
269
270 void LowerPacked::visitStoreInst(StoreInst& SI)
271 {
272    if (const PackedType* PKT = 
273        dyn_cast<PackedType>(SI.getOperand(0)->getType())) {
274        // We will need this for getelementptr
275        std::vector<Value*> Idx(2);
276        Idx[0] = ConstantUInt::get(Type::UIntTy,0);
277          
278        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
279                                       PKT->getNumElements());
280        PointerType* APT = PointerType::get(AT);
281
282        // cast the packed to an array type
283        Value* array = new CastInst(SI.getPointerOperand(),
284                                    APT,
285                                    "store.ge.a.",
286                                    &SI);
287        std::vector<Value*>& values = getValues(SI.getOperand(0));
288       
289        assert((values.size() == PKT->getNumElements()) &&
290               "Scalar must have the same number of elements as Packed Type");
291
292        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
293             // Generate the indices for getelementptr
294             Idx[1] = ConstantUInt::get(Type::UIntTy,i);
295             Value* val = new GetElementPtrInst(array, 
296                                                Idx,
297                                                "store.ge." +
298                                                utostr(i) + ".",
299                                                &SI);
300             new StoreInst(values[i], val, SI.isVolatile(),&SI);
301        }
302                  
303        Changed = true;
304        instrsToRemove.push_back(&SI);
305    }
306 }
307
308 void LowerPacked::visitSelectInst(SelectInst& SELI)
309 {
310    // Make sure both operands are PackedTypes
311    if (isa<PackedType>(SELI.getType())) {
312        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
313        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
314        std::vector<Value*> result;
315
316       assert((op0Vals.size() == op1Vals.size()) &&
317              "The two packed operand to scalar maps must be equal in size.");
318
319       for (unsigned i = 0; i != op0Vals.size(); ++i) {
320            result.push_back(new SelectInst(SELI.getCondition(),
321                                            op0Vals[i], 
322                                            op1Vals[i],
323                                            SELI.getName()+ "." + utostr(i),
324                                            &SELI));
325       }
326    
327       setValues(&SELI,result);
328       Changed = true;
329       instrsToRemove.push_back(&SELI);
330    }
331 }
332
333 bool LowerPacked::runOnFunction(Function& F)
334 {
335    // initialize
336    Changed = false; 
337   
338    // Does three passes:
339    // Pass 1) Converts Packed Operations to 
340    //         new Packed Operations on smaller
341    //         datatypes
342    visit(F);
343   
344    // Pass 2) Drop all references
345    std::for_each(instrsToRemove.begin(),
346                  instrsToRemove.end(),
347                  std::mem_fun(&Instruction::dropAllReferences));
348
349    // Pass 3) Delete the Instructions to remove aka packed instructions
350    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(), 
351                                             e = instrsToRemove.end(); 
352         i != e; ++i) {
353         (*i)->getParent()->getInstList().erase(*i);   
354    }
355
356    // clean-up
357    packedToScalarMap.clear();
358    instrsToRemove.clear();
359
360    return Changed;
361 }
362