Skip to content

Commit

Permalink
Merge pull request #6 from tonybaloney/fix_no_empty
Browse files Browse the repository at this point in the history
Update tracking for todos. Add more parser tests
  • Loading branch information
tonybaloney authored Jul 19, 2024
2 parents 9df2dbb + bda6863 commit bd04f9e
Show file tree
Hide file tree
Showing 19 changed files with 294 additions and 56 deletions.
2 changes: 0 additions & 2 deletions AutoUpdateAssemblyName.txt

This file was deleted.

2 changes: 1 addition & 1 deletion ExamplePythonDependency/ExamplePythonDependency.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<AdditionalFiles Include="type_demos.py">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</AdditionalFiles>
<AdditionalFiles Include="mistral_demo.py">
<AdditionalFiles Include="phi3_demo.py">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</AdditionalFiles>
<AdditionalFiles Include="kmeans_example.py">
Expand Down
18 changes: 0 additions & 18 deletions ExamplePythonDependency/mistral_demo.py

This file was deleted.

34 changes: 34 additions & 0 deletions ExamplePythonDependency/phi3_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

def phi3_inference_demo(user_message: str, system_message: str = "You are a helpful AI assistant.", temperature: float = 0.0) -> str:
torch.random.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
device_map="cuda",
torch_dtype="auto",
trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")

messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]

pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)

generation_args = {
"max_new_tokens": 500,
"return_full_text": False,
"temperature": temperature,
"do_sample": False,
}

output = pipe(messages, **generation_args)
return output[0]['generated_text']
7 changes: 6 additions & 1 deletion ExamplePythonDependency/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
numpy
scikit-learn
scikit-learn
wheel
flash_attn==2.5.8
torch==2.3.1
accelerate==0.31.0
transformers==4.41.2
6 changes: 3 additions & 3 deletions PythonEnvironments/PythonEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ private static string TryLocatePython(string version)
var versionPath = MapVersion(version);
var windowsStorePath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), "Programs", "Python", "Python" + versionPath);
var officialInstallerPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ProgramFiles), "Python", MapVersion(version, "."));
// TODO: Locate from PATH
// TODO: Add standard paths for Linux and MacOS
// TODO: (track) Locate from PATH
// TODO: (track) Add standard paths for Linux and MacOS
if (Directory.Exists(windowsStorePath))
{
return windowsStorePath;
Expand Down Expand Up @@ -115,7 +115,7 @@ public PythonEnvironmentInternal(string pythonLocation, string versionPath, stri
}
else
{
// TODO: C extension path for linux/macos
// TODO: (track) C extension path for linux/macos
}

if (extraPath.Length > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public BasicSmokeTest(TestEnvironment testEnv)
[InlineData("def hello_world(numbers: list[float]) -> list[int]:\n ...\n", "IEnumerable<long> HelloWorld(IEnumerable<double> numbers)")]
[InlineData("def hello_world(a: bool, b: str, c: list[tuple[int, float]]) -> bool: \n ...\n", "bool HelloWorld(bool a, string b, IEnumerable<Tuple<long, double>> c)")]
[InlineData("def hello_world(a: bool = True, b: str = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string b = null)")]
[InlineData("def hello_world(a: str, *args) -> None: \n ...\n", "void HelloWorld(string a, Tuple<PyObject> args = null)")]
[InlineData("def hello_world(a: str, *args, **kwargs) -> None: \n ...\n", "void HelloWorld(string a, Tuple<PyObject> args = null, IReadOnlyDictionary<string, PyObject> kwargs = null)")]
public void TestGeneratedSignature(string code, string expected)
{

Expand All @@ -34,7 +36,7 @@ public void TestGeneratedSignature(string code, string expected)

// create a Python scope
PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);

Assert.Empty(errors);
var module = ModuleReflection.MethodsFromFunctionDefinitions(functions, "test");
var csharp = module.Select(m => m.Syntax).Compile();
Assert.Contains(expected, csharp);
Expand All @@ -49,7 +51,7 @@ public void TestGeneratedSignature(string code, string expected)
.AddReferences(MetadataReference.CreateFromFile(typeof(IReadOnlyDictionary<,>).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(typeof(PythonEnvironments.PythonEnvironment).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(typeof(Py).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "netstandard").Location)) // TODO: Ensure 2.0
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "netstandard").Location)) // TODO: (track) Ensure 2.0
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Runtime").Location))
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Collections").Location))
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Linq.Expressions").Location))
Expand Down
25 changes: 25 additions & 0 deletions PythonSourceGenerator.Tests/CaseHelperTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace PythonSourceGenerator.Tests;

public class CaseHelperTests
{
[Fact]
public void VerifyToPascalCase()
{
Assert.Equal("Hello", CaseHelper.ToPascalCase("hello"));
Assert.Equal("HelloWorld", CaseHelper.ToPascalCase("hello_world"));
Assert.Equal("Hello_", CaseHelper.ToPascalCase("hello_"));
Assert.Equal("Hello_World", CaseHelper.ToPascalCase("hello__world"));
Assert.Equal("_Hello_World", CaseHelper.ToPascalCase("_hello__world"));
}

[Fact]
public void VerifyToLowerPascalCase()
{
Assert.Equal("hello", CaseHelper.ToLowerPascalCase("hello"));
Assert.Equal("helloWorld", CaseHelper.ToLowerPascalCase("hello_world"));
Assert.Equal("hello_", CaseHelper.ToLowerPascalCase("hello_"));
Assert.Equal("hello_World", CaseHelper.ToLowerPascalCase("hello__world"));
// TODO: (track) This instance could arguably be _hello_World although the name is already weird
Assert.Equal("_Hello_World", CaseHelper.ToLowerPascalCase("_hello__world"));
}
}
62 changes: 62 additions & 0 deletions PythonSourceGenerator.Tests/IntegrationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis;
using Python.Runtime;
using PythonSourceGenerator.Parser;
using PythonSourceGenerator.Reflection;

namespace PythonSourceGenerator.Tests;

public class IntegrationTests : IClassFixture<TestEnvironment>
{
TestEnvironment testEnv;

public IntegrationTests(TestEnvironment testEnv)
{
this.testEnv = testEnv;
}

private bool Compile(string code, string assemblyName)
{
var tempName = string.Format("{0}_{1:N}", "test", Guid.NewGuid().ToString("N"));
File.WriteAllText(Path.Combine(testEnv.TempDir, $"{tempName}.py"), code);

// create a Python scope
PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);
Assert.Empty(errors);
var module = ModuleReflection.MethodsFromFunctionDefinitions(functions, "test");
var csharp = module.Select(m => m.Syntax).Compile();

// Check that the sample C# code compiles
string compiledCode = PythonStaticGenerator.FormatClassFromMethods("Python.Generated.Tests", "TestClass", module);
var tree = CSharpSyntaxTree.ParseText(compiledCode);
var compilation = CSharpCompilation.Create(assemblyName, options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary))
.AddReferences(MetadataReference.CreateFromFile(typeof(object).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(typeof(IEnumerable<>).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(typeof(IReadOnlyDictionary<,>).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(typeof(PythonEnvironments.PythonEnvironment).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(typeof(Py).Assembly.Location))
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "netstandard").Location)) // TODO: (track) Ensure 2.0
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Runtime").Location))
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Collections").Location))
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Linq.Expressions").Location))

.AddSyntaxTrees(tree);
var path = testEnv.TempDir + $"/{assemblyName}.dll";
var result = compilation.Emit(path);
Assert.True(result.Success, compiledCode + "\n" + string.Join("\n", result.Diagnostics));
// Delete assembly
File.Delete(path);
return result.Success;
}

[Fact]
public void TestBasicString()
{
var code = """
def foo(in_: str) -> str:
return in_.upper()
""";
Assert.True(Compile(code, "stringFoo"));
}
}
64 changes: 61 additions & 3 deletions PythonSourceGenerator.Tests/TokenizerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def baz(c: float, d: bool) -> None:
xyz = 1
""";
_ = PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);

Assert.Empty(errors);
Assert.NotNull(functions);
Assert.Equal(2, functions.Length);
Assert.Equal("bar", functions[0].Name);
Expand All @@ -332,7 +332,7 @@ def baz(c: float, d: bool) -> None:
[Fact]
public void ParseMultiLineFunctionDefinition()
{
var code = @"""
var code = @"
import foo
def bar(a: int,
Expand All @@ -343,9 +343,52 @@ def bar(a: int,
if __name__ == '__main__':
xyz = 1
""";
";
_ = PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);
Assert.Empty(errors);
Assert.NotNull(functions);
Assert.Single(functions);
Assert.Equal("bar", functions[0].Name);
Assert.Equal("a", functions[0].Parameters[0].Name);
Assert.Equal("int", functions[0].Parameters[0].Type.Name);
Assert.Equal("b", functions[0].Parameters[1].Name);
Assert.Equal("str", functions[0].Parameters[1].Type.Name);
Assert.Equal("None", functions[0].ReturnType.Name);
}

[Fact]
public void ParseFunctionWithTrailingComment()
{
var code = @"def bar(a: int, b: str) -> None: # this is a comment
pass";
_ = PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);
Assert.Empty(errors);
Assert.NotNull(functions);
Assert.Single(functions);
Assert.Equal("bar", functions[0].Name);
}

[Fact]
public void ParseFunctionTrailingSpaceAfterColon()
{
var code = @"def bar(a: int,
b: str) -> None:
pass"; // There is a trailing space after None:
_ = PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);
Assert.Empty(errors);
Assert.NotNull(functions);
Assert.Single(functions);
Assert.Equal("bar", functions[0].Name);
}

[Fact]
public void ParseFunctionNoBlankLineAtEnd()
{
var code = @"def bar(a: int,
b: str) -> None:
pass";
_ = PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);
Assert.Empty(errors);
Assert.NotNull(functions);
Assert.Single(functions);
Assert.Equal("bar", functions[0].Name);
Expand All @@ -355,4 +398,19 @@ def bar(a: int,
Assert.Equal("str", functions[0].Parameters[1].Type.Name);
Assert.Equal("None", functions[0].ReturnType.Name);
}

[Fact]
public void VerifyErrors()
{
var code = @"
def bar(a: int, b:= str) -> None:
pass";
_ = PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors);
Assert.NotEmpty(errors);
Assert.Equal(4, errors[0].StartLine);
Assert.Equal(4, errors[0].EndLine);
}
}
2 changes: 1 addition & 1 deletion PythonSourceGenerator/CaseHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ public static class CaseHelper
{
public static string ToPascalCase(this string snakeCase)
{
return string.Join("", snakeCase.Split('_').Select(s => char.ToUpperInvariant(s[0]) + s.Substring(1)));
return string.Join("", snakeCase.Split('_').Select(s => s.Length > 1 ? char.ToUpperInvariant(s[0]) + s.Substring(1): "_"));
}

public static string ToLowerPascalCase(this string snakeCase)
Expand Down
27 changes: 27 additions & 0 deletions PythonSourceGenerator/GeneratorError.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
namespace PythonSourceGenerator;

public class GeneratorError
{

public string Message { get; }

public int StartLine { get; }

public int StartColumn { get; }

public int EndLine { get; }

public int EndColumn { get; }

public string Code { get; }

public GeneratorError(int startLine, int endLine, int startColumn, int endColumn, string message)
{
Message = message;
StartLine = startLine;
StartColumn = startColumn;
EndLine = endLine;
EndColumn = endColumn;
Code = "hello";
}
}
Loading

0 comments on commit bd04f9e

Please sign in to comment.