1
1
#!/usr/bin/env python3
2
+ # pyright: strict
3
+
2
4
"""
3
5
stubgen.py: nanobind stub generation tool
4
6
@@ -136,6 +138,11 @@ class NbType(Protocol):
136
138
137
139
@dataclass
138
140
class ReplacePattern :
141
+ """
142
+ A compiled query (regular expression) and replacement pattern. Patterns can
143
+ be loaded using the ``load_pattern_file()`` function dfined below
144
+ """
145
+
139
146
# A replacement patterns as produced by ``load_pattern_file()`` below
140
147
query : Pattern [str ]
141
148
lines : List [str ]
@@ -614,10 +621,23 @@ def process_general(m: Match[str]) -> str:
614
621
615
622
return s
616
623
617
- def apply_pattern (self , value : object , pattern : ReplacePattern , match : Match [ str ] ) -> None :
624
+ def apply_pattern (self , query : str , value : object ) -> bool :
618
625
"""
619
- Called when ``value`` matched an entry of a pattern file
626
+ Check if ``value`` matches an entry of a pattern file. Applies the
627
+ pattern and returns ``True`` in that case, otherwise returns ``False``.
620
628
"""
629
+
630
+ match : Optional [Match [str ]] = None
631
+ pattern : Optional [ReplacePattern ] = None
632
+
633
+ for pattern in self .patterns :
634
+ match = pattern .query .search (query )
635
+ if match :
636
+ break
637
+
638
+ if not match or not pattern :
639
+ return False
640
+
621
641
for line in pattern .lines :
622
642
ls = line .strip ()
623
643
if ls == "\\ doc" :
@@ -663,6 +683,9 @@ def apply_pattern(self, value: object, pattern: ReplacePattern, match: Match[str
663
683
line = line .replace (f"\\ { k } " , v )
664
684
self .write_ln (line )
665
685
686
+ # Success, pattern was applied
687
+ return True
688
+
666
689
def put (self , value : object , name : Optional [str ] = None , parent : Optional [object ] = None ) -> None :
667
690
old_prefix = self .prefix
668
691
@@ -675,13 +698,8 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object
675
698
self .prefix = self .prefix + (("." + name ) if name else "" )
676
699
677
700
# Check if an entry in a provided pattern file matches
678
- if self .prefix :
679
- for pattern in self .patterns :
680
- match = pattern .query .search (self .prefix )
681
- if match :
682
- # If so, don't recurse
683
- self .apply_pattern (value , pattern , match )
684
- return
701
+ if self .apply_pattern (self .prefix , value ):
702
+ return
685
703
686
704
# Exclude various standard elements found in modules, classes, etc.
687
705
if name in SKIP_LIST :
@@ -713,8 +731,11 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object
713
731
# Do not recurse into submodules, but include a directive to import them
714
732
self .import_object (value .__name__ , name = None , as_name = name )
715
733
return
716
- for name , child in getmembers (value ):
717
- self .put (child , name = name , parent = value )
734
+ else :
735
+ self .apply_pattern (self .prefix + ".__prefix__" , None )
736
+ for name , child in getmembers (value ):
737
+ self .put (child , name = name , parent = value )
738
+ self .apply_pattern (self .prefix + ".__suffix__" , None )
718
739
elif self .is_function (tp ):
719
740
value = cast (NbFunction , value )
720
741
self .put_function (value , name , parent )
@@ -996,7 +1017,10 @@ def get(self) -> str:
996
1017
if s :
997
1018
s += "\n "
998
1019
s += self .put_abstract_enum_class ()
1020
+
1021
+ # Append the main generated stub
999
1022
s += self .output
1023
+
1000
1024
return s .rstrip () + "\n "
1001
1025
1002
1026
def put_abstract_enum_class (self ) -> str :
@@ -1143,14 +1167,19 @@ def parse_options(args: List[str]) -> argparse.Namespace:
1143
1167
1144
1168
1145
1169
def load_pattern_file (fname : str ) -> List [ReplacePattern ]:
1170
+ """
1171
+ Load a pattern file from disk and return a list of pattern instances that
1172
+ includes precompiled versions of all of the contained regular expressions.
1173
+ """
1174
+
1146
1175
with open (fname , "r" ) as f :
1147
1176
f_lines = f .readlines ()
1148
1177
1149
1178
patterns : List [ReplacePattern ] = []
1150
1179
1151
1180
def add_pattern (query : str , lines : List [str ]):
1152
1181
# Exactly 1 empty line at the end
1153
- while lines and lines [- 1 ].isspace ():
1182
+ while lines and ( lines [- 1 ].isspace () or len ( lines [ - 1 ]) == 0 ):
1154
1183
lines .pop ()
1155
1184
lines .append ("" )
1156
1185
0 commit comments