Removed #include <iostream> and replaced with llvm_* streams.
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Brad Jones and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering Packed datatypes into more primitive
11 // Packed datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Argument.h"
17 #include "llvm/Constants.h"
18 #include "llvm/DerivedTypes.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Support/InstVisitor.h"
23 #include "llvm/Support/Streams.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include <algorithm>
26 #include <map>
27 #include <functional>
28 using namespace llvm;
29
30 namespace {
31
32 /// This pass converts packed operators to an
33 /// equivalent operations on smaller packed data, to possibly
34 /// scalar operations.  Currently it supports lowering
35 /// to scalar operations.
36 ///
37 /// @brief Transforms packed instructions to simpler instructions.
38 ///
39 class LowerPacked : public FunctionPass, public InstVisitor<LowerPacked> {
40 public:
41    /// @brief Lowers packed operations to scalar operations.
42    /// @param F The fuction to process
43    virtual bool runOnFunction(Function &F);
44
45    /// @brief Lowers packed load instructions.
46    /// @param LI the load instruction to convert
47    void visitLoadInst(LoadInst& LI);
48
49    /// @brief Lowers packed store instructions.
50    /// @param SI the store instruction to convert
51    void visitStoreInst(StoreInst& SI);
52
53    /// @brief Lowers packed binary operations.
54    /// @param BO the binary operator to convert
55    void visitBinaryOperator(BinaryOperator& BO);
56
57    /// @brief Lowers packed select instructions.
58    /// @param SELI the select operator to convert
59    void visitSelectInst(SelectInst& SELI);
60
61    /// @brief Lowers packed extractelement instructions.
62    /// @param EI the extractelement operator to convert
63    void visitExtractElementInst(ExtractElementInst& EE);
64
65    /// @brief Lowers packed insertelement instructions.
66    /// @param EI the insertelement operator to convert
67    void visitInsertElementInst(InsertElementInst& IE);
68
69    /// This function asserts if the instruction is a PackedType but
70    /// is handled by another function.
71    ///
72    /// @brief Asserts if PackedType instruction is not handled elsewhere.
73    /// @param I the unhandled instruction
74    void visitInstruction(Instruction &I) {
75      if (isa<PackedType>(I.getType()))
76        llvm_cerr << "Unhandled Instruction with Packed ReturnType: "
77                  << I << '\n';
78    }
79 private:
80    /// @brief Retrieves lowered values for a packed value.
81    /// @param val the packed value
82    /// @return the lowered values
83    std::vector<Value*>& getValues(Value* val);
84
85    /// @brief Sets lowered values for a packed value.
86    /// @param val the packed value
87    /// @param values the corresponding lowered values
88    void setValues(Value* val,const std::vector<Value*>& values);
89
90    // Data Members
91    /// @brief whether we changed the function or not
92    bool Changed;
93
94    /// @brief a map from old packed values to new smaller packed values
95    std::map<Value*,std::vector<Value*> > packedToScalarMap;
96
97    /// Instructions in the source program to get rid of
98    /// after we do a pass (the old packed instructions)
99    std::vector<Instruction*> instrsToRemove;
100 };
101
102 RegisterPass<LowerPacked>
103 X("lower-packed",
104   "lowers packed operations to operations on smaller packed datatypes");
105
106 } // end namespace
107
108 FunctionPass *llvm::createLowerPackedPass() { return new LowerPacked(); }
109
110
111 // This function sets lowered values for a corresponding
112 // packed value.  Note, in the case of a forward reference
113 // getValues(Value*) will have already been called for
114 // the packed parameter.  This function will then replace
115 // all references in the in the function of the "dummy"
116 // value the previous getValues(Value*) call
117 // returned with actual references.
118 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
119 {
120    std::map<Value*,std::vector<Value*> >::iterator it =
121          packedToScalarMap.lower_bound(value);
122    if (it == packedToScalarMap.end() || it->first != value) {
123        // there was not a forward reference to this element
124        packedToScalarMap.insert(it,std::make_pair(value,values));
125    }
126    else {
127       // replace forward declarations with actual definitions
128       assert(it->second.size() == values.size() &&
129              "Error forward refences and actual definition differ in size");
130       for (unsigned i = 0, e = values.size(); i != e; ++i) {
131            // replace and get rid of old forward references
132            it->second[i]->replaceAllUsesWith(values[i]);
133            delete it->second[i];
134            it->second[i] = values[i];
135       }
136    }
137 }
138
139 // This function will examine the packed value parameter
140 // and if it is a packed constant or a forward reference
141 // properly create the lowered values needed.  Otherwise
142 // it will simply retreive values from a
143 // setValues(Value*,const std::vector<Value*>&)
144 // call.  Failing both of these cases, it will abort
145 // the program.
146 std::vector<Value*>& LowerPacked::getValues(Value* value)
147 {
148    assert(isa<PackedType>(value->getType()) &&
149           "Value must be PackedType");
150
151    // reject further processing if this one has
152    // already been handled
153    std::map<Value*,std::vector<Value*> >::iterator it =
154       packedToScalarMap.lower_bound(value);
155    if (it != packedToScalarMap.end() && it->first == value) {
156        return it->second;
157    }
158
159    if (ConstantPacked* CP = dyn_cast<ConstantPacked>(value)) {
160        // non-zero constant case
161        std::vector<Value*> results;
162        results.reserve(CP->getNumOperands());
163        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
164           results.push_back(CP->getOperand(i));
165        }
166        return packedToScalarMap.insert(it,
167                                        std::make_pair(value,results))->second;
168    }
169    else if (ConstantAggregateZero* CAZ =
170             dyn_cast<ConstantAggregateZero>(value)) {
171        // zero constant
172        const PackedType* PKT = cast<PackedType>(CAZ->getType());
173        std::vector<Value*> results;
174        results.reserve(PKT->getNumElements());
175
176        Constant* C = Constant::getNullValue(PKT->getElementType());
177        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
178             results.push_back(C);
179        }
180        return packedToScalarMap.insert(it,
181                                        std::make_pair(value,results))->second;
182    }
183    else if (isa<Instruction>(value)) {
184        // foward reference
185        const PackedType* PKT = cast<PackedType>(value->getType());
186        std::vector<Value*> results;
187        results.reserve(PKT->getNumElements());
188
189       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
190            results.push_back(new Argument(PKT->getElementType()));
191       }
192       return packedToScalarMap.insert(it,
193                                       std::make_pair(value,results))->second;
194    }
195    else {
196        // we don't know what it is, and we are trying to retrieve
197        // a value for it
198        assert(false && "Unhandled PackedType value");
199        abort();
200    }
201 }
202
203 void LowerPacked::visitLoadInst(LoadInst& LI)
204 {
205    // Make sure what we are dealing with is a packed type
206    if (const PackedType* PKT = dyn_cast<PackedType>(LI.getType())) {
207        // Initialization, Idx is needed for getelementptr needed later
208        std::vector<Value*> Idx(2);
209        Idx[0] = ConstantInt::get(Type::UIntTy,0);
210
211        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
212                                       PKT->getNumElements());
213        PointerType* APT = PointerType::get(AT);
214
215        // Cast the packed type to an array
216        Value* array = new CastInst(LI.getPointerOperand(),
217                                    APT,
218                                    LI.getName() + ".a",
219                                    &LI);
220
221        // Convert this load into num elements number of loads
222        std::vector<Value*> values;
223        values.reserve(PKT->getNumElements());
224
225        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
226             // Calculate the second index we will need
227             Idx[1] = ConstantInt::get(Type::UIntTy,i);
228
229             // Get the pointer
230             Value* val = new GetElementPtrInst(array,
231                                                Idx,
232                                                LI.getName() +
233                                                ".ge." + utostr(i),
234                                                &LI);
235
236             // generate the new load and save the result in packedToScalar map
237             values.push_back(new LoadInst(val,
238                              LI.getName()+"."+utostr(i),
239                              LI.isVolatile(),
240                              &LI));
241        }
242
243        setValues(&LI,values);
244        Changed = true;
245        instrsToRemove.push_back(&LI);
246    }
247 }
248
249 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
250 {
251    // Make sure both operands are PackedTypes
252    if (isa<PackedType>(BO.getOperand(0)->getType())) {
253        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
254        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
255        std::vector<Value*> result;
256        assert((op0Vals.size() == op1Vals.size()) &&
257               "The two packed operand to scalar maps must be equal in size.");
258
259        result.reserve(op0Vals.size());
260
261        // generate the new binary op and save the result
262        for (unsigned i = 0; i != op0Vals.size(); ++i) {
263             result.push_back(BinaryOperator::create(BO.getOpcode(),
264                                                     op0Vals[i],
265                                                     op1Vals[i],
266                                                     BO.getName() +
267                                                     "." + utostr(i),
268                                                     &BO));
269        }
270
271        setValues(&BO,result);
272        Changed = true;
273        instrsToRemove.push_back(&BO);
274    }
275 }
276
277 void LowerPacked::visitStoreInst(StoreInst& SI)
278 {
279    if (const PackedType* PKT =
280        dyn_cast<PackedType>(SI.getOperand(0)->getType())) {
281        // We will need this for getelementptr
282        std::vector<Value*> Idx(2);
283        Idx[0] = ConstantInt::get(Type::UIntTy,0);
284
285        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
286                                       PKT->getNumElements());
287        PointerType* APT = PointerType::get(AT);
288
289        // cast the packed to an array type
290        Value* array = new CastInst(SI.getPointerOperand(),
291                                    APT,
292                                    "store.ge.a.",
293                                    &SI);
294        std::vector<Value*>& values = getValues(SI.getOperand(0));
295
296        assert((values.size() == PKT->getNumElements()) &&
297               "Scalar must have the same number of elements as Packed Type");
298
299        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
300             // Generate the indices for getelementptr
301             Idx[1] = ConstantInt::get(Type::UIntTy,i);
302             Value* val = new GetElementPtrInst(array,
303                                                Idx,
304                                                "store.ge." +
305                                                utostr(i) + ".",
306                                                &SI);
307             new StoreInst(values[i], val, SI.isVolatile(),&SI);
308        }
309
310        Changed = true;
311        instrsToRemove.push_back(&SI);
312    }
313 }
314
315 void LowerPacked::visitSelectInst(SelectInst& SELI)
316 {
317    // Make sure both operands are PackedTypes
318    if (isa<PackedType>(SELI.getType())) {
319        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
320        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
321        std::vector<Value*> result;
322
323       assert((op0Vals.size() == op1Vals.size()) &&
324              "The two packed operand to scalar maps must be equal in size.");
325
326       for (unsigned i = 0; i != op0Vals.size(); ++i) {
327            result.push_back(new SelectInst(SELI.getCondition(),
328                                            op0Vals[i],
329                                            op1Vals[i],
330                                            SELI.getName()+ "." + utostr(i),
331                                            &SELI));
332       }
333
334       setValues(&SELI,result);
335       Changed = true;
336       instrsToRemove.push_back(&SELI);
337    }
338 }
339
340 void LowerPacked::visitExtractElementInst(ExtractElementInst& EI)
341 {
342   std::vector<Value*>& op0Vals = getValues(EI.getOperand(0));
343   const PackedType *PTy = cast<PackedType>(EI.getOperand(0)->getType());
344   Value *op1 = EI.getOperand(1);
345
346   if (ConstantInt *C = dyn_cast<ConstantInt>(op1)) {
347     EI.replaceAllUsesWith(op0Vals[C->getZExtValue()]);
348   } else {
349     AllocaInst *alloca = 
350       new AllocaInst(PTy->getElementType(),
351                      ConstantInt::get(Type::UIntTy, PTy->getNumElements()),
352                      EI.getName() + ".alloca", 
353                      EI.getParent()->getParent()->getEntryBlock().begin());
354     for (unsigned i = 0; i < PTy->getNumElements(); ++i) {
355       GetElementPtrInst *GEP = 
356         new GetElementPtrInst(alloca, ConstantInt::get(Type::UIntTy, i),
357                               "store.ge", &EI);
358       new StoreInst(op0Vals[i], GEP, &EI);
359     }
360     GetElementPtrInst *GEP = 
361       new GetElementPtrInst(alloca, op1, EI.getName() + ".ge", &EI);
362     LoadInst *load = new LoadInst(GEP, EI.getName() + ".load", &EI);
363     EI.replaceAllUsesWith(load);
364   }
365
366   Changed = true;
367   instrsToRemove.push_back(&EI);
368 }
369
370 void LowerPacked::visitInsertElementInst(InsertElementInst& IE)
371 {
372   std::vector<Value*>& Vals = getValues(IE.getOperand(0));
373   Value *Elt = IE.getOperand(1);
374   Value *Idx = IE.getOperand(2);
375   std::vector<Value*> result;
376   result.reserve(Vals.size());
377
378   if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
379     unsigned idxVal = C->getZExtValue();
380     for (unsigned i = 0; i != Vals.size(); ++i) {
381       result.push_back(i == idxVal ? Elt : Vals[i]);
382     }
383   } else {
384     for (unsigned i = 0; i != Vals.size(); ++i) {
385       SetCondInst *setcc =
386         new SetCondInst(Instruction::SetEQ, Idx, 
387                         ConstantInt::get(Type::UIntTy, i),
388                         "setcc", &IE);
389       SelectInst *select =
390         new SelectInst(setcc, Elt, Vals[i], "select", &IE);
391       result.push_back(select);
392     }
393   }
394
395   setValues(&IE, result);
396   Changed = true;
397   instrsToRemove.push_back(&IE);
398 }
399
400 bool LowerPacked::runOnFunction(Function& F)
401 {
402    // initialize
403    Changed = false;
404
405    // Does three passes:
406    // Pass 1) Converts Packed Operations to
407    //         new Packed Operations on smaller
408    //         datatypes
409    visit(F);
410
411    // Pass 2) Drop all references
412    std::for_each(instrsToRemove.begin(),
413                  instrsToRemove.end(),
414                  std::mem_fun(&Instruction::dropAllReferences));
415
416    // Pass 3) Delete the Instructions to remove aka packed instructions
417    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(),
418                                             e = instrsToRemove.end();
419         i != e; ++i) {
420         (*i)->getParent()->getInstList().erase(*i);
421    }
422
423    // clean-up
424    packedToScalarMap.clear();
425    instrsToRemove.clear();
426
427    return Changed;
428 }
429