diff --git a/python/dgllife/model/model_zoo/jtvae.py b/python/dgllife/model/model_zoo/jtvae.py index 563e8340..af386dcf 100644 --- a/python/dgllife/model/model_zoo/jtvae.py +++ b/python/dgllife/model/model_zoo/jtvae.py @@ -270,7 +270,7 @@ def forward(self, tree_graphs, tree_vec): # Traverse the tree and predict on children for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)): - eid = eid.to(device) + eid = eid.to(device=device, dtype=tree_graphs.idtype) p = p.to(device=device, dtype=tree_graphs.idtype) # Message passing excluding the target