-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #6 from tonybaloney/fix_no_empty
Update tracking for todos. Add more parser tests
- Loading branch information
Showing
19 changed files
with
294 additions
and
56 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | ||
|
||
def phi3_inference_demo(user_message: str, system_message: str = "You are a helpful AI assistant.", temperature: float = 0.0) -> str: | ||
torch.random.manual_seed(0) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
"microsoft/Phi-3-mini-4k-instruct", | ||
device_map="cuda", | ||
torch_dtype="auto", | ||
trust_remote_code=True, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") | ||
|
||
messages = [ | ||
{"role": "system", "content": system_message}, | ||
{"role": "user", "content": user_message}, | ||
] | ||
|
||
pipe = pipeline( | ||
"text-generation", | ||
model=model, | ||
tokenizer=tokenizer, | ||
) | ||
|
||
generation_args = { | ||
"max_new_tokens": 500, | ||
"return_full_text": False, | ||
"temperature": temperature, | ||
"do_sample": False, | ||
} | ||
|
||
output = pipe(messages, **generation_args) | ||
return output[0]['generated_text'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,7 @@ | ||
numpy | ||
scikit-learn | ||
scikit-learn | ||
wheel | ||
flash_attn==2.5.8 | ||
torch==2.3.1 | ||
accelerate==0.31.0 | ||
transformers==4.41.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
namespace PythonSourceGenerator.Tests; | ||
|
||
public class CaseHelperTests | ||
{ | ||
[Fact] | ||
public void VerifyToPascalCase() | ||
{ | ||
Assert.Equal("Hello", CaseHelper.ToPascalCase("hello")); | ||
Assert.Equal("HelloWorld", CaseHelper.ToPascalCase("hello_world")); | ||
Assert.Equal("Hello_", CaseHelper.ToPascalCase("hello_")); | ||
Assert.Equal("Hello_World", CaseHelper.ToPascalCase("hello__world")); | ||
Assert.Equal("_Hello_World", CaseHelper.ToPascalCase("_hello__world")); | ||
} | ||
|
||
[Fact] | ||
public void VerifyToLowerPascalCase() | ||
{ | ||
Assert.Equal("hello", CaseHelper.ToLowerPascalCase("hello")); | ||
Assert.Equal("helloWorld", CaseHelper.ToLowerPascalCase("hello_world")); | ||
Assert.Equal("hello_", CaseHelper.ToLowerPascalCase("hello_")); | ||
Assert.Equal("hello_World", CaseHelper.ToLowerPascalCase("hello__world")); | ||
// TODO: (track) This instance could arguably be _hello_World although the name is already weird | ||
Assert.Equal("_Hello_World", CaseHelper.ToLowerPascalCase("_hello__world")); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
using Microsoft.CodeAnalysis.CSharp; | ||
using Microsoft.CodeAnalysis; | ||
using Python.Runtime; | ||
using PythonSourceGenerator.Parser; | ||
using PythonSourceGenerator.Reflection; | ||
|
||
namespace PythonSourceGenerator.Tests; | ||
|
||
public class IntegrationTests : IClassFixture<TestEnvironment> | ||
{ | ||
TestEnvironment testEnv; | ||
|
||
public IntegrationTests(TestEnvironment testEnv) | ||
{ | ||
this.testEnv = testEnv; | ||
} | ||
|
||
private bool Compile(string code, string assemblyName) | ||
{ | ||
var tempName = string.Format("{0}_{1:N}", "test", Guid.NewGuid().ToString("N")); | ||
File.WriteAllText(Path.Combine(testEnv.TempDir, $"{tempName}.py"), code); | ||
|
||
// create a Python scope | ||
PythonSignatureParser.TryParseFunctionDefinitions(code, out var functions, out var errors); | ||
Assert.Empty(errors); | ||
var module = ModuleReflection.MethodsFromFunctionDefinitions(functions, "test"); | ||
var csharp = module.Select(m => m.Syntax).Compile(); | ||
|
||
// Check that the sample C# code compiles | ||
string compiledCode = PythonStaticGenerator.FormatClassFromMethods("Python.Generated.Tests", "TestClass", module); | ||
var tree = CSharpSyntaxTree.ParseText(compiledCode); | ||
var compilation = CSharpCompilation.Create(assemblyName, options: new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) | ||
.AddReferences(MetadataReference.CreateFromFile(typeof(object).Assembly.Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(typeof(IEnumerable<>).Assembly.Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(typeof(Enumerable).Assembly.Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(typeof(IReadOnlyDictionary<,>).Assembly.Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(typeof(PythonEnvironments.PythonEnvironment).Assembly.Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(typeof(Py).Assembly.Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "netstandard").Location)) // TODO: (track) Ensure 2.0 | ||
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Runtime").Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Collections").Location)) | ||
.AddReferences(MetadataReference.CreateFromFile(AppDomain.CurrentDomain.GetAssemblies().Single(a => a.GetName().Name == "System.Linq.Expressions").Location)) | ||
|
||
.AddSyntaxTrees(tree); | ||
var path = testEnv.TempDir + $"/{assemblyName}.dll"; | ||
var result = compilation.Emit(path); | ||
Assert.True(result.Success, compiledCode + "\n" + string.Join("\n", result.Diagnostics)); | ||
// Delete assembly | ||
File.Delete(path); | ||
return result.Success; | ||
} | ||
|
||
[Fact] | ||
public void TestBasicString() | ||
{ | ||
var code = """ | ||
def foo(in_: str) -> str: | ||
return in_.upper() | ||
"""; | ||
Assert.True(Compile(code, "stringFoo")); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
namespace PythonSourceGenerator; | ||
|
||
public class GeneratorError | ||
{ | ||
|
||
public string Message { get; } | ||
|
||
public int StartLine { get; } | ||
|
||
public int StartColumn { get; } | ||
|
||
public int EndLine { get; } | ||
|
||
public int EndColumn { get; } | ||
|
||
public string Code { get; } | ||
|
||
public GeneratorError(int startLine, int endLine, int startColumn, int endColumn, string message) | ||
{ | ||
Message = message; | ||
StartLine = startLine; | ||
StartColumn = startColumn; | ||
EndLine = endLine; | ||
EndColumn = endColumn; | ||
Code = "hello"; | ||
} | ||
} |
Oops, something went wrong.