Changes For Bug 352
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Brad Jones and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering Packed datatypes into more primitive
11 // Packed datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Argument.h"
16 #include "llvm/Constants.h"
17 #include "llvm/DerivedTypes.h"
18 #include "llvm/Function.h"
19 #include "llvm/Instructions.h"
20 #include "llvm/Pass.h"
21 #include "llvm/Support/InstVisitor.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include <algorithm>
24 #include <map>
25 #include <iostream>
26
27 using namespace llvm;
28
29 namespace {
30
31 /// This pass converts packed operators to an
32 /// equivalent operations on smaller packed data, to possibly
33 /// scalar operations.  Currently it supports lowering
34 /// to scalar operations.
35 ///
36 /// @brief Transforms packed instructions to simpler instructions.
37 ///
38 class LowerPacked : public FunctionPass, public InstVisitor<LowerPacked> {
39 public:
40    /// @brief Lowers packed operations to scalar operations. 
41    /// @param F The fuction to process
42    virtual bool runOnFunction(Function &F);
43
44    /// @brief Lowers packed load instructions.
45    /// @param LI the load instruction to convert
46    void visitLoadInst(LoadInst& LI);
47
48    /// @brief Lowers packed store instructions.
49    /// @param SI the store instruction to convert
50    void visitStoreInst(StoreInst& SI);
51
52    /// @brief Lowers packed binary operations.
53    /// @param BO the binary operator to convert
54    void visitBinaryOperator(BinaryOperator& BO);
55
56    /// @brief Lowers packed select instructions.
57    /// @param SELI the select operator to convert
58    void visitSelectInst(SelectInst& SELI);
59
60    /// This function asserts if the instruction is a PackedType but
61    /// is handled by another function.
62    /// 
63    /// @brief Asserts if PackedType instruction is not handled elsewhere.
64    /// @param I the unhandled instruction
65    void visitInstruction(Instruction &I)
66    {
67       if(isa<PackedType>(I.getType())) {
68          std::cerr << "Unhandled Instruction with Packed ReturnType: " << 
69                       I << '\n';
70       }
71    }
72 private:
73    /// @brief Retrieves lowered values for a packed value.
74    /// @param val the packed value
75    /// @return the lowered values
76    std::vector<Value*>& getValues(Value* val);
77
78    /// @brief Sets lowered values for a packed value.
79    /// @param val the packed value
80    /// @param values the corresponding lowered values
81    void setValues(Value* val,const std::vector<Value*>& values);
82
83    // Data Members
84    /// @brief whether we changed the function or not   
85    bool Changed;
86
87    /// @brief a map from old packed values to new smaller packed values
88    std::map<Value*,std::vector<Value*> > packedToScalarMap;
89
90    /// Instructions in the source program to get rid of
91    /// after we do a pass (the old packed instructions)
92    std::vector<Instruction*> instrsToRemove;
93 }; 
94
95 RegisterOpt<LowerPacked> 
96 X("lower-packed", 
97   "lowers packed operations to operations on smaller packed datatypes");
98
99 } // end namespace   
100
101 // This function sets lowered values for a corresponding
102 // packed value.  Note, in the case of a forward reference
103 // getValues(Value*) will have already been called for 
104 // the packed parameter.  This function will then replace 
105 // all references in the in the function of the "dummy" 
106 // value the previous getValues(Value*) call 
107 // returned with actual references.
108 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
109 {
110    std::map<Value*,std::vector<Value*> >::iterator it = 
111          packedToScalarMap.lower_bound(value);
112    if (it == packedToScalarMap.end() || it->first != value) {
113        // there was not a forward reference to this element
114        packedToScalarMap.insert(it,std::make_pair(value,values));
115    }
116    else {
117       // replace forward declarations with actual definitions
118       assert(it->second.size() == values.size() && 
119              "Error forward refences and actual definition differ in size");
120       for (unsigned i = 0, e = values.size(); i != e; ++i) {
121            // replace and get rid of old forward references
122            it->second[i]->replaceAllUsesWith(values[i]);
123            delete it->second[i];
124            it->second[i] = values[i];
125       }
126    }
127 }
128
129 // This function will examine the packed value parameter
130 // and if it is a packed constant or a forward reference
131 // properly create the lowered values needed.  Otherwise
132 // it will simply retreive values from a  
133 // setValues(Value*,const std::vector<Value*>&) 
134 // call.  Failing both of these cases, it will abort
135 // the program.
136 std::vector<Value*>& LowerPacked::getValues(Value* value)
137 {
138    assert(isa<PackedType>(value->getType()) &&
139           "Value must be PackedType");
140
141    // reject further processing if this one has
142    // already been handled
143    std::map<Value*,std::vector<Value*> >::iterator it = 
144       packedToScalarMap.lower_bound(value);
145    if (it != packedToScalarMap.end() && it->first == value) {
146        return it->second;
147    }
148
149    if (ConstantPacked* CP = dyn_cast<ConstantPacked>(value)) {
150        // non-zero constant case
151        std::vector<Value*> results;
152        results.reserve(CP->getNumOperands());
153        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
154           results.push_back(CP->getOperand(i));
155        }
156        return packedToScalarMap.insert(it,
157                                        std::make_pair(value,results))->second;
158    }
159    else if (ConstantAggregateZero* CAZ =
160             dyn_cast<ConstantAggregateZero>(value)) {
161        // zero constant 
162        const PackedType* PKT = cast<PackedType>(CAZ->getType());
163        std::vector<Value*> results;
164        results.reserve(PKT->getNumElements());
165    
166        Constant* C = Constant::getNullValue(PKT->getElementType());
167        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
168             results.push_back(C);
169        }
170        return packedToScalarMap.insert(it,
171                                        std::make_pair(value,results))->second;
172    }
173    else if (isa<Instruction>(value)) {
174        // foward reference
175        const PackedType* PKT = cast<PackedType>(value->getType());
176        std::vector<Value*> results;
177        results.reserve(PKT->getNumElements());
178    
179       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
180            results.push_back(new Argument(PKT->getElementType()));
181       }
182       return packedToScalarMap.insert(it,
183                                       std::make_pair(value,results))->second;
184    }
185    else {
186        // we don't know what it is, and we are trying to retrieve
187        // a value for it
188        assert(false && "Unhandled PackedType value");
189        abort();
190    }
191 }
192
193 void LowerPacked::visitLoadInst(LoadInst& LI)
194 {
195    // Make sure what we are dealing with is a packed type
196    if (const PackedType* PKT = dyn_cast<PackedType>(LI.getType())) {
197        // Initialization, Idx is needed for getelementptr needed later
198        std::vector<Value*> Idx(2);
199        Idx[0] = ConstantUInt::get(Type::UIntTy,0);
200
201        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
202                                       PKT->getNumElements());
203        PointerType* APT = PointerType::get(AT);
204
205        // Cast the packed type to an array
206        Value* array = new CastInst(LI.getPointerOperand(),
207                                    APT,
208                                    LI.getName() + ".a",
209                                    &LI);
210
211        // Convert this load into num elements number of loads
212        std::vector<Value*> values;
213        values.reserve(PKT->getNumElements());
214
215        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
216             // Calculate the second index we will need
217             Idx[1] = ConstantUInt::get(Type::UIntTy,i);
218
219             // Get the pointer
220             Value* val = new GetElementPtrInst(array, 
221                                                Idx,
222                                                LI.getName() + 
223                                                ".ge." + utostr(i),
224                                                &LI);
225
226             // generate the new load and save the result in packedToScalar map
227             values.push_back(new LoadInst(val, 
228                              LI.getName()+"."+utostr(i),
229                              LI.isVolatile(),
230                              &LI));
231        }
232                
233        setValues(&LI,values);
234        Changed = true;
235        instrsToRemove.push_back(&LI);
236    }
237 }
238
239 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
240 {
241    // Make sure both operands are PackedTypes
242    if (isa<PackedType>(BO.getOperand(0)->getType())) {
243        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
244        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
245        std::vector<Value*> result;
246        assert((op0Vals.size() == op1Vals.size()) &&
247               "The two packed operand to scalar maps must be equal in size.");
248
249        result.reserve(op0Vals.size());
250    
251        // generate the new binary op and save the result
252        for (unsigned i = 0; i != op0Vals.size(); ++i) {
253             result.push_back(BinaryOperator::create(BO.getOpcode(), 
254                                                     op0Vals[i], 
255                                                     op1Vals[i],
256                                                     BO.getName() + 
257                                                     "." + utostr(i),
258                                                     &BO));
259        }
260
261        setValues(&BO,result);
262        Changed = true;
263        instrsToRemove.push_back(&BO);
264    }
265 }
266
267 void LowerPacked::visitStoreInst(StoreInst& SI)
268 {
269    if (const PackedType* PKT = 
270        dyn_cast<PackedType>(SI.getOperand(0)->getType())) {
271        // We will need this for getelementptr
272        std::vector<Value*> Idx(2);
273        Idx[0] = ConstantUInt::get(Type::UIntTy,0);
274          
275        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
276                                       PKT->getNumElements());
277        PointerType* APT = PointerType::get(AT);
278
279        // cast the packed to an array type
280        Value* array = new CastInst(SI.getPointerOperand(),
281                                    APT,
282                                    "store.ge.a.",
283                                    &SI);
284        std::vector<Value*>& values = getValues(SI.getOperand(0));
285       
286        assert((values.size() == PKT->getNumElements()) &&
287               "Scalar must have the same number of elements as Packed Type");
288
289        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
290             // Generate the indices for getelementptr
291             Idx[1] = ConstantUInt::get(Type::UIntTy,i);
292             Value* val = new GetElementPtrInst(array, 
293                                                Idx,
294                                                "store.ge." +
295                                                utostr(i) + ".",
296                                                &SI);
297             new StoreInst(values[i], val, SI.isVolatile(),&SI);
298        }
299                  
300        Changed = true;
301        instrsToRemove.push_back(&SI);
302    }
303 }
304
305 void LowerPacked::visitSelectInst(SelectInst& SELI)
306 {
307    // Make sure both operands are PackedTypes
308    if (isa<PackedType>(SELI.getType())) {
309        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
310        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
311        std::vector<Value*> result;
312
313       assert((op0Vals.size() == op1Vals.size()) &&
314              "The two packed operand to scalar maps must be equal in size.");
315
316       for (unsigned i = 0; i != op0Vals.size(); ++i) {
317            result.push_back(new SelectInst(SELI.getCondition(),
318                                            op0Vals[i], 
319                                            op1Vals[i],
320                                            SELI.getName()+ "." + utostr(i),
321                                            &SELI));
322       }
323    
324       setValues(&SELI,result);
325       Changed = true;
326       instrsToRemove.push_back(&SELI);
327    }
328 }
329
330 bool LowerPacked::runOnFunction(Function& F)
331 {
332    // initialize
333    Changed = false; 
334   
335    // Does three passes:
336    // Pass 1) Converts Packed Operations to 
337    //         new Packed Operations on smaller
338    //         datatypes
339    visit(F);
340   
341    // Pass 2) Drop all references
342    std::for_each(instrsToRemove.begin(),
343                  instrsToRemove.end(),
344                  std::mem_fun(&Instruction::dropAllReferences));
345
346    // Pass 3) Delete the Instructions to remove aka packed instructions
347    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(), 
348                                             e = instrsToRemove.end(); 
349         i != e; ++i) {
350         (*i)->getParent()->getInstList().erase(*i);   
351    }
352
353    // clean-up
354    packedToScalarMap.clear();
355    instrsToRemove.clear();
356
357    return Changed;
358 }
359