Skip to content

Commit

Permalink
Fixing pytorch lightning dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
mortonjt committed Mar 8, 2022
1 parent ab48e4f commit 05aec99
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 5 additions & 4 deletions deepblast/tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shutil
import unittest
import torch
from deepblast.trainer import LightningAligner
from deepblast.utils import get_data_path
from deepblast.sim import hmm_alignments
Expand Down Expand Up @@ -39,7 +40,7 @@ def tearDown(self):
if os.path.exists('valid.txt'):
os.remove('valid.txt')

@unittest.skip("Can only run with GPU")
@unittest.skipUnless(torch.cuda.is_available(), 'No GPU was detected')
def test_trainer_sim(self):
output_dir = 'output'
args = [
Expand All @@ -51,7 +52,7 @@ def test_trainer_sim(self):
'--batch-size', '3',
'--num-workers', '1',
'--learning-rate', '1e-4',
'--clip-ends', 'False',
# '--clip-ends', 'False',
'--visualization-fraction', '0.5',
'--gpus', '1'
]
Expand All @@ -71,7 +72,7 @@ def test_trainer_sim(self):
)
trainer.fit(model)

@unittest.skip("Can only run with GPU")
@unittest.skipUnless(torch.cuda.is_available(), 'No GPU was detected')
def test_trainer_struct(self):
output_dir = 'output'
args = [
Expand All @@ -83,7 +84,7 @@ def test_trainer_struct(self):
'--batch-size', '1',
'--num-workers', '1',
'--learning-rate', '1e-4',
'--clip-ends', 'False',
# '--clip-ends', 'False',
'--visualization-fraction', '1',
'--gpus', '1'
]
Expand Down
3 changes: 2 additions & 1 deletion deepblast/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ class LightningAligner(pl.LightningModule):

def __init__(self, args):
super(LightningAligner, self).__init__()
self._hparams = args
self.tokenizer = UniprotTokenizer(pad_ends=False)
self.hparams = args
#self.hparams = args
self.initialize_aligner()
if self.hparams.loss == 'sse':
self.loss_func = SoftAlignmentLoss()
Expand Down

0 comments on commit 05aec99

Please sign in to comment.