From 2198e7c8df19bbe57942eaf1ef0cb5ca03182952 Mon Sep 17 00:00:00 2001 From: Tex Riddell Date: Thu, 29 May 2025 12:59:30 -0700 Subject: [PATCH] Fix unicode conversion bugs for *nix --- include/dxc/WinAdapter.h | 32 ++++-- lib/DxcSupport/Unicode.cpp | 102 +++++++++++++------- tools/clang/unittests/HLSL/CompilerTest.cpp | 66 ++++++------- 3 files changed, 124 insertions(+), 76 deletions(-) diff --git a/include/dxc/WinAdapter.h b/include/dxc/WinAdapter.h index d02ad1ac38..986149217f 100644 --- a/include/dxc/WinAdapter.h +++ b/include/dxc/WinAdapter.h @@ -916,19 +916,35 @@ unsigned int SysStringLen(const BSTR bstrString); // RAII style mechanism for setting/unsetting a locale for the specified Windows // codepage class ScopedLocale { - const char *m_prevLocale; + locale_t Utf8Locale = nullptr; + locale_t PrevLocale = nullptr; public: - explicit ScopedLocale(uint32_t codePage) - : m_prevLocale(setlocale(LC_ALL, nullptr)) { - assert((codePage == CP_UTF8) && + explicit ScopedLocale(uint32_t CodePage) { + assert((CodePage == CP_UTF8) && "Support for Linux only handles UTF8 code pages"); - setlocale(LC_ALL, "en_US.UTF-8"); + Utf8Locale = newlocale(LC_CTYPE_MASK, "C.UTF-8", NULL); + if (!Utf8Locale) + Utf8Locale = newlocale(LC_CTYPE_MASK, "C.utf8", NULL); + if (!Utf8Locale) + Utf8Locale = newlocale(LC_CTYPE_MASK, "en_US.UTF-8", NULL); + assert(Utf8Locale && "Failed to create UTF-8 locale"); + if (!Utf8Locale) + return; + PrevLocale = uselocale(Utf8Locale); + assert(PrevLocale && "Failed to set locale to UTF-8"); + if (!PrevLocale) { + freelocale(Utf8Locale); + Utf8Locale = nullptr; + } } ~ScopedLocale() { - if (m_prevLocale != nullptr) { - setlocale(LC_ALL, m_prevLocale); - } + if (PrevLocale != nullptr) + uselocale(PrevLocale); + if (Utf8Locale) + freelocale(Utf8Locale); + PrevLocale = nullptr; + Utf8Locale = nullptr; } }; diff --git a/lib/DxcSupport/Unicode.cpp b/lib/DxcSupport/Unicode.cpp index 1481ae27ff..c5a756dfce 100644 --- a/lib/DxcSupport/Unicode.cpp +++ b/lib/DxcSupport/Unicode.cpp @@ -42,18 +42,17 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, ++cbMultiByte; } // If zero is given as the destination size, this function should - // return the required size (including the null-terminating character). + // return the required size (including or excluding the null-terminating + // character depending on whether the input included the null-terminator). // This is the behavior of mbstowcs when the target is null. if (cchWideChar == 0) { lpWideCharStr = nullptr; - } else if (cchWideChar < cbMultiByte) { - SetLastError(ERROR_INSUFFICIENT_BUFFER); - return 0; } + ScopedLocale utf8_locale_scope(CP_UTF8); + + bool isNullTerminated = false; size_t rv; - const char *prevLocale = setlocale(LC_ALL, nullptr); - setlocale(LC_ALL, "en_US.UTF-8"); if (lpMultiByteStr[cbMultiByte - 1] != '\0') { char *srcStr = (char *)malloc((cbMultiByte + 1) * sizeof(char)); strncpy(srcStr, lpMultiByteStr, cbMultiByte); @@ -62,14 +61,22 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, free(srcStr); } else { rv = mbstowcs(lpWideCharStr, lpMultiByteStr, cchWideChar); + isNullTerminated = true; } - if (prevLocale) - setlocale(LC_ALL, prevLocale); + if (rv == (size_t)-1) { + // mbstowcs returns -1 on error. + SetLastError(ERROR_INVALID_PARAMETER); + return 0; + } - if (rv == (size_t)cbMultiByte) - return rv; - return rv + 1; // mbstowcs excludes the terminating character + // Return value of mbstowcs (rv) excludes the terminating character. + // Matching MultiByteToWideChar requires returning the size written including + // the null terminator if the input was null-terminated, otherwise it + // returns the size written excluding the null terminator. + if (isNullTerminated) + return rv + 1; + return rv; } // WideCharToMultiByte is a Windows-specific method. @@ -98,18 +105,17 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, ++cchWideChar; } // If zero is given as the destination size, this function should - // return the required size (including the null-terminating character). + // return the required size (including or excluding the null-terminating + // character depending on whether the input included the null-terminator). // This is the behavior of wcstombs when the target is null. if (cbMultiByte == 0) { lpMultiByteStr = nullptr; - } else if (cbMultiByte < cchWideChar) { - SetLastError(ERROR_INSUFFICIENT_BUFFER); - return 0; } + ScopedLocale utf8_locale_scope(CP_UTF8); + + bool isNullTerminated = false; size_t rv; - const char *prevLocale = setlocale(LC_ALL, nullptr); - setlocale(LC_ALL, "en_US.UTF-8"); if (lpWideCharStr[cchWideChar - 1] != L'\0') { wchar_t *srcStr = (wchar_t *)malloc((cchWideChar + 1) * sizeof(wchar_t)); wcsncpy(srcStr, lpWideCharStr, cchWideChar); @@ -118,14 +124,22 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/, free(srcStr); } else { rv = wcstombs(lpMultiByteStr, lpWideCharStr, cbMultiByte); + isNullTerminated = true; } - if (prevLocale) - setlocale(LC_ALL, prevLocale); + if (rv == (size_t)-1) { + // wcstombs returns -1 on error. + SetLastError(ERROR_INVALID_PARAMETER); + return 0; + } - if (rv == (size_t)cchWideChar) - return rv; - return rv + 1; // mbstowcs excludes the terminating character + // Return value of wcstombs (rv) excludes the terminating character. + // Matching MultiByteToWideChar requires returning the size written including + // the null terminator if the input was null-terminated, otherwise it + // returns the size written excluding the null terminator. + if (isNullTerminated) + return rv + 1; + return rv; } #endif // _WIN32 @@ -133,6 +147,7 @@ namespace Unicode { bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp, DWORD flags, std::string *pValue, bool *lossy) { + DXASSERT_NOMSG(cWide == (size_t)-1 || cWide < INT32_MAX); BOOL usedDefaultChar; LPBOOL pUsedDefaultChar = (lossy == nullptr) ? nullptr : &usedDefaultChar; if (lossy != nullptr) @@ -147,16 +162,24 @@ bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp, return true; } - int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, nullptr, 0, - nullptr, pUsedDefaultChar); + int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast(cWide), + nullptr, 0, nullptr, pUsedDefaultChar); if (cbUTF8 == 0) return false; pValue->resize(cbUTF8); - cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, &(*pValue)[0], - pValue->size(), nullptr, pUsedDefaultChar); + cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast(cWide), + &(*pValue)[0], pValue->size(), nullptr, + pUsedDefaultChar); DXASSERT(cbUTF8 > 0, "otherwise contents have changed"); + if ((cWide == (size_t)-1 || text[cWide - 1] == L'\0') && + (*pValue)[pValue->size() - 1] == '\0') { + // When the input is null-terminated, the output includes the null + // terminator. Reduce the size by 1 to remove the embedded null terminator + // inside the string. + pValue->resize(cbUTF8 - 1); + } DXASSERT((*pValue)[pValue->size()] == '\0', "otherwise string didn't null-terminate after resize() call"); @@ -166,12 +189,12 @@ bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp, } bool UTF8ToWideString(const char *pUTF8, std::wstring *pWide) { - size_t cbUTF8 = (pUTF8 == nullptr) ? 0 : strlen(pUTF8); - return UTF8ToWideString(pUTF8, cbUTF8, pWide); + return UTF8ToWideString(pUTF8, -1, pWide); } bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) { DXASSERT_NOMSG(pWide != nullptr); + DXASSERT_NOMSG(cbUTF8 == (size_t)-1 || cbUTF8 < INT32_MAX); // Handle zero-length as a special case; it's a special value to indicate // errors in MultiByteToWideChar. @@ -181,15 +204,23 @@ bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) { } int cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8, - cbUTF8, nullptr, 0); + static_cast(cbUTF8), nullptr, 0); if (cWide == 0) return false; pWide->resize(cWide); - cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8, cbUTF8, - &(*pWide)[0], pWide->size()); + cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8, + static_cast(cbUTF8), &(*pWide)[0], + pWide->size()); DXASSERT(cWide > 0, "otherwise contents changed"); + if ((cbUTF8 == (size_t)-1 || pUTF8[cbUTF8 - 1] == '\0') && + (*pWide)[pWide->size() - 1] == '\0') { + // When the input is null-terminated, the output includes the null + // terminator. Reduce the size by 1 to remove the embedded null terminator + // inside the string. + pWide->resize(cWide - 1); + } DXASSERT((*pWide)[pWide->size()] == L'\0', "otherwise wstring didn't null-terminate after resize() call"); return true; @@ -213,11 +244,12 @@ bool UTF8ToConsoleString(const char *text, size_t textLen, std::string *pValue, if (!UTF8ToWideString(text, textLen, &text16)) { return false; } - return WideToConsoleString(text16.c_str(), text16.length(), pValue, lossy); + return WideToConsoleString(text16.c_str(), text16.length() + 1, pValue, + lossy); } bool UTF8ToConsoleString(const char *text, std::string *pValue, bool *lossy) { - return UTF8ToConsoleString(text, strlen(text), pValue, lossy); + return UTF8ToConsoleString(text, (size_t)-1, pValue, lossy); } bool WideToConsoleString(const wchar_t *text, size_t textLen, @@ -230,7 +262,7 @@ bool WideToConsoleString(const wchar_t *text, size_t textLen, bool WideToConsoleString(const wchar_t *text, std::string *pValue, bool *lossy) { - return WideToConsoleString(text, wcslen(text), pValue, lossy); + return WideToConsoleString(text, (size_t)-1, pValue, lossy); } bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) { @@ -242,7 +274,7 @@ bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) { bool WideToUTF8String(const wchar_t *pWide, std::string *pUTF8) { DXASSERT_NOMSG(pWide != nullptr); DXASSERT_NOMSG(pUTF8 != nullptr); - return WideToEncodedString(pWide, wcslen(pWide), CP_UTF8, 0, pUTF8, nullptr); + return WideToEncodedString(pWide, (size_t)-1, CP_UTF8, 0, pUTF8, nullptr); } std::string WideToUTF8StringOrThrow(const wchar_t *pWide) { diff --git a/tools/clang/unittests/HLSL/CompilerTest.cpp b/tools/clang/unittests/HLSL/CompilerTest.cpp index 3f4fb30d58..07252e3d04 100644 --- a/tools/clang/unittests/HLSL/CompilerTest.cpp +++ b/tools/clang/unittests/HLSL/CompilerTest.cpp @@ -207,6 +207,13 @@ class CompilerTest : public ::testing::Test { void TestEncodingImpl(const void *sourceData, size_t sourceSize, UINT32 codePage, const void *includedData, size_t includedSize, const WCHAR *encoding = nullptr); + template + void TestEncodingImpl(std::basic_string source, UINT32 codePage, + std::basic_string included, + const WCHAR *encoding = nullptr) { + TestEncodingImpl(source.data(), source.size() * sizeof(T1), codePage, + included.data(), included.size() * sizeof(T2), encoding); + } TEST_METHOD(CompileWithEncodeFlagTestSource) #if _ITERATOR_DEBUG_LEVEL == 0 @@ -3636,54 +3643,47 @@ void CompilerTest::TestEncodingImpl(const void *sourceData, size_t sourceSize, TEST_F(CompilerTest, CompileWithEncodeFlagTestSource) { - std::string sourceUtf8 = "#include \"include.hlsl\"\r\n" - "float4 main() : SV_Target { return 0; }"; - std::string includeUtf8 = "// Comment\n"; + std::string SourceUtf8 = "#include \"include.hlsl\"\n" + "float4 main() : SV_Target { return Buf[0]; }"; + std::string IncludeUtf8 = "Buffer Buf;\n"; std::string utf8BOM = "\xEF" "\xBB" "\xBF"; // UTF-8 BOM - std::string includeUtf8BOM = utf8BOM + includeUtf8; + std::string IncludeUtf8BOM = utf8BOM + IncludeUtf8; - std::wstring sourceWide = L"#include \"include.hlsl\"\r\n" - L"float4 main() : SV_Target { return 0; }"; - std::wstring includeWide = L"// Comments\n"; - std::wstring utf16BOM = L"\xFEFF"; // UTF-16 LE BOM - std::wstring includeUtf16BOM = utf16BOM + includeWide; + std::wstring SourceWide = L"#include \"include.hlsl\"\n" + L"float4 main() : SV_Target { return Buf[0]; }"; + std::wstring IncludeWide = L"Buffer Buf;\n"; - // Included files interpreted with encoding option if no BOM - TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8, - includeUtf8.data(), includeUtf8.size(), L"utf8"); + // Windows: UTF-16 BOM is '\xFEFF' + // *nix: UTF-32 BOM is L'\x0000FEFF' + // Thus, BOM wide character value is identical for UTF-16 and UTF-32. + // Endianess will be native, since we are using wide strings directly. + std::wstring WideBOM = L"\xFEFF"; + + std::wstring IncludeWideBOM = WideBOM + IncludeWide; - TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'), - DXC_CP_WIDE, includeWide.data(), - includeWide.size() * sizeof(L'A'), L"wide"); + // Included files interpreted with encoding option if no BOM + TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8, L"utf8"); + TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWide, L"wide"); // Encoding option ignored if BOM present - TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8, - includeUtf8BOM.data(), includeUtf8BOM.size(), L"wide"); + TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8BOM, L"wide"); + TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWideBOM, L"utf8"); - TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'), - DXC_CP_WIDE, includeUtf16BOM.data(), - includeUtf16BOM.size() * sizeof(L'A'), L"utf8"); + // Encoding option ignored if BOM present - different encoding for source + TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8BOM, L"wide"); + TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWideBOM, L"utf8"); // Source file interpreted according to DxcBuffer encoding if not CP_ACP // Included files interpreted with encoding option if no BOM - TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8, - includeWide.data(), includeWide.size() * sizeof(L'A'), - L"wide"); - - TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'), - DXC_CP_WIDE, includeUtf8.data(), includeUtf8.size(), - L"utf8"); + TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWide, L"wide"); + TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8, L"utf8"); // Source file interpreted by encoding option if source DxcBuffer encoding = // CP_ACP (default) - TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_ACP, - includeUtf8.data(), includeUtf8.size(), L"utf8"); - - TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'), - DXC_CP_ACP, includeWide.data(), - includeWide.size() * sizeof(L'A'), L"wide"); + TestEncodingImpl(SourceUtf8, DXC_CP_ACP, IncludeUtf8, L"utf8"); + TestEncodingImpl(SourceWide, DXC_CP_ACP, IncludeWide, L"wide"); } TEST_F(CompilerTest, CompileWhenODumpThenOptimizerMatch) {