Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight quantization #88

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
55 changes: 54 additions & 1 deletion ml_genn/ml_genn/compilers/compiled_network.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import numpy as np

from typing import List, Mapping, Optional, Sequence, Union
from ..serialisers import Serialiser
from ..utils.callback_list import CallbackList
from ..utils.network import PopulationType

from ..utils.module import get_object
from ..utils.network import get_underlying_pop

from ..serialisers import default_serialisers

OutputType = Union[np.ndarray, List[np.ndarray]]


Expand All @@ -14,12 +18,61 @@ class CompiledNetwork:

def __init__(self, genn_model, neuron_populations,
connection_populations, communicator,
num_recording_timesteps=None):
num_recording_timesteps=None,
checkpoint_connection_vars: list = [],
checkpoint_population_vars: list = []):
self.genn_model = genn_model
self.neuron_populations = neuron_populations
self.connection_populations = connection_populations
self.communicator = communicator
self.num_recording_timesteps = num_recording_timesteps
self.checkpoint_connection_vars = checkpoint_connection_vars
self.checkpoint_population_vars = checkpoint_population_vars

# Build set of synapse groups with checkpoint variables
self.checkpoint_synapse_groups = set(
connection_populations[c]
for c, _ in self.checkpoint_connection_vars)

def save_connectivity(self, keys=(), serialiser="numpy"):
# Create serialiser
serialiser = get_object(serialiser, Serialiser, "Serialiser",
default_serialisers)

# Loop through connections and their corresponding synapse groups
for c, genn_pop in self.connection_populations.items():
# If synapse group has ragged connectivity, download
# connectivity and save pre and postsynaptic indices
if genn_pop.is_ragged:
genn_pop.pull_connectivity_from_device()
serialiser.serialise(keys + (c, "pre_ind"),
genn_pop.get_sparse_pre_inds())
serialiser.serialise(keys + (c, "post_ind"),
genn_pop.get_sparse_post_inds())

def save(self, keys=(), serialiser="numpy"):
# Create serialiser
serialiser = get_object(serialiser, Serialiser, "Serialiser",
default_serialisers)

# Loop through synapse groups with variables to be checkpointed
for genn_pop in self.checkpoint_synapse_groups:
# If synapse group has ragged connectivity, download
# connectivity so variables can be accessed correctly
if genn_pop.is_ragged:
genn_pop.pull_connectivity_from_device()

# Loop through connection variables to checkpoint
for c, v in self.checkpoint_connection_vars:
genn_pop = self.connection_populations[c]
genn_pop.pull_var_from_device(v)
serialiser.serialise(keys + (c, v), genn_pop.get_var_values(v))

# Loop through population variables to checkpoint
for p, v in self.checkpoint_population_vars:
genn_pop = self.neuron_populations[p]
genn_pop.pull_var_from_device(v)
serialiser.serialise(keys + (p, v), genn_pop.vars[v].view)

def set_input(self, inputs: dict):
# Loop through populations
Expand Down
54 changes: 3 additions & 51 deletions ml_genn/ml_genn/compilers/compiled_training_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
from ..callbacks import BatchProgressBar
from ..connectivity.sparse_base import SparseBase
from ..metrics import Metric
from ..serialisers import Serialiser
from ..utils.callback_list import CallbackList
from ..utils.data import MetricsType

from ..utils.data import (batch_dataset, get_dataset_size,
permute_dataset, split_dataset)
from ..utils.module import get_object, get_object_mapping
from ..utils.module import get_object_mapping
from ..utils.network import get_underlying_pop

from ..metrics import default_metrics
from ..serialisers import default_serialisers

class CompiledTrainingNetwork(CompiledNetwork):
def __init__(self, genn_model, neuron_populations,
Expand All @@ -28,23 +26,17 @@ def __init__(self, genn_model, neuron_populations,
reset_time_between_batches: bool = True):
super(CompiledTrainingNetwork, self).__init__(
genn_model, neuron_populations, connection_populations,
communicator, example_timesteps)
communicator, example_timesteps, checkpoint_connection_vars,
checkpoint_population_vars)

self.losses = losses
self.optimiser = optimiser
self.example_timesteps = example_timesteps
self.base_train_callbacks = base_train_callbacks
self.base_validate_callbacks = base_validate_callbacks
self.optimiser_custom_updates = optimiser_custom_updates
self.checkpoint_connection_vars = checkpoint_connection_vars
self.checkpoint_population_vars = checkpoint_population_vars
self.reset_time_between_batches = reset_time_between_batches

# Build set of synapse groups with checkpoint variables
self.checkpoint_synapse_groups = set(
connection_populations[c]
for c, _ in self.checkpoint_connection_vars)

def train(self, x: dict, y: dict, num_epochs: int,
start_epoch: int = 0, shuffle: bool = True,
metrics: MetricsType = "sparse_categorical_accuracy",
Expand Down Expand Up @@ -196,46 +188,6 @@ def train(self, x: dict, y: dict, num_epochs: int,
else:
return train_metrics, train_callback_list.get_data()

def save_connectivity(self, keys=(), serialiser="numpy"):
# Create serialiser
serialiser = get_object(serialiser, Serialiser, "Serialiser",
default_serialisers)

# Loop through connections and their corresponding synapse groups
for c, genn_pop in self.connection_populations.items():
# If synapse group has ragged connectivity, download
# connectivity and save pre and postsynaptic indices
if genn_pop.is_ragged:
genn_pop.pull_connectivity_from_device()
serialiser.serialise(keys + (c, "pre_ind"),
genn_pop.get_sparse_pre_inds())
serialiser.serialise(keys + (c, "post_ind"),
genn_pop.get_sparse_post_inds())

def save(self, keys=(), serialiser="numpy"):
# Create serialiser
serialiser = get_object(serialiser, Serialiser, "Serialiser",
default_serialisers)

# Loop through synapse groups with variables to be checkpointed
for genn_pop in self.checkpoint_synapse_groups:
# If synapse group has ragged connectivity, download
# connectivity so variables can be accessed correctly
if genn_pop.is_ragged:
genn_pop.pull_connectivity_from_device()

# Loop through connection variables to checkpoint
for c, v in self.checkpoint_connection_vars:
genn_pop = self.connection_populations[c]
genn_pop.pull_var_from_device(v)
serialiser.serialise(keys + (c, v), genn_pop.get_var_values(v))

# Loop through population variables to checkpoint
for p, v in self.checkpoint_population_vars:
genn_pop = self.neuron_populations[p]
genn_pop.pull_var_from_device(v)
serialiser.serialise(keys + (p, v), genn_pop.vars[v].view)

def _validate_batch(self, batch: int, x: dict, y: dict, metrics,
callback_list: CallbackList):
# Start batch
Expand Down
81 changes: 75 additions & 6 deletions ml_genn/ml_genn/compilers/event_prop_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .compiler import Compiler
from .compiled_training_network import CompiledTrainingNetwork
from .weight_quantisation import WeightQuantiseBatch, WeightQuantiseTrain
from ..callbacks import (BatchProgressBar, Callback, CustomUpdateOnBatchBegin,
CustomUpdateOnBatchEnd, CustomUpdateOnTimestepEnd)
from ..communicators import Communicator
Expand All @@ -34,7 +35,8 @@
from .compiler import create_reset_custom_update
from ..utils.module import get_object, get_object_mapping
from ..utils.network import get_underlying_pop
from ..utils.value import is_value_constant
from ..utils.quantisation import quantise_signed
from ..utils.value import is_value_constant, is_value_initializer

from pygenn.genn_wrapper import (SynapseMatrixType_DENSE_INDIVIDUALG,
SynapseMatrixType_SPARSE_INDIVIDUALG,
Expand Down Expand Up @@ -131,6 +133,8 @@ def __init__(self, losses, readouts, backend_name):
self.timestep_softmax_populations = []
self.feedback_connections = []
self.update_trial_pops = []
self.batch_quantise_connections = []
self.train_quantise_connections = []

def add_neuron_reset_vars(self, pop, reset_vars,
reset_event_ring, reset_v_ring):
Expand Down Expand Up @@ -332,6 +336,8 @@ def __init__(self, example_timesteps: int, losses, optimiser="adam",
reg_nu_upper: float = 0.0, max_spikes: int = 500,
strict_buffer_checking: bool = False,
per_timestep_loss: bool = False, dt: float = 1.0,
quantise_num_weight_bits: int = None,
quantise_weight_percentile: float = 99.0,
batch_size: int = 1, rng_seed: int = 0,
kernel_profiling: bool = False,
communicator: Communicator = None, **genn_kwargs):
Expand All @@ -352,6 +358,8 @@ def __init__(self, example_timesteps: int, losses, optimiser="adam",
self.max_spikes = max_spikes
self.strict_buffer_checking = strict_buffer_checking
self.per_timestep_loss = per_timestep_loss
self.quantise_num_weight_bits = quantise_num_weight_bits
self.quantise_weight_percentile = quantise_weight_percentile
self._optimiser = get_object(optimiser, Optimiser, "Optimiser",
default_optimisers)

Expand Down Expand Up @@ -750,6 +758,7 @@ def build_weight_update_model(self, conn, connect_snippet, compile_state):
"support heterogeneous delays")

# If this is some form of trainable connectivity
quantise_enabled = (self.quantise_num_weight_bits is not None)
if connect_snippet.trainable:
# **NOTE** this is probably not necessary as
# it's also checked in build_neuron_model
Expand All @@ -771,16 +780,58 @@ def build_weight_update_model(self, conn, connect_snippet, compile_state):
model=deepcopy(weight_update_model_atomic_cuda if use_atomic
else weight_update_model),
param_vals={"TauSyn": tau_syn},
var_vals={"g": connect_snippet.weight, "Gradient": 0.0})
var_vals={"Gradient": 0.0})
# Add weights to list of checkpoint vars
compile_state.checkpoint_connection_vars.append((conn, "g"))

# Add connection to list of connections to optimise
compile_state.weight_optimiser_connections.append(conn)

# If quantisation is enabled
if quantise_enabled:
# Initially zero the quantised forward weight
wum.var_vals["g"] = 0.0

# Add additional state variable to hold unquantised weight
wum.add_var("gBack", "scalar", connect_snippet.weight,
VarAccess_READ_ONLY)

# Also checkpoint this variable
compile_state.checkpoint_connection_vars.append(
(conn, "gBack"))

# Add connection to list of those
# requiring quantisation every batch
compile_state.batch_quantise_connections.append(conn)
# Otherwise, initialise forward weights directly
else:
wum.var_vals["g"] = connect_snippet.weight
# Otherwise, e.g. it's a pooling layer
else:
wum = WeightUpdateModel(model=deepcopy(static_weight_update_model),
param_vals={"g": connect_snippet.weight})
wum = WeightUpdateModel(
model=deepcopy(static_weight_update_model))

# If quantisation is enabled
if quantise_enabled:
# If weight is initialised on device
if is_value_initializer(connect_snippet.weight):
# Initially zero the quantised forward weight
# and force it to be implemented as a variable
wum.param_vals["g"] = 0.0
wum.make_param_var("g")

# Add connection to list of those
# requiring quantisation at start of training
compile_state.train_quantise_connections.append(conn)
# Otherwise, directly initialise
# forward weights with quantised version
else:
wum.param_vals["g"] = quantise_signed(
connect_snippet.weight, self.quantise_num_weight_bits,
self.quantise_weight_percentile)
# Otherwise, initialise forward weights directly
else:
wum.param_vals["g"] = connect_snippet.weight

# If source neuron isn't an input neuron
source_neuron = conn.source().neuron
Expand All @@ -789,8 +840,11 @@ def build_weight_update_model(self, conn, connect_snippet, compile_state):
compile_state.feedback_connections.append(conn)

# If it's LIF, add additional event code to backpropagate gradient
weight = ("gBack"
if quantise_enabled and connect_snippet.trainable
else "g")
if isinstance(source_neuron, LeakyIntegrateFire):
wum.append_event_code("$(addToPre, $(g) * ($(LambdaV_post) - $(LambdaI_post)));")
wum.append_event_code(f"$(addToPre, $({weight}) * ($(LambdaV_post) - $(LambdaI_post)));")

# Return weight update model
return wum
Expand All @@ -805,9 +859,10 @@ def create_compiled_network(self, genn_model, neuron_populations,
optimiser_custom_updates = []
for i, c in enumerate(compile_state.weight_optimiser_connections):
genn_pop = connection_populations[c]
weight = "g" if self.quantise_num_weight_bits is None else "gBack"
optimiser_custom_updates.append(
self._create_optimiser_custom_update(
f"Weight{i}", create_wu_var_ref(genn_pop, "g"),
f"Weight{i}", create_wu_var_ref(genn_pop, weight),
create_wu_var_ref(genn_pop, "Gradient"), genn_model))

# Add per-batch softmax custom updates for each population that requires them
Expand Down Expand Up @@ -867,6 +922,20 @@ def create_compiled_network(self, genn_model, neuron_populations,
base_train_callbacks.append(CustomUpdateOnBatchBegin("Reset"))
base_validate_callbacks.append(CustomUpdateOnBatchBegin("Reset"))

# Add batch quantisation callbacks for connections that require it
for c in compile_state.batch_quantise_connections:
base_train_callbacks.append(
WeightQuantiseBatch(connection_populations[c], "gBack", "g",
self.quantise_weight_percentile,
self.quantise_num_weight_bits))

# Add train quantisation callbacks for connections that require it
for c in compile_state.train_quantise_connections:
base_train_callbacks.append(
WeightQuantiseTrain(connection_populations[c], "g", "g",
self.quantise_weight_percentile,
self.quantise_num_weight_bits))

return CompiledTrainingNetwork(
genn_model, neuron_populations, connection_populations,
self.communicator, compile_state.losses, self._optimiser,
Expand Down
Loading