diff --git a/src/CSnakes.SourceGeneration/Parser/PythonParser.Parameters.cs b/src/CSnakes.SourceGeneration/Parser/PythonParser.Parameters.cs index dce0cc28..ae8416f0 100644 --- a/src/CSnakes.SourceGeneration/Parser/PythonParser.Parameters.cs +++ b/src/CSnakes.SourceGeneration/Parser/PythonParser.Parameters.cs @@ -18,7 +18,13 @@ from type in Token.EqualTo(PythonToken.Colon).Optional().Then( from defaultValue in Token.EqualTo(PythonToken.Equal).Optional().Then( _ => ConstantValueTokenizer.AssumeNotNull().OptionalOrDefault() ) - select new PythonFunctionParameter(arg.Name, type, defaultValue, arg.ParameterType)) + select new PythonFunctionParameter(arg.Name, type, + // Force a default value for *args and **kwargs as null, otherwise the calling convention is strange + arg.ParameterType is PythonFunctionParameterType.Star or PythonFunctionParameterType.DoubleStar + && defaultValue is null + ? PythonConstant.None.Value + : defaultValue, + arg.ParameterType)) .Named("Parameter"); public static TokenListParser ParameterOrSlash { get; } = diff --git a/src/CSnakes.SourceGeneration/Parser/Types/PythonFunctionParameter.cs b/src/CSnakes.SourceGeneration/Parser/Types/PythonFunctionParameter.cs index af9b570f..81695171 100644 --- a/src/CSnakes.SourceGeneration/Parser/Types/PythonFunctionParameter.cs +++ b/src/CSnakes.SourceGeneration/Parser/Types/PythonFunctionParameter.cs @@ -9,7 +9,7 @@ public class PythonFunctionParameter(string name, PythonTypeSpec? type, PythonCo public bool IsKeywordOnly { get; set; } - public PythonConstant? DefaultValue { get; set; } = defaultValue; + public PythonConstant? DefaultValue { get; } = defaultValue; public PythonFunctionParameterType ParameterType { get; } = parameterType; diff --git a/src/CSnakes.SourceGeneration/Reflection/ArgumentReflection.cs b/src/CSnakes.SourceGeneration/Reflection/ArgumentReflection.cs index 84cb81e7..7e173e7d 100644 --- a/src/CSnakes.SourceGeneration/Reflection/ArgumentReflection.cs +++ b/src/CSnakes.SourceGeneration/Reflection/ArgumentReflection.cs @@ -6,8 +6,8 @@ namespace CSnakes.Reflection; public class ArgumentReflection { - private static readonly PythonTypeSpec DictStrAny = new("dict", [new("str", []), PythonTypeSpec.Any]); - private static readonly TypeSyntax ArrayPyObject = SyntaxFactory.ParseTypeName("PyObject[]"); + private static readonly TypeSyntax StarReflectedType = SyntaxFactory.ParseTypeName("PyObject[]?"); + private static readonly TypeSyntax DoubleStarReflectedType = SyntaxFactory.ParseTypeName("IReadOnlyDictionary?"); public static ParameterSyntax? ArgumentSyntax(PythonFunctionParameter parameter) { @@ -19,82 +19,41 @@ public class ArgumentReflection // Treat *args as list=None and **kwargs as dict=None // TODO: Handle the user specifying *args with a type annotation like tuple[int, str] - TypeSyntax reflectedType = parameter.ParameterType switch + var (reflectedType, defaultValue) = parameter switch { - PythonFunctionParameterType.Star => ArrayPyObject, - PythonFunctionParameterType.DoubleStar => TypeReflection.AsPredefinedType(DictStrAny, TypeReflection.ConversionDirection.ToPython), - PythonFunctionParameterType.Normal => TypeReflection.AsPredefinedType(parameter.Type, TypeReflection.ConversionDirection.ToPython), + { ParameterType: PythonFunctionParameterType.Star } => (StarReflectedType, PythonConstant.None.Value), + { ParameterType: PythonFunctionParameterType.DoubleStar } => (DoubleStarReflectedType, PythonConstant.None.Value), + { ParameterType: PythonFunctionParameterType.Normal, Type: var pt, DefaultValue: var dv } => + (TypeReflection.AsPredefinedType(pt, TypeReflection.ConversionDirection.ToPython), dv), _ => throw new NotImplementedException() }; - // Force a default value for *args and **kwargs as null, otherwise the calling convention is strange - if ((parameter.ParameterType == PythonFunctionParameterType.Star || - parameter.ParameterType == PythonFunctionParameterType.DoubleStar) && - parameter.DefaultValue is null) - - { - parameter.DefaultValue = PythonConstant.None.Value; - } - - bool isNullableType = false; - - LiteralExpressionSyntax? literalExpressionSyntax; - - switch (parameter.DefaultValue) + var literalExpressionSyntax = defaultValue switch { - case null: - literalExpressionSyntax = null; - break; - case PythonConstant.HexidecimalInteger { Value: var v }: - literalExpressionSyntax = SyntaxFactory.LiteralExpression( - SyntaxKind.NumericLiteralExpression, - SyntaxFactory.Literal($"0x{v:X}", v)); - break; - case PythonConstant.BinaryInteger { Value: var v }: - literalExpressionSyntax = SyntaxFactory.LiteralExpression( - SyntaxKind.NumericLiteralExpression, - SyntaxFactory.Literal($"0b{Convert.ToString(v, 2)}", v)); - break; - case PythonConstant.Integer { Value: var v and >= int.MinValue and <= int.MaxValue }: + null => null, + PythonConstant.HexidecimalInteger { Value: var v } => + SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, + SyntaxFactory.Literal($"0x{v:X}", v)), + PythonConstant.BinaryInteger { Value: var v } => + SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, + SyntaxFactory.Literal($"0b{Convert.ToString(v, 2)}", v)), + PythonConstant.Integer { Value: var v and >= int.MinValue and <= int.MaxValue } => // Downcast long to int if the value is small as the code is more readable without the L suffix - literalExpressionSyntax = SyntaxFactory.LiteralExpression( - SyntaxKind.NumericLiteralExpression, - SyntaxFactory.Literal((int)v)); - break; - case PythonConstant.Integer { Value: var v }: - literalExpressionSyntax = SyntaxFactory.LiteralExpression( - SyntaxKind.NumericLiteralExpression, - SyntaxFactory.Literal(v)); - break; - case PythonConstant.String { Value: var v }: - literalExpressionSyntax = SyntaxFactory.LiteralExpression( - SyntaxKind.StringLiteralExpression, - SyntaxFactory.Literal(v)); - break; - case PythonConstant.Float { Value: var v }: - literalExpressionSyntax = SyntaxFactory.LiteralExpression( - SyntaxKind.NumericLiteralExpression, - SyntaxFactory.Literal(v)); - break; - case PythonConstant.Bool { Value: true }: - literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression); - break; - case PythonConstant.Bool { Value: false }: - literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression); - break; - case PythonConstant.None: - literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression); - isNullableType = true; - break; - default: - literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression); - isNullableType = true; - break; - } + SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal((int)v)), + PythonConstant.Integer { Value: var v } => + SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(v)), + PythonConstant.String { Value: var v } => + SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(v)), + PythonConstant.Float { Value: var v } => + SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(v)), + PythonConstant.Bool { Value: true } => SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression), + PythonConstant.Bool { Value: false } => SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression), + _ => SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression) + }; return SyntaxFactory .Parameter(SyntaxFactory.Identifier(Keywords.ValidIdentifier(parameter.Name.ToLowerPascalCase()))) - .WithType(isNullableType ? SyntaxFactory.NullableType(reflectedType) : reflectedType) + .WithType(reflectedType) .WithDefault(literalExpressionSyntax is not null ? SyntaxFactory.EqualsValueClause(literalExpressionSyntax) : null); } diff --git a/src/CSnakes.SourceGeneration/Reflection/MethodReflection.cs b/src/CSnakes.SourceGeneration/Reflection/MethodReflection.cs index 82e3a41a..32e929b7 100644 --- a/src/CSnakes.SourceGeneration/Reflection/MethodReflection.cs +++ b/src/CSnakes.SourceGeneration/Reflection/MethodReflection.cs @@ -34,14 +34,11 @@ public static MethodDefinition FromMethod(PythonFunctionDefinition function, str // Step 3: Build arguments List<(PythonFunctionParameter pythonParameter, ParameterSyntax cSharpParameter)> parameterList = ArgumentReflection.FunctionParametersAsParameterSyntaxPairs(function.Parameters); - List parameterGenericArgs = []; - foreach (var (pythonParameter, cSharpParameter) in parameterList) - { - if (cSharpParameter.Type is GenericNameSyntax g) - { - parameterGenericArgs.Add(g); - } - } + var parameterGenericArgs = + parameterList.Select(e => e.cSharpParameter.Type) + .Append(returnSyntax) + .OfType() + .ToList(); // Step 4: Build body var pythonConversionStatements = new List(); @@ -104,11 +101,12 @@ public static MethodDefinition FromMethod(PythonFunctionDefinition function, str callExpression = GenerateKeywordCall(parameterList); } + var resultIdentifierName = IdentifierName("__result_pyObject"); ReturnStatementSyntax returnExpression = returnSyntax switch { PredefinedTypeSyntax s when s.Keyword.IsKind(SyntaxKind.VoidKeyword) => ReturnStatement(null), - IdentifierNameSyntax { Identifier.ValueText: "PyObject" } => ReturnStatement(IdentifierName("__result_pyObject")), - _ => ProcessMethodWithReturnType(returnSyntax, parameterGenericArgs) + IdentifierNameSyntax { Identifier.ValueText: "PyObject" } => ReturnStatement(resultIdentifierName), + _ => ProcessMethodWithReturnType(resultIdentifierName, returnSyntax) }; bool resultShouldBeDisposed = returnSyntax switch @@ -210,23 +208,36 @@ PredefinedTypeSyntax s when s.Keyword.IsKind(SyntaxKind.VoidKeyword) => true, return new(syntax, parameterGenericArgs); } - private static ReturnStatementSyntax ProcessMethodWithReturnType(TypeSyntax returnSyntax, List parameterGenericArgs) + private static ReturnStatementSyntax ProcessMethodWithReturnType(IdentifierNameSyntax identifierNameSyntax, + TypeSyntax typeSyntax) { - ReturnStatementSyntax returnExpression; - if (returnSyntax is GenericNameSyntax rg) - { - parameterGenericArgs.Add(rg); - } - - returnExpression = ReturnStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("__result_pyObject"), - GenericName(Identifier("As")) - .WithTypeArgumentList(TypeArgumentList(SeparatedList([returnSyntax]))))) - .WithArgumentList(ArgumentList())); - return returnExpression; + (typeSyntax, var nullable) = + typeSyntax is NullableTypeSyntax { ElementType: var elementTypeSyntax } + ? (elementTypeSyntax, true) + : (typeSyntax, false); + + var conversionExpression = + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + identifierNameSyntax, + GenericName(Identifier("As")) + .WithTypeArgumentList(TypeArgumentList(SeparatedList([typeSyntax]))))) + .WithArgumentList(ArgumentList()); + + ExpressionSyntax returnValueExpression = + nullable + ? ConditionalExpression( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + identifierNameSyntax, + IdentifierName("IsNone"))), + LiteralExpression(SyntaxKind.NullLiteralExpression), + conversionExpression) + : conversionExpression; + + return ReturnStatement(returnValueExpression); } private static InvocationExpressionSyntax GenerateParamsCall(IEnumerable<(PythonFunctionParameter pythonParameter, ParameterSyntax cSharpParameter)> parameterList) diff --git a/src/CSnakes.SourceGeneration/Reflection/TypeReflection.cs b/src/CSnakes.SourceGeneration/Reflection/TypeReflection.cs index ee9e2725..252267d7 100644 --- a/src/CSnakes.SourceGeneration/Reflection/TypeReflection.cs +++ b/src/CSnakes.SourceGeneration/Reflection/TypeReflection.cs @@ -30,7 +30,7 @@ public static TypeSyntax AsPredefinedType(PythonTypeSpec pythonType, ConversionD "typing.Dict" or "Dict" => CreateDictionaryType(pythonType.Arguments[0], pythonType.Arguments[1], direction), "typing.Mapping" or "Mapping" => CreateDictionaryType(pythonType.Arguments[0], pythonType.Arguments[1], direction), "typing.Sequence" or "Sequence" => CreateListType(pythonType.Arguments[0], direction), - "typing.Optional" or "Optional" => AsPredefinedType(pythonType.Arguments[0], direction), + "typing.Optional" or "Optional" => SyntaxFactory.NullableType(AsPredefinedType(pythonType.Arguments[0], direction)), "typing.Generator" or "Generator" => CreateGeneratorType(pythonType.Arguments[0], pythonType.Arguments[1], pythonType.Arguments[2], direction), // Todo more types... see https://docs.python.org/3/library/stdtypes.html#standard-generic-classes _ => SyntaxFactory.ParseTypeName("PyObject"), diff --git a/src/CSnakes.Tests/GeneratedSignatureTests.cs b/src/CSnakes.Tests/GeneratedSignatureTests.cs index a8a2f458..b6ab942a 100644 --- a/src/CSnakes.Tests/GeneratedSignatureTests.cs +++ b/src/CSnakes.Tests/GeneratedSignatureTests.cs @@ -27,7 +27,8 @@ public class GeneratedSignatureTests(TestEnvironment testEnv) : IClassFixture typing.Sequence[int]:\n ...\n", "IReadOnlyList HelloWorld(IReadOnlyList numbers)")] [InlineData("def hello_world(value: tuple[int]) -> None:\n ...\n", "void HelloWorld(ValueTuple value)")] [InlineData("def hello_world(a: bool, b: str, c: list[tuple[int, float]]) -> bool: \n ...\n", "bool HelloWorld(bool a, string b, IReadOnlyList<(long, double)> c)")] - [InlineData("def hello_world(a: bool = True, b: str = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string? b = null)")] + [InlineData("def hello_world(a: bool = True, b: Optional[str] = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string? b = null)")] + [InlineData("def hello_world(a: Optional[int], b: Optional[str]) -> Optional[bool]: \n ...\n", "bool? HelloWorld(long? a, string? b)")] [InlineData("def hello_world(a: bytes, b: bool = False, c: float = 0.1) -> None: \n ...\n", "void HelloWorld(byte[] a, bool b = false, double c = 0.1)")] [InlineData("def hello_world(a: str = 'default') -> None: \n ...\n", "void HelloWorld(string a = \"default\")")] [InlineData("def hello_world(a: str, *args) -> None: \n ...\n", "void HelloWorld(string a, PyObject[]? args = null)")] diff --git a/src/CSnakes.Tests/PythonStaticGeneratorTests/FormatClassFromMethods.test_optional.approved.txt b/src/CSnakes.Tests/PythonStaticGeneratorTests/FormatClassFromMethods.test_optional.approved.txt new file mode 100644 index 00000000..d5a3d00e --- /dev/null +++ b/src/CSnakes.Tests/PythonStaticGeneratorTests/FormatClassFromMethods.test_optional.approved.txt @@ -0,0 +1,130 @@ +// +#nullable enable + +using CSnakes.Runtime; +using CSnakes.Runtime.Python; + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Reflection.Metadata; + +using Microsoft.Extensions.Logging; + +[assembly: MetadataUpdateHandler(typeof(Python.Generated.Tests.TestClassExtensions))] + +namespace Python.Generated.Tests; + +public static class TestClassExtensions +{ + private static ITestClass? instance; + + private static ReadOnlySpan HotReloadHash => "5d9aaa2919c9c971ca72ce9f9c9c046a"u8; + + public static ITestClass TestClass(this IPythonEnvironment env) + { + if (instance is null) + { + instance = new TestClassInternal(env.Logger); + } + Debug.Assert(!env.IsDisposed()); + return instance; + } + + public static void UpdateApplication(Type[]? updatedTypes) + { + instance?.ReloadModule(); + } + + private class TestClassInternal : ITestClass + { + private PyObject module; + private readonly ILogger logger; + + private PyObject __func_test_int; + private PyObject __func_test_str; + private PyObject __func_test_any; + + internal TestClassInternal(ILogger logger) + { + this.logger = logger; + using (GIL.Acquire()) + { + logger.LogDebug("Importing module {ModuleName}", "test"); + module = Import.ImportModule("test"); + this.__func_test_int = module.GetAttr("test_int"); + this.__func_test_str = module.GetAttr("test_str"); + this.__func_test_any = module.GetAttr("test_any"); + } + } + + void IReloadableModuleImport.ReloadModule() + { + logger.LogDebug("Reloading module {ModuleName}", "test"); + using (GIL.Acquire()) + { + Import.ReloadModule(ref module); + // Dispose old functions + this.__func_test_int.Dispose(); + this.__func_test_str.Dispose(); + this.__func_test_any.Dispose(); + // Bind to new functions + this.__func_test_int = module.GetAttr("test_int"); + this.__func_test_str = module.GetAttr("test_str"); + this.__func_test_any = module.GetAttr("test_any"); + } + } + + public void Dispose() + { + logger.LogDebug("Disposing module {ModuleName}", "test"); + this.__func_test_int.Dispose(); + this.__func_test_str.Dispose(); + this.__func_test_any.Dispose(); + module.Dispose(); + } + + public long? TestInt(long? n) + { + using (GIL.Acquire()) + { + logger.LogDebug("Invoking Python function: {FunctionName}", "test_int"); + PyObject __underlyingPythonFunc = this.__func_test_int; + using PyObject n_pyObject = PyObject.From(n)!; + using PyObject __result_pyObject = __underlyingPythonFunc.Call(n_pyObject); + return __result_pyObject.IsNone() ? null : __result_pyObject.As(); + } + } + + public string? TestStr(string? s) + { + using (GIL.Acquire()) + { + logger.LogDebug("Invoking Python function: {FunctionName}", "test_str"); + PyObject __underlyingPythonFunc = this.__func_test_str; + using PyObject s_pyObject = PyObject.From(s)!; + using PyObject __result_pyObject = __underlyingPythonFunc.Call(s_pyObject); + return __result_pyObject.IsNone() ? null : __result_pyObject.As(); + } + } + + public PyObject? TestAny(PyObject? obj) + { + using (GIL.Acquire()) + { + logger.LogDebug("Invoking Python function: {FunctionName}", "test_any"); + PyObject __underlyingPythonFunc = this.__func_test_any; + using PyObject obj_pyObject = PyObject.From(obj)!; + using PyObject __result_pyObject = __underlyingPythonFunc.Call(obj_pyObject); + return __result_pyObject.IsNone() ? null : __result_pyObject.As(); + } + } + } +} + +public interface ITestClass : IReloadableModuleImport +{ + long? TestInt(long? n); + string? TestStr(string? s); + PyObject? TestAny(PyObject? obj); +} diff --git a/src/CSnakes.Tests/TokenizerTests.cs b/src/CSnakes.Tests/TokenizerTests.cs index 1aeca75a..6ba50d94 100644 --- a/src/CSnakes.Tests/TokenizerTests.cs +++ b/src/CSnakes.Tests/TokenizerTests.cs @@ -209,6 +209,20 @@ public void ParseFunctionParameterDefaultNone() Assert.True(result.Value.HasTypeAnnotation()); } + [Theory] + [InlineData("*args", "args", PythonFunctionParameterType.Star)] + [InlineData("**kwargs", "kwargs", PythonFunctionParameterType.DoubleStar)] + public void ParseFunctionSpecialParameter(string source, string expectedName, PythonFunctionParameterType expectedParameterType) + { + var tokens = PythonTokenizer.Instance.Tokenize(source); + var result = PythonParser.PythonParameterTokenizer.TryParse(tokens); + Assert.True(result.HasValue); + Assert.Equal(expectedName, result.Value.Name); + Assert.Equal(expectedParameterType, result.Value.ParameterType); + Assert.Equal("None", result.Value.DefaultValue?.ToString()); + Assert.False(result.Value.HasTypeAnnotation()); + } + [Fact] public void ParseFunctionParameterListSingleGeneric() { diff --git a/src/CSnakes.Tests/TypeReflectionTests.cs b/src/CSnakes.Tests/TypeReflectionTests.cs index d256e62d..e33eadec 100644 --- a/src/CSnakes.Tests/TypeReflectionTests.cs +++ b/src/CSnakes.Tests/TypeReflectionTests.cs @@ -41,8 +41,8 @@ public void AsPredefinedType(string pythonType, string expectedType) => [InlineData("Tuple[str, list[int]]", "(string,IReadOnlyList)")] [InlineData("Dict[str, int]", "IReadOnlyDictionary")] [InlineData("Tuple[int, int, Tuple[int, int]]", "(long,long,(long,long))")] - [InlineData("Optional[str]", "string")] - [InlineData("Optional[int]", "long")] + [InlineData("Optional[str]", "string?")] + [InlineData("Optional[int]", "long?")] [InlineData("Callable[[str], int]", "PyObject")] public void AsPredefinedTypeOldTypeNames(string pythonType, string expectedType) => ParsingTestInternal(pythonType, expectedType); diff --git a/src/Integration.Tests/TestOptional.cs b/src/Integration.Tests/TestOptional.cs new file mode 100644 index 00000000..70045839 --- /dev/null +++ b/src/Integration.Tests/TestOptional.cs @@ -0,0 +1,33 @@ +using CSnakes.Runtime.Python; + +namespace Integration.Tests; +public class TestOptional(PythonEnvironmentFixture fixture) : IntegrationTestBase(fixture) +{ + [Theory] + [InlineData(null)] + [InlineData(42)] + public void Int(int? input) + { + var mod = Env.TestOptional(); + Assert.Equal(input, mod.TestInt(input)); + } + + [Theory] + [InlineData(null)] + [InlineData("foobar")] + public void Str(string? input) + { + var mod = Env.TestOptional(); + Assert.Equal(input, mod.TestStr(input)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void Object(bool @null) + { + var input = @null ? null : PyObject.Zero; + var mod = Env.TestOptional(); + Assert.Equal(input, mod.TestAny(input)); + } +} diff --git a/src/Integration.Tests/python/test_optional.py b/src/Integration.Tests/python/test_optional.py new file mode 100644 index 00000000..533b9e87 --- /dev/null +++ b/src/Integration.Tests/python/test_optional.py @@ -0,0 +1,10 @@ +from typing import Any, Optional + +def test_int(n: Optional[int]) -> Optional[int]: + return n + +def test_str(s: Optional[str]) -> Optional[str]: + return s + +def test_any(obj: Optional[Any]) -> Optional[Any]: + return obj