-
Notifications
You must be signed in to change notification settings - Fork 0
/
GraphNode.py
69 lines (52 loc) · 1.52 KB
/
GraphNode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, Any, List
class GraphNode(ABC):
def __init__(self, isTrainable=False, trackGradients=True):
self.gradients = []
self.totalGradient = 0
self._value = None
self._inEdges = []
self._isTrainable = isTrainable
self._trackGradients = trackGradients
self._paramGradients = []
@abstractmethod
def forwardPass(self):
pass
@abstractmethod
def backwardPass(self):
pass
@property
def isTrainable(self):
return self._isTrainable
@property
def value(self):
return self._value
@value.setter
def value(self, v):
self._value = v
@property
def inEdges(self):
return self._inEdges
@inEdges.setter
def inEdges(self, inEdges):
self._inEdges = inEdges
@property
def trackGradients(self):
return self._trackGradients
@property
def paramGradients(self):
return self._paramGradients
@trackGradients.setter
def trackGradients(self, trackGradients):
self._trackGradients = trackGradients
def clearGradients(self):
self.paramGradients = []
self.totalGradient = 0
self.gradients = []
def receiveGradient(self, grad):
self.totalGradient += grad
def registerInEdges(self, sourceNodes : list[GraphNode]):
self._inEdges += sourceNodes
def addToParamValues(self, paramStep):
pass