Skip to content

Commit

Permalink
Merge pull request #3 from AssemblyAI/niels/add-di
Browse files Browse the repository at this point in the history
Refactor to add plugin using DI
  • Loading branch information
Swimburger committed Feb 17, 2024
2 parents 5a0afd2 + da954e5 commit 6fd943b
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 167 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Add the [AssemblyAI.SemanticKernel NuGet package](https://www.nuget.org/packages
dotnet add package AssemblyAI.SemanticKernel
```

Next, register the `TranscriptPlugin` into your kernel:
Next, register the `AssemblyAI` plugin into your kernel:

```csharp
using AssemblyAI.SemanticKernel;
Expand All @@ -37,6 +37,7 @@ string apiKey = Environment.GetEnvironmentVariable("ASSEMBLYAI_API_KEY")

kernel.ImportPluginFromObject(
new TranscriptPlugin(apiKey: apiKey)
TranscriptPlugin.PluginName
);
```

Expand All @@ -45,8 +46,8 @@ kernel.ImportPluginFromObject(
Get the `Transcribe` function from the transcript plugin and invoke it with the context variables.
```csharp
var result = await kernel.InvokeAsync(
nameof(TranscriptPlugin),
TranscriptPlugin.TranscribeFunctionName,
nameof(AssemblyAIPlugin),
AssemblyAIPlugin.TranscribeFunctionName,
new KernelArguments
{
["INPUT"] = "https://storage.googleapis.com/aai-docs-samples/espn.m4a"
Expand All @@ -58,7 +59,7 @@ Console.WriteLine(result.GetValue<string>());
You can get the transcript using `result.GetValue<string>()`.

You can also upload local audio and video file. To do this:
- Set the `TranscriptPlugin.AllowFileSystemAccess` property to `true`.
- Set the `AssemblyAI:Plugin:AllowFileSystemAccess` configuration to `true`.
- Configure the `INPUT` variable with a local file path.

```csharp
Expand All @@ -69,8 +70,8 @@ kernel.ImportPluginFromObject(
}
);
var result = await kernel.InvokeAsync(
nameof(TranscriptPlugin),
TranscriptPlugin.TranscribeFunctionName,
nameof(AssemblyAIPlugin),
AssemblyAIPlugin.TranscribeFunctionName,
new KernelArguments
{
["INPUT"] = "https://storage.googleapis.com/aai-docs-samples/espn.m4a"
Expand All @@ -84,7 +85,7 @@ You can also invoke the function from within a semantic function like this.
```csharp
const string prompt = """
Here is a transcript:
{{TranscriptPlugin.Transcribe "https://storage.googleapis.com/aai-docs-samples/espn.m4a"}}
{{AssemblyAIPlugin.Transcribe "https://storage.googleapis.com/aai-docs-samples/espn.m4a"}}
---
Summarize the transcript.
""";
Expand Down
12 changes: 9 additions & 3 deletions src/AssemblyAI.SemanticKernel/AssemblyAI.SemanticKernel.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
<PackageTags>SemanticKernel;AI;AssemblyAI;transcript</PackageTags>
<Company>AssemblyAI</Company>
<Product>AssemblyAI</Product>
<AssemblyVersion>1.0.3.0</AssemblyVersion>
<FileVersion>1.0.3.0</FileVersion>
<PackageVersion>1.0.3</PackageVersion>
<AssemblyVersion>1.1.0.0</AssemblyVersion>
<FileVersion>1.1.0.0</FileVersion>
<PackageVersion>1.1.0</PackageVersion>
<OutputType>Library</OutputType>
<PackageLicenseExpression>MIT</PackageLicenseExpression>
<PackageProjectUrl>https://github.com/AssemblyAI/assemblyai-semantic-kernel</PackageProjectUrl>
Expand All @@ -31,6 +31,12 @@
<ContinuousIntegrationBuild>true</ContinuousIntegrationBuild>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Options">
<Version>8.0.0</Version>
</PackageReference>
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions">
<Version>8.0.0</Version>
</PackageReference>
<PackageReference Include="Microsoft.SemanticKernel">
<Version>1.0.1</Version>
</PackageReference>
Expand Down
161 changes: 161 additions & 0 deletions src/AssemblyAI.SemanticKernel/AssemblyAIPlugin.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
using System;
using System.ComponentModel;
using System.IO;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel;

namespace AssemblyAI.SemanticKernel
{
public class AssemblyAIPlugin
{
internal AssemblyAIPluginOptions Options { get; }

private string ApiKey => Options.ApiKey;

private bool AllowFileSystemAccess => Options.AllowFileSystemAccess;

public AssemblyAIPlugin(string apiKey)
{
Options = new AssemblyAIPluginOptions
{
ApiKey = apiKey
};
}

public AssemblyAIPlugin(string apiKey, bool allowFileSystemAccess)
{
Options = new AssemblyAIPluginOptions
{
ApiKey = apiKey,
AllowFileSystemAccess = allowFileSystemAccess
};
}

[ActivatorUtilitiesConstructor]
public AssemblyAIPlugin(IOptions<AssemblyAIPluginOptions> options)
{
Options = options.Value;
}

public const string TranscribeFunctionName = nameof(Transcribe);

[KernelFunction, Description("Transcribe an audio or video file to text.")]
public async Task<string> Transcribe(
[Description("The public URL or the local path of the audio or video file to transcribe.")]
string input
)
{
if (string.IsNullOrEmpty(input))
{
throw new Exception("The INPUT parameter is required.");
}

using (var httpClient = new HttpClient())
{
httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue(ApiKey);
string audioUrl;
if (TryGetPath(input, out var filePath))
{
if (AllowFileSystemAccess == false)
{
throw new Exception(
"You need to allow file system access to upload files. Set AssemblyAI:Plugin:AllowFileSystemAccess to true."
);
}

audioUrl = await UploadFileAsync(filePath, httpClient);
}
else
{
audioUrl = input;
}

var transcript = await CreateTranscriptAsync(audioUrl, httpClient);
transcript = await WaitForTranscriptToProcess(transcript, httpClient);
return transcript.Text ?? throw new Exception("Transcript text is null. This should not happen.");
}
}

private static bool TryGetPath(string input, out string filePath)
{
if (Uri.TryCreate(input, UriKind.Absolute, out var inputUrl))
{
if (inputUrl.IsFile)
{
filePath = inputUrl.LocalPath;
return true;
}

filePath = null;
return false;
}

filePath = input;
return true;
}

private static async Task<string> UploadFileAsync(string path, HttpClient httpClient)
{
using (var fileStream = File.OpenRead(path))
using (var fileContent = new StreamContent(fileStream))
{
fileContent.Headers.ContentType = new MediaTypeHeaderValue("application/octet-stream");
using (var response = await httpClient.PostAsync("https://api.assemblyai.com/v2/upload", fileContent))
{
response.EnsureSuccessStatusCode();
var jsonDoc = await response.Content.ReadFromJsonAsync<JsonDocument>();
return jsonDoc?.RootElement.GetProperty("upload_url").GetString();
}
}
}

private static async Task<Transcript> CreateTranscriptAsync(string audioUrl, HttpClient httpClient)
{
var jsonString = JsonSerializer.Serialize(new
{
audio_url = audioUrl
});

var content = new StringContent(jsonString, Encoding.UTF8, "application/json");
using (var response = await httpClient.PostAsync("https://api.assemblyai.com/v2/transcript", content))
{
response.EnsureSuccessStatusCode();
var transcript = await response.Content.ReadFromJsonAsync<Transcript>();
if (transcript.Status == "error") throw new Exception(transcript.Error);
return transcript;
}
}

private static async Task<Transcript> WaitForTranscriptToProcess(Transcript transcript, HttpClient httpClient)
{
var pollingEndpoint = $"https://api.assemblyai.com/v2/transcript/{transcript.Id}";

while (true)
{
var pollingResponse = await httpClient.GetAsync(pollingEndpoint);
pollingResponse.EnsureSuccessStatusCode();
transcript = (await pollingResponse.Content.ReadFromJsonAsync<Transcript>());
switch (transcript.Status)
{
case "processing":
case "queued":
await Task.Delay(TimeSpan.FromSeconds(3));
break;
case "completed":
return transcript;
case "error":
throw new Exception(transcript.Error);
default:
throw new Exception("This code shouldn't be reachable.");
}
}
}
}
}
26 changes: 26 additions & 0 deletions src/AssemblyAI.SemanticKernel/AssemblyAIPluginOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
namespace AssemblyAI.SemanticKernel
{
/// <summary>
/// Options to configure the AssemblyAI plugin with.
/// </summary>
public class AssemblyAIPluginOptions
{
/// <summary>
/// The name of the plugin registered into Semantic Kernel.
/// Defaults to "AssemblyAIPlugin".
/// </summary>
public string PluginName { get; set; }

/// <summary>
/// The AssemblyAI API key. Find your API key at https://www.assemblyai.com/app/account
/// </summary>
public string ApiKey { get; set; }

/// <summary>
/// If true, you can transcribe audio files from disk.
/// The file be uploaded to AssemblyAI's server to transcribe and deleted when transcription is completed.
/// If false, an exception will be thrown when trying to transcribe files from disk.
/// </summary>
public bool AllowFileSystemAccess { get; set; }
}
}
119 changes: 119 additions & 0 deletions src/AssemblyAI.SemanticKernel/Extensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
using System;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Microsoft.SemanticKernel;

// ReSharper disable UnusedMember.Global
// ReSharper disable MemberCanBePrivate.Global

namespace AssemblyAI.SemanticKernel
{
public static class Extensions
{
/// <summary>
/// Configure the AssemblyAI plugins using the specified configuration section path.
/// </summary>
/// <param name="builder"></param>
/// <param name="configuration">The configuration to bind options to</param>
/// <returns></returns>
public static IKernelBuilder AddAssemblyAIPlugin(
this IKernelBuilder builder,
IConfiguration configuration
)
{
var pluginConfigurationSection = configuration.GetSection("AssemblyAI:Plugin");
// if configuration exists at section, use that config, otherwise using section that was passed in.
if (pluginConfigurationSection.Exists())
{
configuration = pluginConfigurationSection;
}

var services = builder.Services;
var optionsBuilder = services.AddOptions<AssemblyAIPluginOptions>();
optionsBuilder.Bind(configuration);
ValidateOptions(optionsBuilder);
AddPlugin(builder);
return builder;
}

/// <summary>
/// Configure the AssemblyAI plugins using the specified options.
/// </summary>
/// <param name="builder"></param>
/// <param name="options">Options to configure plugin with</param>
/// <returns></returns>
public static IKernelBuilder AddAssemblyAIPlugin(
this IKernelBuilder builder,
AssemblyAIPluginOptions options
)
{
var services = builder.Services;
var optionsBuilder = services.AddOptions<AssemblyAIPluginOptions>();
optionsBuilder.Configure(optionsToConfigure =>
{
optionsToConfigure.ApiKey = options.ApiKey;
optionsToConfigure.AllowFileSystemAccess = options.AllowFileSystemAccess;
});
ValidateOptions(optionsBuilder);
AddPlugin(builder);
return builder;
}

/// <summary>
/// Configure the AssemblyAI plugins using the specified options.
/// </summary>
/// <param name="builder"></param>
/// <param name="configureOptions">Action to configure options</param>
/// <returns></returns>
public static IKernelBuilder AddAssemblyAIPlugin(
this IKernelBuilder builder,
Action<AssemblyAIPluginOptions> configureOptions
)
{
var services = builder.Services;
var optionsBuilder = services.AddOptions<AssemblyAIPluginOptions>();
optionsBuilder.Configure(configureOptions);
ValidateOptions(optionsBuilder);
AddPlugin(builder);
return builder;
}

/// <summary>
/// Configure the AssemblyAI plugins using the specified options.
/// </summary>
/// <param name="builder"></param>
/// <param name="configureOptions">Action to configure options</param>
/// <returns></returns>
public static IKernelBuilder AddAssemblyAIPlugin(
this IKernelBuilder builder,
Action<IServiceProvider, AssemblyAIPluginOptions> configureOptions
)
{
var services = builder.Services;
var optionsBuilder = services.AddOptions<AssemblyAIPluginOptions>();
optionsBuilder.Configure<IServiceProvider>((options, provider) => configureOptions(provider, options));
ValidateOptions(optionsBuilder);
AddPlugin(builder);
return builder;
}

private static void ValidateOptions(OptionsBuilder<AssemblyAIPluginOptions> optionsBuilder)
{
optionsBuilder.Validate(
options => !string.IsNullOrEmpty(options.ApiKey),
"AssemblyAI:Plugin:ApiKey must be configured."
);
}

private static void AddPlugin(IKernelBuilder builder)
{
using (var sp = builder.Services.BuildServiceProvider())
{
var config = sp.GetRequiredService<IOptions<AssemblyAIPluginOptions>>().Value;
var pluginName = string.IsNullOrEmpty(config.PluginName) ? null : config.PluginName;
builder.Plugins.AddFromType<AssemblyAIPlugin>(pluginName);
}
}
}
}
Loading

0 comments on commit 6fd943b

Please sign in to comment.