Skip to content

Commit

Permalink
[JIT] Add torch._C.ScriptList` (pytorch#52832)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#52832

**Summary**
This commit adds `torch._C.ScriptList`, a list type that has reference
semantics across the Python/TorchScript boundary. That is, modifications
made in TorchScript to instances of `torch._C.ScriptList`
are visible in Python even when it is not returned from the function.

`torch._C.ScriptList` is implemented using a modified version of pybind's
`stl_bind.h`-style bindings attached to `ScriptList` and `ScriptListIterator`,
wrapper classes around `c10::impl::GenericList` and
`c10::impl::GenericList::iterator`. These bindings allow instances of
`torch._C.ScriptList` to be used as if it were a
regular `list` in Python. Reference semantics are achieved by simply
retrieving the `IValue` contained in `ScriptList` in `toIValue` (invoked
when converting Python arguments to `IValues` before calling TorchScript
code).

**Test Plan**
This commit adds `TestScriptList` to `test_list_dict.py`, a set of tests
that check that all of the common list operations are supported
and that instances have reference semantics across the
Python/TorchScript boundary.

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D29478121

Pulled By: SplitInfinity

fbshipit-source-id: 652cc25cfa37debe28db9527504846f22abd8b54
  • Loading branch information
Meghan Lele authored and facebook-github-bot committed Jul 2, 2021
1 parent 6e9e30c commit 4a2e8b5
Show file tree
Hide file tree
Showing 10 changed files with 941 additions and 8 deletions.
360 changes: 360 additions & 0 deletions test/jit/test_list_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2170,3 +2170,363 @@ def test_reference_semantics(self):
self.assertEqual(len(data), 2)
self.assertTrue(3 in data)
self.assertEqual(data[3], 4)


class TestScriptList(JitTestCase):
"""
This class contains a suite of tests for torch._C.ScriptList, a
function that returns a list-like object that has reference
semantics across the Python/TorchScript boundary. That is,
it can be passed to a TorchScript function that mutates it
and those modifications are visible in the scope of the Python
caller of said TorchScript function.
The vast majority of tests are for making sure that instances of
torch._C.ScriptList behave like lists do so that they are fungible
in almost all cirumstances with regular list.
"""
def _script_list_add(self, l: torch._C.ScriptList, e: int):
"""
This is a helper function that inserts the element e into the
list l in TorchScript. It is used for testing reference
semantics.
"""
@torch.jit.script
def list_add(l: List[int], e: int):
l.append(e)

list_add(l, e)

def _compare_eager_and_script(self, fn, input_list, script_input_list=None):
"""
This is a helper function that facilitates comparing behaviour between
Python lists and "scripted" lists.
Args:
fn: The function to test and compare the behaviour of.
input_list: The input list to use for the test (passed to fn).
script_input_list: The scripted input list to use for the tests.
If None, input_list is scripted with torch.jit.script
and used instead.
"""
# Create ScriptDict version of input_list if needed.
script_input_list = script_input_list or torch.jit.script(input_list)

# Run fn with both input_list and scripted_dict.
eager_raised, script_raised = False, False

try:
eager_out = fn(input_list)
except Exception as e:
eager_exception = e
eager_raised = True

try:
script_out = fn(script_input_list)
except Exception as e:
script_exception = e
script_raised = True

# Check that both calls raised or none of them raised.
self.assertEqual(eager_raised, script_raised)

if eager_raised:
# If fn raised an exception, it should be the same between
# regular and scripted lists.
self.assertEqual(type(eager_exception), type(script_exception))
else:
# Otherwise, make sure the outputs match and the lists
# match (the latter may not be the same as the output).
self.assertEqual(eager_out, script_out)
self.assertEqual(input_list, script_input_list)

def test_repr(self):
"""
Test the __repr__ method.
"""
self._compare_eager_and_script(lambda l: repr(l), [1])

def test_bool(self):
"""
Test the __bool__ method. This should return True
if the list is non-empty and False otherwise.
"""
self._compare_eager_and_script(lambda l: bool(l), [1])
self._compare_eager_and_script(lambda l: bool(l), [])

def test_iter(self):
"""
Test iteration over a list's elements.
"""
def sum_elements(input_list):
s = 0
for k in input_list:
s += k

return s

self._compare_eager_and_script(sum_elements, [1, 2, 3, 4])

def test_getitem(self):
"""
Test accessing list elements using the [] operator.
"""
data = [1, 2, 3, 4]

# Test regular indexing.
self._compare_eager_and_script(lambda l: l[1], data)
self._compare_eager_and_script(lambda l: l[3], data)
self._compare_eager_and_script(lambda l: l[-1], data)

# Test slicing.
self._compare_eager_and_script(lambda l: l[1:3], data)
self._compare_eager_and_script(lambda l: l[:], data)
self._compare_eager_and_script(lambda l: l[1:], data)
self._compare_eager_and_script(lambda l: l[:2], data)
self._compare_eager_and_script(lambda l: l[-1], data)
self._compare_eager_and_script(lambda l: l[-1::-1], data)

# Test errors.
self._compare_eager_and_script(lambda l: l[5], data)
self._compare_eager_and_script(lambda l: l[-7], data)
self._compare_eager_and_script(lambda l: l["key"], data)

def test_setitem(self):
"""
Test setting list elements using the [] operator.
"""
data = [1, 2, 3, 4]

# Test regular assignment.
def setitem(input_list):
input_list[1] = 10
input_list[3] = 11
input_list[-1] = 12

self._compare_eager_and_script(setitem, data.copy())

# Test slice assignment.
# TODO: Something like input_list[:1] = [1, 2, 3, 4, 5]
# is allowed in Python, but pybind11/stl_bind.h does not
# allow it. Should we?
def setitem_slice(input_list):
input_list[:4:2] = [10, 11]
input_list[-2:] = [15, 16]

self._compare_eager_and_script(setitem_slice, data)

# Test errors.
def out_of_range(input_list):
input_list[11] = 3

def out_of_range_negative(input_list):
input_list[-11] = 3

def wrong_index_type(input_list):
input_list["str"] = 3

self._compare_eager_and_script(out_of_range, data)
self._compare_eager_and_script(out_of_range_negative, data)
self._compare_eager_and_script(wrong_index_type, data)

# Check that using value of an incorrect type throws TypeError.
# _compare_eager_and_script cannot be used here since
# the following use of __setitem__ is valid in
# Python.
script_data = torch.jit.script(data)

with self.assertRaises(TypeError):
script_data[0] = "str"

def test_contains(self):
"""
Test membership checks (x in y, x not in y).
"""
data = [1, 2, 3, 4]

def fn(input_list):
return 1 in input_list, 2 not in input_list, 3 in input_list, 4 not in input_list

self._compare_eager_and_script(fn, data)

# Check that using a value of an incorrect type throws a TypeError.
script_data = torch.jit.script(data)

with self.assertRaises(TypeError):
a = "str" in script_data

def test_delitem(self):
"""
Test deletion.
"""
data = [1, 2, 3, 4]

def del_fn(input_list):
del input_list[1]

def del_fn_out_of_range(input_list):
del input_list[10]

def del_fn_wrong_type(input_list):
del input_list["str"]

self._compare_eager_and_script(del_fn, data.copy())
self._compare_eager_and_script(del_fn_out_of_range, data)
self._compare_eager_and_script(del_fn_wrong_type, data)

def test_len(self):
"""
Test len() builtin function.
"""
self._compare_eager_and_script(lambda l: len(l), [1, 2, 3, 4])
self._compare_eager_and_script(lambda l: len(l), [])

def test_count(self):
"""
Test count method.
"""
self._compare_eager_and_script(lambda l: l.count(3), [1, 2, 3, 3])

# Check that using a value of an incorrect type throws TypeError.
script_data = torch.jit.script([1])

with self.assertRaises(TypeError):
script_data.count("str")

def test_remove(self):
"""
Test remove method.
"""
self._compare_eager_and_script(lambda l: l.remove(1), [1, 2, 3])
self._compare_eager_and_script(lambda l: l.remove(10), [1, 2, 3])

# Check that using a value of an incorrect type throws TypeError.
script_data = torch.jit.script([1])

with self.assertRaises(TypeError):
script_data.remove("str")

def test_append(self):
"""
Test append method.
"""
self._compare_eager_and_script(lambda l: l.append(1), [4, 3, 2])

# Check that using a value of an incorrect type throws TypeError.
script_data = torch.jit.script([1])

with self.assertRaises(TypeError):
script_data.append("str")

def test_clear(self):
"""
Test clear.
"""
self._compare_eager_and_script(lambda l: l.clear(), [4, 3, 2])

def test_extend(self):
"""
Test extend.
"""
class Iterable(object):
def __init__(self, limit: int):
self.limit = limit
self.value = 0

def __iter__(self):
return self

def __next__(self):
if self.value == limit:
raise StopIteration()

ret = self.value
self.value += 1
return ret

data = [1, 2, 3]

def extend_list(input_list):
input_list.extend([4, 5, 6])

def extend_dict(input_list):
input_list.extend({4: 10, 5: 11, 6: 12})

def extend_iterable(input_list):
input_list.extend(Iterable(3))

self._compare_eager_and_script(extend_list, data.copy())
self._compare_eager_and_script(extend_dict, data.copy())
self._compare_eager_and_script(extend_iterable, data)

# Check that using a value of an incorrect type throws TypeError.
script_data = torch.jit.script([1])

with self.assertRaises(TypeError):
script_data.extend(["a"])

with self.assertRaises(TypeError):
script_data.extend({"a": 1})

def test_insert(self):
"""
Test insert.
"""
data = [1, 2, 4]

self._compare_eager_and_script(lambda l: l.insert(3, 3), data.copy())
self._compare_eager_and_script(lambda l: l.insert(0, 3), data.copy())
self._compare_eager_and_script(lambda l: l.insert(-2, 3), data)

# Check that using a value of an incorrect type throws TypeError.
script_data = torch.jit.script([1])

with self.assertRaises(TypeError):
script_data.insert((0, "str"))

def test_pop(self):
"""
Test pop.
"""
data = [1, 2, 3, 4, 5]

# Test normal cases.
self._compare_eager_and_script(lambda l: l.pop(), data.copy())
self._compare_eager_and_script(lambda l: l.pop(2), data.copy())
self._compare_eager_and_script(lambda l: l.pop(-3), data.copy())

# Test error cases.
self._compare_eager_and_script(lambda l: l.pop(10), data)

@unittest.skip("Cannot pass until all list returned from TorchScript are ScriptLists")
def test_nested(self):
"""
Test that reference semantics are honoured when the ScriptList that is
mutated using TorchScript is inside another.
"""
nested = torch.jit.script([[1], [2]], List[List[int]])

one = nested[0]
two = nested[1]

self._script_list_add(one, 3)
self._script_list_add(two, 4)

# The mutation should be visible in the original list, nested.
self.assertEqual(len(one), 2)
self.assertEqual(len(two), 2)
self.assertEqual(one[len(one) - 1], 3)
self.assertEqual(two[len(one) - 1], 4)
self.assertEqual(len(nested[0]), 2)
self.assertEqual(len(nested[1]), 2)

def test_reference_semantics(self):
"""
Test that reference semantics are honoured; that modifications made
to a ScriptList in TorchScript are visible in Python.
"""
l = torch.jit.script([1, 2])
self._script_list_add(l, 3)

self.assertEqual(len(l), 3)
self.assertTrue(3 in l)
self.assertEqual(l[2], 3)
2 changes: 1 addition & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jit.test_type_sharing import TestTypeSharing # noqa: F401
from jit.test_logging import TestLogging # noqa: F401
from jit.test_backends import TestBackends, TestBackendsWithCompiler # noqa: F401
from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict # noqa: F401
from jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict, TestScriptList # noqa: F401
from jit.test_async import TestAsync # noqa: F401
from jit.test_data_parallel import TestDataParallel # noqa: F401
from jit.test_models import TestModels # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/python/python_dict.cpp",
"torch/csrc/jit/python/python_interpreter.cpp",
"torch/csrc/jit/python/python_ir.cpp",
"torch/csrc/jit/python/python_list.cpp",
"torch/csrc/jit/python/python_tracer.cpp",
"torch/csrc/jit/python/script_init.cpp",
"torch/csrc/jit/frontend/concrete_module_type.cpp",
Expand Down
Loading

0 comments on commit 4a2e8b5

Please sign in to comment.