Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 74e07d8

Browse files
aivanoufacebook-github-bot
authored andcommittedOct 16, 2021
Make docstring optional (#259)
Summary: Pull Request resolved: #259 * Refactor docstring functions: combines two functions that retrieve docstring into one * Make docstring optional * Remove docstring validator Git issue: #253 Reviewed By: kiukchung Differential Revision: D31671125 fbshipit-source-id: 2da71fcecf0d05f03c04dcc29b44ec43ab919eaa
1 parent bd1c36d commit 74e07d8

File tree

7 files changed

+210
-224
lines changed

7 files changed

+210
-224
lines changed
 

‎docs/source/component_best_practices.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ others to understand how to use it.
7474
return AppDef(roles=[Role(..., num_replicas=num_replicas)])
7575
7676
77+
Documentation
78+
^^^^^^^^^^^^^^^^^^^^^
79+
80+
The documentation is optional, but it is the best practice to keep component functions documented,
81+
especially if you want to share your components. See :ref:Component Authoring<components/overview:Authoring>
82+
for more details.
83+
84+
7785
Named Resources
7886
-----------------
7987

‎torchx/specs/api.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Generic,
2020
Iterator,
2121
List,
22-
Mapping,
2322
Optional,
2423
Tuple,
2524
Type,
@@ -28,8 +27,7 @@
2827
)
2928

3029
import yaml
31-
from pyre_extensions import none_throws
32-
from torchx.specs.file_linter import parse_fn_docstring
30+
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
3331
from torchx.util.types import decode_from_string, decode_optional, is_bool, is_primitive
3432

3533

@@ -748,22 +746,21 @@ def get_argparse_param_type(parameter: inspect.Parameter) -> Callable[[str], obj
748746
return str
749747

750748

751-
def _create_args_parser(
752-
fn_name: str,
753-
parameters: Mapping[str, inspect.Parameter],
754-
function_desc: str,
755-
args_desc: Dict[str, str],
756-
) -> argparse.ArgumentParser:
749+
def _create_args_parser(app_fn: Callable[..., AppDef]) -> argparse.ArgumentParser:
750+
parameters = inspect.signature(app_fn).parameters
751+
function_desc, args_desc = get_fn_docstring(app_fn)
757752
script_parser = argparse.ArgumentParser(
758-
prog=f"torchx run ...torchx_params... {fn_name} ",
759-
description=f"App spec: {function_desc}",
753+
prog=f"torchx run <<torchx_params>> {app_fn.__name__} ",
754+
description=f"AppDef for {function_desc}",
755+
formatter_class=TorchXArgumentHelpFormatter,
760756
)
761757

762758
remainder_arg = []
763759

764760
for param_name, parameter in parameters.items():
761+
param_desc = args_desc[parameter.name]
765762
args: Dict[str, Any] = {
766-
"help": args_desc[param_name],
763+
"help": param_desc,
767764
"type": get_argparse_param_type(parameter),
768765
}
769766
if parameter.default != inspect.Parameter.empty:
@@ -788,20 +785,15 @@ def _create_args_parser(
788785
def _get_function_args(
789786
app_fn: Callable[..., AppDef], app_args: List[str]
790787
) -> Tuple[List[object], List[str], Dict[str, object]]:
791-
docstring = none_throws(inspect.getdoc(app_fn))
792-
function_desc, args_desc = parse_fn_docstring(docstring)
793-
794-
parameters = inspect.signature(app_fn).parameters
795-
script_parser = _create_args_parser(
796-
app_fn.__name__, parameters, function_desc, args_desc
797-
)
788+
script_parser = _create_args_parser(app_fn)
798789

799790
parsed_args = script_parser.parse_args(app_args)
800791

801792
function_args = []
802793
var_arg = []
803794
kwargs = {}
804795

796+
parameters = inspect.signature(app_fn).parameters
805797
for param_name, parameter in parameters.items():
806798
arg_value = getattr(parsed_args, param_name)
807799
parameter_type = parameter.annotation

‎torchx/specs/file_linter.py

Lines changed: 56 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import abc
9+
import argparse
910
import ast
11+
import inspect
1012
from dataclasses import dataclass
11-
from typing import Dict, List, Optional, Tuple, cast
13+
from typing import Dict, List, Optional, Tuple, cast, Callable
1214

1315
from docstring_parser import parse
1416
from pyre_extensions import none_throws
@@ -18,53 +20,66 @@
1820
# pyre-ignore-all-errors[16]
1921

2022

21-
def get_arg_names(app_specs_func_def: ast.FunctionDef) -> List[str]:
22-
arg_names = []
23-
fn_args = app_specs_func_def.args
24-
for arg_def in fn_args.args:
25-
arg_names.append(arg_def.arg)
26-
if fn_args.vararg:
27-
arg_names.append(fn_args.vararg.arg)
28-
for arg in fn_args.kwonlyargs:
29-
arg_names.append(arg.arg)
30-
return arg_names
23+
def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str, str]:
24+
parameters = inspect.signature(fn).parameters
25+
args_decs = {}
26+
for parameter_name in parameters.keys():
27+
# The None or Empty string values getting ignored during help command by argparse
28+
args_decs[parameter_name] = " "
29+
return args_decs
3130

3231

33-
def parse_fn_docstring(func_description: str) -> Tuple[str, Dict[str, str]]:
32+
class TorchXArgumentHelpFormatter(argparse.HelpFormatter):
33+
"""Help message formatter which adds default values and required to argument help.
34+
35+
If the argument is required, the class appends `(required)` at the end of the help message.
36+
If the argument has default value, the class appends `(default: $DEFAULT)` at the end.
37+
The formatter is designed to be used only for the torchx components functions.
38+
These functions do not have both required and default arguments.
3439
"""
35-
Given a docstring in a google-style format, returns the function description and
36-
description of all arguments.
37-
See: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
40+
41+
def _get_help_string(self, action: argparse.Action) -> str:
42+
help = action.help or ""
43+
# Only `--help` will have be SUPPRESS, so we ignore it
44+
if action.default is argparse.SUPPRESS:
45+
return help
46+
if action.required:
47+
help += " (required)"
48+
else:
49+
help += f" (default: {action.default})"
50+
return help
51+
52+
53+
def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
3854
"""
39-
args_description = {}
55+
Parses the function and arguments description from the provided function. Docstring should be in
56+
`google-style format <https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
57+
58+
If function has no docstring, the function description will be the name of the function, TIP
59+
on how to improve the help message and arguments descriptions will be names of the arguments.
60+
61+
The arguments that are not present in the docstring will contain default/required information
62+
63+
Args:
64+
fn: Function with or without docstring
65+
66+
Returns:
67+
function description, arguments description where key is the name of the argument and value
68+
if the description
69+
"""
70+
default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
71+
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
72+
args_description = _get_default_arguments_descriptions(fn)
73+
func_description = inspect.getdoc(fn)
74+
if not func_description:
75+
return default_fn_desc, args_description
4076
docstring = parse(func_description)
4177
for param in docstring.params:
4278
args_description[param.arg_name] = param.description
43-
short_func_description = docstring.short_description
44-
return (short_func_description or "", args_description)
45-
46-
47-
def _get_fn_docstring(
48-
source: str, function_name: str
49-
) -> Optional[Tuple[str, Dict[str, str]]]:
50-
module = ast.parse(source)
51-
for expr in module.body:
52-
if type(expr) == ast.FunctionDef:
53-
func_def = cast(ast.FunctionDef, expr)
54-
if func_def.name == function_name:
55-
docstring = ast.get_docstring(func_def)
56-
if not docstring:
57-
return None
58-
return parse_fn_docstring(docstring)
59-
return None
60-
61-
62-
def get_short_fn_description(path: str, function_name: str) -> Optional[str]:
63-
source = read_conf_file(path)
64-
docstring = _get_fn_docstring(source, function_name)
65-
if not docstring:
66-
return None
67-
return docstring[0]
79+
short_func_description = docstring.short_description or default_fn_desc
80+
if docstring.long_description:
81+
short_func_description += " ..."
82+
return (short_func_description or default_fn_desc, args_description)
6883

6984

7085
@dataclass
@@ -91,38 +106,6 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
91106
)
92107

93108

94-
class TorchxDocstringValidator(TorchxFunctionValidator):
95-
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
96-
"""
97-
Validates the docstring of the `get_app_spec` function. Criteria:
98-
* There mast be google-style docstring
99-
* If there are more than zero arguments, there mast be a `Args:` section defined
100-
with all arguments included.
101-
"""
102-
docsting = ast.get_docstring(app_specs_func_def)
103-
lineno = app_specs_func_def.lineno
104-
if not docsting:
105-
desc = (
106-
f"`{app_specs_func_def.name}` is missing a Google Style docstring, please add one. "
107-
"For more information on the docstring format see: "
108-
"https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html"
109-
)
110-
return [self._gen_linter_message(desc, lineno)]
111-
112-
arg_names = get_arg_names(app_specs_func_def)
113-
_, docstring_arg_defs = parse_fn_docstring(docsting)
114-
missing_args = [
115-
arg_name for arg_name in arg_names if arg_name not in docstring_arg_defs
116-
]
117-
if len(missing_args) > 0:
118-
desc = (
119-
f"`{app_specs_func_def.name}` not all function arguments are present"
120-
f" in the docstring. Missing args: {missing_args}"
121-
)
122-
return [self._gen_linter_message(desc, lineno)]
123-
return []
124-
125-
126109
class TorchxFunctionArgsValidator(TorchxFunctionValidator):
127110
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
128111
linter_errors = []
@@ -149,7 +132,6 @@ def _validate_arg_def(
149132
)
150133
]
151134
if isinstance(arg_def.annotation, ast.Name):
152-
# TODO(aivanou): add support for primitive type check
153135
return []
154136
complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
155137
if complex_type_def.value.id == "Optional":
@@ -239,12 +221,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
239221
Visitor that finds the component_function and runs registered validators on it.
240222
Current registered validators:
241223
242-
* TorchxDocstringValidator - validates the docstring of the function.
243-
Criteria:
244-
* There format should be google-python
245-
* If there are more than zero arguments defined, there
246-
should be obligatory `Args:` section that describes each argument on a new line.
247-
248224
* TorchxFunctionArgsValidator - validates arguments of the function.
249225
Criteria:
250226
* Each argument should be annotated with the type
@@ -260,7 +236,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
260236

261237
def __init__(self, component_function_name: str) -> None:
262238
self.validators = [
263-
TorchxDocstringValidator(),
264239
TorchxFunctionArgsValidator(),
265240
TorchxReturnValidator(),
266241
]

‎torchx/specs/finder.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pyre_extensions import none_throws
1919
from torchx.specs import AppDef
20-
from torchx.specs.file_linter import get_short_fn_description, validate
20+
from torchx.specs.file_linter import get_fn_docstring, validate
2121
from torchx.util import entrypoints
2222
from torchx.util.io import read_conf_file
2323

@@ -40,14 +40,15 @@ class _Component:
4040
Args:
4141
name: The name of the component, which usually MODULE_PATH.FN_NAME
4242
description: The description of the component, taken from the desrciption
43-
of the function that creates component
43+
of the function that creates component. In case of no docstring, description
44+
will be the same as name
4445
fn_name: Function name that creates component
4546
fn: Function that creates component
4647
validation_errors: Validation errors
4748
"""
4849

4950
name: str
50-
description: Optional[str]
51+
description: str
5152
fn_name: str
5253
fn: Callable[..., AppDef]
5354
validation_errors: List[str]
@@ -150,7 +151,7 @@ def _get_components_from_module(
150151
module_path = os.path.abspath(module.__file__)
151152
for function_name, function in functions:
152153
linter_errors = validate(module_path, function_name)
153-
component_desc = get_short_fn_description(module_path, function_name)
154+
component_desc, _ = get_fn_docstring(function)
154155
component_def = _Component(
155156
name=self._get_component_name(
156157
base_module, module.__name__, function_name
@@ -197,7 +198,6 @@ def find(self) -> List[_Component]:
197198
validation_errors = self._get_validation_errors(
198199
self._filepath, self._function_name
199200
)
200-
fn_desc = get_short_fn_description(self._filepath, self._function_name)
201201

202202
file_source = read_conf_file(self._filepath)
203203
namespace = globals()
@@ -207,6 +207,7 @@ def find(self) -> List[_Component]:
207207
f"Function {self._function_name} does not exist in file {self._filepath}"
208208
)
209209
app_fn = namespace[self._function_name]
210+
fn_desc, _ = get_fn_docstring(app_fn)
210211
return [
211212
_Component(
212213
name=f"{self._filepath}:{self._function_name}",

‎torchx/specs/test/api_test.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import argparse
89
import sys
910
import unittest
1011
from dataclasses import asdict
11-
from typing import Dict, List, Optional, Tuple, Union
12+
from typing import Dict, List, Optional, Tuple, Union, Any
1213
from unittest.mock import MagicMock, patch
1314

1415
import torchx.specs.named_resources_aws as named_resources_aws
16+
from pyre_extensions import none_throws
1517
from torchx.specs import named_resources
1618
from torchx.specs.api import (
1719
_TERMINAL_STATES,
@@ -33,6 +35,7 @@
3335
make_app_handle,
3436
parse_app_handle,
3537
runopts,
38+
_create_args_parser,
3639
)
3740

3841

@@ -463,11 +466,6 @@ def _test_complex_fn(
463466
app_name: AppDef name
464467
containers: List of containers
465468
roles_scripts: Dict role_name -> role_script
466-
num_cpus: List of cpus per role
467-
num_gpus: Dict role_name -> gpus used for role
468-
nnodes: Num replicas per role
469-
first_arg: First argument to the user script
470-
roles_args: Roles args
471469
"""
472470
num_roles = len(roles_scripts)
473471
if not num_cpus:
@@ -710,3 +708,43 @@ def test_varargs_only_arg_first(self) -> None:
710708
_TEST_VAR_ARGS_FIRST,
711709
(("fooval", "--foo", "barval", "arg1", "arg2"), "asdf"),
712710
)
711+
712+
# pyre-ignore[3]
713+
def _get_argument_help(
714+
self, parser: argparse.ArgumentParser, name: str
715+
) -> Optional[Tuple[str, Any]]:
716+
actions = parser._actions
717+
for action in actions:
718+
if action.dest == name:
719+
return action.help or "", action.default
720+
return None
721+
722+
def test_argparster_complex_fn_partial(self) -> None:
723+
parser = _create_args_parser(_test_complex_fn)
724+
self.assertTupleEqual(
725+
("AppDef name", None),
726+
none_throws(self._get_argument_help(parser, "app_name")),
727+
)
728+
self.assertTupleEqual(
729+
("List of containers", None),
730+
none_throws(self._get_argument_help(parser, "containers")),
731+
)
732+
self.assertTupleEqual(
733+
("Dict role_name -> role_script", None),
734+
none_throws(self._get_argument_help(parser, "roles_scripts")),
735+
)
736+
self.assertTupleEqual(
737+
(" ", None), none_throws(self._get_argument_help(parser, "num_cpus"))
738+
)
739+
self.assertTupleEqual(
740+
(" ", None), none_throws(self._get_argument_help(parser, "num_gpus"))
741+
)
742+
self.assertTupleEqual(
743+
(" ", 4), none_throws(self._get_argument_help(parser, "nnodes"))
744+
)
745+
self.assertTupleEqual(
746+
(" ", None), none_throws(self._get_argument_help(parser, "first_arg"))
747+
)
748+
self.assertTupleEqual(
749+
(" ", None), none_throws(self._get_argument_help(parser, "roles_args"))
750+
)

‎torchx/specs/test/file_linter_test.py

Lines changed: 63 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import ast
7+
import argparse
88
import os
99
import unittest
10-
from typing import Dict, List, Optional, cast
10+
from typing import Dict, List, Optional
1111
from unittest.mock import patch
1212

13-
from pyre_extensions import none_throws
1413
from torchx.specs.file_linter import (
15-
get_short_fn_description,
16-
_get_fn_docstring,
17-
parse_fn_docstring,
14+
get_fn_docstring,
1815
validate,
16+
TorchXArgumentHelpFormatter,
1917
)
2018

2119

@@ -41,35 +39,24 @@ def _test_fn_return_int() -> int:
4139
return 0
4240

4341

44-
def _test_docstring_empty(arg: str) -> "AppDef":
45-
""" """
46-
pass
42+
def _test_docstring(arg0: str, arg1: int, arg2: Dict[int, str]) -> "AppDef":
43+
"""Short Test description
4744
45+
Long funct description
4846
49-
def _test_docstring_func_desc() -> "AppDef":
50-
"""
51-
Function description
47+
Args:
48+
arg0: arg0 desc
49+
arg1: arg1 desc
5250
"""
5351
pass
5452

5553

56-
def _test_docstring_no_args(arg: str) -> "AppDef":
57-
"""
58-
Test description
59-
"""
54+
def _test_docstring_short() -> "AppDef":
55+
"""Short Test description"""
6056
pass
6157

6258

63-
def _test_docstring_correct(arg0: str, arg1: int, arg2: Dict[int, str]) -> "AppDef":
64-
"""Short Test description
65-
66-
Long funct description
67-
68-
Args:
69-
arg0: arg0 desc
70-
arg1: arg1 desc
71-
arg2: arg2 desc
72-
"""
59+
def _test_without_docstring(arg0: str) -> "AppDef":
7360
pass
7461

7562

@@ -129,10 +116,6 @@ def setUp(self) -> None:
129116
source = fp.read()
130117
self._file_content = source
131118

132-
def test_validate_docstring_func_desc(self) -> None:
133-
linter_errors = validate(self._path, "_test_docstring_func_desc")
134-
self.assertEqual(0, len(linter_errors))
135-
136119
def test_syntax_error(self) -> None:
137120
content = "!!foo====bar"
138121
with patch("torchx.specs.file_linter.read_conf_file") as read_conf_file_mock:
@@ -146,10 +129,9 @@ def test_validate_varargs_kwargs_fn(self) -> None:
146129
self._path,
147130
"_test_invalid_fn_with_varags_and_kwargs",
148131
)
149-
self.assertEqual(2, len(linter_errors))
150-
self.assertTrue("Missing args: ['id']" in linter_errors[0].description)
132+
self.assertEqual(1, len(linter_errors))
151133
self.assertTrue(
152-
"Arg args missing type annotation", linter_errors[1].description
134+
"Arg args missing type annotation", linter_errors[0].description
153135
)
154136

155137
def test_validate_no_return(self) -> None:
@@ -172,45 +154,6 @@ def test_validate_incorrect_return(self) -> None:
172154

173155
def test_validate_empty_fn(self) -> None:
174156
linter_errors = validate(self._path, "_test_empty_fn")
175-
self.assertEqual(1, len(linter_errors))
176-
linter_error = linter_errors[0]
177-
self.assertEqual("TorchxFunctionValidator", linter_error.name)
178-
179-
expected_desc = (
180-
"`_test_empty_fn` is missing a Google Style docstring, please add one. "
181-
"For more information on the docstring format see: "
182-
"https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html"
183-
)
184-
self.assertEquals(expected_desc, linter_error.description)
185-
# TODO(aivanou): change this test to validate fn from another file to avoid changing lineno
186-
# on each file change
187-
self.assertEqual(24, linter_error.line)
188-
189-
def test_validate_docstring_empty(self) -> None:
190-
linter_errors = validate(self._path, "_test_docstring_empty")
191-
self.assertEqual(1, len(linter_errors))
192-
linter_error = linter_errors[0]
193-
self.assertEqual("TorchxFunctionValidator", linter_error.name)
194-
expected_desc = (
195-
"`_test_docstring_empty` is missing a Google Style docstring, please add one. "
196-
"For more information on the docstring format see: "
197-
"https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html"
198-
)
199-
self.assertEquals(expected_desc, linter_error.description)
200-
201-
def test_validate_docstring_no_args(self) -> None:
202-
linter_errors = validate(self._path, "_test_docstring_no_args")
203-
self.assertEqual(1, len(linter_errors))
204-
linter_error = linter_errors[0]
205-
self.assertEqual("TorchxFunctionValidator", linter_error.name)
206-
expected_desc = (
207-
"`_test_docstring_no_args` not all function arguments"
208-
" are present in the docstring. Missing args: ['arg']"
209-
)
210-
self.assertEqual(expected_desc, linter_error.description)
211-
212-
def test_validate_docstring_correct(self) -> None:
213-
linter_errors = validate(self._path, "_test_docstring_correct")
214157
self.assertEqual(0, len(linter_errors))
215158

216159
def test_validate_args_no_type_defs(self) -> None:
@@ -229,66 +172,74 @@ def test_validate_args_no_type_defs_complex(self) -> None:
229172
self._path,
230173
"_test_args_dict_list_complex_types",
231174
)
232-
self.assertEqual(6, len(linter_errors))
233-
expected_desc = (
234-
"`_test_args_dict_list_complex_types` not all function arguments"
235-
" are present in the docstring. Missing args: ['arg4']"
236-
)
175+
self.assertEqual(5, len(linter_errors))
237176
self.assertEqual(
238-
expected_desc,
239-
linter_errors[0].description,
240-
)
241-
self.assertEqual(
242-
"Arg arg0 missing type annotation", linter_errors[1].description
177+
"Arg arg0 missing type annotation", linter_errors[0].description
243178
)
244179
self.assertEqual(
245-
"Arg arg1 missing type annotation", linter_errors[2].description
180+
"Arg arg1 missing type annotation", linter_errors[1].description
246181
)
247182
self.assertEqual(
248-
"Dict can only have primitive types", linter_errors[3].description
183+
"Dict can only have primitive types", linter_errors[2].description
249184
)
250185
self.assertEqual(
251-
"List can only have primitive types", linter_errors[4].description
186+
"List can only have primitive types", linter_errors[3].description
252187
)
253188
self.assertEqual(
254189
"`_test_args_dict_list_complex_types` allows only Dict, List as complex types.Argument `arg4` has: Optional",
255-
linter_errors[5].description,
190+
linter_errors[4].description,
256191
)
257192

258-
def _get_function_def(self, function_name: str) -> ast.FunctionDef:
259-
module: ast.Module = ast.parse(self._file_content)
260-
for expr in module.body:
261-
if type(expr) == ast.FunctionDef:
262-
func_def = cast(ast.FunctionDef, expr)
263-
if func_def.name == function_name:
264-
return func_def
265-
raise RuntimeError(f"No function found: {function_name}")
266-
267-
def test_validate_docstring_full(self) -> None:
268-
func_def = self._get_function_def("_test_docstring_correct")
269-
docstring = none_throws(ast.get_docstring(func_def))
270-
271-
func_desc, param_desc = parse_fn_docstring(docstring)
272-
self.assertEqual("Short Test description", func_desc)
193+
def test_validate_docstring(self) -> None:
194+
func_desc, param_desc = get_fn_docstring(_test_docstring)
195+
self.assertEqual("Short Test description ...", func_desc)
273196
self.assertEqual("arg0 desc", param_desc["arg0"])
274197
self.assertEqual("arg1 desc", param_desc["arg1"])
275-
self.assertEqual("arg2 desc", param_desc["arg2"])
198+
self.assertEqual(" ", param_desc["arg2"])
276199

277-
def test_get_fn_docstring(self) -> None:
278-
function_desc, _ = none_throws(
279-
_get_fn_docstring(self._file_content, "_test_args_dict_list_complex_types")
280-
)
281-
self.assertEqual("Test description", function_desc)
200+
def test_validate_docstring_short(self) -> None:
201+
func_desc, param_desc = get_fn_docstring(_test_docstring_short)
202+
self.assertEqual("Short Test description", func_desc)
203+
204+
def test_validate_docstring_no_docs(self) -> None:
205+
func_desc, param_desc = get_fn_docstring(_test_without_docstring)
206+
expected_fn_desc = """_test_without_docstring TIP: improve this help string by adding a docstring
207+
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
208+
self.assertEqual(expected_fn_desc, func_desc)
209+
self.assertEqual(" ", param_desc["arg0"])
282210

283-
def test_unknown_function(self) -> None:
211+
def test_validate_unknown_function(self) -> None:
284212
linter_errors = validate(self._path, "unknown_function")
285213
self.assertEqual(1, len(linter_errors))
286214
self.assertEqual(
287215
"Function unknown_function not found", linter_errors[0].description
288216
)
289217

290-
def test_get_short_fn_description(self) -> None:
291-
fn_short_desc = none_throws(
292-
get_short_fn_description(self._path, "_test_args_dict_list_complex_types")
218+
def test_formatter(self) -> None:
219+
parser = argparse.ArgumentParser(
220+
prog="test prog",
221+
description="test desc",
222+
)
223+
parser.add_argument(
224+
"--foo",
225+
type=int,
226+
required=True,
227+
help="foo",
228+
)
229+
parser.add_argument(
230+
"--bar",
231+
type=int,
232+
help="bar",
233+
default=1,
234+
)
235+
formatter = TorchXArgumentHelpFormatter(prog="test")
236+
self.assertEqual(
237+
"show this help message and exit",
238+
formatter._get_help_string(parser._actions[0]),
239+
)
240+
self.assertEqual(
241+
"foo (required)", formatter._get_help_string(parser._actions[1])
242+
)
243+
self.assertEqual(
244+
"bar (default: 1)", formatter._get_help_string(parser._actions[2])
293245
)
294-
self.assertEqual("Test description", fn_short_desc)

‎torchx/specs/test/finder_test.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def _test_component(name: str, role_name: str = "worker") -> AppDef:
4040
)
4141

4242

43+
def _test_component_without_docstring(name: str, role_name: str = "worker") -> AppDef:
44+
return AppDef(
45+
name, roles=[Role(name=role_name, image="test_image", entrypoint="main.py")]
46+
)
47+
48+
4349
# pyre-ignore[2]
4450
def invalid_component(name, role_name: str = "worker") -> AppDef:
4551
return AppDef(
@@ -87,7 +93,7 @@ def test_get_invalid_component(self) -> None:
8793
entrypoints_mock.load_group.return_value = test_torchx_group
8894
components = _load_components()
8995
foobar_component = components["foobar.finder_test.invalid_component"]
90-
self.assertEqual(2, len(foobar_component.validation_errors))
96+
self.assertEqual(1, len(foobar_component.validation_errors))
9197

9298
def test_get_entrypoints_components(self) -> None:
9399
test_torchx_group = {"foobar": sys.modules[__name__]}
@@ -151,6 +157,21 @@ def test_find_components(self) -> None:
151157
self.assertEqual("_test_component", component.fn_name)
152158
self.assertListEqual([], component.validation_errors)
153159

160+
def test_find_components_without_docstring(self) -> None:
161+
components = CustomComponentsFinder(
162+
current_file_path(), "_test_component_without_docstring"
163+
).find()
164+
self.assertEqual(1, len(components))
165+
component = components[0]
166+
self.assertEqual(
167+
f"{current_file_path()}:_test_component_without_docstring", component.name
168+
)
169+
exprected_desc = """_test_component_without_docstring TIP: improve this help string by adding a docstring
170+
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
171+
self.assertEqual(exprected_desc, component.description)
172+
self.assertEqual("_test_component_without_docstring", component.fn_name)
173+
self.assertListEqual([], component.validation_errors)
174+
154175
def test_get_component(self) -> None:
155176
component = get_component(f"{current_file_path()}:_test_component")
156177
self.assertEqual(f"{current_file_path()}:_test_component", component.name)

0 commit comments

Comments
 (0)
Please sign in to comment.