Skip to content

Commit

Permalink
Add coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed Oct 29, 2024
1 parent 2f334bf commit 897d58f
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 79 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
os: [ubuntu-latest]
python:
- "3.10"
- "3.13"

runs-on: ${{ matrix.os }}

Expand All @@ -30,3 +31,6 @@ jobs:

- name: Run tests
run: python -m unittest

- name: Coding style
run: python -m black --check ./
18 changes: 9 additions & 9 deletions examples/covid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@

tokenizer = Canonical(Kmer(6))

with open('covid-19-virus.fasta', 'r') as file:
for record in SeqIO.parse(file, 'fasta'):
with open("covid-19-virus.fasta", "r") as file:
for record in SeqIO.parse(file, "fasta"):
for token in tokenizer.tokenize(str(record.seq)):
hash_table.increment(token)

for sequence, count in hash_table.top(25):
print(f'{sequence}: {count}')
print(f"{sequence}: {count}")

print(f'Total sequences: {hash_table.num_sequences}')
print(f'# of unique sequences: {hash_table.num_unique_sequences}')
print(f'# of singletons: {hash_table.num_singletons}')
print(f"Total sequences: {hash_table.num_sequences}")
print(f"# of unique sequences: {hash_table.num_unique_sequences}")
print(f"# of singletons: {hash_table.num_singletons}")

plt.hist(list(hash_table.counts.values()), bins=20)
plt.title('SARS-CoV-2 Genome')
plt.xlabel('Counts')
plt.ylabel('Frequency')
plt.title("SARS-CoV-2 Genome")
plt.xlabel("Counts")
plt.ylabel("Frequency")
plt.show()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ readme = "README.md"
license = {text = "MIT"}

[project.optional-dependencies]
dev = ["mypy", "biopython", "matplotlib"]
test = ["mypy"]
dev = ["mypy", "black", "biopython", "matplotlib"]
test = ["mypy", "black"]

[project.urls]
Homepage = "https://github.com/andrewdalpino/DNAHash"
Expand Down
10 changes: 5 additions & 5 deletions src/dna_hash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
Fragment,
)

__version__ = '0.0.2'
__version__ = "0.0.2"

__all__ = [
'DNAHash',
'Kmer',
'Canonical',
'Fragment',
"DNAHash",
"Kmer",
"Canonical",
"Fragment",
]
44 changes: 24 additions & 20 deletions src/dna_hash/dna_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,33 @@

from dna_hash.tokenizers import Fragment


class DNAHash(object):
"""A specialized datastructure for counting genetic sequences for use in Bioinformatics."""

UP_BIT = 1

BASE_ENCODE_MAP = {
'A': 0,
'C': 1,
'T': 2,
'G': 3,
"A": 0,
"C": 1,
"T": 2,
"G": 3,
}

BITS_PER_BASE = max(BASE_ENCODE_MAP.values()).bit_length()

MAX_FRAGMENT_LENGTH = math.ceil(math.log(sys.maxsize, 2) / BITS_PER_BASE) - UP_BIT.bit_length()
MAX_FRAGMENT_LENGTH = (
math.ceil(math.log(sys.maxsize, 2) / BITS_PER_BASE) - UP_BIT.bit_length()
)

BASE_DECODE_MAP = {encoding: base for base, encoding in BASE_ENCODE_MAP.items()}

def __init__(self,
max_false_positive_rate: float = 0.01,
num_hashes: int = 4,
layer_size: int = 32000000) -> None:
def __init__(
self,
max_false_positive_rate: float = 0.01,
num_hashes: int = 4,
layer_size: int = 32000000,
) -> None:

self.filter = okbloomer.BloomFilter(
max_false_positive_rate=max_false_positive_rate,
Expand Down Expand Up @@ -58,7 +63,7 @@ def num_non_singletons(self) -> int:
def insert(self, sequence: str, count: int) -> None:
"""Insert a sequence count into the hash table."""
if count < 1:
raise ValueError(f'Count cannot be less than 1, {count} given.')
raise ValueError(f"Count cannot be less than 1, {count} given.")

exists = self.filter.exists_or_insert(sequence)

Expand All @@ -69,7 +74,7 @@ def insert(self, sequence: str, count: int) -> None:
self.num_singletons -= 1

self.counts[hashes] = count

elif not exists:
self.num_singletons += 1

Expand All @@ -82,7 +87,7 @@ def increment(self, sequence: str) -> None:

if hashes in self.counts:
self.counts[hashes] += 1

else:
self.num_singletons -= 1

Expand All @@ -101,12 +106,12 @@ def argmax(self) -> str:

return self._decode(hashes)

def get(self, sequence: str) ->int:
def get(self, sequence: str) -> int:
"""Return the count for a sequence."""
exists = self.filter.exists(sequence)

if not exists:
raise ValueError('Sequence not found in hash table.')
raise ValueError("Sequence not found in hash table.")

hashes = self._encode(sequence)

Expand All @@ -116,7 +121,7 @@ def get(self, sequence: str) ->int:
return 1

def top(self, k: int = 10) -> Iterator[Tuple[str, int]]:
""" Return the k sequences with the highest counts."""
"""Return the k sequences with the highest counts."""
counts = sorted(self.counts.items(), key=lambda item: item[1], reverse=True)

for hashes, count in counts[0:k]:
Expand All @@ -143,8 +148,8 @@ def _encode(self, sequence: str) -> Tuple[int, ...]:

def _decode(self, hashes: Tuple[int, ...]) -> str:
"""Decode an up2bit representation into a variable-length sequence."""
sequence = ''
sequence = ""

for hash in hashes:
if hash == self.UP_BIT:
continue
Expand All @@ -161,7 +166,6 @@ def __setitem__(self, sequence: str, count: int) -> None:

def __getitem__(self, sequence: str) -> int:
return self.get(sequence)

def __len__(self) -> int:
return self.num_unique_sequences

44 changes: 25 additions & 19 deletions src/dna_hash/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from abc import ABC, abstractmethod
from typing import Iterator

INVALID_BASE_REGEX = r'[^ACTG]'
INVALID_BASE_REGEX = r"[^ACTG]"


class Tokenizer(ABC):
"""Base tokenizer class"""
Expand All @@ -12,12 +13,13 @@ class Tokenizer(ABC):
def tokenize(self, sequence: str) -> Iterator[str]:
pass


class Kmer(Tokenizer):
"""Generates tokens of length k from reads."""

def __init__(self, k: int, skip_invalid: bool = False) -> None:
if k < 1:
raise ValueError(f'K cannot be less than 1, {k} given.')
raise ValueError(f"K cannot be less than 1, {k} given.")

self.k = k
self.skip_invalid = skip_invalid
Expand All @@ -29,16 +31,17 @@ def tokenize(self, sequence: str) -> Iterator[str]:
i = 0

while i < len(sequence) - self.k:
token = sequence[i:i + self.k]
token = sequence[i : i + self.k]

invalid_token = self.invalid_base.search(token)

if invalid_token:
if not self.skip_invalid:
offset = i + invalid_token.start()

raise ValueError('Invalid base detected at'
+ f' offset {offset} in sequence.')

raise ValueError(
"Invalid base detected at" + f" offset {offset} in sequence."
)

else:
skip = 1 + invalid_token.start()
Expand All @@ -50,29 +53,30 @@ def tokenize(self, sequence: str) -> Iterator[str]:
continue

i += 1

yield token


class Canonical(Tokenizer):
"""Tokenize sequences in their canonical form."""

BASE_COMPLIMENT_MAP = {
'A': 'T',
'T': 'A',
'C': 'G',
'G': 'C',
"A": "T",
"T": "A",
"C": "G",
"G": "C",
}

@classmethod
def reverse_complement(cls, sequence: str) -> str:
"""Return the reverse complement of a sequence."""
complement = ''
complement = ""

for i in range(len(sequence) - 1, -1, -1):
base = sequence[i]

if base not in cls.BASE_COMPLIMENT_MAP:
raise ValueError('Invalid base {base} given.')
raise ValueError("Invalid base {base} given.")

complement += cls.BASE_COMPLIMENT_MAP[base]

Expand All @@ -88,12 +92,13 @@ def tokenize(self, sequence: str) -> Iterator[str]:
for token in tokens:
yield min(token, self.reverse_complement(token))


class Fragment(Tokenizer):
"""Generates a non-overlapping fragment of length n from a sequence."""

def __init__(self, n: int, skip_invalid: bool = False) -> None:
if n < 1:
raise ValueError(f'N must be greater than 1, {n} given.')
raise ValueError(f"N must be greater than 1, {n} given.")

self.n = n
self.skip_invalid = skip_invalid
Expand All @@ -110,16 +115,17 @@ def tokenize(self, sequence: str) -> Iterator[str]:
return

for i in range(0, m, self.n):
token = sequence[i:i + self.n]
token = sequence[i : i + self.n]

invalid_token = self.invalid_base.search(token)

if invalid_token:
if not self.skip_invalid:
offset = i + invalid_token.start()

raise ValueError('Invalid base detected at'
+ f' offset {offset} in sequence.')

raise ValueError(
"Invalid base detected at" + f" offset {offset} in sequence."
)

else:
self.dropped += 1
Expand Down
34 changes: 17 additions & 17 deletions tests/test_dna_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from dna_hash import DNAHash


class TestDNAHash(TestCase):
BASES = ['A', 'C', 'T', 'G']
BASES = ["A", "C", "T", "G"]

@classmethod
def random_read(cls, k: int) -> str:
return ''.join(cls.BASES[random.randint(0, 3)] for i in range(0, k))
return "".join(cls.BASES[random.randint(0, 3)] for i in range(0, k))

def test_increment(self):
hash_table = DNAHash()
Expand All @@ -18,39 +19,39 @@ def test_increment(self):
self.assertEqual(hash_table.num_sequences, 0)
self.assertEqual(hash_table.num_unique_sequences, 0)

hash_table.increment('ACTG')
hash_table.increment("ACTG")

self.assertEqual(hash_table.num_singletons, 1)
self.assertEqual(hash_table.num_sequences, 1)
self.assertEqual(hash_table.num_unique_sequences, 1)
self.assertEqual(hash_table['ACTG'], 1)
self.assertEqual(hash_table["ACTG"], 1)

hash_table.increment('ACTG')
hash_table.increment("ACTG")

self.assertEqual(hash_table.num_singletons, 0)
self.assertEqual(hash_table.num_sequences, 2)
self.assertEqual(hash_table.num_unique_sequences, 1)
self.assertEqual(hash_table['ACTG'], 2)
self.assertEqual(hash_table["ACTG"], 2)

self.assertEqual(hash_table.max(), 2)
self.assertEqual(hash_table.argmax(), 'ACTG')
self.assertEqual(hash_table.argmax(), "ACTG")

def test_top_k(self):
hash_table = DNAHash()

hash_table['CTGA'] = 1
hash_table['ACTG'] = 10
hash_table['GCGC'] = 4
hash_table['AAAA'] = 9
hash_table['AAAT'] = 2
hash_table["CTGA"] = 1
hash_table["ACTG"] = 10
hash_table["GCGC"] = 4
hash_table["AAAA"] = 9
hash_table["AAAT"] = 2

top = list(hash_table.top(3))

self.assertEqual(len(top), 3)

self.assertEqual(top[0], ('ACTG', 10))
self.assertEqual(top[1], ('AAAA', 9))
self.assertEqual(top[2], ('GCGC', 4))
self.assertEqual(top[0], ("ACTG", 10))
self.assertEqual(top[1], ("AAAA", 9))
self.assertEqual(top[2], ("GCGC", 4))

def test_large_dataset(self):
random.seed(1)
Expand All @@ -65,7 +66,7 @@ def test_large_dataset(self):
self.assertEqual(hash_table.num_unique_sequences, 51243)

self.assertEqual(hash_table.max(), 9)
self.assertEqual(hash_table.argmax(), 'AGACTAAA')
self.assertEqual(hash_table.argmax(), "AGACTAAA")

def test_long_sequences(self):
random.seed(1)
Expand All @@ -84,4 +85,3 @@ def test_long_sequences(self):

self.assertEqual(argmax, sequence)
self.assertEqual(len(argmax), 500)

Loading

0 comments on commit 897d58f

Please sign in to comment.