Skip to content

Commit c75bd20

Browse files
authored
Merge pull request #41 from microsoft/order_by_expressions_support
Order by expressions support
2 parents 8b3b466 + 7501aa1 commit c75bd20

File tree

10 files changed

+494
-30
lines changed

10 files changed

+494
-30
lines changed

flowquery-py/src/parsing/operations/order_by.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
"""Represents an ORDER BY operation that sorts results."""
22

3-
from typing import Any, Dict, List
3+
import functools
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
45

56
from .operation import Operation
67

8+
if TYPE_CHECKING:
9+
from ..expressions.expression import Expression
10+
711

812
class SortField:
9-
"""A single sort specification: field name and direction."""
13+
"""A single sort specification: expression and direction."""
1014

11-
def __init__(self, field: str, direction: str = "asc"):
12-
self.field = field
15+
def __init__(self, expression: 'Expression', direction: str = "asc"):
16+
self.expression = expression
1317
self.direction = direction
1418

1519

@@ -19,27 +23,63 @@ class OrderBy(Operation):
1923
Can be attached to a RETURN operation (sorting its results),
2024
or used as a standalone accumulating operation after a non-aggregate WITH.
2125
22-
Example:
26+
Supports both simple field references and arbitrary expressions:
27+
28+
Example::
29+
2330
RETURN x ORDER BY x DESC
31+
RETURN x ORDER BY toLower(x.name) ASC
32+
RETURN x ORDER BY string_distance(toLower(x.name), toLower('Thomas')) ASC
2433
"""
2534

2635
def __init__(self, fields: List[SortField]):
2736
super().__init__()
2837
self._fields = fields
2938
self._results: List[Dict[str, Any]] = []
39+
self._sort_keys: List[List[Any]] = []
3040

3141
@property
3242
def fields(self) -> List[SortField]:
3343
return self._fields
3444

35-
def sort(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
36-
"""Sorts an array of records according to the sort fields."""
37-
import functools
45+
def capture_sort_keys(self) -> None:
46+
"""Evaluate every sort-field expression against the current runtime
47+
context and store the resulting values. Must be called once per
48+
accumulated row (from ``Return.run()``)."""
49+
self._sort_keys.append([f.expression.value() for f in self._fields])
3850

39-
def compare(a: Dict[str, Any], b: Dict[str, Any]) -> int:
40-
for sf in self._fields:
41-
a_val = a.get(sf.field)
42-
b_val = b.get(sf.field)
51+
def sort(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
52+
"""Sort records using pre-computed sort keys captured during
53+
accumulation. When no keys have been captured (e.g. aggregated
54+
returns), falls back to looking up simple reference identifiers
55+
in each record."""
56+
from ..expressions.reference import Reference
57+
58+
use_keys = len(self._sort_keys) == len(records)
59+
keys = self._sort_keys
60+
61+
# Pre-compute fallback field names for when sort keys aren't
62+
# available (aggregated returns).
63+
fallback_fields: List[Optional[str]] = []
64+
for f in self._fields:
65+
root = f.expression.first_child()
66+
if isinstance(root, Reference) and f.expression.child_count() == 1:
67+
fallback_fields.append(root.identifier)
68+
else:
69+
fallback_fields.append(None)
70+
71+
indices = list(range(len(records)))
72+
73+
def compare(ai: int, bi: int) -> int:
74+
for f_idx, sf in enumerate(self._fields):
75+
if use_keys:
76+
a_val = keys[ai][f_idx]
77+
b_val = keys[bi][f_idx]
78+
elif fallback_fields[f_idx] is not None:
79+
a_val = records[ai].get(fallback_fields[f_idx]) # type: ignore[arg-type]
80+
b_val = records[bi].get(fallback_fields[f_idx]) # type: ignore[arg-type]
81+
else:
82+
continue
4383
cmp = 0
4484
if a_val is None and b_val is None:
4585
cmp = 0
@@ -55,7 +95,8 @@ def compare(a: Dict[str, Any], b: Dict[str, Any]) -> int:
5595
return -cmp if sf.direction == "desc" else cmp
5696
return 0
5797

58-
return sorted(records, key=functools.cmp_to_key(compare))
98+
indices.sort(key=functools.cmp_to_key(compare))
99+
return [records[i] for i in indices]
59100

60101
async def run(self) -> None:
61102
"""When used as a standalone operation, passes through to next."""
@@ -64,6 +105,7 @@ async def run(self) -> None:
64105

65106
async def initialize(self) -> None:
66107
self._results = []
108+
self._sort_keys = []
67109
if self.next:
68110
await self.next.initialize()
69111

flowquery-py/src/parsing/operations/return_op.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ async def run(self) -> None:
6868
# Deep copy objects to preserve their state
6969
value = copy.deepcopy(raw) if isinstance(raw, (dict, list)) else raw
7070
record[alias] = value
71+
# Capture sort-key values while expression bindings are still live.
72+
if self._order_by is not None:
73+
self._order_by.capture_sort_keys()
7174
self._results.append(record)
7275
if self._order_by is None and self._limit is not None:
7376
self._limit.increment()

flowquery-py/src/parsing/parser.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -767,10 +767,9 @@ def _parse_order_by(self) -> Optional[OrderBy]:
767767
self._expect_and_skip_whitespace_and_comments()
768768
fields: list[SortField] = []
769769
while True:
770-
if not self.token.is_identifier_or_keyword():
771-
raise ValueError("Expected field name in ORDER BY")
772-
field = self.token.value
773-
self.set_next_token()
770+
expression = self._parse_expression()
771+
if expression is None:
772+
raise ValueError("Expected expression in ORDER BY")
774773
self._skip_whitespace_and_comments()
775774
direction = "asc"
776775
if self.token.is_asc():
@@ -781,7 +780,7 @@ def _parse_order_by(self) -> Optional[OrderBy]:
781780
direction = "desc"
782781
self.set_next_token()
783782
self._skip_whitespace_and_comments()
784-
fields.append(SortField(field, direction))
783+
fields.append(SortField(expression, direction))
785784
if self.token.is_comma():
786785
self.set_next_token()
787786
self._skip_whitespace_and_comments()

flowquery-py/tests/compute/test_runner.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4408,6 +4408,133 @@ async def test_order_by_with_where(self):
44084408
assert results[3] == {"x": 4}
44094409
assert results[4] == {"x": 3}
44104410

4411+
@pytest.mark.asyncio
4412+
async def test_order_by_with_property_access_expression(self):
4413+
"""Test ORDER BY with property access expression."""
4414+
runner = Runner(
4415+
"unwind [{name: 'Charlie', age: 30}, {name: 'Alice', age: 25}, {name: 'Bob', age: 35}] as person "
4416+
"return person.name as name, person.age as age "
4417+
"order by person.name asc"
4418+
)
4419+
await runner.run()
4420+
results = runner.results
4421+
assert len(results) == 3
4422+
assert results[0] == {"name": "Alice", "age": 25}
4423+
assert results[1] == {"name": "Bob", "age": 35}
4424+
assert results[2] == {"name": "Charlie", "age": 30}
4425+
4426+
@pytest.mark.asyncio
4427+
async def test_order_by_with_function_expression(self):
4428+
"""Test ORDER BY with function expression."""
4429+
runner = Runner(
4430+
"unwind ['BANANA', 'apple', 'Cherry'] as fruit "
4431+
"return fruit "
4432+
"order by toLower(fruit)"
4433+
)
4434+
await runner.run()
4435+
results = runner.results
4436+
assert len(results) == 3
4437+
assert results[0] == {"fruit": "apple"}
4438+
assert results[1] == {"fruit": "BANANA"}
4439+
assert results[2] == {"fruit": "Cherry"}
4440+
4441+
@pytest.mark.asyncio
4442+
async def test_order_by_with_function_expression_descending(self):
4443+
"""Test ORDER BY with function expression descending."""
4444+
runner = Runner(
4445+
"unwind ['BANANA', 'apple', 'Cherry'] as fruit "
4446+
"return fruit "
4447+
"order by toLower(fruit) desc"
4448+
)
4449+
await runner.run()
4450+
results = runner.results
4451+
assert len(results) == 3
4452+
assert results[0] == {"fruit": "Cherry"}
4453+
assert results[1] == {"fruit": "BANANA"}
4454+
assert results[2] == {"fruit": "apple"}
4455+
4456+
@pytest.mark.asyncio
4457+
async def test_order_by_with_nested_function_expression(self):
4458+
"""Test ORDER BY with nested function expression."""
4459+
runner = Runner(
4460+
"unwind ['Alice', 'Bob', 'ALICE', 'bob'] as name "
4461+
"return name "
4462+
"order by string_distance(toLower(name), toLower('alice')) asc"
4463+
)
4464+
await runner.run()
4465+
results = runner.results
4466+
assert len(results) == 4
4467+
# 'Alice' and 'ALICE' have distance 0 from 'alice', should come first
4468+
assert results[0]["name"] == "Alice"
4469+
assert results[1]["name"] == "ALICE"
4470+
# 'Bob' and 'bob' have higher distance from 'alice'
4471+
assert results[2]["name"] == "Bob"
4472+
assert results[3]["name"] == "bob"
4473+
4474+
@pytest.mark.asyncio
4475+
async def test_order_by_with_arithmetic_expression(self):
4476+
"""Test ORDER BY with arithmetic expression."""
4477+
runner = Runner(
4478+
"unwind [{a: 3, b: 1}, {a: 1, b: 5}, {a: 2, b: 2}] as item "
4479+
"return item.a as a, item.b as b "
4480+
"order by item.a + item.b asc"
4481+
)
4482+
await runner.run()
4483+
results = runner.results
4484+
assert len(results) == 3
4485+
assert results[0] == {"a": 3, "b": 1} # sum = 4
4486+
assert results[1] == {"a": 2, "b": 2} # sum = 4
4487+
assert results[2] == {"a": 1, "b": 5} # sum = 6
4488+
4489+
@pytest.mark.asyncio
4490+
async def test_order_by_expression_does_not_leak_synthetic_keys(self):
4491+
"""Test ORDER BY expression does not leak synthetic keys."""
4492+
runner = Runner(
4493+
"unwind ['B', 'a', 'C'] as x "
4494+
"return x "
4495+
"order by toLower(x) asc"
4496+
)
4497+
await runner.run()
4498+
results = runner.results
4499+
assert len(results) == 3
4500+
# Results should only contain 'x', no extra keys
4501+
for r in results:
4502+
assert list(r.keys()) == ["x"]
4503+
assert results[0] == {"x": "a"}
4504+
assert results[1] == {"x": "B"}
4505+
assert results[2] == {"x": "C"}
4506+
4507+
@pytest.mark.asyncio
4508+
async def test_order_by_with_expression_and_limit(self):
4509+
"""Test ORDER BY with expression and limit."""
4510+
runner = Runner(
4511+
"unwind ['BANANA', 'apple', 'Cherry', 'date', 'ELDERBERRY'] as fruit "
4512+
"return fruit "
4513+
"order by toLower(fruit) asc "
4514+
"limit 3"
4515+
)
4516+
await runner.run()
4517+
results = runner.results
4518+
assert len(results) == 3
4519+
assert results[0] == {"fruit": "apple"}
4520+
assert results[1] == {"fruit": "BANANA"}
4521+
assert results[2] == {"fruit": "Cherry"}
4522+
4523+
@pytest.mark.asyncio
4524+
async def test_order_by_with_mixed_simple_and_expression_fields(self):
4525+
"""Test ORDER BY with mixed simple and expression fields."""
4526+
runner = Runner(
4527+
"unwind [{name: 'Alice', score: 3}, {name: 'Alice', score: 1}, {name: 'Bob', score: 2}] as item "
4528+
"return item.name as name, item.score as score "
4529+
"order by name asc, item.score desc"
4530+
)
4531+
await runner.run()
4532+
results = runner.results
4533+
assert len(results) == 3
4534+
assert results[0] == {"name": "Alice", "score": 3} # Alice, score 3 desc
4535+
assert results[1] == {"name": "Alice", "score": 1} # Alice, score 1 desc
4536+
assert results[2] == {"name": "Bob", "score": 2} # Bob
4537+
44114538
@pytest.mark.asyncio
44124539
async def test_delete_virtual_node_operation(self):
44134540
"""Test delete virtual node operation."""

flowquery-py/tests/parsing/test_parser.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,3 +1172,66 @@ def test_optional_without_match_throws_error(self):
11721172
parser = Parser()
11731173
with pytest.raises(Exception, match="Expected MATCH after OPTIONAL"):
11741174
parser.parse("OPTIONAL RETURN 1")
1175+
1176+
# ORDER BY expression tests
1177+
1178+
def test_order_by_simple_identifier(self):
1179+
"""Test ORDER BY with a simple identifier parses correctly."""
1180+
parser = Parser()
1181+
ast = parser.parse("unwind [1, 2] as x return x order by x")
1182+
assert ast is not None
1183+
1184+
def test_order_by_property_access(self):
1185+
"""Test ORDER BY with property access parses correctly."""
1186+
parser = Parser()
1187+
ast = parser.parse(
1188+
"unwind [{name: 'Bob'}, {name: 'Alice'}] as person "
1189+
"return person.name as name order by person.name asc"
1190+
)
1191+
assert ast is not None
1192+
1193+
def test_order_by_function_call(self):
1194+
"""Test ORDER BY with function call parses correctly."""
1195+
parser = Parser()
1196+
ast = parser.parse(
1197+
"unwind ['HELLO', 'WORLD'] as word "
1198+
"return word order by toLower(word) asc"
1199+
)
1200+
assert ast is not None
1201+
1202+
def test_order_by_nested_function_calls(self):
1203+
"""Test ORDER BY with nested function calls parses correctly."""
1204+
parser = Parser()
1205+
ast = parser.parse(
1206+
"unwind ['Alice', 'Bob'] as name "
1207+
"return name order by string_distance(toLower(name), toLower('alice')) asc"
1208+
)
1209+
assert ast is not None
1210+
1211+
def test_order_by_arithmetic_expression(self):
1212+
"""Test ORDER BY with arithmetic expression parses correctly."""
1213+
parser = Parser()
1214+
ast = parser.parse(
1215+
"unwind [{a: 3, b: 1}, {a: 1, b: 5}] as item "
1216+
"return item.a as a, item.b as b order by item.a + item.b desc"
1217+
)
1218+
assert ast is not None
1219+
1220+
def test_order_by_multiple_expression_fields(self):
1221+
"""Test ORDER BY with multiple expression fields parses correctly."""
1222+
parser = Parser()
1223+
ast = parser.parse(
1224+
"unwind [{a: 1, b: 2}] as item "
1225+
"return item.a as a, item.b as b "
1226+
"order by toLower(item.a) asc, item.b desc"
1227+
)
1228+
assert ast is not None
1229+
1230+
def test_order_by_expression_with_limit(self):
1231+
"""Test ORDER BY with expression and LIMIT parses correctly."""
1232+
parser = Parser()
1233+
ast = parser.parse(
1234+
"unwind ['c', 'a', 'b'] as x "
1235+
"return x order by toLower(x) asc limit 2"
1236+
)
1237+
assert ast is not None

0 commit comments

Comments
 (0)