Skip to content

Commit d8e2ae3

Browse files
author
twata
committed
Make test run
1 parent 9900664 commit d8e2ae3

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,16 @@ def __init__(self):
426426
super(Net, self).__init__()
427427
self.conv = torch.nn.Conv2d(1, 1, 3)
428428

429+
def map_f(self, u):
430+
return u + 1
431+
429432
def forward(self, x):
430-
y = self.conv(x)
431-
return list(ppe.map(lambda u: u + 1, y))[0]
433+
y1 = self.conv(x)
434+
y2 = self.conv(x)
435+
y = [{"u" : y1}, {"u": y2}]
436+
return list(ppe.map(self.map_f, y))[0]
432437

433-
run_model_test(Net(), (torch.rand(1, 1, 112, 112),), rtol=1e-03)
438+
model = Net()
439+
ppe.to(model, device="cpu")
440+
441+
run_model_test(model, (torch.rand(1, 1, 112, 112),), rtol=1e-03)

0 commit comments

Comments
 (0)