Skip to content

Commit a588b0c

Browse files
Jian Shengchunit-quic
andauthored
Revert "[Frontend] Add Span filling for frontends to Relay (apache#9723)" (apache#10072) (#246)
Because of the failure of LSTM conversion from Pytorch Co-authored-by: Chun-I Tsai <[email protected]>
1 parent 4fd713c commit a588b0c

File tree

13 files changed

+49
-237
lines changed

13 files changed

+49
-237
lines changed

python/tvm/relay/expr.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,10 @@ class TupleGetItem(ExprWithOp):
316316
317317
index: int
318318
The index.
319-
320-
span: Optional[tvm.relay.Span]
321-
Span that points to original source code
322319
"""
323320

324-
def __init__(self, tuple_value, index, span=None):
325-
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span)
321+
def __init__(self, tuple_value, index):
322+
self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index)
326323

327324

328325
@tvm._ffi.register_object("relay.RefCreate")

python/tvm/relay/frontend/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from tvm.topi.utils import get_const_tuple
2626

2727
from .. import expr as _expr
28-
from ..expr_functor import ExprMutator
2928
from .. import function as _function
3029
from .. import transform as _transform
3130
from .. import op as _op
@@ -996,6 +995,7 @@ def visit_tuple(self, tup):
996995
def visit_tuple_getitem(self, op):
997996
if op.span is None:
998997
self.distance_from_leaf += 1
998+
# pylint: disable=too-many-function-args
999999
return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._create_span())
10001000
return op
10011001

python/tvm/relay/frontend/pytorch.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from .common import infer_value as _infer_value
4646
from .common import infer_value_simulated as _infer_value_simulated
4747
from .common import lstm_cell, try_infer_value, unbind
48-
from .common import set_span
4948
from .pytorch_utils import is_version_greater_than
5049

5150
__all__ = ["from_pytorch"]
@@ -3257,9 +3256,6 @@ def body(*current_vals):
32573256

32583257
def convert_operators(self, operators, outputs, ret_names):
32593258
"""Convert each Torch IR operators to Relay equivalent"""
3260-
# an op node might not belong to any of scope in trace info natively
3261-
# use a cunter to prevent from messing up its scope in span
3262-
empty_counter = 0
32633259
for node_name, op_node in operators:
32643260
operator = op_node.kind()
32653261
inputs = _get_op_inputs(op_node, outputs)
@@ -3309,9 +3305,6 @@ def convert_operators(self, operators, outputs, ret_names):
33093305
relay_out = relay_op(
33103306
inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype)
33113307
)
3312-
span_str, empty_counter = self._get_torch_span(op_node, empty_counter)
3313-
relay_out = set_span(relay_out, span_str)
3314-
33153308
self.record_output_type(relay_out)
33163309

33173310
if isinstance(relay_out, tuple):
@@ -3325,18 +3318,6 @@ def convert_operators(self, operators, outputs, ret_names):
33253318

33263319
return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
33273320

3328-
def _get_torch_span(self, node, empty_counter):
3329-
# torch span looks like
3330-
# %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu # ${torch}/nn file
3331-
# the scope part might not exist
3332-
if node.scopeName():
3333-
scope_name_str = "jit._trace.TopLevelTracedModule: " + node.scopeName()
3334-
else:
3335-
scope_name_str = "warning: no trace info " + str(empty_counter)
3336-
empty_counter += 1
3337-
span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str)
3338-
return span_str, empty_counter
3339-
33403321

33413322
def _pytorch_result_type(dtypes, non_tensor_inputs):
33423323
"""This promotes TVM dtypes like PyTorch would"""

python/tvm/relay/frontend/tensorflow.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from .common import infer_type as _infer_type
3838
from .common import infer_shape as _infer_shape
3939
from .common import infer_value as _infer_value
40-
from .common import set_span
4140

4241
from .tensorflow_ops import _convert_map
4342
from .tensorflow_ops import _need_prelude_for_shape_inference
@@ -1029,10 +1028,24 @@ def _convert_operator(
10291028
else:
10301029
raise NotImplementedError("Operator {} not implemented.".format(op_name))
10311030

1032-
sym = set_span(sym, node_name)
1031+
sym = self._set_span(sym, node_name)
10331032

10341033
return sym
10351034

1035+
@staticmethod
1036+
def _set_span(sym, node_name):
1037+
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
1038+
if isinstance(sym, _expr.Call) and sym.span is None:
1039+
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
1040+
elif isinstance(sym, _expr.TupleWrapper):
1041+
tuple_value = sym.tuple_value
1042+
if isinstance(tuple_value, _expr.Call) and tuple_value.span is None:
1043+
tuple_value = _expr.Call(
1044+
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
1045+
)
1046+
sym = _expr.TupleWrapper(tuple_value, sym.size)
1047+
return sym
1048+
10361049
def _licm_construct(self, loop_name, node_name):
10371050
"""Construct a node by considering whether it is
10381051
loop invariant with the given while loop. If yes, we

python/tvm/relay/frontend/tensorflow2.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from .. import function as _function
3737
from ..loops import while_loop as _while_loop
3838
from .common import infer_type as _infer_type
39-
from .common import set_span
4039

4140
from .tensorflow_ops import _convert_map as _convert_map_common
4241
from .tensorflow_ops import _get_more_static_shape_rank
@@ -59,6 +58,22 @@ def _infer_type_with_prelude(val, prelude):
5958
return body.checked_type
6059

6160

61+
def set_span(sym, node_name):
62+
"""set span of symbol"""
63+
64+
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
65+
if isinstance(sym, _expr.Call):
66+
sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
67+
elif isinstance(sym, _expr.TupleWrapper):
68+
tuple_value = sym.tuple_value
69+
if isinstance(tuple_value, _expr.Call):
70+
tuple_value = _expr.Call(
71+
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
72+
)
73+
sym = _expr.TupleWrapper(tuple_value, sym.size)
74+
return sym
75+
76+
6277
def is_tensor_list_constuctor(tf_node):
6378
"""Check whether is tensor list constructor node."""
6479
return tf_node.op == "TensorListReserve"

python/tvm/relay/frontend/tflite.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from .. import qnn as _qnn
3333
from .common import ExprTable
3434
from .common import infer_shape as _infer_shape
35-
from .common import set_span
3635
from .common import to_int_list
3736
from .tflite_flexbuffer import FlexBufferDecoder
3837

@@ -240,17 +239,12 @@ def convert_op_to_relay(self):
240239

241240
if len(output_tensors) == 1:
242241
tensor_idx = output_tensors[0].tensor_idx
243-
curr_output = get_tensor_name(self.subgraph, tensor_idx)
244-
ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output))
245-
self.exp_tab.set_expr(curr_output, ret)
242+
self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
246243
else:
247-
out_names = []
248-
for output_tensor in output_tensors:
249-
out_names.append(get_tensor_name(self.subgraph, output_tensor.tensor_idx))
250-
curr_output = ", ".join(out_names)
251-
ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output))
252-
for idx, out_name in enumerate(out_names):
253-
self.exp_tab.set_expr(out_name, ret[idx])
244+
for idx, output_tensor in enumerate(output_tensors):
245+
self.exp_tab.set_expr(
246+
get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx]
247+
)
254248

255249
def get_op_code_str(self, op):
256250
"""Get TFLite ops string representation"""

src/printer/relay_text_printer.cc

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -389,21 +389,12 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
389389
if (op->fields.size() == 1) {
390390
doc << ",";
391391
}
392-
doc << ")";
393-
if (op->span.defined()) {
394-
doc << " /* " << PrintSpan(op->span) << " */";
395-
}
396-
return doc;
392+
return doc << ")";
397393
}
398394

399395
Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
400396
Doc doc;
401-
doc << Print(op->tuple) << "." << op->index;
402-
403-
if (op->span.defined()) {
404-
doc << " /* " << PrintSpan(op->span) << " */";
405-
}
406-
return doc;
397+
return doc << Print(op->tuple) << "." << op->index;
407398
}
408399

409400
Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
@@ -977,13 +968,11 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>&
977968
return doc;
978969
}
979970

980-
Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) {
971+
Doc RelayTextPrinter::PrintSpan(const Span& span) {
981972
Doc doc;
982-
if (include_spans) {
983-
const auto* span_node = span.as<SpanNode>();
984-
ICHECK(span_node);
985-
doc << span_node->source_name->name;
986-
}
973+
const auto* span_node = span.as<SpanNode>();
974+
ICHECK(span_node);
975+
doc << span_node->source_name->name;
987976
return doc;
988977
}
989978

src/printer/text_printer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
113113
*/
114114
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
115115

116-
Doc PrintSpan(const Span& span, bool include_spans = true);
116+
Doc PrintSpan(const Span& span);
117117

118118
Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
119119

src/relay/ir/expr.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
362362

363363
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
364364

365-
TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) {
366-
return TupleGetItem(tuple, index, span);
365+
TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) {
366+
return TupleGetItem(tuple, index);
367367
});
368368

369369
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

tests/python/frontend/pytorch/test_forward.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -247,53 +247,6 @@ def visit(op):
247247
torch.cuda.empty_cache()
248248

249249

250-
def verify_span(model_name, input_data=[], custom_convert_map={}):
251-
if isinstance(model_name, str):
252-
baseline_model, baseline_input = load_model(model_name)
253-
elif isinstance(input_data, list):
254-
baseline_model = model_name
255-
baseline_input = input_data
256-
elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
257-
baseline_model = model_name
258-
baseline_input = [input_data]
259-
else:
260-
assert False, "Unexpected input format"
261-
262-
trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input])
263-
if isinstance(baseline_model, torch.nn.Module):
264-
trace = trace.float().eval()
265-
266-
if torch.cuda.is_available():
267-
trace = trace.cuda()
268-
else:
269-
trace = trace.cpu()
270-
271-
input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
272-
input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input]))
273-
mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map)
274-
275-
# collect fail cases for the convenience of further improvement
276-
fail_cases = []
277-
mod_main_start = False
278-
for line in str(mod.__str__).split("\n"):
279-
if "@main" in line:
280-
mod_main_start = True
281-
continue
282-
283-
if mod_main_start == True:
284-
if "}" == line:
285-
break
286-
elif not ("/*" in line and "*/" in line):
287-
fail_cases.append(line)
288-
289-
print(fail_cases)
290-
assert len(fail_cases) == 0
291-
292-
293-
def test_span():
294-
verify_span("resnet18")
295-
296-
297250
# Single operator tests
298251
@tvm.testing.uses_gpu
299252
def test_forward_pixel_shuffle():

0 commit comments

Comments
 (0)