Skip to content

Commit

Permalink
Implement variant filtering with include or exclude
Browse files Browse the repository at this point in the history
  • Loading branch information
Will-Tyler authored and tomwhite committed Sep 30, 2024
1 parent e0b9c30 commit e039823
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 7 deletions.
2 changes: 2 additions & 0 deletions tests/test_bcftools_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def test_vcf_output_with_output_option(tmp_path, args, vcf_file):
(r"query -f '%QUAL\n'", "sample.vcf.gz"),
(r"query -f '%FILTER\n'", "sample.vcf.gz"),
(r"query --format '%FILTER\n'", "1kg_2020_chrM.vcf.gz"),
(r"query -f '%POS\n' -i 'POS=112'", "sample.vcf.gz"),
(r"query -f '%POS\n' -e 'POS=112'", "sample.vcf.gz")
],
)
def test_output(tmp_path, args, vcf_name):
Expand Down
64 changes: 57 additions & 7 deletions vcztools/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyparsing as pp
import zarr

from vcztools.filter import FilterExpressionEvaluator, FilterExpressionParser
from vcztools.utils import open_file_like, vcf_name_to_vcz_name


Expand Down Expand Up @@ -54,15 +55,23 @@ def __call__(self, *args, **kwargs):


class QueryFormatGenerator:
def __init__(self, query_format: Union[str, pp.ParseResults]):
def __init__(
self,
query_format: Union[str, pp.ParseResults],
*,
include: Optional[str] = None,
exclude: Optional[str] = None,
):
if isinstance(query_format, str):
parser = QueryFormatParser()
parse_results = parser(query_format)
else:
assert isinstance(query_format, pp.ParseResults)
parse_results = query_format

self._generator = self._compose_generator(parse_results)
self._generator = self._compose_generator(
parse_results, include=include, exclude=exclude
)

def __call__(self, *args, **kwargs):
assert len(args) == 1
Expand Down Expand Up @@ -155,17 +164,58 @@ def generate(root):

return generate

def _compose_generator(self, parse_results: pp.ParseResults) -> Callable:
def _compose_filter_generator(
self, *, include: Optional[str] = None, exclude: Optional[str] = None
) -> Callable:
assert not (include and exclude)

if not include and not exclude:

def generate(root):
variant_count = root["variant_position"].shape[0]
yield from itertools.repeat(True, variant_count)

return generate

parser = FilterExpressionParser()
parse_results = parser(include or exclude)[0]
filter_evaluator = FilterExpressionEvaluator(
parse_results, invert=bool(exclude)
)

def generate(root):
nonlocal filter_evaluator

filter_evaluator = functools.partial(filter_evaluator, root)
variant_chunk_count = root["variant_position"].cdata_shape[0]

for variant_chunk_index in range(variant_chunk_count):
yield from filter_evaluator(variant_chunk_index)

return generate

def _compose_generator(
self,
parse_results: pp.ParseResults,
*,
include: Optional[str] = None,
exclude: Optional[str] = None,
) -> Callable:
generators = (
self._compose_element_generator(element) for element in parse_results
)
filter_generator = self._compose_filter_generator(
include=include, exclude=exclude
)

def generate(root) -> str:
iterables = (generator(root) for generator in generators)
filter_iterable = filter_generator(root)

for results in zip(*iterables):
results = map(str, results)
yield "".join(results)
for results, filter_indicator in zip(zip(*iterables), filter_iterable):
if filter_indicator:
results = map(str, results)
yield "".join(results)

return generate

Expand All @@ -184,7 +234,7 @@ def write_query(
)

root = zarr.open(vcz, mode="r")
generator = QueryFormatGenerator(query_format)
generator = QueryFormatGenerator(query_format, include=include, exclude=exclude)

with open_file_like(output) as output:
for result in generator(root):
Expand Down

0 comments on commit e039823

Please sign in to comment.