Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 18 additions & 192 deletions dotnet/src/VectorData/AzureAISearch/AzureAISearchFilterTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
Expand All @@ -13,26 +10,17 @@

namespace Microsoft.SemanticKernel.Connectors.AzureAISearch;

internal class AzureAISearchFilterTranslator
{
private CollectionModel _model = null!;
private ParameterExpression _recordParameter = null!;
#pragma warning disable MEVD9001 // Experimental: filter translation base types

internal class AzureAISearchFilterTranslator : FilterTranslatorBase
{
private readonly StringBuilder _filter = new();

private static readonly char[] s_searchInDefaultDelimiter = [' ', ','];

internal string Translate(LambdaExpression lambdaExpression, CollectionModel model)
{
Debug.Assert(this._filter.Length == 0);

this._model = model;

Debug.Assert(lambdaExpression.Parameters.Count == 1);
this._recordParameter = lambdaExpression.Parameters[0];

var preprocessor = new FilterTranslationPreprocessor { SupportsParameterization = false };
var preprocessedExpression = preprocessor.Preprocess(lambdaExpression.Body);
var preprocessedExpression = this.PreprocessFilter(lambdaExpression, model, new FilterPreprocessingOptions());

this.Translate(preprocessedExpression);

Expand Down Expand Up @@ -161,52 +149,25 @@ private void TranslateMember(MemberExpression memberExpression)

private void TranslateMethodCall(MethodCallExpression methodCall)
{
switch (methodCall)
// Dictionary access for dynamic mapping (r => r["SomeString"] == "foo")
if (this.TryBindProperty(methodCall, out var property))
{
// Dictionary access for dynamic mapping (r => r["SomeString"] == "foo")
case MethodCallExpression when this.TryBindProperty(methodCall, out var property):
// OData identifiers cannot be escaped; storage names are validated during model building.
this._filter.Append(property.StorageName);
return;

// Enumerable.Contains()
case { Method.Name: nameof(Enumerable.Contains), Arguments: [var source, var item] } contains
when contains.Method.DeclaringType == typeof(Enumerable):
this.TranslateContains(source, item);
return;

// List.Contains()
case
{
Method:
{
Name: nameof(Enumerable.Contains),
DeclaringType: { IsGenericType: true } declaringType
},
Object: Expression source,
Arguments: [var item]
} when declaringType.GetGenericTypeDefinition() == typeof(List<>):
this.TranslateContains(source, item);
return;
// OData identifiers cannot be escaped; storage names are validated during model building.
this._filter.Append(property.StorageName);
return;
}

// C# 14 made changes to overload resolution to prefer Span-based overloads when those exist ("first-class spans");
// this makes MemoryExtensions.Contains() be resolved rather than Enumerable.Contains() (see above).
// MemoryExtensions.Contains() also accepts a Span argument for the source, adding an implicit cast we need to remove.
// See https://github.com/dotnet/runtime/issues/109757 for more context.
// Note that MemoryExtensions.Contains has an optional 3rd ComparisonType parameter; we only match when
// it's null.
case { Method.Name: nameof(MemoryExtensions.Contains), Arguments: [var spanArg, var item, ..] } contains
when contains.Method.DeclaringType == typeof(MemoryExtensions)
&& (contains.Arguments.Count is 2
|| (contains.Arguments.Count is 3 && contains.Arguments[2] is ConstantExpression { Value: null }))
&& TryUnwrapSpanImplicitCast(spanArg, out var source):
switch (methodCall)
{
// Enumerable.Contains(), List.Contains(), MemoryExtensions.Contains()
case var _ when TryMatchContains(methodCall, out var source, out var item):
this.TranslateContains(source, item);
return;

// Enumerable.Any() with a Contains predicate (r => r.Strings.Any(s => array.Contains(s)))
case { Method.Name: nameof(Enumerable.Any), Arguments: [var source, LambdaExpression lambda] } any
case { Method.Name: nameof(Enumerable.Any), Arguments: [var anySource, LambdaExpression lambda] } any
when any.Method.DeclaringType == typeof(Enumerable):
this.TranslateAny(source, lambda);
this.TranslateAny(anySource, lambda);
return;

default:
Expand Down Expand Up @@ -254,35 +215,12 @@ private void TranslateAny(Expression source, LambdaExpression lambda)
// We only support the pattern: r.ArrayField.Any(x => values.Contains(x))
// Translates to: Field/any(t: search.in(t, 'value1, value2, value3'))
if (!this.TryBindProperty(source, out var property)
|| lambda.Body is not MethodCallExpression { Method.Name: "Contains" } containsCall)
|| lambda.Body is not MethodCallExpression { Method.Name: "Contains" } containsCall
|| !TryMatchContains(containsCall, out var valuesExpression, out var itemExpression))
{
throw new NotSupportedException("Unsupported method call: Enumerable.Any");
}

// Match Enumerable.Contains(source, item), List<T>.Contains(item), or MemoryExtensions.Contains
var (valuesExpression, itemExpression) = containsCall switch
{
// Enumerable.Contains(source, item)
{ Method.Name: nameof(Enumerable.Contains), Arguments: [var src, var item] }
when containsCall.Method.DeclaringType == typeof(Enumerable)
=> (src, item),

// List<T>.Contains(item)
{ Method: { Name: nameof(Enumerable.Contains), DeclaringType: { IsGenericType: true } declaringType }, Object: Expression src, Arguments: [var item] }
when declaringType.GetGenericTypeDefinition() == typeof(List<>)
=> (src, item),

// MemoryExtensions.Contains (C# 14 first-class spans)
{ Method.Name: nameof(MemoryExtensions.Contains), Arguments: [var spanArg, var item, ..] }
when containsCall.Method.DeclaringType == typeof(MemoryExtensions)
&& (containsCall.Arguments.Count is 2
|| (containsCall.Arguments.Count is 3 && containsCall.Arguments[2] is ConstantExpression { Value: null }))
&& TryUnwrapSpanImplicitCast(spanArg, out var unwrappedSource)
=> (unwrappedSource, item),

_ => throw new NotSupportedException("Unsupported method call: Enumerable.Any"),
};

// Verify that the item is the lambda parameter
if (itemExpression != lambda.Parameters[0])
{
Expand Down Expand Up @@ -390,65 +328,6 @@ private void GenerateSearchInValues(IEnumerable values)
return result;
}

private static bool TryUnwrapSpanImplicitCast(Expression expression, [NotNullWhen(true)] out Expression? result)
{
// Different versions of the compiler seem to generate slightly different expression tree representations for this
// implicit cast:
var (unwrapped, castDeclaringType) = expression switch
{
UnaryExpression
{
NodeType: ExpressionType.Convert,
Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType },
Operand: var operand
} => (operand, implicitCastDeclaringType),

MethodCallExpression
{
Method: { Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType },
Arguments: [var firstArgument]
} => (firstArgument, implicitCastDeclaringType),

// After the preprocessor runs, the Convert node may have Method: null because the visitor
// recreates the UnaryExpression with a different operand type (QueryParameterExpression).
// Handle this case by checking if the target type is Span<T> or ReadOnlySpan<T>.
UnaryExpression
{
NodeType: ExpressionType.Convert,
Method: null,
Type: { IsGenericType: true } targetType,
Operand: var operand
} when targetType.GetGenericTypeDefinition() is var gtd
&& (gtd == typeof(Span<>) || gtd == typeof(ReadOnlySpan<>))
=> (operand, targetType),

_ => (null, null)
};

// For the dynamic case, there's a Convert node representing an up-cast to object[]; unwrap that too.
// Also handle cases where the preprocessor adds a Convert node back to the array type.
while (unwrapped is UnaryExpression
{
NodeType: ExpressionType.Convert,
Method: null,
Operand: var innerOperand
})
{
unwrapped = innerOperand;
}

if (unwrapped is not null
&& castDeclaringType?.GetGenericTypeDefinition() is var genericTypeDefinition
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)))
{
result = unwrapped;
return true;
}

result = null;
return false;
}

private void TranslateUnary(UnaryExpression unary)
{
switch (unary.NodeType)
Expand Down Expand Up @@ -485,57 +364,4 @@ private void TranslateUnary(UnaryExpression unary)
throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType);
}
}

private bool TryBindProperty(Expression expression, [NotNullWhen(true)] out PropertyModel? property)
{
var unwrappedExpression = expression;
while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert)
{
unwrappedExpression = convert.Operand;
}

var modelName = unwrappedExpression switch
{
// Regular member access for strongly-typed POCO binding (e.g. r => r.SomeInt == 8)
MemberExpression memberExpression when memberExpression.Expression == this._recordParameter
=> memberExpression.Member.Name,

// Dictionary lookup for weakly-typed dynamic binding (e.g. r => r["SomeInt"] == 8)
MethodCallExpression
{
Method: { Name: "get_Item", DeclaringType: var declaringType },
Arguments: [ConstantExpression { Value: string keyName }]
} methodCall when methodCall.Object == this._recordParameter && declaringType == typeof(Dictionary<string, object?>)
=> keyName,

_ => null
};

if (modelName is null)
{
property = null;
return false;
}

if (!this._model.PropertyMap.TryGetValue(modelName, out property))
{
throw new InvalidOperationException($"Property name '{modelName}' provided as part of the filter clause is not a valid property name.");
}

// Now that we have the property, go over all wrapping Convert nodes again to ensure that they're compatible with the property type
var unwrappedPropertyType = Nullable.GetUnderlyingType(property.Type) ?? property.Type;
unwrappedExpression = expression;
while (unwrappedExpression is UnaryExpression { NodeType: ExpressionType.Convert } convert)
{
var convertType = Nullable.GetUnderlyingType(convert.Type) ?? convert.Type;
if (convertType != unwrappedPropertyType && convertType != typeof(object))
{
throw new InvalidCastException($"Property '{property.ModelName}' is being cast to type '{convert.Type.Name}', but its configured type is '{property.Type.Name}'.");
}

unwrappedExpression = convert.Operand;
}

return true;
}
}
Loading
Loading