Support/FileSystem: Implement recursive_directory_iterator and make
[oota-llvm.git] / lib / Support / Windows / PathV2.inc
index 6bd541e49cc885203702f9814f240bde04dda1aa..3872512e4faea7afb258495052d4d89ca3c36041 100644 (file)
@@ -1,4 +1,4 @@
-//===- llvm/Support/Win32/PathV2.cpp - Windows Path Impl --------*- C++ -*-===//
+//===- llvm/Support/Windows/PathV2.inc - Windows Path Impl ------*- C++ -*-===//
 //
 //                     The LLVM Compiler Infrastructure
 //
@@ -17,7 +17,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "Windows.h"
-#include <WinCrypt.h>
+#include <wincrypt.h>
 #include <fcntl.h>
 #include <io.h>
 #include <sys/stat.h>
@@ -41,8 +41,7 @@ namespace {
     ::GetProcAddress(::GetModuleHandleA("kernel32.dll"),
                      "CreateSymbolicLinkW"));
 
-  error_code UTF8ToUTF16(const StringRef &utf8,
-                               SmallVectorImpl<wchar_t> &utf16) {
+  error_code UTF8ToUTF16(StringRef utf8, SmallVectorImpl<wchar_t> &utf16) {
     int len = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS,
                                     utf8.begin(), utf8.size(),
                                     utf16.begin(), 0);
@@ -118,14 +117,23 @@ namespace {
     return ::CryptReleaseContext(Provider, 0);
   }
 
-  typedef ScopedHandle<HCRYPTPROV, HCRYPTPROV(INVALID_HANDLE_VALUE),
+  typedef ScopedHandle<HCRYPTPROV, uintptr_t(-1),
                        BOOL (WINAPI*)(HCRYPTPROV), CryptReleaseContext>
     ScopedCryptContext;
+  bool is_separator(const wchar_t value) {
+    switch (value) {
+    case L'\\':
+    case L'/':
+      return true;
+    default:
+      return false;
+    }
+  }
 }
 
 namespace llvm {
 namespace sys  {
-namespace path {
+namespace fs {
 
 error_code current_path(SmallVectorImpl<char> &result) {
   SmallVector<wchar_t, 128> cur_path;
@@ -171,10 +179,6 @@ retry_cur_dir:
   return success;
 }
 
-} // end namespace path
-
-namespace fs {
-
 error_code copy_file(const Twine &from, const Twine &to, copy_option copt) {
   // Get arguments.
   SmallString<128> from_storage;
@@ -264,17 +268,31 @@ error_code remove(const Twine &path, bool &existed) {
   SmallString<128> path_storage;
   SmallVector<wchar_t, 128> path_utf16;
 
+  file_status st;
+  if (error_code ec = status(path, st))
+    return ec;
+
   if (error_code ec = UTF8ToUTF16(path.toStringRef(path_storage),
                                   path_utf16))
     return ec;
 
-  if (!::DeleteFileW(path_utf16.begin())) {
-    error_code ec = windows_error(::GetLastError());
-    if (ec != windows_error::file_not_found)
-      return ec;
-    existed = false;
-  } else
-    existed = true;
+  if (st.type() == file_type::directory_file) {
+    if (!::RemoveDirectoryW(c_str(path_utf16))) {
+      error_code ec = windows_error(::GetLastError());
+      if (ec != windows_error::file_not_found)
+        return ec;
+      existed = false;
+    } else
+      existed = true;
+  } else {
+    if (!::DeleteFileW(c_str(path_utf16))) {
+      error_code ec = windows_error(::GetLastError());
+      if (ec != windows_error::file_not_found)
+        return ec;
+      existed = false;
+    } else
+      existed = true;
+  }
 
   return success;
 }
@@ -292,7 +310,8 @@ error_code rename(const Twine &from, const Twine &to) {
   if (error_code ec = UTF8ToUTF16(f, wide_from)) return ec;
   if (error_code ec = UTF8ToUTF16(t, wide_to)) return ec;
 
-  if (!::MoveFileW(wide_from.begin(), wide_to.begin()))
+  if (!::MoveFileExW(wide_from.begin(), wide_to.begin(),
+                     MOVEFILE_COPY_ALLOWED | MOVEFILE_REPLACE_EXISTING))
     return windows_error(::GetLastError());
 
   return success;
@@ -426,11 +445,40 @@ error_code file_size(const Twine &path, uint64_t &result) {
   return success;
 }
 
+static bool isReservedName(StringRef path) {
+  // This list of reserved names comes from MSDN, at:
+  // http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx
+  static const char *sReservedNames[] = { "nul", "con", "prn", "aux",
+                              "com1", "com2", "com3", "com4", "com5", "com6",
+                              "com7", "com8", "com9", "lpt1", "lpt2", "lpt3",
+                              "lpt4", "lpt5", "lpt6", "lpt7", "lpt8", "lpt9" };
+
+  // First, check to see if this is a device namespace, which always
+  // starts with \\.\, since device namespaces are not legal file paths.
+  if (path.startswith("\\\\.\\"))
+    return true;
+
+  // Then compare against the list of ancient reserved names
+  for (size_t i = 0; i < sizeof(sReservedNames) / sizeof(const char *); ++i) {
+    if (path.equals_lower(sReservedNames[i]))
+      return true;
+  }
+
+  // The path isn't what we consider reserved.
+  return false;
+}
+
 error_code status(const Twine &path, file_status &result) {
   SmallString<128> path_storage;
   SmallVector<wchar_t, 128> path_utf16;
 
-  if (error_code ec = UTF8ToUTF16(path.toStringRef(path_storage),
+  StringRef path8 = path.toStringRef(path_storage);
+  if (isReservedName(path8)) {
+    result = file_status(file_type::character_file);
+    return success;
+  }
+
+  if (error_code ec = UTF8ToUTF16(path8,
                                   path_utf16))
     return ec;
 
@@ -475,7 +523,8 @@ handle_status_error:
 }
 
 error_code unique_file(const Twine &model, int &result_fd,
-                             SmallVectorImpl<char> &result_path) {
+                             SmallVectorImpl<char> &result_path,
+                             bool makeAbsolute) {
   // Use result_path as temp storage.
   result_path.set_size(0);
   StringRef m = model.toStringRef(result_path);
@@ -483,18 +532,19 @@ error_code unique_file(const Twine &model, int &result_fd,
   SmallVector<wchar_t, 128> model_utf16;
   if (error_code ec = UTF8ToUTF16(m, model_utf16)) return ec;
 
-  // Make model absolute by prepending a temp directory if it's not already.
-  bool absolute;
-  if (error_code ec = path::is_absolute(m, absolute)) return ec;
+  if (makeAbsolute) {
+    // Make model absolute by prepending a temp directory if it's not already.
+    bool absolute = path::is_absolute(m);
 
-  if (!absolute) {
-    SmallVector<wchar_t, 64> temp_dir;
-    if (error_code ec = TempDir(temp_dir)) return ec;
-    // Handle c: by removing it.
-    if (model_utf16.size() > 2 && model_utf16[1] == L':') {
-      model_utf16.erase(model_utf16.begin(), model_utf16.begin() + 2);
+    if (!absolute) {
+      SmallVector<wchar_t, 64> temp_dir;
+      if (error_code ec = TempDir(temp_dir)) return ec;
+      // Handle c: by removing it.
+      if (model_utf16.size() > 2 && model_utf16[1] == L':') {
+        model_utf16.erase(model_utf16.begin(), model_utf16.begin() + 2);
+      }
+      model_utf16.insert(model_utf16.begin(), temp_dir.begin(), temp_dir.end());
     }
-    model_utf16.insert(model_utf16.begin(), temp_dir.begin(), temp_dir.end());
   }
 
   // Replace '%' with random chars. From here on, DO NOT modify model. It may be
@@ -507,7 +557,7 @@ error_code unique_file(const Twine &model, int &result_fd,
                               NULL,
                               NULL,
                               PROV_RSA_FULL,
-                              0))
+                              CRYPT_VERIFYCONTEXT))
     return windows_error(::GetLastError());
   ScopedCryptContext CryptoProvider(HCPC);
 
@@ -555,7 +605,7 @@ retry_create_file:
       SmallString<64> dir_to_create;
       for (path::const_iterator i = path::begin(p),
                                 e = --path::end(p); i != e; ++i) {
-        if (error_code ec = path::append(dir_to_create, *i)) return ec;
+        path::append(dir_to_create, *i);
         bool Exists;
         if (error_code ec = exists(Twine(dir_to_create), Exists)) return ec;
         if (!Exists) {
@@ -598,6 +648,136 @@ retry_create_file:
   result_fd = fd;
   return success;
 }
+
+error_code get_magic(const Twine &path, uint32_t len,
+                     SmallVectorImpl<char> &result) {
+  SmallString<128> path_storage;
+  SmallVector<wchar_t, 128> path_utf16;
+  result.set_size(0);
+
+  // Convert path to UTF-16.
+  if (error_code ec = UTF8ToUTF16(path.toStringRef(path_storage),
+                                  path_utf16))
+    return ec;
+
+  // Open file.
+  HANDLE file = ::CreateFileW(c_str(path_utf16),
+                              GENERIC_READ,
+                              FILE_SHARE_READ,
+                              NULL,
+                              OPEN_EXISTING,
+                              FILE_ATTRIBUTE_READONLY,
+                              NULL);
+  if (file == INVALID_HANDLE_VALUE)
+    return windows_error(::GetLastError());
+
+  // Allocate buffer.
+  result.reserve(len);
+
+  // Get magic!
+  DWORD bytes_read = 0;
+  BOOL read_success = ::ReadFile(file, result.data(), len, &bytes_read, NULL);
+  error_code ec = windows_error(::GetLastError());
+  ::CloseHandle(file);
+  if (!read_success || (bytes_read != len)) {
+    // Set result size to the number of bytes read if it's valid.
+    if (bytes_read <= len)
+      result.set_size(bytes_read);
+    // ERROR_HANDLE_EOF is mapped to errc::value_too_large.
+    return ec;
+  }
+
+  result.set_size(len);
+  return success;
+}
+
+error_code detail::directory_iterator_construct(detail::DirIterState &it,
+                                                StringRef path){
+  SmallVector<wchar_t, 128> path_utf16;
+
+  if (error_code ec = UTF8ToUTF16(path,
+                                  path_utf16))
+    return ec;
+
+  // Convert path to the format that Windows is happy with.
+  if (path_utf16.size() > 0 &&
+      !is_separator(path_utf16[path.size() - 1]) &&
+      path_utf16[path.size() - 1] != L':') {
+    path_utf16.push_back(L'\\');
+    path_utf16.push_back(L'*');
+  } else {
+    path_utf16.push_back(L'*');
+  }
+
+  //  Get the first directory entry.
+  WIN32_FIND_DATAW FirstFind;
+  ScopedFindHandle FindHandle(::FindFirstFileW(c_str(path_utf16), &FirstFind));
+  if (!FindHandle)
+    return windows_error(::GetLastError());
+
+  size_t FilenameLen = ::wcslen(FirstFind.cFileName);
+  while ((FilenameLen == 1 && FirstFind.cFileName[0] == L'.') ||
+         (FilenameLen == 2 && FirstFind.cFileName[0] == L'.' &&
+                              FirstFind.cFileName[1] == L'.'))
+    if (!::FindNextFileW(FindHandle, &FirstFind)) {
+      error_code ec = windows_error(::GetLastError());
+      // Check for end.
+      if (ec == windows_error::no_more_files)
+        return detail::directory_iterator_destruct(it);
+      return ec;
+    } else
+      FilenameLen = ::wcslen(FirstFind.cFileName);
+
+  // Construct the current directory entry.
+  SmallString<128> directory_entry_name_utf8;
+  if (error_code ec = UTF16ToUTF8(FirstFind.cFileName,
+                                  ::wcslen(FirstFind.cFileName),
+                                  directory_entry_name_utf8))
+    return ec;
+
+  it.IterationHandle = intptr_t(FindHandle.take());
+  SmallString<128> directory_entry_path(path);
+  path::append(directory_entry_path, directory_entry_name_utf8.str());
+  it.CurrentEntry = directory_entry(directory_entry_path.str());
+
+  return success;
+}
+
+error_code detail::directory_iterator_destruct(detail::DirIterState &it) {
+  if (it.IterationHandle != 0)
+    // Closes the handle if it's valid.
+    ScopedFindHandle close(HANDLE(it.IterationHandle));
+  it.IterationHandle = 0;
+  it.CurrentEntry = directory_entry();
+  return success;
+}
+
+error_code detail::directory_iterator_increment(detail::DirIterState &it) {
+  WIN32_FIND_DATAW FindData;
+  if (!::FindNextFileW(HANDLE(it.IterationHandle), &FindData)) {
+    error_code ec = windows_error(::GetLastError());
+    // Check for end.
+    if (ec == windows_error::no_more_files)
+      return detail::directory_iterator_destruct(it);
+    return ec;
+  }
+
+  size_t FilenameLen = ::wcslen(FindData.cFileName);
+  if ((FilenameLen == 1 && FindData.cFileName[0] == L'.') ||
+      (FilenameLen == 2 && FindData.cFileName[0] == L'.' &&
+                           FindData.cFileName[1] == L'.'))
+    return directory_iterator_increment(it);
+
+  SmallString<128> directory_entry_path_utf8;
+  if (error_code ec = UTF16ToUTF8(FindData.cFileName,
+                                  ::wcslen(FindData.cFileName),
+                                  directory_entry_path_utf8))
+    return ec;
+
+  it.CurrentEntry.replace_filename(Twine(directory_entry_path_utf8));
+  return success;
+}
+
 } // end namespace fs
 } // end namespace sys
 } // end namespace llvm