From 3ce3aa7ea8f6ef37798e9b514e3ec690682bd163 Mon Sep 17 00:00:00 2001 From: Anthony Shaw Date: Mon, 8 Jul 2024 15:06:03 +1000 Subject: [PATCH 1/4] Re-add mistral demo and add tests for double and float defaults --- .../ExamplePythonDependency.csproj | 5 ++--- ExamplePythonDependency/mistral_demo.py | 2 +- PythonSourceGenerator.Tests/TokenizerTests.cs | 22 +++++++++++++++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/ExamplePythonDependency/ExamplePythonDependency.csproj b/ExamplePythonDependency/ExamplePythonDependency.csproj index 56c4ab6b..3badf904 100644 --- a/ExamplePythonDependency/ExamplePythonDependency.csproj +++ b/ExamplePythonDependency/ExamplePythonDependency.csproj @@ -8,7 +8,6 @@ - @@ -23,9 +22,9 @@ Always - + Always diff --git a/ExamplePythonDependency/mistral_demo.py b/ExamplePythonDependency/mistral_demo.py index 5acc5117..b822923a 100644 --- a/ExamplePythonDependency/mistral_demo.py +++ b/ExamplePythonDependency/mistral_demo.py @@ -6,7 +6,7 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest -def invoke_mistral_inference(messages: list[str], lang: str = "en-US", temperature=0.0) -> str: +def invoke_mistral_inference(messages: list[str], lang: str = "en-US", temperature: float=0.0) -> str: tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tokenizer.model.v3") model = Transformer.from_folder(mistral_models_path) diff --git a/PythonSourceGenerator.Tests/TokenizerTests.cs b/PythonSourceGenerator.Tests/TokenizerTests.cs index aef45fd5..a5de9e34 100644 --- a/PythonSourceGenerator.Tests/TokenizerTests.cs +++ b/PythonSourceGenerator.Tests/TokenizerTests.cs @@ -149,6 +149,28 @@ public void ParseFunctionParameterDefaultDoubleQuotedString() Assert.False(result.Value.HasTypeAnnotation()); } + [Fact] + public void ParseFunctionParameterDefaultDouble() + { + var tokens = PythonSignatureTokenizer.Instance.Tokenize($"a: float = 0.0"); + var result = PythonSignatureParser.PythonParameterTokenizer.TryParse(tokens); + Assert.True(result.HasValue); + Assert.Equal("a", result.Value.Name); + Assert.Equal("0.0", result.Value.DefaultValue?.ToString()); + Assert.True(result.Value.HasTypeAnnotation()); + } + + [Fact] + public void ParseFunctionParameterDefaultInt() + { + var tokens = PythonSignatureTokenizer.Instance.Tokenize($"a: int = 1234"); + var result = PythonSignatureParser.PythonParameterTokenizer.TryParse(tokens); + Assert.True(result.HasValue); + Assert.Equal("a", result.Value.Name); + Assert.Equal("1234", result.Value.DefaultValue?.ToString()); + Assert.True(result.Value.HasTypeAnnotation()); + } + [Fact] public void ParseFunctionParameterListSingleGeneric() { From 375c590971906144f918ac1f020afa35e4530961 Mon Sep 17 00:00:00 2001 From: Anthony Shaw Date: Mon, 8 Jul 2024 15:09:52 +1000 Subject: [PATCH 2/4] Update test for a value that doesn't round to 0 --- PythonSourceGenerator.Tests/TokenizerTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PythonSourceGenerator.Tests/TokenizerTests.cs b/PythonSourceGenerator.Tests/TokenizerTests.cs index a5de9e34..e9b9796a 100644 --- a/PythonSourceGenerator.Tests/TokenizerTests.cs +++ b/PythonSourceGenerator.Tests/TokenizerTests.cs @@ -152,11 +152,11 @@ public void ParseFunctionParameterDefaultDoubleQuotedString() [Fact] public void ParseFunctionParameterDefaultDouble() { - var tokens = PythonSignatureTokenizer.Instance.Tokenize($"a: float = 0.0"); + var tokens = PythonSignatureTokenizer.Instance.Tokenize($"a: float = -1.1"); var result = PythonSignatureParser.PythonParameterTokenizer.TryParse(tokens); Assert.True(result.HasValue); Assert.Equal("a", result.Value.Name); - Assert.Equal("0.0", result.Value.DefaultValue?.ToString()); + Assert.Equal("-1.1", result.Value.DefaultValue?.ToString()); Assert.True(result.Value.HasTypeAnnotation()); } From 3ab8b5740cdfb99f1e58d5fb906657d129ac21fa Mon Sep 17 00:00:00 2001 From: Anthony Shaw Date: Mon, 8 Jul 2024 15:42:53 +1000 Subject: [PATCH 3/4] True False and None are all special tokens now --- PythonSourceGenerator.Tests/TokenizerTests.cs | 40 ++++++++++++++++++- .../Parser/PythonSignatureParser.cs | 17 ++++++-- .../Parser/PythonSignatureTokenizer.cs | 4 +- .../Parser/PythonSignatureTokens.cs | 3 ++ .../Parser/Types/PythonConstant.cs | 7 ++++ 5 files changed, 65 insertions(+), 6 deletions(-) diff --git a/PythonSourceGenerator.Tests/TokenizerTests.cs b/PythonSourceGenerator.Tests/TokenizerTests.cs index e9b9796a..4c3ad260 100644 --- a/PythonSourceGenerator.Tests/TokenizerTests.cs +++ b/PythonSourceGenerator.Tests/TokenizerTests.cs @@ -24,7 +24,7 @@ public void Tokenize() PythonSignatureTokens.PythonSignatureToken.Identifier, PythonSignatureTokens.PythonSignatureToken.CloseParenthesis, PythonSignatureTokens.PythonSignatureToken.Arrow, - PythonSignatureTokens.PythonSignatureToken.Identifier, + PythonSignatureTokens.PythonSignatureToken.None, PythonSignatureTokens.PythonSignatureToken.Colon, ], tokens.Select(t => t.Kind)); } @@ -39,6 +39,9 @@ public void Tokenize() [InlineData("abc123", PythonSignatureTokens.PythonSignatureToken.Identifier)] [InlineData("'hello'", PythonSignatureTokens.PythonSignatureToken.SingleQuotedString)] [InlineData("\"hello\"", PythonSignatureTokens.PythonSignatureToken.DoubleQuotedString)] + [InlineData("True", PythonSignatureTokens.PythonSignatureToken.True)] + [InlineData("False", PythonSignatureTokens.PythonSignatureToken.False)] + [InlineData("None", PythonSignatureTokens.PythonSignatureToken.None)] public void AssertTokenKinds(string code, PythonSignatureTokens.PythonSignatureToken expectedToken) { var tokens = PythonSignatureTokenizer.Instance.Tokenize(code); @@ -65,7 +68,7 @@ public void TokenizeWithDefaultValue() PythonSignatureTokens.PythonSignatureToken.SingleQuotedString, PythonSignatureTokens.PythonSignatureToken.CloseParenthesis, PythonSignatureTokens.PythonSignatureToken.Arrow, - PythonSignatureTokens.PythonSignatureToken.Identifier, + PythonSignatureTokens.PythonSignatureToken.None, PythonSignatureTokens.PythonSignatureToken.Colon, ], tokens.Select(t => t.Kind)); } @@ -171,6 +174,39 @@ public void ParseFunctionParameterDefaultInt() Assert.True(result.Value.HasTypeAnnotation()); } + [Fact] + public void ParseFunctionParameterDefaultBoolTrue() + { + var tokens = PythonSignatureTokenizer.Instance.Tokenize($"a: bool = True"); + var result = PythonSignatureParser.PythonParameterTokenizer.TryParse(tokens); + Assert.True(result.HasValue); + Assert.Equal("a", result.Value.Name); + Assert.Equal("True", result.Value.DefaultValue?.ToString()); + Assert.True(result.Value.HasTypeAnnotation()); + } + + [Fact] + public void ParseFunctionParameterDefaultBoolFalse() + { + var tokens = PythonSignatureTokenizer.Instance.Tokenize($"a: bool = False"); + var result = PythonSignatureParser.PythonParameterTokenizer.TryParse(tokens); + Assert.True(result.HasValue); + Assert.Equal("a", result.Value.Name); + Assert.Equal("False", result.Value.DefaultValue?.ToString()); + Assert.True(result.Value.HasTypeAnnotation()); + } + + [Fact] + public void ParseFunctionParameterDefaultNone() + { + var tokens = PythonSignatureTokenizer.Instance.Tokenize($"a: bool = None"); + var result = PythonSignatureParser.PythonParameterTokenizer.TryParse(tokens); + Assert.True(result.HasValue); + Assert.Equal("a", result.Value.Name); + Assert.Equal("None", result.Value.DefaultValue?.ToString()); + Assert.True(result.Value.HasTypeAnnotation()); + } + [Fact] public void ParseFunctionParameterListSingleGeneric() { diff --git a/PythonSourceGenerator/Parser/PythonSignatureParser.cs b/PythonSourceGenerator/Parser/PythonSignatureParser.cs index 8eceb41a..8359aa24 100644 --- a/PythonSourceGenerator/Parser/PythonSignatureParser.cs +++ b/PythonSourceGenerator/Parser/PythonSignatureParser.cs @@ -89,8 +89,8 @@ from exp in Character.EqualToIgnoreCase('e') } public static TokenListParser PythonTypeDefinitionTokenizer { get; } = - (from name in Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.Identifier) - from openBracket in Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.OpenBracket) + (from name in Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.Identifier).Or(Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.None)) + from openBracket in Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.OpenBracket) .Then(_ => PythonTypeDefinitionTokenizer.ManyDelimitedBy( Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.Comma), end: Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.CloseBracket) @@ -123,13 +123,24 @@ from openBracket in Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.Ope .Select(d => new PythonConstant { IsInteger = true, IntegerValue = d }) .Named("Integer Constant"); + public static TokenListParser BoolConstantTokenizer { get; } = + Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.True).Or(Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.False)) + .Select(d => new PythonConstant { IsBool = true, BoolValue = d.Kind == PythonSignatureTokens.PythonSignatureToken.True }) + .Named("Bool Constant"); + + public static TokenListParser NoneConstantTokenizer { get; } = + Token.EqualTo(PythonSignatureTokens.PythonSignatureToken.None) + .Select(d => new PythonConstant { IsNone = true }) + .Named("None Constant"); + // Any constant value public static TokenListParser ConstantValueTokenizer { get; } = DecimalConstantTokenizer.AsNullable() .Or(IntegerConstantTokenizer.AsNullable()) + .Or(BoolConstantTokenizer.AsNullable()) + .Or(NoneConstantTokenizer.AsNullable()) .Or(DoubleQuotedStringConstantTokenizer.AsNullable()) .Or(SingleQuotedStringConstantTokenizer.AsNullable()) - // TODO: Add None token .Named("Constant"); public static TokenListParser PythonParameterTokenizer { get; } = diff --git a/PythonSourceGenerator/Parser/PythonSignatureTokenizer.cs b/PythonSourceGenerator/Parser/PythonSignatureTokenizer.cs index 8b9a967c..fe2e2be7 100644 --- a/PythonSourceGenerator/Parser/PythonSignatureTokenizer.cs +++ b/PythonSourceGenerator/Parser/PythonSignatureTokenizer.cs @@ -22,11 +22,13 @@ public static class PythonSignatureTokenizer .Match(Span.EqualTo("def"), PythonSignatureTokens.PythonSignatureToken.Def, requireDelimiters: true) .Match(Span.EqualTo("async"), PythonSignatureTokens.PythonSignatureToken.Async, requireDelimiters: true) .Match(Span.EqualTo("..."), PythonSignatureTokens.PythonSignatureToken.Ellipsis) + .Match(Span.EqualTo("None"), PythonSignatureTokens.PythonSignatureToken.None, requireDelimiters: true) + .Match(Span.EqualTo("True"), PythonSignatureTokens.PythonSignatureToken.True, requireDelimiters: true) + .Match(Span.EqualTo("False"), PythonSignatureTokens.PythonSignatureToken.False, requireDelimiters: true) .Match(Identifier.CStyle, PythonSignatureTokens.PythonSignatureToken.Identifier, requireDelimiters: true) // TODO: Does this require delimiters? .Match(PythonSignatureParser.IntegerConstantToken, PythonSignatureTokens.PythonSignatureToken.Integer, requireDelimiters: true) .Match(PythonSignatureParser.DecimalConstantToken, PythonSignatureTokens.PythonSignatureToken.Decimal, requireDelimiters: true) .Match(PythonSignatureParser.DoubleQuotedStringConstantToken, PythonSignatureTokens.PythonSignatureToken.DoubleQuotedString) .Match(PythonSignatureParser.SingleQuotedStringConstantToken, PythonSignatureTokens.PythonSignatureToken.SingleQuotedString) - // TODO: Treat None as a special token .Build(); } diff --git a/PythonSourceGenerator/Parser/PythonSignatureTokens.cs b/PythonSourceGenerator/Parser/PythonSignatureTokens.cs index 48c1435f..ff640f76 100644 --- a/PythonSourceGenerator/Parser/PythonSignatureTokens.cs +++ b/PythonSourceGenerator/Parser/PythonSignatureTokens.cs @@ -50,6 +50,9 @@ public enum PythonSignatureToken Decimal, DoubleQuotedString, SingleQuotedString, + True, + False, + None, [Token(Example = "...")] Ellipsis diff --git a/PythonSourceGenerator/Parser/Types/PythonConstant.cs b/PythonSourceGenerator/Parser/Types/PythonConstant.cs index fd3d1382..2990ea7f 100644 --- a/PythonSourceGenerator/Parser/Types/PythonConstant.cs +++ b/PythonSourceGenerator/Parser/Types/PythonConstant.cs @@ -13,6 +13,9 @@ public class PythonConstant public bool IsFloat { get; set; } public double FloatValue { get; set; } + public bool IsBool { get; set; } + public bool BoolValue { get; set; } + public override string ToString() { if (IsInteger) @@ -31,6 +34,10 @@ public override string ToString() { return "None"; } + if (IsBool) + { + return BoolValue.ToString(); + } return "unknown"; } } From ce51f62fa24f04883582234e71b838ff00ed8af6 Mon Sep 17 00:00:00 2001 From: Anthony Shaw Date: Mon, 8 Jul 2024 15:54:39 +1000 Subject: [PATCH 4/4] Propagate bool constant default to method signatures --- PythonSourceGenerator.Tests/SignatureTests.cs | 1 + PythonSourceGenerator/Reflection/ArgumentReflection.cs | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/PythonSourceGenerator.Tests/SignatureTests.cs b/PythonSourceGenerator.Tests/SignatureTests.cs index 68ca3525..002af014 100644 --- a/PythonSourceGenerator.Tests/SignatureTests.cs +++ b/PythonSourceGenerator.Tests/SignatureTests.cs @@ -25,6 +25,7 @@ public BasicSmokeTest(TestEnvironment testEnv) [InlineData("def hello_world(name: str, age: int) -> str:\n ...\n", "string HelloWorld(string name, long age)")] [InlineData("def hello_world(numbers: list[float]) -> list[int]:\n ...\n", "IEnumerable HelloWorld(IEnumerable numbers)")] [InlineData("def hello_world(a: bool, b: str, c: list[tuple[int, float]]) -> bool: \n ...\n", "bool HelloWorld(bool a, string b, IEnumerable> c)")] + [InlineData("def hello_world(a: bool = True, b: str = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string b = null)")] public void TestGeneratedSignature(string code, string expected) { diff --git a/PythonSourceGenerator/Reflection/ArgumentReflection.cs b/PythonSourceGenerator/Reflection/ArgumentReflection.cs index 14b33ba4..47813740 100644 --- a/PythonSourceGenerator/Reflection/ArgumentReflection.cs +++ b/PythonSourceGenerator/Reflection/ArgumentReflection.cs @@ -31,6 +31,10 @@ public static ParameterSyntax ArgumentSyntax(PythonFunctionParameter parameter) literalExpressionSyntax = SyntaxFactory.LiteralExpression( SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(parameter.DefaultValue.FloatValue)); + else if (parameter.DefaultValue.IsBool && parameter.DefaultValue.BoolValue == true) + literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression); + else if (parameter.DefaultValue.IsBool && parameter.DefaultValue.BoolValue == false) + literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression); else if (parameter.DefaultValue.IsNone) literalExpressionSyntax = SyntaxFactory.LiteralExpression( SyntaxKind.NullLiteralExpression);