Skip to content

Commit

Permalink
Add persistent worker for haskell
Browse files Browse the repository at this point in the history
  • Loading branch information
avdv authored and tek committed Nov 8, 2024
1 parent 3197230 commit cc71ebd
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "haskell/worker/impl"]
path = haskell/worker/impl
url = [email protected]:MercuryTechnologies/ghc-persistent-worker
1 change: 1 addition & 0 deletions decls/haskell_common.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _scripts_arg():
providers = [RunInfo],
default = "prelude//haskell/tools:ghc_wrapper",
),
"_worker": attrs.option(attrs.exec_dep(providers = [WorkerInfo]), default = None),
}

def _external_tools_arg():
Expand Down
7 changes: 7 additions & 0 deletions decls/haskell_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ load("@prelude//linking:types.bzl", "Linkage")
load(":common.bzl", "LinkableDepType", "buck", "prelude_rule")
load(":haskell_common.bzl", "haskell_common")
load(":native_common.bzl", "native_common")
load("@prelude//haskell/worker/worker.bzl", "worker_libs", "worker_srcs", "worker_flags")

haskell_binary = prelude_rule(
name = "haskell_binary",
Expand Down Expand Up @@ -66,6 +67,9 @@ haskell_binary = prelude_rule(
"linker_flags": attrs.list(attrs.arg(), default = []),
"platform": attrs.option(attrs.string(), default = None),
"platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []),
"_worker_srcs": attrs.list(attrs.source(), default = worker_srcs),
"_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]),
"_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags),
}
),
)
Expand Down Expand Up @@ -188,6 +192,9 @@ haskell_library = prelude_rule(
"linker_flags": attrs.list(attrs.arg(), default = []),
"platform": attrs.option(attrs.string(), default = None),
"platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []),
"_worker_srcs": attrs.list(attrs.source(), default = worker_srcs),
"_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]),
"_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags),
}
),
)
Expand Down
9 changes: 9 additions & 0 deletions haskell/compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def _common_compile_module_args(
) -> CommonCompileModuleArgs:
command = cmd_args(ghc_wrapper)
command.add("--ghc", haskell_toolchain.compiler)
command.add("--ghc-dir", haskell_toolchain.ghc_dir)

# Some rules pass in RTS (e.g. `+RTS ... -RTS`) options for GHC, which can't
# be parsed when inside an argsfile.
Expand Down Expand Up @@ -579,6 +580,7 @@ def _compile_module(
aux_deps: None | list[Artifact],
src_envs: None | dict[str, ArgLike],
source_prefixes: list[str],
worker: None | WorkerInfo,
) -> CompiledModuleTSet:
# These compiler arguments can be passed in a response file.
compile_args_for_file = cmd_args(common_args.args_for_file, hidden = aux_deps or [])
Expand Down Expand Up @@ -699,6 +701,7 @@ def _compile_module(
compile_cmd.add("-fwrite-if-simplified-core")
if enable_th:
compile_cmd.add("-fprefer-byte-code")
compile_cmd.add("-fpackage-db-byte-code")

compile_cmd.add(cmd_args(dependency_modules.reduce("packagedb_deps").keys(), prepend = "--buck2-package-db"))

Expand All @@ -709,6 +712,8 @@ def _compile_module(
compile_cmd.add("--buck2-dep", tagged_dep_file)
compile_cmd.add("--abi-out", outputs[module.hash])

worker_args = dict() if worker == None else dict(exe = WorkerRunInfo(worker = worker))

actions.run(
compile_cmd, category = "haskell_compile_" + artifact_suffix.replace("-", "_"), identifier = module_name,
dep_files = {
Expand All @@ -717,6 +722,7 @@ def _compile_module(
},
# explicit turn this on for local_only actions to upload their results.
allow_cache_upload = True,
**worker_args,
)

module_tset = actions.tset(
Expand Down Expand Up @@ -783,6 +789,7 @@ def _dynamic_do_compile_impl(actions, md_file, pkg_deps, arg, direct_deps_by_nam
direct_deps_by_name = direct_deps_by_name,
toolchain_deps_by_name = arg.toolchain_deps_by_name,
source_prefixes = source_prefixes,
worker = arg.worker,
)

return [DynamicCompileResultInfo(modules = module_tsets)]
Expand All @@ -807,6 +814,7 @@ def compile(
enable_profiling: bool,
enable_haddock: bool,
md_file: Artifact,
worker: WorkerInfo | None = None,
pkgname: str | None = None) -> CompileResultInfo:
artifact_suffix = get_artifact_suffix(link_style, enable_profiling)

Expand Down Expand Up @@ -862,6 +870,7 @@ def compile(
sources_deps = ctx.attrs.srcs_deps,
srcs_envs = ctx.attrs.srcs_envs,
toolchain_deps_by_name = toolchain_deps_by_name,
worker = worker,
),
))

Expand Down
68 changes: 68 additions & 0 deletions haskell/haskell.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

load("@prelude//utils:arglike.bzl", "ArgLike")
load("@prelude//:paths.bzl", "paths")
load("@prelude//cxx:link_groups_types.bzl", "LINK_GROUP_MAP_ATTR")
load("@prelude//cxx:archive.bzl", "make_archive")
load(
"@prelude//cxx:cxx.bzl",
Expand Down Expand Up @@ -151,6 +152,8 @@ load(
load("@prelude//utils:argfile.bzl", "at_argfile")
load("@prelude//utils:set.bzl", "set")
load("@prelude//utils:utils.bzl", "filter_and_map_idx", "flatten")
load("@prelude//decls:native_common.bzl", "native_common")
load("@prelude//decls:haskell_common.bzl", "haskell_common")

HaskellIndexingTSet = transitive_set()

Expand Down Expand Up @@ -644,6 +647,7 @@ def _build_haskell_lib(
enable_haddock = enable_haddock,
md_file = md_file,
pkgname = pkgname,
worker = _persistent_worker(ctx),
)
solibs = {}
artifact_suffix = get_artifact_suffix(link_style, enable_profiling)
Expand Down Expand Up @@ -1213,6 +1217,7 @@ def haskell_binary_impl(ctx: AnalysisContext) -> list[Provider]:
enable_profiling = enable_profiling,
enable_haddock = False,
md_file = md_file,
worker = _persistent_worker(ctx),
)

haskell_toolchain = ctx.attrs._haskell_toolchain[HaskellToolchainInfo]
Expand Down Expand Up @@ -1474,3 +1479,66 @@ def _haskell_module_sub_targets(*, compiled, link_style, enable_profiling):
if o.extension[1:] == osuf
})],
}

worker = anon_rule(
impl = haskell_binary_impl,
attrs = {
"_cxx_toolchain": attrs.dep(),
"_generate_target_metadata": attrs.dep(providers = [RunInfo]),
"_ghc_wrapper": attrs.dep(providers = [RunInfo]),
"_haskell_toolchain": attrs.dep(providers = [HaskellToolchainInfo]),
"compiler_flags": attrs.list(attrs.string(), default = []),
"deps": attrs.list(attrs.dep()),
"enable_profiling": attrs.default_only(attrs.bool(default = False)),
"external_tools": attrs.list(attrs.dep(), default = []),
"link_group_map": LINK_GROUP_MAP_ATTR,
"linker_flags": attrs.list(attrs.string(), default = []),
"platform_deps": attrs.list(attrs.dep(), default = []),
"srcs": attrs.list(attrs.source()),
"srcs_deps": attrs.dict(attrs.string(), attrs.dep(), default = {}),
"srcs_envs": attrs.dict(attrs.string(), attrs.string(), default = {}),
"template_deps": attrs.list(attrs.dep(), default = []),
# N.B. the _worker_* attrs are only treated by the call site of the anon_target
"_worker_deps": attrs.default_only(attrs.list(attrs.dep(), default = [])),
"_worker_srcs": attrs.default_only(attrs.list(attrs.source(), default = [])),
}
| haskell_common.use_argsfile_at_link_arg()
| native_common.link_style(),
artifact_promise_mappings = {
"worker": lambda x: x[DefaultInfo].default_outputs[0],
},
)

def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None:
if ctx.label.cell == "prelude":
return None

if not ctx.attrs._haskell_toolchain[HaskellToolchainInfo].use_worker:
return None

worker_target = ctx.actions.anon_target(
worker,
{
"_cxx_toolchain": ctx.attrs._cxx_toolchain,
"_generate_target_metadata": ctx.attrs._generate_target_metadata,
"_ghc_wrapper": ctx.attrs._ghc_wrapper,
"_haskell_toolchain": ctx.attrs._haskell_toolchain,
"deps": ctx.attrs._worker_deps,
"link_style": "shared",
"name": "prelude//haskell:worker",
"srcs": ctx.attrs._worker_srcs,
"compiler_flags": ctx.attrs._worker_compiler_flags + [
"-O2",
],
"linker_flags": ctx.attrs._worker_compiler_flags + [
"-dynamic",
"-rtsopts=all",
"-with-rtsopts=-K512M -H -I5 -T",
"-threaded",
"-O2",
],
"use_argsfile_at_link": False,
},
)
return WorkerInfo(worker_target.artifact("worker"))

2 changes: 2 additions & 0 deletions haskell/toolchain.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ HaskellToolchainInfo = provider(
"cache_links": provider_field(typing.Any, default = None),
"script_template_processor": provider_field(typing.Any, default = None),
"packages": provider_field(typing.Any, default = None),
"use_worker": provider_field(bool, default = False),
"ghc_dir": provider_field(typing.Any, default = None),
},
)

Expand Down
3 changes: 3 additions & 0 deletions haskell/tools/ghc_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def main():
parser.add_argument(
"--ghc", required=True, type=str, help="Path to the Haskell compiler GHC."
)
parser.add_argument(
"--ghc-dir", type=str, help="Worker option"
)
parser.add_argument(
"--abi-out",
required=True,
Expand Down
3 changes: 3 additions & 0 deletions haskell/worker/BUCK
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
load(":worker.bzl", "worker_libs")

[haskell_toolchain_library(name = pkg, visibility = ["PUBLIC"]) for pkg in worker_libs]
1 change: 1 addition & 0 deletions haskell/worker/impl
Submodule impl added at dadddf
41 changes: 41 additions & 0 deletions haskell/worker/worker.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
worker_libs = [
"base",
"bytestring",
"containers",
"deepseq",
"exceptions",
"filepath",
"ghc",
"grpc-haskell",
"proto3-suite",
"proto3-wire",
"text",
"vector",
"unix",
]

worker_srcs = [
"@prelude//haskell/worker/impl/plugin/src:Internal/AbiHash.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Args.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Cache.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Compile.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Error.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Log.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Session.hs",
"@prelude//haskell/worker/impl/buck-worker:Args.hs",
"@prelude//haskell/worker/impl/buck-worker:Main.hs",
"@prelude//haskell/worker/impl/buck-worker:BuckWorker.hs",
]

worker_flags = [
"-Wall",
"-XGHC2021",
"-XBlockArguments",
"-XDerivingStrategies",
"-XRecordWildCards",
"-XDuplicateRecordFields",
"-XOverloadedRecordDot",
"-XStrictData",
"-XNoFieldSelectors",
"-XLambdaCase",
]

0 comments on commit cc71ebd

Please sign in to comment.