Skip to content

Commit

Permalink
Merge branch 'main' into niels9001/homepage-tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
nmetulev authored Nov 18, 2024
2 parents aceb8a9 + ee8c64a commit 93d220f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 117 deletions.
5 changes: 3 additions & 2 deletions AIDevGallery.Utils/AIDevGallery.Utils.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
</PropertyGroup>

<ItemGroup Condition="$(TargetFramework) == 'netstandard2.0'">
<PackageReference Include="System.Text.Json"/>
<PackageReference Include="System.Memory"/>
<PackageReference Include="System.Text.Json" />
<PackageReference Include="System.Memory" />
<PackageReference Include="System.Threading.Tasks.Dataflow" />
<PackageReference Include="PolySharp">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>build; analyzers</IncludeAssets>
Expand Down
66 changes: 38 additions & 28 deletions AIDevGallery.Utils/ModelInformationHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;

namespace AIDevGallery.Utils
{
Expand Down Expand Up @@ -122,43 +123,52 @@ public static async Task<List<ModelFileDetails>> GetDownloadFilesFromHuggingFace
{
var baseUrl = $"https://huggingface.co/api/models/{hfUrl.Organization}/{hfUrl.Repo}/tree/{hfUrl.Ref}";

var httpClient = httpMessageHandler != null ? new HttpClient(httpMessageHandler) : new HttpClient();
var semaphore = new SemaphoreSlim(4, 4);
using var httpClient = httpMessageHandler != null ? new HttpClient(httpMessageHandler) : new HttpClient();

while (hfFiles.Any(f => f.Type == "directory"))
{
var folders = hfFiles.Where(f => f.Type == "directory").ToList();
List<Task> tasks = [];
foreach (var folder in folders)
ActionBlock<string> actionBlock = null!;
actionBlock = new ActionBlock<string>(
async (string path) =>
{
hfFiles.Remove(folder);
tasks.Add(Task.Run(
async () =>
var response = await httpClient.GetAsync($"{baseUrl}/{path}", cancellationToken);
#if NET8_0_OR_GREATER
var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
#else
var stream = await response.Content.ReadAsStreamAsync();
#endif
var files = await JsonSerializer.DeserializeAsync(stream, SourceGenerationContext.Default.ListHuggingFaceModelFileDetails, cancellationToken);
if (files != null)
{
lock (hfFiles)
{
await semaphore.WaitAsync(cancellationToken);
var response = await httpClient.GetAsync($"{baseUrl}/{folder.Path}", cancellationToken);
var responseContent = await response.Content.ReadAsStringAsync();

var files = JsonSerializer.Deserialize(responseContent, SourceGenerationContext.Default.ListHuggingFaceModelFileDetails);
if (files != null)
foreach (var file in files.Where(f => f.Type != "directory"))
{
hfFiles.AddRange(files);
hfFiles.Add(file);
}
}

semaphore.Release();
#if NET8_0_OR_GREATER
},
cancellationToken));
#else
}));
#endif
}
foreach (var folder in files.Where(f => f.Type == "directory" && f.Path != null))
{
actionBlock.Post(folder.Path!);
}
}

if (actionBlock.InputCount == 0)
{
actionBlock.Complete();
}
},
new ExecutionDataflowBlockOptions
{
MaxDegreeOfParallelism = 4,
CancellationToken = cancellationToken
});

await Task.WhenAll(tasks);
foreach (var folder in hfFiles.Where(f => f.Type == "directory" && f.Path != null))
{
actionBlock.Post(folder.Path!);
}

semaphore.Dispose();
httpClient.Dispose();
await actionBlock.Completion;
}

return hfFiles.Select(f =>
Expand Down
194 changes: 107 additions & 87 deletions AIDevGallery/Pages/AddModelPage.xaml.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
using Microsoft.UI.Xaml.Controls;
using Microsoft.UI.Xaml.Navigation;
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;

namespace AIDevGallery.Pages;

Expand Down Expand Up @@ -84,6 +84,104 @@ private async Task SearchModels(string query, CancellationToken cancellationToke

if (results != null && results.Count > 0)
{
ActionBlock<(HFSearchResult Result, Sibling Config, string? ReadmeUrl)> actionBlock = null!;
actionBlock = new ActionBlock<(HFSearchResult Result, Sibling Config, string? ReadmeUrl)>(
async (item) =>
{
var (result, config, readmeUrl) = item;
var configContents = await HuggingFaceApi.GetContentsOfTextFile(result.Id, config.RFilename);
GenAIConfig? genAIConfig = null;
try
{
genAIConfig = JsonSerializer.Deserialize(configContents, SourceGenerationContext.Default.GenAIConfig);
}
catch (JsonException)
{
}

if (genAIConfig != null &&
(
genAIConfig.Model.Decoder.SessionOptions.ProviderOptions.Length == 0 ||
genAIConfig.Model.Decoder.SessionOptions.ProviderOptions.Any(p => p.Dml != null))
)
{
var pathComponents = config.RFilename.Split("/");
string modelPath = string.Empty;
if (pathComponents.Length > 1)
{
modelPath = string.Join("/", pathComponents.Take(pathComponents.Length - 1));
}

var modelUrl = $"https://huggingface.co/{result.Id}/tree/main/{modelPath}";

bool isDmlModel = genAIConfig.Model.Decoder.SessionOptions.ProviderOptions.Any(p => p.Dml != null);

var curratedModel = ModelTypeHelpers.ModelDetails.Values.Where(m => m.Url == modelUrl).FirstOrDefault();

var filesToDownload = await ModelInformationHelper.GetDownloadFilesFromHuggingFace(new HuggingFaceUrl(modelUrl));

var details = curratedModel ?? new ModelDetails()
{
Id = "useradded-languagemodel-" + Guid.NewGuid().ToString(),
Name = result.Id + " " + (isDmlModel ? "DML" : "CPU"),
Url = modelUrl,
Description = "TODO",
HardwareAccelerators = [isDmlModel ? HardwareAccelerator.DML : HardwareAccelerator.CPU],
IsUserAdded = true,
PromptTemplate = GetTemplateFromName(result.Id),
Size = filesToDownload.Sum(f => f.Size),
ReadmeUrl = readmeUrl != null ? $"https://huggingface.co/{result.Id}/blob/main/{readmeUrl}" : null,
};

string? licenseKey = null;
if (result.Tags != null)
{
var licenseTag = result.Tags.Where(t => t.StartsWith("license:", StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault();
if (licenseTag != null)
{
licenseKey = licenseTag.Split(":").Last();
}
}

if (curratedModel == null)
{
details.License = licenseKey;
}

ResultState state = ResultState.NotDownloaded;

if (App.ModelCache.IsModelCached(details.Url))
{
state = ResultState.Downloaded;
}
else if (App.ModelCache.DownloadQueue.GetDownload(details.Url) != null)
{
state = ResultState.Downloading;
}

DispatcherQueue.TryEnqueue(() =>
{
this.results.Add(new Result
{
Details = details,
SearchResult = result,
License = LicenseInfo.GetLicenseInfo(licenseKey),
State = state
});
});
}

if (actionBlock.InputCount == 0)
{
actionBlock.Complete();
}
},
new ExecutionDataflowBlockOptions
{
MaxDegreeOfParallelism = 4,
CancellationToken = cancellationToken
});

foreach (var result in results)
{
if (result.Siblings == null)
Expand All @@ -92,8 +190,6 @@ private async Task SearchModels(string query, CancellationToken cancellationToke
}

var configs = result.Siblings.Where(r => r.RFilename.EndsWith("genai_config.json", StringComparison.InvariantCultureIgnoreCase));
List<Task> tasks = [];
using var semaphore = new SemaphoreSlim(4, 4);

var readmeSiblings = result.Siblings.Where(r => r.RFilename.EndsWith("readme.md", StringComparison.InvariantCultureIgnoreCase));
string? readmeUrl = null;
Expand All @@ -103,94 +199,18 @@ private async Task SearchModels(string query, CancellationToken cancellationToke
readmeUrl = readmeSiblings.First().RFilename;
}

foreach (var config in configs)
if (!configs.Any())
{
tasks.Add(Task.Run(
async () =>
{
await semaphore.WaitAsync(cancellationToken);

var configContents = await HuggingFaceApi.GetContentsOfTextFile(result.Id, config.RFilename);
var genAIConfig = JsonSerializer.Deserialize(configContents, SourceGenerationContext.Default.GenAIConfig);
if (genAIConfig != null &&
(
genAIConfig.Model.Decoder.SessionOptions.ProviderOptions.Length == 0 ||
genAIConfig.Model.Decoder.SessionOptions.ProviderOptions.Any(p => p.Dml != null))
)
{
var pathComponents = config.RFilename.Split("/");
string modelPath = string.Empty;
if (pathComponents.Length > 1)
{
modelPath = string.Join("/", pathComponents.Take(pathComponents.Length - 1));
}

var modelUrl = $"https://huggingface.co/{result.Id}/tree/main/{modelPath}";

bool isDmlModel = genAIConfig.Model.Decoder.SessionOptions.ProviderOptions.Any(p => p.Dml != null);

var curratedModel = ModelTypeHelpers.ModelDetails.Values.Where(m => m.Url == modelUrl).FirstOrDefault();

var filesToDownload = await ModelInformationHelper.GetDownloadFilesFromHuggingFace(new HuggingFaceUrl(modelUrl));

var details = curratedModel ?? new ModelDetails()
{
Id = "useradded-languagemodel-" + Guid.NewGuid().ToString(),
Name = result.Id + " " + (isDmlModel ? "DML" : "CPU"),
Url = modelUrl,
Description = "TODO",
HardwareAccelerators = [isDmlModel ? HardwareAccelerator.DML : HardwareAccelerator.CPU],
IsUserAdded = true,
PromptTemplate = GetTemplateFromName(result.Id),
Size = filesToDownload.Sum(f => f.Size),
ReadmeUrl = readmeUrl != null ? $"https://huggingface.co/{result.Id}/blob/main/{readmeUrl}" : null,
};

string? licenseKey = null;
if (result.Tags != null)
{
var licenseTag = result.Tags.Where(t => t.StartsWith("license:", StringComparison.InvariantCultureIgnoreCase)).FirstOrDefault();
if (licenseTag != null)
{
licenseKey = licenseTag.Split(":").Last();
}
}

if (curratedModel == null)
{
details.License = licenseKey;
}

ResultState state = ResultState.NotDownloaded;

if (App.ModelCache.IsModelCached(details.Url))
{
state = ResultState.Downloaded;
}
else if (App.ModelCache.DownloadQueue.GetDownload(details.Url) != null)
{
state = ResultState.Downloading;
}

DispatcherQueue.TryEnqueue(() =>
{
this.results.Add(new Result()
{
Details = details,
SearchResult = result,
License = LicenseInfo.GetLicenseInfo(licenseKey),
State = state
});
});
}

semaphore.Release();
},
cancellationToken));
continue;
}

await Task.WhenAll(tasks);
foreach (var config in configs)
{
actionBlock.Post((result, config, readmeUrl));
}
}

await actionBlock.Completion;
}
}

Expand Down
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<PackageVersion Include="Microsoft.SemanticKernel.Connectors.InMemory" Version="1.29.0-preview" />
<PackageVersion Include="Microsoft.Xaml.Behaviors.WinUI.Managed" Version="2.0.9" />
<PackageVersion Include="System.Text.Json" Version="9.0.0" />
<PackageVersion Include="System.Threading.Tasks.Dataflow" Version="9.0.0" />
<PackageVersion Include="System.Memory" Version="4.6.0" />
<PackageVersion Include="PolySharp" Version="1.14.1" />
<PackageVersion Include="Microsoft.ML.Tokenizers" Version="1.0.0" />
Expand Down

0 comments on commit 93d220f

Please sign in to comment.