7763921d1843e81d3fdbbdd006d27ee289ef5b02
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Brad Jones and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering Packed datatypes into more primitive
11 // Packed datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Argument.h"
17 #include "llvm/Constants.h"
18 #include "llvm/DerivedTypes.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Support/InstVisitor.h"
23 #include "llvm/Support/Streams.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include <algorithm>
26 #include <map>
27 #include <functional>
28 using namespace llvm;
29
30 namespace {
31
32 /// This pass converts packed operators to an
33 /// equivalent operations on smaller packed data, to possibly
34 /// scalar operations.  Currently it supports lowering
35 /// to scalar operations.
36 ///
37 /// @brief Transforms packed instructions to simpler instructions.
38 ///
39 class LowerPacked : public FunctionPass, public InstVisitor<LowerPacked> {
40 public:
41    /// @brief Lowers packed operations to scalar operations.
42    /// @param F The fuction to process
43    virtual bool runOnFunction(Function &F);
44
45    /// @brief Lowers packed load instructions.
46    /// @param LI the load instruction to convert
47    void visitLoadInst(LoadInst& LI);
48
49    /// @brief Lowers packed store instructions.
50    /// @param SI the store instruction to convert
51    void visitStoreInst(StoreInst& SI);
52
53    /// @brief Lowers packed binary operations.
54    /// @param BO the binary operator to convert
55    void visitBinaryOperator(BinaryOperator& BO);
56
57    /// @brief Lowers packed icmp operations.
58    /// @param CI the icmp operator to convert
59    void visitICmpInst(ICmpInst& IC);
60
61    /// @brief Lowers packed select instructions.
62    /// @param SELI the select operator to convert
63    void visitSelectInst(SelectInst& SELI);
64
65    /// @brief Lowers packed extractelement instructions.
66    /// @param EI the extractelement operator to convert
67    void visitExtractElementInst(ExtractElementInst& EE);
68
69    /// @brief Lowers packed insertelement instructions.
70    /// @param EI the insertelement operator to convert
71    void visitInsertElementInst(InsertElementInst& IE);
72
73    /// This function asserts if the instruction is a PackedType but
74    /// is handled by another function.
75    ///
76    /// @brief Asserts if PackedType instruction is not handled elsewhere.
77    /// @param I the unhandled instruction
78    void visitInstruction(Instruction &I) {
79      if (isa<PackedType>(I.getType()))
80        cerr << "Unhandled Instruction with Packed ReturnType: " << I << '\n';
81    }
82 private:
83    /// @brief Retrieves lowered values for a packed value.
84    /// @param val the packed value
85    /// @return the lowered values
86    std::vector<Value*>& getValues(Value* val);
87
88    /// @brief Sets lowered values for a packed value.
89    /// @param val the packed value
90    /// @param values the corresponding lowered values
91    void setValues(Value* val,const std::vector<Value*>& values);
92
93    // Data Members
94    /// @brief whether we changed the function or not
95    bool Changed;
96
97    /// @brief a map from old packed values to new smaller packed values
98    std::map<Value*,std::vector<Value*> > packedToScalarMap;
99
100    /// Instructions in the source program to get rid of
101    /// after we do a pass (the old packed instructions)
102    std::vector<Instruction*> instrsToRemove;
103 };
104
105 RegisterPass<LowerPacked>
106 X("lower-packed",
107   "lowers packed operations to operations on smaller packed datatypes");
108
109 } // end namespace
110
111 FunctionPass *llvm::createLowerPackedPass() { return new LowerPacked(); }
112
113
114 // This function sets lowered values for a corresponding
115 // packed value.  Note, in the case of a forward reference
116 // getValues(Value*) will have already been called for
117 // the packed parameter.  This function will then replace
118 // all references in the in the function of the "dummy"
119 // value the previous getValues(Value*) call
120 // returned with actual references.
121 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
122 {
123    std::map<Value*,std::vector<Value*> >::iterator it =
124          packedToScalarMap.lower_bound(value);
125    if (it == packedToScalarMap.end() || it->first != value) {
126        // there was not a forward reference to this element
127        packedToScalarMap.insert(it,std::make_pair(value,values));
128    }
129    else {
130       // replace forward declarations with actual definitions
131       assert(it->second.size() == values.size() &&
132              "Error forward refences and actual definition differ in size");
133       for (unsigned i = 0, e = values.size(); i != e; ++i) {
134            // replace and get rid of old forward references
135            it->second[i]->replaceAllUsesWith(values[i]);
136            delete it->second[i];
137            it->second[i] = values[i];
138       }
139    }
140 }
141
142 // This function will examine the packed value parameter
143 // and if it is a packed constant or a forward reference
144 // properly create the lowered values needed.  Otherwise
145 // it will simply retreive values from a
146 // setValues(Value*,const std::vector<Value*>&)
147 // call.  Failing both of these cases, it will abort
148 // the program.
149 std::vector<Value*>& LowerPacked::getValues(Value* value)
150 {
151    assert(isa<PackedType>(value->getType()) &&
152           "Value must be PackedType");
153
154    // reject further processing if this one has
155    // already been handled
156    std::map<Value*,std::vector<Value*> >::iterator it =
157       packedToScalarMap.lower_bound(value);
158    if (it != packedToScalarMap.end() && it->first == value) {
159        return it->second;
160    }
161
162    if (ConstantPacked* CP = dyn_cast<ConstantPacked>(value)) {
163        // non-zero constant case
164        std::vector<Value*> results;
165        results.reserve(CP->getNumOperands());
166        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
167           results.push_back(CP->getOperand(i));
168        }
169        return packedToScalarMap.insert(it,
170                                        std::make_pair(value,results))->second;
171    }
172    else if (ConstantAggregateZero* CAZ =
173             dyn_cast<ConstantAggregateZero>(value)) {
174        // zero constant
175        const PackedType* PKT = cast<PackedType>(CAZ->getType());
176        std::vector<Value*> results;
177        results.reserve(PKT->getNumElements());
178
179        Constant* C = Constant::getNullValue(PKT->getElementType());
180        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
181             results.push_back(C);
182        }
183        return packedToScalarMap.insert(it,
184                                        std::make_pair(value,results))->second;
185    }
186    else if (isa<Instruction>(value)) {
187        // foward reference
188        const PackedType* PKT = cast<PackedType>(value->getType());
189        std::vector<Value*> results;
190        results.reserve(PKT->getNumElements());
191
192       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
193            results.push_back(new Argument(PKT->getElementType()));
194       }
195       return packedToScalarMap.insert(it,
196                                       std::make_pair(value,results))->second;
197    }
198    else {
199        // we don't know what it is, and we are trying to retrieve
200        // a value for it
201        assert(false && "Unhandled PackedType value");
202        abort();
203    }
204 }
205
206 void LowerPacked::visitLoadInst(LoadInst& LI)
207 {
208    // Make sure what we are dealing with is a packed type
209    if (const PackedType* PKT = dyn_cast<PackedType>(LI.getType())) {
210        // Initialization, Idx is needed for getelementptr needed later
211        std::vector<Value*> Idx(2);
212        Idx[0] = ConstantInt::get(Type::UIntTy,0);
213
214        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
215                                       PKT->getNumElements());
216        PointerType* APT = PointerType::get(AT);
217
218        // Cast the pointer to packed type to an equivalent array
219        Value* array = new BitCastInst(LI.getPointerOperand(), APT, 
220                                       LI.getName() + ".a", &LI);
221
222        // Convert this load into num elements number of loads
223        std::vector<Value*> values;
224        values.reserve(PKT->getNumElements());
225
226        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
227             // Calculate the second index we will need
228             Idx[1] = ConstantInt::get(Type::UIntTy,i);
229
230             // Get the pointer
231             Value* val = new GetElementPtrInst(array,
232                                                Idx,
233                                                LI.getName() +
234                                                ".ge." + utostr(i),
235                                                &LI);
236
237             // generate the new load and save the result in packedToScalar map
238             values.push_back(new LoadInst(val, LI.getName()+"."+utostr(i),
239                              LI.isVolatile(), &LI));
240        }
241
242        setValues(&LI,values);
243        Changed = true;
244        instrsToRemove.push_back(&LI);
245    }
246 }
247
248 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
249 {
250    // Make sure both operands are PackedTypes
251    if (isa<PackedType>(BO.getOperand(0)->getType())) {
252        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
253        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
254        std::vector<Value*> result;
255        assert((op0Vals.size() == op1Vals.size()) &&
256               "The two packed operand to scalar maps must be equal in size.");
257
258        result.reserve(op0Vals.size());
259
260        // generate the new binary op and save the result
261        for (unsigned i = 0; i != op0Vals.size(); ++i) {
262             result.push_back(BinaryOperator::create(BO.getOpcode(),
263                                                     op0Vals[i],
264                                                     op1Vals[i],
265                                                     BO.getName() +
266                                                     "." + utostr(i),
267                                                     &BO));
268        }
269
270        setValues(&BO,result);
271        Changed = true;
272        instrsToRemove.push_back(&BO);
273    }
274 }
275
276 void LowerPacked::visitICmpInst(ICmpInst& IC)
277 {
278    // Make sure both operands are PackedTypes
279    if (isa<PackedType>(IC.getOperand(0)->getType())) {
280        std::vector<Value*>& op0Vals = getValues(IC.getOperand(0));
281        std::vector<Value*>& op1Vals = getValues(IC.getOperand(1));
282        std::vector<Value*> result;
283        assert((op0Vals.size() == op1Vals.size()) &&
284               "The two packed operand to scalar maps must be equal in size.");
285
286        result.reserve(op0Vals.size());
287
288        // generate the new binary op and save the result
289        for (unsigned i = 0; i != op0Vals.size(); ++i) {
290             result.push_back(CmpInst::create(IC.getOpcode(),
291                                              IC.getPredicate(),
292                                              op0Vals[i],
293                                              op1Vals[i],
294                                              IC.getName() +
295                                              "." + utostr(i),
296                                              &IC));
297        }
298
299        setValues(&IC,result);
300        Changed = true;
301        instrsToRemove.push_back(&IC);
302    }
303 }
304
305 void LowerPacked::visitStoreInst(StoreInst& SI)
306 {
307    if (const PackedType* PKT =
308        dyn_cast<PackedType>(SI.getOperand(0)->getType())) {
309        // We will need this for getelementptr
310        std::vector<Value*> Idx(2);
311        Idx[0] = ConstantInt::get(Type::UIntTy,0);
312
313        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
314                                       PKT->getNumElements());
315        PointerType* APT = PointerType::get(AT);
316
317        // Cast the pointer to packed to an array of equivalent type
318        Value* array = new BitCastInst(SI.getPointerOperand(), APT, 
319                                       "store.ge.a.", &SI);
320
321        std::vector<Value*>& values = getValues(SI.getOperand(0));
322
323        assert((values.size() == PKT->getNumElements()) &&
324               "Scalar must have the same number of elements as Packed Type");
325
326        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
327             // Generate the indices for getelementptr
328             Idx[1] = ConstantInt::get(Type::UIntTy,i);
329             Value* val = new GetElementPtrInst(array,
330                                                Idx,
331                                                "store.ge." +
332                                                utostr(i) + ".",
333                                                &SI);
334             new StoreInst(values[i], val, SI.isVolatile(),&SI);
335        }
336
337        Changed = true;
338        instrsToRemove.push_back(&SI);
339    }
340 }
341
342 void LowerPacked::visitSelectInst(SelectInst& SELI)
343 {
344    // Make sure both operands are PackedTypes
345    if (isa<PackedType>(SELI.getType())) {
346        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
347        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
348        std::vector<Value*> result;
349
350       assert((op0Vals.size() == op1Vals.size()) &&
351              "The two packed operand to scalar maps must be equal in size.");
352
353       for (unsigned i = 0; i != op0Vals.size(); ++i) {
354            result.push_back(new SelectInst(SELI.getCondition(),
355                                            op0Vals[i],
356                                            op1Vals[i],
357                                            SELI.getName()+ "." + utostr(i),
358                                            &SELI));
359       }
360
361       setValues(&SELI,result);
362       Changed = true;
363       instrsToRemove.push_back(&SELI);
364    }
365 }
366
367 void LowerPacked::visitExtractElementInst(ExtractElementInst& EI)
368 {
369   std::vector<Value*>& op0Vals = getValues(EI.getOperand(0));
370   const PackedType *PTy = cast<PackedType>(EI.getOperand(0)->getType());
371   Value *op1 = EI.getOperand(1);
372
373   if (ConstantInt *C = dyn_cast<ConstantInt>(op1)) {
374     EI.replaceAllUsesWith(op0Vals[C->getZExtValue()]);
375   } else {
376     AllocaInst *alloca = 
377       new AllocaInst(PTy->getElementType(),
378                      ConstantInt::get(Type::UIntTy, PTy->getNumElements()),
379                      EI.getName() + ".alloca", 
380                      EI.getParent()->getParent()->getEntryBlock().begin());
381     for (unsigned i = 0; i < PTy->getNumElements(); ++i) {
382       GetElementPtrInst *GEP = 
383         new GetElementPtrInst(alloca, ConstantInt::get(Type::UIntTy, i),
384                               "store.ge", &EI);
385       new StoreInst(op0Vals[i], GEP, &EI);
386     }
387     GetElementPtrInst *GEP = 
388       new GetElementPtrInst(alloca, op1, EI.getName() + ".ge", &EI);
389     LoadInst *load = new LoadInst(GEP, EI.getName() + ".load", &EI);
390     EI.replaceAllUsesWith(load);
391   }
392
393   Changed = true;
394   instrsToRemove.push_back(&EI);
395 }
396
397 void LowerPacked::visitInsertElementInst(InsertElementInst& IE)
398 {
399   std::vector<Value*>& Vals = getValues(IE.getOperand(0));
400   Value *Elt = IE.getOperand(1);
401   Value *Idx = IE.getOperand(2);
402   std::vector<Value*> result;
403   result.reserve(Vals.size());
404
405   if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
406     unsigned idxVal = C->getZExtValue();
407     for (unsigned i = 0; i != Vals.size(); ++i) {
408       result.push_back(i == idxVal ? Elt : Vals[i]);
409     }
410   } else {
411     for (unsigned i = 0; i != Vals.size(); ++i) {
412       ICmpInst *icmp =
413         new ICmpInst(ICmpInst::ICMP_EQ, Idx, 
414                      ConstantInt::get(Type::UIntTy, i),
415                      "icmp", &IE);
416       SelectInst *select =
417         new SelectInst(icmp, Elt, Vals[i], "select", &IE);
418       result.push_back(select);
419     }
420   }
421
422   setValues(&IE, result);
423   Changed = true;
424   instrsToRemove.push_back(&IE);
425 }
426
427 bool LowerPacked::runOnFunction(Function& F)
428 {
429    // initialize
430    Changed = false;
431
432    // Does three passes:
433    // Pass 1) Converts Packed Operations to
434    //         new Packed Operations on smaller
435    //         datatypes
436    visit(F);
437
438    // Pass 2) Drop all references
439    std::for_each(instrsToRemove.begin(),
440                  instrsToRemove.end(),
441                  std::mem_fun(&Instruction::dropAllReferences));
442
443    // Pass 3) Delete the Instructions to remove aka packed instructions
444    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(),
445                                             e = instrsToRemove.end();
446         i != e; ++i) {
447         (*i)->getParent()->getInstList().erase(*i);
448    }
449
450    // clean-up
451    packedToScalarMap.clear();
452    instrsToRemove.clear();
453
454    return Changed;
455 }
456