Tensorflow Keras Participant Example

This is an example of a Tensorflow Keras implementation of a Participant for federated learning.

We cover the requirements of the Participant Abstract Base Class, give ideas on how to handle a TF Keras Model and TF Keras Data in the Participant, and show how to implement a federated learning TF Keras Training Round. You can find the complete source code here. The example code makes use of typing to be precise about the expected data types.

Participant Abstract Base Class

The SDK provides an abstract base class for Participants which can be imported as

from xain_sdk.participant import Participant as ABCParticipant

A custom Participant should inherit from the abstract base class, like

class Participant(ABCParticipant):

and must implement the init_weights() and train_round() methods in order to be able to execute a round of federated learning, where each round consists of a certain number of epochs. These methods adhere to the function signatures

init_weights(self) -> ndarray
train_round(self, weights: ndarray, epochs: int, epoch_base: int) -> Tuple[ndarray, int]

The expected arguments are:

  • weights (ndarray): A Numpy array containing the flattened weights of the global model.

  • epochs (int): The number of epochs to be trained during the federated learning round. Can be any non-negative number including zero.

  • epoch_base (int): A global training epoch number (e.g. for epoch dependent learning rate schedules and metrics logging).

The expected return values are:

  • ndarray: The flattened weights of the local model which result from initialization or the global model after certain epochs of training on local data.

  • int: The number of samples in the train dataset used for aggregation strategies.

The Participant’s base class provides utility methods to set the weights of the local model according to the given flat weights vector, by

set_tensorflow_weights(weights: ndarray, shapes: List[Tuple[int, ...]], model: Model) -> None

and to get a flattened weights vector from the local model, by

get_tensorflow_weights(model: Model) -> ndarray

as well as the original shapes of the weights of the local model, by

get_tensorflow_shapes(model: Model) -> List[Tuple[int, ...]]

Also, metrics of the current training epoch can be send to a time series data base via the coordinator by

update_metrics(epoch, epoch_base, MetricName=metric_value, ...)

for any number of metrics.

TF Keras Model

A TF Keras model definition might either be loaded from a file, generated during the initialization of the Participant, or even generated on the fly. Here, we present a simple dense neural network for classification generated during the Participant’s initialization, which is wrapped in the init_model() helper function.

The following attributes are only used to make the model configurable, via

self.features: int
self.units: int
self.categories: int

The example model consists of an input layer holding features parameters per sample, as

input_layer: Tensor = Input(shape=(self.features,), dtype="float32")

Next, it has a fully connected hidden layer with units relu-activated units, as

hidden_layer: Tensor = Dense(

Finally, it has a fully connected output layer with categories softmax-activated units, as

output_layer: Tensor = Dense(

The model gets compiled with an Adam optimizer, the categorical crossentropy loss function and the categorical accuracy metric, like

self.model: Model = Model(inputs=[input_layer], outputs=[output_layer])
self.model.compile(optimizer="Adam", loss="categorical_crossentropy", metrics=["categorical_accuracy"])

The utility method for setting the model weights require the original shapes of the weights, obtainable as

self.model_shapes: List[Tuple[int, ...]] = self.get_tensorflow_shapes(model=self.model)

TF Keras Data

The data on which the model will be trained, can either be loaded from a data source (e.g. file, bucket, database) during the initialization of the Participant or on the fly in a train_round(). Here, we employ randomly generated placeholder data as an example, which is wrapped in the init_datasets() helper function. This is by no means a meaningful dataset, but it should be sufficient to convey the overall idea.

The following attributes are only used to make the dataset configurable, via

self.train_samples: int
self.val_samples: int
self.test_samples: int
self.batch_size: int

The dataset for training gets shuffled and batched, like

self.trainset: Dataset = Dataset.from_tensor_slices(
        np.ones(shape=(self.train_samples, self.features), dtype=np.float32),
            np.eye(self.categories, dtype=np.float32), reps=(int(np.ceil(self.train_samples / self.categories)), 1)
        )[0 : self.train_samples, :],

while the datasets for validation and testing only get batched, like

self.valset: Dataset = Dataset.from_tensor_slices(
        np.ones(shape=(self.val_samples, self.features), dtype=np.float32),
            np.eye(self.categories, dtype=np.float32), reps=(int(np.ceil(self.val_samples / self.categories)), 1)
        )[0 : self.val_samples, :],
self.testset: Dataset = Dataset.from_tensor_slices(
        np.ones(shape=(self.test_samples, self.features), dtype=np.float32),
            np.eye(self.categories, dtype=np.float32), reps=(int(np.ceil(self.test_samples / self.categories)), 1)
        )[0 : self.test_samples, :],

TF Keras Training Round

Whenever the coordinator needs to get freshly initialized model weights, e.g. in the 0-th round of the training, the init_weights() method is called, which consists of two main steps. First, new model weights are initialized according to the model definition, and finally, these weights are returned without further training, as

return self.get_tensorflow_weights(model=self.model)

The implementation of the actual train_round() method consists of three main steps. First, the provided weights of the global model are loaded into the local model, as

self.set_tensorflow_weights(weights=weights, shapes=self.model_shapes, model=self.model)

Next, the local model is trained for a certain number of epochs on the local data, whereby the metrics are gathered in each epoch, as

for epoch in range(epochs):
    self.model.fit(x=self.trainset, verbose=2, shuffle=False)
    metrics: List[ndarray] = self.model.evaluate(x=self.valset, verbose=0)
    self.update_metrics(epoch, epoch_base, Loss=metrics[0], Accuracy=metrics[1])

Finally, the updated weights of the local model and the number of samples of the train dataset are returned, as

return self.get_tensorflow_weights(model=self.model), self.train_samples