Skip to content

Commit 25efc21

Browse files
author
Killian Perlin
committed
Add QueryComprehension
Closes: #207
1 parent 32a51d2 commit 25efc21

File tree

7 files changed

+320
-1
lines changed

7 files changed

+320
-1
lines changed

lkql_jit/language/src/main/java/com/adacore/lkql_jit/langkit_translator/passes/LktPasses.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ private static boolean isStreamOpRhs(LktNode node) {
8080
);
8181
}
8282

83-
/** Whether "node" needs a frame to be introduced to contain inner bindings. */
83+
/**
84+
* Whether "node" needs a frame to be introduced to contain inner bindings.
85+
* NB: this does not include the Query node (which needs a concrete frame)
86+
* because its handled in its own separate function
87+
*/
8488
private static FrameKind needsFrame(LktNode node) {
8589
if (
8690
node instanceof FunDecl ||
@@ -116,6 +120,12 @@ private static String getBindingName(LktNode node) {
116120
* nodes and building the frame tree.
117121
*/
118122
private static void recurseBuildFrames(LktNode node, ScriptFramesBuilder builder) {
123+
// Queries need special handling
124+
if (node instanceof Liblktlang.Query query) {
125+
handleQuery(query, builder);
126+
return;
127+
}
128+
119129
var bindingName = getBindingName(node);
120130
if (bindingName != null) {
121131
builder.addBinding(bindingName);
@@ -152,6 +162,19 @@ public static ScriptFramesBuilder buildFrames(LktNode root) {
152162
recurseBuildFrames(root, builder);
153163
return builder;
154164
}
165+
166+
/**
167+
* Special handling for query comprehensions
168+
* since the source part is not included in the concrete frame.
169+
*/
170+
public static void handleQuery(Liblktlang.Query query, ScriptFramesBuilder builder) {
171+
recurseBuildFrames(query.fSource(), builder);
172+
builder.openFrame(query);
173+
recurseBuildFrames(query.fPattern(), builder);
174+
if (!query.fGuard().isNone()) recurseBuildFrames(query.fGuard(), builder);
175+
if (!query.fMapping().isNone()) recurseBuildFrames(query.fMapping(), builder);
176+
builder.closeFrame();
177+
}
155178
}
156179

157180
private static class TranslationPass extends BaseTranslationPass {
@@ -531,6 +554,29 @@ private Expr buildExpr(Liblktlang.Expr expr) {
531554
lambdaExpr.fBody(),
532555
lambdaExpr.fParams()
533556
);
557+
} else if (expr instanceof Liblktlang.Query query) {
558+
var source = buildExpr(query.fSource());
559+
560+
this.frames.enterFrame(query);
561+
562+
var pattern = buildPattern(query.fPattern());
563+
var guard = query.fGuard().isNone() ? null : buildExpr(query.fGuard());
564+
var result = query.fMapping().isNone() ? null : buildExpr(query.fMapping());
565+
566+
var frameDescriptor = this.frames.getFrameDescriptor();
567+
var closureDescriptor = this.frames.getClosureDescriptor();
568+
569+
this.frames.exitFrame();
570+
571+
return QueryComprehensionNodeGen.create(
572+
loc(query),
573+
frameDescriptor,
574+
closureDescriptor,
575+
pattern,
576+
guard,
577+
result,
578+
source
579+
);
534580
} else {
535581
throw LKQLRuntimeException.create(
536582
"Translation for " + expr.getKind() + " not implemented"
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//
2+
// Copyright (C) 2005-2025, AdaCore
3+
// SPDX-License-Identifier: GPL-3.0-or-later
4+
//
5+
6+
package com.adacore.lkql_jit.nodes.expressions;
7+
8+
import com.adacore.lkql_jit.LKQLLanguage;
9+
import com.adacore.lkql_jit.exception.LKQLRuntimeException;
10+
import com.adacore.lkql_jit.nodes.patterns.Pattern;
11+
import com.adacore.lkql_jit.nodes.root_nodes.QueryComprehensionRootNode;
12+
import com.adacore.lkql_jit.nodes.utils.CreateClosureNode;
13+
import com.adacore.lkql_jit.runtime.values.interfaces.Iterable;
14+
import com.adacore.lkql_jit.runtime.values.lists.LKQLQueryComprehension;
15+
import com.adacore.lkql_jit.utils.ClosureDescriptor;
16+
import com.adacore.lkql_jit.utils.LKQLTypesHelper;
17+
import com.oracle.truffle.api.dsl.Fallback;
18+
import com.oracle.truffle.api.dsl.NodeChild;
19+
import com.oracle.truffle.api.dsl.Specialization;
20+
import com.oracle.truffle.api.frame.FrameDescriptor;
21+
import com.oracle.truffle.api.frame.VirtualFrame;
22+
import com.oracle.truffle.api.source.SourceSection;
23+
24+
@NodeChild(value = "source", type = Expr.class)
25+
public abstract class QueryComprehension extends Expr {
26+
27+
// ----- Attributes -----
28+
29+
private final QueryComprehensionRootNode rootNode;
30+
31+
// ----- Children -----
32+
33+
@Child
34+
@SuppressWarnings("FieldMayBeFinal")
35+
private CreateClosureNode createClosureNode;
36+
37+
// ----- Constructors -----
38+
39+
@SuppressWarnings("this-escape")
40+
protected QueryComprehension(
41+
final SourceSection location,
42+
final FrameDescriptor frameDescriptor,
43+
final ClosureDescriptor closureDescriptor,
44+
final Pattern pattern,
45+
final Expr guard,
46+
final Expr result
47+
) {
48+
super(location);
49+
this.rootNode = new QueryComprehensionRootNode(
50+
LKQLLanguage.getLanguage(this),
51+
frameDescriptor,
52+
pattern,
53+
guard,
54+
result
55+
);
56+
this.createClosureNode = new CreateClosureNode(closureDescriptor);
57+
}
58+
59+
// ----- Execution methods -----
60+
61+
@Specialization
62+
protected LKQLQueryComprehension onIterable(VirtualFrame frame, Iterable source) {
63+
return new LKQLQueryComprehension(this.rootNode, createClosureNode.execute(frame), source);
64+
}
65+
66+
@Fallback
67+
protected void fallback(VirtualFrame frame, Object notIterable) {
68+
throw LKQLRuntimeException.wrongType(
69+
LKQLTypesHelper.LKQL_ITERABLE,
70+
LKQLTypesHelper.fromJava(notIterable),
71+
this.getSource()
72+
);
73+
}
74+
75+
// ----- Class methods -----
76+
77+
abstract Expr getSource();
78+
79+
// ----- Override methods -----
80+
81+
/**
82+
* @see com.adacore.lkql_jit.nodes.LKQLNode#toString(int)
83+
*/
84+
@Override
85+
public String toString(int indentLevel) {
86+
return this.nodeRepresentation(indentLevel);
87+
}
88+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//
2+
// Copyright (C) 2005-2025, AdaCore
3+
// SPDX-License-Identifier: GPL-3.0-or-later
4+
//
5+
6+
package com.adacore.lkql_jit.nodes.root_nodes;
7+
8+
import com.adacore.lkql_jit.nodes.expressions.Expr;
9+
import com.adacore.lkql_jit.nodes.expressions.LKQLToBoolean;
10+
import com.adacore.lkql_jit.nodes.expressions.LKQLToBooleanNodeGen;
11+
import com.adacore.lkql_jit.nodes.patterns.Pattern;
12+
import com.oracle.truffle.api.TruffleLanguage;
13+
import com.oracle.truffle.api.frame.FrameDescriptor;
14+
import com.oracle.truffle.api.frame.VirtualFrame;
15+
16+
public final class QueryComprehensionRootNode extends BaseRootNode {
17+
18+
@Child
19+
@SuppressWarnings("FieldMayBeFinal")
20+
private Pattern pattern;
21+
22+
@Child
23+
@SuppressWarnings("FieldMayBeFinal")
24+
private Expr guard;
25+
26+
@Child
27+
@SuppressWarnings("FieldMayBeFinal")
28+
private Expr result;
29+
30+
@Child
31+
private LKQLToBoolean toBoolean;
32+
33+
// ----- Constructors -----
34+
35+
public QueryComprehensionRootNode(
36+
TruffleLanguage<?> language,
37+
FrameDescriptor frameDescriptor,
38+
Pattern pattern,
39+
Expr guard,
40+
Expr result
41+
) {
42+
super(language, frameDescriptor);
43+
this.pattern = pattern;
44+
this.guard = guard;
45+
this.result = result;
46+
this.toBoolean = LKQLToBooleanNodeGen.create();
47+
}
48+
49+
// ----- Execution methods -----
50+
51+
/**
52+
* @see
53+
* com.adacore.lkql_jit.nodes.root_nodes.BaseRootNode#execute(com.oracle.truffle.api.frame.VirtualFrame)
54+
*/
55+
@Override
56+
public Object execute(VirtualFrame frame) {
57+
this.initFrame(frame);
58+
59+
// by convention only 2 args [closure, arg]
60+
// see QueryComprehension
61+
var arg = frame.getArguments()[1];
62+
63+
// pattern does not match -> early exit
64+
if (!this.pattern.executeValue(frame, arg)) return null;
65+
66+
// guard present and evaluates to false -> early exit
67+
if (this.guard != null && !toBoolean.execute(guard.executeGeneric(frame))) return null;
68+
69+
// result present -> evaluate result
70+
// else default to arg
71+
return this.result != null ? this.result.executeGeneric(frame) : arg;
72+
}
73+
74+
@Override
75+
public String toString() {
76+
return (
77+
"<querycomp>:" +
78+
this.result.getLocation().fileName() +
79+
":" +
80+
this.result.getLocation().startLine()
81+
);
82+
}
83+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//
2+
// Copyright (C) 2005-2025, AdaCore
3+
// SPDX-License-Identifier: GPL-3.0-or-later
4+
//
5+
6+
package com.adacore.lkql_jit.runtime.values.lists;
7+
8+
import com.adacore.lkql_jit.runtime.Closure;
9+
import com.adacore.lkql_jit.runtime.ListStorage;
10+
import com.adacore.lkql_jit.runtime.values.interfaces.Iterable;
11+
import com.adacore.lkql_jit.runtime.values.interfaces.Iterator;
12+
import com.oracle.truffle.api.CompilerDirectives;
13+
import com.oracle.truffle.api.nodes.DirectCallNode;
14+
import com.oracle.truffle.api.nodes.RootNode;
15+
16+
/**
17+
* This class represents either
18+
* - a list comprehension value
19+
* - a query expression
20+
* in the LKQL language.
21+
* It allows iterating on a source and
22+
* - pattern matching values from it
23+
* - checking against a boolean guard
24+
* - mapping to another value
25+
* and returns the result as a LazyList.
26+
*/
27+
public final class LKQLQueryComprehension extends BaseLKQLLazyList {
28+
29+
// ----- Attributes -----
30+
31+
private final DirectCallNode callNode;
32+
33+
private final Iterator iterator;
34+
35+
/**
36+
* argument[0] is the closure
37+
* argument[1] is the iterator result
38+
*/
39+
private final Object[] arguments = new Object[2];
40+
41+
// ----- Constructors -----
42+
43+
@CompilerDirectives.TruffleBoundary
44+
public LKQLQueryComprehension(
45+
final RootNode rootNode,
46+
final Closure closure,
47+
final Iterable source
48+
) {
49+
super(new ListStorage<>(1));
50+
this.callNode = DirectCallNode.create(rootNode.getCallTarget());
51+
this.iterator = source.iterator();
52+
this.arguments[0] = closure.getContent();
53+
}
54+
55+
// ----- Lazy list required methods -----
56+
57+
@Override
58+
protected void initCacheTo(long n) {
59+
while ((n < 0 || this.cache.size() <= n) && iterator.hasNext()) {
60+
this.arguments[1] = iterator.next();
61+
Object value = this.callNode.call(this.arguments);
62+
if (value != null) this.cache.append(value);
63+
}
64+
}
65+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# lkql version: 2
2+
3+
fun debug_id(i : Any) : Any = {
4+
val dummy = print("In debug_id: " & img(i));
5+
i
6+
}
7+
8+
# When created, the iterator should be unevaluated completely
9+
val itt = (from [1, 2, 3, 4, 5] match i select debug_id(i*i) if i > 2)
10+
11+
val _ = print("Evaluating itt up to index 2")
12+
val _ = print(itt[2])
13+
val _ = print("Querying element 1, already computed")
14+
val _ = print(itt[1])
15+
val _ = print("Calling to_list will consume itt entirely")
16+
val _ = print(itt.to_list)
17+
18+
fun dfs(this : Any) : Any = match this {
19+
case AdaNode => [this] & concat((from this.children match c select dfs(c)).to_list)
20+
case _ => []
21+
}
22+
23+
val root = units()[1].root
24+
val _ = print(dfs(root))
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Evaluating itt up to index 2
2+
In debug_id: 9
3+
In debug_id: 16
4+
16
5+
Querying element 1, already computed
6+
9
7+
Calling to_list will consume itt entirely
8+
In debug_id: 25
9+
[9, 16, 25]
10+
[<CompilationUnit main.adb:1:1-8:10>, <AdaNodeList main.adb:1:1-1:35>, <WithClause main.adb:1:1-1:18>, <LimitedAbsent main.adb:1:1-1:1>, <PrivateAbsent main.adb:1:1-1:1>, <NameList main.adb:1:6-1:17>, <DottedName main.adb:1:6-1:17>, <Id "Ada" main.adb:1:6-1:9>, <Id "Text_IO" main.adb:1:10-1:17>, <UsePackageClause main.adb:1:19-1:35>, <NameList main.adb:1:23-1:34>, <DottedName main.adb:1:23-1:34>, <Id "Ada" main.adb:1:23-1:26>, <Id "Text_IO" main.adb:1:27-1:34>, <LibraryItem main.adb:3:1-8:10>, <PrivateAbsent main.adb:1:35-1:35>, <SubpBody ["Main"] main.adb:3:1-8:10>, <OverridingUnspecified main.adb:1:35-1:35>, <SubpSpec main.adb:3:1-3:15>, <SubpKindProcedure main.adb:3:1-3:10>, <DefiningName "Main" main.adb:3:11-3:15>, <Id "Main" main.adb:3:11-3:15>, <DeclarativePart main.adb:3:18-6:1>, <AdaNodeList main.adb:4:4-5:66>, <ObjectDecl ["Message"] main.adb:4:4-4:48>, <DefiningNameList main.adb:4:4-4:11>, <DefiningName "Message" main.adb:4:4-4:11>, <Id "Message" main.adb:4:4-4:11>, <AliasedAbsent main.adb:4:21-4:21>, <ConstantAbsent main.adb:4:21-4:21>, <ModeDefault main.adb:4:21-4:21>, <SubtypeIndication main.adb:4:22-4:28>, <NotNullAbsent main.adb:4:21-4:21>, <Id "String" main.adb:4:22-4:28>, <Str ""Hello World !"" main.adb:4:32-4:47>, <ObjectDecl ["ConstantMessage"] main.adb:5:4-5:66>, <DefiningNameList main.adb:5:4-5:19>, <DefiningName "ConstantMessage" main.adb:5:4-5:19>, <Id "ConstantMessage" main.adb:5:4-5:19>, <AliasedAbsent main.adb:5:21-5:21>, <ConstantPresent main.adb:5:22-5:30>, <ModeDefault main.adb:5:30-5:30>, <SubtypeIndication main.adb:5:31-5:37>, <NotNullAbsent main.adb:5:30-5:30>, <Id "String" main.adb:5:31-5:37>, <Str ""Hello Constant world !"" main.adb:5:41-5:65>, <HandledStmts main.adb:6:6-8:1>, <StmtList main.adb:7:4-7:23>, <CallStmt main.adb:7:4-7:23>, <CallExpr main.adb:7:4-7:22>, <Id "Put_Line" main.adb:7:4-7:12>, <AssocList main.adb:7:14-7:21>, <ParamAssoc main.adb:7:14-7:21>, <Id "Message" main.adb:7:14-7:21>, <AdaNodeList main.adb:7:23-7:23>, <EndName main.adb:8:5-8:9>, <Id "Main" main.adb:8:5-8:9>, <PragmaNodeList main.adb:8:10-8:10>]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
driver: 'interpreter'
2+
project: 'default_project/default.gpr'
3+
lkt_refactor: False # this is already an LKQL_V2 test

0 commit comments

Comments
 (0)