Skip to content

Commit 4240a97

Browse files
committed
stubgen: added the ability to include a custom prefix/suffix in generated stubs
1 parent 0d191b4 commit 4240a97

File tree

4 files changed

+55
-12
lines changed

4 files changed

+55
-12
lines changed

docs/typing.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,3 +634,7 @@ you may use the special ``\from`` escape code to import them:
634634
\from typing import Optional as _Opt, Literal
635635
def lookup(array: Array[T], index: Literal[0] = 0) -> _Opt[T]:
636636
\doc
637+
638+
You may also add free-form text the beginning or the end of the generated stub.
639+
To do so, add an entry that matches on ``module_name.__prefix__`` or
640+
``module_name.__suffix__``.

src/stubgen.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
# pyright: strict
3+
24
"""
35
stubgen.py: nanobind stub generation tool
46
@@ -136,6 +138,11 @@ class NbType(Protocol):
136138

137139
@dataclass
138140
class ReplacePattern:
141+
"""
142+
A compiled query (regular expression) and replacement pattern. Patterns can
143+
be loaded using the ``load_pattern_file()`` function dfined below
144+
"""
145+
139146
# A replacement patterns as produced by ``load_pattern_file()`` below
140147
query: Pattern[str]
141148
lines: List[str]
@@ -614,10 +621,23 @@ def process_general(m: Match[str]) -> str:
614621

615622
return s
616623

617-
def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str]) -> None:
624+
def apply_pattern(self, query: str, value: object) -> bool:
618625
"""
619-
Called when ``value`` matched an entry of a pattern file
626+
Check if ``value`` matches an entry of a pattern file. Applies the
627+
pattern and returns ``True`` in that case, otherwise returns ``False``.
620628
"""
629+
630+
match: Optional[Match[str]] = None
631+
pattern: Optional[ReplacePattern] = None
632+
633+
for pattern in self.patterns:
634+
match = pattern.query.search(query)
635+
if match:
636+
break
637+
638+
if not match or not pattern:
639+
return False
640+
621641
for line in pattern.lines:
622642
ls = line.strip()
623643
if ls == "\\doc":
@@ -663,6 +683,9 @@ def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str
663683
line = line.replace(f"\\{k}", v)
664684
self.write_ln(line)
665685

686+
# Success, pattern was applied
687+
return True
688+
666689
def put(self, value: object, name: Optional[str] = None, parent: Optional[object] = None) -> None:
667690
old_prefix = self.prefix
668691

@@ -675,13 +698,8 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object
675698
self.prefix = self.prefix + (("." + name) if name else "")
676699

677700
# Check if an entry in a provided pattern file matches
678-
if self.prefix:
679-
for pattern in self.patterns:
680-
match = pattern.query.search(self.prefix)
681-
if match:
682-
# If so, don't recurse
683-
self.apply_pattern(value, pattern, match)
684-
return
701+
if self.apply_pattern(self.prefix, value):
702+
return
685703

686704
# Exclude various standard elements found in modules, classes, etc.
687705
if name in SKIP_LIST:
@@ -713,8 +731,11 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object
713731
# Do not recurse into submodules, but include a directive to import them
714732
self.import_object(value.__name__, name=None, as_name=name)
715733
return
716-
for name, child in getmembers(value):
717-
self.put(child, name=name, parent=value)
734+
else:
735+
self.apply_pattern(self.prefix + ".__prefix__", None)
736+
for name, child in getmembers(value):
737+
self.put(child, name=name, parent=value)
738+
self.apply_pattern(self.prefix + ".__suffix__", None)
718739
elif self.is_function(tp):
719740
value = cast(NbFunction, value)
720741
self.put_function(value, name, parent)
@@ -996,7 +1017,10 @@ def get(self) -> str:
9961017
if s:
9971018
s += "\n"
9981019
s += self.put_abstract_enum_class()
1020+
1021+
# Append the main generated stub
9991022
s += self.output
1023+
10001024
return s.rstrip() + "\n"
10011025

10021026
def put_abstract_enum_class(self) -> str:
@@ -1143,14 +1167,19 @@ def parse_options(args: List[str]) -> argparse.Namespace:
11431167

11441168

11451169
def load_pattern_file(fname: str) -> List[ReplacePattern]:
1170+
"""
1171+
Load a pattern file from disk and return a list of pattern instances that
1172+
includes precompiled versions of all of the contained regular expressions.
1173+
"""
1174+
11461175
with open(fname, "r") as f:
11471176
f_lines = f.readlines()
11481177

11491178
patterns: List[ReplacePattern] = []
11501179

11511180
def add_pattern(query: str, lines: List[str]):
11521181
# Exactly 1 empty line at the end
1153-
while lines and lines[-1].isspace():
1182+
while lines and (lines[-1].isspace() or len(lines[-1]) == 0):
11541183
lines.pop()
11551184
lines.append("")
11561185

tests/pattern_file.nb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ tweak_me:
1111
# Apply a pattern to multiple places
1212
__(lt|gt)__:
1313
def __\1__(self, arg: int, /) -> bool: ...
14+
15+
test_typing_ext.__prefix__:
16+
# a prefix
17+
18+
test_typing_ext.__suffix__:
19+
# a suffix

tests/test_typing_ext.pyi.ref

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ from .submodule import F as F, f as f2
33
from collections.abc import Iterable
44
from typing import Self, Optional, TypeAlias, TypeVar, Generic
55

6+
# a prefix
7+
68
@my_decorator
79
class CustomSignature(Iterable[int]):
810
@my_decorator
@@ -52,3 +54,5 @@ def tweak_me(arg: int):
5254
prior docstring
5355
remains preserved
5456
"""
57+
58+
# a suffix

0 commit comments

Comments
 (0)