|
| 1 | +#!/usr/bin/python3 |
| 2 | +#-*- encoding: Utf-8 -*- |
| 3 | +from google.protobuf.descriptor_pb2 import FileDescriptorProto, DescriptorProto |
| 4 | +from collections import defaultdict, OrderedDict |
| 5 | +from re import sub |
| 6 | + |
| 7 | +from utils.descpb_to_proto import descpb_to_proto |
| 8 | + |
| 9 | +""" |
| 10 | + When parsing output from e.g. the Java extractor, messages aren't |
| 11 | + nested and you need to nest them back between them. |
| 12 | + |
| 13 | + We nest only messages that are in the same package, and have only |
| 14 | + one reference into it. Except if this leads to disallowed patterns, |
| 15 | + such as mutually importing files, in this case more will be done. |
| 16 | + |
| 17 | + Also, ensure every message starts by an uppercase letter. Once this |
| 18 | + is done, render the .proto files to ASCII using the existing module. |
| 19 | +""" |
| 20 | + |
| 21 | +def nest_and_print_to_files(msg_path_to_obj, msg_to_referrers): |
| 22 | + msg_to_topmost = OrderedDict() |
| 23 | + msg_to_newloc = {} |
| 24 | + newloc_to_msg = {} |
| 25 | + msg_to_imports = defaultdict(list) |
| 26 | + |
| 27 | + # Iterate over referred to messages/groups/enums. |
| 28 | + |
| 29 | + for msg, referrers in dict(msg_to_referrers).items(): |
| 30 | + # Suppress references to unknown messages caused by |
| 31 | + # decompilation failures. |
| 32 | + if msg not in msg_path_to_obj: |
| 33 | + del msg_to_referrers[msg] |
| 34 | + for field, referrer, _ in referrers: |
| 35 | + field = next((i for i in msg_path_to_obj[referrer].field if i.name == field), None) |
| 36 | + |
| 37 | + field.ClearField('type_name') |
| 38 | + field.type = field.TYPE_BYTES |
| 39 | + else: |
| 40 | + for _, referrer, _ in referrers: |
| 41 | + msg_to_imports[referrer].append(msg) |
| 42 | + |
| 43 | + # Merge groups first: |
| 44 | + msg_to_referrers = OrderedDict(sorted(msg_to_referrers.items(), key=lambda x: -x[1][0][2])) |
| 45 | + |
| 46 | + mergeable = OrderedDict() |
| 47 | + enumfield_to_enums = defaultdict(set) |
| 48 | + enum_to_dupfields = defaultdict(set) |
| 49 | + |
| 50 | + for msg, referrers in msg_to_referrers.items(): |
| 51 | + msg_pkg = get_pkg(msg) |
| 52 | + msg_obj = msg_path_to_obj[msg] |
| 53 | + |
| 54 | + # Check for duplicate enum fields in the same package. |
| 55 | + if not isinstance(msg_obj, DescriptorProto): |
| 56 | + for enum_field in msg_obj.value: |
| 57 | + name = msg_pkg + '.' + enum_field.name |
| 58 | + enumfield_to_enums[name].add(msg) |
| 59 | + |
| 60 | + if len(enumfield_to_enums[name]) > 1: |
| 61 | + for other_enum in enumfield_to_enums[name]: |
| 62 | + enum_to_dupfields[other_enum].add(name) |
| 63 | + |
| 64 | + first_field = referrers[0] |
| 65 | + field, referrer, is_group = first_field |
| 66 | + |
| 67 | + # Check whether message/enum has exactly one reference in this |
| 68 | + # package. |
| 69 | + if not is_group: |
| 70 | + in_pkg = [(field, referrer) for field, referrer, _ in referrers \ |
| 71 | + if (get_pkg(referrer) == msg_pkg or not msg_pkg) \ |
| 72 | + and msg_to_topmost.get(referrer, referrer) != msg \ |
| 73 | + and not msg_path_to_obj[referrer].options.map_entry \ |
| 74 | + and ('$' not in msg or msg.split('.')[-1].split('$')[0] == \ |
| 75 | + referrer.split('.')[-1].split('$')[0])] |
| 76 | + |
| 77 | + if len({i for _, i in in_pkg}) != 1: |
| 78 | + # It doesn't. Keep for the next step |
| 79 | + if in_pkg: |
| 80 | + mergeable[msg] = in_pkg |
| 81 | + continue |
| 82 | + else: |
| 83 | + field, referrer = in_pkg[0] |
| 84 | + |
| 85 | + else: |
| 86 | + assert len(referrers) == 1 |
| 87 | + |
| 88 | + merge_and_rename(msg, referrer, msg_pkg, is_group, |
| 89 | + msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) |
| 90 | + |
| 91 | + # Try to fix recursive (mutual) imports, and conflicting enum field names. |
| 92 | + for msg, in_pkg in mergeable.items(): |
| 93 | + duplicate_enumfields = enum_to_dupfields.get(msg, set()) |
| 94 | + |
| 95 | + for field, referrer in sorted(in_pkg, key=lambda x: msg_to_newloc.get(x[1], x[1]).count('.')): |
| 96 | + top_referrer = msg_to_topmost.get(referrer, referrer) |
| 97 | + |
| 98 | + if (msg in msg_to_imports[top_referrer] and \ |
| 99 | + top_referrer in msg_to_imports[msg] and \ |
| 100 | + msg_to_topmost.get(referrer, referrer) != msg) or \ |
| 101 | + duplicate_enumfields: |
| 102 | + |
| 103 | + merge_and_rename(msg, referrer, get_pkg(msg), False, |
| 104 | + msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) |
| 105 | + break |
| 106 | + |
| 107 | + for dupfield in duplicate_enumfields: |
| 108 | + siblings = enumfield_to_enums[dupfield] |
| 109 | + siblings.remove(msg) |
| 110 | + if len(siblings) == 1: |
| 111 | + enum_to_dupfields[siblings.pop()].remove(dupfield) |
| 112 | + |
| 113 | + for msg, msg_obj in msg_path_to_obj.items(): |
| 114 | + # If we're a top-level message, enforce name transforms anyway |
| 115 | + if msg not in msg_to_topmost: |
| 116 | + new_name = msg_obj.name.split('$')[-1] |
| 117 | + new_name = new_name[0].upper() + new_name[1:] |
| 118 | + |
| 119 | + msg_pkg = get_pkg(msg) |
| 120 | + if msg_pkg: |
| 121 | + msg_pkg += '.' |
| 122 | + |
| 123 | + if new_name != msg_obj.name: |
| 124 | + while newloc_to_msg.get(msg_pkg + new_name, msg_pkg + new_name) in msg_path_to_obj: |
| 125 | + new_name += '_' |
| 126 | + msg_obj.name = new_name |
| 127 | + |
| 128 | + fix_naming(msg_obj, msg_pkg + new_name, msg, msg, |
| 129 | + msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) |
| 130 | + |
| 131 | + # Turn messages into individual files and stringify. |
| 132 | + |
| 133 | + path_to_file = OrderedDict() |
| 134 | + path_to_defines = defaultdict(list) |
| 135 | + |
| 136 | + for msg, msg_obj in msg_path_to_obj.items(): |
| 137 | + if msg not in msg_to_topmost: |
| 138 | + path = msg.split('$')[0].replace('.', '/') + '.proto' |
| 139 | + |
| 140 | + if path not in path_to_file: |
| 141 | + path_to_file[path] = FileDescriptorProto() |
| 142 | + path_to_file[path].syntax = 'proto2' |
| 143 | + path_to_file[path].package = get_pkg(msg) |
| 144 | + path_to_file[path].name = path |
| 145 | + file_obj = path_to_file[path] |
| 146 | + |
| 147 | + for imported in msg_to_imports[msg]: |
| 148 | + import_path = imported.split('$')[0].replace('.', '/') + '.proto' |
| 149 | + if import_path != path and imported not in msg_to_topmost: |
| 150 | + if import_path not in file_obj.dependency: |
| 151 | + file_obj.dependency.append(import_path) |
| 152 | + |
| 153 | + if isinstance(msg_obj, DescriptorProto): |
| 154 | + nested = file_obj.message_type.add() |
| 155 | + else: |
| 156 | + nested = file_obj.enum_type.add() |
| 157 | + nested.MergeFrom(msg_obj) |
| 158 | + |
| 159 | + path_to_defines[path].append(msg) |
| 160 | + path_to_defines[path] += [k for k, v in msg_to_topmost.items() if v == msg and '$map' not in k] |
| 161 | + |
| 162 | + for path, file_obj in path_to_file.items(): |
| 163 | + name, proto = descpb_to_proto(file_obj) |
| 164 | + header_lines = ['/**', 'Messages defined in this file:\n'] |
| 165 | + header_lines += path_to_defines[path] |
| 166 | + yield name, '\n * '.join(header_lines) + '\n */\n\n' + proto |
| 167 | + |
| 168 | +def merge_and_rename(msg, referrer, msg_pkg, is_group, |
| 169 | + msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg): |
| 170 | + if msg_pkg: |
| 171 | + msg_pkg += '.' |
| 172 | + |
| 173 | + msg_obj = msg_path_to_obj[msg] |
| 174 | + referrer_obj = msg_path_to_obj[referrer] |
| 175 | + top_path = msg_to_topmost.get(referrer, referrer) |
| 176 | + |
| 177 | + # Strip out $'s from name |
| 178 | + new_name = msg_obj.name.split('$')[-1] |
| 179 | + |
| 180 | + # Ensure first letter is uppercase, and avoid conflicts |
| 181 | + new_name = new_name[0].upper() + new_name[1:] |
| 182 | + |
| 183 | + other_names = [i.name for i in [*filter(lambda x: x.type != x.TYPE_GROUP, |
| 184 | + referrer_obj.field), |
| 185 | + *referrer_obj.nested_type, |
| 186 | + *referrer_obj.enum_type]] |
| 187 | + |
| 188 | + while new_name in other_names or \ |
| 189 | + (is_group and new_name.lower() in other_names) or \ |
| 190 | + (msg_pkg + new_name in msg_to_imports[top_path] and \ |
| 191 | + msg_pkg + new_name not in msg_to_topmost): |
| 192 | + new_name += '_' |
| 193 | + msg_obj.name = new_name |
| 194 | + |
| 195 | + # Perform the merging of nested message |
| 196 | + |
| 197 | + if isinstance(msg_obj, DescriptorProto): |
| 198 | + nested = referrer_obj.nested_type.add() |
| 199 | + else: |
| 200 | + nested = referrer_obj.enum_type.add() |
| 201 | + nested.MergeFrom(msg_obj) |
| 202 | + |
| 203 | + # Perform the renaming of references to nested message, and |
| 204 | + # of references to children of nested message. Also, fix imports |
| 205 | + |
| 206 | + new_path = msg_to_newloc.get(referrer, referrer) + '.' + nested.name |
| 207 | + |
| 208 | + msg_to_imports[top_path].extend(msg_to_imports[msg]) |
| 209 | + |
| 210 | + fix_naming(nested, new_path, msg, top_path, |
| 211 | + msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) |
| 212 | + |
| 213 | +""" |
| 214 | + Recursively iterate over the children and references of a just |
| 215 | + merged nested message/group/enum, in order to make state variables |
| 216 | + coherent. |
| 217 | +""" |
| 218 | +def fix_naming(nested, new_path, prev_path, top_path, |
| 219 | + msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg): |
| 220 | + |
| 221 | + # Keep track of the original full name of the generated block, as |
| 222 | + # it's the one we'll use when processing further references from |
| 223 | + # msg_to_referrers and other objects. |
| 224 | + orig_path = newloc_to_msg.get(prev_path, prev_path) |
| 225 | + newloc_to_msg[new_path] = orig_path |
| 226 | + |
| 227 | + if orig_path != top_path: |
| 228 | + msg_to_topmost[orig_path] = top_path |
| 229 | + msg_to_newloc[orig_path] = new_path |
| 230 | + msg_path_to_obj[orig_path] = nested |
| 231 | + |
| 232 | + # Fix references. |
| 233 | + for field, referrer, _ in msg_to_referrers.get(orig_path, []): |
| 234 | + field = next((i for i in msg_path_to_obj[referrer].field if i.name == field), None) |
| 235 | + |
| 236 | + field.type_name = '.' + new_path |
| 237 | + |
| 238 | + # Fix imports in reference's files. |
| 239 | + referrer_top_path = msg_to_topmost.get(referrer, referrer) |
| 240 | + |
| 241 | + msg_to_imports[referrer_top_path].append(top_path) |
| 242 | + |
| 243 | + # Do the same with children. |
| 244 | + if isinstance(nested, DescriptorProto): |
| 245 | + for child in [*nested.nested_type, *nested.enum_type]: |
| 246 | + fix_naming(child, new_path + '.' + child.name, prev_path + '.' + child.name, top_path, |
| 247 | + msg_to_referrers, msg_to_topmost, msg_to_newloc, msg_to_imports, msg_path_to_obj, newloc_to_msg) |
| 248 | + |
| 249 | +get_pkg = lambda x: ('.' + x).rsplit('.', 1)[0][1:] |
0 commit comments