IR: Allow vectors of halfs to be ConstantDataVectors
[oota-llvm.git] / unittests / IR / ConstantsTest.cpp
index db783f72d46760f8b3e2d3cdf40e511780fb847c..8c33453d293dfb6432bf7fabaf7caf219a2c4bed 100644 (file)
@@ -7,12 +7,14 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/Support/SourceMgr.h"
 #include "gtest/gtest.h"
 
 namespace llvm {
@@ -151,16 +153,18 @@ TEST(ConstantsTest, PointerCast) {
               Constant::getNullValue(Int8PtrVecTy), Int32PtrVecTy));
 }
 
-#define CHECK(x, y) {                                           \
-    std::string __s;                                            \
-    raw_string_ostream __o(__s);                                \
-    cast<ConstantExpr>(x)->getAsInstruction()->print(__o);      \
-    __o.flush();                                                \
-    EXPECT_EQ(std::string("  <badref> = " y), __s);             \
+#define CHECK(x, y) {                                                  \
+    std::string __s;                                                   \
+    raw_string_ostream __o(__s);                                       \
+    Instruction *__I = cast<ConstantExpr>(x)->getAsInstruction();      \
+    __I->print(__o);                                                   \
+    delete __I;                                                        \
+    __o.flush();                                                       \
+    EXPECT_EQ(std::string("  <badref> = " y), __s);                    \
   }
 
 TEST(ConstantsTest, AsInstructionsTest) {
-  Module *M = new Module("MyModule", getGlobalContext());
+  std::unique_ptr<Module> M(new Module("MyModule", getGlobalContext()));
 
   Type *Int64Ty = Type::getInt64Ty(getGlobalContext());
   Type *Int32Ty = Type::getInt32Ty(getGlobalContext());
@@ -183,6 +187,13 @@ TEST(ConstantsTest, AsInstructionsTest) {
   Constant *P6 = ConstantExpr::getBitCast(P4, VectorType::get(Int16Ty, 2));
 
   Constant *One = ConstantInt::get(Int32Ty, 1);
+  Constant *Two = ConstantInt::get(Int64Ty, 2);
+  Constant *Big = ConstantInt::get(getGlobalContext(),
+                                   APInt{256, uint64_t(-1), true});
+  Constant *Elt = ConstantInt::get(Int16Ty, 2015);
+  Constant *Undef16  = UndefValue::get(Int16Ty);
+  Constant *Undef64  = UndefValue::get(Int64Ty);
+  Constant *UndefV16 = UndefValue::get(P6->getType());
 
   #define P0STR "ptrtoint (i32** @dummy to i32)"
   #define P1STR "uitofp (i32 ptrtoint (i32** @dummy to i32) to float)"
@@ -244,15 +255,160 @@ TEST(ConstantsTest, AsInstructionsTest) {
   // FIXME: getGetElementPtr() actually creates an inbounds ConstantGEP,
   //        not a normal one!
   //CHECK(ConstantExpr::getGetElementPtr(Global, V, false),
-  //      "getelementptr i32** @dummy, i32 1");
-  CHECK(ConstantExpr::getInBoundsGetElementPtr(Global, V),
-        "getelementptr inbounds i32** @dummy, i32 1");
+  //      "getelementptr i32*, i32** @dummy, i32 1");
+  CHECK(ConstantExpr::getInBoundsGetElementPtr(PointerType::getUnqual(Int32Ty),
+                                               Global, V),
+        "getelementptr inbounds i32*, i32** @dummy, i32 1");
 
   CHECK(ConstantExpr::getExtractElement(P6, One), "extractelement <2 x i16> "
         P6STR ", i32 1");
+
+  EXPECT_EQ(Undef16, ConstantExpr::getExtractElement(P6, Two));
+  EXPECT_EQ(Undef16, ConstantExpr::getExtractElement(P6, Big));
+  EXPECT_EQ(Undef16, ConstantExpr::getExtractElement(P6, Undef64));
+
+  EXPECT_EQ(Elt, ConstantExpr::getExtractElement(
+                 ConstantExpr::getInsertElement(P6, Elt, One), One));
+  EXPECT_EQ(UndefV16, ConstantExpr::getInsertElement(P6, Elt, Two));
+  EXPECT_EQ(UndefV16, ConstantExpr::getInsertElement(P6, Elt, Big));
+  EXPECT_EQ(UndefV16, ConstantExpr::getInsertElement(P6, Elt, Undef64));
+}
+
+#ifdef GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(ConstantsTest, ReplaceWithConstantTest) {
+  std::unique_ptr<Module> M(new Module("MyModule", getGlobalContext()));
+
+  Type *Int32Ty = Type::getInt32Ty(getGlobalContext());
+  Constant *One = ConstantInt::get(Int32Ty, 1);
+
+  Constant *Global =
+      M->getOrInsertGlobal("dummy", PointerType::getUnqual(Int32Ty));
+  Constant *GEP = ConstantExpr::getGetElementPtr(
+      PointerType::getUnqual(Int32Ty), Global, One);
+  EXPECT_DEATH(Global->replaceAllUsesWith(GEP),
+               "this->replaceAllUsesWith\\(expr\\(this\\)\\) is NOT valid!");
 }
 
+#endif
+#endif
+
 #undef CHECK
 
+TEST(ConstantsTest, ConstantArrayReplaceWithConstant) {
+  LLVMContext Context;
+  std::unique_ptr<Module> M(new Module("MyModule", Context));
+
+  Type *IntTy = Type::getInt8Ty(Context);
+  ArrayType *ArrayTy = ArrayType::get(IntTy, 2);
+  Constant *A01Vals[2] = {ConstantInt::get(IntTy, 0),
+                          ConstantInt::get(IntTy, 1)};
+  Constant *A01 = ConstantArray::get(ArrayTy, A01Vals);
+
+  Constant *Global = new GlobalVariable(*M, IntTy, false,
+                                        GlobalValue::ExternalLinkage, nullptr);
+  Constant *GlobalInt = ConstantExpr::getPtrToInt(Global, IntTy);
+  Constant *A0GVals[2] = {ConstantInt::get(IntTy, 0), GlobalInt};
+  Constant *A0G = ConstantArray::get(ArrayTy, A0GVals);
+  ASSERT_NE(A01, A0G);
+
+  GlobalVariable *RefArray =
+      new GlobalVariable(*M, ArrayTy, false, GlobalValue::ExternalLinkage, A0G);
+  ASSERT_EQ(A0G, RefArray->getInitializer());
+
+  GlobalInt->replaceAllUsesWith(ConstantInt::get(IntTy, 1));
+  ASSERT_EQ(A01, RefArray->getInitializer());
+}
+
+TEST(ConstantsTest, ConstantExprReplaceWithConstant) {
+  LLVMContext Context;
+  std::unique_ptr<Module> M(new Module("MyModule", Context));
+
+  Type *IntTy = Type::getInt8Ty(Context);
+  Constant *G1 = new GlobalVariable(*M, IntTy, false,
+                                    GlobalValue::ExternalLinkage, nullptr);
+  Constant *G2 = new GlobalVariable(*M, IntTy, false,
+                                    GlobalValue::ExternalLinkage, nullptr);
+  ASSERT_NE(G1, G2);
+
+  Constant *Int1 = ConstantExpr::getPtrToInt(G1, IntTy);
+  Constant *Int2 = ConstantExpr::getPtrToInt(G2, IntTy);
+  ASSERT_NE(Int1, Int2);
+
+  GlobalVariable *Ref =
+      new GlobalVariable(*M, IntTy, false, GlobalValue::ExternalLinkage, Int1);
+  ASSERT_EQ(Int1, Ref->getInitializer());
+
+  G1->replaceAllUsesWith(G2);
+  ASSERT_EQ(Int2, Ref->getInitializer());
+}
+
+TEST(ConstantsTest, GEPReplaceWithConstant) {
+  LLVMContext Context;
+  std::unique_ptr<Module> M(new Module("MyModule", Context));
+
+  Type *IntTy = Type::getInt32Ty(Context);
+  Type *PtrTy = PointerType::get(IntTy, 0);
+  auto *C1 = ConstantInt::get(IntTy, 1);
+  auto *Placeholder = new GlobalVariable(
+      *M, IntTy, false, GlobalValue::ExternalWeakLinkage, nullptr);
+  auto *GEP = ConstantExpr::getGetElementPtr(IntTy, Placeholder, C1);
+  ASSERT_EQ(GEP->getOperand(0), Placeholder);
+
+  auto *Ref =
+      new GlobalVariable(*M, PtrTy, false, GlobalValue::ExternalLinkage, GEP);
+  ASSERT_EQ(GEP, Ref->getInitializer());
+
+  auto *Global = new GlobalVariable(*M, PtrTy, false,
+                                    GlobalValue::ExternalLinkage, nullptr);
+  auto *Alias = GlobalAlias::create(IntTy, 0, GlobalValue::ExternalLinkage,
+                                    "alias", Global, M.get());
+  Placeholder->replaceAllUsesWith(Alias);
+  ASSERT_EQ(GEP, Ref->getInitializer());
+  ASSERT_EQ(GEP->getOperand(0), Alias);
+}
+
+TEST(ConstantsTest, AliasCAPI) {
+  LLVMContext Context;
+  SMDiagnostic Error;
+  std::unique_ptr<Module> M =
+      parseAssemblyString("@g = global i32 42", Error, Context);
+  GlobalVariable *G = M->getGlobalVariable("g");
+  Type *I16Ty = Type::getInt16Ty(Context);
+  Type *I16PTy = PointerType::get(I16Ty, 0);
+  Constant *Aliasee = ConstantExpr::getBitCast(G, I16PTy);
+  LLVMValueRef AliasRef =
+      LLVMAddAlias(wrap(M.get()), wrap(I16PTy), wrap(Aliasee), "a");
+  ASSERT_EQ(unwrap<GlobalAlias>(AliasRef)->getAliasee(), Aliasee);
+}
+
+static std::string getNameOfType(Type *T) {
+  std::string S;
+  raw_string_ostream RSOS(S);
+  T->print(RSOS);
+  return S;
+}
+
+TEST(ConstantsTest, BuildConstantDataVectors) {
+  LLVMContext Context;
+  std::unique_ptr<Module> M(new Module("MyModule", Context));
+
+  for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context),
+                  Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) {
+    Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)};
+    Constant *CDV = ConstantVector::get(Vals);
+    ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
+        << " T = " << getNameOfType(T);
+  }
+
+  for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context),
+                  Type::getDoubleTy(Context)}) {
+    Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)};
+    Constant *CDV = ConstantVector::get(Vals);
+    ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
+        << " T = " << getNameOfType(T);
+  }
+}
+
 }  // end anonymous namespace
 }  // end namespace llvm