diff --git a/.gitignore b/.gitignore index 3999e3b5dca..0d7e303042d 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,5 @@ xcuserdata/ .swiftpm/ *.xcworkspace/ *.xcframework/ + +playground/ diff --git a/comparison/export.py b/comparison/export.py new file mode 100644 index 00000000000..45c00d27711 --- /dev/null +++ b/comparison/export.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +from models import model_dict +import torch_mlir.fx +import os +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.exir import to_edge_transform_and_lower +from torch.export.experimental import _export_forward_backward +from torch._decomp import get_decompositions, register_decomposition +from torch_mlir.extras.fx_decomp_util import DEFAULT_DECOMPOSITIONS +from torch.nn.attention import sdpa_kernel, SDPBackend +import copy +import subprocess +import time +from glob import glob +from mlir_inject_resources import patch_file + + +ph_input = torch.randn(1, 3, 224, 224) +ph_label = torch.randint(0, 10, (1,)) + +class TrainingNet(nn.Module): + def __init__(self, net): + super().__init__() + self.net = net + self.loss = nn.CrossEntropyLoss() + + def forward(self, input, label): + pred = self.net(input) + return self.loss(pred, label), pred.detach().argmax(dim=1) + +@register_decomposition(torch.ops.prims.broadcast_in_dim.default) +def _bcast_dim_decomp(x, shape, b_dims): + for i, s in enumerate(shape): + if i not in b_dims: + x = x.unsqueeze(i) + return x.expand(shape).clone() + +@register_decomposition(torch.ops.prims.div.default) +def _prims_div(a, b, rounding_mode=None): + return torch.div(a, b, rounding_mode=rounding_mode) + +# 2. prims.sum -> aten.sum +@register_decomposition(torch.ops.prims.sum.default) +def _prims_sum(x, dims, dtype=None): + return torch.sum(x, dim=dims, dtype=dtype) + +# 3. prims.var -> aten.var +@register_decomposition(torch.ops.prims.var.default) +def _prims_var(x, dims, correction=0, keepdim=False): + unbiased = correction != 0 + return torch.var(x, dim=dims, unbiased=unbiased, keepdim=keepdim) + +# Remove bernoulli RNG +@register_decomposition(torch.ops.aten.bernoulli.p) +def _bernoulli_p(x, p, generator=None): + scale = 1.0 / (1.0 - p) + return torch.full_like(x, scale).to(x.dtype) + +def cannonicalize_empty_permute(ep): + for node in ep.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.empty_permuted.out + ): + print("found empty_permute: ", node) + empty_permuted_node = node + with ep.graph.inserting_before(empty_permuted_node): + empty_node = ep.graph.create_node( + "call_function", + torch.ops.aten.empty.memory_format, + (node.args[0],), + empty_permuted_node.kwargs, + ) + permute_node = ep.graph.create_node( + "call_function", + torch.ops.aten.permute, + (empty_node, node.args[1]), + ) + for user in empty_permuted_node.users.copy(): + user.replace_input_with(empty_permuted_node, permute_node) + if ( + node.op == "call_function" + and node.target == torch.ops.aten.empty_permuted.default + ): + print("found empty_permute default: ", node) + empty_permuted_node = node + with ep.graph.inserting_before(empty_permuted_node): + empty_node = ep.graph.create_node( + "call_function", + torch.ops.aten.empty.memory_format, + (node.args[0],), + empty_permuted_node.kwargs, + ) + permute_node = ep.graph.create_node( + "call_function", + torch.ops.aten.permute.default, + (empty_node, node.args[1]), + ) + for user in empty_permuted_node.users.copy(): + user.replace_input_with(empty_permuted_node, permute_node) + return ep + + +def compile_executorch(ep, prefix): + xnn_et = to_edge_transform_and_lower(ep, partitioner=[XnnpackPartitioner()]).to_executorch() + with open(f"{prefix}.pte", "wb") as f: + f.write(xnn_et.buffer) + +def safe_cmd(cmd, timeout_sec=150, max_retries=3): + for attempt in range(1, max_retries + 1): + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout_sec + ) + + stderr = result.stderr.lower() + stdout = result.stdout.lower() + + if result.returncode == 0: + print("✅ Compilation succeeded.") + return # 정상 종료 + + if "double free" in stderr or "double free" in stdout: + print("⚠️ Detected 'double free' error. Retrying...") + time.sleep(1) + continue + + # 다른 오류면 바로 종료 + print("❌ Compilation failed:") + print(stderr) + break + + except subprocess.TimeoutExpired: + print(f"⏱️ Timeout after {timeout_sec} seconds. Retrying...") + continue + + print("❌ Compilation failed after max retries.") + return + +def compile_iree(ep, prefix, decomposition_table=None): + print("Exporting MLIR") + mlir = torch_mlir.fx.export_and_import(ep, output_type="raw", decomposition_table=decomposition_table) + print(mlir, file=open(f"{prefix}.mlir", "w")) + cmd = ["/root/iree-latest-build/install/bin/iree-compile", + "--iree-input-type=torch", + "--iree-hal-target-device=local", + "--iree-hal-local-target-device-backends=llvm-cpu", + "--iree-llvmcpu-target-cpu=host", + "--mlir-elide-elementsattrs-if-larger=1", + "--mlir-elide-resource-strings-if-larger=1", + f"--iree-hal-dump-executable-benchmarks-to={prefix}_dispatch", + f"--dump-compilation-phases-to={prefix}_phases", + f"{prefix}.mlir", + "-o", + f"{prefix}.vmfb" + ] + print("\n\t".join(cmd)) + safe_cmd(cmd) + for dispatch in glob(f"{prefix}_dispatch/*.mlir"): + patch_file(dispatch, dispatch) + +def export_inference(model, prefix, et=True, iree=True): + ep = torch.export.export(model, (ph_input, )) + print(ep.graph_module.graph, file=open(f"{prefix}_graph.txt", "w")) + if et: + compile_executorch(ep, prefix) + if iree: + decomposition_table = get_decompositions([ + *DEFAULT_DECOMPOSITIONS, + torch.ops.aten.as_strided + ]) + compile_iree(ep, prefix, decomposition_table) + + +def export_training(model, prefix, et=True, iree=True): + model = TrainingNet(model) + if et: # faster than deepcopy + ep = torch.export.export(model, (ph_input, ph_label), strict=False) + ep = _export_forward_backward(ep) + ep = cannonicalize_empty_permute(ep) + print(ep.graph_module.graph, file=open(f"{prefix}_graph.txt", "w")) + print("Training Executorch - EP mutate done") + decomposition_table = get_decompositions([ + torch.ops.aten._native_batch_norm_legit_functional.default, + torch.ops.aten._native_batch_norm_legit_no_training.default, + torch.ops.aten.var_mean.correction, + torch.ops.prims.broadcast_in_dim.default, + torch.ops.prims.div.default, # ★ + torch.ops.prims.sum.default, # ★ + torch.ops.prims.var.default, + torch.ops.aten.bernoulli.p, + torch.ops.aten.empty_strided.out, + torch.ops.aten.convolution_backward + ]) + ep = ep.run_decompositions(decomposition_table) + compile_executorch(ep, prefix) + print("Training Executorch - compilation done") + if iree: + ep = torch.export.export(model, (ph_input, ph_label), strict=False) + ep = _export_forward_backward(ep) + ep = cannonicalize_empty_permute(ep) + print("Training IREE - EP mutate done") + decomposition_table = get_decompositions([ + *DEFAULT_DECOMPOSITIONS, + torch.ops.aten.convolution_backward, + torch.ops.aten.bernoulli.p, + torch.ops.aten.any.dim, + torch.ops.aten.as_strided + ]) + compile_iree(ep, prefix, decomposition_table) + print("Training IREE - compilation done") + +def export_model(model_name, train=False, et=True, iree=True): + print(f"Exporting {model_name} {'train' if train else 'inf'}") + model = model_dict[model_name](train) + prefix = f"./{model_name}/{'train' if train else 'inf'}" + if not os.path.exists(f"./{model_name}"): + os.makedirs(f"./{model_name}") + if not train: + export_inference(model, prefix, et, iree) + else: + export_training(model, prefix, et, iree) + +if __name__ == "__main__": + for train in [False, True]: + # export_model("vgg", train) + # export_model("alexnet", train) + export_model("mobilenet_v2", False, et=False) + # export_model("efficientnet_b0", train) + # export_model("resnet18", train) + with sdpa_kernel([SDPBackend.MATH]): + # export_model("vit_b_16", train) + # export_model("vit_tiny_patch16_224", train, et=False) + # export_model("mobilevit_s", train, et=False) + pass \ No newline at end of file diff --git a/comparison/inspector.py b/comparison/inspector.py new file mode 100644 index 00000000000..0b29e02ce80 --- /dev/null +++ b/comparison/inspector.py @@ -0,0 +1,55 @@ +# %% +from iree.compiler import ir +from collections import defaultdict +import os + + +os.chdir("/root/executorch/comparison") +def cvalue_to_int(cvalue): + """Operand 이름을 int로 변환하거나 원래 이름 반환.""" + name = str(cvalue.get_name()) + if name.startswith("%zero"): + return 0 + if name.startswith("%c"): + try: + return int(name[2:]) + except ValueError: + return name + return name + +def parse_dispatch_statistics(prefix): + """vm.call.variadic -> hal.command_buffer.dispatch 호출 통계를 수집.""" + train_or_test = os.path.basename(prefix) + mlir_path = f"{prefix}_phases/{train_or_test}.12.vm.mlir" + + with ir.Context(): + with open(mlir_path, "r") as f: + module = ir.Module.parse(f.read()) + + # 모듈 첫 번째 operation의 첫 번째 블록 + top_block = list(module.body.operations)[0].body.blocks[0] + + # vm.func 찾기 + func_op = next((op for op in top_block.operations if op.operation.name == "vm.func"), None) + if func_op is None: + raise RuntimeError("No vm.func found in module.") + + stats = defaultdict(int) + for op in func_op.body.blocks[0].operations: + op_name = op.operation.name + + if op_name == "vm.call.variadic": + target_attr = str(op.attributes[0].attr) + if target_attr == "@hal.command_buffer.dispatch": + dispatch_id = cvalue_to_int(op.operands[2]) + stats[dispatch_id] += 1 + + elif op_name == "vm.call": + target_attr = str(op.attributes[0].attr) + if target_attr == "@hal.command_buffer.execution_barrier": + # 아직 별도 처리 없음 (pass) + pass + + return stats + + diff --git a/comparison/mlir_inject_resources.py b/comparison/mlir_inject_resources.py new file mode 100644 index 00000000000..9a35fd11f08 --- /dev/null +++ b/comparison/mlir_inject_resources.py @@ -0,0 +1,47 @@ +import re + +def collect_required_resources(content): + """ 필요한 dense_resource 이름과 타입(사이즈) 추출 """ + pattern = r'dense_resource<([\w\\\.]+)>\s*:\s*tensor<(\d+)x(f\d+)>' + matches = re.findall(pattern, content) + resources = {} + for name, dim, dtype in matches: + print("- Found: ", name, dim, dtype) + dim = int(dim) + if dtype == "f32": + size = dim * 4 # float32 = 4 bytes + elif dtype == "f16": + size = dim * 2 # float16 = 2 bytes + elif dtype == "i32": + size = dim * 4 # int32 = 4 bytes + else: + raise ValueError(f"Unsupported dtype {dtype}") + resources[name] = size + return resources + +def generate_dummy_blob(size): + """ 크기에 맞는 00 blob 생성 (hex string) """ + return '00' * size + +def generate_dialect_resource_block(resource_sizes): + """ dialect_resources 블록 생성 """ + lines = ["{-#", " dialect_resources: {", " builtin: {"] + for key, size in resource_sizes.items(): + blob = generate_dummy_blob(size) + lines.append(f' {key}: "0x04000000{blob}",') + if lines[-1].endswith(','): + lines[-1] = lines[-1][:-1] # 마지막 콤마 제거 + lines += [" }", " }", "#-}"] + return "\n".join(lines) + +def patch_file(target_file, output_file): + with open(target_file, 'r') as f: + content = f.read() + resource_sizes = collect_required_resources(content) + if len(list(resource_sizes.values())) > 0: + dialect_block = generate_dialect_resource_block(resource_sizes) + + with open(output_file, 'a') as f: + f.write("\n\n" + dialect_block + "\n") + + print(f"Patched file saved to {output_file}") \ No newline at end of file diff --git a/comparison/models.py b/comparison/models.py new file mode 100644 index 00000000000..52100c19c43 --- /dev/null +++ b/comparison/models.py @@ -0,0 +1,94 @@ +import torchvision.models as models +import torch.nn as nn +import timm +import torch + +def vgg(train=False): + model = models.vgg11(num_classes=10) + if train: + model.train() + model.avgpool = nn.MaxPool2d(1) + model.classifier[2] = nn.Identity() + model.classifier[5] = nn.Identity() + else: + model.eval() + model.avgpool = nn.AvgPool2d(1) + return model + +def alexnet(train=False): + model = models.alexnet(num_classes=10) + if train: + model.train() + model.avgpool = nn.MaxPool2d(1) + model.classifier[0] = nn.Identity() + model.classifier[3] = nn.Identity() + else: + model.eval() + model.avgpool = nn.AvgPool2d(1) + return model + +def mobilenet_v2(train=False): + model = models.mobilenet_v2(num_classes=10) + if train: + model.train() + model.classifier[0] = nn.Identity() + else: + model.eval() + return model + +def efficientnet_b0(train=False): + model = models.efficientnet_b0(num_classes=10) + if train: + model.train() + model.avgpool = nn.MaxPool2d(kernel_size=7) + model.classifier[0] = nn.Identity() + else: + model.eval() + model.avgpool = nn.AvgPool2d(kernel_size=7) + return model + +def resnet18(train=False): + model = models.resnet18(num_classes=10) + if train: + model.train() + model.avgpool = nn.MaxPool2d(kernel_size=7) + else: + model.eval() + model.avgpool = nn.AvgPool2d(kernel_size=7) + return model + +def vit_b_16(train=False): + model = models.vit_b_16(num_classes=10) + if train: + model.train() + else: + model.eval() + return model + +def vit_tiny_patch16_224(train=False): + model = timm.create_model("vit_tiny_patch16_224", pretrained=True) + if train: + model.train() + else: + model.eval() + return model + +def mobilevit_s(train=False): + model = timm.create_model("mobilevit_s", pretrained=True) + if train: + model.train() + else: + model.eval() + return model + +model_dict = { + "vgg": vgg, + "alexnet": alexnet, + "mobilenet_v2": mobilenet_v2, + "efficientnet_b0": efficientnet_b0, + "resnet18": resnet18, + "vit_b_16": vit_b_16, + "vit_tiny_patch16_224": vit_tiny_patch16_224, + "mobilevit_s": mobilevit_s, +} + diff --git a/comparison/run_torch.py b/comparison/run_torch.py new file mode 100644 index 00000000000..04a91d0f9df --- /dev/null +++ b/comparison/run_torch.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from torchvision import models +from torch.utils.benchmark import Timer +import timm +import traceback +from models import model_dict + + +def run_torch(model_name, train=False, multithreaded=True, repeat=100): + model = model_dict[model_name](train) + loss_fn = nn.CrossEntropyLoss() + ph_input = torch.randn(1, 3, 224, 224) + ph_label = torch.randint(0, 10, (1,)) + + def train_step(model, loss_fn, input, label): + model.zero_grad() + output = model(input) + loss = loss_fn(output, label) + loss.backward() + return loss + + def eval_step(model, input): + model.eval() + output = model(input) + return output + + threads = 48 if multithreaded else 1 + torch.set_num_threads(threads) + + if train: + step = train_step + model.train() + stmt = 'step(model, loss_fn, ph_input, ph_label)' + _globals = { + 'step': step, + 'model': model, + 'loss_fn': loss_fn, + 'ph_input': ph_input, + 'ph_label': ph_label + } + else: + step = eval_step + model.eval() + stmt = 'step(model, ph_input)' + _globals = { + 'step': step, + 'model': model, + 'ph_input': ph_input + } + try: + t_torch = Timer( + stmt=stmt, + globals=_globals, + num_threads=threads + ).timeit(repeat) + print(f"{model_name} - Torch: {t_torch.mean}") + except Exception as e: + print(f"{model_name} - Torch: Error") + traceback.print_exc(file=open(f"./torch_experiment_error.txt", "a")) + + try: + compiled_model = torch.compile(model, backend='inductor', fullgraph=True) + train_step(compiled_model, loss_fn, ph_input, ph_label) + t_inductor = Timer( + stmt=stmt, + globals=_globals, + num_threads=threads + ).timeit(repeat) + print(f"{model_name} - Inductor: {t_inductor.mean}") + except Exception as e: + print(f"{model_name} - Inductor: Error") + traceback.print_exc(file=open(f"./torch_experiment_error.txt", "a")) + +if __name__ == "__main__": + run_torch("mobilevit_s", train=True, multithreaded=False) + # print() + # run_torch("mobilevit_s", train=False, multithreaded=True) + # print() + # run_torch("vit_b_16", train=False, multithreaded=True) + # run_torch("mobilenet_v2", train=False, multithreaded=True) + # run_torch("mobilevit_s", train=False, multithreaded=True) + # print() + # run_torch("vit_b_16", train=False, multithreaded=False) + # run_torch("mobilevit_s", train=False, multithreaded=False) \ No newline at end of file diff --git a/comparison/runner.py b/comparison/runner.py new file mode 100644 index 00000000000..96d2b84d558 --- /dev/null +++ b/comparison/runner.py @@ -0,0 +1,106 @@ +import os, re +from models import model_dict +import glob +import subprocess +import json +from inspector import parse_dispatch_statistics + +def make_input_args(fname): + with open(fname, "r") as f: + next(f) # 첫 번째 줄 건너뛰기 + definition = next(f).split("->")[0].split("!")[1:] + + print(len(definition)) + + def process_definition(d): + d = re.sub(r"(^.*<)|(>.*)", "", d) + shape, _type = d.split("]") + shape = re.sub(",", "x", shape)[1:] if shape else "" + _type = _type[1:] + _type = {"si64": "i64", "si32": "i32"}.get(_type, _type) + + return f" --input={shape}x{_type}=0.5 \\" if shape else f" --input={_type}=0.5 \\" + + return "\n".join(process_definition(d) for d in definition) + +def run_executorch(model_name, train=False, xnn=True, num_executions=100, multithreaded=True, profiling=False): + prefix = f"./{model_name}/{'train' if train else 'inf'}" + model_path = f"{prefix}.pte" + threads = 48 if multithreaded else 1 + app = "/root/executorch/cmake-out/executor_runner" + cmd = f"{app} --model_path={model_path} \ + --num_executions={num_executions} \ + --cpu_threads={threads}" + if profiling: + cmd += f" --etdump_path={prefix}_{threads}.etdp" + print(cmd) + os.system(cmd) + if profiling: + from executorch.devtools import Inspector + inspector = Inspector(etdump_path=f"{prefix}_{threads}.etdp") + inspector.save_data_to_tsv(f"{prefix}_{threads}.tsv", include_delegate_debug_data = True) + +def run_iree(model_name, train=False, num_executions=100, multithreaded=True): + prefix = f"./{model_name}/{'train' if train else 'inf'}" + app = "/root/iree-latest-build/install/bin/iree-benchmark-module" + cmd = f"{app} \ + --module={prefix}.vmfb \ + --device={'local-task' if multithreaded else 'local-sync'} \ + --function=main \ + --benchmark_repetitions={num_executions} \ + {make_input_args(prefix + '.mlir')}" + print(cmd) + os.system(cmd) + +def run_iree_prof(model_name, train=False, num_executions=100, multithreaded=True): + prefix = f"./{model_name}/{'train' if train else 'inf'}" + + files = list(glob.glob(f"{prefix}_dispatch/*.mlir")) + def get_idx(f): + return int(f.split("async_dispatch_")[1].split("_")[0]) + files = [(get_idx(f), f) for f in files] + files.sort(key=lambda x: x[0]) + + stats = parse_dispatch_statistics(prefix) + real_time_sum = 0 + cpu_time_sum = 0 + for j, (i, f) in enumerate(files): + try: + result = subprocess.run( + f"/root/iree-latest-build/install/bin/iree-compile '{f}'" + " | /root/iree-latest-build/install/bin/iree-benchmark-module " + "--benchmark_format=json --benchmark_repetitions=1 --module=-" + f" --device={'local-task' if multithreaded else 'local-sync'}", + shell=True, + check=True, + capture_output=True, + text=True, + ) + bench = json.loads(result.stdout) + for bench_item in bench["benchmarks"]: + name = bench_item["name"].split("async_dispatch_")[2].split("/process_time")[0] + real_time = bench_item["real_time"] + cpu_time = bench_item["cpu_time"] + # time_unit = bench_item["time_unit"] + rep = stats[j] if j in stats else 0 + print(f"{i:<5}\t{name:<40}\t{real_time:>15.2f}\t{cpu_time:>15.2f}\t{rep}\t{real_time * rep:>15.2f}\t{cpu_time * rep:>15.2f}") + real_time_sum += real_time * rep + cpu_time_sum += cpu_time * rep + except subprocess.CalledProcessError as e: + print("ERR") + print(f"real_time_sum: {real_time_sum}, cpu_time_sum: {cpu_time_sum}") + +# for train in [True, False]: +# run_iree("vit_b_16", train=train, num_executions=10, multithreaded=True) +# run_iree("vit_tiny_patch16_224", train=train, num_executions=10, multithreaded=True) +# run_iree("mobilevit_s", train=train, num_executions=10, multithreaded=True) +# print() + +# run_executorch("vit_b_16", train=False, xnn=True, num_executions=100, multithreaded=False) +# run_executorch("vit_tiny_patch16_224", train=False, xnn=True, num_executions=100, multithreaded=False) +# run_executorch("mobilevit_s", train=False, xnn=True, num_executions=100, multithreaded=False) + +# next: single-threaded +# run_iree("vit_tiny_patch16_224", train=False, multithreaded=False) +# run_executorch("vit_tiny_patch16_224", train=False, multithreaded=True, profiling=True, num_executions=1) +run_iree("resnet18", train=False, multithreaded=True, num_executions=10) \ No newline at end of file diff --git a/examples/portable/executor_runner/executor_runner.cpp b/examples/portable/executor_runner/executor_runner.cpp index b5d008ff1e7..9487bbf3b76 100644 --- a/examples/portable/executor_runner/executor_runner.cpp +++ b/examples/portable/executor_runner/executor_runner.cpp @@ -295,9 +295,9 @@ int main(int argc, char** argv) { ET_CHECK(status == Error::Ok); // Print the first and last 100 elements of long lists of scalars. std::cout << executorch::extension::evalue_edge_items(100); - for (int i = 0; i < outputs.size(); ++i) { - std::cout << "Output " << i << ": " << outputs[i] << std::endl; - } + // for (int i = 0; i < outputs.size(); ++i) { + // std::cout << "Output " << i << ": " << outputs[i] << std::endl; + // } if (tracer.get_event_tracer()) { // Dump ETDump data containing profiling/debugging data to file specified in