Update GEP constructors to use an iterator interface to fix
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
1 //===- LowerPacked.cpp -  Implementation of LowerPacked Transform ---------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Brad Jones and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements lowering Packed datatypes into more primitive
11 // Packed datatypes, and finally to scalar operations.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Argument.h"
17 #include "llvm/Constants.h"
18 #include "llvm/DerivedTypes.h"
19 #include "llvm/Function.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Pass.h"
22 #include "llvm/Support/Compiler.h"
23 #include "llvm/Support/InstVisitor.h"
24 #include "llvm/Support/Streams.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include <algorithm>
27 #include <map>
28 #include <functional>
29 using namespace llvm;
30
31 namespace {
32
33 /// This pass converts packed operators to an
34 /// equivalent operations on smaller packed data, to possibly
35 /// scalar operations.  Currently it supports lowering
36 /// to scalar operations.
37 ///
38 /// @brief Transforms packed instructions to simpler instructions.
39 ///
40 class VISIBILITY_HIDDEN LowerPacked 
41   : public FunctionPass, public InstVisitor<LowerPacked> {
42 public:
43     static char ID; // Pass identification, replacement for typeid
44     LowerPacked() : FunctionPass((intptr_t)&ID) {}
45
46    /// @brief Lowers packed operations to scalar operations.
47    /// @param F The fuction to process
48    virtual bool runOnFunction(Function &F);
49
50    /// @brief Lowers packed load instructions.
51    /// @param LI the load instruction to convert
52    void visitLoadInst(LoadInst& LI);
53
54    /// @brief Lowers packed store instructions.
55    /// @param SI the store instruction to convert
56    void visitStoreInst(StoreInst& SI);
57
58    /// @brief Lowers packed binary operations.
59    /// @param BO the binary operator to convert
60    void visitBinaryOperator(BinaryOperator& BO);
61
62    /// @brief Lowers packed icmp operations.
63    /// @param IC the icmp operator to convert
64    void visitICmpInst(ICmpInst& IC);
65
66    /// @brief Lowers packed select instructions.
67    /// @param SELI the select operator to convert
68    void visitSelectInst(SelectInst& SELI);
69
70    /// @brief Lowers packed extractelement instructions.
71    /// @param EE the extractelement operator to convert
72    void visitExtractElementInst(ExtractElementInst& EE);
73
74    /// @brief Lowers packed insertelement instructions.
75    /// @param IE the insertelement operator to convert
76    void visitInsertElementInst(InsertElementInst& IE);
77
78    /// This function asserts if the instruction is a VectorType but
79    /// is handled by another function.
80    ///
81    /// @brief Asserts if VectorType instruction is not handled elsewhere.
82    /// @param I the unhandled instruction
83    void visitInstruction(Instruction &I) {
84      if (isa<VectorType>(I.getType()))
85        cerr << "Unhandled Instruction with Packed ReturnType: " << I << '\n';
86    }
87 private:
88    /// @brief Retrieves lowered values for a packed value.
89    /// @param val the packed value
90    /// @return the lowered values
91    std::vector<Value*>& getValues(Value* val);
92
93    /// @brief Sets lowered values for a packed value.
94    /// @param val the packed value
95    /// @param values the corresponding lowered values
96    void setValues(Value* val,const std::vector<Value*>& values);
97
98    // Data Members
99    /// @brief whether we changed the function or not
100    bool Changed;
101
102    /// @brief a map from old packed values to new smaller packed values
103    std::map<Value*,std::vector<Value*> > packedToScalarMap;
104
105    /// Instructions in the source program to get rid of
106    /// after we do a pass (the old packed instructions)
107    std::vector<Instruction*> instrsToRemove;
108 };
109
110 char LowerPacked::ID = 0;
111 RegisterPass<LowerPacked>
112 X("lower-packed",
113   "lowers packed operations to operations on smaller packed datatypes");
114
115 } // end namespace
116
117 FunctionPass *llvm::createLowerPackedPass() { return new LowerPacked(); }
118
119
120 // This function sets lowered values for a corresponding
121 // packed value.  Note, in the case of a forward reference
122 // getValues(Value*) will have already been called for
123 // the packed parameter.  This function will then replace
124 // all references in the in the function of the "dummy"
125 // value the previous getValues(Value*) call
126 // returned with actual references.
127 void LowerPacked::setValues(Value* value,const std::vector<Value*>& values)
128 {
129    std::map<Value*,std::vector<Value*> >::iterator it =
130          packedToScalarMap.lower_bound(value);
131    if (it == packedToScalarMap.end() || it->first != value) {
132        // there was not a forward reference to this element
133        packedToScalarMap.insert(it,std::make_pair(value,values));
134    }
135    else {
136       // replace forward declarations with actual definitions
137       assert(it->second.size() == values.size() &&
138              "Error forward refences and actual definition differ in size");
139       for (unsigned i = 0, e = values.size(); i != e; ++i) {
140            // replace and get rid of old forward references
141            it->second[i]->replaceAllUsesWith(values[i]);
142            delete it->second[i];
143            it->second[i] = values[i];
144       }
145    }
146 }
147
148 // This function will examine the packed value parameter
149 // and if it is a packed constant or a forward reference
150 // properly create the lowered values needed.  Otherwise
151 // it will simply retreive values from a
152 // setValues(Value*,const std::vector<Value*>&)
153 // call.  Failing both of these cases, it will abort
154 // the program.
155 std::vector<Value*>& LowerPacked::getValues(Value* value)
156 {
157    assert(isa<VectorType>(value->getType()) &&
158           "Value must be VectorType");
159
160    // reject further processing if this one has
161    // already been handled
162    std::map<Value*,std::vector<Value*> >::iterator it =
163       packedToScalarMap.lower_bound(value);
164    if (it != packedToScalarMap.end() && it->first == value) {
165        return it->second;
166    }
167
168    if (ConstantVector* CP = dyn_cast<ConstantVector>(value)) {
169        // non-zero constant case
170        std::vector<Value*> results;
171        results.reserve(CP->getNumOperands());
172        for (unsigned i = 0, e = CP->getNumOperands(); i != e; ++i) {
173           results.push_back(CP->getOperand(i));
174        }
175        return packedToScalarMap.insert(it,
176                                        std::make_pair(value,results))->second;
177    }
178    else if (ConstantAggregateZero* CAZ =
179             dyn_cast<ConstantAggregateZero>(value)) {
180        // zero constant
181        const VectorType* PKT = cast<VectorType>(CAZ->getType());
182        std::vector<Value*> results;
183        results.reserve(PKT->getNumElements());
184
185        Constant* C = Constant::getNullValue(PKT->getElementType());
186        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
187             results.push_back(C);
188        }
189        return packedToScalarMap.insert(it,
190                                        std::make_pair(value,results))->second;
191    }
192    else if (isa<Instruction>(value)) {
193        // foward reference
194        const VectorType* PKT = cast<VectorType>(value->getType());
195        std::vector<Value*> results;
196        results.reserve(PKT->getNumElements());
197
198       for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
199            results.push_back(new Argument(PKT->getElementType()));
200       }
201       return packedToScalarMap.insert(it,
202                                       std::make_pair(value,results))->second;
203    }
204    else {
205        // we don't know what it is, and we are trying to retrieve
206        // a value for it
207        assert(false && "Unhandled VectorType value");
208        abort();
209    }
210 }
211
212 void LowerPacked::visitLoadInst(LoadInst& LI)
213 {
214    // Make sure what we are dealing with is a vector type
215    if (const VectorType* PKT = dyn_cast<VectorType>(LI.getType())) {
216        // Initialization, Idx is needed for getelementptr needed later
217        std::vector<Value*> Idx(2);
218        Idx[0] = ConstantInt::get(Type::Int32Ty,0);
219
220        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
221                                       PKT->getNumElements());
222        PointerType* APT = PointerType::get(AT);
223
224        // Cast the pointer to vector type to an equivalent array
225        Value* array = new BitCastInst(LI.getPointerOperand(), APT, 
226                                       LI.getName() + ".a", &LI);
227
228        // Convert this load into num elements number of loads
229        std::vector<Value*> values;
230        values.reserve(PKT->getNumElements());
231
232        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
233             // Calculate the second index we will need
234             Idx[1] = ConstantInt::get(Type::Int32Ty,i);
235
236             // Get the pointer
237             Value* val = new GetElementPtrInst(array,
238                                                Idx.begin(), Idx.end(),
239                                                LI.getName() +
240                                                ".ge." + utostr(i),
241                                                &LI);
242
243             // generate the new load and save the result in packedToScalar map
244             values.push_back(new LoadInst(val, LI.getName()+"."+utostr(i),
245                              LI.isVolatile(), &LI));
246        }
247
248        setValues(&LI,values);
249        Changed = true;
250        instrsToRemove.push_back(&LI);
251    }
252 }
253
254 void LowerPacked::visitBinaryOperator(BinaryOperator& BO)
255 {
256    // Make sure both operands are VectorTypes
257    if (isa<VectorType>(BO.getOperand(0)->getType())) {
258        std::vector<Value*>& op0Vals = getValues(BO.getOperand(0));
259        std::vector<Value*>& op1Vals = getValues(BO.getOperand(1));
260        std::vector<Value*> result;
261        assert((op0Vals.size() == op1Vals.size()) &&
262               "The two packed operand to scalar maps must be equal in size.");
263
264        result.reserve(op0Vals.size());
265
266        // generate the new binary op and save the result
267        for (unsigned i = 0; i != op0Vals.size(); ++i) {
268             result.push_back(BinaryOperator::create(BO.getOpcode(),
269                                                     op0Vals[i],
270                                                     op1Vals[i],
271                                                     BO.getName() +
272                                                     "." + utostr(i),
273                                                     &BO));
274        }
275
276        setValues(&BO,result);
277        Changed = true;
278        instrsToRemove.push_back(&BO);
279    }
280 }
281
282 void LowerPacked::visitICmpInst(ICmpInst& IC)
283 {
284    // Make sure both operands are VectorTypes
285    if (isa<VectorType>(IC.getOperand(0)->getType())) {
286        std::vector<Value*>& op0Vals = getValues(IC.getOperand(0));
287        std::vector<Value*>& op1Vals = getValues(IC.getOperand(1));
288        std::vector<Value*> result;
289        assert((op0Vals.size() == op1Vals.size()) &&
290               "The two packed operand to scalar maps must be equal in size.");
291
292        result.reserve(op0Vals.size());
293
294        // generate the new binary op and save the result
295        for (unsigned i = 0; i != op0Vals.size(); ++i) {
296             result.push_back(CmpInst::create(IC.getOpcode(),
297                                              IC.getPredicate(),
298                                              op0Vals[i],
299                                              op1Vals[i],
300                                              IC.getName() +
301                                              "." + utostr(i),
302                                              &IC));
303        }
304
305        setValues(&IC,result);
306        Changed = true;
307        instrsToRemove.push_back(&IC);
308    }
309 }
310
311 void LowerPacked::visitStoreInst(StoreInst& SI)
312 {
313    if (const VectorType* PKT =
314        dyn_cast<VectorType>(SI.getOperand(0)->getType())) {
315        // We will need this for getelementptr
316        std::vector<Value*> Idx(2);
317        Idx[0] = ConstantInt::get(Type::Int32Ty,0);
318
319        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
320                                       PKT->getNumElements());
321        PointerType* APT = PointerType::get(AT);
322
323        // Cast the pointer to packed to an array of equivalent type
324        Value* array = new BitCastInst(SI.getPointerOperand(), APT, 
325                                       "store.ge.a.", &SI);
326
327        std::vector<Value*>& values = getValues(SI.getOperand(0));
328
329        assert((values.size() == PKT->getNumElements()) &&
330               "Scalar must have the same number of elements as Vector Type");
331
332        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
333             // Generate the indices for getelementptr
334             Idx[1] = ConstantInt::get(Type::Int32Ty,i);
335             Value* val = new GetElementPtrInst(array,
336                                                Idx.begin(), Idx.end(),
337                                                "store.ge." +
338                                                utostr(i) + ".",
339                                                &SI);
340             new StoreInst(values[i], val, SI.isVolatile(),&SI);
341        }
342
343        Changed = true;
344        instrsToRemove.push_back(&SI);
345    }
346 }
347
348 void LowerPacked::visitSelectInst(SelectInst& SELI)
349 {
350    // Make sure both operands are VectorTypes
351    if (isa<VectorType>(SELI.getType())) {
352        std::vector<Value*>& op0Vals = getValues(SELI.getTrueValue());
353        std::vector<Value*>& op1Vals = getValues(SELI.getFalseValue());
354        std::vector<Value*> result;
355
356       assert((op0Vals.size() == op1Vals.size()) &&
357              "The two packed operand to scalar maps must be equal in size.");
358
359       for (unsigned i = 0; i != op0Vals.size(); ++i) {
360            result.push_back(new SelectInst(SELI.getCondition(),
361                                            op0Vals[i],
362                                            op1Vals[i],
363                                            SELI.getName()+ "." + utostr(i),
364                                            &SELI));
365       }
366
367       setValues(&SELI,result);
368       Changed = true;
369       instrsToRemove.push_back(&SELI);
370    }
371 }
372
373 void LowerPacked::visitExtractElementInst(ExtractElementInst& EI)
374 {
375   std::vector<Value*>& op0Vals = getValues(EI.getOperand(0));
376   const VectorType *PTy = cast<VectorType>(EI.getOperand(0)->getType());
377   Value *op1 = EI.getOperand(1);
378
379   if (ConstantInt *C = dyn_cast<ConstantInt>(op1)) {
380     EI.replaceAllUsesWith(op0Vals[C->getZExtValue()]);
381   } else {
382     AllocaInst *alloca = 
383       new AllocaInst(PTy->getElementType(),
384                      ConstantInt::get(Type::Int32Ty, PTy->getNumElements()),
385                      EI.getName() + ".alloca", 
386                      EI.getParent()->getParent()->getEntryBlock().begin());
387     for (unsigned i = 0; i < PTy->getNumElements(); ++i) {
388       GetElementPtrInst *GEP = 
389         new GetElementPtrInst(alloca, ConstantInt::get(Type::Int32Ty, i),
390                               "store.ge", &EI);
391       new StoreInst(op0Vals[i], GEP, &EI);
392     }
393     GetElementPtrInst *GEP = 
394       new GetElementPtrInst(alloca, op1, EI.getName() + ".ge", &EI);
395     LoadInst *load = new LoadInst(GEP, EI.getName() + ".load", &EI);
396     EI.replaceAllUsesWith(load);
397   }
398
399   Changed = true;
400   instrsToRemove.push_back(&EI);
401 }
402
403 void LowerPacked::visitInsertElementInst(InsertElementInst& IE)
404 {
405   std::vector<Value*>& Vals = getValues(IE.getOperand(0));
406   Value *Elt = IE.getOperand(1);
407   Value *Idx = IE.getOperand(2);
408   std::vector<Value*> result;
409   result.reserve(Vals.size());
410
411   if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
412     unsigned idxVal = C->getZExtValue();
413     for (unsigned i = 0; i != Vals.size(); ++i) {
414       result.push_back(i == idxVal ? Elt : Vals[i]);
415     }
416   } else {
417     for (unsigned i = 0; i != Vals.size(); ++i) {
418       ICmpInst *icmp =
419         new ICmpInst(ICmpInst::ICMP_EQ, Idx, 
420                      ConstantInt::get(Type::Int32Ty, i),
421                      "icmp", &IE);
422       SelectInst *select =
423         new SelectInst(icmp, Elt, Vals[i], "select", &IE);
424       result.push_back(select);
425     }
426   }
427
428   setValues(&IE, result);
429   Changed = true;
430   instrsToRemove.push_back(&IE);
431 }
432
433 bool LowerPacked::runOnFunction(Function& F)
434 {
435    // initialize
436    Changed = false;
437
438    // Does three passes:
439    // Pass 1) Converts Packed Operations to
440    //         new Packed Operations on smaller
441    //         datatypes
442    visit(F);
443
444    // Pass 2) Drop all references
445    std::for_each(instrsToRemove.begin(),
446                  instrsToRemove.end(),
447                  std::mem_fun(&Instruction::dropAllReferences));
448
449    // Pass 3) Delete the Instructions to remove aka packed instructions
450    for (std::vector<Instruction*>::iterator i = instrsToRemove.begin(),
451                                             e = instrsToRemove.end();
452         i != e; ++i) {
453         (*i)->getParent()->getInstList().erase(*i);
454    }
455
456    // clean-up
457    packedToScalarMap.clear();
458    instrsToRemove.clear();
459
460    return Changed;
461 }
462