|
13 | 13 | from collections import defaultdict
|
14 | 14 |
|
15 | 15 | import torch
|
16 |
| -from functorch import make_fx |
17 | 16 | from graph_net.torch import utils
|
18 | 17 |
|
19 | 18 |
|
@@ -376,62 +375,6 @@ def collect_op_stats_with_symbolic_trace(model, sample_inputs, device):
|
376 | 375 | return meta_executor.is_complete, meta_executor.op_stats
|
377 | 376 |
|
378 | 377 |
|
379 |
| -def collect_op_stats_with_make_fx(model, sample_inputs): |
380 |
| - # Use meta tensors as input to avoid actually running the model |
381 |
| - meta_input_list = convert_real_to_meta(sample_inputs) |
382 |
| - |
383 |
| - try: |
384 |
| - # Generate FX Graph, and automatically fill in meta information |
385 |
| - fx_model = make_fx(model)(*meta_input_list) |
386 |
| - except Exception: |
387 |
| - print("Failed to execute make_fx") |
388 |
| - return False, None |
389 |
| - |
390 |
| - is_complete = True |
391 |
| - op_stats = {} |
392 |
| - for node in fx_model.graph.nodes: |
393 |
| - op_name = None |
394 |
| - if node.op == "call_module": |
395 |
| - # classname of module |
396 |
| - submod = fx_model.get_submodule(node.target) |
397 |
| - op_name = submod.__class__.__name__ |
398 |
| - elif node.op == "call_function": |
399 |
| - op_name = node.target.__name__ |
400 |
| - elif node.op == "call_method": |
401 |
| - op_name = node.target |
402 |
| - elif node.op in ["placeholder", "output", "get_attr"]: |
403 |
| - op_name = node.op |
404 |
| - else: |
405 |
| - assert False, f"node.op: {node.op}" |
406 |
| - |
407 |
| - dtype = None |
408 |
| - if node.op not in ["placeholder", "output"]: |
409 |
| - if "tensor_meta" in node.meta: |
410 |
| - tensor_meta = node.meta["tensor_meta"] |
411 |
| - dtype = tensor_meta.dtype |
412 |
| - # print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}") |
413 |
| - else: |
414 |
| - print( |
415 |
| - f"node.op={node.op}, node.target={node.target} has no tensor_meta!" |
416 |
| - ) |
417 |
| - is_complete = False |
418 |
| - |
419 |
| - op_name = ( |
420 |
| - op_name.replace(".default", "") |
421 |
| - .replace(".Tensor", "") |
422 |
| - .replace(".Scalar", "") |
423 |
| - ) |
424 |
| - dtype_str = str(dtype).replace("torch.", "") |
425 |
| - if op_stats.get(op_name, None) is None: |
426 |
| - op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1) |
427 |
| - else: |
428 |
| - op_stats[op_name].op_dtypes[dtype_str] = ( |
429 |
| - op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1 |
430 |
| - ) |
431 |
| - op_stats[op_name].count = op_stats[op_name].count + 1 |
432 |
| - return is_complete, op_stats |
433 |
| - |
434 |
| - |
435 | 378 | def collect_op_stats(model, sample_inputs, device):
|
436 | 379 | is_complete_symbolic, op_stats_symbolic = collect_op_stats_with_symbolic_trace(
|
437 | 380 | model, sample_inputs, device
|
|
0 commit comments