Remove attribution from file headers, per discussion on llvmdev.
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering vector datatypes into more primitive
11 // vector datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Argument.h"
17 #include "llvm/Constants.h"
18 #include "llvm/DerivedTypes.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Support/Compiler.h"
23 #include "llvm/Support/InstVisitor.h"
24 #include "llvm/Support/Streams.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include <algorithm>
28 #include <map>
29 #include <functional>
30 using namespace llvm;
31
32 namespace {
33
34 /// This pass converts packed operators to an
35 /// equivalent operations on smaller packed data, to possibly
36 /// scalar operations.  Currently it supports lowering
37 /// to scalar operations.
38 ///
39 /// @brief Transforms packed instructions to simpler instructions.
40 ///
41 class VISIBILITY_HIDDEN LowerPacked 
42   : public FunctionPass, public InstVisitor<LowerPacked> {
43 public:
44     static char ID; // Pass identification, replacement for typeid
45     LowerPacked() : FunctionPass((intptr_t)&ID) {}
46
47    /// @brief Lowers packed operations to scalar operations.
48    /// @param F The fuction to process
49    virtual bool runOnFunction(Function &F);
50
51    /// @brief Lowers packed load instructions.
52    /// @param LI the load instruction to convert
53    void visitLoadInst(LoadInst& LI);
54
55    /// @brief Lowers packed store instructions.
56    /// @param SI the store instruction to convert
57    void visitStoreInst(StoreInst& SI);
58
59    /// @brief Lowers packed binary operations.
60    /// @param BO the binary operator to convert
61    void visitBinaryOperator(BinaryOperator& BO);
62
63    /// @brief Lowers packed icmp operations.
64    /// @param IC the icmp operator to convert
65    void visitICmpInst(ICmpInst& IC);
66
67    /// @brief Lowers packed select instructions.
68    /// @param SELI the select operator to convert
69    void visitSelectInst(SelectInst& SELI);
70
71    /// @brief Lowers packed extractelement instructions.
72    /// @param EE the extractelement operator to convert
73    void visitExtractElementInst(ExtractElementInst& EE);
74
75    /// @brief Lowers packed insertelement instructions.
76    /// @param IE the insertelement operator to convert
77    void visitInsertElementInst(InsertElementInst& IE);
78
79    /// This function asserts that the given instruction does not have
80    /// vector type. Instructions with vector type should be handled by
81    /// the other functions in this class.
82    ///
83    /// @brief Asserts if VectorType instruction is not handled elsewhere.
84    /// @param I the unhandled instruction
85    void visitInstruction(Instruction &I) {
86      assert(!isa<VectorType>(I.getType()) &&
87             "Unhandled Instruction with Packed ReturnType!");
88    }
89 private:
90    /// @brief Retrieves lowered values for a packed value.
91    /// @param val the packed value
92    /// @return the lowered values
93    std::vector<Value*>& getValues(Value* val);
94
95    /// @brief Sets lowered values for a packed value.
96    /// @param val the packed value
97    /// @param values the corresponding lowered values
98    void setValues(Value* val,const std::vector<Value*>& values);
99
100    // Data Members
101    /// @brief whether we changed the function or not
102    bool Changed;
103
104    /// @brief a map from old packed values to new smaller packed values
105    std::map<Value*,std::vector<Value*> > packedToScalarMap;
106
107    /// Instructions in the source program to get rid of
108    /// after we do a pass (the old packed instructions)
109    std::vector<Instruction*> instrsToRemove;
110 };
111
112 char LowerPacked::ID = 0;
113 RegisterPass<LowerPacked>
114 X("lower-packed",
115   "lowers packed operations to operations on smaller packed datatypes");
116
117 } // end namespace
118
119 FunctionPass *llvm::createLowerPackedPass() { return new LowerPacked(); }
120
121
122 // This function sets lowered values for a corresponding
123 // packed value.  Note, in the case of a forward reference
124 // getValues(Value*) will have already been called for
125 // the packed parameter.  This function will then replace
126 // all references in the in the function of the "dummy"
127 // value the previous getValues(Value*) call
128 // returned with actual references.
129 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
130 {
131    std::map<Value*,std::vector<Value*> >::iterator it =
132          packedToScalarMap.lower_bound(value);
133    if (it == packedToScalarMap.end() || it->first != value) {
134        // there was not a forward reference to this element
135        packedToScalarMap.insert(it,std::make_pair(value,values));
136    }
137    else {
138       // replace forward declarations with actual definitions
139       assert(it->second.size() == values.size() &&
140              "Error forward refences and actual definition differ in size");
141       for (unsigned i = 0, e = values.size(); i != e; ++i) {
142            // replace and get rid of old forward references
143            it->second[i]->replaceAllUsesWith(values[i]);
144            delete it->second[i];
145            it->second[i] = values[i];
146       }
147    }
148 }
149
150 // This function will examine the packed value parameter
151 // and if it is a packed constant or a forward reference
152 // properly create the lowered values needed.  Otherwise
153 // it will simply retreive values from a
154 // setValues(Value*,const std::vector<Value*>&)
155 // call.  Failing both of these cases, it will abort
156 // the program.
157 std::vector<Value*>& LowerPacked::getValues(Value* value)
158 {
159    assert(isa<VectorType>(value->getType()) &&
160           "Value must be VectorType");
161
162    // reject further processing if this one has
163    // already been handled
164    std::map<Value*,std::vector<Value*> >::iterator it =
165       packedToScalarMap.lower_bound(value);
166    if (it != packedToScalarMap.end() && it->first == value) {
167        return it->second;
168    }
169
170    if (ConstantVector* CP = dyn_cast<ConstantVector>(value)) {
171        // non-zero constant case
172        std::vector<Value*> results;
173        results.reserve(CP->getNumOperands());
174        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
175           results.push_back(CP->getOperand(i));
176        }
177        return packedToScalarMap.insert(it,
178                                        std::make_pair(value,results))->second;
179    }
180    else if (ConstantAggregateZero* CAZ =
181             dyn_cast<ConstantAggregateZero>(value)) {
182        // zero constant
183        const VectorType* PKT = cast<VectorType>(CAZ->getType());
184        std::vector<Value*> results;
185        results.reserve(PKT->getNumElements());
186
187        Constant* C = Constant::getNullValue(PKT->getElementType());
188        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
189             results.push_back(C);
190        }
191        return packedToScalarMap.insert(it,
192                                        std::make_pair(value,results))->second;
193    }
194    else if (isa<Instruction>(value)) {
195        // foward reference
196        const VectorType* PKT = cast<VectorType>(value->getType());
197        std::vector<Value*> results;
198        results.reserve(PKT->getNumElements());
199
200       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
201            results.push_back(new Argument(PKT->getElementType()));
202       }
203       return packedToScalarMap.insert(it,
204                                       std::make_pair(value,results))->second;
205    }
206    else {
207        // we don't know what it is, and we are trying to retrieve
208        // a value for it
209        assert(false && "Unhandled VectorType value");
210        abort();
211    }
212 }
213
214 void LowerPacked::visitLoadInst(LoadInst& LI)
215 {
216    // Make sure what we are dealing with is a vector type
217    if (const VectorType* PKT = dyn_cast<VectorType>(LI.getType())) {
218        // Initialization, Idx is needed for getelementptr needed later
219        Value *Idx[2];
220        Idx[0] = ConstantInt::get(Type::Int32Ty, 0);
221
222        // Convert this load into num elements number of loads
223        std::vector<Value*> values;
224        values.reserve(PKT->getNumElements());
225
226        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
227             // Calculate the second index we will need
228             Idx[1] = ConstantInt::get(Type::Int32Ty,i);
229
230             // Get the pointer
231             Value* val = new GetElementPtrInst(LI.getPointerOperand(),
232                                                Idx, array_endof(Idx),
233                                                LI.getName() +
234                                                ".ge." + utostr(i),
235                                                &LI);
236
237             // generate the new load and save the result in packedToScalar map
238             values.push_back(new LoadInst(val, LI.getName()+"."+utostr(i),
239                              LI.isVolatile(), &LI));
240        }
241
242        setValues(&LI,values);
243        Changed = true;
244        instrsToRemove.push_back(&LI);
245    }
246 }
247
248 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
249 {
250    // Make sure both operands are VectorTypes
251    if (isa<VectorType>(BO.getOperand(0)->getType())) {
252        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
253        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
254        std::vector<Value*> result;
255        assert((op0Vals.size() == op1Vals.size()) &&
256               "The two packed operand to scalar maps must be equal in size.");
257
258        result.reserve(op0Vals.size());
259
260        // generate the new binary op and save the result
261        for (unsigned i = 0; i != op0Vals.size(); ++i) {
262             result.push_back(BinaryOperator::create(BO.getOpcode(),
263                                                     op0Vals[i],
264                                                     op1Vals[i],
265                                                     BO.getName() +
266                                                     "." + utostr(i),
267                                                     &BO));
268        }
269
270        setValues(&BO,result);
271        Changed = true;
272        instrsToRemove.push_back(&BO);
273    }
274 }
275
276 void LowerPacked::visitICmpInst(ICmpInst& IC)
277 {
278    // Make sure both operands are VectorTypes
279    if (isa<VectorType>(IC.getOperand(0)->getType())) {
280        std::vector<Value*>& op0Vals = getValues(IC.getOperand(0));
281        std::vector<Value*>& op1Vals = getValues(IC.getOperand(1));
282        std::vector<Value*> result;
283        assert((op0Vals.size() == op1Vals.size()) &&
284               "The two packed operand to scalar maps must be equal in size.");
285
286        result.reserve(op0Vals.size());
287
288        // generate the new binary op and save the result
289        for (unsigned i = 0; i != op0Vals.size(); ++i) {
290             result.push_back(CmpInst::create(IC.getOpcode(),
291                                              IC.getPredicate(),
292                                              op0Vals[i],
293                                              op1Vals[i],
294                                              IC.getName() +
295                                              "." + utostr(i),
296                                              &IC));
297        }
298
299        setValues(&IC,result);
300        Changed = true;
301        instrsToRemove.push_back(&IC);
302    }
303 }
304
305 void LowerPacked::visitStoreInst(StoreInst& SI)
306 {
307    if (const VectorType* PKT =
308        dyn_cast<VectorType>(SI.getOperand(0)->getType())) {
309        // We will need this for getelementptr
310        Value *Idx[2];
311        Idx[0] = ConstantInt::get(Type::Int32Ty, 0);
312
313        std::vector<Value*>& values = getValues(SI.getOperand(0));
314
315        assert((values.size() == PKT->getNumElements()) &&
316               "Scalar must have the same number of elements as Vector Type");
317
318        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
319             // Generate the indices for getelementptr
320             Idx[1] = ConstantInt::get(Type::Int32Ty,i);
321             Value* val = new GetElementPtrInst(SI.getPointerOperand(),
322                                                Idx, array_endof(Idx),
323                                                "store.ge." +
324                                                utostr(i) + ".",
325                                                &SI);
326             new StoreInst(values[i], val, SI.isVolatile(),&SI);
327        }
328
329        Changed = true;
330        instrsToRemove.push_back(&SI);
331    }
332 }
333
334 void LowerPacked::visitSelectInst(SelectInst& SELI)
335 {
336    // Make sure both operands are VectorTypes
337    if (isa<VectorType>(SELI.getType())) {
338        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
339        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
340        std::vector<Value*> result;
341
342       assert((op0Vals.size() == op1Vals.size()) &&
343              "The two packed operand to scalar maps must be equal in size.");
344
345       for (unsigned i = 0; i != op0Vals.size(); ++i) {
346            result.push_back(new SelectInst(SELI.getCondition(),
347                                            op0Vals[i],
348                                            op1Vals[i],
349                                            SELI.getName()+ "." + utostr(i),
350                                            &SELI));
351       }
352
353       setValues(&SELI,result);
354       Changed = true;
355       instrsToRemove.push_back(&SELI);
356    }
357 }
358
359 void LowerPacked::visitExtractElementInst(ExtractElementInst& EI)
360 {
361   std::vector<Value*>& op0Vals = getValues(EI.getOperand(0));
362   const VectorType *PTy = cast<VectorType>(EI.getOperand(0)->getType());
363   Value *op1 = EI.getOperand(1);
364
365   if (ConstantInt *C = dyn_cast<ConstantInt>(op1)) {
366     EI.replaceAllUsesWith(op0Vals[C->getZExtValue()]);
367   } else {
368     AllocaInst *alloca = 
369       new AllocaInst(PTy->getElementType(),
370                      ConstantInt::get(Type::Int32Ty, PTy->getNumElements()),
371                      EI.getName() + ".alloca", 
372                      EI.getParent()->getParent()->getEntryBlock().begin());
373     for (unsigned i = 0; i < PTy->getNumElements(); ++i) {
374       GetElementPtrInst *GEP = 
375         new GetElementPtrInst(alloca, ConstantInt::get(Type::Int32Ty, i),
376                               "store.ge", &EI);
377       new StoreInst(op0Vals[i], GEP, &EI);
378     }
379     GetElementPtrInst *GEP = 
380       new GetElementPtrInst(alloca, op1, EI.getName() + ".ge", &EI);
381     LoadInst *load = new LoadInst(GEP, EI.getName() + ".load", &EI);
382     EI.replaceAllUsesWith(load);
383   }
384
385   Changed = true;
386   instrsToRemove.push_back(&EI);
387 }
388
389 void LowerPacked::visitInsertElementInst(InsertElementInst& IE)
390 {
391   std::vector<Value*>& Vals = getValues(IE.getOperand(0));
392   Value *Elt = IE.getOperand(1);
393   Value *Idx = IE.getOperand(2);
394   std::vector<Value*> result;
395   result.reserve(Vals.size());
396
397   if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
398     unsigned idxVal = C->getZExtValue();
399     for (unsigned i = 0; i != Vals.size(); ++i) {
400       result.push_back(i == idxVal ? Elt : Vals[i]);
401     }
402   } else {
403     for (unsigned i = 0; i != Vals.size(); ++i) {
404       ICmpInst *icmp =
405         new ICmpInst(ICmpInst::ICMP_EQ, Idx, 
406                      ConstantInt::get(Type::Int32Ty, i),
407                      "icmp", &IE);
408       SelectInst *select =
409         new SelectInst(icmp, Elt, Vals[i], "select", &IE);
410       result.push_back(select);
411     }
412   }
413
414   setValues(&IE, result);
415   Changed = true;
416   instrsToRemove.push_back(&IE);
417 }
418
419 bool LowerPacked::runOnFunction(Function& F)
420 {
421    // initialize
422    Changed = false;
423
424    // Does three passes:
425    // Pass 1) Converts Packed Operations to
426    //         new Packed Operations on smaller
427    //         datatypes
428    visit(F);
429
430    // Pass 2) Drop all references
431    std::for_each(instrsToRemove.begin(),
432                  instrsToRemove.end(),
433                  std::mem_fun(&Instruction::dropAllReferences));
434
435    // Pass 3) Delete the Instructions to remove aka packed instructions
436    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(),
437                                             e = instrsToRemove.end();
438         i != e; ++i) {
439         (*i)->getParent()->getInstList().erase(*i);
440    }
441
442    // clean-up
443    packedToScalarMap.clear();
444    instrsToRemove.clear();
445
446    return Changed;
447 }
448