When a function takes a variable number of pointer arguments, with a zero
[oota-llvm.git] / lib / Transforms / IPO / SimplifyLibCalls.cpp
index ae76a4d11bc80a1273cee59c6ca045e3873dd891..63ab333430c8ed5df59c238ff32e7e1381523561 100644 (file)
@@ -311,7 +311,8 @@ public:
     if (!memcpy_func) {
       const Type *SBP = PointerType::get(Type::SByteTy);
       memcpy_func = M->getOrInsertFunction("llvm.memcpy", Type::VoidTy,SBP, SBP,
-                                           Type::UIntTy, Type::UIntTy, 0);
+                                           Type::UIntTy, Type::UIntTy,
+                                           (Type *)0);
     }
     return memcpy_func;
   }
@@ -319,7 +320,7 @@ public:
   Function* get_floorf() {
     if (!floorf_func)
       floorf_func = M->getOrInsertFunction("floorf", Type::FloatTy,
-                                           Type::FloatTy, 0);
+                                           Type::FloatTy, (Type *)0);
     return floorf_func;
   }
   
@@ -383,7 +384,6 @@ struct ExitInMainOptimization : public LibCallOptimization
 {
   ExitInMainOptimization() : LibCallOptimization("exit",
       "Number of 'exit' calls simplified") {}
-  virtual ~ExitInMainOptimization() {}
 
   // Make sure the called function looks like exit (int argument, int return
   // type, external linkage, not varargs).
@@ -451,8 +451,6 @@ public:
       "Number of 'strcat' calls simplified") {}
 
 public:
-  /// @breif  Destructor
-  virtual ~StrCatOptimization() {}
 
   /// @brief Make sure that the "strcat" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
@@ -540,7 +538,6 @@ struct StrChrOptimization : public LibCallOptimization
 public:
   StrChrOptimization() : LibCallOptimization("strchr",
       "Number of 'strchr' calls simplified") {}
-  virtual ~StrChrOptimization() {}
 
   /// @brief Make sure that the "strchr" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
@@ -632,7 +629,6 @@ struct StrCmpOptimization : public LibCallOptimization
 public:
   StrCmpOptimization() : LibCallOptimization("strcmp",
       "Number of 'strcmp' calls simplified") {}
-  virtual ~StrCmpOptimization() {}
 
   /// @brief Make sure that the "strcmp" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
@@ -719,7 +715,6 @@ struct StrNCmpOptimization : public LibCallOptimization
 public:
   StrNCmpOptimization() : LibCallOptimization("strncmp",
       "Number of 'strncmp' calls simplified") {}
-  virtual ~StrNCmpOptimization() {}
 
   /// @brief Make sure that the "strncmp" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
@@ -822,7 +817,6 @@ struct StrCpyOptimization : public LibCallOptimization
 public:
   StrCpyOptimization() : LibCallOptimization("strcpy",
       "Number of 'strcpy' calls simplified") {}
-  virtual ~StrCpyOptimization() {}
 
   /// @brief Make sure that the "strcpy" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
@@ -910,7 +904,6 @@ struct StrLenOptimization : public LibCallOptimization
 {
   StrLenOptimization() : LibCallOptimization("strlen",
       "Number of 'strlen' calls simplified") {}
-  virtual ~StrLenOptimization() {}
 
   /// @brief Make sure that the "strlen" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
@@ -976,6 +969,127 @@ struct StrLenOptimization : public LibCallOptimization
   }
 } StrLenOptimizer;
 
+/// IsOnlyUsedInEqualsComparison - Return true if it only matters that the value
+/// is equal or not-equal to zero. 
+static bool IsOnlyUsedInEqualsZeroComparison(Instruction *I) {
+  for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
+       UI != E; ++UI) {
+    Instruction *User = cast<Instruction>(*UI);
+    if (User->getOpcode() == Instruction::SetNE ||
+        User->getOpcode() == Instruction::SetEQ) {
+      if (isa<Constant>(User->getOperand(1)) && 
+          cast<Constant>(User->getOperand(1))->isNullValue())
+        continue;
+    } else if (CastInst *CI = dyn_cast<CastInst>(User))
+      if (CI->getType() == Type::BoolTy)
+        continue;
+    // Unknown instruction.
+    return false;
+  }
+  return true;
+}
+
+/// This memcmpOptimization will simplify a call to the memcmp library
+/// function.
+struct memcmpOptimization : public LibCallOptimization {
+  /// @brief Default Constructor
+  memcmpOptimization()
+    : LibCallOptimization("memcmp", "Number of 'memcmp' calls simplified") {}
+  
+  /// @brief Make sure that the "memcmp" function has the right prototype
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &TD) {
+    Function::const_arg_iterator AI = F->arg_begin();
+    if (F->arg_size() != 3 || !isa<PointerType>(AI->getType())) return false;
+    if (!isa<PointerType>((++AI)->getType())) return false;
+    if (!(++AI)->getType()->isInteger()) return false;
+    if (!F->getReturnType()->isInteger()) return false;
+    return true;
+  }
+  
+  /// Because of alignment and instruction information that we don't have, we
+  /// leave the bulk of this to the code generators.
+  ///
+  /// Note that we could do much more if we could force alignment on otherwise
+  /// small aligned allocas, or if we could indicate that loads have a small
+  /// alignment.
+  virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &TD) {
+    Value *LHS = CI->getOperand(1), *RHS = CI->getOperand(2);
+
+    // If the two operands are the same, return zero.
+    if (LHS == RHS) {
+      // memcmp(s,s,x) -> 0
+      CI->replaceAllUsesWith(Constant::getNullValue(CI->getType()));
+      CI->eraseFromParent();
+      return true;
+    }
+    
+    // Make sure we have a constant length.
+    ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getOperand(3));
+    if (!LenC) return false;
+    uint64_t Len = LenC->getRawValue();
+      
+    // If the length is zero, this returns 0.
+    switch (Len) {
+    case 0:
+      // memcmp(s1,s2,0) -> 0
+      CI->replaceAllUsesWith(Constant::getNullValue(CI->getType()));
+      CI->eraseFromParent();
+      return true;
+    case 1: {
+      // memcmp(S1,S2,1) -> *(ubyte*)S1 - *(ubyte*)S2
+      const Type *UCharPtr = PointerType::get(Type::UByteTy);
+      CastInst *Op1Cast = new CastInst(LHS, UCharPtr, LHS->getName(), CI);
+      CastInst *Op2Cast = new CastInst(RHS, UCharPtr, RHS->getName(), CI);
+      Value *S1V = new LoadInst(Op1Cast, LHS->getName()+".val", CI);
+      Value *S2V = new LoadInst(Op2Cast, RHS->getName()+".val", CI);
+      Value *RV = BinaryOperator::createSub(S1V, S2V, CI->getName()+".diff",CI);
+      if (RV->getType() != CI->getType())
+        RV = new CastInst(RV, CI->getType(), RV->getName(), CI);
+      CI->replaceAllUsesWith(RV);
+      CI->eraseFromParent();
+      return true;
+    }
+    case 2:
+      if (IsOnlyUsedInEqualsZeroComparison(CI)) {
+        // TODO: IF both are aligned, use a short load/compare.
+      
+        // memcmp(S1,S2,2) -> S1[0]-S2[0] | S1[1]-S2[1] iff only ==/!= 0 matters
+        const Type *UCharPtr = PointerType::get(Type::UByteTy);
+        CastInst *Op1Cast = new CastInst(LHS, UCharPtr, LHS->getName(), CI);
+        CastInst *Op2Cast = new CastInst(RHS, UCharPtr, RHS->getName(), CI);
+        Value *S1V1 = new LoadInst(Op1Cast, LHS->getName()+".val1", CI);
+        Value *S2V1 = new LoadInst(Op2Cast, RHS->getName()+".val1", CI);
+        Value *D1 = BinaryOperator::createSub(S1V1, S2V1,
+                                              CI->getName()+".d1", CI);
+        Constant *One = ConstantInt::get(Type::IntTy, 1);
+        Value *G1 = new GetElementPtrInst(Op1Cast, One, "next1v", CI);
+        Value *G2 = new GetElementPtrInst(Op2Cast, One, "next2v", CI);
+        Value *S1V2 = new LoadInst(G1, LHS->getName()+".val2", CI);
+        Value *S2V2 = new LoadInst(G1, RHS->getName()+".val2", CI);
+        Value *D2 = BinaryOperator::createSub(S1V2, S2V2,
+                                              CI->getName()+".d1", CI);
+        Value *Or = BinaryOperator::createOr(D1, D2, CI->getName()+".res", CI);
+        if (Or->getType() != CI->getType())
+          Or = new CastInst(Or, CI->getType(), Or->getName(), CI);
+        CI->replaceAllUsesWith(Or);
+        CI->eraseFromParent();
+        return true;
+      }
+      break;
+    default:
+      break;
+    }
+    
+    
+    
+    return false;
+  }
+} memcmpOptimizer;
+
+
+
+
+
 /// This LibCallOptimization will simplify a call to the memcpy library
 /// function by expanding it out to a single store of size 0, 1, 2, 4, or 8
 /// bytes depending on the length of the string and the alignment. Additional
@@ -992,8 +1106,6 @@ protected:
   LLVMMemCpyOptimization(const char* fname, const char* desc)
     : LibCallOptimization(fname, desc) {}
 public:
-  /// @brief Destructor
-  virtual ~LLVMMemCpyOptimization() {}
 
   /// @brief Make sure that the "memcpy" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& TD)
@@ -1078,8 +1190,6 @@ struct LLVMMemSetOptimization : public LibCallOptimization
       "Number of 'llvm.memset' calls simplified") {}
 
 public:
-  /// @brief Destructor
-  virtual ~LLVMMemSetOptimization() {}
 
   /// @brief Make sure that the "memset" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& TD)
@@ -1186,9 +1296,6 @@ public:
   PowOptimization() : LibCallOptimization("pow",
       "Number of 'pow' calls simplified") {}
 
-  /// @brief Destructor
-  virtual ~PowOptimization() {}
-
   /// @brief Make sure that the "pow" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
   {
@@ -1263,9 +1370,6 @@ public:
   FPrintFOptimization() : LibCallOptimization("fprintf",
       "Number of 'fprintf' calls simplified") {}
 
-  /// @brief Destructor
-  virtual ~FPrintFOptimization() {}
-
   /// @brief Make sure that the "fprintf" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
   {
@@ -1399,9 +1503,6 @@ public:
   SPrintFOptimization() : LibCallOptimization("sprintf",
       "Number of 'sprintf' calls simplified") {}
 
-  /// @brief Destructor
-  virtual ~SPrintFOptimization() {}
-
   /// @brief Make sure that the "fprintf" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
   {
@@ -1536,9 +1637,6 @@ public:
   PutsOptimization() : LibCallOptimization("fputs",
       "Number of 'fputs' calls simplified") {}
 
-  /// @brief Destructor
-  virtual ~PutsOptimization() {}
-
   /// @brief Make sure that the "fputs" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
   {
@@ -1602,17 +1700,12 @@ public:
 /// This LibCallOptimization will simplify calls to the "isdigit" library
 /// function. It simply does range checks the parameter explicitly.
 /// @brief Simplify the isdigit library function.
-struct IsDigitOptimization : public LibCallOptimization
-{
+struct isdigitOptimization : public LibCallOptimization {
 public:
-  /// @brief Default Constructor
-  IsDigitOptimization() : LibCallOptimization("isdigit",
+  isdigitOptimization() : LibCallOptimization("isdigit",
       "Number of 'isdigit' calls simplified") {}
 
-  /// @brief Destructor
-  virtual ~IsDigitOptimization() {}
-
-  /// @brief Make sure that the "fputs" function has the right prototype
+  /// @brief Make sure that the "isdigit" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
   {
     // Just make sure this has 1 argument
@@ -1651,7 +1744,35 @@ public:
     ci->eraseFromParent();
     return true;
   }
-} IsDigitOptimizer;
+} isdigitOptimizer;
+
+struct isasciiOptimization : public LibCallOptimization {
+public:
+  isasciiOptimization()
+    : LibCallOptimization("isascii", "Number of 'isascii' calls simplified") {}
+  
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
+    return F->arg_size() == 1 && F->arg_begin()->getType()->isInteger() && 
+           F->getReturnType()->isInteger();
+  }
+  
+  /// @brief Perform the isascii optimization.
+  virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
+    // isascii(c)   -> (unsigned)c < 128
+    Value *V = CI->getOperand(1);
+    if (V->getType()->isSigned())
+      V = new CastInst(V, V->getType()->getUnsignedVersion(), V->getName(), CI);
+    Value *Cmp = BinaryOperator::createSetLT(V, ConstantUInt::get(V->getType(),
+                                                                  128),
+                                             V->getName()+".isascii", CI);
+    if (Cmp->getType() != CI->getType())
+      Cmp = new CastInst(Cmp, CI->getType(), Cmp->getName(), CI);
+    CI->replaceAllUsesWith(Cmp);
+    CI->eraseFromParent();
+    return true;
+  }
+} isasciiOptimizer;
+
 
 /// This LibCallOptimization will simplify calls to the "toascii" library
 /// function. It simply does the corresponding and operation to restrict the
@@ -1664,9 +1785,6 @@ public:
   ToAsciiOptimization() : LibCallOptimization("toascii",
       "Number of 'toascii' calls simplified") {}
 
-  /// @brief Destructor
-  virtual ~ToAsciiOptimization() {}
-
   /// @brief Make sure that the "fputs" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
   {
@@ -1705,9 +1823,6 @@ public:
   FFSOptimization() : LibCallOptimization("ffs",
       "Number of 'ffs' calls simplified") {}
 
-  /// @brief Destructor
-  virtual ~FFSOptimization() {}
-
   /// @brief Make sure that the "fputs" function has the right prototype
   virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
   {
@@ -1949,12 +2064,6 @@ Value *CastToCStr(Value *V, Instruction &IP) {
 // exp, expf, expl:
 //   * exp(log(x))  -> x
 //
-// isascii:
-//   * isascii(c)    -> ((c & ~0x7f) == 0)
-//
-// isdigit:
-//   * isdigit(c)    -> (unsigned)(c) - '0' <= 9
-//
 // log, logf, logl:
 //   * log(exp(x))   -> x
 //   * log(x**y)     -> y*log(x)
@@ -1968,11 +2077,8 @@ Value *CastToCStr(Value *V, Instruction &IP) {
 //   * lround(cnst) -> cnst'
 //
 // memcmp:
-//   * memcmp(s1,s2,0) -> 0
-//   * memcmp(x,x,l)   -> 0
 //   * memcmp(x,y,l)   -> cnst
 //      (if all arguments are constant and strlen(x) <= l and strlen(y) <= l)
-//   * memcmp(x,y,1)   -> *x - *y
 //
 // memmove:
 //   * memmove(d,s,l,a) -> memcpy(d,s,l,a)