Skip to content

Commit c14fed5

Browse files
committed
Remove the make_fx implementation.
1 parent a738b6b commit c14fed5

File tree

1 file changed

+0
-57
lines changed

1 file changed

+0
-57
lines changed

graph_net/torch/collect_stats.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from collections import defaultdict
1414

1515
import torch
16-
from functorch import make_fx
1716
from graph_net.torch import utils
1817

1918

@@ -376,62 +375,6 @@ def collect_op_stats_with_symbolic_trace(model, sample_inputs, device):
376375
return meta_executor.is_complete, meta_executor.op_stats
377376

378377

379-
def collect_op_stats_with_make_fx(model, sample_inputs):
380-
# Use meta tensors as input to avoid actually running the model
381-
meta_input_list = convert_real_to_meta(sample_inputs)
382-
383-
try:
384-
# Generate FX Graph, and automatically fill in meta information
385-
fx_model = make_fx(model)(*meta_input_list)
386-
except Exception:
387-
print("Failed to execute make_fx")
388-
return False, None
389-
390-
is_complete = True
391-
op_stats = {}
392-
for node in fx_model.graph.nodes:
393-
op_name = None
394-
if node.op == "call_module":
395-
# classname of module
396-
submod = fx_model.get_submodule(node.target)
397-
op_name = submod.__class__.__name__
398-
elif node.op == "call_function":
399-
op_name = node.target.__name__
400-
elif node.op == "call_method":
401-
op_name = node.target
402-
elif node.op in ["placeholder", "output", "get_attr"]:
403-
op_name = node.op
404-
else:
405-
assert False, f"node.op: {node.op}"
406-
407-
dtype = None
408-
if node.op not in ["placeholder", "output"]:
409-
if "tensor_meta" in node.meta:
410-
tensor_meta = node.meta["tensor_meta"]
411-
dtype = tensor_meta.dtype
412-
# print(f"node.op={node.op}, node.target={node.target}, dtype={tensor_meta.dtype}")
413-
else:
414-
print(
415-
f"node.op={node.op}, node.target={node.target} has no tensor_meta!"
416-
)
417-
is_complete = False
418-
419-
op_name = (
420-
op_name.replace(".default", "")
421-
.replace(".Tensor", "")
422-
.replace(".Scalar", "")
423-
)
424-
dtype_str = str(dtype).replace("torch.", "")
425-
if op_stats.get(op_name, None) is None:
426-
op_stats[op_name] = OpStat(op_name, {dtype_str: 1}, 1)
427-
else:
428-
op_stats[op_name].op_dtypes[dtype_str] = (
429-
op_stats[op_name].op_dtypes.get(dtype_str, 0) + 1
430-
)
431-
op_stats[op_name].count = op_stats[op_name].count + 1
432-
return is_complete, op_stats
433-
434-
435378
def collect_op_stats(model, sample_inputs, device):
436379
is_complete_symbolic, op_stats_symbolic = collect_op_stats_with_symbolic_trace(
437380
model, sample_inputs, device

0 commit comments

Comments
 (0)