Skip to content

Commit

Permalink
Fix classifiers in setup.cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
4ment committed Oct 28, 2021
1 parent 016a555 commit 09f087c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
12 changes: 8 additions & 4 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchtree.evolution.alignment import Alignment, Sequence
from torchtree.evolution.coalescent import ConstantCoalescent
from torchtree.evolution.datatype import NucleotideDataType
from torchtree.evolution.io import read_tree, read_tree_and_alignment
from torchtree.evolution.site_pattern import compress_alignment
from torchtree.evolution.substitution_model import JC69
from torchtree.evolution.taxa import Taxa, Taxon
Expand All @@ -21,7 +22,6 @@
ReparameterizedTimeTreeModel,
heights_from_branch_lengths,
)
from torchtree.io import read_tree, read_tree_and_alignment


def benchmark(f):
Expand Down Expand Up @@ -473,10 +473,13 @@ def fn_grad(ratios_root_height):
def ratio_transform(args):
replicates = args.replicates
tree = read_tree(args.tree, True, True)
taxa_count = len(tree.taxon_namespace)
taxa = []
for node in tree.leaf_node_iter():
taxa.append(Taxon(node.label, {'date': node.date}))
ratios_root_height = Parameter("internal_heights", torch.tensor([0.5] * 67 + [10]))
ratios_root_height = Parameter(
"internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [10])
)
tree_model = ReparameterizedTimeTreeModel(
"tree", tree, Taxa('taxa', taxa), ratios_root_height
)
Expand Down Expand Up @@ -632,11 +635,12 @@ def fn3_grad_jit(ratios_root_height):

def constant_coalescent(args):
tree = read_tree(args.tree, True, True)
taxa_count = len(tree.taxon_namespace)
taxa = []
for node in tree.leaf_node_iter():
taxa.append(Taxon(node.label, {'date': node.date}))
ratios_root_height = Parameter(
"internal_heights", torch.tensor([0.5] * 67 + [20.0])
"internal_heights", torch.tensor([0.5] * (taxa_count - 2) + [20.0])
)
tree_model = ReparameterizedTimeTreeModel(
"tree", tree, Taxa('taxa', taxa), ratios_root_height
Expand Down Expand Up @@ -743,7 +747,7 @@ def fn3_grad(tree_model, ratios_root_height, pop_size):
return log_p

x, counts = torch.unique(tree_model.sampling_times, return_counts=True)
counts = torch.cat((counts, torch.tensor([-1] * 68)))
counts = torch.cat((counts, torch.tensor([-1] * (taxa_count - 1))))

with torch.no_grad():
total_time, log_p = fn3(args.replicates, tree_model, pop_size)
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ long_description = file: README.md
license = GPL3
license_file = LICENSE
classifiers =
Intended Audience :: Science/Research",
Intended Audience :: Science/Research
License :: OSI Approved :: GNU General Public License v3 (GPLv3)
Operating System :: OS Independent
Programming Language :: Python :: 3.5
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Topic :: Scientific/Engineering :: Bio-Informatics

[options]
Expand Down

0 comments on commit 09f087c

Please sign in to comment.