Previously, all operands to Constant were themselves constant.
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 2c0a67f1d0435da860085a245830c643adc49e12..7f713d15c67a4f00dc30da0aada133324ab24f16 100644 (file)
@@ -215,7 +215,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
   switch (CE->getOpcode()) {
   default: return 0;
   case Instruction::Or: {
-    Constant *RHS = ExtractConstantBytes(C->getOperand(1), ByteStart, ByteSize);
+    Constant *RHS = ExtractConstantBytes(CE->getOperand(1), ByteStart,ByteSize);
     if (RHS == 0)
       return 0;
     
@@ -224,13 +224,13 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
       if (RHSC->isAllOnesValue())
         return RHSC;
     
-    Constant *LHS = ExtractConstantBytes(C->getOperand(0), ByteStart, ByteSize);
+    Constant *LHS = ExtractConstantBytes(CE->getOperand(0), ByteStart,ByteSize);
     if (LHS == 0)
       return 0;
     return ConstantExpr::getOr(LHS, RHS);
   }
   case Instruction::And: {
-    Constant *RHS = ExtractConstantBytes(C->getOperand(1), ByteStart, ByteSize);
+    Constant *RHS = ExtractConstantBytes(CE->getOperand(1), ByteStart,ByteSize);
     if (RHS == 0)
       return 0;
     
@@ -238,7 +238,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
     if (RHS->isNullValue())
       return RHS;
     
-    Constant *LHS = ExtractConstantBytes(C->getOperand(0), ByteStart, ByteSize);
+    Constant *LHS = ExtractConstantBytes(CE->getOperand(0), ByteStart,ByteSize);
     if (LHS == 0)
       return 0;
     return ConstantExpr::getAnd(LHS, RHS);
@@ -259,7 +259,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
                                                      ByteSize*8));
     // If the extract is known to be fully in the input, extract it.
     if (ByteStart+ByteSize+ShAmt <= CSize)
-      return ExtractConstantBytes(C->getOperand(0), ByteStart+ShAmt, ByteSize);
+      return ExtractConstantBytes(CE->getOperand(0), ByteStart+ShAmt, ByteSize);
     
     // TODO: Handle the 'partially zero' case.
     return 0;
@@ -281,7 +281,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
                                                      ByteSize*8));
     // If the extract is known to be fully in the input, extract it.
     if (ByteStart >= ShAmt)
-      return ExtractConstantBytes(C->getOperand(0), ByteStart-ShAmt, ByteSize);
+      return ExtractConstantBytes(CE->getOperand(0), ByteStart-ShAmt, ByteSize);
     
     // TODO: Handle the 'partially zero' case.
     return 0;
@@ -289,7 +289,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
       
   case Instruction::ZExt: {
     unsigned SrcBitSize =
-      cast<IntegerType>(C->getOperand(0)->getType())->getBitWidth();
+      cast<IntegerType>(CE->getOperand(0)->getType())->getBitWidth();
     
     // If extracting something that is completely zero, return 0.
     if (ByteStart*8 >= SrcBitSize)
@@ -298,18 +298,18 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
 
     // If exactly extracting the input, return it.
     if (ByteStart == 0 && ByteSize*8 == SrcBitSize)
-      return C->getOperand(0);
+      return CE->getOperand(0);
     
     // If extracting something completely in the input, if if the input is a
     // multiple of 8 bits, recurse.
     if ((SrcBitSize&7) == 0 && (ByteStart+ByteSize)*8 <= SrcBitSize)
-      return ExtractConstantBytes(C->getOperand(0), ByteStart, ByteSize);
+      return ExtractConstantBytes(CE->getOperand(0), ByteStart, ByteSize);
       
     // Otherwise, if extracting a subset of the input, which is not multiple of
     // 8 bits, do a shift and trunc to get the bits.
     if ((ByteStart+ByteSize)*8 < SrcBitSize) {
       assert((SrcBitSize&7) && "Shouldn't get byte sized case here");
-      Constant *Res = C->getOperand(0);
+      Constant *Res = CE->getOperand(0);
       if (ByteStart)
         Res = ConstantExpr::getLShr(Res, 
                                  ConstantInt::get(Res->getType(), ByteStart*8));
@@ -634,7 +634,15 @@ Constant *llvm::ConstantFoldExtractValueInstruction(LLVMContext &Context,
                                                               Idxs + NumIdx));
 
   // Otherwise recurse.
-  return ConstantFoldExtractValueInstruction(Context, Agg->getOperand(*Idxs),
+  if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Agg))
+    return ConstantFoldExtractValueInstruction(Context, CS->getOperand(*Idxs),
+                                               Idxs+1, NumIdx-1);
+
+  if (ConstantArray *CA = dyn_cast<ConstantArray>(Agg))
+    return ConstantFoldExtractValueInstruction(Context, CA->getOperand(*Idxs),
+                                               Idxs+1, NumIdx-1);
+  ConstantVector *CV = cast<ConstantVector>(Agg);
+  return ConstantFoldExtractValueInstruction(Context, CV->getOperand(*Idxs),
                                              Idxs+1, NumIdx-1);
 }
 
@@ -714,11 +722,10 @@ Constant *llvm::ConstantFoldInsertValueInstruction(LLVMContext &Context,
     // Insertion of constant into aggregate constant.
     std::vector<Constant*> Ops(Agg->getNumOperands());
     for (unsigned i = 0; i < Agg->getNumOperands(); ++i) {
-      Constant *Op =
-        (*Idxs == i) ?
-        ConstantFoldInsertValueInstruction(Context, Agg->getOperand(i),
-                                           Val, Idxs+1, NumIdx-1) :
-        Agg->getOperand(i);
+      Constant *Op = cast<Constant>(Agg->getOperand(i));
+      if (*Idxs == i)
+        Op = ConstantFoldInsertValueInstruction(Context, Op,
+                                                Val, Idxs+1, NumIdx-1);
       Ops[i] = Op;
     }