Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 57 additions & 13 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import UserDict
from collections import UserDict, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from dataclasses import dataclass, field
from itertools import batched
from threading import Lock
from typing import Iterable, Sequence, TypeGuard
Expand Down Expand Up @@ -33,6 +33,10 @@ class SlotOutOfRootsRange(Exception): ...
class FrameCheckpoint:
slot: SlotNumber # Slot for the state to get the trusted block roots from.
duty_epochs: Sequence[EpochNumber] # NOTE: max 255 elements.
raw_attestations: defaultdict[ValidatorIndex, list[bool | None]] = field(init=False)

def __post_init__(self):
self.raw_attestations = defaultdict(lambda: [None] * len(self.duty_epochs))


@dataclass
Expand Down Expand Up @@ -158,7 +162,8 @@ def exec(self, checkpoint: FrameCheckpoint) -> int:
duty_epoch: self._select_block_roots(block_roots, duty_epoch, checkpoint.slot)
for duty_epoch in unprocessed_epochs
}
self._process(block_roots, checkpoint.slot, unprocessed_epochs, duty_epochs_roots)
self._process(checkpoint, block_roots, unprocessed_epochs, duty_epochs_roots)
self._finalize(checkpoint)
self.state.commit()
return len(unprocessed_epochs)

Expand Down Expand Up @@ -201,8 +206,8 @@ def _select_block_root_by_slot(block_roots: list[BlockRoot | None], checkpoint_s

def _process(
self,
checkpoint: FrameCheckpoint,
checkpoint_block_roots: list[BlockRoot | None],
checkpoint_slot: SlotNumber,
unprocessed_epochs: list[EpochNumber],
duty_epochs_roots: dict[EpochNumber, tuple[list[SlotBlockRoot], list[SlotBlockRoot]]]
):
Expand All @@ -211,8 +216,8 @@ def _process(
futures = {
executor.submit(
self._check_duties,
checkpoint,
checkpoint_block_roots,
checkpoint_slot,
duty_epoch,
*duty_epochs_roots[duty_epoch]
)
Expand All @@ -228,26 +233,47 @@ def _process(
executor.shutdown(wait=True, cancel_futures=True)
logger.info({"msg": "The executor was shut down"})

@timeit(lambda args, duration: logger.info({"msg": f"Checkpoint slot {args.checkpoint.slot} processing finalized in {duration:.2f} seconds"}))
def _finalize(self, checkpoint: FrameCheckpoint) -> None:
logger.info({"msg": "Finalizing checkpoint processing"})
checkpoint_l_epoch = min(checkpoint.duty_epochs)
for validator_index, raw_attestations in checkpoint.raw_attestations.items():
slashing_epoch = self.state.slashings.get(validator_index)
if slashing_epoch is not None:
if slashing_epoch <= checkpoint_l_epoch:
continue
stop_index = checkpoint.duty_epochs.index(slashing_epoch)
else:
stop_index = len(checkpoint.duty_epochs)
for epoch, att_result in zip(checkpoint.duty_epochs[:stop_index], raw_attestations[:stop_index]):
if att_result is None:
continue
self.state.increment_att_duty(epoch, validator_index, included=att_result)

@timeit(lambda args, duration: logger.info({"msg": f"Epoch {args.duty_epoch} processed in {duration:.2f} seconds"}))
def _check_duties(
self,
checkpoint: FrameCheckpoint,
checkpoint_block_roots: list[BlockRoot | None],
checkpoint_slot: SlotNumber,
duty_epoch: EpochNumber,
duty_epoch_roots: list[SlotBlockRoot],
next_epoch_roots: list[SlotBlockRoot],
):
logger.info({"msg": f"Processing epoch {duty_epoch}"})

att_committees = self._prepare_att_committees(duty_epoch)
propose_duties = self._prepare_propose_duties(duty_epoch, checkpoint_block_roots, checkpoint_slot)
propose_duties = self._prepare_propose_duties(duty_epoch, checkpoint_block_roots, checkpoint.slot)
sync_committees = self._prepare_sync_committee(duty_epoch, duty_epoch_roots)
slashings = []

for slot, root in [*duty_epoch_roots, *next_epoch_roots]:
missed_slot = root is None
if missed_slot:
continue
attestations, sync_aggregate = self.cc.get_block_attestations_and_sync(root)
(
attestations, sync_aggregate, proposer_slashings, attester_slashings
) = self.cc.get_block_attestations_and_sync(root)
slashings.extend(process_slashings(proposer_slashings, attester_slashings))
process_attestations(attestations, att_committees)
if (slot, root) in duty_epoch_roots:
propose_duties[slot].included = True
Expand All @@ -256,20 +282,24 @@ def _check_duties(
with lock:
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")

for validator_index in slashings:
self.state.add_slashing(validator_index, duty_epoch)
# Because of slashing case we need to
# put aside attestation results for handling later (at the end of the checkpoint processing)
epoch_checkpoint_index = checkpoint.duty_epochs.index(duty_epoch)
for att_committee in att_committees.values():
for att_duty in att_committee:
self.state.increment_att_duty(
duty_epoch,
att_duty.validator_index,
included=att_duty.included,
)
checkpoint.raw_attestations[att_duty.validator_index][epoch_checkpoint_index] = att_duty.included

for sync_committee in sync_committees.values():
for sync_duty in sync_committee:
self.state.increment_sync_duty(
duty_epoch,
sync_duty.validator_index,
included=sync_duty.included,
)

for proposer_duty in propose_duties.values():
self.state.increment_prop_duty(
duty_epoch,
Expand Down Expand Up @@ -454,3 +484,17 @@ def _bytes_to_bool_list(bytes_: bytes, count: int | None = None) -> list[bool]:
count = count if count is not None else len(bytes_) * 8
# copied from https://github.com/ethereum/py-ssz/blob/main/ssz/sedes/bitvector.py#L66
return [bool((bytes_[bit_index // 8] >> bit_index % 8) % 2) for bit_index in range(count)]


def process_slashings(proposer_slashings, attester_slashings):
slashed_indexes = []
for proposer_slashing in proposer_slashings:
signed_header_1 = proposer_slashing['signed_header_1']
slashed_indexes.append(signed_header_1['message']['proposer_index'])

for attester_slashing in attester_slashings:
attestation_1 = attester_slashing['attestation_1']
attestation_2 = attester_slashing['attestation_2']
attesters = set(attestation_1['attesting_indices']).intersection(attestation_2['attesting_indices'])
slashed_indexes.extend(attesters)
return slashed_indexes
7 changes: 7 additions & 0 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def merge(self, other: Self) -> None:

type Frame = tuple[EpochNumber, EpochNumber]
type StateData = dict[Frame, NetworkDuties]
type SlashingsData = dict[ValidatorIndex, EpochNumber]


class State:
Expand All @@ -77,6 +78,8 @@ class State:
"""
frames: list[Frame]
data: StateData
# TODO: fetch slashings from l_epoch_first block state to have it in State initially
slashings: SlashingsData

_epochs_to_process: tuple[EpochNumber, ...]
_processed_epochs: set[EpochNumber]
Expand All @@ -88,6 +91,7 @@ def __init__(self) -> None:
self.data = {}
self._epochs_to_process = tuple()
self._processed_epochs = set()
self.slashings = {}

EXTENSION = ".pkl"

Expand Down Expand Up @@ -175,6 +179,9 @@ def increment_sync_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, inc
frame = self.find_frame(epoch)
self.data[frame].syncs[val_index].add_duty(included)

def add_slashing(self, val_index: ValidatorIndex, epoch: EpochNumber) -> None:
self.slashings[val_index] = epoch

def add_processed_epoch(self, epoch: EpochNumber) -> None:
self._processed_epochs.add(epoch)

Expand Down
6 changes: 4 additions & 2 deletions src/providers/consensus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_block_details(self, state_id: SlotNumber | BlockRoot) -> BlockDetailsRes
return BlockDetailsResponse.from_response(**data)

@lru_cache(maxsize=variables.CSM_ORACLE_MAX_CONCURRENCY * 32 * 2) # threads count * blocks * epochs to check duties
def get_block_attestations_and_sync(self, state_id: SlotNumber | BlockRoot) -> tuple[list[BlockAttestation], SyncAggregate]:
def get_block_attestations_and_sync(self, state_id: SlotNumber | BlockRoot) -> tuple[list[BlockAttestation], SyncAggregate, list, list]:
"""Spec: https://ethereum.github.io/beacon-APIs/#/Beacon/getBlockV2"""
data, _ = self._get(
self.API_GET_BLOCK_DETAILS,
Expand All @@ -133,8 +133,10 @@ def get_block_attestations_and_sync(self, state_id: SlotNumber | BlockRoot) -> t

attestations = [BlockAttestationResponse.from_response(**att) for att in data["message"]["body"]["attestations"]]
sync = SyncAggregate.from_response(**data["message"]["body"]["sync_aggregate"])
proposer_slashings = data["message"]["body"]["proposer_slashings"]
attester_slashings = data["message"]["body"]["attester_slashings"]

return cast(list[BlockAttestation], attestations), sync
return cast(list[BlockAttestation], attestations), sync, proposer_slashings, attester_slashings

@list_of_dataclasses(SlotAttestationCommittee.from_response)
def get_attestation_committees(
Expand Down
Loading