Removed #include <iostream> and replaced with llvm_* streams.
[oota-llvm.git] / lib / Transforms / Scalar / LowerPacked.cpp
index e8338bb96f72f116c080ca2c8b0a29255a0aef1d..ae8506e2fcedda1a471e9e20800252fd1a59e5f5 100644 (file)
 #include "llvm/Instructions.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/InstVisitor.h"
+#include "llvm/Support/Streams.h"
 #include "llvm/ADT/StringExtras.h"
 #include <algorithm>
 #include <map>
-#include <iostream>
 #include <functional>
-
 using namespace llvm;
 
 namespace {
@@ -61,19 +60,21 @@ public:
 
    /// @brief Lowers packed extractelement instructions.
    /// @param EI the extractelement operator to convert
-   void visitExtractElementInst(ExtractElementInst& EI);
+   void visitExtractElementInst(ExtractElementInst& EE);
+
+   /// @brief Lowers packed insertelement instructions.
+   /// @param EI the insertelement operator to convert
+   void visitInsertElementInst(InsertElementInst& IE);
 
    /// This function asserts if the instruction is a PackedType but
    /// is handled by another function.
    ///
    /// @brief Asserts if PackedType instruction is not handled elsewhere.
    /// @param I the unhandled instruction
-   void visitInstruction(Instruction &I)
-   {
-      if(isa<PackedType>(I.getType())) {
-         std::cerr << "Unhandled Instruction with Packed ReturnType: " <<
-                      I << '\n';
-      }
+   void visitInstruction(Instruction &I) {
+     if (isa<PackedType>(I.getType()))
+       llvm_cerr << "Unhandled Instruction with Packed ReturnType: "
+                 << I << '\n';
    }
 private:
    /// @brief Retrieves lowered values for a packed value.
@@ -98,7 +99,7 @@ private:
    std::vector<Instruction*> instrsToRemove;
 };
 
-RegisterOpt<LowerPacked>
+RegisterPass<LowerPacked>
 X("lower-packed",
   "lowers packed operations to operations on smaller packed datatypes");
 
@@ -205,7 +206,7 @@ void LowerPacked::visitLoadInst(LoadInst& LI)
    if (const PackedType* PKT = dyn_cast<PackedType>(LI.getType())) {
        // Initialization, Idx is needed for getelementptr needed later
        std::vector<Value*> Idx(2);
-       Idx[0] = ConstantUInt::get(Type::UIntTy,0);
+       Idx[0] = ConstantInt::get(Type::UIntTy,0);
 
        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
                                       PKT->getNumElements());
@@ -223,7 +224,7 @@ void LowerPacked::visitLoadInst(LoadInst& LI)
 
        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
             // Calculate the second index we will need
-            Idx[1] = ConstantUInt::get(Type::UIntTy,i);
+            Idx[1] = ConstantInt::get(Type::UIntTy,i);
 
             // Get the pointer
             Value* val = new GetElementPtrInst(array,
@@ -279,7 +280,7 @@ void LowerPacked::visitStoreInst(StoreInst& SI)
        dyn_cast<PackedType>(SI.getOperand(0)->getType())) {
        // We will need this for getelementptr
        std::vector<Value*> Idx(2);
-       Idx[0] = ConstantUInt::get(Type::UIntTy,0);
+       Idx[0] = ConstantInt::get(Type::UIntTy,0);
 
        ArrayType* AT = ArrayType::get(PKT->getContainedType(0),
                                       PKT->getNumElements());
@@ -297,7 +298,7 @@ void LowerPacked::visitStoreInst(StoreInst& SI)
 
        for (unsigned i = 0, e = PKT->getNumElements(); i != e; ++i) {
             // Generate the indices for getelementptr
-            Idx[1] = ConstantUInt::get(Type::UIntTy,i);
+            Idx[1] = ConstantInt::get(Type::UIntTy,i);
             Value* val = new GetElementPtrInst(array,
                                                Idx,
                                                "store.ge." +
@@ -342,19 +343,22 @@ void LowerPacked::visitExtractElementInst(ExtractElementInst& EI)
   const PackedType *PTy = cast<PackedType>(EI.getOperand(0)->getType());
   Value *op1 = EI.getOperand(1);
 
-  if (ConstantUInt *C = dyn_cast<ConstantUInt>(op1)) {
-    EI.replaceAllUsesWith(op0Vals[C->getValue()]);
+  if (ConstantInt *C = dyn_cast<ConstantInt>(op1)) {
+    EI.replaceAllUsesWith(op0Vals[C->getZExtValue()]);
   } else {
-    AllocaInst *alloca = new AllocaInst(PTy->getElementType(),
-                                       ConstantUInt::get(Type::UIntTy, PTy->getNumElements()),
-                                       EI.getName() + ".alloca", &(EI.getParent()->getParent()->getEntryBlock().front()));
+    AllocaInst *alloca = 
+      new AllocaInst(PTy->getElementType(),
+                     ConstantInt::get(Type::UIntTy, PTy->getNumElements()),
+                     EI.getName() + ".alloca", 
+                    EI.getParent()->getParent()->getEntryBlock().begin());
     for (unsigned i = 0; i < PTy->getNumElements(); ++i) {
-      GetElementPtrInst *GEP = new GetElementPtrInst(alloca, ConstantUInt::get(Type::UIntTy, i),
-                                                    "store.ge", &EI);
+      GetElementPtrInst *GEP = 
+        new GetElementPtrInst(alloca, ConstantInt::get(Type::UIntTy, i),
+                              "store.ge", &EI);
       new StoreInst(op0Vals[i], GEP, &EI);
     }
-    GetElementPtrInst *GEP = new GetElementPtrInst(alloca, op1,
-                                                  EI.getName() + ".ge", &EI);
+    GetElementPtrInst *GEP = 
+      new GetElementPtrInst(alloca, op1, EI.getName() + ".ge", &EI);
     LoadInst *load = new LoadInst(GEP, EI.getName() + ".load", &EI);
     EI.replaceAllUsesWith(load);
   }
@@ -363,6 +367,36 @@ void LowerPacked::visitExtractElementInst(ExtractElementInst& EI)
   instrsToRemove.push_back(&EI);
 }
 
+void LowerPacked::visitInsertElementInst(InsertElementInst& IE)
+{
+  std::vector<Value*>& Vals = getValues(IE.getOperand(0));
+  Value *Elt = IE.getOperand(1);
+  Value *Idx = IE.getOperand(2);
+  std::vector<Value*> result;
+  result.reserve(Vals.size());
+
+  if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
+    unsigned idxVal = C->getZExtValue();
+    for (unsigned i = 0; i != Vals.size(); ++i) {
+      result.push_back(i == idxVal ? Elt : Vals[i]);
+    }
+  } else {
+    for (unsigned i = 0; i != Vals.size(); ++i) {
+      SetCondInst *setcc =
+        new SetCondInst(Instruction::SetEQ, Idx, 
+                        ConstantInt::get(Type::UIntTy, i),
+                        "setcc", &IE);
+      SelectInst *select =
+        new SelectInst(setcc, Elt, Vals[i], "select", &IE);
+      result.push_back(select);
+    }
+  }
+
+  setValues(&IE, result);
+  Changed = true;
+  instrsToRemove.push_back(&IE);
+}
+
 bool LowerPacked::runOnFunction(Function& F)
 {
    // initialize