Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optionality of function parameters & return #299

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ from type in Token.EqualTo(PythonToken.Colon).Optional().Then(
from defaultValue in Token.EqualTo(PythonToken.Equal).Optional().Then(
_ => ConstantValueTokenizer.AssumeNotNull().OptionalOrDefault()
)
select new PythonFunctionParameter(arg.Name, type, defaultValue, arg.ParameterType))
select new PythonFunctionParameter(arg.Name, type,
// Force a default value for *args and **kwargs as null, otherwise the calling convention is strange
arg.ParameterType is PythonFunctionParameterType.Star or PythonFunctionParameterType.DoubleStar
&& defaultValue is null
? PythonConstant.None.Value
: defaultValue,
arg.ParameterType))
.Named("Parameter");

public static TokenListParser<PythonToken, PythonFunctionParameter?> ParameterOrSlash { get; } =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class PythonFunctionParameter(string name, PythonTypeSpec? type, PythonCo

public bool IsKeywordOnly { get; set; }

public PythonConstant? DefaultValue { get; set; } = defaultValue;
public PythonConstant? DefaultValue { get; } = defaultValue;

public PythonFunctionParameterType ParameterType { get; } = parameterType;

Expand Down
97 changes: 28 additions & 69 deletions src/CSnakes.SourceGeneration/Reflection/ArgumentReflection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ namespace CSnakes.Reflection;

public class ArgumentReflection
{
private static readonly PythonTypeSpec DictStrAny = new("dict", [new("str", []), PythonTypeSpec.Any]);
private static readonly TypeSyntax ArrayPyObject = SyntaxFactory.ParseTypeName("PyObject[]");
private static readonly TypeSyntax StarReflectedType = SyntaxFactory.ParseTypeName("PyObject[]?");
private static readonly TypeSyntax DoubleStarReflectedType = SyntaxFactory.ParseTypeName("IReadOnlyDictionary<string, PyObject>?");

public static ParameterSyntax? ArgumentSyntax(PythonFunctionParameter parameter)
{
Expand All @@ -19,82 +19,41 @@ public class ArgumentReflection

// Treat *args as list<Any>=None and **kwargs as dict<str, Any>=None
// TODO: Handle the user specifying *args with a type annotation like tuple[int, str]
TypeSyntax reflectedType = parameter.ParameterType switch
var (reflectedType, defaultValue) = parameter switch
{
PythonFunctionParameterType.Star => ArrayPyObject,
PythonFunctionParameterType.DoubleStar => TypeReflection.AsPredefinedType(DictStrAny, TypeReflection.ConversionDirection.ToPython),
PythonFunctionParameterType.Normal => TypeReflection.AsPredefinedType(parameter.Type, TypeReflection.ConversionDirection.ToPython),
{ ParameterType: PythonFunctionParameterType.Star } => (StarReflectedType, PythonConstant.None.Value),
{ ParameterType: PythonFunctionParameterType.DoubleStar } => (DoubleStarReflectedType, PythonConstant.None.Value),
{ ParameterType: PythonFunctionParameterType.Normal, Type: var pt, DefaultValue: var dv } =>
(TypeReflection.AsPredefinedType(pt, TypeReflection.ConversionDirection.ToPython), dv),
_ => throw new NotImplementedException()
};

// Force a default value for *args and **kwargs as null, otherwise the calling convention is strange
if ((parameter.ParameterType == PythonFunctionParameterType.Star ||
parameter.ParameterType == PythonFunctionParameterType.DoubleStar) &&
parameter.DefaultValue is null)

{
parameter.DefaultValue = PythonConstant.None.Value;
}

bool isNullableType = false;

LiteralExpressionSyntax? literalExpressionSyntax;

switch (parameter.DefaultValue)
var literalExpressionSyntax = defaultValue switch
{
case null:
literalExpressionSyntax = null;
break;
case PythonConstant.HexidecimalInteger { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0x{v:X}", v));
break;
case PythonConstant.BinaryInteger { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0b{Convert.ToString(v, 2)}", v));
break;
case PythonConstant.Integer { Value: var v and >= int.MinValue and <= int.MaxValue }:
null => null,
PythonConstant.HexidecimalInteger { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0x{v:X}", v)),
PythonConstant.BinaryInteger { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0b{Convert.ToString(v, 2)}", v)),
PythonConstant.Integer { Value: var v and >= int.MinValue and <= int.MaxValue } =>
// Downcast long to int if the value is small as the code is more readable without the L suffix
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal((int)v));
break;
case PythonConstant.Integer { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal(v));
break;
case PythonConstant.String { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.StringLiteralExpression,
SyntaxFactory.Literal(v));
break;
case PythonConstant.Float { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal(v));
break;
case PythonConstant.Bool { Value: true }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression);
break;
case PythonConstant.Bool { Value: false }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression);
break;
case PythonConstant.None:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);
isNullableType = true;
break;
default:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);
isNullableType = true;
break;
}
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal((int)v)),
PythonConstant.Integer { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(v)),
PythonConstant.String { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(v)),
PythonConstant.Float { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(v)),
PythonConstant.Bool { Value: true } => SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression),
PythonConstant.Bool { Value: false } => SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression),
_ => SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)
};

return SyntaxFactory
.Parameter(SyntaxFactory.Identifier(Keywords.ValidIdentifier(parameter.Name.ToLowerPascalCase())))
.WithType(isNullableType ? SyntaxFactory.NullableType(reflectedType) : reflectedType)
.WithType(reflectedType)
.WithDefault(literalExpressionSyntax is not null ? SyntaxFactory.EqualsValueClause(literalExpressionSyntax) : null);
}

Expand Down
63 changes: 37 additions & 26 deletions src/CSnakes.SourceGeneration/Reflection/MethodReflection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@ public static MethodDefinition FromMethod(PythonFunctionDefinition function, str
// Step 3: Build arguments
List<(PythonFunctionParameter pythonParameter, ParameterSyntax cSharpParameter)> parameterList = ArgumentReflection.FunctionParametersAsParameterSyntaxPairs(function.Parameters);

List<GenericNameSyntax> parameterGenericArgs = [];
foreach (var (pythonParameter, cSharpParameter) in parameterList)
{
if (cSharpParameter.Type is GenericNameSyntax g)
{
parameterGenericArgs.Add(g);
}
}
var parameterGenericArgs =
parameterList.Select(e => e.cSharpParameter.Type)
.Append(returnSyntax)
.OfType<GenericNameSyntax>()
.ToList();

// Step 4: Build body
var pythonConversionStatements = new List<StatementSyntax>();
Expand Down Expand Up @@ -104,11 +101,12 @@ public static MethodDefinition FromMethod(PythonFunctionDefinition function, str
callExpression = GenerateKeywordCall(parameterList);
}

var resultIdentifierName = IdentifierName("__result_pyObject");
ReturnStatementSyntax returnExpression = returnSyntax switch
{
PredefinedTypeSyntax s when s.Keyword.IsKind(SyntaxKind.VoidKeyword) => ReturnStatement(null),
IdentifierNameSyntax { Identifier.ValueText: "PyObject" } => ReturnStatement(IdentifierName("__result_pyObject")),
_ => ProcessMethodWithReturnType(returnSyntax, parameterGenericArgs)
IdentifierNameSyntax { Identifier.ValueText: "PyObject" } => ReturnStatement(resultIdentifierName),
_ => ProcessMethodWithReturnType(resultIdentifierName, returnSyntax)
};

bool resultShouldBeDisposed = returnSyntax switch
Expand Down Expand Up @@ -210,23 +208,36 @@ PredefinedTypeSyntax s when s.Keyword.IsKind(SyntaxKind.VoidKeyword) => true,
return new(syntax, parameterGenericArgs);
}

private static ReturnStatementSyntax ProcessMethodWithReturnType(TypeSyntax returnSyntax, List<GenericNameSyntax> parameterGenericArgs)
private static ReturnStatementSyntax ProcessMethodWithReturnType(IdentifierNameSyntax identifierNameSyntax,
TypeSyntax typeSyntax)
{
ReturnStatementSyntax returnExpression;
if (returnSyntax is GenericNameSyntax rg)
{
parameterGenericArgs.Add(rg);
}

returnExpression = ReturnStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("__result_pyObject"),
GenericName(Identifier("As"))
.WithTypeArgumentList(TypeArgumentList(SeparatedList([returnSyntax])))))
.WithArgumentList(ArgumentList()));
return returnExpression;
(typeSyntax, var nullable) =
typeSyntax is NullableTypeSyntax { ElementType: var elementTypeSyntax }
? (elementTypeSyntax, true)
: (typeSyntax, false);

var conversionExpression =
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
identifierNameSyntax,
GenericName(Identifier("As"))
.WithTypeArgumentList(TypeArgumentList(SeparatedList([typeSyntax])))))
.WithArgumentList(ArgumentList());

ExpressionSyntax returnValueExpression =
nullable
? ConditionalExpression(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
identifierNameSyntax,
IdentifierName("IsNone"))),
LiteralExpression(SyntaxKind.NullLiteralExpression),
conversionExpression)
: conversionExpression;

return ReturnStatement(returnValueExpression);
}

private static InvocationExpressionSyntax GenerateParamsCall(IEnumerable<(PythonFunctionParameter pythonParameter, ParameterSyntax cSharpParameter)> parameterList)
Expand Down
2 changes: 1 addition & 1 deletion src/CSnakes.SourceGeneration/Reflection/TypeReflection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static TypeSyntax AsPredefinedType(PythonTypeSpec pythonType, ConversionD
"typing.Dict" or "Dict" => CreateDictionaryType(pythonType.Arguments[0], pythonType.Arguments[1], direction),
"typing.Mapping" or "Mapping" => CreateDictionaryType(pythonType.Arguments[0], pythonType.Arguments[1], direction),
"typing.Sequence" or "Sequence" => CreateListType(pythonType.Arguments[0], direction),
"typing.Optional" or "Optional" => AsPredefinedType(pythonType.Arguments[0], direction),
"typing.Optional" or "Optional" => SyntaxFactory.NullableType(AsPredefinedType(pythonType.Arguments[0], direction)),
"typing.Generator" or "Generator" => CreateGeneratorType(pythonType.Arguments[0], pythonType.Arguments[1], pythonType.Arguments[2], direction),
// Todo more types... see https://docs.python.org/3/library/stdtypes.html#standard-generic-classes
_ => SyntaxFactory.ParseTypeName("PyObject"),
Expand Down
3 changes: 2 additions & 1 deletion src/CSnakes.Tests/GeneratedSignatureTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public class GeneratedSignatureTests(TestEnvironment testEnv) : IClassFixture<Te
[InlineData("def hello_world(numbers: Sequence[float]) -> typing.Sequence[int]:\n ...\n", "IReadOnlyList<long> HelloWorld(IReadOnlyList<double> numbers)")]
[InlineData("def hello_world(value: tuple[int]) -> None:\n ...\n", "void HelloWorld(ValueTuple<long> value)")]
[InlineData("def hello_world(a: bool, b: str, c: list[tuple[int, float]]) -> bool: \n ...\n", "bool HelloWorld(bool a, string b, IReadOnlyList<(long, double)> c)")]
[InlineData("def hello_world(a: bool = True, b: str = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string? b = null)")]
[InlineData("def hello_world(a: bool = True, b: Optional[str] = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string? b = null)")]
[InlineData("def hello_world(a: Optional[int], b: Optional[str]) -> Optional[bool]: \n ...\n", "bool? HelloWorld(long? a, string? b)")]
[InlineData("def hello_world(a: bytes, b: bool = False, c: float = 0.1) -> None: \n ...\n", "void HelloWorld(byte[] a, bool b = false, double c = 0.1)")]
[InlineData("def hello_world(a: str = 'default') -> None: \n ...\n", "void HelloWorld(string a = \"default\")")]
[InlineData("def hello_world(a: str, *args) -> None: \n ...\n", "void HelloWorld(string a, PyObject[]? args = null)")]
Expand Down
Loading
Loading