Skip to content

Commit

Permalink
feat/translate-fx-profiling-tutorial PyTorchKorea#767
Browse files Browse the repository at this point in the history
  • Loading branch information
Ssunbell committed Sep 3, 2023
1 parent 61dc8d1 commit da195e2
Showing 1 changed file with 106 additions and 123 deletions.
229 changes: 106 additions & 123 deletions intermediate_source/fx_profiling_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
# -*- coding: utf-8 -*-
"""
(beta) Building a Simple CPU Performance Profiler with FX
(๋ฒ ํƒ€) FX๋ฅผ ์ด์šฉํ•˜์—ฌ ๊ฐ„๋‹จํ•œ CPU ์„ฑ๋Šฅ ํ”„๋กœํŒŒ์ผ๋Ÿฌ(Profiler) ๋งŒ๋“ค๊ธฐ
*******************************************************
**Author**: `James Reed <https://github.com/jamesr66a>`_
**์ €์ž**: `James Reed <https://github.com/jamesr66a>`_
In this tutorial, we are going to use FX to do the following:
**๋ฒˆ์—ญ:** `์œ ์„ ์ข… <https://github.com/Ssunbell>`_
1) Capture PyTorch Python code in a way that we can inspect and gather
statistics about the structure and execution of the code
2) Build out a small class that will serve as a simple performance "profiler",
collecting runtime statistics about each part of the model from actual
runs.
์ด๋ฒˆ ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” FX๋ฅผ ์ด์šฉํ•ด์„œ ๋‹ค์Œ์„ ์ง„ํ–‰ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
1) ์ฝ”๋“œ์˜ ๊ตฌ์กฐ์™€ ์‹คํ–‰์— ๋Œ€ํ•œ ํ†ต๊ณ„๋ฅผ ์กฐ์‚ฌํ•˜๊ณ  ์ˆ˜์ง‘ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ์‹์œผ๋กœ ํŒŒ์ดํ† ์น˜ ํŒŒ์ด์ฌ ์ฝ”๋“œ๋ฅผ ํฌ์ฐฉํ•ฉ๋‹ˆ๋‹ค.
2) ์‹ค์ œ ์‹คํ–‰์—์„œ ๋ชจ๋ธ์˜ ๊ฐ ๋ถ€๋ถ„์— ๋Œ€ํ•ด์„œ ๋Ÿฐํƒ€์ž„ ํ†ต๊ณ„๋“ค์„ ์ˆ˜์ง‘ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ์„ฑ๋Šฅ "ํ”„๋กœํŒŒ์ผ๋Ÿฌ" ์—ญํ• ์„ ํ•  ์ž‘์€ ํด๋ž˜์Šค๋ฅผ ๋งŒ๋“ค์–ด๋ด…๋‹ˆ๋‹ค.
"""

######################################################################
# For this tutorial, we are going to use the torchvision ResNet18 model
# for demonstration purposes.
# ์ด๋ฒˆ ํŠœํ† ๋ฆฌ์–ผ์„ ์œ„ํ•ด์„œ, torchvision์˜ ResNet18 ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
# ๋ฐ๋ชจ ๋ฒ„์ „์ž…๋‹ˆ๋‹ค.

import torch
import torch.fx
Expand All @@ -26,211 +25,195 @@
rn18.eval()

######################################################################
# Now that we have our model, we want to inspect deeper into its
# performance. That is, for the following invocation, which parts
# of the model are taking the longest?
# ์ด์ œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™”์œผ๋ฏ€๋กœ, ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ์ข€๋” ๊นŠ์ด ์กฐ์‚ฌํ•˜๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค.
# ์ฆ‰, ๋‹ค์Œ ํ˜ธ์ถœ์— ๋Œ€ํ•ด์„œ, ๋ชจ๋ธ์˜ ์–ด๋–ค ๋ถ€๋ถ„์ด ๊ฐ€์žฅ ์˜ค๋ž˜ ๊ฑธ๋ฆฝ๋‹ˆ๊นŒ?
input = torch.randn(5, 3, 224, 224)
output = rn18(input)

######################################################################
# A common way of answering that question is to go through the program
# source, add code that collects timestamps at various points in the
# program, and compare the difference between those timestamps to see
# how long the regions between the timestamps take.
# ์œ„์˜ ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๊ฐ€์žฅ ์ผ๋ฐ˜์ ์ธ ๋‹ต์•ˆ์€ ํ”„๋กœ๊ทธ๋žจ ์†Œ์Šค๋ฅผ ํ†ตํ•ด
# ํ”„๋กœ๊ทธ๋žจ์˜ ๋‹ค์–‘ํ•œ ์ง€์ ์—์„œ ํƒ€์ž„์Šคํƒฌํ”„(timestamps)๋ฅผ ์ˆ˜์ง‘ํ•˜๋Š” ์ฝ”๋“œ๋ฅผ ์ถ”๊ฐ€ํ•˜๊ณ ,
# ํƒ€์ž„์Šคํƒฌํ”„ ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ๋น„๊ตํ•˜์—ฌ ํƒ€์ž„์Šคํƒฌํ”„ ๊ฐ„์˜ ์‹œ๊ฐ„ ๊ฐ„๊ฒฉ์„ ํ™•์ธํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
#
# That technique is certainly applicable to PyTorch code, however it
# would be nicer if we didn't have to copy over model code and edit it,
# especially code we haven't written (like this torchvision model).
# Instead, we are going to use FX to automate this "instrumentation"
# process without needing to modify any source.
# ์ด ๊ธฐ์ˆ ์€ ํ™•์‹คํžˆ ํŒŒ์ดํ† ์น˜ ์ฝ”๋“œ์— ์ ์šฉ ๊ฐ€๋Šฅํ•˜์ง€๋งŒ ๋ชจ๋ธ ์ฝ”๋“œ๋ฅผ ๋ณต์‚ฌํ•˜๊ณ  ํŽธ์ง‘ํ•  ํ•„์š”๊ฐ€ ์—†๋‹ค๋ฉด,
# ํŠนํžˆ ์ž์‹ ์ด ์ž‘์„ฑํ•˜์ง€ ์•Š์€ ์ฝ”๋“œ๋ผ๋ฉด(์ด torchvision ๋ชจ๋ธ์ฒ˜๋Ÿผ) ๋” ์ข‹์„ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
# FX๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์†Œ์Šค ์ฝ”๋“œ๋ฅผ ์ˆ˜์ •ํ•  ํ•„์š” ์—†์ด
# ์ด "์žฅ์น˜(instrumentation)" ํ”„๋กœ์„ธ์Šค๋ฅผ ์ž๋™ํ™”ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

######################################################################
# First, let's get some imports out of the way (we will be using all
# of these later in the code).
# ์ฒซ๋ฒˆ์งธ๋กœ, ๋‹ค์Œ์˜ ๋ฐฉ์‹์ฒ˜๋Ÿผ ๋ช‡๋ช‡ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค
# (์ด ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ์ฝ”๋“œ ๋’ท๋ถ€๋ถ„์—์„œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค).

import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

######################################################################
# .. note::
# ``tabulate`` is an external library that is not a dependency of PyTorch.
# We will be using it to more easily visualize performance data. Please
# make sure you've installed it from your favorite Python package source.
# ``tabulate``๋Š” ํŒŒ์ดํ† ์น˜์— ์ข…์†์„ฑ์ด ์—†๋Š” ์™ธ๋ถ€ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค.
# ์šฐ๋ฆฌ๋Š” tabulate๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ข€๋” ์‰ฝ๊ฒŒ ์„ฑ๋Šฅ์— ๊ด€ํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์‹œ๊ฐํ™”ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.
# ๋‹น์‹ ์ด ์„ ํ˜ธํ•˜๋Š” ํŒŒ์ด์ฌ ํŒจํ‚ค์ง€ ์†Œ์Šค๋กœ๋ถ€ํ„ฐ tabulate๋ฅผ ์„ค์น˜ํ•ด์ฃผ์‹œ๊ธธ ๋ฐ”๋ž๋‹ˆ๋‹ค.

######################################################################
# Capturing the Model with Symbolic Tracing
#
# ์ƒ์ง•์  ์ถ”์ (Symbolic Tracing)์„ ์ด์šฉํ•˜์—ฌ ๋ชจ๋ธ ํฌ์ฐฉํ•˜๊ธฐ
# -----------------------------------------
# Next, we are going to use FX's symbolic tracing mechanism to capture
# the definition of our model in a data structure we can manipulate
# and examine.
# ๋‹ค์Œ์œผ๋กœ, FX์˜ ์ƒ์ง•์  ์ถ”์  ๋ฉ”์ปค๋‹ˆ์ฆ์„ ํ™œ์šฉํ•˜์—ฌ ์šฐ๋ฆฌ๊ฐ€ ์กฐ์ž‘ํ•˜๊ณ 
# ์กฐ์‚ฌํ•  ์ˆ˜ ์žˆ๋Š” ์ž๋ฃŒ๊ตฌ์กฐ์—์„œ ์šฐ๋ฆฌ ๋ชจ๋ธ์˜ ์ •์˜๋ฅผ ํฌ์ฐฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

traced_rn18 = torch.fx.symbolic_trace(rn18)
print(traced_rn18.graph)

######################################################################
# This gives us a Graph representation of the ResNet18 model. A Graph
# consists of a series of Nodes connected to each other. Each Node
# represents a call-site in the Python code (whether to a function,
# a module, or a method) and the edges (represented as ``args`` and ``kwargs``
# on each node) represent the values passed between these call-sites. More
# information about the Graph representation and the rest of FX's APIs ca
# be found at the FX documentation https://pytorch.org/docs/master/fx.html.
# ์ด๊ฒƒ์€ ResNet18 ๋ชจ๋ธ์˜ ๊ทธ๋ž˜ํ”„ ํ‘œํ˜„์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
# ๊ทธ๋ž˜ํ”„๋Š” ์„œ๋กœ ์—ฐ๊ฒฐ๋œ ์ผ๋ จ์˜ ๋…ธ๋“œ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.
# ๊ฐ ๋…ธ๋“œ๋Š” Python ์ฝ”๋“œ(ํ•จ์ˆ˜, ๋ชจ๋“ˆ ๋˜๋Š” ๋ฉ”์†Œ๋“œ ์—ฌ๋ถ€)์—์„œ ํ˜ธ์ถœ ์‚ฌ์ดํŠธ๋ฅผ ๋‚˜ํƒ€๋‚ด๊ณ ,
# ์—ฃ์ง€(๊ฐ ๋…ธ๋“œ์—์„œ "args" ๋ฐ "kwargs"๋กœ ํ‘œ์‹œ๋จ)๋Š”
# ์ด๋Ÿฌํ•œ ํ˜ธ์ถœ ๊ฒฝ๋กœ ์‚ฌ์ด์— ์ „๋‹ฌ๋œ ๊ฐ’์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
# ๊ทธ๋ž˜ํ”„ ํ‘œํ˜„๊ณผ FX์˜ ๋‚˜๋จธ์ง€ API์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์ •๋ณด๋Š”
# FX ์„ค๋ช…์„œ https://pytorch.org/docs/master/fx.html ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


######################################################################
# Creating a Profiling Interpreter
# ํ”„๋กœํŒŒ์ผ๋ง Interpreter ์ƒ์„ฑํ•˜๊ธฐ
# --------------------------------
# Next, we are going to create a class that inherits from ``torch.fx.Interpreter``.
# Though the ``GraphModule`` that ``symbolic_trace`` produces compiles Python code
# that is run when you call a ``GraphModule``, an alternative way to run a
# ``GraphModule`` is by executing each ``Node`` in the ``Graph`` one by one. That is
# the functionality that ``Interpreter`` provides: It interprets the graph node-
# by-node.
# ๋‹ค์Œ์œผ๋กœ, ์šฐ๋ฆฌ๋Š” ``torch.fx.Interpreter``๋กœ๋ถ€ํ„ฐ ์ƒ์†๋ฐ›์€ ํด๋ž˜์Šค๋ฅผ ์ƒ์„ฑํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.
# ๋น„๋ก ``symbolic_trace``๊ฐ€ ์ƒ์„ฑํ•˜๋Š” ``GraphModule``์€ ``GraphModule``์„
# ํ˜ธ์ถœํ•  ๋•Œ ์‹คํ–‰๋˜๋Š” ํŒŒ์ด์ฌ ์ฝ”๋“œ๋ฅผ ํ•œ๋ฒˆ์— ์ปดํŒŒ์ผํ•˜์ง€๋งŒ, ``GraphModule``์„ ์‹คํ–‰ํ•˜๋Š”
# ๋Œ€์•ˆ์ ์ธ ๋ฐฉ๋ฒ•์€ ``graph``์˜ ๊ฐ ``node``๋ฅผ ํ•˜๋‚˜์”ฉ ์‹คํ–‰ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
# ์ด๊ฒƒ์ด ``Interpreter``๊ฐ€ ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋Šฅ์ž…๋‹ˆ๋‹ค. Interpreter๋Š” ๊ทธ๋ž˜ํ”„๋ฅผ ๋…ธ๋“œ ๋‹จ์œ„๋กœ ์ธํ„ฐํ”„๋ฆฌํŠธํ•ฉ๋‹ˆ๋‹ค.
#
# By inheriting from ``Interpreter``, we can override various functionality and
# install the profiling behavior we want. The goal is to have an object to which
# we can pass a model, invoke the model 1 or more times, then get statistics about
# how long the model and each part of the model took during those runs.
# ``Interpreter``๋กœ๋ถ€ํ„ฐ ์ƒ์†๋ฐ›์Œ์œผ๋กœ์จ, ๋‹ค์–‘ํ•œ ๊ธฐ๋Šฅ์„ ๋ฎ์—ฌ์”Œ์šธ ์ˆ˜ ์žˆ๊ณ 
# ์›ํ•˜๋Š” ํ”„๋กœํŒŒ์ผ๋ง ํ–‰๋™์„ ์„ค์น˜ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ชฉํ‘œ๋Š” ๋ชจ๋ธ์„ ์ „๋‹ฌํ•˜๊ณ  ๋ชจ๋ธ์„
# 1ํšŒ ์ด์ƒ ํ˜ธ์ถœํ•œ ๋‹ค์Œ ๋ชจ๋ธ๊ณผ ๋ชจ๋ธ์˜ ๊ฐ ๋ถ€๋ถ„์ด ์‹คํ–‰๋˜๋Š” ๋™์•ˆ ๊ฑธ๋ฆฐ ์‹œ๊ฐ„์— ๋Œ€ํ•œ
# ํ†ต๊ณ„๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๋Š” ๊ฐ์ฒด๋ฅผ ๊ฐ–๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
#
# Let's define our ``ProfilingInterpreter`` class:
# ``ProfilingInterpreter`` ํด๋ž˜์Šค๋ฅผ ์ •์˜ํ•ด๋ด…์‹œ๋‹ค:

class ProfilingInterpreter(Interpreter):
def __init__(self, mod : torch.nn.Module):
# Rather than have the user symbolically trace their model,
# we're going to do it in the constructor. As a result, the
# user can pass in any ``Module`` without having to worry about
# symbolic tracing APIs
# ์‚ฌ์šฉ์ž๊ฐ€ ์ž์‹ ์˜ ๋ชจ๋ธ์„ ์ƒ์ง•์ ์œผ๋กœ ์ถ”์ ํ•˜๋„๋ก ํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค๋Š”,
# ์šฐ๋ฆฌ๋Š” ๊ทธ๊ฒƒ์„ constructor์—์„œ ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ์ ์œผ๋กœ
# ์‚ฌ์šฉ์ž๋Š” ๊ธฐํ˜ธ ์ถ”์  API์— ๋Œ€ํ•œ ๊ฑฑ์ • ์—†์ด ``Module``์„
# ํ†ต๊ณผํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
gm = torch.fx.symbolic_trace(mod)
super().__init__(gm)

# We are going to store away two things here:
# ์šฐ๋ฆฌ๋Š” ์—ฌ๊ธฐ์— ๋‘ ๊ฐ€์ง€๋ฅผ ์ €์žฅํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค:
#
# 1. A list of total runtimes for ``mod``. In other words, we are
# storing away the time ``mod(...)`` took each time this
# interpreter is called.
# 1. "mod"์˜ ์ด ์‹คํ–‰ ์‹œ๊ฐ„ ๋ชฉ๋ก. ์ฆ‰, ์ธํ„ฐํ”„๋ฆฌํ„ฐ๊ฐ€ ํ˜ธ์ถœ๋ 
# ๋•Œ๋งˆ๋‹ค `mod(...)`` ์‹œ๊ฐ„์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.
self.total_runtime_sec : List[float] = []
# 2. A map from ``Node`` to a list of times (in seconds) that
# node took to run. This can be seen as similar to (1) but
# for specific sub-parts of the model.
# 2. ๋…ธ๋“œ๊ฐ€ ์‹คํ–‰๋˜๋Š” ๋ฐ ๊ฑธ๋ฆฐ ์‹œ๊ฐ„(์ดˆ) ๋ชฉ๋ก์— ๋Œ€ํ•œ ``๋…ธ๋“œ``์˜ ๋งต์ž…๋‹ˆ๋‹ค.
# ์ด๋Š” (1)๊ณผ ์œ ์‚ฌํ•˜์ง€๋งŒ ๋ชจ๋ธ์˜ ํŠน์ • ํ•˜์œ„ ๋ถ€๋ถ„์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
self.runtimes_sec : Dict[torch.fx.Node, List[float]] = {}

######################################################################
# Next, let's override our first method: ``run()``. ``Interpreter``'s ``run``
# method is the top-level entry point for execution of the model. We will
# want to intercept this so that we can record the total runtime of the
# model.
# ๋‹ค์Œ์œผ๋กœ, ์šฐ๋ฆฌ์˜ ์ฒซ ๋ฒˆ์งธ ๋งค์„œ๋“œ์ธ ``run()``์„ ๋ฎ์—ฌ์”Œ์›๋‹ˆ๋‹ค. ``Interpreter``์˜ ``run``
# ๋งค์„œ๋“œ๋Š” ๋ชจ๋ธ ์‹คํ–‰์„ ์œ„ํ•œ ์ตœ์ƒ์œ„ ์ง„์ž…์ ์ž…๋‹ˆ๋‹ค. ๋ชจ๋ธ์˜ ์ด ๋Ÿฐํƒ€์ž„์„ ๊ธฐ๋กํ•  ์ˆ˜ ์žˆ๋„๋ก
# ์ด ๋งค์„œ๋“œ๋ฅผ ๊ฐ€๋กœ์ฑ„๊ณ ์ž ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

def run(self, *args) -> Any:
# Record the time we started running the model
# ๋ชจ๋ธ์„ ๋™์ž‘ํ•˜๊ธฐ ์‹œ์ž‘ํ•œ ์‹œ์ ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
t_start = time.time()
# Run the model by delegating back into Interpreter.run()
# Interpreter.run()์— ๋‹ค์‹œ ์œ„์ž„ํ•˜์—ฌ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
return_val = super().run(*args)
# Record the time we finished running the model
# ๋ชจ๋ธ ๋™์ž‘์ด ๋๋‚œ ์‹œ์ ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
t_end = time.time()
# Store the total elapsed time this model execution took in the
# ``ProfilingInterpreter``
# ``ProfilingInterpreter``์— ๋ชจ๋ธ ์‹คํ–‰์— ๊ฑธ๋ฆฐ ์ด ๊ฒฝ๊ณผ ์‹œ๊ฐ„์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
self.total_runtime_sec.append(t_end - t_start)
return return_val

######################################################################
# Now, let's override ``run_node``. ``Interpreter`` calls ``run_node`` each
# time it executes a single node. We will intercept this so that we
# can measure and record the time taken for each individual call in
# the model.
# ์ด์ œ, ``run_node``๋ฅผ ๋ฎ์—ฌ์”Œ์›๋‹ˆ๋‹ค. ``Interpreter``๋Š” ํ•˜๋‚˜์˜ ๋…ธ๋“œ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด
# ``run_node``๋ฅผ ๊ฐ๊ฐ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์—์„œ ๊ฐ๊ฐ์˜ ๊ฐœ๋ณ„ ํ˜ธ์ถœ์— ๊ฑธ๋ฆฐ ์‹œ๊ฐ„์„ ๊ธฐ๋กํ•˜๊ณ 
# ์ธก์ •ํ•˜๊ธฐ ์œ„ํ•˜์—ฌ ์ด๊ฒƒ์„ ๊ฐ€๋กœ์ฑŒ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

def run_node(self, n : torch.fx.Node) -> Any:
# Record the time we started running the op
# op๋ฅผ ์‹คํ–‰ํ•œ ์‹œ์ž‘ ์‹œ์ ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
t_start = time.time()
# Run the op by delegating back into Interpreter.run_node()
# Interpreter.run_node()์— ๋‹ค์‹œ ์œ„์ž„ํ•˜์—ฌ op๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
return_val = super().run_node(n)
# Record the time we finished running the op
# op๊ฐ€ ๋๋‚œ ์‹œ์ ์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
t_end = time.time()
# If we don't have an entry for this node in our runtimes_sec
# data structure, add one with an empty list value.
# runtimes_sec ์ž๋ฃŒ๊ตฌ์กฐ์— ์ด ๋…ธ๋“œ์— ๋Œ€ํ•œ ํ•ญ๋ชฉ์ด ์—†์œผ๋ฉด,
# ๋นˆ ๋ฆฌ์ŠคํŠธ ๊ฐ’์„ ๊ฐ€์ง„ ํ•ญ๋ชฉ์„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.
self.runtimes_sec.setdefault(n, [])
# Record the total elapsed time for this single invocation
# in the runtimes_sec data structure
# ์ด ๋‹จ์ผ ํ˜ธ์ถœ์˜ ์ด ๊ฒฝ๊ณผ ์‹œ๊ฐ„์„ runtimes_sec ์ž๋ฃŒ๊ตฌ์กฐ์•ˆ์— ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
self.runtimes_sec[n].append(t_end - t_start)
return return_val

######################################################################
# Finally, we are going to define a method (one which doesn't override
# any ``Interpreter`` method) that provides us a nice, organized view of
# the data we have collected.
# ๋งˆ์ง€๋ง‰์œผ๋กœ, ์šฐ๋ฆฌ๊ฐ€ ์ˆ˜์ง‘ํ•œ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ๋ฉ‹์ง€๊ณ  ์กฐ์ง์ ์ธ ๋ทฐ๋ฅผ ์ œ๊ณตํ•˜๋Š”
# ๋งค์„œ๋“œ(``Interpreter`` ๋งค์„œ๋“œ๋ฅผ ๋ฎ์—ฌ์”Œ์šฐ์ง€ ์•Š๋Š” ๋ฐฉ๋ฒ•)์„ ์ •์˜ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค

def summary(self, should_sort : bool = False) -> str:
# Build up a list of summary information for each node
# ๊ฐ ๋…ธ๋“œ์— ๋Œ€ํ•œ ์š”์•ฝ ์ •๋ณด๊ฐ€ ๋‹ด๊ธด ๋ฆฌ์ŠคํŠธ๋ฅผ ์„ ์–ธํ•ฉ๋‹ˆ๋‹ค.
node_summaries : List[List[Any]] = []
# Calculate the mean runtime for the whole network. Because the
# network may have been called multiple times during profiling,
# we need to summarize the runtimes. We choose to use the
# arithmetic mean for this.
# ์ „์ฒด ๋„คํŠธ์›Œํฌ์˜ ํ‰๊ท  ๋Ÿฐํƒ€์ž„์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ํ”„๋กœํŒŒ์ผ๋ง ์ค‘์— ๋„คํŠธ์›Œํฌ๊ฐ€
# ์—ฌ๋Ÿฌ ๋ฒˆ ํ˜ธ์ถœ๋˜์—ˆ์„ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋Ÿฐํƒ€์ž„์„ ์š”์•ฝํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
# ์ด๋ฅผ ์œ„ํ•ด ์—ฌ๋Ÿฌ ๋ฐฉ๋ฒ• ์ค‘ ์‚ฐ์ˆ  ํ‰๊ท ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
mean_total_runtime = statistics.mean(self.total_runtime_sec)

# For each node, record summary statistics
# ๊ฐ ๋…ธ๋“œ์— ๋Œ€ํ•ด ์š”์•ฝ ํ†ต๊ณ„๋ฅผ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
for node, runtimes in self.runtimes_sec.items():
# Similarly, compute the mean runtime for ``node``
# ๋น„์Šทํ•˜๊ฒŒ, ``node``์— ๋Œ€ํ•œ ํ‰๊ท  ๋Ÿฐํƒ€์ž„์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
mean_runtime = statistics.mean(runtimes)
# For easier understanding, we also compute the percentage
# time each node took with respect to the whole network.
# ๋” ์‰ฌ์šด ์ดํ•ด๋ฅผ ๋•๊ธฐ ์œ„ํ•ด, ์ „์ฒด ๋„คํŠธ์›Œํฌ์— ๋Œ€ํ•ด์„œ
# ๊ฐ๊ฐ์˜ ๋…ธ๋“œ์˜ ํผ์„ผํŠธ ์‹œ๊ฐ„ ๋˜ํ•œ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
pct_total = mean_runtime / mean_total_runtime * 100
# Record the node's type, name of the node, mean runtime, and
# percent runtime.
# ๋…ธ๋“œ์˜ ํƒ€์ž…, ์ด๋ฆ„, ํ‰๊ท  ๋Ÿฐํƒ€์ž„ ๊ทธ๋ฆฌ๊ณ  ํผ์„ผํŠธ ๋Ÿฐํƒ€์ž„์„ ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
node_summaries.append(
[node.op, str(node), mean_runtime, pct_total])

# One of the most important questions to answer when doing performance
# profiling is "Which op(s) took the longest?". We can make this easy
# to see by providing sorting functionality in our summary view
# ์„ฑ๋Šฅ ํ”„๋กœํŒŒ์ผ๋ง์„ ํ• ๋•Œ ๋Œ€๋‹ตํ•ด์•ผ ํ•  ๊ฐ€์žฅ ์ค‘์š”ํ•œ ์งˆ๋ฌธ ์ค‘์˜ ํ•˜๋‚˜๋Š” "์–ด๋–ค op(๋“ค)์—์„œ ๊ฐ€์žฅ
# ๊ธด ์‹œ๊ฐ„์ด ๊ฑธ๋ ธ๋‚˜?"์ž…๋‹ˆ๋‹ค. ์š”์•ฝ ๋ทฐ์—์„œ ์ •๋ ฌ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•จ์œผ๋กœ์จ ํ•ด๋‹น ์งˆ๋ฌธ์— ๋Œ€ํ•œ ๋‹ต์„
# ์‰ฝ๊ฒŒ ์ฐพ์•„๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
if should_sort:
node_summaries.sort(key=lambda s: s[2], reverse=True)

# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
# ``tabulate`` ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ด์šฉํ•ด ์š”์•ฝ ์ •๋ณด๋“ค์„ ์ž˜ ๊ตฌ์„ฑ๋œ ํ‘œ๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค.
headers : List[str] = [
'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
]
return tabulate.tabulate(node_summaries, headers=headers)

######################################################################
# .. note::
# We use Python's ``time.time`` function to pull wall clock
# timestamps and compare them. This is not the most accurate
# way to measure performance, and will only give us a first-
# order approximation. We use this simple technique only for the
# purpose of demonstration in this tutorial.
# Python์˜ "time.time" ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฒฝ์‹œ๊ณ„์˜ ํƒ€์ž„์Šคํƒฌํ”„๋ฅผ
# ๋‹น๊ฒจ์„œ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์„ฑ๋Šฅ์„ ์ธก์ •ํ•˜๋Š” ๊ฐ€์žฅ ์ •ํ™•ํ•œ ๋ฐฉ๋ฒ•์€ ์•„๋‹ˆ๋ฉฐ
# 1์ฐจ์ ์ธ ๊ทผ์‚ฌ๊ฐ’๋งŒ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฐ„๋‹จํ•œ ๊ธฐ๋ฒ•์€ ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ ์‹œ์—ฐํ• 
# ๋ชฉ์ ์œผ๋กœ๋งŒ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

######################################################################
# Investigating the Performance of ResNet18
# ResNet18์˜ ์„ฑ๋Šฅ ์กฐ์‚ฌํ•˜๊ธฐ
# -----------------------------------------
# We can now use ``ProfilingInterpreter`` to inspect the performance
# characteristics of our ResNet18 model;
# ``ProfilingInterpreter``๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ResNet18 ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ ํŠน์ง•๋“ค์„ ์กฐ์‚ฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))

######################################################################
# ๊ผญ ํ˜ธ์ถœํ•ด์•ผ ํ•  ๋‘๊ฐ€์ง€๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
# There are two things we should call out here:
#
# * ``MaxPool2d`` takes up the most time. This is a known issue:
# * ``MaxPool2d``์€ ๊ฐ€์žฅ ๋งŽ์€ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ ์ž˜ ์•Œ๋ ค์ ธ์žˆ๋Š” ์ด์Šˆ์ž…๋‹ˆ๋‹ค:
# https://github.com/pytorch/pytorch/issues/51393
# * BatchNorm2d also takes up significant time. We can continue this
# line of thinking and optimize this in the Conv-BN Fusion with FX
# `tutorial <https://tutorials.pytorch.kr/intermediate/fx_conv_bn_fuser.html>`_.
# * BatchNorm2d ๋˜ํ•œ ์ƒ๋‹นํ•œ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. FX ํŠœํ† ๋ฆฌ์–ผ
# <https://tutorials.pytorch.kr/intermediate/fx_conv_bn_fuser.html>`_.
# ์˜ Conv-BN Fusion์—์„œ ์ข€๋” ์ƒ๊ฐํ•  ์‹œ๊ฐ„์„ ๊ฐ–๊ณ  ์ตœ์ ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
#
#
# Conclusion
# ๊ฒฐ๋ก 
# ----------
# As we can see, using FX we can easily capture PyTorch programs (even
# ones we don't have the source code for!) in a machine-interpretable
# format and use that for analysis, such as the performance analysis
# we've done here. FX opens up an exciting world of possibilities for
# working with PyTorch programs.
# ๋ณด์‹œ๋‹ค์‹œํ”ผ FX๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด PyTorch ํ”„๋กœ๊ทธ๋žจ(์†Œ์Šค ์ฝ”๋“œ๊ฐ€ ์—†๋Š” ํ”„๋กœ๊ทธ๋žจ๋„!)์„
# ๊ธฐ๊ณ„ ํ•ด์„์ด ๊ฐ€๋Šฅํ•œ ํ˜•์‹์œผ๋กœ ์‰ฝ๊ฒŒ ํฌ์ฐฉํ•˜์—ฌ ์—ฌ๊ธฐ์—์„œ ์ˆ˜ํ–‰ํ•œ ์„ฑ๋Šฅ ๋ถ„์„๊ณผ ๊ฐ™์€
# ๋ถ„์„์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. FX๋Š” PyTorch ํ”„๋กœ๊ทธ๋žจ๊ณผ ํ•จ๊ป˜ ์ž‘์—…ํ•  ์ˆ˜ ์žˆ๋Š”
# ํฅ๋ฏธ๋กœ์šด ๊ฐ€๋Šฅ์„ฑ์˜ ์„ธ๊ณ„๋ฅผ ์—ฝ๋‹ˆ๋‹ค.
#
# Finally, since FX is still in beta, we would be happy to hear any
# feedback you have about using it. Please feel free to use the
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
# you might have.
# ๋งˆ์ง€๋ง‰์œผ๋กœ, FX๋Š” ์—ฌ์ „ํžˆ ๋ฒ ํƒ€ ๋ฒ„์ „์ด๊ธฐ ๋•Œ๋ฌธ์—, ์—ฌ๋Ÿฌ๋ถ„์ด ์ด๊ฒƒ์„ ์‚ฌ์šฉํ•ด๋ณด์‹œ๋ฉด์„œ
# ์–ด๋–ค ํ”ผ๋“œ๋ฐฑ๋„ ๊ธฐ๊บผ์ด ๊ท€๊ธฐ์šธ์ผ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
# ํŒŒ์ดํ† ์น˜ ํฌ๋Ÿผ(https://discuss.pytorch.org/)์ด๋‚˜ ์ด์Šˆ ํŠธ๋ž˜์ปค
# (https://github.com/pytorch/pytorch/issues)๋ฅผ ํ†ตํ•ด
# ์—ฌ๋Ÿฌ๋ถ„๋“ค์ด ์ƒ๊ฐํ•˜์‹œ๋Š” ์–ด๋–ค ํ”ผ๋“œ๋ฐฑ์ด๋ผ๋„ ์ œ๋ณดํ•ด์ฃผ์‹œ๊ธธ ๋ฐ”๋ž๋‹ˆ๋‹ค.

0 comments on commit da195e2

Please sign in to comment.