Skip to content

Commit

Permalink
ENH: Prefer using visit_Constant (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhlegarreta authored Jan 14, 2025
1 parent 8d108e2 commit 34cb131
Showing 1 changed file with 40 additions and 33 deletions.
73 changes: 40 additions & 33 deletions tract_querier/query_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import numbers
from os import path
from copy import deepcopy
from operator import lt, gt
Expand Down Expand Up @@ -240,12 +241,6 @@ def visit_UnaryOp(self, node):
raise TractQuerierSyntaxError(
"Syntax error in query line %d" % node.lineno)

def visit_Str(self, node):
query_info = FiberQueryInfo()
for name in fnmatch.filter(self.evaluated_queries_info.keys(), node.s):
query_info.update(self.evaluated_queries_info[name])
return query_info

def visit_Call(self, node):
# Single string argument function
if (
Expand Down Expand Up @@ -558,31 +553,40 @@ def visit_Attribute(self, node):
(node.lineno, query_name)
)

def visit_Num(self, node):
if (
node.n in
self.tractography_spatial_indexing.crossing_labels_tracts
):
tracts = (
self.tractography_spatial_indexing.
crossing_labels_tracts[node.n]
def visit_Constant(self, node):
if isinstance(node.value, numbers.Number):
if (
node.n in
self.tractography_spatial_indexing.crossing_labels_tracts
):
tracts = (
self.tractography_spatial_indexing.
crossing_labels_tracts[node.n]
)
else:
tracts = set()

endpoints = (set(), set())
for i in (0, 1):
elt = self.tractography_spatial_indexing.ending_labels_tracts[i]
if node.n in elt:
endpoints[i].update(elt[node.n])

labelset = set((node.n,))
query_info = FiberQueryInfo(
tracts, labelset,
endpoints
)
else:
tracts = set()

endpoints = (set(), set())
for i in (0, 1):
elt = self.tractography_spatial_indexing.ending_labels_tracts[i]
if node.n in elt:
endpoints[i].update(elt[node.n])

labelset = set((node.n,))
tract_info = FiberQueryInfo(
tracts, labelset,
endpoints
)
elif isinstance(node.value, str):
query_info = FiberQueryInfo()
for name in fnmatch.filter(self.evaluated_queries_info.keys(),
node.s):
query_info.update(self.evaluated_queries_info[name])
else:
raise NotImplementedError(f"{node.value} not supported.")

return tract_info
return query_info

def visit_Expr(self, node):
if isinstance(node.value, ast.Name):
Expand Down Expand Up @@ -735,11 +739,14 @@ def visit_Name(self, node):
node
)

def visit_Str(self, node):
return ast.copy_location(
ast.Str(s=node.s.lower()),
node
)
def visit_Constant(self, node):
if isinstance(node.s, str):
return ast.copy_location(
ast.Constant(node.s.lower()),
node
)
else:
return self.generic_visit(node)

def visit_Import(self, node):
try:
Expand Down

0 comments on commit 34cb131

Please sign in to comment.