From 0abbf187018db14aae3b5ee8e2ada108c246cfad Mon Sep 17 00:00:00 2001 From: Nadav Samet Date: Sun, 2 Aug 2020 10:14:56 -0700 Subject: [PATCH] Add support for JVM plugin --- .../src/main/scala/scalapb/ScalaPBC.scala | 131 ++++++++++++++---- 1 file changed, 102 insertions(+), 29 deletions(-) diff --git a/scalapbc/src/main/scala/scalapb/ScalaPBC.scala b/scalapbc/src/main/scala/scalapb/ScalaPBC.scala index d6382ecf5..d87a3d231 100644 --- a/scalapbc/src/main/scala/scalapb/ScalaPBC.scala +++ b/scalapbc/src/main/scala/scalapb/ScalaPBC.scala @@ -6,6 +6,11 @@ import protocbridge.{ProtocBridge, ProtocCodeGenerator} import coursier.parse.DependencyParser import coursier.core.Configuration import com.github.ghik.silencer.silent +import coursier.core.Dependency +import java.net.URLClassLoader +import java.util.jar.JarInputStream +import java.io.FileInputStream +import protocbridge.SandboxedJvmGenerator case class Config( version: String = "-v" + scalapb.compiler.Version.protobufVersion, @@ -13,7 +18,8 @@ case class Config( args: Seq[String] = Seq.empty, customProtocLocation: Option[String] = None, namedGenerators: Seq[(String, ProtocCodeGenerator)] = Seq("scala" -> ScalaPbCodeGenerator), - artifacts: Seq[String] = Seq.empty + executableArtifacts: Seq[String] = Seq.empty, + jvmPlugins: Seq[(String, String)] = Seq.empty ) class ScalaPbcException(msg: String) extends RuntimeException(msg) @@ -22,6 +28,7 @@ object ScalaPBC { private val CustomPathArgument = "--protoc=" private val CustomGenArgument = "--custom-gen=" private val PluginArtifactArgument = "--plugin-artifact=" + private val JvmPluginArgument = "--jvm-plugin=" def processArgs(args: Array[String]): Config = { case class State(cfg: Config, passThrough: Boolean) @@ -40,6 +47,11 @@ object ScalaPBC { cfg = state.cfg.copy(namedGenerators = state.cfg.namedGenerators :+ (genName -> gen)) ) + case (false, p) if p.startsWith(JvmPluginArgument) => + val Array(genName, artifactName) = p.substring(JvmPluginArgument.length).split('=') + state.copy( + cfg = state.cfg.copy(jvmPlugins = state.cfg.jvmPlugins :+ (genName -> artifactName)) + ) case (false, p) if p.startsWith(CustomPathArgument) => state.copy( cfg = state.cfg @@ -48,8 +60,8 @@ object ScalaPBC { case (false, p) if p.startsWith(PluginArtifactArgument) => state.copy(cfg = state.cfg - .copy(artifacts = - state.cfg.artifacts :+ p.substring(PluginArtifactArgument.length()) + .copy(executableArtifacts = + state.cfg.executableArtifacts :+ p.substring(PluginArtifactArgument.length()) ) ) case (false, v) if v.startsWith("-v") => state.copy(cfg = state.cfg.copy(version = v)) @@ -61,36 +73,97 @@ object ScalaPBC { } @silent("method right in class Either is deprecated") - def main(args: Array[String]): Unit = { + def fetchArtifact(artifact: String): Either[String, (Dependency, Seq[File])] = { import coursier._ - val config = processArgs(args) + for { + dep <- DependencyParser + .dependency( + artifact, + scala.util.Properties.versionNumberString, + Configuration.empty + ) + .right + runResult = Fetch().addDependencies(dep).run() + outcome <- if (runResult.isEmpty) Left(s"Could not find artifact for $artifact") + else Right(runResult) + } yield (dep, outcome) + } + + def fetchArtifacts( + artifacts: Seq[(String, String)] + ): Either[String, Seq[(String, (Dependency, Seq[File]))]] = + artifacts.foldLeft[Either[String, Seq[(String, (Dependency, Seq[File]))]]](Right(Seq())) { + case (Left(error), _) => Left(error) + case (Right(result), (name, artifact)) => + fetchArtifact(artifact) match { + case Right((dep, files)) => Right(result :+ ((name, (dep, files)))) + case Left(error) => Left(error) + } + } + + def findMainClass(f: File): Either[String, String] = { + val jin = new JarInputStream(new FileInputStream(f)) + try { + val manifest = jin.getManifest() + Option(manifest.getMainAttributes().getValue("Main-Class")) + .toRight("Could not find main class for plugin") + .map(_ + "$") + } finally { + jin.close() + } + } - val arts: Seq[Either[String, (Dependency, java.io.File)]] = for { - art <- config.artifacts - maybeDep = DependencyParser.dependency( - art, - scala.util.Properties.versionNumberString, - Configuration.empty + @silent("method right in class Either is deprecated") + def main(args: Array[String]): Unit = { + val config = processArgs(args) + if (config.namedGenerators + .map(_._1) + .toSet + .intersect(config.jvmPlugins.map(_._1).toSet) + .nonEmpty) { + throw new RuntimeException( + s"Same plugin name provided by $PluginArtifactArgument and $JvmPluginArgument" ) - } yield for { - dep <- maybeDep.right - elm <- Fetch().addDependencies(dep).run() match { - case List(elm) => Right(elm) - case _ => Left(s"Could not find artifact for $art") + } + + def fatalError(err: String): Nothing = { + if (config.throwException) { + throw new ScalaPbcException(s"Error: $err") + } else { + System.err.println(err) + sys.exit(1) } - } yield dep -> elm - - val pluginArgs: Seq[String] = arts.map { - case Left(err) => - if (!config.throwException) { - System.err.println(err) - sys.exit(1) - } else { - throw new ScalaPbcException(s"Error: $err") + } + + val jvmGenerators = fetchArtifacts( + config.jvmPlugins + ) match { + case Left(error) => fatalError(error) + case Right(arts) => + arts.map { + case (name, (_, files)) => + val urls = files.map(_.toURI().toURL()).toArray + val loader = new URLClassLoader(urls, null) + val mainClass = findMainClass(files.head) match { + case Right(v) => v + case Left(err) => fatalError(err) + } + name -> SandboxedJvmGenerator.load(mainClass, loader) + } + } + + val pluginArgs = fetchArtifacts( + config.executableArtifacts.map(a => ("", a)) + ) match { + case Left(error) => fatalError(error) + case Right(arts) => + arts.map { + case (_, (dep, file :: Nil)) => + file.setExecutable(true) + s"--plugin=${dep.module.name.value}=${file.getAbsolutePath()}" + case (_, (dep, files)) => + fatalError(s"Got ${files.length} files for dependency $dep. Only one expected.") } - case Right((dep, file)) => - file.setExecutable(true) - s"--plugin=${dep.module.name.value}=${file.getAbsolutePath()}" } val code = ProtocBridge.runWithGenerators( @@ -103,7 +176,7 @@ object ScalaPBC { case None => a => com.github.os72.protocjar.Protoc.runProtoc(config.version +: a.toArray) }, - namedGenerators = config.namedGenerators, + namedGenerators = config.namedGenerators ++ jvmGenerators, params = config.args ++ pluginArgs )