stop using methods that take vectors.
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Brad Jones and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering Packed datatypes into more primitive
11 // Packed datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Argument.h"
17 #include "llvm/Constants.h"
18 #include "llvm/DerivedTypes.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Support/Compiler.h"
23 #include "llvm/Support/InstVisitor.h"
24 #include "llvm/Support/Streams.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include <algorithm>
27 #include <map>
28 #include <functional>
29 using namespace llvm;
30
31 namespace {
32
33 /// This pass converts packed operators to an
34 /// equivalent operations on smaller packed data, to possibly
35 /// scalar operations.  Currently it supports lowering
36 /// to scalar operations.
37 ///
38 /// @brief Transforms packed instructions to simpler instructions.
39 ///
40 class VISIBILITY_HIDDEN LowerPacked 
41   : public FunctionPass, public InstVisitor<LowerPacked> {
42 public:
43    /// @brief Lowers packed operations to scalar operations.
44    /// @param F The fuction to process
45    virtual bool runOnFunction(Function &F);
46
47    /// @brief Lowers packed load instructions.
48    /// @param LI the load instruction to convert
49    void visitLoadInst(LoadInst& LI);
50
51    /// @brief Lowers packed store instructions.
52    /// @param SI the store instruction to convert
53    void visitStoreInst(StoreInst& SI);
54
55    /// @brief Lowers packed binary operations.
56    /// @param BO the binary operator to convert
57    void visitBinaryOperator(BinaryOperator& BO);
58
59    /// @brief Lowers packed icmp operations.
60    /// @param CI the icmp operator to convert
61    void visitICmpInst(ICmpInst& IC);
62
63    /// @brief Lowers packed select instructions.
64    /// @param SELI the select operator to convert
65    void visitSelectInst(SelectInst& SELI);
66
67    /// @brief Lowers packed extractelement instructions.
68    /// @param EI the extractelement operator to convert
69    void visitExtractElementInst(ExtractElementInst& EE);
70
71    /// @brief Lowers packed insertelement instructions.
72    /// @param EI the insertelement operator to convert
73    void visitInsertElementInst(InsertElementInst& IE);
74
75    /// This function asserts if the instruction is a PackedType but
76    /// is handled by another function.
77    ///
78    /// @brief Asserts if PackedType instruction is not handled elsewhere.
79    /// @param I the unhandled instruction
80    void visitInstruction(Instruction &I) {
81      if (isa<PackedType>(I.getType()))
82        cerr << "Unhandled Instruction with Packed ReturnType: " << I << '\n';
83    }
84 private:
85    /// @brief Retrieves lowered values for a packed value.
86    /// @param val the packed value
87    /// @return the lowered values
88    std::vector<Value*>& getValues(Value* val);
89
90    /// @brief Sets lowered values for a packed value.
91    /// @param val the packed value
92    /// @param values the corresponding lowered values
93    void setValues(Value* val,const std::vector<Value*>& values);
94
95    // Data Members
96    /// @brief whether we changed the function or not
97    bool Changed;
98
99    /// @brief a map from old packed values to new smaller packed values
100    std::map<Value*,std::vector<Value*> > packedToScalarMap;
101
102    /// Instructions in the source program to get rid of
103    /// after we do a pass (the old packed instructions)
104    std::vector<Instruction*> instrsToRemove;
105 };
106
107 RegisterPass<LowerPacked>
108 X("lower-packed",
109   "lowers packed operations to operations on smaller packed datatypes");
110
111 } // end namespace
112
113 FunctionPass *llvm::createLowerPackedPass() { return new LowerPacked(); }
114
115
116 // This function sets lowered values for a corresponding
117 // packed value.  Note, in the case of a forward reference
118 // getValues(Value*) will have already been called for
119 // the packed parameter.  This function will then replace
120 // all references in the in the function of the "dummy"
121 // value the previous getValues(Value*) call
122 // returned with actual references.
123 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
124 {
125    std::map<Value*,std::vector<Value*> >::iterator it =
126          packedToScalarMap.lower_bound(value);
127    if (it == packedToScalarMap.end() || it->first != value) {
128        // there was not a forward reference to this element
129        packedToScalarMap.insert(it,std::make_pair(value,values));
130    }
131    else {
132       // replace forward declarations with actual definitions
133       assert(it->second.size() == values.size() &&
134              "Error forward refences and actual definition differ in size");
135       for (unsigned i = 0, e = values.size(); i != e; ++i) {
136            // replace and get rid of old forward references
137            it->second[i]->replaceAllUsesWith(values[i]);
138            delete it->second[i];
139            it->second[i] = values[i];
140       }
141    }
142 }
143
144 // This function will examine the packed value parameter
145 // and if it is a packed constant or a forward reference
146 // properly create the lowered values needed.  Otherwise
147 // it will simply retreive values from a
148 // setValues(Value*,const std::vector<Value*>&)
149 // call.  Failing both of these cases, it will abort
150 // the program.
151 std::vector<Value*>& LowerPacked::getValues(Value* value)
152 {
153    assert(isa<PackedType>(value->getType()) &&
154           "Value must be PackedType");
155
156    // reject further processing if this one has
157    // already been handled
158    std::map<Value*,std::vector<Value*> >::iterator it =
159       packedToScalarMap.lower_bound(value);
160    if (it != packedToScalarMap.end() && it->first == value) {
161        return it->second;
162    }
163
164    if (ConstantPacked* CP = dyn_cast<ConstantPacked>(value)) {
165        // non-zero constant case
166        std::vector<Value*> results;
167        results.reserve(CP->getNumOperands());
168        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
169           results.push_back(CP->getOperand(i));
170        }
171        return packedToScalarMap.insert(it,
172                                        std::make_pair(value,results))->second;
173    }
174    else if (ConstantAggregateZero* CAZ =
175             dyn_cast<ConstantAggregateZero>(value)) {
176        // zero constant
177        const PackedType* PKT = cast<PackedType>(CAZ->getType());
178        std::vector<Value*> results;
179        results.reserve(PKT->getNumElements());
180
181        Constant* C = Constant::getNullValue(PKT->getElementType());
182        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
183             results.push_back(C);
184        }
185        return packedToScalarMap.insert(it,
186                                        std::make_pair(value,results))->second;
187    }
188    else if (isa<Instruction>(value)) {
189        // foward reference
190        const PackedType* PKT = cast<PackedType>(value->getType());
191        std::vector<Value*> results;
192        results.reserve(PKT->getNumElements());
193
194       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
195            results.push_back(new Argument(PKT->getElementType()));
196       }
197       return packedToScalarMap.insert(it,
198                                       std::make_pair(value,results))->second;
199    }
200    else {
201        // we don't know what it is, and we are trying to retrieve
202        // a value for it
203        assert(false && "Unhandled PackedType value");
204        abort();
205    }
206 }
207
208 void LowerPacked::visitLoadInst(LoadInst& LI)
209 {
210    // Make sure what we are dealing with is a packed type
211    if (const PackedType* PKT = dyn_cast<PackedType>(LI.getType())) {
212        // Initialization, Idx is needed for getelementptr needed later
213        std::vector<Value*> Idx(2);
214        Idx[0] = ConstantInt::get(Type::Int32Ty,0);
215
216        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
217                                       PKT->getNumElements());
218        PointerType* APT = PointerType::get(AT);
219
220        // Cast the pointer to packed type to an equivalent array
221        Value* array = new BitCastInst(LI.getPointerOperand(), APT, 
222                                       LI.getName() + ".a", &LI);
223
224        // Convert this load into num elements number of loads
225        std::vector<Value*> values;
226        values.reserve(PKT->getNumElements());
227
228        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
229             // Calculate the second index we will need
230             Idx[1] = ConstantInt::get(Type::Int32Ty,i);
231
232             // Get the pointer
233             Value* val = new GetElementPtrInst(array,
234                                                &Idx[0], Idx.size(),
235                                                LI.getName() +
236                                                ".ge." + utostr(i),
237                                                &LI);
238
239             // generate the new load and save the result in packedToScalar map
240             values.push_back(new LoadInst(val, LI.getName()+"."+utostr(i),
241                              LI.isVolatile(), &LI));
242        }
243
244        setValues(&LI,values);
245        Changed = true;
246        instrsToRemove.push_back(&LI);
247    }
248 }
249
250 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
251 {
252    // Make sure both operands are PackedTypes
253    if (isa<PackedType>(BO.getOperand(0)->getType())) {
254        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
255        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
256        std::vector<Value*> result;
257        assert((op0Vals.size() == op1Vals.size()) &&
258               "The two packed operand to scalar maps must be equal in size.");
259
260        result.reserve(op0Vals.size());
261
262        // generate the new binary op and save the result
263        for (unsigned i = 0; i != op0Vals.size(); ++i) {
264             result.push_back(BinaryOperator::create(BO.getOpcode(),
265                                                     op0Vals[i],
266                                                     op1Vals[i],
267                                                     BO.getName() +
268                                                     "." + utostr(i),
269                                                     &BO));
270        }
271
272        setValues(&BO,result);
273        Changed = true;
274        instrsToRemove.push_back(&BO);
275    }
276 }
277
278 void LowerPacked::visitICmpInst(ICmpInst& IC)
279 {
280    // Make sure both operands are PackedTypes
281    if (isa<PackedType>(IC.getOperand(0)->getType())) {
282        std::vector<Value*>& op0Vals = getValues(IC.getOperand(0));
283        std::vector<Value*>& op1Vals = getValues(IC.getOperand(1));
284        std::vector<Value*> result;
285        assert((op0Vals.size() == op1Vals.size()) &&
286               "The two packed operand to scalar maps must be equal in size.");
287
288        result.reserve(op0Vals.size());
289
290        // generate the new binary op and save the result
291        for (unsigned i = 0; i != op0Vals.size(); ++i) {
292             result.push_back(CmpInst::create(IC.getOpcode(),
293                                              IC.getPredicate(),
294                                              op0Vals[i],
295                                              op1Vals[i],
296                                              IC.getName() +
297                                              "." + utostr(i),
298                                              &IC));
299        }
300
301        setValues(&IC,result);
302        Changed = true;
303        instrsToRemove.push_back(&IC);
304    }
305 }
306
307 void LowerPacked::visitStoreInst(StoreInst& SI)
308 {
309    if (const PackedType* PKT =
310        dyn_cast<PackedType>(SI.getOperand(0)->getType())) {
311        // We will need this for getelementptr
312        std::vector<Value*> Idx(2);
313        Idx[0] = ConstantInt::get(Type::Int32Ty,0);
314
315        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
316                                       PKT->getNumElements());
317        PointerType* APT = PointerType::get(AT);
318
319        // Cast the pointer to packed to an array of equivalent type
320        Value* array = new BitCastInst(SI.getPointerOperand(), APT, 
321                                       "store.ge.a.", &SI);
322
323        std::vector<Value*>& values = getValues(SI.getOperand(0));
324
325        assert((values.size() == PKT->getNumElements()) &&
326               "Scalar must have the same number of elements as Packed Type");
327
328        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
329             // Generate the indices for getelementptr
330             Idx[1] = ConstantInt::get(Type::Int32Ty,i);
331             Value* val = new GetElementPtrInst(array,
332                                                &Idx[0], Idx.size(),
333                                                "store.ge." +
334                                                utostr(i) + ".",
335                                                &SI);
336             new StoreInst(values[i], val, SI.isVolatile(),&SI);
337        }
338
339        Changed = true;
340        instrsToRemove.push_back(&SI);
341    }
342 }
343
344 void LowerPacked::visitSelectInst(SelectInst& SELI)
345 {
346    // Make sure both operands are PackedTypes
347    if (isa<PackedType>(SELI.getType())) {
348        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
349        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
350        std::vector<Value*> result;
351
352       assert((op0Vals.size() == op1Vals.size()) &&
353              "The two packed operand to scalar maps must be equal in size.");
354
355       for (unsigned i = 0; i != op0Vals.size(); ++i) {
356            result.push_back(new SelectInst(SELI.getCondition(),
357                                            op0Vals[i],
358                                            op1Vals[i],
359                                            SELI.getName()+ "." + utostr(i),
360                                            &SELI));
361       }
362
363       setValues(&SELI,result);
364       Changed = true;
365       instrsToRemove.push_back(&SELI);
366    }
367 }
368
369 void LowerPacked::visitExtractElementInst(ExtractElementInst& EI)
370 {
371   std::vector<Value*>& op0Vals = getValues(EI.getOperand(0));
372   const PackedType *PTy = cast<PackedType>(EI.getOperand(0)->getType());
373   Value *op1 = EI.getOperand(1);
374
375   if (ConstantInt *C = dyn_cast<ConstantInt>(op1)) {
376     EI.replaceAllUsesWith(op0Vals[C->getZExtValue()]);
377   } else {
378     AllocaInst *alloca = 
379       new AllocaInst(PTy->getElementType(),
380                      ConstantInt::get(Type::Int32Ty, PTy->getNumElements()),
381                      EI.getName() + ".alloca", 
382                      EI.getParent()->getParent()->getEntryBlock().begin());
383     for (unsigned i = 0; i < PTy->getNumElements(); ++i) {
384       GetElementPtrInst *GEP = 
385         new GetElementPtrInst(alloca, ConstantInt::get(Type::Int32Ty, i),
386                               "store.ge", &EI);
387       new StoreInst(op0Vals[i], GEP, &EI);
388     }
389     GetElementPtrInst *GEP = 
390       new GetElementPtrInst(alloca, op1, EI.getName() + ".ge", &EI);
391     LoadInst *load = new LoadInst(GEP, EI.getName() + ".load", &EI);
392     EI.replaceAllUsesWith(load);
393   }
394
395   Changed = true;
396   instrsToRemove.push_back(&EI);
397 }
398
399 void LowerPacked::visitInsertElementInst(InsertElementInst& IE)
400 {
401   std::vector<Value*>& Vals = getValues(IE.getOperand(0));
402   Value *Elt = IE.getOperand(1);
403   Value *Idx = IE.getOperand(2);
404   std::vector<Value*> result;
405   result.reserve(Vals.size());
406
407   if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
408     unsigned idxVal = C->getZExtValue();
409     for (unsigned i = 0; i != Vals.size(); ++i) {
410       result.push_back(i == idxVal ? Elt : Vals[i]);
411     }
412   } else {
413     for (unsigned i = 0; i != Vals.size(); ++i) {
414       ICmpInst *icmp =
415         new ICmpInst(ICmpInst::ICMP_EQ, Idx, 
416                      ConstantInt::get(Type::Int32Ty, i),
417                      "icmp", &IE);
418       SelectInst *select =
419         new SelectInst(icmp, Elt, Vals[i], "select", &IE);
420       result.push_back(select);
421     }
422   }
423
424   setValues(&IE, result);
425   Changed = true;
426   instrsToRemove.push_back(&IE);
427 }
428
429 bool LowerPacked::runOnFunction(Function& F)
430 {
431    // initialize
432    Changed = false;
433
434    // Does three passes:
435    // Pass 1) Converts Packed Operations to
436    //         new Packed Operations on smaller
437    //         datatypes
438    visit(F);
439
440    // Pass 2) Drop all references
441    std::for_each(instrsToRemove.begin(),
442                  instrsToRemove.end(),
443                  std::mem_fun(&Instruction::dropAllReferences));
444
445    // Pass 3) Delete the Instructions to remove aka packed instructions
446    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(),
447                                             e = instrsToRemove.end();
448         i != e; ++i) {
449         (*i)->getParent()->getInstList().erase(*i);
450    }
451
452    // clean-up
453    packedToScalarMap.clear();
454    instrsToRemove.clear();
455
456    return Changed;
457 }
458