Fix several nasty bugs in the strchr optimizer, this fixes
authorChris Lattner <sabre@nondot.org>
Fri, 6 Apr 2007 23:38:55 +0000 (23:38 +0000)
committerChris Lattner <sabre@nondot.org>
Fri, 6 Apr 2007 23:38:55 +0000 (23:38 +0000)
SimplifyLibCalls/2007-04-06-strchr-miscompile.ll and PR1307

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@35706 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/IPO/SimplifyLibCalls.cpp

index eca78d72f8eae1f5d5b6bcb31f76247d7ae781db..17149e2ec64d0d4c6b0310244379ef544a32780d 100644 (file)
@@ -532,74 +532,80 @@ public:
       "Number of 'strchr' calls simplified") {}
 
   /// @brief Make sure that the "strchr" function has the right prototype
-  virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){
-    if (f->getReturnType() == PointerType::get(Type::Int8Ty) &&
-        f->arg_size() == 2)
-      return true;
-    return false;
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
+    const FunctionType *FT = F->getFunctionType();
+    return FT->getNumParams() == 2 &&
+           FT->getReturnType() == PointerType::get(Type::Int8Ty) &&
+           FT->getParamType(0) == FT->getReturnType() &&
+           isa<IntegerType>(FT->getParamType(1));
   }
 
   /// @brief Perform the strchr optimizations
-  virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) {
-    // If there aren't three operands, bail
-    if (ci->getNumOperands() != 3)
-      return false;
-
+  virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
     // Check that the first argument to strchr is a constant array of sbyte.
     // If it is, get the length and data, otherwise return false.
-    uint64_t len, StartIdx;
-    ConstantArrayCA = 0;
-    if (!GetConstantStringInfo(ci->getOperand(1), CA, len, StartIdx))
+    uint64_t StrLength, StartIdx;
+    ConstantArray *CA = 0;
+    if (!GetConstantStringInfo(CI->getOperand(1), CA, StrLength, StartIdx))
       return false;
 
-    // Check that the second argument to strchr is a constant int. If it isn't
-    // a constant integer, we can try an alternate optimization
-    ConstantInt* CSI = dyn_cast<ConstantInt>(ci->getOperand(2));
+    // If the second operand is not constant, just lower this to memchr since we
+    // know the length of the input string.
+    ConstantInt *CSI = dyn_cast<ConstantInt>(CI->getOperand(2));
     if (!CSI) {
-      // The second operand is not constant just lower this to 
-      // memchr since we know the length of the string since it is constant.
-      Constant *f = SLC.get_memchr();
-      Value* args[3] = {
-        ci->getOperand(1),
-        ci->getOperand(2),
-        ConstantInt::get(SLC.getIntPtrType(), len)
+      Value *Args[3] = {
+        CI->getOperand(1),
+        CI->getOperand(2),
+        ConstantInt::get(SLC.getIntPtrType(), StrLength+1)
       };
-      ci->replaceAllUsesWith(new CallInst(f, args, 3, ci->getName(), ci));
-      ci->eraseFromParent();
+      CI->replaceAllUsesWith(new CallInst(SLC.get_memchr(), Args, 3,
+                                          CI->getName(), CI));
+      CI->eraseFromParent();
       return true;
     }
 
     // Get the character we're looking for
-    int64_t chr = CSI->getSExtValue();
-
+    int64_t CharValue = CSI->getSExtValue();
+
+    if (StrLength == 0) {
+      // If the length of the string is zero, and we are searching for zero,
+      // return the input pointer.
+      if (CharValue == 0) {
+        CI->replaceAllUsesWith(CI->getOperand(1));
+      } else {
+        // Otherwise, char wasn't found.
+        CI->replaceAllUsesWith(Constant::getNullValue(CI->getType()));
+      }
+      CI->eraseFromParent();
+      return true;
+    }
+    
     // Compute the offset
-    uint64_t offset = 0;
-    bool char_found = false;
-    for (uint64_t i = 0; i < len; ++i) {
-      if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(i))) {
-        // Check for the null terminator
-        if (CI->isZero())
-          break; // we found end of string
-        else if (CI->getSExtValue() == chr) {
-          char_found = true;
-          offset = i;
+    uint64_t i = 0;
+    while (1) {
+      assert(i <= StrLength && "Didn't find null terminator?");
+      if (ConstantInt *C = dyn_cast<ConstantInt>(CA->getOperand(i+StartIdx))) {
+        // Did we find our match?
+        if (C->getSExtValue() == CharValue)
           break;
+        if (C->isZero()) {
+          // We found the end of the string.  strchr returns null.
+          CI->replaceAllUsesWith(Constant::getNullValue(CI->getType()));
+          CI->eraseFromParent();
+          return true;
         }
       }
+      ++i;
     }
 
-    // strchr(s,c)  -> offset_of_in(c,s)
+    // strchr(s+n,c)  -> gep(s+n+i,c)
     //    (if c is a constant integer and s is a constant string)
-    if (char_found) {
-      Value* Idx = ConstantInt::get(Type::Int64Ty,offset);
-      GetElementPtrInst* GEP = new GetElementPtrInst(ci->getOperand(1), Idx, 
-          ci->getOperand(1)->getName()+".strchr",ci);
-      ci->replaceAllUsesWith(GEP);
-    } else {
-      ci->replaceAllUsesWith(
-          ConstantPointerNull::get(PointerType::get(Type::Int8Ty)));
-    }
-    ci->eraseFromParent();
+    Value *Idx = ConstantInt::get(Type::Int64Ty, i);
+    Value *GEP = new GetElementPtrInst(CI->getOperand(1), Idx, 
+                                       CI->getOperand(1)->getName() +
+                                       ".strchr", CI);
+    CI->replaceAllUsesWith(GEP);
+    CI->eraseFromParent();
     return true;
   }
 } StrChrOptimizer;