* Allow datasize to be specified on the commandline
[oota-llvm.git] / lib / Transforms / IPO / OldPoolAllocate.cpp
index 1d1e2fe14c76945ca816f311c44304361c01e5d6..0776ad79fe2e50736ea64c7c716690771b8265b9 100644 (file)
 //
 #define DEBUG_CREATE_POOLS 1
 
+#include "Support/CommandLine.h"
+enum PtrSize {
+  Ptr8bits, Ptr16bits, Ptr32bits
+};
+
+static cl::Enum<enum PtrSize> ReqPointerSize("ptrsize", 0,
+                                      "Set pointer size for pool allocation",
+  clEnumValN(Ptr32bits, "32", "Use 32 bit indices for pointers"),
+  clEnumValN(Ptr16bits, "16", "Use 16 bit indices for pointers"),
+  clEnumValN(Ptr8bits ,  "8", "Use 8 bit indices for pointers"), 0);
+
 const Type *POINTERTYPE;
 
 // FIXME: This is dependant on the sparc backend layout conventions!!
 static TargetData TargetData("test");
 
+static const Type *getPointerTransformedType(const Type *Ty) {
+  if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
+    return POINTERTYPE;
+  } else if (StructType *STy = dyn_cast<StructType>(Ty)) {
+    vector<const Type *> NewElTypes;
+    NewElTypes.reserve(STy->getElementTypes().size());
+    for (StructType::ElementTypes::const_iterator
+           I = STy->getElementTypes().begin(),
+           E = STy->getElementTypes().end(); I != E; ++I)
+      NewElTypes.push_back(getPointerTransformedType(*I));
+    return StructType::get(NewElTypes);
+  } else if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
+    return ArrayType::get(getPointerTransformedType(ATy->getElementType()),
+                                                    ATy->getNumElements());
+  } else {
+    assert(Ty->isPrimitiveType() && "Unknown derived type!");
+    return Ty;
+  }
+}
+
 namespace {
   struct PoolInfo {
     DSNode *Node;           // The node this pool allocation represents
@@ -62,18 +93,7 @@ namespace {
 
       // The new type of the memory object is the same as the old type, except
       // that all of the pointer values are replaced with POINTERTYPE values.
-      assert(isa<StructType>(getOldType()) && "Can only handle structs!");
-      StructType *OldTy = cast<StructType>(getOldType());
-      vector<const Type *> NewElTypes;
-      NewElTypes.reserve(OldTy->getElementTypes().size());
-      for (StructType::ElementTypes::const_iterator
-             I = OldTy->getElementTypes().begin(),
-             E = OldTy->getElementTypes().end(); I != E; ++I)
-        if (PointerType *PT = dyn_cast<PointerType>(I->get()))
-          NewElTypes.push_back(POINTERTYPE);
-        else
-          NewElTypes.push_back(*I);
-      NewType = StructType::get(NewElTypes);
+      NewType = getPointerTransformedType(getOldType());
     }
   };
 
@@ -149,7 +169,11 @@ namespace {
   // Define the pass class that we implement...
   struct PoolAllocate : public Pass {
     PoolAllocate() {
-      POINTERTYPE = Type::UShortTy;
+      switch (ReqPointerSize) {
+      case Ptr32bits: POINTERTYPE = Type::UIntTy; break;
+      case Ptr16bits: POINTERTYPE = Type::UShortTy; break;
+      case Ptr8bits:  POINTERTYPE = Type::UByteTy; break;
+      }
 
       CurModule = 0; DS = 0;
       PoolInit = PoolDestroy = PoolAlloc = PoolFree = 0;
@@ -747,6 +771,27 @@ void PoolAllocate::transformFunctionBody(Function *F, FunctionDSGraph &IPFGraph,
       InstToFix.push_back(cast<Instruction>(*UI));
   }
 
+  // Make sure that we get return instructions that return a null value from the
+  // function...
+  //
+  if (!IPFGraph.getRetNodes().empty()) {
+    assert(IPFGraph.getRetNodes().size() == 1 && "Can only return one node?");
+    PointerVal RetNode = IPFGraph.getRetNodes()[0];
+    assert(RetNode.Index == 0 && "Subindexing not implemented yet!");
+
+    // Only process return instructions if the return value of this function is
+    // part of one of the data structures we are transforming...
+    //
+    if (PoolDescs.count(RetNode.Node)) {
+      // Loop over all of the basic blocks, adding return instructions...
+      for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I)
+        if (ReturnInst *RI = dyn_cast<ReturnInst>((*I)->getTerminator()))
+          InstToFix.push_back(RI);
+    }
+  }
+
+
+
   // Eliminate duplicates by sorting, then removing equal neighbors.
   sort(InstToFix.begin(), InstToFix.end());
   InstToFix.erase(unique(InstToFix.begin(), InstToFix.end()), InstToFix.end());