Skip to content

Commit

Permalink
Fix optionality of function parameters & return
Browse files Browse the repository at this point in the history
  • Loading branch information
atifaziz committed Oct 26, 2024
1 parent 048e011 commit 2ef29c4
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ from type in Token.EqualTo(PythonToken.Colon).Optional().Then(
from defaultValue in Token.EqualTo(PythonToken.Equal).Optional().Then(
_ => ConstantValueTokenizer.AssumeNotNull().OptionalOrDefault()
)
select new PythonFunctionParameter(arg.Name, type, defaultValue, arg.ParameterType))
select new PythonFunctionParameter(arg.Name, type,
// Force a default value for *args and **kwargs as null, otherwise the calling convention is strange
arg.ParameterType is PythonFunctionParameterType.Star or PythonFunctionParameterType.DoubleStar
&& defaultValue is null
? PythonConstant.None.Value
: defaultValue,
arg.ParameterType))
.Named("Parameter");

public static TokenListParser<PythonToken, PythonFunctionParameter?> ParameterOrSlash { get; } =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class PythonFunctionParameter(string name, PythonTypeSpec? type, PythonCo

public bool IsKeywordOnly { get; set; }

public PythonConstant? DefaultValue { get; set; } = defaultValue;
public PythonConstant? DefaultValue { get; } = defaultValue;

public PythonFunctionParameterType ParameterType { get; } = parameterType;

Expand Down
97 changes: 28 additions & 69 deletions src/CSnakes.SourceGeneration/Reflection/ArgumentReflection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ namespace CSnakes.Reflection;

public class ArgumentReflection
{
private static readonly PythonTypeSpec DictStrAny = new("dict", [new("str", []), PythonTypeSpec.Any]);
private static readonly TypeSyntax ArrayPyObject = SyntaxFactory.ParseTypeName("PyObject[]");
private static readonly TypeSyntax StarReflectedType = SyntaxFactory.ParseTypeName("PyObject[]?");
private static readonly TypeSyntax DoubleStarReflectedType = SyntaxFactory.ParseTypeName("IReadOnlyDictionary<string, PyObject>?");

public static ParameterSyntax? ArgumentSyntax(PythonFunctionParameter parameter)
{
Expand All @@ -19,82 +19,41 @@ public class ArgumentReflection

// Treat *args as list<Any>=None and **kwargs as dict<str, Any>=None
// TODO: Handle the user specifying *args with a type annotation like tuple[int, str]
TypeSyntax reflectedType = parameter.ParameterType switch
var (reflectedType, defaultValue) = parameter switch
{
PythonFunctionParameterType.Star => ArrayPyObject,
PythonFunctionParameterType.DoubleStar => TypeReflection.AsPredefinedType(DictStrAny, TypeReflection.ConversionDirection.ToPython),
PythonFunctionParameterType.Normal => TypeReflection.AsPredefinedType(parameter.Type, TypeReflection.ConversionDirection.ToPython),
{ ParameterType: PythonFunctionParameterType.Star } => (StarReflectedType, PythonConstant.None.Value),
{ ParameterType: PythonFunctionParameterType.DoubleStar } => (DoubleStarReflectedType, PythonConstant.None.Value),
{ ParameterType: PythonFunctionParameterType.Normal, Type: var pt, DefaultValue: var dv } =>
(TypeReflection.AsPredefinedType(pt, TypeReflection.ConversionDirection.ToPython), dv),
_ => throw new NotImplementedException()
};

// Force a default value for *args and **kwargs as null, otherwise the calling convention is strange
if ((parameter.ParameterType == PythonFunctionParameterType.Star ||
parameter.ParameterType == PythonFunctionParameterType.DoubleStar) &&
parameter.DefaultValue is null)

{
parameter.DefaultValue = PythonConstant.None.Value;
}

bool isNullableType = false;

LiteralExpressionSyntax? literalExpressionSyntax;

switch (parameter.DefaultValue)
var literalExpressionSyntax = defaultValue switch
{
case null:
literalExpressionSyntax = null;
break;
case PythonConstant.HexidecimalInteger { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0x{v:X}", v));
break;
case PythonConstant.BinaryInteger { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0b{Convert.ToString(v, 2)}", v));
break;
case PythonConstant.Integer { Value: var v and >= int.MinValue and <= int.MaxValue }:
null => null,
PythonConstant.HexidecimalInteger { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0x{v:X}", v)),
PythonConstant.BinaryInteger { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal($"0b{Convert.ToString(v, 2)}", v)),
PythonConstant.Integer { Value: var v and >= int.MinValue and <= int.MaxValue } =>
// Downcast long to int if the value is small as the code is more readable without the L suffix
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal((int)v));
break;
case PythonConstant.Integer { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal(v));
break;
case PythonConstant.String { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.StringLiteralExpression,
SyntaxFactory.Literal(v));
break;
case PythonConstant.Float { Value: var v }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(
SyntaxKind.NumericLiteralExpression,
SyntaxFactory.Literal(v));
break;
case PythonConstant.Bool { Value: true }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression);
break;
case PythonConstant.Bool { Value: false }:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression);
break;
case PythonConstant.None:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);
isNullableType = true;
break;
default:
literalExpressionSyntax = SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression);
isNullableType = true;
break;
}
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal((int)v)),
PythonConstant.Integer { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(v)),
PythonConstant.String { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.StringLiteralExpression, SyntaxFactory.Literal(v)),
PythonConstant.Float { Value: var v } =>
SyntaxFactory.LiteralExpression(SyntaxKind.NumericLiteralExpression, SyntaxFactory.Literal(v)),
PythonConstant.Bool { Value: true } => SyntaxFactory.LiteralExpression(SyntaxKind.TrueLiteralExpression),
PythonConstant.Bool { Value: false } => SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression),
_ => SyntaxFactory.LiteralExpression(SyntaxKind.NullLiteralExpression)
};

return SyntaxFactory
.Parameter(SyntaxFactory.Identifier(Keywords.ValidIdentifier(parameter.Name.ToLowerPascalCase())))
.WithType(isNullableType ? SyntaxFactory.NullableType(reflectedType) : reflectedType)
.WithType(reflectedType)
.WithDefault(literalExpressionSyntax is not null ? SyntaxFactory.EqualsValueClause(literalExpressionSyntax) : null);
}

Expand Down
63 changes: 37 additions & 26 deletions src/CSnakes.SourceGeneration/Reflection/MethodReflection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@ public static MethodDefinition FromMethod(PythonFunctionDefinition function, str
// Step 3: Build arguments
List<(PythonFunctionParameter pythonParameter, ParameterSyntax cSharpParameter)> parameterList = ArgumentReflection.FunctionParametersAsParameterSyntaxPairs(function.Parameters);

List<GenericNameSyntax> parameterGenericArgs = [];
foreach (var (pythonParameter, cSharpParameter) in parameterList)
{
if (cSharpParameter.Type is GenericNameSyntax g)
{
parameterGenericArgs.Add(g);
}
}
var parameterGenericArgs =
parameterList.Select(e => e.cSharpParameter.Type)
.Append(returnSyntax)
.OfType<GenericNameSyntax>()
.ToList();

// Step 4: Build body
var pythonConversionStatements = new List<StatementSyntax>();
Expand Down Expand Up @@ -104,11 +101,12 @@ public static MethodDefinition FromMethod(PythonFunctionDefinition function, str
callExpression = GenerateKeywordCall(parameterList);
}

var resultIdentifierName = IdentifierName("__result_pyObject");
ReturnStatementSyntax returnExpression = returnSyntax switch
{
PredefinedTypeSyntax s when s.Keyword.IsKind(SyntaxKind.VoidKeyword) => ReturnStatement(null),
IdentifierNameSyntax { Identifier.ValueText: "PyObject" } => ReturnStatement(IdentifierName("__result_pyObject")),
_ => ProcessMethodWithReturnType(returnSyntax, parameterGenericArgs)
IdentifierNameSyntax { Identifier.ValueText: "PyObject" } => ReturnStatement(resultIdentifierName),
_ => ProcessMethodWithReturnType(resultIdentifierName, returnSyntax)
};

bool resultShouldBeDisposed = returnSyntax switch
Expand Down Expand Up @@ -210,23 +208,36 @@ PredefinedTypeSyntax s when s.Keyword.IsKind(SyntaxKind.VoidKeyword) => true,
return new(syntax, parameterGenericArgs);
}

private static ReturnStatementSyntax ProcessMethodWithReturnType(TypeSyntax returnSyntax, List<GenericNameSyntax> parameterGenericArgs)
private static ReturnStatementSyntax ProcessMethodWithReturnType(IdentifierNameSyntax identifierNameSyntax,
TypeSyntax typeSyntax)
{
ReturnStatementSyntax returnExpression;
if (returnSyntax is GenericNameSyntax rg)
{
parameterGenericArgs.Add(rg);
}

returnExpression = ReturnStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("__result_pyObject"),
GenericName(Identifier("As"))
.WithTypeArgumentList(TypeArgumentList(SeparatedList([returnSyntax])))))
.WithArgumentList(ArgumentList()));
return returnExpression;
(typeSyntax, var nullable) =
typeSyntax is NullableTypeSyntax { ElementType: var elementTypeSyntax }
? (elementTypeSyntax, true)
: (typeSyntax, false);

var conversionExpression =
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
identifierNameSyntax,
GenericName(Identifier("As"))
.WithTypeArgumentList(TypeArgumentList(SeparatedList([typeSyntax])))))
.WithArgumentList(ArgumentList());

ExpressionSyntax returnValueExpression =
nullable
? ConditionalExpression(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
identifierNameSyntax,
IdentifierName("IsNone"))),
LiteralExpression(SyntaxKind.NullLiteralExpression),
conversionExpression)
: conversionExpression;

return ReturnStatement(returnValueExpression);
}

private static InvocationExpressionSyntax GenerateParamsCall(IEnumerable<(PythonFunctionParameter pythonParameter, ParameterSyntax cSharpParameter)> parameterList)
Expand Down
2 changes: 1 addition & 1 deletion src/CSnakes.SourceGeneration/Reflection/TypeReflection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static TypeSyntax AsPredefinedType(PythonTypeSpec pythonType, ConversionD
"typing.Dict" or "Dict" => CreateDictionaryType(pythonType.Arguments[0], pythonType.Arguments[1], direction),
"typing.Mapping" or "Mapping" => CreateDictionaryType(pythonType.Arguments[0], pythonType.Arguments[1], direction),
"typing.Sequence" or "Sequence" => CreateListType(pythonType.Arguments[0], direction),
"typing.Optional" or "Optional" => AsPredefinedType(pythonType.Arguments[0], direction),
"typing.Optional" or "Optional" => SyntaxFactory.NullableType(AsPredefinedType(pythonType.Arguments[0], direction)),
"typing.Generator" or "Generator" => CreateGeneratorType(pythonType.Arguments[0], pythonType.Arguments[1], pythonType.Arguments[2], direction),
// Todo more types... see https://docs.python.org/3/library/stdtypes.html#standard-generic-classes
_ => SyntaxFactory.ParseTypeName("PyObject"),
Expand Down
3 changes: 2 additions & 1 deletion src/CSnakes.Tests/GeneratedSignatureTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public class GeneratedSignatureTests(TestEnvironment testEnv) : IClassFixture<Te
[InlineData("def hello_world(numbers: Sequence[float]) -> typing.Sequence[int]:\n ...\n", "IReadOnlyList<long> HelloWorld(IReadOnlyList<double> numbers)")]
[InlineData("def hello_world(value: tuple[int]) -> None:\n ...\n", "void HelloWorld(ValueTuple<long> value)")]
[InlineData("def hello_world(a: bool, b: str, c: list[tuple[int, float]]) -> bool: \n ...\n", "bool HelloWorld(bool a, string b, IReadOnlyList<(long, double)> c)")]
[InlineData("def hello_world(a: bool = True, b: str = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string? b = null)")]
[InlineData("def hello_world(a: bool = True, b: Optional[str] = None) -> bool: \n ...\n", "bool HelloWorld(bool a = true, string? b = null)")]
[InlineData("def hello_world(a: Optional[int], b: Optional[str]) -> Optional[bool]: \n ...\n", "bool? HelloWorld(long? a, string? b)")]
[InlineData("def hello_world(a: bytes, b: bool = False, c: float = 0.1) -> None: \n ...\n", "void HelloWorld(byte[] a, bool b = false, double c = 0.1)")]
[InlineData("def hello_world(a: str = 'default') -> None: \n ...\n", "void HelloWorld(string a = \"default\")")]
[InlineData("def hello_world(a: str, *args) -> None: \n ...\n", "void HelloWorld(string a, PyObject[]? args = null)")]
Expand Down
Loading

0 comments on commit 2ef29c4

Please sign in to comment.