From 18fa4133c2e3c37cf74f4befd803df5fd15e5bba Mon Sep 17 00:00:00 2001 From: PatReis Date: Tue, 2 Jan 2024 17:18:41 +0100 Subject: [PATCH] Draft of dict input template. --- kgcnn/models/casting.py | 199 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 179 insertions(+), 20 deletions(-) diff --git a/kgcnn/models/casting.py b/kgcnn/models/casting.py index dc6816cd..4cf588b6 100644 --- a/kgcnn/models/casting.py +++ b/kgcnn/models/casting.py @@ -36,14 +36,16 @@ def template_cast_output(model_outputs, output_tensor_type, input_tensor_type, cast_disjoint_kwargs): - """ + """Template to cast graph, node or edge output to the desired tensor representation. Args: - model_outputs: - output_embedding: - output_tensor_type: - input_tensor_type: - cast_disjoint_kwargs: + model_outputs (list): List of output and additional ID tensors. The list must always be + [model_output, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges] + but can have None in place of ID tensors if not required. + output_embedding (str): Embedding of the graph output, either "graph", "node" or "edge". + output_tensor_type (str): The tensor representation of model output. + input_tensor_type (str): The tensor representation of model input. + cast_disjoint_kwargs (dict): Kwargs for casting layers. Returns: Tensor: Keras output tensor. @@ -81,7 +83,7 @@ def template_cast_output(model_outputs, else: out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out) else: - raise NotImplementedError() + raise NotImplementedError("Unknown output embedding choice.") return out @@ -167,21 +169,27 @@ def template_cast_output(model_outputs, """ -def template_cast_list_input(model_inputs, - input_tensor_type, - cast_disjoint_kwargs, +def template_cast_list_input(model_inputs: list, + input_tensor_type: str, + cast_disjoint_kwargs: dict, mask_assignment: list = None, index_assignment: list = None, return_sub_id: bool = True): - """ + r"""Template to cast a list of model inputs to a list of disjoint tensors. The number of model inputs can be + variable. The system is preset and explained by :obj:`template_cast_list_input_docs` . + + The ID information from the mask in tensor form is appended to the returned list. + Ragged Tensor do not require a mask but use the mask information to generate joint ID tensors. Args: - model_inputs: - input_tensor_type: - cast_disjoint_kwargs: - mask_assignment: - index_assignment: - return_sub_id: + model_inputs (list): List of Keras inputs. + input_tensor_type (str): Input tensor type. Either "padded", "ragged" or "disjoint". + cast_disjoint_kwargs (dict): Kwargs for casting layers. + mask_assignment (list): List that assigns Tensors to their mask. + Inputs that do not require a mask must be marked by Nones. Different inputs can use the same mask. + index_assignment (list): List that assigns index Tensors to their target to which the index refer to. + Inputs that are not indices must be marked by Nones. + return_sub_id (bool): Whether the returned list contains the sub-graph ID tensors. Returns: list: List of Keras Tensors for disjoint model. @@ -210,7 +218,7 @@ def template_cast_list_input(model_inputs, if index_assignment is None: index_assignment = [None for _ in range(len(values_input))] if len(index_assignment) != len(mask_assignment): - raise ValueError() + raise ValueError("Number of provided mask tensors does not match template specification.") out_tensor = [None for _ in range(len(values_input))] out_batch_id = [None for _ in range(num_mask)] @@ -229,7 +237,7 @@ def template_cast_list_input(model_inputs, o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastBatchedIndicesToDisjoint( **cast_disjoint_kwargs)([ref, x, ref_mask, x_mask]) out_tensor[i] = o_x - # Important to no overwrite indices with simple values here. + # Important to not overwrite indices with simple values here! if out_tensor[i_ref] is None: out_tensor[i_ref] = o_ref out_batch_id[m] = b_x @@ -281,7 +289,7 @@ def template_cast_list_input(model_inputs, o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastRaggedIndicesToDisjoint( **cast_disjoint_kwargs)([ref, x]) out_tensor[i] = o_x - # Important to no overwrite indices with simple values here. + # Important to no overwrite indices with simple values here! This is the case if indices reference indices. if out_tensor[i_ref] is None: out_tensor[i_ref] = o_ref out_batch_id[m] = b_x @@ -316,3 +324,154 @@ def template_cast_list_input(model_inputs, out = out + out_totals return out + + +def template_cast_dict_input(model_inputs: dict, + input_tensor_type: str, + cast_disjoint_kwargs: dict, + mask_assignment: dict = None, + index_assignment: dict = None, + return_sub_id: bool = True, + rename_mask_to_id: dict = None): + """Template to cast a dictionary of model inputs to a dict of disjoint tensors. The number of model inputs can be + variable. The system is rather flexible and explained by :obj:`template_cast_dict_input_docs` . + + The ID information from the mask in tensor form is appended to the returned dictionary. + Ragged Tensor do not require a mask but use the mask information to generate joint ID tensors. + + Args: + model_inputs (dict): Dictionary of Keras inputs. + input_tensor_type (str): Input tensor type. Either "padded", "ragged" or "disjoint". + cast_disjoint_kwargs (dict): Kwargs for casting layers. + mask_assignment (dict): Dictionary of mask name for each input that requires a mask and is cast + to a disjoint tensor representation. + index_assignment (dict): Dictionary of assigning indices to the name of their target tensors to which + the indices refer to. + return_sub_id (bool): Whether the returned dict contains the sub-graph ID tensors. + rename_mask_to_id (dict): Mapping of mask names to ID names. + + Returns: + dict: Model input tensors in disjoint representation. + """ + + is_already_disjoint = False + out_tensor = {} + out_batch_id = {} + out_graph_id = {} + out_totals = {} + + if input_tensor_type in ["padded", "masked"]: + if mask_assignment is None or not isinstance(mask_assignment, dict): + raise ValueError("Mask assignment information is required or invalid.") + + reduced_mask = list(set(list(mask_assignment.values()))) + + values_input = {key: value for key, value in model_inputs.items() if key not in reduced_mask} + mask_input = {key: value for key, value in model_inputs.items() if key in reduced_mask} + + if index_assignment is None: + index_assignment = {} + + for name_index, name_ref in index_assignment.items(): + ref = values_input[name_ref] + x = values_input[name_index] + m, m_ref = mask_assignment[name_index], mask_assignment[name_ref] + ref_mask = mask_input[m_ref] + x_mask = mask_input[m] + o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastBatchedIndicesToDisjoint( + **cast_disjoint_kwargs)([ref, x, ref_mask, x_mask]) + out_tensor[name_index] = o_x + # Important to not overwrite indices with simple values here! This is the case if indices reference indices. + if name_ref not in out_tensor.keys(): + out_tensor[name_ref] = o_ref + out_batch_id[m] = b_x + out_batch_id[m_ref] = b_r + out_graph_id[m] = g_x + out_graph_id[m_ref] = g_r + out_totals[m] = t_x + out_totals[m_ref] = t_r + + for name_i, x in values_input.items(): + + if name_i in out_tensor.keys(): + continue + + if name_i not in mask_assignment.keys(): + out_tensor[name_i] = CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(x) + continue + + m = mask_assignment[name_i] + x_mask = mask_input[m] + o_x, bi, gi, tot = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, x_mask]) + out_tensor[name_i] = o_x + out_batch_id[m] = bi + out_graph_id[m] = gi + out_totals[m] = tot + + elif input_tensor_type in ["ragged", "jagged"]: + if mask_assignment is None or not isinstance(mask_assignment, dict): + raise ValueError("Mask assignment information is required or invalid.") + + if index_assignment is None: + index_assignment = {} + + reduced_mask = list(set(list(mask_assignment.values()))) + + values_input = {key: value for key, value in model_inputs.items() if key not in reduced_mask} + + for name_index, name_ref in index_assignment.items(): + ref = values_input[name_ref] + x = values_input[name_index] + m, m_ref = mask_assignment[name_index], mask_assignment[name_ref] + o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastRaggedIndicesToDisjoint( + **cast_disjoint_kwargs)([ref, x]) + out_tensor[name_index] = o_x + # Important to no overwrite indices with simple values here! This is the case if indices reference indices. + if name_ref not in out_tensor.keys(): + out_tensor[name_ref] = o_ref + out_batch_id[m] = b_x + out_batch_id[m_ref] = b_r + out_graph_id[m] = g_x + out_graph_id[m_ref] = g_r + out_totals[m] = t_x + out_totals[m_ref] = t_r + + for name_i, x in values_input.items(): + if name_i in out_tensor.keys(): + continue + + if name_i not in mask_assignment.keys(): + out_tensor[name_i] = x + continue + + m = mask_assignment[name_i] + o_x, bi, gi, tot = CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x) + out_tensor[name_i] = o_x + out_batch_id[m] = bi + out_graph_id[m] = gi + out_totals[m] = tot + + else: + is_already_disjoint = True + + # Rename IDs. + def map_mask_key(k): + mapping = {} if rename_mask_to_id is None else rename_mask_to_id + if k in mapping.keys(): + return mapping[k] + return k + out_batch_id = {"graph_id_%s" % map_mask_key(key): value for key, value in out_batch_id.items()} + out_graph_id = {"%s_id" % map_mask_key(key): value for key, value in out_graph_id.items()} + out_totals = {"%s_count" % map_mask_key(key): value for key, value in out_totals.items()} + + if is_already_disjoint: + out = model_inputs + else: + out = {} + out.update(out_tensor) + out.update(out_batch_id) + if return_sub_id: + out.update(out_graph_id) + out.update(out_totals) + + return out