Skip to content

Commit

Permalink
feat(secure aggregation): added init tensorkeys method for the runner
Browse files Browse the repository at this point in the history
Signed-off-by: Pant, Akshay <[email protected]>
  • Loading branch information
theakshaypant committed Jan 17, 2025
1 parent fc28386 commit fa5d8d3
Showing 1 changed file with 49 additions and 6 deletions.
55 changes: 49 additions & 6 deletions openfl/federated/task/runner_sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np

from openfl.federated.task.runner import TaskRunner
from openfl.utilities import TensorKey
from openfl.utilities.secagg import (
create_ciphertext,
Expand All @@ -18,7 +19,45 @@
)


class SATaskRunner:
class SATaskRunner(TaskRunner):
def __init__(self, device: str = None, loss_fn=None, optimizer=None, **kwargs):
super().__init__(self, **kwargs)
self.required_tensorkeys_for_function = {}

def initialize_tensorkeys_for_functions(self, with_opt_vars=False):
"""
Initialize the required TensorKeys for various functions.
This method sets up the required TensorKeys for the functions
"generate_keys", "generate_ciphertexts", and "decrypt_ciphertexts".
The TensorKeys specify the name, scope, version, and other attributes
for the tensors used in these functions.
Args:
with_opt_vars (bool): If True, includes optional variables in the
TensorKeys. Defaults to False.
TensorKeys:
- "generate_keys": No TensorKeys required.
- "generate_ciphertexts":
- public_key: A global tensor key for the public key.
- public_key_local: A local tensor key for the public key.
- "decrypt_ciphertexts":
- ciphertext: A global tensor key for the ciphertext.
- ciphertext_local: A local tensor key for the ciphertext.
- index: A local tensor key for the index.
"""
self.required_tensorkeys_for_function["generate_keys"] = []
self.required_tensorkeys_for_function["generate_ciphertexts"] = [
TensorKey("public_key", "GLOBAL", 1, False, ("public_key")),
TensorKey("public_key_local", "LOCAL", 1, False, ("public_key")),
]
self.required_tensorkeys_for_function["decrypt_ciphertexts"] = [
TensorKey("ciphertext", "GLOBAL", 1, False, ("ciphertext")),
TensorKey("ciphertext_local", "LOCAL", 1, False, ("ciphertext_local")),
TensorKey("index", "LOCAL", 1, False, ())
]

def generate_keys(
self,
col_name,
Expand Down Expand Up @@ -51,11 +90,15 @@ def generate_keys(
private_key1,
private_key2,
],
TensorKey("public_key_local", col_name, round_number, False, ("public_key")): [
public_key1,
public_key2,
],
TensorKey("private_seed", col_name, round_number, False, ()): [random.random()],
}

global_tensor_dict = {
TensorKey("public_key_local", col_name, round_number, False, ("public_key")): [
TensorKey("public_key", col_name, round_number, False, ("public_key")): [
public_key1,
public_key2,
]
Expand All @@ -80,7 +123,7 @@ def generate_ciphertexts(
Required tensors for the task include:
- GLOBAL public_key
- public_key_local
- LOCAL public_key_local
Args:
col_name (str): The column name for the tensor key.
Expand Down Expand Up @@ -176,10 +219,10 @@ def decrypt_ciphertexts(
Required tensors for the task include:
- GLOBAL ciphertext.
- index
- ciphertext_local
- LOCAL ciphertext_local
- LOCAL index
Args:
Args:
col_name (str): The name of the column.
round_number (int): The current round number.
input_tensor_dict (dict): A dictionary containing the required
Expand Down

0 comments on commit fa5d8d3

Please sign in to comment.