The process of linking types can cause their addresses to become invalid. For this...
[oota-llvm.git] / lib / Linker / LinkModules.cpp
index 3fdbc7f03db51598218e588df15645862fa6ad82..e6e89c36d30dd9f5f570d92925ec2b78a2939bc7 100644 (file)
@@ -27,25 +27,91 @@ static inline bool Error(std::string *E, const std::string &Message) {
 // ResolveTypes - Attempt to link the two specified types together.  Return true
 // if there is an error and they cannot yet be linked.
 //
-static bool ResolveTypes(Type *DestTy, Type *SrcTy, SymbolTable *DestST, 
-                         const std::string &Name) {
+static bool ResolveTypes(const Type *DestTy, const Type *SrcTy,
+                         SymbolTable *DestST, const std::string &Name) {
+  if (DestTy == SrcTy) return false;       // If already equal, noop
+
   // Does the type already exist in the module?
   if (DestTy && !isa<OpaqueType>(DestTy)) {  // Yup, the type already exists...
-    if (DestTy == SrcTy) return false;       // If already equal, noop
-    if (OpaqueType *OT = dyn_cast<OpaqueType>(SrcTy)) {
-      OT->refineAbstractTypeTo(DestTy);
+    if (const OpaqueType *OT = dyn_cast<OpaqueType>(SrcTy)) {
+      const_cast<OpaqueType*>(OT)->refineAbstractTypeTo(DestTy);
     } else {
       return true;  // Cannot link types... neither is opaque and not-equal
     }
   } else {                       // Type not in dest module.  Add it now.
     if (DestTy)                  // Type _is_ in module, just opaque...
-      cast<OpaqueType>(DestTy)->refineAbstractTypeTo(SrcTy);
+      const_cast<OpaqueType*>(cast<OpaqueType>(DestTy))
+                           ->refineAbstractTypeTo(SrcTy);
     else
-      DestST->insert(Name, SrcTy);
+      DestST->insert(Name, const_cast<Type*>(SrcTy));
   }
   return false;
 }
 
+static const FunctionType *getFT(const PATypeHolder &TH) {
+  return cast<FunctionType>(TH.get());
+}
+static const StructType *getsT(const PATypeHolder &TH) {
+  return cast<StructType>(TH.get());
+}
+
+// RecursiveResolveTypes - This is just like ResolveTypes, except that it
+// recurses down into derived types, merging the used types if the parent types
+// are compatible.
+//
+static bool RecursiveResolveTypes(const PATypeHolder &DestTy,
+                                  const PATypeHolder &SrcTy,
+                                  SymbolTable *DestST, const std::string &Name){
+  const Type *SrcTyT = SrcTy.get();
+  const Type *DestTyT = DestTy.get();
+  if (DestTyT == SrcTyT) return false;       // If already equal, noop
+  
+  // If we found our opaque type, resolve it now!
+  if (isa<OpaqueType>(DestTyT) || isa<OpaqueType>(SrcTyT))
+    return ResolveTypes(DestTyT, SrcTyT, DestST, Name);
+  
+  // Two types cannot be resolved together if they are of different primitive
+  // type.  For example, we cannot resolve an int to a float.
+  if (DestTyT->getPrimitiveID() != SrcTyT->getPrimitiveID()) return true;
+
+  // Otherwise, resolve the used type used by this derived type...
+  switch (DestTyT->getPrimitiveID()) {
+  case Type::FunctionTyID: {
+    if (cast<FunctionType>(DestTyT)->isVarArg() !=
+        cast<FunctionType>(SrcTyT)->isVarArg())
+      return true;
+    for (unsigned i = 0, e = getFT(DestTy)->getNumContainedTypes(); i != e; ++i)
+      if (RecursiveResolveTypes(getFT(DestTy)->getContainedType(i),
+                                getFT(SrcTy)->getContainedType(i), DestST,Name))
+        return true;
+    return false;
+  }
+  case Type::StructTyID: {
+    if (getST(DestTy)->getNumContainedTypes() != 
+        getST(SrcTy)->getNumContainedTypes()) return 1;
+    for (unsigned i = 0, e = getST(DestTy)->getNumContainedTypes(); i != e; ++i)
+      if (RecursiveResolveTypes(getST(DestTy)->getContainedType(i),
+                                getST(SrcTy)->getContainedType(i), DestST,Name))
+        return true;
+    return false;
+  }
+  case Type::ArrayTyID: {
+    const ArrayType *DAT = cast<ArrayType>(DestTy.get());
+    const ArrayType *SAT = cast<ArrayType>(SrcTy.get());
+    if (DAT->getNumElements() != SAT->getNumElements()) return true;
+    return RecursiveResolveTypes(DAT->getElementType(), SAT->getElementType(),
+                                 DestST, Name);
+  }
+  case Type::PointerTyID:
+    return RecursiveResolveTypes(
+                              cast<PointerType>(DestTy.get())->getElementType(),
+                              cast<PointerType>(SrcTy.get())->getElementType(),
+                                 DestST, Name);
+  default: assert(0 && "Unexpected type!"); return true;
+  }  
+}
+
+
 // LinkTypes - Go through the symbol table of the Src module and see if any
 // types are named in the src module that are not named in the Dst module.
 // Make sure there are no type name conflicts.
@@ -83,6 +149,7 @@ static bool LinkTypes(Module *Dest, const Module *Src, std::string *Err) {
     // Loop over all of the types, attempting to resolve them if possible...
     unsigned OldSize = DelayedTypesToResolve.size();
 
+    // Try direct resolution by name...
     for (unsigned i = 0; i != DelayedTypesToResolve.size(); ++i) {
       const std::string &Name = DelayedTypesToResolve[i];
       Type *T1 = cast<Type>(VM.find(Name)->second);
@@ -96,18 +163,39 @@ static bool LinkTypes(Module *Dest, const Module *Src, std::string *Err) {
 
     // Did we not eliminate any types?
     if (DelayedTypesToResolve.size() == OldSize) {
-      // Build up an error message of all of the mismatched types.
-      std::string ErrorMessage;
+      // Attempt to resolve subelements of types.  This allows us to merge these
+      // two types: { int* } and { opaque* }
       for (unsigned i = 0, e = DelayedTypesToResolve.size(); i != e; ++i) {
         const std::string &Name = DelayedTypesToResolve[i];
-        const Type *T1 = cast<Type>(VM.find(Name)->second);
-        const Type *T2 = cast<Type>(DestST->lookup(Type::TypeTy, Name));
-        ErrorMessage += "  Type named '" + Name + 
-                        "' conflicts.\n    Src='" + T1->getDescription() +
-                        "'.\n   Dest='" + T2->getDescription() + "'\n";
+        PATypeHolder T1(cast<Type>(VM.find(Name)->second));
+        PATypeHolder T2(cast<Type>(DestST->lookup(Type::TypeTy, Name)));
+
+        if (!RecursiveResolveTypes(T2, T1, DestST, Name)) {
+          // We are making progress!
+          DelayedTypesToResolve.erase(DelayedTypesToResolve.begin()+i);
+          
+          // Go back to the main loop, perhaps we can resolve directly by name
+          // now...
+          break;
+        }
+      }
+
+      // If we STILL cannot resolve the types, then there is something wrong.
+      // Report the error.
+      if (DelayedTypesToResolve.size() == OldSize) {
+        // Build up an error message of all of the mismatched types.
+        std::string ErrorMessage;
+        for (unsigned i = 0, e = DelayedTypesToResolve.size(); i != e; ++i) {
+          const std::string &Name = DelayedTypesToResolve[i];
+          const Type *T1 = cast<Type>(VM.find(Name)->second);
+          const Type *T2 = cast<Type>(DestST->lookup(Type::TypeTy, Name));
+          ErrorMessage += "  Type named '" + Name + 
+                          "' conflicts.\n    Src='" + T1->getDescription() +
+                          "'.\n   Dest='" + T2->getDescription() + "'\n";
+        }
+        return Error(Err, "Type conflict between types in modules:\n" +
+                     ErrorMessage);
       }
-      return Error(Err, "Type conflict between types in modules:\n" +
-                        ErrorMessage);
     }
   }