Skip to content

Latest commit

 

History

History
44 lines (30 loc) · 835 Bytes

README.md

File metadata and controls

44 lines (30 loc) · 835 Bytes

torch-graph

Simple PyTorch graph capturing.

Instructions

Please install graphviz first:

$ apt-get install graphviz

Clone and install this package:

$ pip install .

Examples:

from torchgraph import dispatch_capture, aot_capture, compile_capture

import torch
import torch.nn as nn
model = nn.Sequential(
    nn.Conv2d(16, 32, 3),
    nn.BatchNorm2d(32),
    nn.SiLU(),
).cuda()
x = torch.randn((2, 16, 8, 8), requires_grad=True, device="cuda")

# Capture joint forward and backward graph through dispatch
dispatch_capture(model, x)

# Capture separate forward and backward graphs through PyTorch AOTAutograd
aot_capture(model, x)

# Capture forward graph through PyTorch compile
compile_capture(model, x)

You'll find the captured graphs in .svg format under current folder.