Skip to content

Commit

Permalink
Draft of dict input template.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Jan 2, 2024
1 parent a574b22 commit 18fa413
Showing 1 changed file with 179 additions and 20 deletions.
199 changes: 179 additions & 20 deletions kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ def template_cast_output(model_outputs,
output_tensor_type,
input_tensor_type,
cast_disjoint_kwargs):
"""
"""Template to cast graph, node or edge output to the desired tensor representation.
Args:
model_outputs:
output_embedding:
output_tensor_type:
input_tensor_type:
cast_disjoint_kwargs:
model_outputs (list): List of output and additional ID tensors. The list must always be
[model_output, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges]
but can have None in place of ID tensors if not required.
output_embedding (str): Embedding of the graph output, either "graph", "node" or "edge".
output_tensor_type (str): The tensor representation of model output.
input_tensor_type (str): The tensor representation of model input.
cast_disjoint_kwargs (dict): Kwargs for casting layers.
Returns:
Tensor: Keras output tensor.
Expand Down Expand Up @@ -81,7 +83,7 @@ def template_cast_output(model_outputs,
else:
out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out)
else:
raise NotImplementedError()
raise NotImplementedError("Unknown output embedding choice.")

return out

Expand Down Expand Up @@ -167,21 +169,27 @@ def template_cast_output(model_outputs,
"""


def template_cast_list_input(model_inputs,
input_tensor_type,
cast_disjoint_kwargs,
def template_cast_list_input(model_inputs: list,
input_tensor_type: str,
cast_disjoint_kwargs: dict,
mask_assignment: list = None,
index_assignment: list = None,
return_sub_id: bool = True):
"""
r"""Template to cast a list of model inputs to a list of disjoint tensors. The number of model inputs can be
variable. The system is preset and explained by :obj:`template_cast_list_input_docs` .
The ID information from the mask in tensor form is appended to the returned list.
Ragged Tensor do not require a mask but use the mask information to generate joint ID tensors.
Args:
model_inputs:
input_tensor_type:
cast_disjoint_kwargs:
mask_assignment:
index_assignment:
return_sub_id:
model_inputs (list): List of Keras inputs.
input_tensor_type (str): Input tensor type. Either "padded", "ragged" or "disjoint".
cast_disjoint_kwargs (dict): Kwargs for casting layers.
mask_assignment (list): List that assigns Tensors to their mask.
Inputs that do not require a mask must be marked by Nones. Different inputs can use the same mask.
index_assignment (list): List that assigns index Tensors to their target to which the index refer to.
Inputs that are not indices must be marked by Nones.
return_sub_id (bool): Whether the returned list contains the sub-graph ID tensors.
Returns:
list: List of Keras Tensors for disjoint model.
Expand Down Expand Up @@ -210,7 +218,7 @@ def template_cast_list_input(model_inputs,
if index_assignment is None:
index_assignment = [None for _ in range(len(values_input))]
if len(index_assignment) != len(mask_assignment):
raise ValueError()
raise ValueError("Number of provided mask tensors does not match template specification.")

out_tensor = [None for _ in range(len(values_input))]
out_batch_id = [None for _ in range(num_mask)]
Expand All @@ -229,7 +237,7 @@ def template_cast_list_input(model_inputs,
o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([ref, x, ref_mask, x_mask])
out_tensor[i] = o_x
# Important to no overwrite indices with simple values here.
# Important to not overwrite indices with simple values here!
if out_tensor[i_ref] is None:
out_tensor[i_ref] = o_ref
out_batch_id[m] = b_x
Expand Down Expand Up @@ -281,7 +289,7 @@ def template_cast_list_input(model_inputs,
o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastRaggedIndicesToDisjoint(
**cast_disjoint_kwargs)([ref, x])
out_tensor[i] = o_x
# Important to no overwrite indices with simple values here.
# Important to no overwrite indices with simple values here! This is the case if indices reference indices.
if out_tensor[i_ref] is None:
out_tensor[i_ref] = o_ref
out_batch_id[m] = b_x
Expand Down Expand Up @@ -316,3 +324,154 @@ def template_cast_list_input(model_inputs,
out = out + out_totals

return out


def template_cast_dict_input(model_inputs: dict,
input_tensor_type: str,
cast_disjoint_kwargs: dict,
mask_assignment: dict = None,
index_assignment: dict = None,
return_sub_id: bool = True,
rename_mask_to_id: dict = None):
"""Template to cast a dictionary of model inputs to a dict of disjoint tensors. The number of model inputs can be
variable. The system is rather flexible and explained by :obj:`template_cast_dict_input_docs` .
The ID information from the mask in tensor form is appended to the returned dictionary.
Ragged Tensor do not require a mask but use the mask information to generate joint ID tensors.
Args:
model_inputs (dict): Dictionary of Keras inputs.
input_tensor_type (str): Input tensor type. Either "padded", "ragged" or "disjoint".
cast_disjoint_kwargs (dict): Kwargs for casting layers.
mask_assignment (dict): Dictionary of mask name for each input that requires a mask and is cast
to a disjoint tensor representation.
index_assignment (dict): Dictionary of assigning indices to the name of their target tensors to which
the indices refer to.
return_sub_id (bool): Whether the returned dict contains the sub-graph ID tensors.
rename_mask_to_id (dict): Mapping of mask names to ID names.
Returns:
dict: Model input tensors in disjoint representation.
"""

is_already_disjoint = False
out_tensor = {}
out_batch_id = {}
out_graph_id = {}
out_totals = {}

if input_tensor_type in ["padded", "masked"]:
if mask_assignment is None or not isinstance(mask_assignment, dict):
raise ValueError("Mask assignment information is required or invalid.")

reduced_mask = list(set(list(mask_assignment.values())))

values_input = {key: value for key, value in model_inputs.items() if key not in reduced_mask}
mask_input = {key: value for key, value in model_inputs.items() if key in reduced_mask}

if index_assignment is None:
index_assignment = {}

for name_index, name_ref in index_assignment.items():
ref = values_input[name_ref]
x = values_input[name_index]
m, m_ref = mask_assignment[name_index], mask_assignment[name_ref]
ref_mask = mask_input[m_ref]
x_mask = mask_input[m]
o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([ref, x, ref_mask, x_mask])
out_tensor[name_index] = o_x
# Important to not overwrite indices with simple values here! This is the case if indices reference indices.
if name_ref not in out_tensor.keys():
out_tensor[name_ref] = o_ref
out_batch_id[m] = b_x
out_batch_id[m_ref] = b_r
out_graph_id[m] = g_x
out_graph_id[m_ref] = g_r
out_totals[m] = t_x
out_totals[m_ref] = t_r

for name_i, x in values_input.items():

if name_i in out_tensor.keys():
continue

if name_i not in mask_assignment.keys():
out_tensor[name_i] = CastBatchedGraphStateToDisjoint(**cast_disjoint_kwargs)(x)
continue

m = mask_assignment[name_i]
x_mask = mask_input[m]
o_x, bi, gi, tot = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([x, x_mask])
out_tensor[name_i] = o_x
out_batch_id[m] = bi
out_graph_id[m] = gi
out_totals[m] = tot

elif input_tensor_type in ["ragged", "jagged"]:
if mask_assignment is None or not isinstance(mask_assignment, dict):
raise ValueError("Mask assignment information is required or invalid.")

if index_assignment is None:
index_assignment = {}

reduced_mask = list(set(list(mask_assignment.values())))

values_input = {key: value for key, value in model_inputs.items() if key not in reduced_mask}

for name_index, name_ref in index_assignment.items():
ref = values_input[name_ref]
x = values_input[name_index]
m, m_ref = mask_assignment[name_index], mask_assignment[name_ref]
o_ref, o_x, b_r, b_x, g_r, g_x, t_r, t_x = CastRaggedIndicesToDisjoint(
**cast_disjoint_kwargs)([ref, x])
out_tensor[name_index] = o_x
# Important to no overwrite indices with simple values here! This is the case if indices reference indices.
if name_ref not in out_tensor.keys():
out_tensor[name_ref] = o_ref
out_batch_id[m] = b_x
out_batch_id[m_ref] = b_r
out_graph_id[m] = g_x
out_graph_id[m_ref] = g_r
out_totals[m] = t_x
out_totals[m_ref] = t_r

for name_i, x in values_input.items():
if name_i in out_tensor.keys():
continue

if name_i not in mask_assignment.keys():
out_tensor[name_i] = x
continue

m = mask_assignment[name_i]
o_x, bi, gi, tot = CastRaggedAttributesToDisjoint(**cast_disjoint_kwargs)(x)
out_tensor[name_i] = o_x
out_batch_id[m] = bi
out_graph_id[m] = gi
out_totals[m] = tot

else:
is_already_disjoint = True

# Rename IDs.
def map_mask_key(k):
mapping = {} if rename_mask_to_id is None else rename_mask_to_id
if k in mapping.keys():
return mapping[k]
return k
out_batch_id = {"graph_id_%s" % map_mask_key(key): value for key, value in out_batch_id.items()}
out_graph_id = {"%s_id" % map_mask_key(key): value for key, value in out_graph_id.items()}
out_totals = {"%s_count" % map_mask_key(key): value for key, value in out_totals.items()}

if is_already_disjoint:
out = model_inputs
else:
out = {}
out.update(out_tensor)
out.update(out_batch_id)
if return_sub_id:
out.update(out_graph_id)
out.update(out_totals)

return out

0 comments on commit 18fa413

Please sign in to comment.