[InterleavedAccess] Add a pass InterleavedAccess to identify interleaved memory acces...
[oota-llvm.git] / lib / CodeGen / InterleavedAccessPass.cpp
1 //=----------------------- InterleavedAccessPass.cpp -----------------------==//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the Interleaved Access pass, which identifies
11 // interleaved memory accesses and transforms into target specific intrinsics.
12 //
13 // An interleaved load reads data from memory into several vectors, with
14 // DE-interleaving the data on a factor. An interleaved store writes several
15 // vectors to memory with RE-interleaving the data on a factor.
16 //
17 // As interleaved accesses are hard to be identified in CodeGen (mainly because
18 // the VECTOR_SHUFFLE DAG node is quite different from the shufflevector IR),
19 // we identify and transform them to intrinsics in this pass. So the intrinsics
20 // can be easily matched into target specific instructions later in CodeGen.
21 //
22 // E.g. An interleaved load (Factor = 2):
23 //        %wide.vec = load <8 x i32>, <8 x i32>* %ptr
24 //        %v0 = shuffle <8 x i32> %wide.vec, <8 x i32> undef, <0, 2, 4, 6>
25 //        %v1 = shuffle <8 x i32> %wide.vec, <8 x i32> undef, <1, 3, 5, 7>
26 //
27 // It could be transformed into a ld2 intrinsic in AArch64 backend or a vld2
28 // intrinsic in ARM backend.
29 //
30 // E.g. An interleaved store (Factor = 3):
31 //        %i.vec = shuffle <8 x i32> %v0, <8 x i32> %v1,
32 //                                    <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11>
33 //        store <12 x i32> %i.vec, <12 x i32>* %ptr
34 //
35 // It could be transformed into a st3 intrinsic in AArch64 backend or a vst3
36 // intrinsic in ARM backend.
37 //
38 //===----------------------------------------------------------------------===//
39
40 #include "llvm/CodeGen/Passes.h"
41 #include "llvm/IR/InstIterator.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Target/TargetLowering.h"
45 #include "llvm/Target/TargetSubtargetInfo.h"
46
47 using namespace llvm;
48
49 #define DEBUG_TYPE "interleaved-access"
50
51 static cl::opt<bool> LowerInterleavedAccesses(
52     "lower-interleaved-accesses",
53     cl::desc("Enable lowering interleaved accesses to intrinsics"),
54     cl::init(false), cl::Hidden);
55
56 static unsigned MaxFactor; // The maximum supported interleave factor.
57
58 namespace llvm {
59 static void initializeInterleavedAccessPass(PassRegistry &);
60 }
61
62 namespace {
63
64 class InterleavedAccess : public FunctionPass {
65
66 public:
67   static char ID;
68   InterleavedAccess(const TargetMachine *TM = nullptr)
69       : FunctionPass(ID), TM(TM), TLI(nullptr) {
70     initializeInterleavedAccessPass(*PassRegistry::getPassRegistry());
71   }
72
73   const char *getPassName() const override { return "Interleaved Access Pass"; }
74
75   bool runOnFunction(Function &F) override;
76
77 private:
78   const TargetMachine *TM;
79   const TargetLowering *TLI;
80
81   /// \brief Transform an interleaved load into target specific intrinsics.
82   bool lowerInterleavedLoad(LoadInst *LI,
83                             SmallVector<Instruction *, 32> &DeadInsts);
84
85   /// \brief Transform an interleaved store into target specific intrinsics.
86   bool lowerInterleavedStore(StoreInst *SI,
87                              SmallVector<Instruction *, 32> &DeadInsts);
88 };
89 } // end anonymous namespace.
90
91 char InterleavedAccess::ID = 0;
92 INITIALIZE_TM_PASS(InterleavedAccess, "interleaved-access",
93     "Lower interleaved memory accesses to target specific intrinsics",
94     false, false)
95
96 FunctionPass *llvm::createInterleavedAccessPass(const TargetMachine *TM) {
97   return new InterleavedAccess(TM);
98 }
99
100 /// \brief Check if the mask is a DE-interleave mask of the given factor
101 /// \p Factor like:
102 ///     <Index, Index+Factor, ..., Index+(NumElts-1)*Factor>
103 static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor,
104                                        unsigned &Index) {
105   // Check all potential start indices from 0 to (Factor - 1).
106   for (Index = 0; Index < Factor; Index++) {
107     unsigned i = 0;
108
109     // Check that elements are in ascending order by Factor. Ignore undef
110     // elements.
111     for (; i < Mask.size(); i++)
112       if (Mask[i] >= 0 && static_cast<unsigned>(Mask[i]) != Index + i * Factor)
113         break;
114
115     if (i == Mask.size())
116       return true;
117   }
118
119   return false;
120 }
121
122 /// \brief Check if the mask is a DE-interleave mask for an interleaved load.
123 ///
124 /// E.g. DE-interleave masks (Factor = 2) could be:
125 ///     <0, 2, 4, 6>    (mask of index 0 to extract even elements)
126 ///     <1, 3, 5, 7>    (mask of index 1 to extract odd elements)
127 static bool isDeInterleaveMask(ArrayRef<int> Mask, unsigned &Factor,
128                                unsigned &Index) {
129   if (Mask.size() < 2)
130     return false;
131
132   // Check potential Factors.
133   for (Factor = 2; Factor <= MaxFactor; Factor++)
134     if (isDeInterleaveMaskOfFactor(Mask, Factor, Index))
135       return true;
136
137   return false;
138 }
139
140 /// \brief Check if the mask is RE-interleave mask for an interleaved store.
141 ///
142 /// I.e. <0, NumSubElts, ... , NumSubElts*(Factor - 1), 1, NumSubElts + 1, ...>
143 ///
144 /// E.g. The RE-interleave mask (Factor = 2) could be:
145 ///     <0, 4, 1, 5, 2, 6, 3, 7>
146 static bool isReInterleaveMask(ArrayRef<int> Mask, unsigned &Factor) {
147   unsigned NumElts = Mask.size();
148   if (NumElts < 4)
149     return false;
150
151   // Check potential Factors.
152   for (Factor = 2; Factor <= MaxFactor; Factor++) {
153     if (NumElts % Factor)
154       continue;
155
156     unsigned NumSubElts = NumElts / Factor;
157     if (!isPowerOf2_32(NumSubElts))
158       continue;
159
160     // Check whether each element matchs the RE-interleaved rule. Ignore undef
161     // elements.
162     unsigned i = 0;
163     for (; i < NumElts; i++)
164       if (Mask[i] >= 0 &&
165           static_cast<unsigned>(Mask[i]) !=
166               (i % Factor) * NumSubElts + i / Factor)
167         break;
168
169     // Find a RE-interleaved mask of current factor.
170     if (i == NumElts)
171       return true;
172   }
173
174   return false;
175 }
176
177 bool InterleavedAccess::lowerInterleavedLoad(
178     LoadInst *LI, SmallVector<Instruction *, 32> &DeadInsts) {
179   if (!LI->isSimple())
180     return false;
181
182   SmallVector<ShuffleVectorInst *, 4> Shuffles;
183
184   // Check if all users of this load are shufflevectors.
185   for (auto UI = LI->user_begin(), E = LI->user_end(); UI != E; UI++) {
186     ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(*UI);
187     if (!SVI || !isa<UndefValue>(SVI->getOperand(1)))
188       return false;
189
190     Shuffles.push_back(SVI);
191   }
192
193   if (Shuffles.empty())
194     return false;
195
196   unsigned Factor, Index;
197
198   // Check if the first shufflevector is DE-interleave shuffle.
199   if (!isDeInterleaveMask(Shuffles[0]->getShuffleMask(), Factor, Index))
200     return false;
201
202   // Holds the corresponding index for each DE-interleave shuffle.
203   SmallVector<unsigned, 4> Indices;
204   Indices.push_back(Index);
205
206   Type *VecTy = Shuffles[0]->getType();
207
208   // Check if other shufflevectors are also DE-interleaved of the same type
209   // and factor as the first shufflevector.
210   for (unsigned i = 1; i < Shuffles.size(); i++) {
211     if (Shuffles[i]->getType() != VecTy)
212       return false;
213
214     if (!isDeInterleaveMaskOfFactor(Shuffles[i]->getShuffleMask(), Factor,
215                                     Index))
216       return false;
217
218     Indices.push_back(Index);
219   }
220
221   DEBUG(dbgs() << "IA: Found an interleaved load: " << *LI << "\n");
222
223   // Try to create target specific intrinsics to replace the load and shuffles.
224   if (!TLI->lowerInterleavedLoad(LI, Shuffles, Indices, Factor))
225     return false;
226
227   for (auto SVI : Shuffles)
228     DeadInsts.push_back(SVI);
229
230   DeadInsts.push_back(LI);
231   return true;
232 }
233
234 bool InterleavedAccess::lowerInterleavedStore(
235     StoreInst *SI, SmallVector<Instruction *, 32> &DeadInsts) {
236   if (!SI->isSimple())
237     return false;
238
239   ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(SI->getValueOperand());
240   if (!SVI || !SVI->hasOneUse())
241     return false;
242
243   // Check if the shufflevector is RE-interleave shuffle.
244   unsigned Factor;
245   if (!isReInterleaveMask(SVI->getShuffleMask(), Factor))
246     return false;
247
248   DEBUG(dbgs() << "IA: Found an interleaved store: " << *SI << "\n");
249
250   // Try to create target specific intrinsics to replace the store and shuffle.
251   if (!TLI->lowerInterleavedStore(SI, SVI, Factor))
252     return false;
253
254   // Already have a new target specific interleaved store. Erase the old store.
255   DeadInsts.push_back(SI);
256   DeadInsts.push_back(SVI);
257   return true;
258 }
259
260 bool InterleavedAccess::runOnFunction(Function &F) {
261   if (!TM || !LowerInterleavedAccesses)
262     return false;
263
264   DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() << "\n");
265
266   TLI = TM->getSubtargetImpl(F)->getTargetLowering();
267   MaxFactor = TLI->getMaxSupportedInterleaveFactor();
268
269   // Holds dead instructions that will be erased later.
270   SmallVector<Instruction *, 32> DeadInsts;
271   bool Changed = false;
272
273   for (auto &I : inst_range(F)) {
274     if (LoadInst *LI = dyn_cast<LoadInst>(&I))
275       Changed |= lowerInterleavedLoad(LI, DeadInsts);
276
277     if (StoreInst *SI = dyn_cast<StoreInst>(&I))
278       Changed |= lowerInterleavedStore(SI, DeadInsts);
279   }
280
281   for (auto I : DeadInsts)
282     I->eraseFromParent();
283
284   return Changed;
285 }