Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 132 additions & 9 deletions src/windows/common/Distribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,34 +210,157 @@ std::optional<TDistribution> LookupDistributionInManifest(const DistributionList
return *it;
}

// Helper function to merge distributions from multiple manifests
void MergeDistributionLists(DistributionList& target, const DistributionList& source)
{
// Merge legacy distributions
if (source.Distributions.has_value())
{
if (!target.Distributions.has_value())
{
target.Distributions = std::vector<Distribution>{};
}

for (const auto& dist : *source.Distributions)
{
// Check if distribution already exists (avoid duplicates)
auto it = std::find_if(target.Distributions->begin(), target.Distributions->end(),
[&](const Distribution& d) {
return d.Name == dist.Name
&& d.Version == dist.Version
&& d.Architecture == dist.Architecture;
});

if (it == target.Distributions->end())
{
target.Distributions->push_back(dist);
}
}
}

// Merge modern distributions
if (source.ModernDistributions.has_value())
{
if (!target.ModernDistributions.has_value())
{
target.ModernDistributions = std::map<std::string, std::vector<ModernDistributionVersion>>{};
}

for (const auto& [distroName, versions] : *source.ModernDistributions)
{
auto& targetVersions = (*target.ModernDistributions)[distroName];

for (const auto& version : versions)
{
// Check if version already exists
auto it = std::find_if(targetVersions.begin(), targetVersions.end(),
[&](const ModernDistributionVersion& v) {
return v.Name == version.Name
&& v.Version == version.Version
&& v.Architecture == version.Architecture;
});

if (it == targetVersions.end())
{
targetVersions.push_back(version);
}
}
}
}

// Update default if source has one and target doesn't
if (source.Default.has_value() && !target.Default.has_value())
{
target.Default = source.Default;
}
}

} // namespace

AvailableDistributions wsl::windows::common::distribution::GetAvailable()
{
AvailableDistributions distributions{};

// Determine the base manifest URL
// Priority: HKCU > HKLM > Default
std::wstring url = c_defaultDistroListUrl;
std::optional<std::wstring> appendUrl;
std::vector<std::wstring> appendUrls;

try
{
const auto registryKey = registry::OpenLxssMachineKey();
url = registry::ReadString(registryKey.get(), nullptr, c_distroUrlRegistryValue, c_defaultDistroListUrl);
// First check HKEY_LOCAL_MACHINE
const auto machineKey = registry::OpenLxssMachineKey();
url = registry::ReadString(machineKey.get(), nullptr, c_distroUrlRegistryValue, c_defaultDistroListUrl);

// Read HKLM append URLs (supports REG_MULTI_SZ)
auto hklmAppendUrls = registry::ReadWideStringSet(machineKey.get(), nullptr, c_distroUrlAppendRegistryValue, {});
appendUrls.insert(appendUrls.end(), hklmAppendUrls.begin(), hklmAppendUrls.end());

if (url != c_defaultDistroListUrl)
{
WSL_LOG("Found custom URL for distribution list", TraceLoggingValue(url.c_str(), "url"));
WSL_LOG("Found custom URL for distribution list in HKLM", TraceLoggingValue(url.c_str(), "url"));
}

if (!appendUrls.empty())
{
WSL_LOG("Found append URLs in HKLM", TraceLoggingValue(static_cast<UINT32>(appendUrls.size()), "count"));
}
}
CATCH_LOG()

appendUrl = registry::ReadOptionalString(registryKey.get(), nullptr, c_distroUrlAppendRegistryValue);
try
{
// Then check HKEY_CURRENT_USER (takes precedence)
const auto userKey = registry::OpenLxssUserKey();

// Check if user has overridden the base URL
auto userUrl = registry::ReadOptionalString(userKey.get(), nullptr, c_distroUrlRegistryValue);
if (userUrl.has_value())
{
url = userUrl.value();
WSL_LOG("Found custom URL for distribution list in HKCU (overriding)", TraceLoggingValue(url.c_str(), "url"));
}

// Read HKCU append URLs (supports REG_MULTI_SZ) - these are added to HKLM append URLs
auto hkcuAppendUrls = registry::ReadWideStringSet(userKey.get(), nullptr, c_distroUrlAppendRegistryValue, {});
appendUrls.insert(appendUrls.end(), hkcuAppendUrls.begin(), hkcuAppendUrls.end());

if (!hkcuAppendUrls.empty())
{
WSL_LOG("Found append URLs in HKCU", TraceLoggingValue(static_cast<UINT32>(hkcuAppendUrls.size()), "count"));
}
}
CATCH_LOG()

// Load the base manifest
distributions.Manifest = ReadFromManifest(url);

if (appendUrl.has_value())
// Load and merge all append manifests
if (!appendUrls.empty())
{
WSL_LOG("Found append URL for distribution list", TraceLoggingValue(appendUrl->c_str(), "url"));

distributions.OverrideManifest = ReadFromManifest(appendUrl.value());
for (const auto& appendUrl : appendUrls)
{
try
{
WSL_LOG("Loading append manifest", TraceLoggingValue(appendUrl.c_str(), "url"));
auto appendManifest = ReadFromManifest(appendUrl);

// Merge into override manifest if it exists, otherwise create it
if (!distributions.OverrideManifest.has_value())
{
distributions.OverrideManifest = appendManifest;
}
else
{
MergeDistributionLists(*distributions.OverrideManifest, appendManifest);
}
}
catch (...)
{
// Log the error but continue with other sources
LOG_CAUGHT_EXCEPTION_MSG("Failed to load append manifest from %ls", appendUrl.c_str());
}
}
}

return distributions;
Expand Down
14 changes: 13 additions & 1 deletion src/windows/common/Distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,24 @@ Module Name:

namespace wsl::windows::common::distribution {

// Represents a file to be injected into the distribution during installation
struct InjectedFile
{
std::wstring Source; // "url" or "inline"
std::optional<std::wstring> Url; // URL to download from (if Source == "url")
std::optional<std::wstring> Sha256; // SHA256 hash for verification (if Source == "url")
std::optional<std::wstring> Contents; // Inline file contents (if Source == "inline")

NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(InjectedFile, Source, Url, Sha256, Contents);
};

struct DistributionArchive
{
std::wstring Url;
std::wstring Sha256;
std::optional<std::map<std::string, InjectedFile>> Files; // Map of file paths to inject

NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(DistributionArchive, Url, Sha256);
NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(DistributionArchive, Url, Sha256, Files);
};

struct ModernDistributionVersion
Expand Down
103 changes: 103 additions & 0 deletions src/windows/common/WslInstall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,5 +324,108 @@ std::pair<std::wstring, GUID> WslInstall::InstallModernDistribution(
fixedVhd ? LXSS_IMPORT_DISTRO_FLAGS_FIXED_VHD : 0,
vhdSize);

// Inject files if specified in the distribution metadata
if (downloadInfo->Files.has_value() && !downloadInfo->Files->empty())
{
try
{
PrintMessage(L"Injecting configuration files...", stdout);

for (const auto& [targetPath, fileSpec] : *downloadInfo->Files)
{
const auto targetPathWide = wsl::shared::string::MultiByteToWide(targetPath);

if (wsl::windows::common::string::IsEqual(fileSpec.Source, L"inline", true))
{
// Inline content - write directly using base64 encoding to avoid shell escaping issues
if (!fileSpec.Contents.has_value())
{
LOG_HR_MSG(E_INVALIDARG, "Inline file source specified but no contents provided for %s", targetPath.c_str());
continue;
}

// Convert content to base64 to safely pass through shell
const auto contentBase64 = wsl::shared::string::Base64EncodeFromWide(fileSpec.Contents->c_str());

// Create parent directory and decode base64 content into file
const auto command = std::format(
L"/bin/sh -c \"mkdir -p $(dirname '{}') && echo '{}' | base64 -d > '{}'\"",
targetPathWide,
wsl::shared::string::MultiByteToWide(contentBase64),
targetPathWide);

LPCWSTR argv[] = {L"/bin/sh", L"-c", command.c_str()};
const auto exitCode = service.LaunchProcess(&id, L"/bin/sh", 3, argv, LXSS_LAUNCH_FLAGS_NONE, nullptr, nullptr, 30000);

if (exitCode != 0)
{
LOG_HR_MSG(E_FAIL, "Failed to inject inline file %s, exit code: %d", targetPath.c_str(), exitCode);
}
}
else if (wsl::windows::common::string::IsEqual(fileSpec.Source, L"url", true))
{
// URL-based file - download and inject
if (!fileSpec.Url.has_value() || !fileSpec.Sha256.has_value())
{
LOG_HR_MSG(E_INVALIDARG, "URL file source specified but no URL or SHA256 provided for %s", targetPath.c_str());
continue;
}

// Download file to temp location with UUID to prevent collisions
GUID uniqueId{};
THROW_IF_FAILED(CoCreateGuid(&uniqueId));
const auto tempFileName = std::format(L"injected_file_{}.tmp",
wsl::shared::string::GuidToString<wchar_t>(uniqueId, wsl::shared::string::GuidToStringFlags::None));
const auto tempFilePath = DownloadFile(*fileSpec.Url, tempFileName);

// Verify hash
wil::unique_handle tempFile{CreateFile(tempFilePath.c_str(), GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, 0, nullptr)};
if (tempFile)
{
try
{
EnforceFileHash(tempFile.get(), *fileSpec.Sha256);
tempFile.reset();

// Read file content and inject (using base64 for safety)
const auto fileContent = wsl::shared::string::ReadFile<char, char>(tempFilePath.c_str());
const auto contentBase64 = wsl::shared::string::Base64Encode(fileContent);

const auto command = std::format(
L"/bin/sh -c \"mkdir -p $(dirname '{}') && echo '{}' | base64 -d > '{}'\"",
targetPathWide,
wsl::shared::string::MultiByteToWide(contentBase64),
targetPathWide);

LPCWSTR argv[] = {L"/bin/sh", L"-c", command.c_str()};
const auto exitCode = service.LaunchProcess(&id, L"/bin/sh", 3, argv, LXSS_LAUNCH_FLAGS_NONE, nullptr, nullptr, 30000);

if (exitCode != 0)
{
LOG_HR_MSG(E_FAIL, "Failed to inject URL-based file %s, exit code: %d", targetPath.c_str(), exitCode);
}
}
catch (...)
{
LOG_CAUGHT_EXCEPTION_MSG("Failed to inject file from URL for %s", targetPath.c_str());
}

// Clean up temp file
DeleteFileW(tempFilePath.c_str());
}
}
else
{
LOG_HR_MSG(E_INVALIDARG, "Unknown file source type: %ls for %s", fileSpec.Source.c_str(), targetPath.c_str());
}
}
}
catch (...)
{
// Log but don't fail installation if file injection fails
LOG_CAUGHT_EXCEPTION_MSG("File injection failed, but installation will continue");
}
}

return {installedName.get(), id};
}
44 changes: 44 additions & 0 deletions src/windows/common/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,50 @@ std::vector<std::string> wsl::windows::common::registry::ReadStringSet(
return Values;
}

std::vector<std::wstring> wsl::windows::common::registry::ReadWideStringSet(
_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName, const std::vector<std::wstring>& Default)
{
//
// Detect if the key exists and determine how large of a buffer is needed.
// If the key does not exist, return the default value.
//

LONG Result;
DWORD Size = 0;
Result = RegGetValueW(Key, KeyName, ValueName, RRF_RT_REG_MULTI_SZ, nullptr, nullptr, &Size);
if ((Result == ERROR_PATH_NOT_FOUND) || (Result == ERROR_FILE_NOT_FOUND) || (Size == 0))
{
return Default;
}

ReportErrorIfFailed(Result, Key, KeyName, ValueName);

//
// Allocate a buffer to hold the value and two NULL terminators.
//

std::vector<WCHAR> Buffer((Size / sizeof(WCHAR)) + 2);

//
// Read the value.
//

Result = RegGetValueW(Key, KeyName, ValueName, RRF_RT_REG_MULTI_SZ, nullptr, Buffer.data(), &Size);
ReportErrorIfFailed(Result, Key, KeyName, ValueName);

//
// Convert the reg value into a vector of wide strings.
//

std::vector<std::wstring> Values{};
for (auto Current = Buffer.data(); UNICODE_NULL != *Current; Current += wcslen(Current) + 1)
{
Values.push_back(Current);
}

return Values;
}

void wsl::windows::common::registry::WriteDword(_In_ HKEY Key, _In_ LPCWSTR SubKey, _In_ LPCWSTR ValueName, _In_ DWORD Value)
{
const auto Result = RegSetKeyValueW(Key, SubKey, ValueName, REG_DWORD, &Value, sizeof(Value));
Expand Down
1 change: 1 addition & 0 deletions src/windows/common/registry.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ std::wstring ReadString(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWST
std::optional<std::wstring> ReadOptionalString(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName);

std::vector<std::string> ReadStringSet(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName, _In_ const std::vector<std::string>& Default);
std::vector<std::wstring> ReadWideStringSet(_In_ HKEY Key, _In_opt_ LPCWSTR KeyName, _In_opt_ LPCWSTR ValueName, _In_ const std::vector<std::wstring>& Default);

void WriteDword(_In_ HKEY Key, _In_ LPCWSTR SubKey, _In_ LPCWSTR KeyName, _In_ DWORD Value);

Expand Down