Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Play with how a contrib module could look like #312

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions example/contriblike/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import ast
import sys
from pathlib import Path
from typing import Any, Sequence, Set

import black


def get_tree(path: Path):
src = path.read_text()
return ast.parse(src)


def write_tree(tree, path):
new_src = ast.unparse(ast.fix_missing_locations(tree))
new_src = black.format_str(new_src, mode=black.FileMode(line_length=120))
path.write_text(new_src)


# turn op function defs into async function defs
class OpTransformer(ast.NodeTransformer):
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AsyncFunctionDef:
"""any function in ops is an operator. make them async"""
self.generic_visit(node)
return ast.AsyncFunctionDef(
name=node.name,
args=node.args,
body=node.body,
decorator_list=node.decorator_list,
returns=node.returns,
type_comment=node.type_comment,
)


class ImportTransformer(ast.NodeTransformer):
def __init__(self, *, module_allow_list: Sequence[str] = tuple()):
self.known_ops = set()
self.module_allow_list = frozenset(
{mn for mn in sys.stdlib_module_names if not mn.startswith("_")} | set(module_allow_list)
)

def visit_Import(self, node: ast.Import) -> ast.Import:
for alias in node.names:
if alias.name.split(".")[0] not in self.module_allow_list:
raise ValueError(f"Invalid 'import {alias.name}'.")

return node

def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
if node.module == "ops":
for alias_node in node.names:
op_name = alias_node.name
self.known_ops.add(op_name)
if alias_node.asname is not None:
raise ValueError(
f"Please import operator names without 'as', i.e. use '{op_name}' instead of '{alias_node.asname}'."
)
node.module = "compiled_ops"
elif node.module.split(".")[0] in self.module_allow_list:
pass
else:
raise ValueError(f"Unsupported import from {node.module}")

return node


class OpCallTransformer(ast.NodeTransformer):
def __init__(self, known_ops: Set[str]):
self.known_ops = known_ops

def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AsyncFunctionDef:
"""any function in wfs is a workflow. make them async"""
self.generic_visit(node)
return ast.AsyncFunctionDef(
name=node.name,
args=node.args,
body=node.body,
decorator_list=node.decorator_list,
returns=node.returns,
type_comment=node.type_comment,
)

def visit_Call(self, node: ast.Call) -> Any:
"""await any operator call"""
self.generic_visit(node)
if isinstance(node.func, ast.Name):
if node.func.id in self.known_ops:
return ast.Await(node)
else:
return node
elif isinstance(node.func, ast.Attribute):
return node # e.g. method call
else:
raise NotImplementedError(node.func)


ops_path = Path("ops.py")
ops_tree = get_tree(ops_path)
ops_tree = OpTransformer().visit(ops_tree)
write_tree(ops_tree, ops_path.with_name("compiled_" + ops_path.name))

wfs_path = Path("wfs.py")
wfs_tree = get_tree(wfs_path)

allowed_module_names = [] # todo: add modules from appropriate env to allow-list
import_transformer = ImportTransformer(module_allow_list=allowed_module_names)
wfs_tree = import_transformer.visit(wfs_tree)
wfs_tree = OpCallTransformer(known_ops=import_transformer.known_ops).visit(wfs_tree)

write_tree(wfs_tree, wfs_path.with_name("compiled_" + wfs_path.name))
25 changes: 25 additions & 0 deletions example/contriblike/compiled_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import shutil
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Tuple


async def my_op(a: int) -> str:
return f"{a:~^10}"


async def heavy_compute(p):
Path(p).write_text("done")


async def parallel_op(*, max_thread_workers: int) -> Tuple[str, str]:
srcs = ("src1.txt", "src2.txt")
dests = ("dest1.txt", "dest2.txt")
with ThreadPoolExecutor(max_workers=max_thread_workers) as e:
for (src, dest) in zip(srcs, dests):
e.submit(shutil.copy, src, dest)
return dests


if __name__ == "__main__":
print(my_op(5))
20 changes: 20 additions & 0 deletions example/contriblike/compiled_wfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, Tuple
from compiled_ops import heavy_compute, my_op, parallel_op


async def my_wf(a: int) -> Dict[str, str]:
m = await my_op(a)
return {"centered": m}


async def wf_with_p_ops(*, max_thread_workers: int, max_process_workers: int) -> Tuple[str, str]:
with ProcessPoolExecutor(max_workers=max_process_workers) as e:
e.submit(heavy_compute, "src1.txt")
e.submit(heavy_compute, "src2.txt")
return await parallel_op(max_thread_workers=max_thread_workers)


if __name__ == "__main__":
print(my_wf(5))
print(wf_with_p_ops(max_thread_workers=2, max_process_workers=2))
26 changes: 26 additions & 0 deletions example/contriblike/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import shutil
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Tuple


def my_op(a: int) -> str:
return f"{a:~^10}"


async def heavy_compute(p):
Path(p).write_text("done")


def parallel_op(*, max_thread_workers: int) -> Tuple[str, str]:
srcs = ("src1.txt", "src2.txt")
dests = ("dest1.txt", "dest2.txt")
with ThreadPoolExecutor(max_workers=max_thread_workers) as e:
for src, dest in zip(srcs, dests):
e.submit(shutil.copy, src, dest)

return dests


if __name__ == "__main__":
print(my_op(5))
22 changes: 22 additions & 0 deletions example/contriblike/wfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, Tuple

from ops import heavy_compute, my_op, parallel_op


def my_wf(a: int) -> Dict[str, str]:
m = my_op(a)
return {"centered": m}


def wf_with_p_ops(*, max_thread_workers: int, max_process_workers: int) -> Tuple[str, str]:
with ProcessPoolExecutor(max_workers=max_process_workers) as e:
e.submit(heavy_compute, "src1.txt")
e.submit(heavy_compute, "src2.txt")

return parallel_op(max_thread_workers=max_thread_workers)


if __name__ == "__main__":
print(my_wf(5))
print(wf_with_p_ops(max_thread_workers=2, max_process_workers=2))