[AArch64] Match interleaved memory accesses into ldN/stN instructions.
[oota-llvm.git] / lib / Target / AArch64 / AArch64InterleavedAccess.cpp
1 //=--------------------- AArch64InterleavedAccess.cpp ----------------------==//
2 //
3 // The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the AArch64InterleavedAccess pass, which identifies
11 // interleaved memory accesses and Transforms them into an AArch64 ldN/stN
12 // intrinsics (N = 2, 3, 4).
13 //
14 // An interleaved load reads data from memory into several vectors, with
15 // DE-interleaving the data on factor. An interleaved store writes several
16 // vectors to memory with RE-interleaving the data on factor. The interleave
17 // factor is equal to the number of vectors. AArch64 backend supports interleave
18 // factor of 2, 3 and 4.
19 //
20 // E.g. Transform an interleaved load (Factor = 2):
21 //        %wide.vec = load <8 x i32>, <8 x i32>* %ptr
22 //        %v0 = shuffle %wide.vec, undef, <0, 2, 4, 6>  ; Extract even elements
23 //        %v1 = shuffle %wide.vec, undef, <1, 3, 5, 7>  ; Extract odd elements
24 //      Into:
25 //        %ld2 = { <4 x i32>, <4 x i32> } call aarch64.neon.ld2(%ptr)
26 //        %v0 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 0
27 //        %v1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
28 //
29 // E.g. Transform an interleaved store (Factor = 2):
30 //        %i.vec = shuffle %v0, %v1, <0, 4, 1, 5, 2, 6, 3, 7>  ; Interleaved vec
31 //        store <8 x i32> %i.vec, <8 x i32>* %ptr
32 //      Into:
33 //        %v0 = shuffle %i.vec, undef, <0, 1, 2, 3>
34 //        %v1 = shuffle %i.vec, undef, <4, 5, 6, 7>
35 //        call void aarch64.neon.st2(%v0, %v1, %ptr)
36 //
37 //===----------------------------------------------------------------------===//
38
39 #include "AArch64.h"
40 #include "llvm/ADT/SetVector.h"
41 #include "llvm/Analysis/TargetTransformInfo.h"
42 #include "llvm/IR/InstIterator.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Module.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/MathExtras.h"
47 #include "llvm/Support/raw_ostream.h"
48
49 using namespace llvm;
50
51 #define DEBUG_TYPE "aarch64-interleaved-access"
52
53 static const unsigned MIN_FACTOR = 2;
54 static const unsigned MAX_FACTOR = 4;
55
56 namespace llvm {
57 static void initializeAArch64InterleavedAccessPass(PassRegistry &);
58 }
59
60 namespace {
61
62 class AArch64InterleavedAccess : public FunctionPass {
63
64 public:
65   static char ID;
66   AArch64InterleavedAccess() : FunctionPass(ID) {
67     initializeAArch64InterleavedAccessPass(*PassRegistry::getPassRegistry());
68   }
69
70   const char *getPassName() const override {
71     return "AArch64 Interleaved Access Pass";
72   }
73
74   bool runOnFunction(Function &F) override;
75
76 private:
77   const DataLayout *DL;
78   Module *M;
79
80   /// \brief Transform an interleaved load into ldN intrinsic.
81   bool matchInterleavedLoad(ShuffleVectorInst *SVI,
82                             SmallSetVector<Instruction *, 32> &DeadInsts);
83
84   /// \brief Transform an interleaved store into stN intrinsic.
85   bool matchInterleavedStore(ShuffleVectorInst *SVI,
86                              SmallSetVector<Instruction *, 32> &DeadInsts);
87 };
88 } // end anonymous namespace.
89
90 char AArch64InterleavedAccess::ID = 0;
91
92 INITIALIZE_PASS_BEGIN(AArch64InterleavedAccess, DEBUG_TYPE,
93                       "AArch64 interleaved access Pass", false, false)
94 INITIALIZE_PASS_END(AArch64InterleavedAccess, DEBUG_TYPE,
95                     "AArch64 interleaved access Pass", false, false)
96
97 FunctionPass *llvm::createAArch64InterleavedAccessPass() {
98   return new AArch64InterleavedAccess();
99 }
100
101 /// \brief Get a ldN/stN intrinsic according to the Factor (2, 3, or 4).
102 static Intrinsic::ID getLdNStNIntrinsic(unsigned Factor, bool IsLoad) {
103   static const Intrinsic::ID LoadInt[3] = {Intrinsic::aarch64_neon_ld2,
104                                            Intrinsic::aarch64_neon_ld3,
105                                            Intrinsic::aarch64_neon_ld4};
106   static const Intrinsic::ID StoreInt[3] = {Intrinsic::aarch64_neon_st2,
107                                             Intrinsic::aarch64_neon_st3,
108                                             Intrinsic::aarch64_neon_st4};
109
110   assert(Factor >= MIN_FACTOR && Factor <= MAX_FACTOR &&
111          "Invalid interleave factor");
112
113   if (IsLoad)
114     return LoadInt[Factor - 2];
115   else
116     return StoreInt[Factor - 2];
117 }
118
119 /// \brief Check if the mask is a DE-interleave mask of the given factor
120 /// \p Factor like:
121 ///     <Index, Index+Factor, ..., Index+(NumElts-1)*Factor>
122 static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor,
123                                        unsigned &Index) {
124   // Check all potential start indices from 0 to (Factor - 1).
125   for (Index = 0; Index < Factor; Index++) {
126     unsigned i = 0;
127
128     // Check that elements are in ascending order by Factor.
129     for (; i < Mask.size(); i++)
130       if (Mask[i] >= 0 && static_cast<unsigned>(Mask[i]) != Index + i * Factor)
131         break;
132
133     if (i == Mask.size())
134       return true;
135   }
136
137   return false;
138 }
139
140 /// \brief Check if the mask is a DE-interleave mask for an interleaved load.
141 ///
142 /// E.g. DE-interleave masks (Factor = 2) could be:
143 ///     <0, 2, 4, 6>    (mask of index 0 to extract even elements)
144 ///     <1, 3, 5, 7>    (mask of index 1 to extract odd elements)
145 static bool isDeInterleaveMask(ArrayRef<int> Mask, unsigned &Factor,
146                                unsigned &Index) {
147   unsigned NumElts = Mask.size();
148   if (NumElts < 2)
149     return false;
150
151   for (Factor = MIN_FACTOR; Factor <= MAX_FACTOR; Factor++)
152     if (isDeInterleaveMaskOfFactor(Mask, Factor, Index))
153       return true;
154
155   return false;
156 }
157
158 /// \brief Check if the given mask \p Mask is RE-interleaved mask of the given
159 /// factor \p Factor.
160 ///
161 /// I.e. <0, NumSubElts, ... , NumSubElts*(Factor - 1), 1, NumSubElts + 1, ...>
162 static bool isReInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor) {
163   unsigned NumElts = Mask.size();
164   if (NumElts % Factor)
165     return false;
166
167   unsigned NumSubElts = NumElts / Factor;
168   if (!isPowerOf2_32(NumSubElts))
169     return false;
170
171   for (unsigned i = 0; i < NumSubElts; i++)
172     for (unsigned j = 0; j < Factor; j++)
173       if (Mask[i * Factor + j] >= 0 &&
174           static_cast<unsigned>(Mask[i * Factor + j]) != j * NumSubElts + i)
175         return false;
176
177   return true;
178 }
179
180 /// \brief Check if the mask is RE-interleave mask for an interleaved store.
181 ///
182 /// E.g. The RE-interleave mask (Factor = 2) could be:
183 ///     <0, 4, 1, 5, 2, 6, 3, 7>
184 static bool isReInterleaveMask(ArrayRef<int> Mask, unsigned &Factor) {
185   if (Mask.size() < 4)
186     return false;
187
188   // Check potential Factors and return true if find a factor for the mask.
189   for (Factor = MIN_FACTOR; Factor <= MAX_FACTOR; Factor++)
190     if (isReInterleaveMaskOfFactor(Mask, Factor))
191       return true;
192
193   return false;
194 }
195
196 /// \brief Get a mask consisting of sequential integers starting from \p Start.
197 ///
198 /// I.e. <Start, Start + 1, ..., Start + NumElts - 1>
199 static Constant *getSequentialMask(IRBuilder<> &Builder, unsigned Start,
200                                    unsigned NumElts) {
201   SmallVector<Constant *, 16> Mask;
202   for (unsigned i = 0; i < NumElts; i++)
203     Mask.push_back(Builder.getInt32(Start + i));
204
205   return ConstantVector::get(Mask);
206 }
207
208 bool AArch64InterleavedAccess::matchInterleavedLoad(
209     ShuffleVectorInst *SVI, SmallSetVector<Instruction *, 32> &DeadInsts) {
210   if (DeadInsts.count(SVI))
211     return false;
212
213   LoadInst *LI = dyn_cast<LoadInst>(SVI->getOperand(0));
214   if (!LI || !LI->isSimple() || !isa<UndefValue>(SVI->getOperand(1)))
215     return false;
216
217   SmallVector<ShuffleVectorInst *, 4> Shuffles;
218
219   // Check if all users of this load are shufflevectors.
220   for (auto UI = LI->user_begin(), E = LI->user_end(); UI != E; UI++) {
221     ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(*UI);
222     if (!SV)
223       return false;
224
225     Shuffles.push_back(SV);
226   }
227
228   // Check if the type of the first shuffle is legal.
229   VectorType *VecTy = Shuffles[0]->getType();
230   unsigned TypeSize = DL->getTypeAllocSizeInBits(VecTy);
231   if (TypeSize != 64 && TypeSize != 128)
232     return false;
233
234   // Check if the mask of the first shuffle is strided and get the start index.
235   unsigned Factor, Index;
236   if (!isDeInterleaveMask(Shuffles[0]->getShuffleMask(), Factor, Index))
237     return false;
238
239   // Holds the corresponding index for each strided shuffle.
240   SmallVector<unsigned, 4> Indices;
241   Indices.push_back(Index);
242
243   // Check if other shufflevectors are of the same type and factor
244   for (unsigned i = 1; i < Shuffles.size(); i++) {
245     if (Shuffles[i]->getType() != VecTy)
246       return false;
247
248     unsigned Index;
249     if (!isDeInterleaveMaskOfFactor(Shuffles[i]->getShuffleMask(), Factor,
250                                     Index))
251       return false;
252
253     Indices.push_back(Index);
254   }
255
256   DEBUG(dbgs() << "Found an interleaved load:" << *LI << "\n");
257
258   // A pointer vector can not be the return type of the ldN intrinsics. Need to
259   // load integer vectors first and then convert to pointer vectors.
260   Type *EltTy = VecTy->getVectorElementType();
261   if (EltTy->isPointerTy())
262     VecTy = VectorType::get(DL->getIntPtrType(EltTy),
263                             VecTy->getVectorNumElements());
264
265   Type *PtrTy = VecTy->getPointerTo(LI->getPointerAddressSpace());
266   Type *Tys[2] = {VecTy, PtrTy};
267   Function *LdNFunc =
268       Intrinsic::getDeclaration(M, getLdNStNIntrinsic(Factor, true), Tys);
269
270   IRBuilder<> Builder(LI);
271   Value *Ptr = Builder.CreateBitCast(LI->getPointerOperand(), PtrTy);
272
273   CallInst *LdN = Builder.CreateCall(LdNFunc, Ptr, "ldN");
274   DEBUG(dbgs() << "   Created:" << *LdN << "\n");
275
276   // Replace each strided shufflevector with the corresponding vector loaded
277   // by ldN.
278   for (unsigned i = 0; i < Shuffles.size(); i++) {
279     ShuffleVectorInst *SV = Shuffles[i];
280     unsigned Index = Indices[i];
281
282     Value *SubVec = Builder.CreateExtractValue(LdN, Index);
283
284     // Convert the integer vector to pointer vector if the element is pointer.
285     if (EltTy->isPointerTy())
286       SubVec = Builder.CreateIntToPtr(SubVec, SV->getType());
287
288     SV->replaceAllUsesWith(SubVec);
289
290     DEBUG(dbgs() << "  Replaced:" << *SV << "\n"
291                  << "      With:" << *SubVec << "\n");
292
293     // Avoid analyzing it twice.
294     DeadInsts.insert(SV);
295   }
296
297   // Mark this load as dead.
298   DeadInsts.insert(LI);
299   return true;
300 }
301
302 bool AArch64InterleavedAccess::matchInterleavedStore(
303     ShuffleVectorInst *SVI, SmallSetVector<Instruction *, 32> &DeadInsts) {
304   if (DeadInsts.count(SVI) || !SVI->hasOneUse())
305     return false;
306
307   StoreInst *SI = dyn_cast<StoreInst>(SVI->user_back());
308   if (!SI || !SI->isSimple())
309     return false;
310
311   // Check if the mask is interleaved and get the interleave factor.
312   unsigned Factor;
313   if (!isReInterleaveMask(SVI->getShuffleMask(), Factor))
314     return false;
315
316   VectorType *VecTy = SVI->getType();
317   unsigned NumSubElts = VecTy->getVectorNumElements() / Factor;
318   Type *EltTy = VecTy->getVectorElementType();
319   VectorType *SubVecTy = VectorType::get(EltTy, NumSubElts);
320
321   // Skip illegal vector types.
322   unsigned TypeSize = DL->getTypeAllocSizeInBits(SubVecTy);
323   if (TypeSize != 64 && TypeSize != 128)
324     return false;
325
326   DEBUG(dbgs() << "Found an interleaved store:" << *SI << "\n");
327
328   Value *Op0 = SVI->getOperand(0);
329   Value *Op1 = SVI->getOperand(1);
330   IRBuilder<> Builder(SI);
331
332   // StN intrinsics don't support pointer vectors as arguments. Convert pointer
333   // vectors to integer vectors.
334   if (EltTy->isPointerTy()) {
335     Type *IntTy = DL->getIntPtrType(EltTy);
336     unsigned NumOpElts =
337         dyn_cast<VectorType>(Op0->getType())->getVectorNumElements();
338
339     // The corresponding integer vector type of the same element size.
340     Type *IntVecTy = VectorType::get(IntTy, NumOpElts);
341
342     Op0 = Builder.CreatePtrToInt(Op0, IntVecTy);
343     Op1 = Builder.CreatePtrToInt(Op1, IntVecTy);
344     SubVecTy = VectorType::get(IntTy, NumSubElts);
345   }
346
347   Type *PtrTy = SubVecTy->getPointerTo(SI->getPointerAddressSpace());
348   Type *Tys[2] = {SubVecTy, PtrTy};
349   Function *StNFunc =
350       Intrinsic::getDeclaration(M, getLdNStNIntrinsic(Factor, false), Tys);
351
352   SmallVector<Value *, 5> Ops;
353
354   // Split the shufflevector operands into sub vectors for the new stN call.
355   for (unsigned i = 0; i < Factor; i++)
356     Ops.push_back(Builder.CreateShuffleVector(
357         Op0, Op1, getSequentialMask(Builder, NumSubElts * i, NumSubElts)));
358
359   Ops.push_back(Builder.CreateBitCast(SI->getPointerOperand(), PtrTy));
360   CallInst *StN = Builder.CreateCall(StNFunc, Ops);
361
362   (void)StN; // silence warning.
363   DEBUG(dbgs() << "  Replaced:" << *SI << "'\n");
364   DEBUG(dbgs() << "      with:" << *StN << "\n");
365
366   // Mark this shufflevector and store as dead.
367   DeadInsts.insert(SI);
368   DeadInsts.insert(SVI);
369   return true;
370 }
371
372 bool AArch64InterleavedAccess::runOnFunction(Function &F) {
373   DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() << "\n");
374
375   M = F.getParent();
376   DL = &M->getDataLayout();
377
378   // Holds dead instructions that will be erased later.
379   SmallSetVector<Instruction *, 32> DeadInsts;
380   bool Changed = false;
381   for (auto &I : inst_range(F)) {
382     if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(&I)) {
383       Changed |= matchInterleavedLoad(SVI, DeadInsts);
384       Changed |= matchInterleavedStore(SVI, DeadInsts);
385     }
386   }
387
388   for (auto I : DeadInsts)
389     I->eraseFromParent();
390
391   return Changed;
392 }