X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FIR%2FConstantFold.cpp;h=68ddf096754733cd6d308cc1807471c086731609;hb=5ae08fa0db70dd6b6d9e076b20f89af0ac8ec153;hp=bf93d4f95663534061449634a5552b7daa0ea39c;hpb=8df7c39976890d94198d835e032848a374fec158;p=oota-llvm.git diff --git a/lib/IR/ConstantFold.cpp b/lib/IR/ConstantFold.cpp index bf93d4f9566..68ddf096754 100644 --- a/lib/IR/ConstantFold.cpp +++ b/lib/IR/ConstantFold.cpp @@ -22,17 +22,19 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/GetElementPtrTypeIterator.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/MathExtras.h" #include using namespace llvm; +using namespace llvm::PatternMatch; //===----------------------------------------------------------------------===// // ConstantFold*Instruction Implementations @@ -51,7 +53,7 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) { // Analysis/ConstantFolding.cpp unsigned NumElts = DstTy->getNumElements(); if (NumElts != CV->getType()->getVectorNumElements()) - return 0; + return nullptr; Type *DstEltTy = DstTy->getElementType(); @@ -75,25 +77,26 @@ static unsigned foldConstantCastPair( unsigned opc, ///< opcode of the second cast constant expression ConstantExpr *Op, ///< the first cast constant expression - Type *DstTy ///< desintation type of the first cast + Type *DstTy ///< destination type of the first cast ) { assert(Op && Op->isCast() && "Can't fold cast of cast without a cast!"); assert(DstTy && DstTy->isFirstClassType() && "Invalid cast destination type"); assert(CastInst::isCast(opc) && "Invalid cast opcode"); - // The the types and opcodes for the two Cast constant expressions + // The types and opcodes for the two Cast constant expressions Type *SrcTy = Op->getOperand(0)->getType(); Type *MidTy = Op->getType(); Instruction::CastOps firstOp = Instruction::CastOps(Op->getOpcode()); Instruction::CastOps secondOp = Instruction::CastOps(opc); - // Assume that pointers are never more than 64 bits wide. + // Assume that pointers are never more than 64 bits wide, and only use this + // for the middle type. Otherwise we could end up folding away illegal + // bitcasts between address spaces with different sizes. IntegerType *FakeIntPtrTy = Type::getInt64Ty(DstTy->getContext()); // Let CastInst::isEliminableCastPair do the heavy lifting. return CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, DstTy, - FakeIntPtrTy, FakeIntPtrTy, - FakeIntPtrTy); + nullptr, FakeIntPtrTy, nullptr); } static Constant *FoldBitCast(Constant *V, Type *DestTy) { @@ -129,7 +132,8 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) { if (ElTy == DPTy->getElementType()) // This GEP is inbounds because all indices are zero. - return ConstantExpr::getInBoundsGetElementPtr(V, IdxList); + return ConstantExpr::getInBoundsGetElementPtr(PTy->getElementType(), + V, IdxList); } // Handle casts from one vector constant to another. We know that the src @@ -138,7 +142,7 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) { if (VectorType *SrcTy = dyn_cast(V->getType())) { assert(DestPTy->getBitWidth() == SrcTy->getBitWidth() && "Not cast between same sized vectors!"); - SrcTy = NULL; + SrcTy = nullptr; // First, check for null. Undef is already handled. if (isa(V)) return Constant::getNullValue(DestTy); @@ -166,21 +170,32 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) { // be the same. Consequently, we just fold to V. return V; - if (DestTy->isFloatingPointTy()) + // See note below regarding the PPC_FP128 restriction. + if (DestTy->isFloatingPointTy() && !DestTy->isPPC_FP128Ty()) return ConstantFP::get(DestTy->getContext(), APFloat(DestTy->getFltSemantics(), CI->getValue())); // Otherwise, can't fold this (vector?) - return 0; + return nullptr; } // Handle ConstantFP input: FP -> Integral. - if (ConstantFP *FP = dyn_cast(V)) + if (ConstantFP *FP = dyn_cast(V)) { + // PPC_FP128 is really the sum of two consecutive doubles, where the first + // double is always stored first in memory, regardless of the target + // endianness. The memory layout of i128, however, depends on the target + // endianness, and so we can't fold this without target endianness + // information. This should instead be handled by + // Analysis/ConstantFolding.cpp + if (FP->getType()->isPPC_FP128Ty()) + return nullptr; + return ConstantInt::get(FP->getContext(), FP->getValueAPF().bitcastToAPInt()); + } - return 0; + return nullptr; } @@ -215,14 +230,14 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart, // In the input is a constant expr, we might be able to recursively simplify. // If not, we definitely can't do anything. ConstantExpr *CE = dyn_cast(C); - if (CE == 0) return 0; - + if (!CE) return nullptr; + switch (CE->getOpcode()) { - default: return 0; + default: return nullptr; case Instruction::Or: { Constant *RHS = ExtractConstantBytes(CE->getOperand(1), ByteStart,ByteSize); - if (RHS == 0) - return 0; + if (!RHS) + return nullptr; // X | -1 -> -1. if (ConstantInt *RHSC = dyn_cast(RHS)) @@ -230,32 +245,32 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart, return RHSC; Constant *LHS = ExtractConstantBytes(CE->getOperand(0), ByteStart,ByteSize); - if (LHS == 0) - return 0; + if (!LHS) + return nullptr; return ConstantExpr::getOr(LHS, RHS); } case Instruction::And: { Constant *RHS = ExtractConstantBytes(CE->getOperand(1), ByteStart,ByteSize); - if (RHS == 0) - return 0; + if (!RHS) + return nullptr; // X & 0 -> 0. if (RHS->isNullValue()) return RHS; Constant *LHS = ExtractConstantBytes(CE->getOperand(0), ByteStart,ByteSize); - if (LHS == 0) - return 0; + if (!LHS) + return nullptr; return ConstantExpr::getAnd(LHS, RHS); } case Instruction::LShr: { ConstantInt *Amt = dyn_cast(CE->getOperand(1)); - if (Amt == 0) - return 0; + if (!Amt) + return nullptr; unsigned ShAmt = Amt->getZExtValue(); // Cannot analyze non-byte shifts. if ((ShAmt & 7) != 0) - return 0; + return nullptr; ShAmt >>= 3; // If the extract is known to be all zeros, return zero. @@ -267,17 +282,17 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart, return ExtractConstantBytes(CE->getOperand(0), ByteStart+ShAmt, ByteSize); // TODO: Handle the 'partially zero' case. - return 0; + return nullptr; } case Instruction::Shl: { ConstantInt *Amt = dyn_cast(CE->getOperand(1)); - if (Amt == 0) - return 0; + if (!Amt) + return nullptr; unsigned ShAmt = Amt->getZExtValue(); // Cannot analyze non-byte shifts. if ((ShAmt & 7) != 0) - return 0; + return nullptr; ShAmt >>= 3; // If the extract is known to be all zeros, return zero. @@ -289,7 +304,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart, return ExtractConstantBytes(CE->getOperand(0), ByteStart-ShAmt, ByteSize); // TODO: Handle the 'partially zero' case. - return 0; + return nullptr; } case Instruction::ZExt: { @@ -323,7 +338,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart, } // TODO: Handle the 'partially zero' case. - return 0; + return nullptr; } } } @@ -375,7 +390,7 @@ static Constant *getFoldedSizeOf(Type *Ty, Type *DestTy, // If there's no interesting folding happening, bail so that we don't create // a constant that looks like it needs folding but really doesn't. if (!Folded) - return 0; + return nullptr; // Base case: Get a regular sizeof expression. Constant *C = ConstantExpr::getSizeOf(Ty); @@ -441,7 +456,7 @@ static Constant *getFoldedAlignOf(Type *Ty, Type *DestTy, // If there's no interesting folding happening, bail so that we don't create // a constant that looks like it needs folding but really doesn't. if (!Folded) - return 0; + return nullptr; // Base case: Get a regular alignof expression. Constant *C = ConstantExpr::getAlignOf(Ty); @@ -472,7 +487,7 @@ static Constant *getFoldedOffsetOf(Type *Ty, Constant *FieldNo, unsigned NumElems = STy->getNumElements(); // An empty struct has no members. if (NumElems == 0) - return 0; + return nullptr; // Check for a struct with all members having the same size. Constant *MemberSize = getFoldedSizeOf(STy->getElementType(0), DestTy, true); @@ -496,7 +511,7 @@ static Constant *getFoldedOffsetOf(Type *Ty, Constant *FieldNo, // If there's no interesting folding happening, bail so that we don't create // a constant that looks like it needs folding but really doesn't. if (!Folded) - return 0; + return nullptr; // Base case: Get a regular offsetof expression. Constant *C = ConstantExpr::getOffsetOf(Ty, FieldNo); @@ -528,7 +543,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, // Try hard to fold cast of cast because they are often eliminable. if (unsigned newOpc = foldConstantCastPair(opc, CE, DestTy)) return ConstantExpr::getCast(newOpc, CE->getOperand(0), DestTy); - } else if (CE->getOpcode() == Instruction::GetElementPtr) { + } else if (CE->getOpcode() == Instruction::GetElementPtr && + // Do not fold addrspacecast (gep 0, .., 0). It might make the + // addrspacecast uncanonicalized. + opc != Instruction::AddrSpaceCast) { // If all of the indexes in the GEP are null values, there is no pointer // adjustment going on. We might as well cast the source pointer. bool isAllNull = true; @@ -581,7 +599,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, APFloat::rmNearestTiesToEven, &ignored); return ConstantFP::get(V->getContext(), Val); } - return 0; // Can't fold. + return nullptr; // Can't fold. case Instruction::FPToUI: case Instruction::FPToSI: if (ConstantFP *FPC = dyn_cast(V)) { @@ -589,16 +607,21 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, bool ignored; uint64_t x[2]; uint32_t DestBitWidth = cast(DestTy)->getBitWidth(); - (void) V.convertToInteger(x, DestBitWidth, opc==Instruction::FPToSI, - APFloat::rmTowardZero, &ignored); + if (APFloat::opInvalidOp == + V.convertToInteger(x, DestBitWidth, opc==Instruction::FPToSI, + APFloat::rmTowardZero, &ignored)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return UndefValue::get(DestTy); + } APInt Val(DestBitWidth, x); return ConstantInt::get(FPC->getContext(), Val); } - return 0; // Can't fold. + return nullptr; // Can't fold. case Instruction::IntToPtr: //always treated as unsigned if (V->isNullValue()) // Is it an integral null value? return ConstantPointerNull::get(cast(DestTy)); - return 0; // Other pointer types cannot be casted + return nullptr; // Other pointer types cannot be casted case Instruction::PtrToInt: // always treated as unsigned // Is it a null pointer value? if (V->isNullValue()) @@ -609,8 +632,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, if (ConstantExpr *CE = dyn_cast(V)) if (CE->getOpcode() == Instruction::GetElementPtr && CE->getOperand(0)->isNullValue()) { - Type *Ty = - cast(CE->getOperand(0)->getType())->getElementType(); + GEPOperator *GEPO = cast(CE); + Type *Ty = GEPO->getSourceElementType(); if (CE->getNumOperands() == 2) { // Handle a sizeof-like expression. Constant *Idx = CE->getOperand(1); @@ -642,34 +665,41 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, } } // Other pointer types cannot be casted - return 0; + return nullptr; case Instruction::UIToFP: case Instruction::SIToFP: if (ConstantInt *CI = dyn_cast(V)) { APInt api = CI->getValue(); APFloat apf(DestTy->getFltSemantics(), APInt::getNullValue(DestTy->getPrimitiveSizeInBits())); - (void)apf.convertFromAPInt(api, - opc==Instruction::SIToFP, - APFloat::rmNearestTiesToEven); + if (APFloat::opOverflow & + apf.convertFromAPInt(api, opc==Instruction::SIToFP, + APFloat::rmNearestTiesToEven)) { + // Undefined behavior invoked - the destination type can't represent + // the input constant. + return UndefValue::get(DestTy); + } return ConstantFP::get(V->getContext(), apf); } - return 0; + return nullptr; case Instruction::ZExt: if (ConstantInt *CI = dyn_cast(V)) { uint32_t BitWidth = cast(DestTy)->getBitWidth(); return ConstantInt::get(V->getContext(), CI->getValue().zext(BitWidth)); } - return 0; + return nullptr; case Instruction::SExt: if (ConstantInt *CI = dyn_cast(V)) { uint32_t BitWidth = cast(DestTy)->getBitWidth(); return ConstantInt::get(V->getContext(), CI->getValue().sext(BitWidth)); } - return 0; + return nullptr; case Instruction::Trunc: { + if (V->getType()->isVectorTy()) + return nullptr; + uint32_t DestBitWidth = cast(DestTy)->getBitWidth(); if (ConstantInt *CI = dyn_cast(V)) { return ConstantInt::get(V->getContext(), @@ -684,10 +714,12 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, if (Constant *Res = ExtractConstantBytes(V, 0, DestBitWidth / 8)) return Res; - return 0; + return nullptr; } case Instruction::BitCast: return FoldBitCast(V, DestTy); + case Instruction::AddrSpaceCast: + return nullptr; } } @@ -702,12 +734,21 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond, SmallVector Result; Type *Ty = IntegerType::get(CondV->getContext(), 32); for (unsigned i = 0, e = V1->getType()->getVectorNumElements(); i != e;++i){ - ConstantInt *Cond = dyn_cast(CondV->getOperand(i)); - if (Cond == 0) break; - - Constant *V = Cond->isNullValue() ? V2 : V1; - Constant *Res = ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i)); - Result.push_back(Res); + Constant *V; + Constant *V1Element = ConstantExpr::getExtractElement(V1, + ConstantInt::get(Ty, i)); + Constant *V2Element = ConstantExpr::getExtractElement(V2, + ConstantInt::get(Ty, i)); + Constant *Cond = dyn_cast(CondV->getOperand(i)); + if (V1Element == V2Element) { + V = V1Element; + } else if (isa(Cond)) { + V = isa(V1Element) ? V1Element : V2Element; + } else { + if (!isa(Cond)) break; + V = Cond->isNullValue() ? V2Element : V1Element; + } + Result.push_back(V); } // If we were able to build the vector, return it. @@ -734,7 +775,7 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond, return ConstantExpr::getSelect(Cond, V1, FalseVal->getOperand(2)); } - return 0; + return nullptr; } Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val, @@ -748,35 +789,41 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val, return UndefValue::get(Val->getType()->getVectorElementType()); if (ConstantInt *CIdx = dyn_cast(Idx)) { - uint64_t Index = CIdx->getZExtValue(); // ee({w,x,y,z}, wrong_value) -> undef - if (Index >= Val->getType()->getVectorNumElements()) + if (CIdx->uge(Val->getType()->getVectorNumElements())) return UndefValue::get(Val->getType()->getVectorElementType()); - return Val->getAggregateElement(Index); + return Val->getAggregateElement(CIdx->getZExtValue()); } - return 0; + return nullptr; } Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val, Constant *Elt, Constant *Idx) { + if (isa(Idx)) + return UndefValue::get(Val->getType()); + ConstantInt *CIdx = dyn_cast(Idx); - if (!CIdx) return 0; - const APInt &IdxVal = CIdx->getValue(); - + if (!CIdx) return nullptr; + + unsigned NumElts = Val->getType()->getVectorNumElements(); + if (CIdx->uge(NumElts)) + return UndefValue::get(Val->getType()); + SmallVector Result; - Type *Ty = IntegerType::get(Val->getContext(), 32); - for (unsigned i = 0, e = Val->getType()->getVectorNumElements(); i != e; ++i){ + Result.reserve(NumElts); + auto *Ty = Type::getInt32Ty(Val->getContext()); + uint64_t IdxVal = CIdx->getZExtValue(); + for (unsigned i = 0; i != NumElts; ++i) { if (i == IdxVal) { Result.push_back(Elt); continue; } - Constant *C = - ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i)); + Constant *C = ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i)); Result.push_back(C); } - + return ConstantVector::get(Result); } @@ -791,7 +838,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, return UndefValue::get(VectorType::get(EltTy, MaskNumElts)); // Don't break the bitcode reader hack. - if (isa(Mask)) return 0; + if (isa(Mask)) return nullptr; unsigned SrcNumElts = V1->getType()->getVectorNumElements(); @@ -830,7 +877,7 @@ Constant *llvm::ConstantFoldExtractValueInstruction(Constant *Agg, if (Constant *C = Agg->getAggregateElement(Idxs[0])) return ConstantFoldExtractValueInstruction(C, Idxs.slice(1)); - return 0; + return nullptr; } Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg, @@ -851,8 +898,8 @@ Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg, SmallVector Result; for (unsigned i = 0; i != NumElts; ++i) { Constant *C = Agg->getAggregateElement(i); - if (C == 0) return 0; - + if (!C) return nullptr; + if (Idxs[0] == i) C = ConstantFoldInsertValueInstruction(C, Val, Idxs.slice(1)); @@ -886,49 +933,70 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, return C1; return Constant::getNullValue(C1->getType()); // undef & X -> 0 case Instruction::Mul: { - ConstantInt *CI; - // X * undef -> undef if X is odd or undef - if (((CI = dyn_cast(C1)) && CI->getValue()[0]) || - ((CI = dyn_cast(C2)) && CI->getValue()[0]) || - (isa(C1) && isa(C2))) - return UndefValue::get(C1->getType()); + // undef * undef -> undef + if (isa(C1) && isa(C2)) + return C1; + const APInt *CV; + // X * undef -> undef if X is odd + if (match(C1, m_APInt(CV)) || match(C2, m_APInt(CV))) + if ((*CV)[0]) + return UndefValue::get(C1->getType()); // X * undef -> 0 otherwise return Constant::getNullValue(C1->getType()); } - case Instruction::UDiv: case Instruction::SDiv: + case Instruction::UDiv: + // X / undef -> undef + if (match(C1, m_Zero())) + return C2; + // undef / 0 -> undef // undef / 1 -> undef - if (Opcode == Instruction::UDiv || Opcode == Instruction::SDiv) - if (ConstantInt *CI2 = dyn_cast(C2)) - if (CI2->isOne()) - return C1; - // FALL THROUGH + if (match(C2, m_Zero()) || match(C2, m_One())) + return C1; + // undef / X -> 0 otherwise + return Constant::getNullValue(C1->getType()); case Instruction::URem: case Instruction::SRem: - if (!isa(C2)) // undef / X -> 0 - return Constant::getNullValue(C1->getType()); - return C2; // X / undef -> undef + // X % undef -> undef + if (match(C2, m_Undef())) + return C2; + // undef % 0 -> undef + if (match(C2, m_Zero())) + return C1; + // undef % X -> 0 otherwise + return Constant::getNullValue(C1->getType()); case Instruction::Or: // X | undef -> -1 if (isa(C1) && isa(C2)) // undef | undef -> undef return C1; return Constant::getAllOnesValue(C1->getType()); // undef | X -> ~0 case Instruction::LShr: - if (isa(C2) && isa(C1)) - return C1; // undef lshr undef -> undef - return Constant::getNullValue(C1->getType()); // X lshr undef -> 0 - // undef lshr X -> 0 + // X >>l undef -> undef + if (isa(C2)) + return C2; + // undef >>l 0 -> undef + if (match(C2, m_Zero())) + return C1; + // undef >>l X -> 0 + return Constant::getNullValue(C1->getType()); case Instruction::AShr: - if (!isa(C2)) // undef ashr X --> all ones - return Constant::getAllOnesValue(C1->getType()); - else if (isa(C1)) - return C1; // undef ashr undef -> undef - else - return C1; // X ashr undef --> X + // X >>a undef -> undef + if (isa(C2)) + return C2; + // undef >>a 0 -> undef + if (match(C2, m_Zero())) + return C1; + // TODO: undef >>a X -> undef if the shift is exact + // undef >>a X -> 0 + return Constant::getNullValue(C1->getType()); case Instruction::Shl: - if (isa(C2) && isa(C1)) - return C1; // undef shl undef -> undef - // undef << X -> 0 or X << undef -> 0 + // X << undef -> undef + if (isa(C2)) + return C2; + // undef << 0 -> undef + if (match(C2, m_Zero())) + return C1; + // undef << X -> 0 return Constant::getNullValue(C1->getType()); } } @@ -1070,27 +1138,18 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, return ConstantInt::get(CI1->getContext(), C1V | C2V); case Instruction::Xor: return ConstantInt::get(CI1->getContext(), C1V ^ C2V); - case Instruction::Shl: { - uint32_t shiftAmt = C2V.getZExtValue(); - if (shiftAmt < C1V.getBitWidth()) - return ConstantInt::get(CI1->getContext(), C1V.shl(shiftAmt)); - else - return UndefValue::get(C1->getType()); // too big shift is undef - } - case Instruction::LShr: { - uint32_t shiftAmt = C2V.getZExtValue(); - if (shiftAmt < C1V.getBitWidth()) - return ConstantInt::get(CI1->getContext(), C1V.lshr(shiftAmt)); - else - return UndefValue::get(C1->getType()); // too big shift is undef - } - case Instruction::AShr: { - uint32_t shiftAmt = C2V.getZExtValue(); - if (shiftAmt < C1V.getBitWidth()) - return ConstantInt::get(CI1->getContext(), C1V.ashr(shiftAmt)); - else - return UndefValue::get(C1->getType()); // too big shift is undef - } + case Instruction::Shl: + if (C2V.ult(C1V.getBitWidth())) + return ConstantInt::get(CI1->getContext(), C1V.shl(C2V)); + return UndefValue::get(C1->getType()); // too big shift is undef + case Instruction::LShr: + if (C2V.ult(C1V.getBitWidth())) + return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V)); + return UndefValue::get(C1->getType()); // too big shift is undef + case Instruction::AShr: + if (C2V.ult(C1V.getBitWidth())) + return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V)); + return UndefValue::get(C1->getType()); // too big shift is undef } } @@ -1128,7 +1187,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, (void)C3V.divide(C2V, APFloat::rmNearestTiesToEven); return ConstantFP::get(C1->getContext(), C3V); case Instruction::FRem: - (void)C3V.mod(C2V, APFloat::rmNearestTiesToEven); + (void)C3V.mod(C2V); return ConstantFP::get(C1->getContext(), C3V); } } @@ -1197,7 +1256,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, } // We don't know how to fold this. - return 0; + return nullptr; } /// isZeroSizedType - This type is zero sized if its an array or structure of @@ -1218,9 +1277,9 @@ static bool isMaybeZeroSizedType(Type *Ty) { } /// IdxCompare - Compare the two constants as though they were getelementptr -/// indices. This allows coersion of the types to be the same thing. +/// indices. This allows coercion of the types to be the same thing. /// -/// If the two constants are the "same" (after coersion), return 0. If the +/// If the two constants are the "same" (after coercion), return 0. If the /// first is less than the second, return -1, if the second is less than the /// first, return 1. If the constants are not integral, return -2. /// @@ -1232,15 +1291,17 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) { if (!isa(C1) || !isa(C2)) return -2; // don't know! - // Ok, we have two differing integer indices. Sign extend them to be the same - // type. Long is always big enough, so we use it. - if (!C1->getType()->isIntegerTy(64)) - C1 = ConstantExpr::getSExt(C1, Type::getInt64Ty(C1->getContext())); + // We cannot compare the indices if they don't fit in an int64_t. + if (cast(C1)->getValue().getActiveBits() > 64 || + cast(C2)->getValue().getActiveBits() > 64) + return -2; // don't know! - if (!C2->getType()->isIntegerTy(64)) - C2 = ConstantExpr::getSExt(C2, Type::getInt64Ty(C1->getContext())); + // Ok, we have two differing integer indices. Sign extend them to be the same + // type. + int64_t C1Val = cast(C1)->getSExtValue(); + int64_t C2Val = cast(C2)->getSExtValue(); - if (C1 == C2) return 0; // They are equal + if (C1Val == C2Val) return 0; // They are equal // If the type being indexed over is really just a zero sized type, there is // no pointer difference being made here. @@ -1249,8 +1310,7 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) { // If they are really different, now that they are the same type, then we // found a difference! - if (cast(C1)->getSExtValue() < - cast(C2)->getSExtValue()) + if (C1Val < C2Val) return -1; else return 1; @@ -1276,8 +1336,8 @@ static FCmpInst::Predicate evaluateFCmpRelation(Constant *V1, Constant *V2) { if (!isa(V1)) { if (!isa(V2)) { - // We distilled thisUse the standard constant folder for a few cases - ConstantInt *R = 0; + // Simple case, use the standard constant folder. + ConstantInt *R = nullptr; R = dyn_cast( ConstantExpr::getFCmp(FCmpInst::FCMP_OEQ, V1, V2)); if (R && !R->isZero()) @@ -1319,6 +1379,30 @@ static FCmpInst::Predicate evaluateFCmpRelation(Constant *V1, Constant *V2) { return FCmpInst::BAD_FCMP_PREDICATE; } +static ICmpInst::Predicate areGlobalsPotentiallyEqual(const GlobalValue *GV1, + const GlobalValue *GV2) { + auto isGlobalUnsafeForEquality = [](const GlobalValue *GV) { + if (GV->hasExternalWeakLinkage() || GV->hasWeakAnyLinkage()) + return true; + if (const auto *GVar = dyn_cast(GV)) { + Type *Ty = GVar->getValueType(); + // A global with opaque type might end up being zero sized. + if (!Ty->isSized()) + return true; + // A global with an empty type might lie at the address of any other + // global. + if (Ty->isEmptyTy()) + return true; + } + return false; + }; + // Don't try to decide equality of aliases. + if (!isa(GV1) && !isa(GV2)) + if (!isGlobalUnsafeForEquality(GV1) && !isGlobalUnsafeForEquality(GV2)) + return ICmpInst::ICMP_NE; + return ICmpInst::BAD_ICMP_PREDICATE; +} + /// evaluateICmpRelation - This function determines if there is anything we can /// decide about the two constants provided. This doesn't need to handle simple /// things like integer comparisons, but should instead handle ConstantExprs @@ -1343,7 +1427,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2, !isa(V2)) { // We distilled this down to a simple case, use the standard constant // folder. - ConstantInt *R = 0; + ConstantInt *R = nullptr; ICmpInst::Predicate pred = ICmpInst::ICMP_EQ; R = dyn_cast(ConstantExpr::getICmp(pred, V1, V2)); if (R && !R->isZero()) @@ -1380,10 +1464,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2, // constant (which, since the types must match, means that it's a // ConstantPointerNull). if (const GlobalValue *GV2 = dyn_cast(V2)) { - // Don't try to decide equality of aliases. - if (!isa(GV) && !isa(GV2)) - if (!GV->hasExternalWeakLinkage() || !GV2->hasExternalWeakLinkage()) - return ICmpInst::ICMP_NE; + return areGlobalsPotentiallyEqual(GV, GV2); } else if (isa(V2)) { return ICmpInst::ICMP_NE; // Globals never equal labels. } else { @@ -1448,7 +1529,8 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2, } break; - case Instruction::GetElementPtr: + case Instruction::GetElementPtr: { + GEPOperator *CE1GEP = cast(CE1); // Ok, since this is a getelementptr, we know that the constant has a // pointer type. Check the various cases. if (isa(V2)) { @@ -1495,7 +1577,8 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2, "Surprising getelementptr!"); return isSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; } else { - // If they are different globals, we don't know what the value is. + if (CE1GEP->hasAllZeroIndices()) + return areGlobalsPotentiallyEqual(GV, GV2); return ICmpInst::BAD_ICMP_PREDICATE; } } @@ -1511,8 +1594,14 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2, // By far the most common case to handle is when the base pointers are // obviously to the same global. if (isa(CE1Op0) && isa(CE2Op0)) { - if (CE1Op0 != CE2Op0) // Don't know relative ordering. + // Don't know relative ordering, but check for inequality. + if (CE1Op0 != CE2Op0) { + GEPOperator *CE2GEP = cast(CE2); + if (CE1GEP->hasAllZeroIndices() && CE2GEP->hasAllZeroIndices()) + return areGlobalsPotentiallyEqual(cast(CE1Op0), + cast(CE2Op0)); return ICmpInst::BAD_ICMP_PREDICATE; + } // Ok, we know that both getelementptr instructions are based on the // same global. From this, we can precisely determine the relative // ordering of the resultant pointers. @@ -1558,6 +1647,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2, } } } + } default: break; } @@ -1584,15 +1674,22 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, // Handle some degenerate cases first if (isa(C1) || isa(C2)) { + CmpInst::Predicate Predicate = CmpInst::Predicate(pred); + bool isIntegerPredicate = ICmpInst::isIntPredicate(Predicate); // For EQ and NE, we can always pick a value for the undef to make the // predicate pass or fail, so we can return undef. - // Also, if both operands are undef, we can return undef. - if (ICmpInst::isEquality(ICmpInst::Predicate(pred)) || - (isa(C1) && isa(C2))) + // Also, if both operands are undef, we can return undef for int comparison. + if (ICmpInst::isEquality(Predicate) || (isIntegerPredicate && C1 == C2)) return UndefValue::get(ResultTy); - // Otherwise, pick the same value as the non-undef operand, and fold - // it to true or false. - return ConstantInt::get(ResultTy, CmpInst::isTrueWhenEqual(pred)); + + // Otherwise, for integer compare, pick the same value as the non-undef + // operand, and fold it to true or false. + if (isIntegerPredicate) + return ConstantInt::get(ResultTy, CmpInst::isTrueWhenEqual(pred)); + + // Choosing NaN for the undef will always make unordered comparison succeed + // and ordered comparison fails. + return ConstantInt::get(ResultTy, CmpInst::isUnordered(Predicate)); } // icmp eq/ne(null,GV) -> false/true @@ -1708,7 +1805,10 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, return ConstantVector::get(ResElts); } - if (C1->getType()->isFloatingPointTy()) { + if (C1->getType()->isFloatingPointTy() && + // Only call evaluateFCmpRelation if we have a constant expr to avoid + // infinite recursive loop + (isa(C1) || isa(C2))) { int Result = -1; // -1 = unknown, 0 = known false, 1 = known true. switch (evaluateFCmpRelation(C1, C2)) { default: llvm_unreachable("Unknown relation!"); @@ -1857,9 +1957,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, if (CE1Inverse == CE1Op0) { // Check whether we can safely truncate the right hand side. Constant *C2Inverse = ConstantExpr::getTrunc(C2, CE1Op0->getType()); - if (ConstantExpr::getZExt(C2Inverse, C2->getType()) == C2) { + if (ConstantExpr::getCast(CE1->getOpcode(), C2Inverse, + C2->getType()) == C2) return ConstantExpr::getICmp(pred, CE1Inverse, C2Inverse); - } } } } @@ -1873,7 +1973,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, return ConstantExpr::getICmp(pred, C2, C1); } } - return 0; + return nullptr; } /// isInBoundsIndices - Test whether the given sequence of *normalized* indices @@ -1896,8 +1996,39 @@ static bool isInBoundsIndices(ArrayRef Idxs) { return true; } +/// \brief Test whether a given ConstantInt is in-range for a SequentialType. +static bool isIndexInRangeOfSequentialType(SequentialType *STy, + const ConstantInt *CI) { + // And indices are valid when indexing along a pointer + if (isa(STy)) + return true; + + uint64_t NumElements = 0; + // Determine the number of elements in our sequential type. + if (auto *ATy = dyn_cast(STy)) + NumElements = ATy->getNumElements(); + else if (auto *VTy = dyn_cast(STy)) + NumElements = VTy->getNumElements(); + + assert((isa(STy) || NumElements > 0) && + "didn't expect non-array type to have zero elements!"); + + // We cannot bounds check the index if it doesn't fit in an int64_t. + if (CI->getValue().getActiveBits() > 64) + return false; + + // A negative index or an index past the end of our sequential type is + // considered out-of-range. + int64_t IndexVal = CI->getSExtValue(); + if (IndexVal < 0 || (NumElements > 0 && (uint64_t)IndexVal >= NumElements)) + return false; + + // Otherwise, it is in-range. + return true; +} + template -static Constant *ConstantFoldGetElementPtrImpl(Constant *C, +static Constant *ConstantFoldGetElementPtrImpl(Type *PointeeTy, Constant *C, bool inBounds, ArrayRef Idxs) { if (Idxs.empty()) return C; @@ -1907,8 +2038,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C, if (isa(C)) { PointerType *Ptr = cast(C->getType()); - Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs); - assert(Ty != 0 && "Invalid indices for GEP!"); + Type *Ty = GetElementPtrInst::getIndexedType( + cast(Ptr->getScalarType())->getElementType(), Idxs); + assert(Ty && "Invalid indices for GEP!"); return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace())); } @@ -1921,8 +2053,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C, } if (isNull) { PointerType *Ptr = cast(C->getType()); - Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs); - assert(Ty != 0 && "Invalid indices for GEP!"); + Type *Ty = GetElementPtrInst::getIndexedType( + cast(Ptr->getScalarType())->getElementType(), Idxs); + assert(Ty && "Invalid indices for GEP!"); return ConstantPointerNull::get(PointerType::get(Ty, Ptr->getAddressSpace())); } @@ -1934,16 +2067,40 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C, // getelementptr instructions into a single instruction. // if (CE->getOpcode() == Instruction::GetElementPtr) { - Type *LastTy = 0; + Type *LastTy = nullptr; for (gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE); I != E; ++I) LastTy = *I; - if ((LastTy && isa(LastTy)) || Idx0->isNullValue()) { + // We cannot combine indices if doing so would take us outside of an + // array or vector. Doing otherwise could trick us if we evaluated such a + // GEP as part of a load. + // + // e.g. Consider if the original GEP was: + // i8* getelementptr ({ [2 x i8], i32, i8, [3 x i8] }* @main.c, + // i32 0, i32 0, i64 0) + // + // If we then tried to offset it by '8' to get to the third element, + // an i8, we should *not* get: + // i8* getelementptr ({ [2 x i8], i32, i8, [3 x i8] }* @main.c, + // i32 0, i32 0, i64 8) + // + // This GEP tries to index array element '8 which runs out-of-bounds. + // Subsequent evaluation would get confused and produce erroneous results. + // + // The following prohibits such a GEP from being formed by checking to see + // if the index is in-range with respect to an array or vector. + bool PerformFold = false; + if (Idx0->isNullValue()) + PerformFold = true; + else if (SequentialType *STy = dyn_cast_or_null(LastTy)) + if (ConstantInt *CI = dyn_cast(Idx0)) + PerformFold = isIndexInRangeOfSequentialType(STy, CI); + + if (PerformFold) { SmallVector NewIndices; NewIndices.reserve(Idxs.size() + CE->getNumOperands()); - for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i) - NewIndices.push_back(CE->getOperand(i)); + NewIndices.append(CE->op_begin() + 1, CE->op_end() - 1); // Add the last index of the source with the first index of the new GEP. // Make sure to handle the case when they are actually different types. @@ -1952,9 +2109,15 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C, if (!Idx0->isNullValue()) { Type *IdxTy = Combined->getType(); if (IdxTy != Idx0->getType()) { - Type *Int64Ty = Type::getInt64Ty(IdxTy->getContext()); - Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, Int64Ty); - Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, Int64Ty); + unsigned CommonExtendedWidth = + std::max(IdxTy->getIntegerBitWidth(), + Idx0->getType()->getIntegerBitWidth()); + CommonExtendedWidth = std::max(CommonExtendedWidth, 64U); + + Type *CommonTy = + Type::getIntNTy(IdxTy->getContext(), CommonExtendedWidth); + Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, CommonTy); + Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, CommonTy); Combined = ConstantExpr::get(Instruction::Add, C1, C2); } else { Combined = @@ -1964,10 +2127,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C, NewIndices.push_back(Combined); NewIndices.append(Idxs.begin() + 1, Idxs.end()); - return - ConstantExpr::getGetElementPtr(CE->getOperand(0), NewIndices, - inBounds && - cast(CE)->isInBounds()); + return ConstantExpr::getGetElementPtr( + cast(CE)->getSourceElementType(), CE->getOperand(0), + NewIndices, inBounds && cast(CE)->isInBounds()); } } @@ -1992,45 +2154,55 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C, if (SrcArrayTy && DstArrayTy && SrcArrayTy->getElementType() == DstArrayTy->getElementType() && SrcPtrTy->getAddressSpace() == DstPtrTy->getAddressSpace()) - return ConstantExpr::getGetElementPtr((Constant*)CE->getOperand(0), - Idxs, inBounds); + return ConstantExpr::getGetElementPtr( + SrcArrayTy, (Constant *)CE->getOperand(0), Idxs, inBounds); } } } // Check to see if any array indices are not within the corresponding - // notional array bounds. If so, try to determine if they can be factored - // out into preceding dimensions. - bool Unknown = false; + // notional array or vector bounds. If so, try to determine if they can be + // factored out into preceding dimensions. SmallVector NewIdxs; - Type *Ty = C->getType(); - Type *Prev = 0; - for (unsigned i = 0, e = Idxs.size(); i != e; + Type *Ty = PointeeTy; + Type *Prev = C->getType(); + bool Unknown = !isa(Idxs[0]); + for (unsigned i = 1, e = Idxs.size(); i != e; Prev = Ty, Ty = cast(Ty)->getTypeAtIndex(Idxs[i]), ++i) { if (ConstantInt *CI = dyn_cast(Idxs[i])) { - if (ArrayType *ATy = dyn_cast(Ty)) - if (ATy->getNumElements() <= INT64_MAX && - ATy->getNumElements() != 0 && - CI->getSExtValue() >= (int64_t)ATy->getNumElements()) { + if (isa(Ty) || isa(Ty)) + if (CI->getSExtValue() > 0 && + !isIndexInRangeOfSequentialType(cast(Ty), CI)) { if (isa(Prev)) { // It's out of range, but we can factor it into the prior // dimension. NewIdxs.resize(Idxs.size()); - ConstantInt *Factor = ConstantInt::get(CI->getType(), - ATy->getNumElements()); + uint64_t NumElements = 0; + if (auto *ATy = dyn_cast(Ty)) + NumElements = ATy->getNumElements(); + else + NumElements = cast(Ty)->getNumElements(); + + ConstantInt *Factor = ConstantInt::get(CI->getType(), NumElements); NewIdxs[i] = ConstantExpr::getSRem(CI, Factor); Constant *PrevIdx = cast(Idxs[i-1]); Constant *Div = ConstantExpr::getSDiv(CI, Factor); + unsigned CommonExtendedWidth = + std::max(PrevIdx->getType()->getIntegerBitWidth(), + Div->getType()->getIntegerBitWidth()); + CommonExtendedWidth = std::max(CommonExtendedWidth, 64U); + // Before adding, extend both operands to i64 to avoid // overflow trouble. - if (!PrevIdx->getType()->isIntegerTy(64)) - PrevIdx = ConstantExpr::getSExt(PrevIdx, - Type::getInt64Ty(Div->getContext())); - if (!Div->getType()->isIntegerTy(64)) - Div = ConstantExpr::getSExt(Div, - Type::getInt64Ty(Div->getContext())); + if (!PrevIdx->getType()->isIntegerTy(CommonExtendedWidth)) + PrevIdx = ConstantExpr::getSExt( + PrevIdx, + Type::getIntNTy(Div->getContext(), CommonExtendedWidth)); + if (!Div->getType()->isIntegerTy(CommonExtendedWidth)) + Div = ConstantExpr::getSExt( + Div, Type::getIntNTy(Div->getContext(), CommonExtendedWidth)); NewIdxs[i-1] = ConstantExpr::getAdd(PrevIdx, Div); } else { @@ -2049,26 +2221,43 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C, if (!NewIdxs.empty()) { for (unsigned i = 0, e = Idxs.size(); i != e; ++i) if (!NewIdxs[i]) NewIdxs[i] = cast(Idxs[i]); - return ConstantExpr::getGetElementPtr(C, NewIdxs, inBounds); + return ConstantExpr::getGetElementPtr(PointeeTy, C, NewIdxs, inBounds); } // If all indices are known integers and normalized, we can do a simple // check for the "inbounds" property. - if (!Unknown && !inBounds && - isa(C) && isInBoundsIndices(Idxs)) - return ConstantExpr::getInBoundsGetElementPtr(C, Idxs); + if (!Unknown && !inBounds) + if (auto *GV = dyn_cast(C)) + if (!GV->hasExternalWeakLinkage() && isInBoundsIndices(Idxs)) + return ConstantExpr::getInBoundsGetElementPtr(PointeeTy, C, Idxs); - return 0; + return nullptr; } Constant *llvm::ConstantFoldGetElementPtr(Constant *C, bool inBounds, ArrayRef Idxs) { - return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs); + return ConstantFoldGetElementPtrImpl( + cast(C->getType()->getScalarType())->getElementType(), C, + inBounds, Idxs); } Constant *llvm::ConstantFoldGetElementPtr(Constant *C, bool inBounds, ArrayRef Idxs) { - return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs); + return ConstantFoldGetElementPtrImpl( + cast(C->getType()->getScalarType())->getElementType(), C, + inBounds, Idxs); +} + +Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C, + bool inBounds, + ArrayRef Idxs) { + return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs); +} + +Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C, + bool inBounds, + ArrayRef Idxs) { + return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs); }