Skip to content

Commit fb602a7

Browse files
author
Killian Perlin
committed
Refactoring for query comprehension
1 parent c349fe3 commit fb602a7

File tree

5 files changed

+241
-12
lines changed

5 files changed

+241
-12
lines changed

lkql_jit/options/src/main/java/com/adacore/lkql_jit/options/Refactorings/LKQLToLkt.java

Lines changed: 178 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import static com.adacore.liblkqllang.Liblkqllang.Token.textRange;
99

1010
import com.adacore.liblkqllang.Liblkqllang;
11+
import java.util.ArrayList;
12+
import java.util.stream.Collectors;
1113

1214
public class LKQLToLkt implements TreeBasedRefactoring {
1315

@@ -42,11 +44,13 @@ private String refactorNode(Liblkqllang.LkqlNode node) {
4244
case Liblkqllang.SelectorArm arm -> refactorArm(arm, arm.fPattern(), arm.fExpr());
4345
case Liblkqllang.SelectorDecl selectorDecl -> refactorSelectorDecl(selectorDecl);
4446
case Liblkqllang.RecExpr recExpr -> refactorRecExpr(recExpr);
47+
case Liblkqllang.Query query -> refactorQuery(query);
48+
case Liblkqllang.ListComprehension comprehension -> refactorListComprehension(
49+
comprehension
50+
);
4551
case Liblkqllang.BlockBodyExpr bbe -> "val _ = " + refactorGeneric(bbe);
4652
case Liblkqllang.UnitLiteral _ -> "Unit()";
47-
case Liblkqllang.Expr expr when (
48-
expr.parent() instanceof Liblkqllang.TopLevelList
49-
) -> "val _ = " + refactorGeneric(expr);
53+
case Liblkqllang.TopLevelList topLevel -> refactorTopLevelList(topLevel);
5054
case Liblkqllang.UniversalPattern _ -> "_";
5155
default -> refactorGeneric(node);
5256
};
@@ -55,6 +59,10 @@ case Liblkqllang.Expr expr when (
5559
/**
5660
* Copy all the text belonging to a node in the input source,
5761
* but recursively refactor the code of its children.
62+
* Ex:
63+
* - node = SomeNode(... # comment 1\n ... # comment 2\n ...)
64+
* - returns = "# comment 1\n#comment 2\n"
65+
*
5866
*/
5967
private String refactorGeneric(Liblkqllang.LkqlNode node) {
6068
if (node.isTokenNode()) return node.getText();
@@ -78,6 +86,43 @@ private String refactorGeneric(Liblkqllang.LkqlNode node) {
7886
return s.toString();
7987
}
8088

89+
/**
90+
* Takes a node and returns the concatenation of all its comments
91+
* as a block of text.
92+
*/
93+
private String getAllComments(Liblkqllang.LkqlNode node) {
94+
return Refactoring.streamFrom(node.tokenStart())
95+
.takeWhile(tok -> tok.tokenIndex < node.tokenEnd().tokenIndex)
96+
.filter(tok -> tok.isTrivia() && !tok.getText().isBlank())
97+
.map(tok -> tok.getText() + "\n")
98+
.collect(Collectors.joining());
99+
}
100+
101+
private String refactorTopLevelList(Liblkqllang.TopLevelList topLevel) {
102+
var s = new StringBuilder();
103+
var cursor = topLevel.tokenStart();
104+
105+
for (int i = 0; i < topLevel.getChildrenCount(); i++) {
106+
final var child = topLevel.getChild(i);
107+
if (child.isNone() || child.isGhost()) continue;
108+
// copy until child
109+
s.append(textRange(cursor, child.tokenStart().previous()));
110+
// copy child
111+
112+
if (child instanceof Liblkqllang.Expr) {
113+
s.append("val _ = ");
114+
}
115+
s.append(refactorNode(child));
116+
// fast forward token cursor after child
117+
cursor = child.tokenEnd().next();
118+
}
119+
120+
// copy until end
121+
s.append(textRange(cursor, topLevel.tokenEnd()));
122+
123+
return s.toString();
124+
}
125+
81126
/*
82127
*
83128
* fun <name> <funexpr>
@@ -249,10 +294,10 @@ private String refactorSelectorDecl(Liblkqllang.SelectorDecl selectorDecl) {
249294
*
250295
* 2) Case disjonction
251296
*
252-
* rec( <left>, <right>) --> <right> :: <selector>(<left>)
253-
* rec(*<left>, <right>) --> <right> :: <left>.iterator.flatMap(<selector>)
254-
* rec( <left>, *<right>) --> <right>.iterator ::: <selector>(<left>)
255-
* rec(*<left>, *<right>) --> <right>.iterator ::: <left>.iterator.flatMap(<selector>)
297+
* rec( <left>, <right>) --> <right> :: <selector>(<left>)
298+
* rec(*<left>, <right>) --> <right> :: <left>.flat_map(<selector>)
299+
* rec( <left>, *<right>) --> <right> ::: <selector>(<left>)
300+
* rec(*<left>, *<right>) --> <right> ::: <left>.flat_map(<selector>)
256301
*
257302
*/
258303
private String refactorRecExpr(Liblkqllang.RecExpr recExpr) {
@@ -264,7 +309,7 @@ private String refactorRecExpr(Liblkqllang.RecExpr recExpr) {
264309
final var left = recExpr.fRecurseExpr();
265310
final var right = hasRight ? recExpr.fResultExpr() : left;
266311

267-
var s = unpackRight ? refactorNode(right) + ".iterator :::" : refactorNode(right) + " ::";
312+
var s = unpackRight ? refactorNode(right) + " :::" : refactorNode(right) + " ::";
268313

269314
// try to preserve spacing after "," (any newline for example)
270315
if (hasRight && left.tokenEnd().next().getText().equals(",")) {
@@ -275,9 +320,133 @@ private String refactorRecExpr(Liblkqllang.RecExpr recExpr) {
275320
}
276321

277322
s += unpackLeft
278-
? refactorNode(left) + ".iterator.flatMap(" + currentSelector.fName().getText() + ")"
323+
? refactorNode(left) + ".flat_map(" + currentSelector.fName().getText() + ")"
279324
: currentSelector.fName().getText() + "(" + refactorNode(left) + ")";
280325

281326
return s;
282327
}
328+
329+
/*
330+
* select <pattern>
331+
* from all_nodes match <pattern>
332+
*
333+
* from <expr> through <selector> select <pattern>
334+
* from <selector(expr)> match <pattern>
335+
*
336+
* Heuristics:
337+
* from <expr> through <selector> select <pattern> (where <expr> is plural)
338+
* from <expr>.flat_map(<selector>) match <pattern>
339+
*
340+
* If first keyword:
341+
* from <expr> select first <pattern>
342+
* (from <expr> match <pattern>).head
343+
*
344+
*/
345+
private String refactorQuery(Liblkqllang.Query query) {
346+
final var fromNode = query.fFromExpr();
347+
final var throughNode = query.fThroughExpr();
348+
349+
final String source;
350+
351+
if (fromNode.isNone()) {
352+
source = throughNode.isNone()
353+
? "all_nodes"
354+
: "units().flat_map((unit) => " + refactorNode(throughNode) + "(unit.root))";
355+
} else {
356+
final var from = refactorNode(fromNode);
357+
final var through = throughNode.isNone() ? "children" : refactorNode(throughNode);
358+
359+
// best effort heuristic to cover common cases
360+
final var isPlural =
361+
switch (fromNode) {
362+
case Liblkqllang.ListLiteral _ -> true;
363+
case Liblkqllang.ListComprehension _ -> true;
364+
case Liblkqllang.DotAccess dot -> dot.fMember().getText().equals("children");
365+
default -> false;
366+
};
367+
368+
source = isPlural
369+
? "(" + from + ").flat_map(" + through + ")"
370+
: through + "(" + from + ")";
371+
}
372+
373+
var s = "from " + source + " match " + refactorNode(query.fPattern());
374+
375+
if (query.fQueryKind() instanceof Liblkqllang.QueryKindFirst) {
376+
s = "(" + s + ").head";
377+
}
378+
379+
return getAllComments(query) + s;
380+
}
381+
382+
/*
383+
*
384+
* [ <expr> for <binding> in <source> if <guard> ]
385+
* from <source> match <binding> select <expr> if <guard>
386+
*
387+
* Multiple generators is handled as follow:
388+
* [ <expr> for <x_1> in <src_1>, ..., <x_n> in <src_n> if <guard> ]
389+
* <src_1>.flat_map(<x_1> => ... <src_n>.flat_map(<x_n> => if <guard> then [<expr>] else []))
390+
*
391+
*/
392+
private String refactorListComprehension(Liblkqllang.ListComprehension comprehension) {
393+
final var hasGuard = !comprehension.fGuard().isNone();
394+
final var sb = new StringBuilder();
395+
396+
final int nbSources = comprehension.fGenerators().getChildrenCount();
397+
398+
final var generators = new ArrayList<Liblkqllang.ListCompAssoc>();
399+
comprehension.fGenerators().iterator().forEachRemaining(generators::add);
400+
401+
sb.append(getAllComments(comprehension));
402+
403+
// default case
404+
if (nbSources == 1) {
405+
sb.append("from ");
406+
sb.append(refactorNode(generators.get(0).fCollExpr()));
407+
sb.append(" match ");
408+
sb.append(refactorNode(generators.get(0).fBindingName()));
409+
sb.append(" select ");
410+
sb.append(refactorNode(comprehension.fExpr()));
411+
412+
if (hasGuard) {
413+
sb.append(" if ");
414+
sb.append(refactorNode(comprehension.fGuard()));
415+
}
416+
}
417+
// special handling for multiple sources
418+
else {
419+
// open lambda for each source
420+
for (final var generator : comprehension.fGenerators()) {
421+
// simple heuristic to reduce parenthesis bloat
422+
if (generator.fCollExpr().isTokenNode()) {
423+
sb.append(generator.fCollExpr().getText());
424+
} else {
425+
sb.append("(");
426+
sb.append(refactorNode(generator.fCollExpr()));
427+
sb.append(")");
428+
}
429+
sb.append(".flat_map((");
430+
sb.append(refactorNode(generator.fBindingName()));
431+
sb.append(") => ");
432+
}
433+
434+
if (hasGuard) {
435+
sb.append("if ");
436+
sb.append(refactorNode(comprehension.fGuard()));
437+
sb.append(" then ");
438+
}
439+
sb.append("[");
440+
sb.append(refactorNode(comprehension.fExpr()));
441+
sb.append("]");
442+
if (hasGuard) {
443+
sb.append(" else []");
444+
}
445+
446+
// balance parenthesis, closing lambdas
447+
sb.repeat(')', nbSources);
448+
}
449+
450+
return "(" + sb.toString() + ")";
451+
}
283452
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
[a * 2 for a in int_list if is_prime(a)]
2+
3+
print(
4+
[
5+
o.image & " " & st.image
6+
for o in objects, st in subtypes
7+
if (o.image & " " & st.image).length != 64
8+
].to_list
9+
)
10+
11+
# Will select all non null nodes
12+
select AdaNode
13+
14+
# Select all non null nodes starting from node a
15+
from a select AdaNode
16+
17+
# Select all non null nodes starting from all nodes in list
18+
from [a, b, c] select AdaNode
19+
20+
# Select first basic declaration
21+
select # useless comment 1
22+
first # useless comment 2
23+
BasicDecl # useless comment 3
24+
25+
# Selects the parents of the first basic declaration
26+
from (select first BasicDecl) through parent select *
27+
28+
# Selects all nodes following generic instantiations
29+
through follow_generics select *
30+
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# lkql version: 2
2+
3+
val _ = (from int_list match a select a * 2 if is_prime(a))
4+
5+
val _ = print(
6+
(objects.flat_map((o) => subtypes.flat_map((st) => if (o.image & " " & st.image).length != 64 then [o.image & " " & st.image] else []))).to_list
7+
)
8+
9+
# Will select all non null nodes
10+
val _ = from all_nodes match AdaNode
11+
12+
# Select all non null nodes starting from node a
13+
val _ = from children(a) match AdaNode
14+
15+
# Select all non null nodes starting from all nodes in list
16+
val _ = from ([a, b, c]).flat_map(children) match AdaNode
17+
18+
# Select first basic declaration
19+
val _ = # useless comment 1
20+
# useless comment 2
21+
(from all_nodes match BasicDecl).head # useless comment 3
22+
23+
# Selects the parents of the first basic declaration
24+
val _ = from parent((from all_nodes match BasicDecl).head) match _
25+
26+
# Selects all nodes following generic instantiations
27+
val _ = from units().flat_map((unit) => follow_generics(unit.root)) match _
28+
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
driver: refactor
2+
refactoring: TO_LKQL_V2

testsuite/tests/refactor/selectors/test.out

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@ fun parents(this : Any) : Any = match this {
77

88

99
fun super_types(this : Any) : Any = match this {
10-
case BaseTypeDecl => this.p_base_types().iterator ::: this.p_base_types().iterator.flatMap(super_types)
10+
case BaseTypeDecl => this.p_base_types() ::: this.p_base_types().flat_map(super_types)
1111
case _ => Unit()
1212
}
1313

1414

1515
fun children(this : Any) : Any = match this {
16-
case AdaNode => this :: this.children.iterator.flatMap(children)
16+
case AdaNode => this :: this.children.flat_map(children)
1717
case _ => Unit()
1818
}
1919

2020

2121
fun foo(this : Any) : Any = match this {
22-
case Foo => [s, o, m, e, t, h, i, n, g].iterator ::: foo(this)
22+
case Foo => [s, o, m, e, t, h, i, n, g] ::: foo(this)
2323
}
2424

2525

0 commit comments

Comments
 (0)