Implement a complete type inference system for dag patterns, based on the
authorChris Lattner <sabre@nondot.org>
Thu, 8 Sep 2005 23:22:48 +0000 (23:22 +0000)
committerChris Lattner <sabre@nondot.org>
Thu, 8 Sep 2005 23:22:48 +0000 (23:22 +0000)
constraints defined in the DAG node definitions in the .td files.  This
allows us to infer (and check!) the types for all nodes in the current
ppc .td file.  For example, instead of:

Inst pattern EQV:       (set GPRC:i32:$rT, (xor (xor GPRC:i32:$rA, GPRC:i32:$rB), (imm)<<Predicate_immAllOnes>>))

we now fully infer:

Inst pattern EQV:       (set:void GPRC:i32:$rT, (xor:i32 (xor:i32 GPRC:i32:$rA, GPRC:i32:$rB), (imm:i32)<<Predicate_immAllOnes>>))

from:  (set GPRC:$rT, (not (xor GPRC:$rA, GPRC:$rB)))

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@23284 91177308-0d34-0410-b5e6-96231b3b80d8

utils/TableGen/DAGISelEmitter.cpp
utils/TableGen/DAGISelEmitter.h

index dfb17f3ae42c1b4896cd557a2af20e006b8c58a3..555a6341c46daf170e57e95fa226cba5e622524d 100644 (file)
@@ -45,6 +45,87 @@ SDTypeConstraint::SDTypeConstraint(Record *R) {
   }
 }
 
+/// getOperandNum - Return the node corresponding to operand #OpNo in tree
+/// N, which has NumResults results.
+TreePatternNode *SDTypeConstraint::getOperandNum(unsigned OpNo,
+                                                 TreePatternNode *N,
+                                                 unsigned NumResults) const {
+  assert(NumResults == 1 && "We only work with single result nodes so far!");
+  
+  if (OpNo < NumResults)
+    return N;  // FIXME: need value #
+  else
+    return N->getChild(OpNo-NumResults);
+}
+
+/// ApplyTypeConstraint - Given a node in a pattern, apply this type
+/// constraint to the nodes operands.  This returns true if it makes a
+/// change, false otherwise.  If a type contradiction is found, throw an
+/// exception.
+bool SDTypeConstraint::ApplyTypeConstraint(TreePatternNode *N,
+                                           const SDNodeInfo &NodeInfo,
+                                           TreePattern &TP) const {
+  unsigned NumResults = NodeInfo.getNumResults();
+  assert(NumResults == 1 && "We only work with single result nodes so far!");
+  
+  // Check that the number of operands is sane.
+  if (NodeInfo.getNumOperands() >= 0) {
+    if (N->getNumChildren() != (unsigned)NodeInfo.getNumOperands())
+      TP.error(N->getOperator()->getName() + " node requires exactly " +
+               itostr(NodeInfo.getNumOperands()) + " operands!");
+  }
+  
+  TreePatternNode *NodeToApply = getOperandNum(OperandNo, N, NumResults);
+  
+  switch (ConstraintType) {
+  default: assert(0 && "Unknown constraint type!");
+  case SDTCisVT:
+    // Operand must be a particular type.
+    return NodeToApply->UpdateNodeType(x.SDTCisVT_Info.VT, TP);
+  case SDTCisInt:
+    if (NodeToApply->hasTypeSet() && !MVT::isInteger(NodeToApply->getType()))
+      NodeToApply->UpdateNodeType(MVT::i1, TP);  // throw an error.
+
+    // FIXME: can tell from the target if there is only one Int type supported.
+    return false;
+  case SDTCisFP:
+    if (NodeToApply->hasTypeSet() &&
+        !MVT::isFloatingPoint(NodeToApply->getType()))
+      NodeToApply->UpdateNodeType(MVT::f32, TP);  // throw an error.
+    // FIXME: can tell from the target if there is only one FP type supported.
+    return false;
+  case SDTCisSameAs: {
+    TreePatternNode *OtherNode =
+      getOperandNum(x.SDTCisSameAs_Info.OtherOperandNum, N, NumResults);
+    return NodeToApply->UpdateNodeType(OtherNode->getType(), TP) |
+           OtherNode->UpdateNodeType(NodeToApply->getType(), TP);
+  }
+  case SDTCisVTSmallerThanOp: {
+    // The NodeToApply must be a leaf node that is a VT.  OtherOperandNum must
+    // have an integer type that is smaller than the VT.
+    if (!NodeToApply->isLeaf() ||
+        !dynamic_cast<DefInit*>(NodeToApply->getLeafValue()) ||
+        !static_cast<DefInit*>(NodeToApply->getLeafValue())->getDef()
+               ->isSubClassOf("ValueType"))
+      TP.error(N->getOperator()->getName() + " expects a VT operand!");
+    MVT::ValueType VT =
+     getValueType(static_cast<DefInit*>(NodeToApply->getLeafValue())->getDef());
+    if (!MVT::isInteger(VT))
+      TP.error(N->getOperator()->getName() + " VT operand must be integer!");
+    
+    TreePatternNode *OtherNode =
+      getOperandNum(x.SDTCisVTSmallerThanOp_Info.OtherOperandNum, N,NumResults);
+    if (OtherNode->hasTypeSet() &&
+        (!MVT::isInteger(OtherNode->getType()) ||
+         OtherNode->getType() <= VT))
+      OtherNode->UpdateNodeType(MVT::Other, TP);  // Throw an error.
+    return false;
+  }
+  }  
+  return false;
+}
+
+
 //===----------------------------------------------------------------------===//
 // SDNodeInfo implementation
 //
@@ -77,6 +158,23 @@ TreePatternNode::~TreePatternNode() {
 #endif
 }
 
+/// UpdateNodeType - Set the node type of N to VT if VT contains
+/// information.  If N already contains a conflicting type, then throw an
+/// exception.  This returns true if any information was updated.
+///
+bool TreePatternNode::UpdateNodeType(MVT::ValueType VT, TreePattern &TP) {
+  if (VT == MVT::LAST_VALUETYPE || getType() == VT) return false;
+  if (getType() == MVT::LAST_VALUETYPE) {
+    setType(VT);
+    return true;
+  }
+  
+  TP.error("Type inference contradiction found in node " + 
+           getOperator()->getName() + "!");
+  return true; // unreachable
+}
+
+
 void TreePatternNode::print(std::ostream &OS) const {
   if (isLeaf()) {
     OS << *getLeafValue();
@@ -132,6 +230,8 @@ TreePatternNode *TreePatternNode::clone() const {
   return New;
 }
 
+/// SubstituteFormalArguments - Replace the formal arguments in this tree
+/// with actual values specified by ArgMap.
 void TreePatternNode::
 SubstituteFormalArguments(std::map<std::string, TreePatternNode*> &ArgMap) {
   if (isLeaf()) return;
@@ -196,6 +296,35 @@ TreePatternNode *TreePatternNode::InlinePatternFragments(TreePattern &TP) {
   return FragTree;
 }
 
+/// ApplyTypeConstraints - Apply all of the type constraints relevent to
+/// this node and its children in the tree.  This returns true if it makes a
+/// change, false otherwise.  If a type contradiction is found, throw an
+/// exception.
+bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP) {
+  if (isLeaf()) return false;
+  
+  // special handling for set, which isn't really an SDNode.
+  if (getOperator()->getName() == "set") {
+    assert (getNumChildren() == 2 && "Only handle 2 operand set's for now!");
+    bool MadeChange = getChild(0)->ApplyTypeConstraints(TP);
+    MadeChange |= getChild(1)->ApplyTypeConstraints(TP);
+    
+    // Types of operands must match.
+    MadeChange |= getChild(0)->UpdateNodeType(getChild(1)->getType(), TP);
+    MadeChange |= getChild(1)->UpdateNodeType(getChild(0)->getType(), TP);
+    MadeChange |= UpdateNodeType(MVT::isVoid, TP);
+    return MadeChange;
+  }
+  
+  const SDNodeInfo &NI = TP.getDAGISelEmitter().getSDNodeInfo(getOperator());
+  
+  bool MadeChange = NI.ApplyTypeConstraints(this, TP);
+  for (unsigned i = 0, e = getNumChildren(); i != e; ++i)
+    MadeChange |= getChild(i)->ApplyTypeConstraints(TP);
+  return MadeChange;  
+}
+
+
 //===----------------------------------------------------------------------===//
 // TreePattern implementation
 //
@@ -311,9 +440,8 @@ TreePatternNode *TreePattern::ParseTreePattern(DagInit *Dag) {
       return 0;
     }
     
-    // Apply the type cast...
-    assert(0 && "unimp yet");
-    //New->updateNodeType(getValueType(Operator), TheRecord->getName());
+    // Apply the type cast.
+    New->UpdateNodeType(getValueType(Operator), *this);
     return New;
   }
   
@@ -361,6 +489,23 @@ TreePatternNode *TreePattern::ParseTreePattern(DagInit *Dag) {
   return new TreePatternNode(Operator, Children);
 }
 
+/// InferAllTypes - Infer/propagate as many types throughout the expression
+/// patterns as possible.  Return true if all types are infered, false
+/// otherwise.  Throw an exception if a type contradiction is found.
+bool TreePattern::InferAllTypes() {
+  bool MadeChange = true;
+  while (MadeChange) {
+    MadeChange = false;
+    for (unsigned i = 0, e = Trees.size(); i != e; ++i)
+      MadeChange |= Trees[i]->ApplyTypeConstraints(*this);
+  }
+  
+  bool HasUnresolvedTypes = false;
+  for (unsigned i = 0, e = Trees.size(); i != e; ++i)
+    HasUnresolvedTypes |= Trees[i]->ContainsUnresolvedType();
+  return !HasUnresolvedTypes;
+}
+
 void TreePattern::print(std::ostream &OS) const {
   switch (getPatternType()) {
   case TreePattern::PatFrag:     OS << "PatFrag pattern "; break;
@@ -449,9 +594,22 @@ void DAGISelEmitter::ParseAndResolvePatternFragments(std::ostream &OS) {
   // that there are not references to PatFrags left inside of them.
   for (std::map<Record*, TreePattern*>::iterator I = PatternFragments.begin(),
        E = PatternFragments.end(); I != E; ++I) {
-    I->second->InlinePatternFragments();
+    TreePattern *ThePat = I->second;
+    ThePat->InlinePatternFragments();
+    
+    // Infer as many types as possible.  Don't worry about it if we don't infer
+    // all of them, some may depend on the inputs of the pattern.
+    try {
+      ThePat->InferAllTypes();
+    } catch (...) {
+      // If this pattern fragment is not supported by this target (no types can
+      // satisfy its constraints), just ignore it.  If the bogus pattern is
+      // actually used by instructions, the type consistency error will be
+      // reported there.
+    }
+    
     // If debugging, print out the pattern fragment result.
-    DEBUG(I->second->dump());
+    DEBUG(ThePat->dump());
   }
 }
 
@@ -473,12 +631,21 @@ void DAGISelEmitter::ParseAndResolveInstructions() {
       Trees.push_back((DagInit*)LI->getElement(j));
 
     // Parse the instruction.
-    Instructions.push_back(new TreePattern(TreePattern::Instruction, Instrs[i],
-                                           Trees, *this));
+    TreePattern *I = new TreePattern(TreePattern::Instruction, Instrs[i],
+                                     Trees, *this);
     // Inline pattern fragments into it.
-    Instructions.back()->InlinePatternFragments();
+    I->InlinePatternFragments();
+    
+    // Infer as many types as possible.  Don't worry about it if we don't infer
+    // all of them, some may depend on the inputs of the pattern.
+    if (!I->InferAllTypes()) {
+      I->dump();
+      I->error("Could not infer all types in pattern!");
+    }
+
+    DEBUG(I->dump());
     
-    DEBUG(Instructions.back()->dump());
+    Instructions.push_back(I);
   }
 }
 
index e2fade9917130d8aaf9213f077a91140c6cd1f76..20d24fb97524f1b56c4ffa5e6fe21860e624c630 100644 (file)
@@ -21,7 +21,9 @@ namespace llvm {
   class Record;
   class Init;
   class DagInit;
+  class SDNodeInfo;
   class TreePattern;
+  class TreePatternNode;
   class DAGISelEmitter;
   
   /// SDTypeConstraint - This is a discriminated union of constraints,
@@ -45,6 +47,18 @@ namespace llvm {
         unsigned OtherOperandNum;
       } SDTCisVTSmallerThanOp_Info;
     } x;
+
+    /// ApplyTypeConstraint - Given a node in a pattern, apply this type
+    /// constraint to the nodes operands.  This returns true if it makes a
+    /// change, false otherwise.  If a type contradiction is found, throw an
+    /// exception.
+    bool ApplyTypeConstraint(TreePatternNode *N, const SDNodeInfo &NodeInfo,
+                             TreePattern &TP) const;
+    
+    /// getOperandNum - Return the node corresponding to operand #OpNo in tree
+    /// N, which has NumResults results.
+    TreePatternNode *getOperandNum(unsigned OpNo, TreePatternNode *N,
+                                   unsigned NumResults) const;
   };
   
   /// SDNodeInfo - One of these records is created for each SDNode instance in
@@ -54,20 +68,32 @@ namespace llvm {
     Record *Def;
     std::string EnumName;
     std::string SDClassName;
-    int NumResults, NumOperands;
+    unsigned NumResults;
+    int NumOperands;
     std::vector<SDTypeConstraint> TypeConstraints;
   public:
     SDNodeInfo(Record *R);  // Parse the specified record.
     
-    int getNumResults() const { return NumResults; }
+    unsigned getNumResults() const { return NumResults; }
     int getNumOperands() const { return NumOperands; }
     Record *getRecord() const { return Def; }
     const std::string &getEnumName() const { return EnumName; }
     const std::string &getSDClassName() const { return SDClassName; }
     
-    const std::vector<SDTypeConstraint> &getTypeConstraints() {
+    const std::vector<SDTypeConstraint> &getTypeConstraints() const {
       return TypeConstraints;
     }
+
+    /// ApplyTypeConstraints - Given a node in a pattern, apply the type
+    /// constraints for this node to the operands of the node.  This returns
+    /// true if it makes a change, false otherwise.  If a type contradiction is
+    /// found, throw an exception.
+    bool ApplyTypeConstraints(TreePatternNode *N, TreePattern &TP) const {
+      bool MadeChange = false;
+      for (unsigned i = 0, e = TypeConstraints.size(); i != e; ++i)
+        MadeChange |= TypeConstraints[i].ApplyTypeConstraint(N, *this, TP);
+      return MadeChange;
+    }
   };
 
   /// FIXME: TreePatternNode's can be shared in some cases (due to dag-shaped
@@ -106,6 +132,7 @@ namespace llvm {
     void setName(const std::string &N) { Name = N; }
     
     bool isLeaf() const { return Val != 0; }
+    bool hasTypeSet() const { return Ty != MVT::LAST_VALUETYPE; }
     MVT::ValueType getType() const { return Ty; }
     void setType(MVT::ValueType VT) { Ty = VT; }
     
@@ -130,6 +157,8 @@ namespace llvm {
     ///
     TreePatternNode *clone() const;
     
+    /// SubstituteFormalArguments - Replace the formal arguments in this tree
+    /// with actual values specified by ArgMap.
     void SubstituteFormalArguments(std::map<std::string,
                                             TreePatternNode*> &ArgMap);
 
@@ -137,7 +166,27 @@ namespace llvm {
     /// fragments, inline them into place, giving us a pattern without any
     /// PatFrag references.
     TreePatternNode *InlinePatternFragments(TreePattern &TP);
-        
+    
+    /// ApplyTypeConstraints - Apply all of the type constraints relevent to
+    /// this node and its children in the tree.  This returns true if it makes a
+    /// change, false otherwise.  If a type contradiction is found, throw an
+    /// exception.
+    bool ApplyTypeConstraints(TreePattern &TP);
+    
+    /// UpdateNodeType - Set the node type of N to VT if VT contains
+    /// information.  If N already contains a conflicting type, then throw an
+    /// exception.  This returns true if any information was updated.
+    ///
+    bool UpdateNodeType(MVT::ValueType VT, TreePattern &TP);
+    
+    /// ContainsUnresolvedType - Return true if this tree contains any
+    /// unresolved types.
+    bool ContainsUnresolvedType() const {
+      if (Ty == MVT::LAST_VALUETYPE) return true;
+      for (unsigned i = 0, e = getNumChildren(); i != e; ++i)
+        if (getChild(i)->ContainsUnresolvedType()) return true;
+      return false;
+    }
   };
   
   
@@ -206,6 +255,11 @@ namespace llvm {
         Trees[i] = Trees[i]->InlinePatternFragments(*this);
     }
     
+    /// InferAllTypes - Infer/propagate as many types throughout the expression
+    /// patterns as possible.  Return true if all types are infered, false
+    /// otherwise.  Throw an exception if a type contradiction is found.
+    bool InferAllTypes();
+    
     /// error - Throw an exception, prefixing it with information about this
     /// pattern.
     void error(const std::string &Msg) const;