diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index bd5d60a0f..df0976755 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -1,7 +1,7 @@ import logging -from collections import UserDict +from collections import UserDict, defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass +from dataclasses import dataclass, field from itertools import batched from threading import Lock from typing import Iterable, Sequence, TypeGuard @@ -33,6 +33,10 @@ class SlotOutOfRootsRange(Exception): ... class FrameCheckpoint: slot: SlotNumber # Slot for the state to get the trusted block roots from. duty_epochs: Sequence[EpochNumber] # NOTE: max 255 elements. + raw_attestations: defaultdict[ValidatorIndex, list[bool | None]] = field(init=False) + + def __post_init__(self): + self.raw_attestations = defaultdict(lambda: [None] * len(self.duty_epochs)) @dataclass @@ -158,7 +162,8 @@ def exec(self, checkpoint: FrameCheckpoint) -> int: duty_epoch: self._select_block_roots(block_roots, duty_epoch, checkpoint.slot) for duty_epoch in unprocessed_epochs } - self._process(block_roots, checkpoint.slot, unprocessed_epochs, duty_epochs_roots) + self._process(checkpoint, block_roots, unprocessed_epochs, duty_epochs_roots) + self._finalize(checkpoint) self.state.commit() return len(unprocessed_epochs) @@ -201,8 +206,8 @@ def _select_block_root_by_slot(block_roots: list[BlockRoot | None], checkpoint_s def _process( self, + checkpoint: FrameCheckpoint, checkpoint_block_roots: list[BlockRoot | None], - checkpoint_slot: SlotNumber, unprocessed_epochs: list[EpochNumber], duty_epochs_roots: dict[EpochNumber, tuple[list[SlotBlockRoot], list[SlotBlockRoot]]] ): @@ -211,8 +216,8 @@ def _process( futures = { executor.submit( self._check_duties, + checkpoint, checkpoint_block_roots, - checkpoint_slot, duty_epoch, *duty_epochs_roots[duty_epoch] ) @@ -228,11 +233,28 @@ def _process( executor.shutdown(wait=True, cancel_futures=True) logger.info({"msg": "The executor was shut down"}) + @timeit(lambda args, duration: logger.info({"msg": f"Checkpoint slot {args.checkpoint.slot} processing finalized in {duration:.2f} seconds"})) + def _finalize(self, checkpoint: FrameCheckpoint) -> None: + logger.info({"msg": "Finalizing checkpoint processing"}) + checkpoint_l_epoch = min(checkpoint.duty_epochs) + for validator_index, raw_attestations in checkpoint.raw_attestations.items(): + slashing_epoch = self.state.slashings.get(validator_index) + if slashing_epoch is not None: + if slashing_epoch <= checkpoint_l_epoch: + continue + stop_index = checkpoint.duty_epochs.index(slashing_epoch) + else: + stop_index = len(checkpoint.duty_epochs) + for epoch, att_result in zip(checkpoint.duty_epochs[:stop_index], raw_attestations[:stop_index]): + if att_result is None: + continue + self.state.increment_att_duty(epoch, validator_index, included=att_result) + @timeit(lambda args, duration: logger.info({"msg": f"Epoch {args.duty_epoch} processed in {duration:.2f} seconds"})) def _check_duties( self, + checkpoint: FrameCheckpoint, checkpoint_block_roots: list[BlockRoot | None], - checkpoint_slot: SlotNumber, duty_epoch: EpochNumber, duty_epoch_roots: list[SlotBlockRoot], next_epoch_roots: list[SlotBlockRoot], @@ -240,14 +262,18 @@ def _check_duties( logger.info({"msg": f"Processing epoch {duty_epoch}"}) att_committees = self._prepare_att_committees(duty_epoch) - propose_duties = self._prepare_propose_duties(duty_epoch, checkpoint_block_roots, checkpoint_slot) + propose_duties = self._prepare_propose_duties(duty_epoch, checkpoint_block_roots, checkpoint.slot) sync_committees = self._prepare_sync_committee(duty_epoch, duty_epoch_roots) + slashings = [] for slot, root in [*duty_epoch_roots, *next_epoch_roots]: missed_slot = root is None if missed_slot: continue - attestations, sync_aggregate = self.cc.get_block_attestations_and_sync(root) + ( + attestations, sync_aggregate, proposer_slashings, attester_slashings + ) = self.cc.get_block_attestations_and_sync(root) + slashings.extend(process_slashings(proposer_slashings, attester_slashings)) process_attestations(attestations, att_committees) if (slot, root) in duty_epoch_roots: propose_duties[slot].included = True @@ -256,13 +282,16 @@ def _check_duties( with lock: if duty_epoch not in self.state.unprocessed_epochs: raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed") + + for validator_index in slashings: + self.state.add_slashing(validator_index, duty_epoch) + # Because of slashing case we need to + # put aside attestation results for handling later (at the end of the checkpoint processing) + epoch_checkpoint_index = checkpoint.duty_epochs.index(duty_epoch) for att_committee in att_committees.values(): for att_duty in att_committee: - self.state.increment_att_duty( - duty_epoch, - att_duty.validator_index, - included=att_duty.included, - ) + checkpoint.raw_attestations[att_duty.validator_index][epoch_checkpoint_index] = att_duty.included + for sync_committee in sync_committees.values(): for sync_duty in sync_committee: self.state.increment_sync_duty( @@ -270,6 +299,7 @@ def _check_duties( sync_duty.validator_index, included=sync_duty.included, ) + for proposer_duty in propose_duties.values(): self.state.increment_prop_duty( duty_epoch, @@ -454,3 +484,17 @@ def _bytes_to_bool_list(bytes_: bytes, count: int | None = None) -> list[bool]: count = count if count is not None else len(bytes_) * 8 # copied from https://github.com/ethereum/py-ssz/blob/main/ssz/sedes/bitvector.py#L66 return [bool((bytes_[bit_index // 8] >> bit_index % 8) % 2) for bit_index in range(count)] + + +def process_slashings(proposer_slashings, attester_slashings): + slashed_indexes = [] + for proposer_slashing in proposer_slashings: + signed_header_1 = proposer_slashing['signed_header_1'] + slashed_indexes.append(signed_header_1['message']['proposer_index']) + + for attester_slashing in attester_slashings: + attestation_1 = attester_slashing['attestation_1'] + attestation_2 = attester_slashing['attestation_2'] + attesters = set(attestation_1['attesting_indices']).intersection(attestation_2['attesting_indices']) + slashed_indexes.extend(attesters) + return slashed_indexes diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 581197e5e..c3b4982ca 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -63,6 +63,7 @@ def merge(self, other: Self) -> None: type Frame = tuple[EpochNumber, EpochNumber] type StateData = dict[Frame, NetworkDuties] +type SlashingsData = dict[ValidatorIndex, EpochNumber] class State: @@ -77,6 +78,8 @@ class State: """ frames: list[Frame] data: StateData + # TODO: fetch slashings from l_epoch_first block state to have it in State initially + slashings: SlashingsData _epochs_to_process: tuple[EpochNumber, ...] _processed_epochs: set[EpochNumber] @@ -88,6 +91,7 @@ def __init__(self) -> None: self.data = {} self._epochs_to_process = tuple() self._processed_epochs = set() + self.slashings = {} EXTENSION = ".pkl" @@ -175,6 +179,9 @@ def increment_sync_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, inc frame = self.find_frame(epoch) self.data[frame].syncs[val_index].add_duty(included) + def add_slashing(self, val_index: ValidatorIndex, epoch: EpochNumber) -> None: + self.slashings[val_index] = epoch + def add_processed_epoch(self, epoch: EpochNumber) -> None: self._processed_epochs.add(epoch) diff --git a/src/providers/consensus/client.py b/src/providers/consensus/client.py index 36e7d9ef1..bdd4a6983 100644 --- a/src/providers/consensus/client.py +++ b/src/providers/consensus/client.py @@ -121,7 +121,7 @@ def get_block_details(self, state_id: SlotNumber | BlockRoot) -> BlockDetailsRes return BlockDetailsResponse.from_response(**data) @lru_cache(maxsize=variables.CSM_ORACLE_MAX_CONCURRENCY * 32 * 2) # threads count * blocks * epochs to check duties - def get_block_attestations_and_sync(self, state_id: SlotNumber | BlockRoot) -> tuple[list[BlockAttestation], SyncAggregate]: + def get_block_attestations_and_sync(self, state_id: SlotNumber | BlockRoot) -> tuple[list[BlockAttestation], SyncAggregate, list, list]: """Spec: https://ethereum.github.io/beacon-APIs/#/Beacon/getBlockV2""" data, _ = self._get( self.API_GET_BLOCK_DETAILS, @@ -133,8 +133,10 @@ def get_block_attestations_and_sync(self, state_id: SlotNumber | BlockRoot) -> t attestations = [BlockAttestationResponse.from_response(**att) for att in data["message"]["body"]["attestations"]] sync = SyncAggregate.from_response(**data["message"]["body"]["sync_aggregate"]) + proposer_slashings = data["message"]["body"]["proposer_slashings"] + attester_slashings = data["message"]["body"]["attester_slashings"] - return cast(list[BlockAttestation], attestations), sync + return cast(list[BlockAttestation], attestations), sync, proposer_slashings, attester_slashings @list_of_dataclasses(SlotAttestationCommittee.from_response) def get_attestation_committees(