From 4df3c8565af1e7981f47328148a2a0a4443c5807 Mon Sep 17 00:00:00 2001 From: Jan Blunk Date: Mon, 28 Aug 2023 16:32:00 +0200 Subject: [PATCH] example notebook and network with added documentation --- README.md | 42 +++- dataset_utils.py | 37 +++ feature_steering_example.ipynb | 222 ++++++++++++++++++ ...miI_estimator.py => mixed_cmi_estimator.py | 0 regression_dataset.py | 4 +- regression_network.py | 176 ++++++++++++++ requirements.txt | 8 + teaser.png | Bin 0 -> 34850 bytes 8 files changed, 483 insertions(+), 6 deletions(-) create mode 100644 dataset_utils.py create mode 100644 feature_steering_example.ipynb rename mixed_cmiI_estimator.py => mixed_cmi_estimator.py (100%) create mode 100644 regression_network.py create mode 100644 requirements.txt create mode 100644 teaser.png diff --git a/README.md b/README.md index 8ec73e9..ff30c75 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,15 @@ # Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization -This repository provides code to use the method presented in our GCPR 2023 paper "Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization". If you use this method, please cite: +## Overview +This repository provides code to use the method presented in our GCPR 2023 paper **"Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization"**. If you want to get started, take a look at our [example network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and the corresponding [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb). + +
+By measuring the feature usage, we can steer the model towards (not) using features that are specifically (un-)desired. +
+ +Our method generalizes from debiasing to the **encouragement and discouragement of arbitrary features**. That is, it not only aims at removing the influence of undesired features / biases but also at increasing the influence of features that are known to be well-established from domain knowledge. + +If you use our method, please cite: @inproceedings{Blunk23:FS, author = {Jan Blunk and Niklas Penzel and Paul Bodesheim and Joachim Denzler}, @@ -9,8 +18,33 @@ This repository provides code to use the method presented in our GCPR 2023 paper year = {2023}, } -This repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022). CMIh was published under the MIT license. - ## Installation -First, you have to install ACD as described in the [Repository for "Hierarchical interpretations for neural network predictions" by Singh et al.](https://github.com/csinva/hierarchical-dnn-interpretations) +**Install with pip, Python and PyTorch 2.0+** + + git clone https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing.git + cd beyond-debiasing + pip install -r requirements.txt + +First, create an environment with pip and Python first (Anaconda environment / Python virtual environment). We recommend to install [PyTorch with CUDA support](https://pytorch.org/get-started/locally/). Then, you can install all subsequent packages via pip as described above. + +## Usage in Python +Since our method relies on loss regularization, it is very simple to add to your own networks - you only need to modify your loss function. To help with that, we provide an [exemplary network](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/regression_network.py) and a [jupyter notebook](https://git.inf-cv.uni-jena.de/blunk/beyond-debiasing/src/main/feature_steering_example.ipynb) with example code. + +## Repository Organization +* Installation: + * [`requirements.txt`](requirements.txt): List of required packages for installation with pip +* Feature attribution: + * [`contextual_decomposition.py`](contextual_decomposition.py): Wrapper for contextual decomposition + * [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py): Python port of the CMIh estimator of the conditional +* Redundant regression dataset: + * [`algebra.py`](algebra.py): Generation of random orthogonal matrices + * [`make_regression.py`](make_regression.py): An adapted version of scikit-learns make_regression(...), where the coefficients are standard-uniform + * [`regression_dataset.py`](regression_dataset.py): Generation of the redundant regression dataset + * [`dataset_utils.py`](dataset_utils.py): Creation of torch dataset from numpy arrays + * [`tensor_utils.py`](tensor_utils.py): Some helpful functions for dealing with tensors +* Example: + * [`feature_steering_example.ipynb`](feature_steering_example.ipynb): Example for generating the dataset, creating and training the network with detailed comments + * [`regression_network.py`](regression_network.py): Neural network (PyTorch) used in the example notebook + +With [`mixed_cmi_estimator.py`](mixed_cmi_estimator.py) this repository includes a Python implementation of the hybrid CMI estimator CMIh presented by [Zan et al.](https://doi.org/10.3390/e24091234) The authors' original R implementation can be found [here](https://github.com/leizan/CMIh2022). diff --git a/dataset_utils.py b/dataset_utils.py new file mode 100644 index 0000000..ee6b039 --- /dev/null +++ b/dataset_utils.py @@ -0,0 +1,37 @@ +import torch + +def get_dataset_from_arrays(train_features, train_outputs, test_features, test_outputs, validation_features=None, validation_outputs=None, batch_size=1): + """ + Both test and train dataset are numpy arrays. Observations are represented + as rows, features as columns. + train_targets and test_targets are vectors, containing one value per row + (expected results). + """ + + train_inputs = torch.tensor(train_features.tolist()) + train_targets = torch.FloatTensor(train_outputs) + train_dataset = torch.utils.data.TensorDataset(train_inputs, train_targets) + + test_inputs = torch.tensor(test_features.tolist()) + test_targets = torch.FloatTensor(test_outputs) + test_dataset = torch.utils.data.TensorDataset(test_inputs, test_targets) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=1 + ) + + if not validation_features is None: + validation_inputs = torch.tensor(validation_features.tolist()) + validation_targets = torch.FloatTensor(validation_outputs) + validation_dataset = torch.utils.data.TensorDataset(validation_inputs, validation_targets) + + validation_loader = torch.utils.data.DataLoader( + validation_dataset, batch_size=batch_size, shuffle=False, num_workers=1 + ) + + return (train_dataset, train_loader, test_dataset, test_loader, validation_dataset, validation_loader) + else: + return (train_dataset, train_loader, test_dataset, test_loader) diff --git a/feature_steering_example.ipynb b/feature_steering_example.ipynb new file mode 100644 index 0000000..30a4017 --- /dev/null +++ b/feature_steering_example.ipynb @@ -0,0 +1,222 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from regression_dataset import make_regression_dataset\n", + "from regression_network import RegressionNetwork\n", + "\n", + "# Deterministic execution.\n", + "CUDA_LAUNCH_BLOCKING = 1\n", + "seed = 42\n", + "torch.manual_seed(seed)\n", + "np.random.seed(seed)\n", + "torch.backends.cudnn.benchmark = False\n", + "torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example for the Application of the Feature Steering Method Presented in “Beyond Debiasing: Actively Steering Feature Selection via Loss Regularization”\n", + "\n", + "This jupyter notebook provides an example for our method for the redundant regression dataset presented in our paper.\n", + "\n", + "You can choose to generate feature attributions with the feature attribution method provided by Reimers et al. based on both **contextual decomposition** and **conditional mutual information**. Additionally, you can choose other hyperparameters such as the weight factor $\\lambda$ and the norm that is applied (L1 / L2 norm).\n", + "\n", + "## Dataset\n", + "We create a small regression dataset with redundant variables as described in our paper. That is, the created dataset has 9 input variables with a redundancy of 3 variables. In total, we generate 2000 samples, of which 1400 are used for training.\n", + "\n", + "*Note:* In the evaluations for our paper we not only generate one, but rather 9 datasets with different seeds." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration of the datasets.\n", + "high_dim_transform = True\n", + "normalize = True # Should only be set if HIGH_DIM_TRANSFORM\n", + "n_informative_low_dim = 6\n", + "n_high_dim = 9\n", + "n_train, n_test, n_validation = 1400, 300, 300\n", + "n_uninformative_low_dim = 0\n", + "dataset_seed = 42\n", + "batch_size = 100\n", + "n_datasets = 9\n", + "\n", + "noise_on_output = 0.0\n", + "noise_on_high_dim_snrdb = None" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Create and load the regression dataset.\n", + "train_dataloader, test_dataloader, validation_dataloader, _ = make_regression_dataset(\n", + " high_dim_transform=high_dim_transform,\n", + " n_features_low_dim=n_informative_low_dim,\n", + " n_uninformative_low_dim=n_uninformative_low_dim,\n", + " n_high_dim=n_high_dim,\n", + " noise_on_high_dim_snrdb=noise_on_high_dim_snrdb,\n", + " noise_on_output=noise_on_output,\n", + " n_train=n_train,\n", + " n_test=n_test,\n", + " n_validation=n_validation,\n", + " normalize=normalize,\n", + " seed=dataset_seed,\n", + " batch_size=batch_size,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network\n", + "We follow the paper and create a network with a single hidden layer of size 9 and input size 9." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Network architecture.\n", + "input_size = n_high_dim\n", + "hidden_dim_size = n_high_dim\n", + "n_hidden_layers = 1\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# Create Network.\n", + "mlp = RegressionNetwork(\n", + " input_shape=input_size,\n", + " n_hidden_layers=n_hidden_layers,\n", + " hidden_dim_size=hidden_dim_size,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After creating the network, we can train it with the *feature steering loss*.\n", + "\n", + "Recall from the paper that our method to steer feature usage is implemented via loss regularization. Let $D$ refer to the set of features that should be discouraged and $E$ to the set of features that should be encouraged. With $c_i$ being a measure of the influence of feature $i$ on the model's prediction process, $\\lambda \\in \\mathbb{R}_{\\ge 0}$ as a weight factor and $\\mathcal{L}$ as the standard maximum-likelihood loss for network parameters $\\theta$, our model is trained with the following loss function:\n", + "\n", + "$$ \\mathcal{L}'(\\theta) = \\mathcal{L}(\\theta) + \\lambda \\left( \\sum_{i \\in D} || c_i || - \\sum_{i \\in E} || c_i || \\right) .$$\n", + "For $|| \\cdot ||$, we consider the L1 and L2 norms.\n", + "\n", + "**Parameters:**\n", + "\n", + "Our implementation allows you to choose several *hyperparameters* for the feature steering process. You can adapt the following aspects of the calculation of the loss function:\n", + "\n", + "* The feature attributions $c_i$ are generated based on the feature attribution method proposed by Reimers et al. For this, the attribution modes `cmi` for feature attribution based on the (transformed) conditional mutual information and `contextual_decomposition` for feature attribution performed with contextual decomposition are available.\n", + "* Feature steering can be performed with feature attributions weighted with L1 norm (`loss_l1`) and L2 norm (`loss_l2`). That is, this modifies the norm applied for $|| \\cdot ||$.\n", + "* The indices of the features that shall be encouraged or discouraged (defining $D$ and $E$) are passed as lists.\n", + "* The weight factor $\\lambda$ is specified as `lambda`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loss (per sample) after epoch 1: 4712089.607142857\n", + "Loss (per sample) after epoch 2: 4480544.142857143\n", + "Loss (per sample) after epoch 3: 4258867.017857143\n", + "Loss (per sample) after epoch 4: 4050848.214285714\n", + "Loss (per sample) after epoch 5: 3851129.5714285714\n", + "Loss (per sample) after epoch 6: 3662716.375\n", + "Loss (per sample) after epoch 7: 3484030.3214285714\n", + "Loss (per sample) after epoch 8: 3317511.035714286\n", + "Loss (per sample) after epoch 9: 3158147.5714285714\n", + "Loss (per sample) after epoch 10: 3006144.9821428573\n", + "Loss (per sample) after epoch 11: 2864496.8035714286\n", + "Loss (per sample) after epoch 12: 2727674.410714286\n", + "Loss (per sample) after epoch 13: 2597680.910714286\n", + "Loss (per sample) after epoch 14: 2478867.535714286\n", + "Loss (per sample) after epoch 15: 2361367.4553571427\n", + "Loss (per sample) after epoch 16: 2251085.125\n", + "Loss (per sample) after epoch 17: 2148403.6160714286\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[9], line 13\u001b[0m\n\u001b[1;32m 4\u001b[0m feat_steering_config \u001b[39m=\u001b[39m {\n\u001b[1;32m 5\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mattrib_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mcmi\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 6\u001b[0m \u001b[39m\"\u001b[39m\u001b[39msteering_mode\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m\"\u001b[39m\u001b[39mloss_l2\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlambda\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39m100.0\u001b[39m, \u001b[39m# Adapt accordingly for CMI / CD\u001b[39;00m\n\u001b[1;32m 10\u001b[0m }\n\u001b[1;32m 12\u001b[0m \u001b[39m# Train the network.\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m mlp\u001b[39m.\u001b[39;49mtrain(train_dataloader, feat_steering_config, epochs, learning_rate)\n", + "File \u001b[0;32m~/OneDrive/Publikationen/2023 - Feature Steering GCPR/Offizielles Repository/beyond-debiasing/regression_network.py:169\u001b[0m, in \u001b[0;36mRegressionNetwork.train\u001b[0;34m(self, train_dataloader, feat_steering_config, epochs, learning_rate)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[39mif\u001b[39;00m loss\u001b[39m.\u001b[39misnan():\n\u001b[1;32m 167\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mThe loss of your model is nan. Thus, no reasonable gradient can be computed!\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 169\u001b[0m loss\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m 170\u001b[0m optimizer\u001b[39m.\u001b[39mstep()\n\u001b[1;32m 172\u001b[0m \u001b[39m# Print statistics.\u001b[39;00m\n", + "File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m 478\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 479\u001b[0m Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m 480\u001b[0m (\u001b[39mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 485\u001b[0m inputs\u001b[39m=\u001b[39minputs,\n\u001b[1;32m 486\u001b[0m )\n\u001b[0;32m--> 487\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\n\u001b[1;32m 488\u001b[0m \u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs\n\u001b[1;32m 489\u001b[0m )\n", + "File \u001b[0;32m~/anaconda3/envs/featuresteering-minimal/lib/python3.11/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 195\u001b[0m retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m 197\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 198\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 199\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward( \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 201\u001b[0m tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m 202\u001b[0m allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Training configuration.\n", + "learning_rate = 0.01\n", + "epochs = 90\n", + "feat_steering_config = {\n", + " \"attrib_mode\": \"cmi\",\n", + " \"steering_mode\": \"loss_l2\",\n", + " \"encourage\": [0, 1, 2],\n", + " \"discourage\": [],\n", + " \"lambda\": 100.0, # Adapt accordingly for CMI / CD\n", + "}\n", + "\n", + "# Train the network.\n", + "mlp.train(train_dataloader, feat_steering_config, epochs, learning_rate)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "featuresteering", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mixed_cmiI_estimator.py b/mixed_cmi_estimator.py similarity index 100% rename from mixed_cmiI_estimator.py rename to mixed_cmi_estimator.py diff --git a/regression_dataset.py b/regression_dataset.py index 766756c..68d7cd9 100644 --- a/regression_dataset.py +++ b/regression_dataset.py @@ -1,7 +1,7 @@ import os import numpy as np -from toy_data.algebra import random_orthogonal_matrix -from toy_data.sklearn_adaptions import make_regression +from algebra import random_orthogonal_matrix +from make_regression import make_regression from dataset_utils import get_dataset_from_arrays def make_regression_dataset(high_dim_transform=True, n_features_low_dim=4, n_uninformative_low_dim=0, n_high_dim = 128, noise_on_high_dim_snrdb=None, diff --git a/regression_network.py b/regression_network.py new file mode 100644 index 0000000..d19acb5 --- /dev/null +++ b/regression_network.py @@ -0,0 +1,176 @@ +import torch +from torch import nn +from contextual_decomposition import get_cd_1d_by_modules +from mixed_cmi_estimator import mixed_cmi_model + +class RegressionNetwork(nn.Module): + def __init__(self, input_shape, n_hidden_layers=2, hidden_dim_size=32, device='cpu'): + super().__init__() + self.device = device + + # The network always has at least one hidden layer (input_shape -> 32). + # Make sure that n_hidden_layers is valid. + if n_hidden_layers < 1: + raise ValueError("The network cannot have less than 1 hidden layer.") + + # Generate and initialize hidden layers. + # Note: we only need to generate n_hidden_layers-1 hidden layers! + lin_layers = [nn.Linear(in_features=hidden_dim_size, out_features=hidden_dim_size)] * (n_hidden_layers - 1) + for lin_layer in lin_layers: + nn.init.xavier_uniform_(lin_layer.weight) + nn.init.zeros_(lin_layer.bias) + relus = [nn.ReLU()] * len(lin_layers) + + # Generate and intialize first and last layer. + input_layer = nn.Linear(input_shape, hidden_dim_size) + output_layer = nn.Linear(hidden_dim_size, 1) + for lin_layer in [input_layer, output_layer]: + nn.init.xavier_uniform_(lin_layer.weight) + nn.init.zeros_(lin_layer.bias) + + # Combine layers to a model. + modules = [ nn.Flatten(), + input_layer, + nn.ReLU(), + *[z for tuple in zip(lin_layers, relus) for z in tuple], + output_layer, + ] + self.layers = nn.Sequential(*modules) + self.to(device) + + def forward(self, x): + return self.layers(x) + + def feat_steering_loss(self, inputs, targets, outputs, feat_steering_config=None): + # Get configuration for feature steering. + # Do not perform feature steering if it is not desired. + if feat_steering_config["steering_mode"] == "none": + return torch.tensor(0.0) + elif not feat_steering_config["steering_mode"] in ["loss_l1", "loss_l2"]: + raise ValueError("The feature steering mode is invalid.") + if not feat_steering_config["attrib_mode"] in ["contextual_decomposition", "cmi"]: + raise ValueError("The feature attribution mode is invalid.") + feat_to_encourage, feat_to_discourage = feat_steering_config["encourage"], feat_steering_config["discourage"] + + # Feature attribution. + if feat_steering_config["attrib_mode"] == "contextual_decomposition": + scores_feat_to_encourage, _ = get_cd_1d_by_modules(self, self.layers, inputs, feat_to_encourage, device=self.device) + scores_feat_to_discourage, _ = get_cd_1d_by_modules(self, self.layers, inputs, feat_to_discourage, device=self.device) + + elif feat_steering_config["attrib_mode"] == "cmi": + # Estimate CMI. + if len(feat_to_encourage) > 0: + scores_feat_to_encourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_encourage], 0) + # scores_feat_to_encourage = torch.stack([get_continuous_cmi(inputs[:,feat], outputs, z=targets, knn=0.2, seed=42) for feat in feat_to_encourage], 0) + else: + scores_feat_to_encourage = torch.tensor([]).float() + if len(feat_to_discourage) > 0: + scores_feat_to_discourage = torch.stack([mixed_cmi_model(inputs[:,feat], outputs, targets, feature_is_categorical=False, target_is_categorical=False) for feat in feat_to_discourage], 0) + # scores_feat_to_discourage = torch.stack([get_continuous_cmi(inputs[:,feat], outputs, z=targets, knn=0.2, seed=42) for feat in feat_to_discourage], 0) + else: + scores_feat_to_discourage = torch.tensor([]).float() + + # Transform to [0,1]. + # NOTE: Even though in theory CMI >= 0, in practice our estimates + # can be smaler than zero. + # Make sure that sqrt does not receive values < 0. In analogy to the + # Straight-Through Estimators (STEs) we apply our transformation only + # to inputs >= 0 and use the identity transformation for inputs < 0. + scores_feat_to_encourage[scores_feat_to_encourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_encourage[scores_feat_to_encourage > 0])) + scores_feat_to_discourage[scores_feat_to_discourage > 0] = torch.sqrt(1 - torch.exp(-2*scores_feat_to_discourage[scores_feat_to_discourage > 0])) + + else: + raise NotImplementedError("The selected feature attribution mode is not yet implemented!") + + # Small corrections: + # If there are no features to en- or discourage, we can explicitly set their contribution to 0. + if len(feat_to_encourage) == 0: + scores_feat_to_encourage = torch.tensor(0) + if len(feat_to_discourage) == 0: + scores_feat_to_discourage = torch.tensor(0) + + # Feature steering. + if feat_steering_config["attrib_mode"] == "cmi": + # With the CMI estimates we can have negative values even though in theory CMI >= 0. + # L1 / L2 would emphasize them, but we want values < 0 to result in a smaller loss. + # L1-Loss: + # We know that our values should be almost > 0. Therefore, we apply the absolute value + # only to values >= 0 and the identity transformation to all others (analogous to + # Straight-Through Estimators, keeps gradients). + # In practice, this results in ignoring the absolute value. + # + # L2-Loss: + # Here, we also only square for values >= 0 and the identity transformation to all + # others. + if feat_steering_config["steering_mode"] == "loss_l2": + scores_feat_to_encourage[scores_feat_to_encourage >= 0] = torch.square(scores_feat_to_encourage[scores_feat_to_encourage >= 0]) + scores_feat_to_discourage[scores_feat_to_discourage >= 0] = torch.square(scores_feat_to_discourage[scores_feat_to_discourage >= 0]) + return feat_steering_config["lambda"] * (torch.sum(scores_feat_to_discourage) - torch.sum(scores_feat_to_encourage)) / inputs.size()[0] # Average over Batch + + if feat_steering_config["lambda"] == 0: + return torch.tensor(0.0) + elif feat_steering_config["steering_mode"] == "loss_l1": + feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.abs(scores_feat_to_discourage)) - torch.sum(torch.abs(scores_feat_to_encourage))) + elif feat_steering_config["steering_mode"] == "loss_l2": + feat_steering_loss = feat_steering_config["lambda"] * (torch.sum(torch.square(scores_feat_to_discourage)) - torch.sum(torch.square(scores_feat_to_encourage))) + else: + raise NotImplementedError("The selected feature steering mode is not yet implemented!") + + return feat_steering_loss / inputs.size()[0] # Average over Batch + + def loss(self, inputs, targets, outputs, feat_steering_config=None): + # For MSE make sure that outputs is a 1D tensor. That is, we need to + # prevent tensors of shape torch.Size([batch_size, 1]). + if len(outputs.size()) > 1: + outputs = outputs.squeeze(axis=1) + + # Compute default loss. + loss_func = nn.MSELoss() + loss = loss_func(outputs, targets) + + # No feature steering if in evaluation mode or explicitly specified. + if not self.training or feat_steering_config["steering_mode"] == "none": + return loss + else: + feat_steering_loss = self.feat_steering_loss(inputs, targets, outputs, feat_steering_config=feat_steering_config) + if feat_steering_loss.isnan(): + raise ValueError("The feature steering loss of your model is nan. Thus, no reasonable gradient can be computed! \ + The feature steering config was: " + str(feat_steering_config) + ".") + return loss + feat_steering_loss + + def train(self, train_dataloader, feat_steering_config, epochs=90, learning_rate=0.01): + optimizer = torch.optim.AdamW(self.layers.parameters(), lr=learning_rate) + + for epoch in range(epochs): + epoch_loss = 0.0 + + for inputs, targets in train_dataloader: + # Pass data to GPU / CPU if necessary. + inputs, targets = inputs.to(self.device), targets.to(self.device) + + # Zero the gradients. + optimizer.zero_grad() + + # Perform forward pass. + outputs = self(inputs) + if outputs.isnan().any(): + raise ValueError("The output of the model contains nan. Thus, no \ + reasonable loss can be computed!") + + # Calculate loss. + loss = self.loss(inputs, targets, outputs, feat_steering_config=feat_steering_config) + + # Perform backward pass and modify weights accordingly. + if loss == torch.inf: + raise ValueError("The loss of your model is inf. Thus, no reasonable gradient can be computed!") + if loss.isnan(): + raise ValueError("The loss of your model is nan. Thus, no reasonable gradient can be computed!") + + loss.backward() + optimizer.step() + + # Print statistics. + epoch_loss += loss.item() + print("Loss (per sample) after epoch " + str(epoch+1) + ": " + str(epoch_loss / len(train_dataloader))) + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2555d89 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +acd +jupyterlab +numpy +scikit-learn +torch +torchaudio +torchvision +torchmetrics \ No newline at end of file diff --git a/teaser.png b/teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..dbe147c2338590b23a392baed1533d34ec1120bd GIT binary patch literal 34850 zcmeFa2{@JA_dl$9G*HTLh>}8vW6CUJnPuh_84DRhGG=T-&LK%Mh0IQp6q!jJGH0I2 zF+>?Mlj+^}In>G1^Zowb>w16x_j=#!d9LfJTle03eb(A*uf6tKYu^Gi)D#Yp(2)=k z5gk%el+_|4+6N;dBIew`2mDKA>DdwR*Dgmb1!|0Y1xOmy85oJo`KQPFJs?kXXjQ+ z4nKAQtqcj96;1EkWyK^4*^=3+y3^Xx_gv*9eARw<-!np1Ui* zAH84vZlxG{KeOfiN67a~mjC6SANa5Ez&1=F1j*#maO_pYQe!fnB>I8BW}j3z4$-M3 z#uKAxK?eWA#Gfy3S=Uz$PT|o&J!AFe9P-4Xb`hyq0n%;}sqLP}rylMz0c-G z`Rc}cG#wokp%gUheS$}(JrCEj4`g3-FPxUjKL4b1!b^PS)C*@)NujtV^M`b^7K4^z zTI7nK_Md{yoO!Y0I_F8a0CLsMVMkT?L z)V`#RjcOq|(tbeRBL;(gJ-S-E%v2q$yVmP@A^L2&pL;~oq@vu(^Gx><73FE8wrG(O z`e*hF7jhhuic7uuNec$u+>Zv{bvS`ij^=6cOJcn7Xd9XK%-o^6^DJV<>E^<|nI~i3 z@l^GS$x`M%b&rTYq$KzGWY7^U^23Qy3||c0k63!0sin{Hu599>o3D#|?< z3OZ8zdXJr~pp+5@k)%?h&fjP0w(-Gh^oUbvPpMhmA$^`{qr?2!r_lS4Cv9;Zicl%` zjS6TuRX2a*`p4#K%4Xi&JXS68G*zYwzKPq%H0#xO^9Yh;gZ1BFlE^8iCps$RH}*`3 z8FuHMfHY@0Ssr;?k%@w^b!{Zytv${%)`uF8mHI~Ay8Giy{?*f4Q*(XB_(hbRbK4u%#S%vY+T5NG*LhJJWoLmdVLY;~_t(L>Gy{@oe(x5b) zyI4V#?`69T4~-+Ka4b#_J}xQ|tGDZwO^gM{&R!^A9!SoW&Z*AzMHn2&@cmRLA@fa9 z?(+q@Q$4IO@yj;EZxf~5ktSbOZzNK;t72e2B~!F*6I{ccvIY{~$WyQ?(0WVidoy5l z4>!ixGZCI#b;TcUoPy06zq^clVNhpVl=G|^Ki7(;p~vYuwy2Yp=Y6NW(T(lVvU`=B z(0)wYL6K97JnxbGZ>=5{wCbm2b>*aL^U})2p{i3ysZ^V=Y+M%tyvZcQ(# z0%bWMAWauN{3)-U#9YB={)sCb&NDjgXMN8eg+8he7?l70y(S#~!SkEq_4$k>#_xa8 zoqBq8-$6Z&2>SIVZFzLwBljU;4IWjDy>hl{(xFg8+Mi1{OAb%Gh7ATgnz-LFyPx$_ z7&=d}N7kANryhhL9Kdw=(nVKBKSsXKu)5DI&;B|s=WbUMgT%FGNYX6fjcPBsUN`6K zN($uY{Z0x@Nm)u{K{gpzT;>(jX;Tv@OrKmW|N5Q$1hC_GeiIR&lAN!u?!#Aq?@?iStZmt4ViR zx7dips0C)7!ena&WA7vIjrvM`q7`U#LG0YszKzy5SGINTb_(iIbo14oRVy-8111 ze7wRh=?2{Jmif2l7H*^$ZCbxpRJHV8zVJpkL!f`ns@@ES6A)w{p5t|Ggx|$aAaIj5 zS><}fXhU;hERS-?(}(#OL;OBNFeJ4JtI$|qW0UHY>k%Ul>1gb!bF61u0;?mANjnty zA^x&%*`IY#?8+znYJPJ>tlYx)PPcfDilfE;x4sO+42!0PEBd`rh`(O!Y$+2eOcv?O zqL2|MjH@@Wd|Lt= zRiiZ#*e`ui`FYB%siJoqylUC5>2zk}Zqq7jZcI=A29`=priat?Y4R+(6{wcmi!9GygQI*Z5O&9BL1R)_Sfr8TdY1oP7g=OMs-V0Rmx|L z1yM9b`eNA}Xgs~Iu@>gPedL1b_qYUB9L~||OxRlEyfEXkl&X+!$$`XY7;L9F)mz>r zS|re-nBNxjn+T-LkN^?Qv3q8m-uF$oMs56dmyZQ{7?DOzWsz z2)3HvTRlU%lU@vFe%Cc*<4`VJ$JNJ*S8hV<`sNuj%bD6Ho+}JFw~Ap1sFyUEstO*A zPFKd;buDszoYrb%sY<*cHz)Q&!`|hx6WhL>2@1DLfP?4#df&E#hi5r1wY`X8;?B00 z_YSw}zMuuC_jx?u_7u5rCIh&OtN#uO|KF3PbemM<@AJTlSP$Ly?6>G{a6DPRNBD1k zyM;&-`&w{^i~CB%kRnjMzI=2t7C)2<9Gd6&+w~tTm+Lm zU(NQwYt{5xR!^ETwL>Po9JC=D=yf|HgnNBNff7+QH`8 zTC3G>?fhM|XMA25x2>tu=S8i?4XTdK@&Z)~C1H7LWTMj*#@{~q$qxQPA2$BJZR|Dp zvr}Ft7;x|maVBeC5fg>I=q#{DK&J17;=Wh%{uzqIX z=I4jUtm=sNvcgn2CA6|q+%D4B6V<2nCX>E8w`t+F*}-U$j|=R1QYGWH-8~z<-V(?mB*#|E3^GuLDyM z?fcJ3cLS0Y;O$hIFA`~=(}oo~mBGd~4^D5=5l`iXR35x?L#wQ!nF(ft)M;xv_$jc) zu;4SvypP|KPtb1HUPaTrXZ{t#|5ywIOcGH6&PwoD&VQPY=kjG01VKLJjl?S^ zhL4e~sDGM{G|e+v+)9|xf7Hrf`@x?S26*k;M*dGK{r_0!>l4l4r}Zox%3x`pA9wqI z+dQa(S4iORZf{Nk<-N%~#w8on>Q}4njZ3Wy#dFb5U!2{1)a7%KN~FAoX7cXOh45>W z`E*b}yu>r8xba8*kw80su8WQvr<{48sxPgTx0P6i9bUp_T0IA=laLkNg89=feOCmJ z_6g~|7+sTjjFelpp|b=l@Ylxpoc5dH8;gV`{S9sTTwffHBop$DsJ&pB5lXGnI-HMz zg?m+1-C(%g1|2nbSP4!&h^I+lV4b3oJ$6*59Sg_Wt^cvnd=#=!nwJJ;m;5`uJHITt zOe#9>II>g~?aWAvvWxlWc5!GCQcR|?wp8NM%G*G+u2Xt8qJG7n8HW1aA@B)UqX}Zl z^lwCytc?ZqKU6TCLt{3h<H&SS+lLb^m`$28nzRJJFg!+Z zelTu12du+u8kLu?_GHJ2e{&h*g$+!{nftVxeDx|#^NSot-fWO!M(f@$O}PUM&6Qsp z%_mcMiQSseeJmSwc-+V(%Rk5`Zs6%iyLAk1|B^^CnV7EBIF3m;p`P4LD!TdT=fvOy zmHe(Dqdp@V>?(tAj06VEm_TjE>M)xxs^1;y=nNsX!L~ zit)dVS`$cb{O*>jeG*oQWP&escF^cQNya$g7#IVfuJ#{V}2JlJ+t!_7z>st<^Pc`&e)b*u9Ixa`Gc|e+YG)@E>FFuFLk+2 zs_}o+zCgKLKdr~xYzBs`EmX~+0UZ1ny>+#6V|8lj!-`Rb#|{5!?fnf?igJ0qiS>r@ zceUmUqfm|-V4?r>0yHa_FPVNjf(_SXNtxp5qx&B$r;pC#UT??N76rotjafn>)HHK1 zD7QWPZ)@s}VN^xJQuLg2rJ&WM{_HW)KN`s7cN-_hn-Ck3+Q6%{f_qmcsQ$1?|D(;F zWY6Sk%(nA4(snTBH1~d5p3(Mq^TDl^MT>VVv)Yr#t9)7V8HR79=ltE;+Nj@Je=yHx zKR|whz2pSs!M$`DudVFC47`!*8x#7pbLhHU#o!CScE?Z2T=~n+Bdi;JEdOEGddKaR zLp-F%V#GHT>+{iEYX()ax2OmE?c(QyDtJ8xA+zlh+NXn#;U(N_M3tQ#Jn120d53VQ zud~3~Le{C(%Ye$g>O8Fg>^i~mu)ecBc_ccNC*|Bh;3E^0jjMIGB%*iwD^EQ}R1AUm z-q!K+m%E|zU*2?)jy6swm->9w+&W7B^w7`bD{GMlTksq;)*rbSvRi15NAdZ#+p0+4 zym57mg}7n65jeOc-)o{judGXMeE6sP`$#ki^<*KAx$D(?S{7F&ICciw_-Z+51MQad zJ|LG(M1GzKT2GlKJv$MI8*y{#0@-EpmE9Isu*MI!cQ$Tz>Oi{zsMiI;UdFGayg7q+g)De@|BIbygdE#0nfdk9=9U>&9%j$fU>d0l;jqw&TDET<;BZ1 zjN7}i4R&>MQ`xtyAJ2KnAL=9>_2A4TwgA9mp1x)@@`NNNx=|lc@@_zY;y({SqUOCxmo@ zEg}K1mH@HV3$d2NZM&|$W?!G05Uy5PTs9W_;+H_uPe5+mD#>Wi5<^y}p+gQX%DOMnBbI>BEgT`92>YyFNqQ3bB~PeNH`HgiFnd+LY5X)ShW}oZ zgwiiDc8EcuC!=htBafd047$eY2~p>We|WD{|g?99)h4swo+WG_%;EZCtoT1OJ}`BI5kgG+aHk6&Hn{1(+L81oGn7vZfYEzIWTcRen|kr z_YuMejO*z|x6EV738rN7xf#E@@i_kMHZz_|oapfo68Tj~nwKKzD1z@ZqCs0RMBt zHSH#)jf6sjGX-sWqN+E;K@N8MTv^3UeK&W;qd7vry%kl#T&G|1&j z{|0G1weg;2J&2Ngx(ctHiuFu)!vF@%;IJ6Xk`@h9a>bUh!Rg0D3u`G%%@%5H3@2mY z_&RvGVPFV8cZwf9=GDvh82JThcoSM!dx$A(x#s4(GLPuk;G!}&H%{n89?pM^Y!EoA zgmJ`hpM_5OckTv^Jazl=fuZeF6y$01zyt?l*RrrM{#PrHzM+@HtB;{OQnAJ**9C%` zFiTa2g-^c{P_P{zul2FGz%lo-C%y+!Cs%O-T@Evo9Ij#aiSc>Yc@AzTK#HlbsI&_L z-C}+Mn;aUi4Ot#qbbGRC7yJtZQ5QJ)h(r) zin0mqnU4B4pncAE`x&m8@-?qyEGH#D&rALEL%V7`(1vZG!b|hh7IJ`M3ZO{Sg{LSGA}In}0yu=4 znytvDZpH@Kltz@)t~s((?R@Nj_IMZNNj1H;`sq<24)EXXue@%aN-9t;^v@$N<@F z2tz9vVdCb~fYJ;L?CDI+BZXr)J4b3h+tXXQr~nXT0t7le5O%Y0ff&i@+--)$pUYTvhTxYqCmAsRLGRp#B{ZfVELrFa@*0vD+{?Y_RmW4 zEG)+b3SG`_PTDPoH`%@fQj?+wb9}|7%##D~Ot1ko8@l^eL}rb-q$XQ)`m1<7$rGfY zUt;EpF|XoY!gK5K72-CUgeCF;E8uA{8RE4c;?*g<-*{}`#37s!GGBPa3d9|on%r`I zVC0TU(Xb$ii7`YL9Kj-^E8a3TIjpt-(~~|8nVuBoq1Vfk-*ncyi1{b4)l32x)9i4h z=YZ3a6mhS1j#E@d;j?Slc1j2JCN0CPTu5}dQ>$ZYA2YHCGDjsuxGkc9%A=?QYc|oH zgHZkPs5-iW$+c>#3&`%B;fXq&t}S*zvV%Mi8|(xmz8oMv!5h%}=}7M4k?TrZKs-yX zsx4bW^_|FeisCBlU99b+&QG-h9+)ZQf#FWpxx0Z9(?e^i_A|u6J7s^9tXhQl2KkGg zfTwDv10(ImQ|~@XXev8nB*lQKz#a0GA_2t`h+-Iy;y^M}R;MxiL@-bQ;1^^~%|5}4 zDsxLzSp`KVfKVSq2sH=@RbB`ut1z)ZyMfhZr=uqtWE<|`W4U<`Xzo|!y36*V7)}@P zQ0y)wmo4v>eI@9LphPnEG8nDcYFyx{VeHd@@FMy16Zz1V&`Da^acNM|iR}vi=H0)5 zE?pgKZZMc&fvH2f33j5(10w-oYp|iJDB0o$E}QT(VC9EaHg1+`4{uk1#~uq%yTfnL zGZJEK_V)ecdBA@I)PpcmgAdti`8#RebY~U?US4T!DF4jk4z>qmM8YW1b_kUp(iQge(~H1BVA)0fNE?KUhq=E<2`3Jmvq@c7$)^Xi z>qZ9Y6^QeB%ep^N*h;wkip33cTeu>7a%O|lPP$Y>I%G|y@>03S;bqSq#1m2|JRE`B z#nS&A2OQ>Bk`p%yy8EXvB;Fgwji)tZwiOzGu59c>N?lsB9jU{6^=)7m4o|{XqVSQU9Wl|}hKw-vOIwaNJ{Z!v6RsjA6t?ANk#@FEr-1~13p||v-+&>igyH!yeqTL))n@~(5SJZ`5D)(cp1v}}T5hs_zSdC) z176MlP5yyzziLMq7QVwB@~7y>R`^d*j{QcSU=IjfTBErMLRMs-yk~KGw%cF-?fRpn zgr{F;e}zu|N1fx#5xtxKPv=8PldzW`7jn*3y(W2fC*bcPz^icUuLHpdapjxgxwnGu z;!m9f(ZrK53={yF1Odz(gbv=LBnYoQLSljOtzQJ8T)h_@_?0q2oX8;Adx~J561bSLpkd&6zko-HU_)c8%`RvC{O%h6Qew>)}3a(jQ7UE|R9#`FmD-8GfvN<;npp zGo;SFtv+ijsk}5c&|2JHI~P_vq=L{P6V1u|Q>AO)x?Sw4~si4xr>Av zWEtNKua@$I{Rjs>NBQ0)#_lA9z;Hs3yz)rJtDmDpNt;Rt_6h-$u4Ot9C}xK3cFC-7 z5G$XtHA#Z;m%HFb8dx+OHM;xxlMO-wz zy*so^F#$Gxbhd6av=1fKJoUwWdGuDIPc-HUg13txp)uCO2GhCOl+mGeHtX27VR<4% z669%(bJC`0U7vvwqOQ9yA7o`LwR zLO*>>_5)X2Q{jIK)mhBILS4GNqk)9h_$!j0bc;AwNh`18lwt`@CgqJgD|L*>v)lT6I-^O5@NIE6?l5HDaU$T;Fv}4 zh>F3tOIRh+{4u>8fzSq7i@_5&rK7!SGJ335fnLvZg&@JZZZNOn0zc>VwDb@^jbZdt zwYOH4s@Jf3xsTOXKgxBaB(pm^n$6OhNK$K0$5H1$?*WRMU|)?`11}Q*hC4qoaM9Au z88~qPq4?1XLz+<|?(h`#N={jO4)M$@ljRs{T!Yi$eD3sPMDK`{R<==Vo>~uzu4O|z zrBQpRKs$EPM?dC=K=JoWp$1ODIRM?uv#^e_m+JZV-C-SG;;I*Or9peSpu+K+!(XC( zE-!QlZo);o6O_0Ms(@p#A9S-xqFL+g%=6Ubgzuul*PfuJLV#!W%mH;RLAXY4Nl{sQ=#&1#;0hhP$Hv(Q zdiUZqS||khBfR(ArXA(_x5?9sRS#)L%46ad$HP!|mz77>K-7=NrSWb@K5wRhQ!PMb zRml|<(9c2B_vUGe5qfz8S?}*_@T{~+^b2&_csD(pjX)$EqO#g^y8GnRTQc(}hh4q9 zDMk|Xk)%NnV%~mswo?y0lzU{ugRwkNUz}0`_lEzgISXYmPv&Ji4(!uE#qd#m~zRud<`{lAK1|B))Yv8nb$Ul~}PoUzN zvV2o{QR^UCXG!g;nT79^=rObdz51%gk`+5QP2Z-Cz3l^f2uF*tvI>?1t;=~v!qBrb$hb#c>hXUxw>Ad+lF3;5RMDIZl8ipE6}^|4d@-gQr3h~duXK|5TBHxQCFUmYKt$)SrrsYAKQ+T6!KJ5sDarpcw+D6Go0oqt)>c@`yB zdPOYGC4c!IsoklyTnhSE^4ygwMPCm#W}Aj*GGrV482-LcI5F=g&(TvEI6K&&zB>N{ zV*%7g-*a<@Y083BLRh_ru6KpFXC)hjQ7Pi)zwib7D4-)|@Ydo)v+=#pLv}OktfwiB zg`2zI@g{oa&9#immmX}DXnx$aL?19VK7Oma{c{6PXG5x?^~VYh{zNnG68@PQe(~B^ zLo)k#(Q;!|vPvIGju}UGy`C-hTiGbfC9}=uC8o^wO`iNopR|~Nrn`*lxN zzfjt6+2^)Yr;|vuVc|Rt_VlxRVp*EZEcUR}#_JbD%3era_-^$>K=4)1a`tlWp$*6R z=u8R0S1pSPmCoHp^BqUptvob`L~VJjpXInmmNE+#z64f*Wx%DM)xBFT_G%&K?Iq|n zSmjn8D~+3ori{oFYC#m4H3`8Tmcr|z)d$x%ss+|dYI@acg88#PGRN|Z8F16&C`teL zD>c$mN`{8cVP7q0>0~_!SwIKU{O=qF7TJUQGD;r$QJmn6M836hw}E@3Jh~KCKZ2B@ ztFHAZJFr(RDx7=)Up$X<$Dq1Tf-KOSia_k(a7Rqa!mNizakbar-Qo4jc7_z6>O=F> zLQSRpMwL}<5oPy)BgsT5oboAN1E1EXWR1mA#Alu2zs?&H@&n|X-IYcXF0pdPeg3i% z$#<9_4BbT9~D_WL3$8T2i!y`z0#lYzpICp!vMxrovtjTs6;CM#w2)P}k=)4GM zl=l|M*u}>nZP}*naLPi+TpsKh4RRS5mQ^q#bNVsE15#;Mp8kls@VVUvh|4vyZnfLu z3I!Bl`CKNc#Tyb3t!243*}^aoBo$T~){_ez_9(z1xLH!m+|eANib*zTMc%&!Pie5X zeLUcvc(kp!rMt9tBv@IepV&p;5*%mGg5&IAA!Upii`|9zjtp$@v^D$tNh2t1e-Ok# ztlkE8h0um+&+yZI^G^F3 zvfL*o!jK%$U|2Fze@3(=4K6WZmFd(K zLkrGRA~bD}&Lm;GSFRmC+mK&9EvI~R(=MV_HC&${#Geid6F!$cL8eFMo$~c#sAZcw z9Zlo<7YG1*N*u2zN@mx8Xy&&`e^vR(!9mvtTy#R8yjctipFZ|8Xf3}Vf&BI*Tx^(a z`H(&5-o(_jYuv@=xw{MN(4KiTKAIpoEe4Gk&WVc0K!4pk2tIN znB>YHHH7r;ePyUj%S*FxvMe=WWbLN^yutx*{$5rxqg>Q@sR&_k8;8^Y9>wP%Lo#lP zxoa$SsBZnd=XnTbzQNfjTJ{lb-q8eV%8fQR??{=Fsns`9#Y=-Ql)E_OuM(bePE1Gz zMk>NxI7?1Kv$^3e!vEMhDWOr0i+K&#+0RI{M|y@HRjczpgCIeH&(J%TRbb)kSvZLw zf#PcKT+wRqSXSqc9W)vXo!8y4H=_#vcCfm_9ieN0Nl*W9YKjp@*8KW(v!Rw|U)_3a z)L(tzqRp4B^G_WyB?^VOu7G2$zOz5MFv(}DNt;(%>=BZts0dXw6krFm+>PZ5&kYupy4SWima zKy}~!ljdJPnN`ROT6Tw$U|Ge=x$2lU%Ik1g@O=}0C^R_?#Yjf)4(8s=Wx+>+u1`~- zJDbJQK#C(0#9MH|Zd@4!+|U?v_y)tw2o#ty^$4e;8;_eye@1IV;V=Fe>2Qz5n~%xK zOc;=UOH=oq8t&tqNB*<%C2ihUnNOj%Dsl4aXpt2VhTgEBwExh^k>I*`m$qM;>ln&z z6Ux_FZ$V9$YJR>Dl$xMsstz~LJojS#b{)5j)ZGC-_L88UNMJ3$3_{?Kk;qF$4|?(} z4JPyyLMDY8it|m3X^Dr7v^x6LjWNM!+A~s24B|vW5v#k#4NL4`u50}T%(-G99^F_0 z1!{z->b{vbfg5K%m%3GBdvc&*j{F>ESoUQbE<^uh9Jxg6ZG#pW9-rN7=L@nD{ny6) z9nIOEitGt>gX7}0$yfZ;sa>0mz-fH&5?ZA9Ip+W;T>I;_%Ba>;&UA5i-aJOiu(31@ zKnvk<;qg}s=y*2_2%0%MKyfWq;yLfb&W9(NK`O*93|Veav|RBcG7D6YnB{%Zo^ie# zK&kYqlPAA+kP2iyQl~$$#iF3-+;xr!nXEJ~P)1swp*17u#kE%LOUh@e`t?!uv5f2R zk`1Lt7gTe>-la69g^nml?_ z_>_!P%q`9o?U`CJ%Fr(%MsM|RhQ~9cGEA>YbXSDp5N8lOo}|w42dRxb9e9cTe#4}q z52h5f=$Moas!;La^wQFC;kX(Tc)@7(%4z#gy>O6j%BdZ6{#5?rhW7oec&ccssREVp&|*OvT&oO}>9 z#&+S@czCxf8s+W|vB3!&pwyMFuhR{c92Or>nMJ3Ks^>e4CA%oWcNcyKeVosOIucRW7n9$H1DC`IE^%yETSsN(Dko+}H;kBDT#-wAGR z`Z1%U*-8r>QST6jGuI6-U`&40%T<8`w_ulyxe0F=vew5Ei?IdfI|*WeFr=wako^R; z@-fe(vYvqtSXgx4;{dsaGdao&G4?&5%YIdXpW}b%JpN7Ks1DN*lGb(PQMK-xB4(h( z`?5qSqeVpb$R}qB@!A!S>;TrGonh0e!g)EL_wA34hS%kV9lreeMzvu9EOxx6Ctt!E z1VwQ(NXYJX%`S)+=6(=cxhDq_ByBMDN@%4K0@^NlWpO?0*?9|}&j70$yOaKQ*fzBA zxTLQ5(yEZ1KL{M*6`ycMwrgxvrkgVks0GR8N#$*A1{PbYUcE%?Q3EGu&X=(br=*Oq`5m4~~%MJ1qI1nu@T%<#m6ov(#q{5}c}#Ncg?;nbQ1N(K(u1y^uk^bF?D!?2SJM@(sD_wVqCePR|*#AjBnRw;6udE zK)igadG)fslzAVSVT#FYW^2zQq{ViP1U}5!0OH85Nha6uvGI^`t%)1+f(xcA$MOxW z(7^U8VebDbC)?2&41Q@QdqB*qIxzb8!n-nBV#;q15ihr@f}0JI>&^`z#b)07pGTZ% z{&s1mPNT@AQk7l{omP5+mqv#>vrSZM#sD8W-!brGLpr-;wL=&VXkh2NwOvmQqrFn} z2^=f)O!5g6`@|1y=X+M zJQS9^wwYw?BWtwo}Myh8^2` zCp_wG}8AyyCK* z3r-@tbq|41SY(1Vi9sd5NJ;*tNbiwRE>DK48e3et9ao%S5N;B@WPJ^+N}TkiptiGG zyi2w-V#4kAaq?3KpmKy;koZAadwQz?4%D`#OPS2nZzh=tY~1nY@bEdogFhM!7dkOhuc~2=Yh^vy&YA>f zW@NZ}_keP`4dMfkW1XzL&ji>>kX!-PF^xuNdIk>l-P^i2_9?K9JE%0E-~ce!xLjRq z@fDXM`hq=MH^iRP`f*RzVcLPiW?%LYQMW#(6u|%Yl0H!WHS*2ACoyXQ!I|`a5EP8L z7V3Y#!Gd22f7;4lzlTyOvVXdPup2lq4wQsG$@YFtIsvK-RsXa@bG` zaOJglV}+~o#_yqnn&@T0jphhUV;?Ay#MQ@dnGS-Y``9q#jEa7wLnBlaFZ+j#me%JR zHn@uU{{t`Mq<+=cCnZpc^y2R^kmJobPEET-bA@v48VczBYTcVWLp7f(0|QBdn+_-u z_qQ_P?VMxp^i3`Y5mJJ@;?7d|Kje6Ebr$FhM;pg)qc3}|=Dnn;7p5f!wmSCZb~)j% z`vyC!$^VcM{*MU%OHDXPcmFMa{0|a|NlGZlXF$6VW>NpQxitgVU;4h~x}^ZCvR$V6 zpYAE3RV?$sjfVfpZ2;y%a2w#%f4&Xy2B-nuPZ#8mmkR!8!+;eMRHSO|zKt(1U=j%L!tA^X0H*zByyo^AB)qkyXfJQ;ku89OHH{5_eOE@HW_=mL1`a@CT_AR;janO74M5AZ6VJ zE`Gid#?wvw^P1)Fqd09M>OlR&M&tVm{58!*>bl(Nfzs)fE{x#MT6R3xUyWQ~b-4_f z3mJ7r_Ut8lAyLa2Yg%n4b)Tu~6gV?RH&&xb%ul#mdn@^3{dAprprX()$N~EUZ@>=}nD_RK^SKg1CbxQC==u+* znGgRksy>PN{KoiFaX`Xp;&G(i}MeRRa>jlL$VeanYOAVz;)+JH4!YG2% z!{0>>44|K|V)!&w%XB7LmgTBhlPy$}AN~)wX>k=U!(7HcVsGds`r|HL*X+2{dnf;j zpGR4aKE_vHI)k^^_v<@c1>--UjUbh`f1c96&&z9<^C>~+x0ddtZMLaUx#T|}I71Vd zz~xZBRB+KiuaqK<(oPIgq<^?y46v9uB=?`I?}L{xY9D?Z7<~tE++fE&Xebt~1Bd0f ztXz5w&P+y#KhN8HgELuq6vs;Y&l!@bX0#@94w;{X?ZX8!2owy8gNJo4FYJABLfs)# zSGFka&nk{Xc3P;troz=h5yOa;`v*0IBKI~cT+I`%wu%_(|IjRSwK_1brfG=Qn=kRf zKV8V)<5*_qO|`jT7+rCgf3!m_e-bK-{q2(ZZ@kdsPr7o5}i+l6kN*C{FNQ|uF!M|*t3%y0Y}&9;lx(5)|1d$*{`|6!+Fn1Nu2l42PL-o^fO)q!)T!C|fgr&yb4 z?~8L%&Y!J+NXbm*GHze1zs?rXVK@*c<)ptn*;6<AIk3zhB^`mcUGk#A; z@;!d=r2sBse)WQYuHxr<`zn{zp>~iR0_|jJ=}v)fV3=r^nVC3j>_tamjw6#242!aX z8hM8x*<^5ELbwi(823msWr(ADA@K&6sLRX%2U_DK#v1(R6M9I@*CwN~8Z(WS7FVVZvO`@wezIxn9YC=nYd zNtS8MY{Rl_?@+&(A=qR8{m0)wi9#^t|4Iq^M$EqgVM-cT8VQ!#SYG`cEY)pEE;2v$F|CRUf3#>V>P6-Bv1!Z$X64Y9i+x-V06?pnRN-2;QZ60~Ir zt!`+ol&{Uv?lH-VY9~p-?6yM?5^PesHbg%pwHzcI23A)F^^JRkT-2+|yS(C8D77Wp;H`vMZ^>osyFS$=@vD(mZ-XQlqRelPq>Hf2+j3frN~&hLU7; zJ_c0o?Y5I6q;e1F$04a!bxvyVLB@HF_d)v!;5qT4kD%(o!Q%(pXO{cW$CAcq&=2?o zq<=&p9$+p3ImA%iASIMTj0(KMgb5%sBFWoDh|GuuAUgz*2@`u36Cx9aklhE!;&zu9 z5F(3%_y<=US~PbVa}o02l$1b`b&^`yD&vYf{4)yo5vMI&qEUzqjUe44C|`h00`?t4 zh)qHjUV&Wz*j|w=>k(pm#RBtQ=@|iUN=TNq2;YBCU1Pat? zK&Ke?-ItI~j4E7TQg7`GgF8h-3lQq1_cSB~LY)^vU%z#DA(sl+@}g(bE`}SKsQL1`3QkdvA|0F+|fXcrMpX12*sGI%6HLJ6UB>K$6VR6QhMg+^B4>Qg-SPC>JcNsA+)As9 z8tM5$R~q!YxXYM>uwPTs8(u0U3 z@uRPkVLEbcHn22x@7E1=msxTPuIQqd54$SBvT{V;(~K z)k$xTAbpHWz1FG*HkVg9!~!t<(EL%|*h45~FBLc{K5xPBqRGr(VCq?kd+!oXJx%BZ zMydeJ55R>aN!|fME>5t7nXyrz+Q8HqQr$R4I4HdeJW_+1+6H`u(w*(IJmF&MjzheO z2lfFhKqrm&8H5W^KnnfxK4X@y4&cUnmoO6{H^2IHmDgj*;Zb`wa*h1Q zHZOkL^_BM0*v$|m{U>w)o#r!n#g(LQ(2;% z#E?fTYmca)CKIw{ztvgTWV}e$&9{nhBby(yH@+AJfp0!{?uv;@Bga_X7kGvglX2jJ zMfzvT!Hr~xVh|Izfp`1>Df=JLKA#DdAHgk}$p*nVM z2Z$&Y%%1p&RYZu!^<8K;@*r0N&Um3RpiztE0~nS^J4W2AJ)_8GJ_S2~5!(4`QENt* zjax$EB-OF=JBcY{?!}A5Iq$n~!m*Pn8?>7UIKcoMdw1gCgsJbxB4CT&FGU&V`$kpjtc`;M0Ey+$C zprHQ{rfDj|wMTj#-VdiD#GQM%#L?@N|+Xzrpei9Gn0Ee=`o>oPf|6pXw5Lv zO(&V`bFxm5r%EMnmWNXy8+Y;=r6?yobAIOtq-aGpNs#aPedbPriBUKO<4KrN`?Uf( zQW@iyIg_w*0{u!Dvawd#AT~9o;aLPT6@%Xn5gEh)HEP*8vncHl4bS5{z}~cl!0u$4 zTNQx4vjgnMFbFL94zT%WAh3)(!6pN+Dm%fxfxsenfNfR9NNvmnmu^aMk|Wi3uCf;? zCV|onsQXycovR`WEh0N8wI8#OX$Nf>sHY_v*pEp)x#PuxR*1%Zg&p(zVr6rlgy+XV z_x`1&(uF96%}hPb;B9X_iCUT*Ff6!z>Qg*@pAqO|)X(>1(&T%zW~{eQj*sB`<%%95 zb+&w_wL(8d;`@yKK#<^j^f(WHsMIR(i2|Dm7^Tw(-^_O{dashzea1>fjLmvw~ zMlwcNQ`UlpnUN=*7WuvqG(` zj6Y=97z1!m1>h+N_v8+q?%|$9!IM9}oirOG3f~Sq1>@UEvzZX%;5gmasMrF%FaE(r zJnci$gl=m4q@lS{EGpgX9NbfUV$3N}6&lre*+GF95D~Drn zQfysqhQdnTd-w(V)Ac^d$d5wAW$9gRIpf4n)A>@6=mJwisp&xo04r3ya zy`242xV8cr2sS?Bei**v%X4gGP)7-T+j0^=M#@oD{11f}?qkJZ`In1MB1w6Z)#pbtGWkT5Jp5mKaOwq?=oEGb17e)!o;@HTD zQnRlm^JR1#rN>iApMN}ZoxAfg9E>7`A0^Goid&QpsS>r-CrEEo*TEy-dm=O%({l>u za0LB8J|hKtApYaaTiac75WSi&NehR~XYBB#H96uqVkwa-(Ybg9393oO#%dX=$k|cb z`{AZeFFE{rlU4l3qFcgvx-Z7(g|gB$C?SD>uMI=OOXstXpNjlO5&yB`7C*k#%i((L zD>Qc`BvKN*NI9I}y4kVuIrfv{rD}7QdP&tO5&_Vt;n7@G)^u@a<8$8Xkv+j*=c4A= z1+|dJx3FyySw+jT@wxUNzz;GNX6e%-g#lkD7e+0e$LqlAIE}@$Z9*;uQYBUx{FViv z_=jIdK}h!EI`s}%v+;TK@5BEPcM6(4Qrzt2NyxDAd2Kg~WryAzz`qgGX5;he$18Df z3@`L%AO6i*Q#QUk{eR&{dv+FjLyUXlM%8vFjeHA1aDwO&K`fpPtR`Y#f_yLd#REVs z#gSjnwLUn8pD_hZ`Ui9lOgI_Nu)Wx$$z%sUcn?N+k;w+-k^f8N zUvd4bU9f-G$-idhUjzB?egZs^|8MS$LwnZeThvR}xRBu=c-FD&t+yT#?DOgrEgE0q zKF?+Kh!FS)aUX~i)OstfCB}AwOG4lbfO9rxc6tA>n%wJf*{50&=l9^ej{d#|+2iwJ zfxo|BwbsaH zHu3ZmNFY~}m$I0}M+1M-zGPmh)en!6AR6SqSa$vBbj{@1a8ySX<1Yp2EWWIUv3%3VsTb(F*>uKnT{AwDN?*52d9Rl6qZTe zSI0JIxv||pOdt!fS!GcNNgdzYB3~@N4IMVCV|`%8ZFH1C(M;9(vX(=L?40641MZ;S z2g?Zm_ErbH9UQ){*KG)=;XyGGOaPBh;p}>~dIpA2G?Aq_f*QwAgcHT@{Ygl$NBy4V zji&2_f~pMeChM0gHWwvW&4okS4|EOW3A83PRJSOB5U&4P??CH~j3cPcq0J^~YJzD= zJc2q@zfy~KGl8!MZ&oe3Oj#KD`L;b|#Q)sig}Vt*hdPJ?nU( zPF9QjBKX~N&|xJ@oi`xr0mg|fwyEqCrIH*1oT#9)N~M}0B`RM-{J6@r3NuU-oPzL6 z0=DP#tY-9KjH%A8X8}o!FHYi}G1FBF;rkbE+oW#fXo#Q5P=W3{!nU^P>qfJg3ghD3 zUNjfI9$B%;i4_Vr5+ZQ+e!|UKd&WC65L118ww5z2L85Cul5p^EA3UDamI)Vmbt!WG zu;Z-2E(TmkbCZ!EGmC#BXu$?OXTdkAN(diXoMs8BT}3EQ?=*?6#7;LJcQr$ z={}1n*#~<$!1*}fJ+Z*Xcf<>5^tkMeYrISSO1Bv zeyQ1gQm=k*!Slt@KPJ*IaYkS7f3Kq{+{-CC%y@(N&^dA4=&ZB`Rja<9n$oWdwdx>sWbvdSw)y29jh%tYyjAl>O&QEIt|GJ_mXZk(|hAOn9 zju|m%y+HJSh;4uW%W1!c{#(T&)NLr5Ig1$O{IrWiw^LU<_bkNK4j|TB)i^iiSH9T_ zdA6=ZqV#qVxkdRRoKy}7Z^rQHc)^_6DRURhV^P+7iNu3$e!h2_M#{Ubo7h-xeLNha z;*@=-UVx@;KapF+h(6|X_(I^So6DrDX1G0NVi37RgPhdX5`2iKZM{Kt>aRu58d^uq z!$d@lk4a&+Z?JO`b){`<4fSaI_`Q-eIZ*SX45*j(zTCrF^{$$Z_VL!C85=a{{Y1g@ zA(n}Oyb3)Z*3B`x$74PyoNF(#A|`U0X8I(r+FkLvTep>bef6E5=PB4OqGNT_ucRZ^ zDYWdAUm4P&29IMcBIjE2zlKi|5fvoOyV1U>!}O-;uDVUXF=1SvZ!A4jYX5oY&|v35 zqT6Lv60zMeK}+tWS-O+!wM!0?S|Pa_yNJx^5WPM=x!>hlqoxzcs0YXa*Vu|9R!r0@ zSICK8N52^w`FQLh=|CQu0p3~O>(`aK>!Dk2IJicvU#28e*tkz$k(dWa-JF(6<-Y0^Q8 zh+I8_1`rTLIw&L%k=~`L5Ebbi5h+1RM1)B1--dJVeZD{6pYO-Xue{mW*_qjy+1WR{ zGgS=pS8!v=AUqdfzgWVK4OjJvaU9xH>-sx%2JcyYtTw0W9Js6py)iL-ET}A6uo98< zvM^hsHd`19uI5`ddu0B~x9NSK3j|dqZ(dKbS%s@Ec0MBRhJN5rz+wu`Sy#IiB=jb#eTF(ba=MPiw4F)~qrBtoD11%K|s#m9NA( z&x^kDfj1UEaO56hN4W)KLB0ryr32tBdl}mykb#~bNk%WE!XT{n7 zXM_8{?&fudK+s_7i_>WMF39~hrIy=eAM|EMdQG8341&tDrkdLS3lNtn!TSaZc9xD* zD7j{!&IcH(i2F5_hMH!d>;`WWA$QqvP~i{)rAMJRVZT73KZv6dBN%@jATLlG2%3AL z$>F>|fr8pzbN1_)*pAKT zdv)G;0r5dPhBC|3x^{i`EHsi3%dLC&fYO5SMAWL-#X|uezytZ6*>hhSR?6BKGRz(Oaa7zXPSt5sEj|%v$aoOd zbrj7J5Y4h6lQ|t?6uGxvwol4aAhBAka>DJ(%hgQchUU|QT6-bQ1kblmD0E}^lKdyB z(q=-oAyO{6kuw{q#+5b(7?;&bmU|O5^zqRNkFwYssfld{7%h>mqeeUE&Gm(bEvZ8x zf%Z)r-U|`}^S6tS0{T!Dxj@R21ep^uRHpq8yIqdonMBQYFX)YglmOp*$7qu%@9)i) zu@uRLW6<0_%7ec=NZg-`=3Fy+o9lnpGw6QlpugTlnG@j>hgZY;KQ(g6#>&SbYW&!SQ#x(lD+}vN}I-E$5SB zxhAuiGW?Qh%xKVrkcOGy7^YHNd3E7+$v`R`zgrs1WuL{2GL@^;y^R2xk0l_%UqD*T zU1i0aR-lgGKIr_tU1Q&ffnSj(zDB@85@~PL&cZ@*FneCie;-R(nw!A=obTcnngn|I z{NLY0h{uLtSBDiUSr~f7bB=TfHaGPd5Bd|}Dn0eja_JSxY5WPWc$2A z;vh%n8PaD|GMSp40?8?qIA8KXx&ucIJNSj-+6GI;^={o6cmgOBC)_l>YNk0XHvX>m`|=lF$A&6s*TJ z${#LkyiQ_SaePjR+WBV`vzX2drSIJixPPOlI74N+A6SW&5S!JF$lrkLb!+gWQquxM z=Y4dvmV$zmO_pM?7%$H33KYF7%w%fxA3{(lUyhH9xQMQ$S*_FD=DduWg~^^T_d_plOdh`d4^Q)#AQd>Gj`boF}hTh!O(}F9fQ9heqi$E_|#thl&FAp ze8(WDf?l@bv%)tyTep>H5;v-+V1?q2R=!zPfI9B#@m04BSolBo2^QE~Yn!^P^#|~N zoTZhpxiOuOayHE~#Pmg9!9h?qYclF?-L$DU7P%$z*g$3%RDGDM$M%9=N(1vGUF}$` z`EJP4=e~d&?6?+>9LuoU?A3OeER>ExBaMI2E5snlYu*OP@9|IDM4@)Xb?x!;!yK!< zZ^b66hFOIkw`Wu7YMabM&`7kETYBE0jAIftqU6Sxmbg?!e|#g+_~-f)sXw3+gM^Xm z!`eM6Ah}_>153HJlJ%?fFg6raF3l3maCp^2ZZcN?)zVj4tlCcKhYb69-*mLV2L=8;*wcb`}LFxO^_zH~dog429TEf%j6m#xft8rfO zUm>-*2aAQo+_jdcP1%fw{`dgLafjkj3l0-C^`9pci@hv|^iQF2R`%b=w|ODRfr4;P zJ6!)m^m2GJ`PxXWdeCb32dy{*0UtPz0f+UUkPFT)`W>zaXX1s7NU@zp* z<1k@T3(9fV{tYk142d;2U_@R_c41Ykrk<6BNtWL^Js+e-8VmTf%1!2l)aAb!`lXLY z8fYH5lVJ9?An2G!8PRedBT_BE<<*DN_EaWWGP%vXn_?}f?QyLRIABMf$l_{knhx@D z0^UQ~GTCuBB*KVFM)C`iO=lhjIZHt2zbH9>Fc;STQ&Xr?97aIAHB*Q69K+7LiiA0t ztW5n0rC&h5YS2R)x73eLT1(_^Wr?h{DCFFuMtzB*c<7BDGHGPKccy1s9n7tp_ON7~ zz_ag|Z`yoZF)A#!eW^@IiY(^=0pI}5|L()Y6)AuFj2a8>?Qjof`;LtBp@h~Fg~rU0EN-zr zo^Uf)#ok{sfE#k?JM(sW)Y%m`2e}LvZg}e1jqP>MR&vLJVS<9)k3?0 zYaockM*G{@Z1vKcuOTQnPR6+ZbNR(mS81t|8w9BzScdN~7mZmk^A4UJW}hPf0@DCx z0|le z@B-v68aBW9?yv0^X5d%qbU}$kDX_qH`gk}vKS1&^Lk79lot#LPO*(`PP%Cwh?$$=` zI-kWiUrQ?2Ny|_c&3TQ*l**)>0>Oo9CM)2HLF?h;2i%T39i&&Zd+nggBI?-DJaWZA z)(NO0No@bZ9Cqe_e>%*_;5q27k0}KXq+O|{H%6Z;5K2%=^#N z?h7GhfdmaYsP+A(#i6J*L6?$o3#3!n9ei7Om-wWu_~pObYPiovQC+ON1AUOSm#;ph z9MHfY0O+(jEmKrmVu4lt68Q^svJj>%R<9c8P2}L&US^)QK!`S~&8tmbi)UimviwTV ziNm~(TMt;tf~||UK4B4G>rc|CGsRB>=_X(rKb}pcd{%;&^`UVbhz0wHp#~xmCz$Bm z%b*}UI7vl~IUb-JU{u!&ii`Ma z3<(vsTiGg>a^OnI;n;W(6^L`mCgx=t2oahJv)G<8-VC>BMU>w8<5ED+7BN|GiEK!BtX&`79hu2=9v0ocB~+IB>LMIS z;%kRajk}GIFc?s{a+TW#B!{;=ZIKa5=tQB;{nBmKVOSINV7(&0~OvMnl@Ep5$G zUqV^J9}fl^iky6rCS5bwq=QzhnwtV%8QwlaAF!EB@5#FDb3eN_Xyz@P3B3{jWpaM4 zo1BczaC##4$gGOsAy9n(ifhY#S2fG ztD#intv~_+^Lgu#-R4O&Ysd5t3Y2vJ5SiS8Bub&n6JHh&?%t^&L9;{WY+d=o%D|ev z0xxKtZQ5)n0VNYbYwz>B|7t^hJ literal 0 HcmV?d00001