663e4afc5a33bfec014e9fab5fe7830921d46e7f
[oota-llvm.git] / include / llvm / CodeGen / PBQP / ReductionRules.h
1 //===----------- ReductionRules.h - Reduction Rules -------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Reduction Rules.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_CODEGEN_PBQP_REDUCTIONRULES_H
15 #define LLVM_CODEGEN_PBQP_REDUCTIONRULES_H
16
17 #include "Graph.h"
18 #include "Math.h"
19 #include "Solution.h"
20
21 namespace PBQP {
22
23   /// \brief Reduce a node of degree one.
24   ///
25   /// Propagate costs from the given node, which must be of degree one, to its
26   /// neighbor. Notify the problem domain.
27   template <typename GraphT>
28   void applyR1(GraphT &G, typename GraphT::NodeId NId) {
29     typedef typename GraphT::NodeId NodeId;
30     typedef typename GraphT::EdgeId EdgeId;
31     typedef typename GraphT::Vector Vector;
32     typedef typename GraphT::Matrix Matrix;
33     typedef typename GraphT::RawVector RawVector;
34
35     assert(G.getNodeDegree(NId) == 1 &&
36            "R1 applied to node with degree != 1.");
37
38     EdgeId EId = *G.adjEdgeIds(NId).begin();
39     NodeId MId = G.getEdgeOtherNodeId(EId, NId);
40
41     const Matrix &ECosts = G.getEdgeCosts(EId);
42     const Vector &XCosts = G.getNodeCosts(NId);
43     RawVector YCosts = G.getNodeCosts(MId);
44
45     // Duplicate a little to avoid transposing matrices.
46     if (NId == G.getEdgeNode1Id(EId)) {
47       for (unsigned j = 0; j < YCosts.getLength(); ++j) {
48         PBQPNum Min = ECosts[0][j] + XCosts[0];
49         for (unsigned i = 1; i < XCosts.getLength(); ++i) {
50           PBQPNum C = ECosts[i][j] + XCosts[i];
51           if (C < Min)
52             Min = C;
53         }
54         YCosts[j] += Min;
55       }
56     } else {
57       for (unsigned i = 0; i < YCosts.getLength(); ++i) {
58         PBQPNum Min = ECosts[i][0] + XCosts[0];
59         for (unsigned j = 1; j < XCosts.getLength(); ++j) {
60           PBQPNum C = ECosts[i][j] + XCosts[j];
61           if (C < Min)
62             Min = C;
63         }
64         YCosts[i] += Min;
65       }
66     }
67     G.setNodeCosts(MId, YCosts);
68     G.disconnectEdge(EId, MId);
69   }
70
71   template <typename GraphT>
72   void applyR2(GraphT &G, typename GraphT::NodeId NId) {
73     typedef typename GraphT::NodeId NodeId;
74     typedef typename GraphT::EdgeId EdgeId;
75     typedef typename GraphT::Vector Vector;
76     typedef typename GraphT::Matrix Matrix;
77     typedef typename GraphT::RawMatrix RawMatrix;
78
79     assert(G.getNodeDegree(NId) == 2 &&
80            "R2 applied to node with degree != 2.");
81
82     const Vector &XCosts = G.getNodeCosts(NId);
83
84     typename GraphT::AdjEdgeItr AEItr = G.adjEdgeIds(NId).begin();
85     EdgeId YXEId = *AEItr,
86            ZXEId = *(++AEItr);
87
88     NodeId YNId = G.getEdgeOtherNodeId(YXEId, NId),
89            ZNId = G.getEdgeOtherNodeId(ZXEId, NId);
90
91     bool FlipEdge1 = (G.getEdgeNode1Id(YXEId) == NId),
92          FlipEdge2 = (G.getEdgeNode1Id(ZXEId) == NId);
93
94     const Matrix *YXECosts = FlipEdge1 ?
95       new Matrix(G.getEdgeCosts(YXEId).transpose()) :
96       &G.getEdgeCosts(YXEId);
97
98     const Matrix *ZXECosts = FlipEdge2 ?
99       new Matrix(G.getEdgeCosts(ZXEId).transpose()) :
100       &G.getEdgeCosts(ZXEId);
101
102     unsigned XLen = XCosts.getLength(),
103       YLen = YXECosts->getRows(),
104       ZLen = ZXECosts->getRows();
105
106     RawMatrix Delta(YLen, ZLen);
107
108     for (unsigned i = 0; i < YLen; ++i) {
109       for (unsigned j = 0; j < ZLen; ++j) {
110         PBQPNum Min = (*YXECosts)[i][0] + (*ZXECosts)[j][0] + XCosts[0];
111         for (unsigned k = 1; k < XLen; ++k) {
112           PBQPNum C = (*YXECosts)[i][k] + (*ZXECosts)[j][k] + XCosts[k];
113           if (C < Min) {
114             Min = C;
115           }
116         }
117         Delta[i][j] = Min;
118       }
119     }
120
121     if (FlipEdge1)
122       delete YXECosts;
123
124     if (FlipEdge2)
125       delete ZXECosts;
126
127     EdgeId YZEId = G.findEdge(YNId, ZNId);
128
129     if (YZEId == G.invalidEdgeId()) {
130       YZEId = G.addEdge(YNId, ZNId, Delta);
131     } else {
132       const Matrix &YZECosts = G.getEdgeCosts(YZEId);
133       if (YNId == G.getEdgeNode1Id(YZEId)) {
134         G.setEdgeCosts(YZEId, Delta + YZECosts);
135       } else {
136         G.setEdgeCosts(YZEId, Delta.transpose() + YZECosts);
137       }
138     }
139
140     G.disconnectEdge(YXEId, YNId);
141     G.disconnectEdge(ZXEId, ZNId);
142
143     // TODO: Try to normalize newly added/modified edge.
144   }
145
146
147   // \brief Find a solution to a fully reduced graph by backpropagation.
148   //
149   // Given a graph and a reduction order, pop each node from the reduction
150   // order and greedily compute a minimum solution based on the node costs, and
151   // the dependent costs due to previously solved nodes.
152   //
153   // Note - This does not return the graph to its original (pre-reduction)
154   //        state: the existing solvers destructively alter the node and edge
155   //        costs. Given that, the backpropagate function doesn't attempt to
156   //        replace the edges either, but leaves the graph in its reduced
157   //        state.
158   template <typename GraphT, typename StackT>
159   Solution backpropagate(GraphT& G, StackT stack) {
160     typedef GraphBase::NodeId NodeId;
161     typedef typename GraphT::Matrix Matrix;
162     typedef typename GraphT::RawVector RawVector;
163
164     Solution s;
165
166     while (!stack.empty()) {
167       NodeId NId = stack.back();
168       stack.pop_back();
169
170       RawVector v = G.getNodeCosts(NId);
171
172       for (auto EId : G.adjEdgeIds(NId)) {
173         const Matrix& edgeCosts = G.getEdgeCosts(EId);
174         if (NId == G.getEdgeNode1Id(EId)) {
175           NodeId mId = G.getEdgeNode2Id(EId);
176           v += edgeCosts.getColAsVector(s.getSelection(mId));
177         } else {
178           NodeId mId = G.getEdgeNode1Id(EId);
179           v += edgeCosts.getRowAsVector(s.getSelection(mId));
180         }
181       }
182
183       s.setSelection(NId, v.minIndex());
184     }
185
186     return s;
187   }
188
189 }
190
191 #endif