diff --git a/pydra/engine/task.py b/pydra/engine/task.py index 4872445a0c..03b8963414 100644 --- a/pydra/engine/task.py +++ b/pydra/engine/task.py @@ -209,6 +209,153 @@ def _run_task(self): self.output_[output_names[0]] = output +class BoutiquesTask(FunctionTask): + """Wrap a Boutiques callable as a task element.""" + + global boutiques_func + + def boutiques_func(descriptor, args, **kwargs): + from boutiques.descriptor2func import function + + tool = function(descriptor) + ret = tool(*args, **kwargs) + + # print formatted output + print(ret) + + if ret.exit_code: + raise RuntimeError(ret.stderr) + + return ret.output_files + + def __init__( + self, + descriptor: ty.Text, + audit_flags: AuditFlag = AuditFlag.NONE, + cache_dir=None, + cache_locations=None, + input_spec: ty.Optional[SpecInfo] = None, + messenger_args=None, + messengers=None, + name=None, + output_spec: ty.Optional[BaseSpec] = None, + rerun=False, + bosh_args=[], + **kwargs, + ): + """ + Initialize this task. + + Parameters + ---------- + descriptor : :obj:`str` + The filename or zenodo ID of the boutiques descriptor + audit_flags: :obj:`pydra.utils.messenger.AuditFlag` + Auditing configurations + cache_dir : :obj:`os.pathlike` + Cache directory + cache_locations : :obj:`list` of :obj:`os.pathlike` + List of alternative cache locations. + input_spec: :obj:`pydra.engine.specs.SpecInfo` + Specification of inputs. + messenger_args : + TODO + messengers : + TODO + name : :obj:`str` + Name of this task. + output_spec : :obj:`pydra.engine.specs.BaseSpec` + Specification of inputs. + bosh_args : :object:`list` of :obj:`str` + List of arguments to pass to Boutiques + + """ + + func = boutiques_func + self.func = func + + default_fields = [ + ( + val.name, + attr.ib( + default=val.default, + type=val.annotation, + metadata={ + "help_string": f"{val.name} parameter from {func.__name__}" + }, + ), + ) + if val.default is not inspect.Signature.empty + else ( + val.name, + attr.ib(type=val.annotation, metadata={"help_string": val.name}), + ) + for val in inspect.signature(func).parameters.values() + if val.name != "kwargs" + ] + + # Adding kwargs here because Boutiques also validates inputs + if input_spec is None: + input_spec = SpecInfo( + name="Inputs", + fields=[ + (k, attr.ib(type=type(v), metadata={"help_string": k}),) + for k, v in kwargs.items() + ], + bases=(BaseSpec,), + ) + + # users shouldn't have to add "descriptor" and "args" in their input_spec + input_spec.fields.extend(default_fields) + + fmt_kwargs = {"descriptor": descriptor, "args": bosh_args} + fmt_kwargs.update(kwargs) + + super(BoutiquesTask, self).__init__( + func, + name=name, + audit_flags=audit_flags, + messengers=messengers, + messenger_args=messenger_args, + cache_dir=cache_dir, + cache_locations=cache_locations, + input_spec=input_spec, + rerun=rerun, + **fmt_kwargs, + ) + + if output_spec is None: + output_spec = SpecInfo( + name="Output", fields=[("out", ty.Any)], bases=(BaseSpec,) + ) + self.output_spec = output_spec + + def _run_task(self): + + inputs = attr.asdict(self.inputs) + del inputs["_func"] + self.output_ = None + + output = cp.loads(self.inputs._func)(**inputs) + if output is not None: + output_names = [el[0] for el in self.output_spec.fields] + self.output_ = {} + if len(output_names) > 1 or output_names[0] != "out": + self.output_ = { + f.boutiques_name: f.file_name + for f in output + if f.boutiques_name in output_names + } + if len(output_names) != len(self.output_.keys()): + raise Exception( + f"expected {len(self.output_spec.fields)} elements, " + f"but {len(output)} were returned" + ) + else: + # Return all filenames if not specified by the user what to return + self.output_[output_names[0]] = [out.file_name for out in output] + + class ShellCommandTask(TaskBase): """Wrap a shell command as a task element."""