From: Nadav Rotem Date: Sun, 25 Nov 2012 16:27:16 +0000 (+0000) Subject: Refactor the ptr runtime check generation code. No functionality change. X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=8c6b73666bdd08f15b31c00bd2fd663b632a1d65;p=oota-llvm.git Refactor the ptr runtime check generation code. No functionality change. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@168568 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 8ed4caf4ec4..2ca5feae95e 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -128,6 +128,10 @@ public: } private: + /// Add code that checks at runtime if the accessed arrays overlap. + /// Returns the comperator value or NULL if no check is needed. + Value* addRuntimeCheck(LoopVectorizationLegality *Legal, + Instruction *Loc); /// Create an empty loop, based on the loop ranges of the old loop. void createEmptyLoop(LoopVectorizationLegality *Legal); /// Copy and widen the instructions from the old loop. @@ -671,6 +675,67 @@ void SingleBlockLoopVectorizer::scalarizeInstruction(Instruction *Instr) { WidenMap[Instr] = VecResults; } +Value* +SingleBlockLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal, + Instruction *Loc) { + LoopVectorizationLegality::RuntimePointerCheck *PtrRtCheck = + Legal->getRuntimePointerCheck(); + + if (!PtrRtCheck->Need) + return NULL; + + Value *MemoryRuntimeCheck = 0; + unsigned NumPointers = PtrRtCheck->Pointers.size(); + SmallVector Starts; + SmallVector Ends; + + SCEVExpander Exp(*SE, "induction"); + + // Use this type for pointer arithmetic. + Type* PtrArithTy = PtrRtCheck->Pointers[0]->getType(); + + for (unsigned i=0; i < NumPointers; ++i) { + Value *Ptr = PtrRtCheck->Pointers[i]; + const SCEV *Sc = SE->getSCEV(Ptr); + + if (SE->isLoopInvariant(Sc, OrigLoop)) { + DEBUG(dbgs() << "LV1: Adding RT check for a loop invariant ptr:" << + *Ptr <<"\n"); + Starts.push_back(Ptr); + Ends.push_back(Ptr); + } else { + DEBUG(dbgs() << "LV: Adding RT check for range:" << *Ptr <<"\n"); + + Value *Start = Exp.expandCodeFor(PtrRtCheck->Starts[i], + PtrArithTy, Loc); + Value *End = Exp.expandCodeFor(PtrRtCheck->Ends[i], PtrArithTy, Loc); + Starts.push_back(Start); + Ends.push_back(End); + } + } + + for (unsigned i = 0; i < NumPointers; ++i) { + for (unsigned j = i+1; j < NumPointers; ++j) { + Value *Cmp0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, + Starts[i], Ends[j], "bound0", Loc); + Value *Cmp1 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, + Starts[j], Ends[i], "bound1", Loc); + Value *IsConflict = BinaryOperator::Create(Instruction::And, Cmp0, Cmp1, + "found.conflict", Loc); + if (MemoryRuntimeCheck) { + MemoryRuntimeCheck = BinaryOperator::Create(Instruction::Or, + MemoryRuntimeCheck, + IsConflict, + "conflict.rdx", Loc); + } else { + MemoryRuntimeCheck = IsConflict; + } + } + } + + return MemoryRuntimeCheck; +} + void SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) { /* @@ -791,56 +856,7 @@ SingleBlockLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) { StartIdx, "cmp.zero", Loc); - LoopVectorizationLegality::RuntimePointerCheck *PtrRtCheck = - Legal->getRuntimePointerCheck(); - Value *MemoryRuntimeCheck = 0; - if (PtrRtCheck->Need) { - unsigned NumPointers = PtrRtCheck->Pointers.size(); - SmallVector Starts; - SmallVector Ends; - - // Use this type for pointer arithmetic. - Type* PtrArithTy = PtrRtCheck->Pointers[0]->getType(); - - for (unsigned i=0; i < NumPointers; ++i) { - Value *Ptr = PtrRtCheck->Pointers[i]; - const SCEV *Sc = SE->getSCEV(Ptr); - - if (SE->isLoopInvariant(Sc, OrigLoop)) { - DEBUG(dbgs() << "LV1: Adding RT check for a loop invariant ptr:" << - *Ptr <<"\n"); - Starts.push_back(Ptr); - Ends.push_back(Ptr); - } else { - DEBUG(dbgs() << "LV: Adding RT check for range:" << *Ptr <<"\n"); - - Value *Start = Exp.expandCodeFor(PtrRtCheck->Starts[i], - PtrArithTy, Loc); - Value *End = Exp.expandCodeFor(PtrRtCheck->Ends[i], PtrArithTy, Loc); - Starts.push_back(Start); - Ends.push_back(End); - } - } - - for (unsigned i = 0; i < NumPointers; ++i) { - for (unsigned j = i+1; j < NumPointers; ++j) { - Value *Cmp0 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, - Starts[i], Ends[j], "bound0", Loc); - Value *Cmp1 = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_ULE, - Starts[j], Ends[i], "bound1", Loc); - Value *IsConflict = BinaryOperator::Create(Instruction::And, Cmp0, Cmp1, - "found.conflict", Loc); - if (MemoryRuntimeCheck) { - MemoryRuntimeCheck = BinaryOperator::Create(Instruction::Or, - MemoryRuntimeCheck, - IsConflict, - "conflict.rdx", Loc); - } else { - MemoryRuntimeCheck = IsConflict; - } - } - } - }// end of need-runtime-check code. + Value *MemoryRuntimeCheck = addRuntimeCheck(Legal, Loc); // If we are using memory runtime checks, include them in. if (MemoryRuntimeCheck) {