Skip to content

Commit

Permalink
delay protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolayBlagoev committed Apr 2, 2024
1 parent 9163a04 commit 4d7646b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
27 changes: 27 additions & 0 deletions deccom/protocols/delayprotocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio
from typing import Callable, Union
from deccom.peers.peer import Peer
from deccom.protocols.abstractprotocol import AbstractProtocol
from deccom.protocols.streamprotocol import StreamProtocol
from deccom.protocols.wrappers import bindfrom, bindto
from deccom.utils.common import *

class DelayProtocol(AbstractProtocol):
def __init__(self, delay_map, submodule=None, callback: Callable[[tuple[str, int], bytes], None] = ...):
self.stream_callback = lambda data, node_id,addr: ...
self.delay_map = delay_map
super().__init__(submodule, callback)

@bindfrom("stream_callback")
def process_data(self,data,node_id,addr):
p = self.get_peer(node_id)
loop = asyncio.get_event_loop()
dl = self.delay_map(p.pub_key, self.peer.pub_key)

loop.call_later(dl[0]/1000 + len(data)/(1024*1024*dl[1]), self.stream_callback, data, node_id, addr)
#self.stream_callback(data,node_id,addr)


@bindto("get_peer")
def get_peer(self, id: bytes) -> Union[Peer,None]:
return None
2 changes: 1 addition & 1 deletion swarmprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def process_datagram(self, addr: tuple[str, int], data: bytes):
print("ADDING TO SAME STAGE")
elif stage == (self.stage + 1) % self.pipeline_size:
if self.next_stage.get(nodeid) == None:
self.next_stage[nodeid] = PeerClassification(nodeid, 15000.0)
self.next_stage[nodeid] = PeerClassification(nodeid, 0.5)
return
elif data[0] == SwarmProtocol.COMPLETE:
# print("COMPLETE")
Expand Down
31 changes: 29 additions & 2 deletions swarmtrainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from deccom.protocols.delayprotocol import DelayProtocol
from deccom.protocols.peerdiscovery.kademliadiscovery import KademliaDiscovery
from gpt_distributed import GPTStageFirst, GPTStageLast, GPTStageMiddle
from sys import argv
Expand All @@ -7,6 +8,7 @@
from deccom.protocols.defaultprotocol import DefaultProtocol
from deccom.peers import Peer
from deccom.protocols.streamprotocol import StreamProtocol
from swarmprotocol import SwarmProtocol
from trainingnode import TrainingNode
from faultprotocol import FaultProtocol
from task_datasets.qqp import get_glue_qqp_train_data_loader
Expand Down Expand Up @@ -84,13 +86,38 @@ def forward(self, x):
return F.log_softmax(x)


delay_bandwidth_dict = {
"0-1": (143, 0.007),
"0-2": (172, 0.006),
"0-3": (11, 0.007),
"0-4": (100, 0.004),
"0-5": (86, 0.010),
"1-2": (34, 0.010),
"1-3": (130, 0.006),
"1-4": (223, 0.002),
"1-5": (210, 0.002),
"2-3": (159, 0.005),
"2-4": (235, 0.003),
"2-5": (238, 0.010),
"3-4": (99, 0.003),
"3-5": (86, 0.010),
"4-5": (14, 0.011),



}
def delay_map(p1,p2):
if delay_bandwidth_dict.get(p1+"-"+p2) != None:
return delay_bandwidth_dict.get(p1+"-"+p2)
else:
delay_bandwidth_dict.get(p2+"-"+p1)
protocol = DefaultProtocol()
gossip = KademliaDiscovery([],interval=12)
gossip.set_lower(protocol)
stream = StreamProtocol(False)
stream.set_lower(gossip)
delayer= DelayProtocol(delay_map=delay_map)
delayer.set_lower(stream)
net = None
train_loader = None
n = Peer(("127.0.0.1", 10015))
Expand All @@ -115,8 +142,8 @@ def forward(self, x):

optimizer = optim.SGD(net.parameters(), lr=learning_rate,
momentum=momentum)
training = FaultProtocol(int(argv[1]) % 6,net,optimizer,train_loader)
training.set_lower(stream)
training = SwarmProtocol(pipeline_size=3, stage=int(argv[1])%3, net = net, optimizer=optimizer, max_iterations=10, microbatches=3)
training.set_lower(delayer)
me = TrainingNode( Peer(None, pub_key=argv[1]), training,"127.0.0.1", 10015 if argv[1] == "0" else None)
print( "TCP", me.tcp_port)

Expand Down

0 comments on commit 4d7646b

Please sign in to comment.