Skip to content

Commit e61be89

Browse files
tonybaloneyatifazizCopilot
authored
Optionally embed source code for modules in generated classes (#451)
* Add an import from string mechanism * Decode the base64 in the module loader * Make configurable and default to off * Use const instead of static * Update snapshots * Apply suggestions from code review Co-authored-by: Copilot <[email protected]> Co-authored-by: Atif Aziz <[email protected]> * Test more weird arguments * Clean up the import function with PyObject classes * Fix syntax * Add relative path --------- Co-authored-by: Atif Aziz <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 35db15a commit e61be89

28 files changed

+407
-43
lines changed

src/CSnakes.Runtime.Tests/Python/ImportTests.cs

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,66 @@ public class ImportTests : RuntimeTestBase
66
[Fact]
77
public void TestImportModule()
88
{
9-
using (GIL.Acquire())
10-
{
11-
using PyObject sys = Import.ImportModule("sys");
12-
Assert.NotNull(sys);
13-
Assert.Equal("<module 'sys' (built-in)>", sys.ToString());
14-
}
9+
using PyObject sys = Import.ImportModule("sys");
10+
Assert.NotNull(sys);
11+
Assert.Equal("<module 'sys' (built-in)>", sys.ToString());
1512
}
1613

1714
[Fact]
1815
public void TestReloadModule()
1916
{
20-
using (GIL.Acquire())
21-
{
22-
PyObject sys = Import.ImportModule("sys");
23-
Assert.NotNull(sys);
24-
Import.ReloadModule(ref sys);
25-
Assert.Equal("<module 'sys' (built-in)>", sys.ToString());
26-
}
17+
PyObject sys = Import.ImportModule("sys");
18+
Assert.NotNull(sys);
19+
Import.ReloadModule(ref sys);
20+
Assert.Equal("<module 'sys' (built-in)>", sys.ToString());
2721
}
2822

2923
[Fact]
3024
public void TestReloadModuleThatIsntAModule()
3125
{
32-
using (GIL.Acquire())
33-
{
34-
PyObject sys = PyObject.From(42); // definitely not a module
35-
Assert.NotNull(sys);
36-
Assert.Throws<PythonInvocationException>(() => Import.ReloadModule(ref sys));
37-
}
26+
PyObject sys = PyObject.From(42); // definitely not a module
27+
Assert.NotNull(sys);
28+
Assert.Throws<PythonInvocationException>(() => Import.ReloadModule(ref sys));
29+
}
30+
31+
[Fact]
32+
public void TestImportFromString()
33+
{
34+
string source = "print('hello world')"; // Python code string that prints "hello world"
35+
string path = Environment.CurrentDirectory;
36+
using PyObject module = Import.ImportModule("test_module", source, path);
37+
Assert.NotNull(module);
38+
Assert.StartsWith("<module 'test_module' ", module.ToString());
39+
}
40+
41+
[Fact]
42+
public void TestImportFromStringWithInvalidCode()
43+
{
44+
string source = "print('hello world'"; // Invalid Python code
45+
string path = Environment.CurrentDirectory;
46+
Assert.Throws<PythonInvocationException>(() => Import.ImportModule("test_module", source, path));
47+
}
48+
49+
[Fact]
50+
public void TestImportFromStringWithInvalidPath()
51+
{
52+
string source = "print('hello world')"; // Valid Python code
53+
Assert.Throws<ArgumentException>(() => Import.ImportModule("", source, ""));
54+
}
55+
56+
[Fact]
57+
public void TestImportFromStringWithNullValues()
58+
{
59+
#pragma warning disable CS8625 // Cannot convert null literal to non-nullable reference type.
60+
Assert.Throws<ArgumentNullException>(() => Import.ImportModule("test_module", null, null));
61+
#pragma warning restore CS8625 // Cannot convert null literal to non-nullable reference type.
62+
}
63+
64+
[Fact]
65+
public void TestImportFromStringWithEmptyValues()
66+
{
67+
string source = ""; // Invalid Python code
68+
string path = ""; // Invalid path
69+
Assert.Throws<ArgumentException>(() => Import.ImportModule("test_module", source, path));
3870
}
3971
}

src/CSnakes.Runtime/CPython/Import.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ namespace CSnakes.Runtime.CPython;
55

66
internal unsafe partial class CPythonAPI
77
{
8+
internal enum OptimizationLevel
9+
{
10+
Default = -1,
11+
None = 0,
12+
RemoveAssertions = 1,
13+
RemoveAssertionsAndDocStrings = 2,
14+
}
15+
816
/// <summary>
917
/// Import a module and return a reference to it.
1018
/// </summary>
@@ -18,6 +26,18 @@ internal static PyObject Import(string name)
1826
return PyObject.Create(module);
1927
}
2028

29+
internal static PyObject Import(string name, string code, string path, OptimizationLevel optimizationLevel = OptimizationLevel.Default)
30+
{
31+
using var pyPath = PyObject.From(path);
32+
33+
using var codeObject = PyObject.Create(Py_CompileStringObject(code, pyPath, InputType.Py_file_input, IntPtr.Zero, optimizationLevel));
34+
35+
using var pyName = PyObject.From(name);
36+
using var pyCode = PyObject.From(code);
37+
38+
return PyObject.Create(PyImport_ExecCodeModuleObject(pyName, codeObject, pyPath, pyPath));
39+
}
40+
2141
internal static PyObject ReloadModule(PyObject module)
2242
{
2343
nint reloaded = PyImport_ReloadModule(module);
@@ -48,6 +68,9 @@ protected static nint GetBuiltin(string name)
4868
internal static partial nint PyImport_Import(nint name);
4969

5070

71+
[LibraryImport(PythonLibraryName)]
72+
private static partial nint PyImport_ExecCodeModuleObject(PyObject name, PyObject co, PyObject pathname, PyObject cpathname);
73+
5174
/// <summary>
5275
/// Reload a module. Return a new reference to the reloaded module, or NULL with an exception set on failure (the module still exists in this case).
5376
/// </summary>
@@ -56,4 +79,6 @@ protected static nint GetBuiltin(string name)
5679
[LibraryImport(PythonLibraryName)]
5780
internal static partial nint PyImport_ReloadModule(PyObject module);
5881

82+
[LibraryImport(PythonLibraryName, StringMarshalling = StringMarshalling.Custom, StringMarshallingCustomType = typeof(NonFreeUtf8StringMarshaller))]
83+
private static partial nint Py_CompileStringObject(string code, PyObject filename, InputType start, nint flags = 0, OptimizationLevel opt = OptimizationLevel.Default);
5984
}

src/CSnakes.Runtime/Python/Import.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ public static PyObject ImportModule(string module)
1010
return CPythonAPI.Import(module);
1111
}
1212

13+
public static PyObject ImportModule(string module, string source, string path)
14+
{
15+
ArgumentException.ThrowIfNullOrEmpty(module);
16+
ArgumentException.ThrowIfNullOrEmpty(source);
17+
ArgumentException.ThrowIfNullOrEmpty(path);
18+
19+
using (GIL.Acquire())
20+
{
21+
return CPythonAPI.Import(module, source, path);
22+
}
23+
}
24+
1325
public static void ReloadModule(ref PyObject module)
1426
{
1527
using (GIL.Acquire())

src/CSnakes.SourceGeneration/PythonStaticGenerator.cs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using CSnakes.Parser;
22
using CSnakes.Parser.Types;
33
using CSnakes.Reflection;
4+
using CSnakes.SourceGeneration;
45
using Microsoft.CodeAnalysis;
56
using Microsoft.CodeAnalysis.CSharp;
67
using Microsoft.CodeAnalysis.Text;
@@ -17,20 +18,33 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
1718
var pythonFilesPipeline = context.AdditionalTextsProvider
1819
.Where(static text => Path.GetExtension(text.Path) == ".py");
1920

20-
context.RegisterSourceOutput(pythonFilesPipeline, static (sourceContext, file) =>
21+
// Get analyser config options
22+
var embedPythonSource = context.AnalyzerConfigOptionsProvider.Select(static (options, cancellationToken) =>
23+
options.GlobalOptions.TryGetValue("csnakes_embed_source", out var embedSourceSwitch)
24+
? embedSourceSwitch.Equals("true", StringComparison.InvariantCultureIgnoreCase)
25+
: false); // Default
26+
27+
context.RegisterSourceOutput(pythonFilesPipeline.Combine(embedPythonSource), static (sourceContext, opts) =>
2128
{
29+
var file = opts.Left;
30+
var embedSourceSwitch = opts.Right;
31+
2232
// Add environment path
2333
var @namespace = "CSnakes.Runtime";
2434

2535
var fileName = Path.GetFileNameWithoutExtension(file.Path);
2636

2737
// Convert snake_case to PascalCase
28-
var pascalFileName = string.Join("", fileName.Split('_').Select(s => char.ToUpperInvariant(s[0]) + s.Substring(1)));
38+
var pascalFileName = string.Join("", fileName.Split('_').Select(s => char.ToUpperInvariant(s[0]) + s[1..]));
39+
2940
// Read the file
3041
var code = file.GetText(sourceContext.CancellationToken);
3142

3243
if (code is null) return;
3344

45+
// Decide whether to embed the source based on project settings
46+
var embedSource = embedSourceSwitch ? code.ToBaseUTF864() : string.Empty;
47+
3448
// Calculate hash of code
3549
var hash = code.GetContentHash();
3650

@@ -47,14 +61,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
4761
if (result)
4862
{
4963
var methods = ModuleReflection.MethodsFromFunctionDefinitions(functions, fileName).ToImmutableArray();
50-
string source = FormatClassFromMethods(@namespace, pascalFileName, methods, fileName, functions, hash);
64+
string source = FormatClassFromMethods(@namespace, pascalFileName, methods, fileName, functions, hash, embedSource);
5165
sourceContext.AddSource($"{pascalFileName}.py.cs", source);
5266
sourceContext.ReportDiagnostic(Diagnostic.Create(new DiagnosticDescriptor("PSG002", "PythonStaticGenerator", $"Generated {pascalFileName}.py.cs", "PythonStaticGenerator", DiagnosticSeverity.Info, true), Location.None));
5367
}
5468
});
5569
}
5670

57-
public static string FormatClassFromMethods(string @namespace, string pascalFileName, ImmutableArray<MethodDefinition> methods, string fileName, PythonFunctionDefinition[] functions, ImmutableArray<byte> hash)
71+
public static string FormatClassFromMethods(string @namespace, string pascalFileName, ImmutableArray<MethodDefinition> methods, string fileName, PythonFunctionDefinition[] functions, ImmutableArray<byte> hash, string? base64Value = null)
5872
{
5973
var paramGenericArgs = methods
6074
.Select(m => m.ParameterGenericArgs)
@@ -77,6 +91,7 @@ public static string FormatClassFromMethods(string @namespace, string pascalFile
7791
using System.Collections.Generic;
7892
using System.Diagnostics;
7993
using System.Reflection.Metadata;
94+
using System.Text;
8095
using System.Threading;
8196
using System.Threading.Tasks;
8297
@@ -92,6 +107,8 @@ public static class {{pascalFileName}}Extensions
92107
93108
private static ReadOnlySpan<byte> HotReloadHash => "{{HexString(hash.AsSpan())}}"u8;
94109
110+
private const string encodedSource = "{{base64Value}}";
111+
95112
public static I{{pascalFileName}} {{pascalFileName}}(this IPythonEnvironment env)
96113
{
97114
if (instance is null)
@@ -122,7 +139,9 @@ from f in functionNames
122139
using (GIL.Acquire())
123140
{
124141
logger?.LogDebug("Importing module {ModuleName}", "{{fileName}}");
125-
module = Import.ImportModule("{{fileName}}");
142+
this.module = !string.IsNullOrEmpty(encodedSource)
143+
? Import.ImportModule("{{fileName}}", Encoding.UTF8.GetString(Convert.FromBase64String(encodedSource)), "{{fileName}}.py")
144+
: Import.ImportModule("{{fileName}}");
126145
{{ Lines(IndentationLevel.Four,
127146
from f in functionNames
128147
select $"this.{f.Field} = module.GetAttr(\"{f.Attr}\");") }}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using Microsoft.CodeAnalysis.Text;
2+
using System.Text;
3+
4+
namespace CSnakes.SourceGeneration;
5+
internal static class SourceFileUtils
6+
{
7+
internal static string ToBaseUTF864(this SourceText text)
8+
{
9+
var bytes = Encoding.UTF8.GetBytes(text.ToString());
10+
return Convert.ToBase64String(bytes);
11+
}
12+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Get all .received.txt files in the current directory and subdirectories
2+
$receivedFiles = Get-ChildItem -Path . -Filter *.received.txt -Recurse
3+
4+
foreach ($file in $receivedFiles) {
5+
# Construct the path for the corresponding .approved.txt file
6+
$approvedFile = $file.FullName -replace '\.received\.txt$', '.approved.txt'
7+
8+
# Copy the .received.txt file to .approved.txt, overwriting if it exists
9+
Copy-Item -Path $file.FullName -Destination $approvedFile -Force
10+
}
11+
12+
Write-Host "All .received.txt files have been copied to .approved.txt."

src/CSnakes.Tests/PythonStaticGeneratorTests/FormatClassFromMethods.test_args.approved.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using System;
1010
using System.Collections.Generic;
1111
using System.Diagnostics;
1212
using System.Reflection.Metadata;
13+
using System.Text;
1314
using System.Threading;
1415
using System.Threading.Tasks;
1516

@@ -25,6 +26,8 @@ public static class TestClassExtensions
2526

2627
private static ReadOnlySpan<byte> HotReloadHash => "fb3f77ef96b716122ce4f1a0f0115065"u8;
2728

29+
private const string encodedSource = "";
30+
2831
public static ITestClass TestClass(this IPythonEnvironment env)
2932
{
3033
if (instance is null)
@@ -58,7 +61,9 @@ public static class TestClassExtensions
5861
using (GIL.Acquire())
5962
{
6063
logger?.LogDebug("Importing module {ModuleName}", "test");
61-
module = Import.ImportModule("test");
64+
this.module = !string.IsNullOrEmpty(encodedSource)
65+
? Import.ImportModule("test", Encoding.UTF8.GetString(Convert.FromBase64String(encodedSource)), "test.py")
66+
: Import.ImportModule("test");
6267
this.__func_positional_only_args = module.GetAttr("positional_only_args");
6368
this.__func_collect_star_args = module.GetAttr("collect_star_args");
6469
this.__func_keyword_only_args = module.GetAttr("keyword_only_args");

src/CSnakes.Tests/PythonStaticGeneratorTests/FormatClassFromMethods.test_args_underscore.approved.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using System;
1010
using System.Collections.Generic;
1111
using System.Diagnostics;
1212
using System.Reflection.Metadata;
13+
using System.Text;
1314
using System.Threading;
1415
using System.Threading.Tasks;
1516

@@ -25,6 +26,8 @@ public static class TestClassExtensions
2526

2627
private static ReadOnlySpan<byte> HotReloadHash => "b689cac043d82a1e3830504a67fe21c6"u8;
2728

29+
private const string encodedSource = "";
30+
2831
public static ITestClass TestClass(this IPythonEnvironment env)
2932
{
3033
if (instance is null)
@@ -53,7 +56,9 @@ public static class TestClassExtensions
5356
using (GIL.Acquire())
5457
{
5558
logger?.LogDebug("Importing module {ModuleName}", "test");
56-
module = Import.ImportModule("test");
59+
this.module = !string.IsNullOrEmpty(encodedSource)
60+
? Import.ImportModule("test", Encoding.UTF8.GetString(Convert.FromBase64String(encodedSource)), "test.py")
61+
: Import.ImportModule("test");
5762
this.__func_test_with_underscore = module.GetAttr("test_with_underscore");
5863
}
5964
}

src/CSnakes.Tests/PythonStaticGeneratorTests/FormatClassFromMethods.test_basic.approved.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using System;
1010
using System.Collections.Generic;
1111
using System.Diagnostics;
1212
using System.Reflection.Metadata;
13+
using System.Text;
1314
using System.Threading;
1415
using System.Threading.Tasks;
1516

@@ -25,6 +26,8 @@ public static class TestClassExtensions
2526

2627
private static ReadOnlySpan<byte> HotReloadHash => "c9f266d12c416006de17b285369f898f"u8;
2728

29+
private const string encodedSource = "";
30+
2831
public static ITestClass TestClass(this IPythonEnvironment env)
2932
{
3033
if (instance is null)
@@ -63,7 +66,9 @@ public static class TestClassExtensions
6366
using (GIL.Acquire())
6467
{
6568
logger?.LogDebug("Importing module {ModuleName}", "test");
66-
module = Import.ImportModule("test");
69+
this.module = !string.IsNullOrEmpty(encodedSource)
70+
? Import.ImportModule("test", Encoding.UTF8.GetString(Convert.FromBase64String(encodedSource)), "test.py")
71+
: Import.ImportModule("test");
6772
this.__func_test_int_float = module.GetAttr("test_int_float");
6873
this.__func_test_int_int = module.GetAttr("test_int_int");
6974
this.__func_test_float_float = module.GetAttr("test_float_float");

src/CSnakes.Tests/PythonStaticGeneratorTests/FormatClassFromMethods.test_buffer.approved.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using System;
1010
using System.Collections.Generic;
1111
using System.Diagnostics;
1212
using System.Reflection.Metadata;
13+
using System.Text;
1314
using System.Threading;
1415
using System.Threading.Tasks;
1516

@@ -25,6 +26,8 @@ public static class TestClassExtensions
2526

2627
private static ReadOnlySpan<byte> HotReloadHash => "cabb6a0854c73088e04547b226cc6c79"u8;
2728

29+
private const string encodedSource = "";
30+
2831
public static ITestClass TestClass(this IPythonEnvironment env)
2932
{
3033
if (instance is null)
@@ -83,7 +86,9 @@ public static class TestClassExtensions
8386
using (GIL.Acquire())
8487
{
8588
logger?.LogDebug("Importing module {ModuleName}", "test");
86-
module = Import.ImportModule("test");
89+
this.module = !string.IsNullOrEmpty(encodedSource)
90+
? Import.ImportModule("test", Encoding.UTF8.GetString(Convert.FromBase64String(encodedSource)), "test.py")
91+
: Import.ImportModule("test");
8792
this.__func_test_bool_buffer = module.GetAttr("test_bool_buffer");
8893
this.__func_test_int8_buffer = module.GetAttr("test_int8_buffer");
8994
this.__func_test_uint8_buffer = module.GetAttr("test_uint8_buffer");

0 commit comments

Comments
 (0)