diff --git a/decls/haskell_rules.bzl b/decls/haskell_rules.bzl index c3d9a75eb..6fbd82da2 100644 --- a/decls/haskell_rules.bzl +++ b/decls/haskell_rules.bzl @@ -14,7 +14,6 @@ load("@prelude//linking:types.bzl", "Linkage") load(":common.bzl", "LinkableDepType", "buck", "prelude_rule") load(":haskell_common.bzl", "haskell_common") load(":native_common.bzl", "native_common") -load("@prelude//haskell/worker/worker.bzl", "worker_libs", "worker_srcs", "worker_flags") haskell_binary = prelude_rule( name = "haskell_binary", @@ -67,9 +66,7 @@ haskell_binary = prelude_rule( "linker_flags": attrs.list(attrs.arg(), default = []), "platform": attrs.option(attrs.string(), default = None), "platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []), - "_worker_srcs": attrs.list(attrs.source(), default = worker_srcs), - "_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]), - "_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags), + "allow_worker": attrs.bool(default = True), } ), ) @@ -192,9 +189,7 @@ haskell_library = prelude_rule( "linker_flags": attrs.list(attrs.arg(), default = []), "platform": attrs.option(attrs.string(), default = None), "platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []), - "_worker_srcs": attrs.list(attrs.source(), default = worker_srcs), - "_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]), - "_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags), + "allow_worker": attrs.bool(default = True), } ), ) diff --git a/haskell/compile.bzl b/haskell/compile.bzl index 9ec3655f3..71ce828bc 100644 --- a/haskell/compile.bzl +++ b/haskell/compile.bzl @@ -437,12 +437,24 @@ def _common_compile_module_args( external_tool_paths: list[RunInfo], sources: list[Artifact], direct_deps_info: list[HaskellLibraryInfoTSet], + allow_worker: bool, pkgname: str | None = None, ) -> CommonCompileModuleArgs: command = cmd_args(ghc_wrapper) command.add("--ghc", haskell_toolchain.compiler) command.add("--ghc-dir", haskell_toolchain.ghc_dir) + if allow_worker and haskell_toolchain.use_worker and haskell_toolchain.use_worker_multiplexer: + if haskell_toolchain.worker_multiplexer_plugin == None: + fail("'worker_multiplexer_plugin' must be set on the toolchain if 'use_worker_multiplexer' is true") + if pkgname == None: + warning("Module {} has no 'pkgname', worker multiplexer will break".format(label)) + else: + package_db = pkg_deps.providers[DynamicHaskellPackageDbInfo].packages + db = package_db[haskell_toolchain.worker_multiplexer_plugin[HaskellToolchainLibrary].name] + command.add("--plugin-db", db.value.db) + command.add("--worker-target-id", pkgname) + # Some rules pass in RTS (e.g. `+RTS ... -RTS`) options for GHC, which can't # be parsed when inside an argsfile. command.add(haskell_toolchain.compiler_flags) @@ -753,6 +765,7 @@ def _dynamic_do_compile_impl(actions, md_file, pkg_deps, arg, direct_deps_by_nam enable_profiling = arg.enable_profiling, link_style = arg.link_style, direct_deps_info = arg.direct_deps_info, + allow_worker = arg.allow_worker, pkgname = arg.pkgname, ) @@ -871,6 +884,7 @@ def compile( srcs_envs = ctx.attrs.srcs_envs, toolchain_deps_by_name = toolchain_deps_by_name, worker = worker, + allow_worker = ctx.attrs.allow_worker, ), )) diff --git a/haskell/haskell.bzl b/haskell/haskell.bzl index b2a218ea0..54cbcbe6d 100644 --- a/haskell/haskell.bzl +++ b/haskell/haskell.bzl @@ -1211,6 +1211,10 @@ def haskell_binary_impl(ctx: AnalysisContext) -> list[Provider]: md_file = target_metadata(ctx, sources = ctx.attrs.srcs) + # Provisional hack to have a worker ID + libname = repr(ctx.label.path).replace("//", "_").replace("/", "_") + "_" + ctx.label.name + pkgname = libname.replace("_", "-") + compiled = compile( ctx, link_style, @@ -1218,6 +1222,7 @@ def haskell_binary_impl(ctx: AnalysisContext) -> list[Provider]: enable_haddock = False, md_file = md_file, worker = _persistent_worker(ctx), + pkgname = pkgname, ) haskell_toolchain = ctx.attrs._haskell_toolchain[HaskellToolchainInfo] @@ -1498,9 +1503,8 @@ worker = anon_rule( "srcs_deps": attrs.dict(attrs.string(), attrs.dep(), default = {}), "srcs_envs": attrs.dict(attrs.string(), attrs.string(), default = {}), "template_deps": attrs.list(attrs.dep(), default = []), - # N.B. the _worker_* attrs are only treated by the call site of the anon_target - "_worker_deps": attrs.default_only(attrs.list(attrs.dep(), default = [])), - "_worker_srcs": attrs.default_only(attrs.list(attrs.source(), default = [])), + # N.B. allow_worker is only treated by the call site of the anon_target + "allow_worker": attrs.bool(), } | haskell_common.use_argsfile_at_link_arg() | native_common.link_style(), @@ -1510,10 +1514,11 @@ worker = anon_rule( ) def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None: - if ctx.label.cell == "prelude": + if not ctx.attrs.allow_worker: return None - if not ctx.attrs._haskell_toolchain[HaskellToolchainInfo].use_worker: + tc = ctx.attrs._haskell_toolchain[HaskellToolchainInfo] + if not tc.use_worker: return None worker_target = ctx.actions.anon_target( @@ -1523,14 +1528,14 @@ def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None: "_generate_target_metadata": ctx.attrs._generate_target_metadata, "_ghc_wrapper": ctx.attrs._ghc_wrapper, "_haskell_toolchain": ctx.attrs._haskell_toolchain, - "deps": ctx.attrs._worker_deps, + "deps": tc.worker_deps, "link_style": "shared", "name": "prelude//haskell:worker", - "srcs": ctx.attrs._worker_srcs, - "compiler_flags": ctx.attrs._worker_compiler_flags + [ + "srcs": tc.worker_srcs_multiplexer if tc.use_worker_multiplexer else tc.worker_srcs, + "compiler_flags": tc.worker_compiler_flags + [ "-O2", ], - "linker_flags": ctx.attrs._worker_compiler_flags + [ + "linker_flags": [ "-dynamic", "-rtsopts=all", "-with-rtsopts=-K512M -H -I5 -T", @@ -1538,6 +1543,7 @@ def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None: "-O2", ], "use_argsfile_at_link": False, + "allow_worker": False, }, ) return WorkerInfo(worker_target.artifact("worker")) diff --git a/haskell/toolchain.bzl b/haskell/toolchain.bzl index d4a3eae5b..80ef57b27 100644 --- a/haskell/toolchain.bzl +++ b/haskell/toolchain.bzl @@ -41,6 +41,12 @@ HaskellToolchainInfo = provider( "script_template_processor": provider_field(typing.Any, default = None), "packages": provider_field(typing.Any, default = None), "use_worker": provider_field(bool, default = False), + "use_worker_multiplexer": provider_field(bool, default = False), + "worker_multiplexer_plugin": provider_field(None | Dependency, default = None), + "worker_srcs": provider_field(typing.Any, default = []), + "worker_srcs_multiplexer": provider_field(typing.Any, default = []), + "worker_deps": provider_field(typing.Any, default = []), + "worker_compiler_flags": provider_field(typing.Any, default = []), "ghc_dir": provider_field(typing.Any, default = None), }, ) diff --git a/haskell/worker/BUCK b/haskell/worker/BUCK index 232b24429..9d8f52a26 100644 --- a/haskell/worker/BUCK +++ b/haskell/worker/BUCK @@ -1,3 +1,5 @@ load(":worker.bzl", "worker_libs") [haskell_toolchain_library(name = pkg, visibility = ["PUBLIC"]) for pkg in worker_libs] + +haskell_toolchain_library(name = "ghc-persistent-worker-plugin", visibility = ["PUBLIC"]) diff --git a/haskell/worker/impl b/haskell/worker/impl index dadddf43d..4e016ddf0 160000 --- a/haskell/worker/impl +++ b/haskell/worker/impl @@ -1 +1 @@ -Subproject commit dadddf43dac15881b203be822c3192c76366bb07 +Subproject commit 4e016ddf0be4b24cb2244e2b6d56f1efd24af274 diff --git a/haskell/worker/worker.bzl b/haskell/worker/worker.bzl index b02f56854..961366827 100644 --- a/haskell/worker/worker.bzl +++ b/haskell/worker/worker.bzl @@ -1,20 +1,26 @@ worker_libs = [ "base", + "binary", "bytestring", "containers", "deepseq", + "directory", "exceptions", "filepath", "ghc", "grpc-haskell", + "network", + "process", "proto3-suite", "proto3-wire", + "stm", "text", + "transformers", "vector", "unix", ] -worker_srcs = [ +worker_srcs_shared = [ "@prelude//haskell/worker/impl/plugin/src:Internal/AbiHash.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Args.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Cache.hs", @@ -22,20 +28,32 @@ worker_srcs = [ "@prelude//haskell/worker/impl/plugin/src:Internal/Error.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Log.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Session.hs", - "@prelude//haskell/worker/impl/buck-worker:Args.hs", + "@prelude//haskell/worker/impl/buck-worker/lib:BuckArgs.hs", + "@prelude//haskell/worker/impl/buck-worker/lib:BuckWorker.hs", +] + +worker_srcs = worker_srcs_shared + [ "@prelude//haskell/worker/impl/buck-worker:Main.hs", - "@prelude//haskell/worker/impl/buck-worker:BuckWorker.hs", +] + +worker_srcs_multiplexer = worker_srcs_shared + [ + "@prelude//haskell/worker/impl/comm/src:Message.hs", + "@prelude//haskell/worker/impl/server/lib:Server.hs", + "@prelude//haskell/worker/impl/server/lib:Pool.hs", + "@prelude//haskell/worker/impl/server/lib:Worker.hs", + "@prelude//haskell/worker/impl/buck-multiplex-worker:Main.hs", ] worker_flags = [ "-Wall", - "-XGHC2021", "-XBlockArguments", "-XDerivingStrategies", - "-XRecordWildCards", "-XDuplicateRecordFields", + "-XGHC2021", + "-XLambdaCase", + "-XOverloadedLists", "-XOverloadedRecordDot", + "-XOverloadedStrings", + "-XRecordWildCards", "-XStrictData", - "-XNoFieldSelectors", - "-XLambdaCase", ]