Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,21 @@
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import ru.tinkoff.kora.annotation.processor.common.TestUtils.ProcessorOptions;
import ru.tinkoff.kora.annotation.processor.common.compile.ByteArrayJavaFileObject;
import ru.tinkoff.kora.annotation.processor.common.compile.KoraCompileTestJavaFileManager;
import ru.tinkoff.kora.application.graph.*;

import javax.annotation.processing.Processor;
import javax.tools.Diagnostic;
import javax.tools.JavaCompiler;
import javax.tools.JavaFileObject;
import javax.tools.ToolProvider;
import java.io.IOException;
import java.io.StringWriter;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Flow;
Expand All @@ -39,9 +35,6 @@

@TestInstance(TestInstance.Lifecycle.PER_METHOD)
public abstract class AbstractAnnotationProcessorTest {

private final JavaCompiler javaCompiler = ToolProvider.getSystemJavaCompiler();

protected TestInfo testInfo;
protected CompileResult compileResult;

Expand Down Expand Up @@ -83,71 +76,58 @@ protected String commonImports() {
}

protected CompileResult compile(List<Processor> processors, @Language("java") String... sources) {
return compile(processors, Collections.emptyList(), sources);
}

protected CompileResult compile(List<Processor> processors, List<ProcessorOptions> processorOptions, @Language("java") String... sources) {
var w = new StringWriter();
var diagnostic = new ArrayList<Diagnostic<? extends JavaFileObject>>();
var testPackage = testPackage();
var testClass = this.testInfo.getTestClass().get();
var testMethod = this.testInfo.getTestMethod().get();
var commonImports = this.commonImports();
var sourceList = Arrays.stream(sources).map(s -> "package %s;\n%s\n/**\n* @see %s#%s \n*/\n".formatted(testPackage, commonImports, testClass.getCanonicalName(), testMethod.getName()) + s)
.map(s -> {
var prefixes = List.of("class ", "interface ", "@interface ", "record ", "enum ");
var firstClass = prefixes.stream()
.map(p -> Map.entry(s.indexOf(p), p.length()))
.filter(e -> e.getKey() >= 0)
.map(e -> e.getKey() + e.getValue())
.min(Comparator.comparing(Function.identity()))
.map(classStart -> {
var firstSpace = s.indexOf(" ", classStart + 1);
var firstBracket = s.indexOf("(", classStart + 1);
var firstSquareBracket = s.indexOf("{", classStart + 1);
var classEnd = IntStream.of(firstSpace, firstBracket, firstSquareBracket)
.filter(i -> i >= 0)
.min()
.getAsInt();
var className = s.substring(classStart, classEnd).trim();
int generic = className.indexOf('<');
if (generic == -1) {
return className;
} else {
return className.substring(0, generic);
}
})
.get();

return new ByteArrayJavaFileObject(JavaFileObject.Kind.SOURCE, testPackage + "." + firstClass, s.getBytes(StandardCharsets.UTF_8));
})
.toList();

try (var delegate = javaCompiler.getStandardFileManager(diagnostic::add, Locale.US, StandardCharsets.UTF_8);
var manager = new KoraCompileTestJavaFileManager(this.testInfo, delegate, sourceList.toArray(ByteArrayJavaFileObject[]::new))) {

var defaultOptions = new LinkedHashSet<>(List.of("--release", "17", "-XprintRounds"));
defaultOptions.addAll(processorOptions.stream().map(o -> o.value).toList());

var task = javaCompiler.getTask(
w,
manager,
diagnostic::add,
defaultOptions,
null,
sourceList
);
task.setProcessors(processors);
task.setLocale(Locale.US);
task.call();
w.close();
return this.compileResult = new CompileResult(testPackage, diagnostic, manager);
} catch (RuntimeException e) {
if (e.getCause() instanceof RuntimeException er) {
throw er;
var sourceList = new ArrayList<Path>();
for (var source : sources) {
var string = "package %s;\n%s\n/**\n* @see %s#%s \n*/\n".formatted(testPackage, commonImports, testClass.getCanonicalName(), testMethod.getName()) + source;
var prefixes = List.of("class ", "interface ", "@interface ", "record ", "enum ");
var firstClass = prefixes.stream()
.map(p -> Map.entry(string.indexOf(p), p.length()))
.filter(e -> e.getKey() >= 0)
.map(e -> e.getKey() + e.getValue())
.min(Comparator.comparing(Function.identity()))
.map(classStart -> {
var firstSpace = string.indexOf(" ", classStart + 1);
var firstBracket = string.indexOf("(", classStart + 1);
var firstSquareBracket = string.indexOf("{", classStart + 1);
var classEnd = IntStream.of(firstSpace, firstBracket, firstSquareBracket)
.filter(i -> i >= 0)
.min()
.getAsInt();
var className = string.substring(classStart, classEnd).trim();
int generic = className.indexOf('<');
if (generic == -1) {
return className;
} else {
return className.substring(0, generic);
}
})
.get();
var className = testPackage + "." + firstClass;
var path = Paths.get(".", "build", "in-test-generated", "sources").resolve(className.replace('.', '/') + ".java");
try {
Files.createDirectories(path.getParent());
Files.write(path, string.getBytes(StandardCharsets.UTF_8));
} catch (IOException e) {
throw new RuntimeException(e);
}
sourceList.add(path);
}

try {
var jc = new JavaCompilation()
.withSources(sourceList)
.withProcessors(processors);
var cl = jc.compile();
return this.compileResult = new CompileResult(testPackage, jc.diagnostics(), cl);
} catch (TestUtils.CompilationErrorException e) {
return this.compileResult = new CompileResult(testPackage, e.diagnostics, null);
} catch (RuntimeException e) {
throw e;
} catch (IOException e) {
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
package ru.tinkoff.kora.annotation.processor.common;


import ru.tinkoff.kora.annotation.processor.common.compile.KoraCompileTestJavaFileManager;

import javax.tools.Diagnostic;
import javax.tools.FileObject;
import javax.tools.JavaFileObject;
import javax.tools.StandardLocation;
import java.io.IOException;
import java.io.StringWriter;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;

public record CompileResult(String testPackage, List<Diagnostic<? extends JavaFileObject>> diagnostic, KoraCompileTestJavaFileManager manager) {
public record CompileResult(String testPackage, List<Diagnostic<? extends JavaFileObject>> diagnostic, ClassLoader cl) {
public boolean isFailed() {
return this.diagnostic.stream()
.anyMatch(d -> d.getKind() == Diagnostic.Kind.ERROR);
Expand All @@ -37,13 +35,9 @@ public List<Diagnostic<? extends JavaFileObject>> errors() {
}


public FileObject generatedSourceFile(String className) throws IOException {
return this.manager.getFileForInput(StandardLocation.SOURCE_OUTPUT, this.testPackage, className);
}

public Class<?> loadClass(String className) {
try {
return this.manager.getClassLoader(StandardLocation.CLASS_OUTPUT).loadClass(this.testPackage + "." + className);
return cl.loadClass(this.testPackage + "." + className);
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
}
Expand All @@ -56,21 +50,23 @@ public CompilationFailedException(String message) {
}

public RuntimeException compilationException() {
var diagnosticMap = new IdentityHashMap<JavaFileObject, Map<Long, List<Diagnostic<? extends JavaFileObject>>>>();
var diagnosticMap = new HashMap<Path, Map<Long, List<Diagnostic<? extends JavaFileObject>>>>();
for (var d : this.diagnostic) {
var map = diagnosticMap.computeIfAbsent(d.getSource(), o -> new HashMap<>());
var map = diagnosticMap.computeIfAbsent(Path.of(d.getSource().toUri()).toAbsolutePath(), o -> new HashMap<>());
map.computeIfAbsent(d.getLineNumber(), l -> new ArrayList<>()).add(d);
}

try {
var j = new StringJoiner("\n", "\n", "\n");
for (var javaFileObject : this.manager.list(StandardLocation.SOURCE_OUTPUT, "", Set.of(JavaFileObject.Kind.SOURCE), true)) {
var diagnostic = diagnosticMap.getOrDefault(javaFileObject, Map.of());
j.add(javaFileObject.getName()).add(javaFileToString(javaFileObject, diagnostic));
var generatedSources = Files.walk(Path.of("build/in-test-generated/sources")).filter(Files::isRegularFile).toList();
for (var src : generatedSources) {
var diagnostic = diagnosticMap.getOrDefault(src.toAbsolutePath(), Map.of());
j.add(src.toString()).add(javaFileToString(src, diagnostic));
}
for (var javaFileObject : this.manager.list(StandardLocation.SOURCE_PATH, "", Set.of(JavaFileObject.Kind.SOURCE), true)) {
var sources = Files.walk(Paths.get(".", "build", "in-test-generated", "sources")).filter(Files::isRegularFile).toList();
for (var javaFileObject : sources) {
var diagnostic = diagnosticMap.getOrDefault(javaFileObject, Map.of());
j.add(javaFileObject.getName()).add(javaFileToString(javaFileObject, diagnostic));
j.add(javaFileObject.toString()).add(javaFileToString(javaFileObject, diagnostic));
}

var errors = this.diagnostic.stream()
Expand All @@ -84,20 +80,16 @@ public RuntimeException compilationException() {

}

private static String javaFileToString(JavaFileObject object, Map<Long, List<Diagnostic<? extends JavaFileObject>>> diagnostic) throws IOException {
private static String javaFileToString(Path object, Map<Long, List<Diagnostic<? extends JavaFileObject>>> diagnostic) throws IOException {
var j = new StringJoiner("\n", "\n", "\n");
try (var r = object.openReader(true);
var sw = new StringWriter()) {
r.transferTo(sw);
sw.flush();
var lines = sw.toString().lines().toList();
for (int i = 0; i < lines.size(); i++) {
var lineDiagnostic = diagnostic.getOrDefault((long) i + 1, List.of());
j.add("%03d | %s".formatted(i, lines.get(i)));
for (var d : lineDiagnostic) {
var diagnosticString = " ".repeat(((int) d.getColumnNumber()) - 1) + "^ " + d.getMessage(Locale.US);
j.add(diagnosticString.indent(6));
}
var lines = Files.readAllLines(object);

for (int i = 0; i < lines.size(); i++) {
var lineDiagnostic = diagnostic.getOrDefault((long) i + 1, List.of());
j.add("%03d | %s".formatted(i, lines.get(i)));
for (var d : lineDiagnostic) {
var diagnosticString = " ".repeat(((int) d.getColumnNumber()) - 1) + "^ " + d.getMessage(Locale.US);
j.add(diagnosticString.indent(6));
}
}
return j.toString();
Expand Down
Loading