Skip to content

Commit

Permalink
Update debug_compare (#2612)
Browse files Browse the repository at this point in the history
This PR fixes a bug of the debug_compare.py script.
  • Loading branch information
Hzfengsy committed Jul 2, 2024
1 parent 0575b92 commit c09b108
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions python/mlc_llm/testing/debug_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,31 +67,34 @@ def __init__( # pylint: disable=too-many-arguments, unused-argument
self,
mod: runtime.Module,
device: runtime.Device,
debug_dir: Path,
debug_out: Path,
time_eval: bool = True,
rtol: float = 1e-2,
atol: float = 1,
skip_rounds: int = 0,
):
super().__init__(mod, device, True, rtol, atol)
self.debug_out = debug_out
self.time_eval = time_eval
self.time_eval_results: Dict[str, Tuple[float, int]] = {}
self.visited: Set[str] = set([])
self.skip_rounds = skip_rounds
self.counter = 0
debug_out.mkdir(exist_ok=True, parents=True)

def reset(self, debug_dir: Path): # pylint: disable=unused-argument
def reset(self, debug_out: Path): # pylint: disable=unused-argument
"""Reset the state of the Instrument class
Note
----
`debug_dir` is not used in this class.
`debug_out` is not used in this class.
Parameters
----------
debug_out : Path
the directory to dump the .npz files
"""
self.debug_out = debug_out
_print_as_table(
sorted(
self.time_eval_results.items(),
Expand All @@ -101,6 +104,7 @@ def reset(self, debug_dir: Path): # pylint: disable=unused-argument
self.time_eval_results = {}
self.visited = set([])
self.counter = 0
debug_out.mkdir(exist_ok=True, parents=True)

def skip_instrument(self, func, name, before_run, ret_val, *args):
if name.startswith("shape_func"):
Expand Down Expand Up @@ -128,7 +132,12 @@ def compare(

if self.time_eval and name not in self.time_eval_results:
res = self.mod.time_evaluator(
name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6
name,
self.device,
number=20,
repeat=3,
min_repeat_ms=100,
# cache_flush_bytes=256 * 10**6
)(*new_args)
self.time_eval_results[name] = (res.mean, 1)
print(f"Time-eval result {name} on {self.device}:\n {res}")
Expand Down Expand Up @@ -159,19 +168,14 @@ def get_instrument(args):
lib = sess.load_module(os.path.basename(args.cmp_lib_path))
cmp_device = sess.cl(0)
else:
lib = tvm.runtime.load_module(
os.path.join(
args.artifact_path,
f"{args.model}-{args.quantization.name}-{args.cmp_device}.so",
)
)
lib = tvm.runtime.load_module(args.cmp_lib_path)
cmp_device = tvm.device(args.cmp_device)

return LibCompare(
lib,
cmp_device,
time_eval=args.time_eval,
debug_dir=Path(args.debug_dir),
debug_out=Path(args.debug_dir),
)


Expand Down

0 comments on commit c09b108

Please sign in to comment.