Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:HIEST适应多维输入,将双联通向量搜寻从递归改为循环 #423

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions libcity/data/dataset/dataset_subclass/hiest_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,56 @@ def addEdge(self, u, v):
rooted with current vertex
st -- >> To store visited edges'''

def BCCUtilLoop(self, u, parent, low, disc, st):
stack = [(u, self.graph[u], 0, False, -1)]

while stack:
# u, self.graph[u], children, 标记是否在if disc[v] == -1分支中断了,v
u, g, children, flag, v = stack.pop(-1)
if flag:
# 继续把break后没走完的路走完
# Check if the subtree rooted with v has a connection to
# one of the ancestors of u
# Case 1 -- per Strongly Connected Components Article
low[u] = min(low[u], low[v])

# If u is an articulation point, pop
# all edges from stack till (u, v)
if parent[u] == -1 and children > 1 or parent[u] != -1 and low[v] >= disc[u]:
self.count += 1 # increment count
w = -1
tmp = []
while w != (u, v):
w = st.pop()
tmp.append(w)
self.res.append(tmp)
disc[u] = self.Time
low[u] = self.Time
self.Time += 1
while g:
v = g.pop(-1)
# If v is not visited yet, then make it a child of u
# in DFS tree and recur for it
if disc[v] == -1:
parent[v] = u
children += 1
st.append((u, v)) # store the edge in stack
# 把没遍历完的点入栈,下次继续遍历
stack.append((u, g, children, True, v))
# 把子节点入栈,继续向下搜索
stack.append((v, self.graph[v], 0, False, -1))
break

elif v != parent[u] and low[u] > disc[v]:
'''Update low value of 'u' only of 'v' is still in stack
(i.e. it's a back edge, not cross edge).
Case 2
-- per Strongly Connected Components Article'''

low[u] = min(low[u], disc[v])

st.append((u, v))

def BCCUtil(self, u, parent, low, disc, st):

# Count of children in current node
Expand Down Expand Up @@ -103,12 +153,12 @@ def BCC(self):
# in DFS tree rooted with vertex 'i'
for i in range(self.V):
if disc[i] == -1:
self.BCCUtil(i, parent, low, disc, st)
self.BCCUtilLoop(i, parent, low, disc, st)

# If stack is not empty, pop all edges from stack
if st:
self.count = self.count + 1
tmp=[]
tmp = []
while st:
w = st.pop()
tmp.append(w)
Expand Down Expand Up @@ -172,6 +222,6 @@ def get_data_feature(self):
Returns:
dict: 包含数据集的相关特征的字典
"""
return {"scaler": self.scaler, "adj_mx": self.adj_mx ,"ext_dim": self.ext_dim,
"num_nodes": self.num_nodes, "feature_dim": self.feature_dim,"regional_nodes": self.regional_nodes,
"Mor_mx": self.Mor_mx,"output_dim": self.output_dim, "num_batches": self.num_batches}
return {"scaler": self.scaler, "adj_mx": self.adj_mx, "ext_dim": self.ext_dim,
"num_nodes": self.num_nodes, "feature_dim": self.feature_dim, "regional_nodes": self.regional_nodes,
"Mor_mx": self.Mor_mx, "output_dim": self.output_dim, "num_batches": self.num_batches}
11 changes: 5 additions & 6 deletions libcity/model/traffic_speed_prediction/HIEST.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, config, data_feature):
# adaptive layer numbers
self.apt_layer = config.get('apt_layer', True)
if self.apt_layer:
self.layers = np.int(
self.layers = np.int64(
np.round(np.log((((self.input_window - 1) / (self.blocks * (self.kernel_size - 1))) + 1)) / np.log(2)))
print('# of layers change to %s' % self.layers)

Expand Down Expand Up @@ -318,6 +318,7 @@ def forward(self, batch):
# ---------------------------------------> + -------------> *skip*
# (dilation, init_dilation) = self.dilations[i]
# residual = dilation_func(x, dilation, init_dilation, i)

residual = x
# (batch_size, residual_channels, num_nodes, self.receptive_field)
# dilated convolution
Expand Down Expand Up @@ -359,9 +360,8 @@ def ortLoss(self, hr):
The implementation of L_{ort}
'''
# hr = (batch_size, end_channels, num_nodes, self.output_dim)
hr = hr.squeeze(3)
hr = hr.permute(2, 0, 1)
hr = torch.reshape(hr, (self.global_nodes, -1))
hr = hr.permute(2, 0, 1, 3)
hr = torch.reshape(hr, (self.global_nodes, -1, self.output_dim))
# print('hr.size')
# print(hr.size())
tmpLoss = 0
Expand All @@ -370,9 +370,8 @@ def ortLoss(self, hr):
for j in range(i + 1, self.global_nodes):
# print(hr[i].size())
tmpLoss += cos(hr[i], hr[j])

tmpLoss = tmpLoss / (self.global_nodes * (self.global_nodes - 1) / 2)

tmpLoss = torch.mean(tmpLoss)
return tmpLoss

def cal_adj(self, adjtype):
Expand Down