Skip to content

Commit

Permalink
[query] Add query_matrix_table an analogue to query_table
Browse files Browse the repository at this point in the history
CHANGELOG: Add query_matrix_table an analogue to query_table
  • Loading branch information
chrisvittal committed Jan 30, 2025
1 parent 5d9c642 commit ecaa96b
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 0 deletions.
10 changes: 10 additions & 0 deletions hail/hail/src/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ package defs {
classOf[PartitionNativeReader],
classOf[PartitionNativeReaderIndexed],
classOf[PartitionNativeIntervalReader],
classOf[PartitionZippedNativeIntervalReader],
classOf[PartitionZippedNativeReader],
classOf[PartitionZippedIndexedNativeReader],
classOf[BgenPartitionReader],
Expand Down Expand Up @@ -1216,6 +1217,15 @@ package defs {
spec,
(jv \ "uidFieldName").extract[String],
)
case "PartitionZippedNativeIntervalReader" =>
val path = (jv \ "path").extract[String]
val spec = RelationalSpec.read(ctx.fs, path).asInstanceOf[AbstractMatrixTableSpec]
PartitionZippedNativeIntervalReader(
ctx.stateManager,
path,
spec,
(jv \ "uidFieldName").extract[String],
)
case "GVCFPartitionReader" =>
val header = VCFHeaderInfo.fromJSON((jv \ "header"))
val callFields = (jv \ "callFields").extract[Set[String]]
Expand Down
34 changes: 34 additions & 0 deletions hail/hail/src/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,40 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe
}
}

case class PartitionZippedNativeIntervalReader(
sm: HailStateManager,
mtPath: String,
mtSpec: AbstractMatrixTableSpec,
uidFieldName: String,
) extends PartitionReader {
require(mtSpec.indexed)

// XXX: rows and entries paths are hardcoded, see MatrixTableSpec
private val rowsReader = PartitionNativeIntervalReader(sm, mtPath + "/rows", mtSpec.rowsSpec, "__dummy")
private val entriesReader = PartitionNativeIntervalReader(sm, mtPath + "/entries", mtSpec.entriesSpec, uidFieldName)
private val zippedReader = PartitionZippedNativeReader(rowsReader, entriesReader)

def contextType = rowsReader.contextType
def fullRowType = zippedReader.fullRowType
def rowRequiredness(requestedType: TStruct): RStruct = zippedReader.rowRequiredness(requestedType)
def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats)

def emitStream(
ctx: ExecuteContext,
cb: EmitCodeBuilder,
mb: EmitMethodBuilder[_],
_context: EmitCode,
requestedType: TStruct,
): IEmitCode = {
val zipContextType: TBaseStruct = tcoerce(zippedReader.contextType)
val contexts = FastSeq(_context, _context)
val st = SStackStruct(zipContextType, contexts.map(_.emitType))
val context = EmitCode.present(mb, st.fromEmitCodes(cb, contexts))

zippedReader.emitStream(ctx, cb, mb, context, requestedType)
}
}

case class PartitionZippedIndexedNativeReader(
specLeft: AbstractTypedCodecSpec,
specRight: AbstractTypedCodecSpec,
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@
qchisqtail,
qnorm,
qpois,
query_matrix_table,
query_table,
rand_beta,
rand_bool,
Expand Down Expand Up @@ -554,6 +555,7 @@
'_console_log',
'dnorm',
'dchisq',
'query_matrix_table',
'query_table',
'keyed_union',
'keyed_intersection',
Expand Down
66 changes: 66 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7069,6 +7069,72 @@ def coerce_endpoint(point):
aggregations=partition_interval._aggregations,
)

@typecheck(path=builtins.str, point_or_interval=expr_any, entries_name=builtins.str)
def query_matrix_table(path, point_or_interval, entries_name='entries_array'):
matrix_table = hl.read_matrix_table(path)
if entries_name in matrix_table.row:
raise ValueError(f'field "{entries_name}" is present in matrix table row fields, pick a different `entries_name`')

entries_table = hl.read_table(path + '/entries')
[entry_id] = list(entries_table.row)

full_row_type = tstruct(**matrix_table.row.dtype, **entries_table.row.dtype)
key_typ = matrix_table.row_key.dtype
key_names = list(key_typ)
len = builtins.len

def coerce_endpoint(point):
if point.dtype == key_typ[0]:
point = hl.struct(**{key_names[0]: point})
ts = point.dtype
if isinstance(ts, tstruct):
i = 0
while i < len(ts):
if i >= len(key_typ):
raise ValueError(
f"query_table: queried with {len(ts)} key field(s), but table only has {len(key_typ)} key field(s)"
)
if key_typ[i] != ts[i]:
raise ValueError(
f"query_table: key mismatch at key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, table key type is {key_typ[i]}"
)
i += 1

if i == 0:
raise ValueError("query_table: cannot query with empty key")

point_size = builtins.len(point.dtype)
return hl.tuple([
hl.struct(**{
key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i]))
for i in builtins.range(builtins.len(key_typ))
}),
hl.int32(point_size),
])
else:
raise ValueError(
f"query_table: key mismatch: cannot query a table with key "
f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}"
)

if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval):
partition_interval = hl.interval(
start=coerce_endpoint(point_or_interval.start),
end=coerce_endpoint(point_or_interval.end),
includes_start=point_or_interval.includes_start,
includes_end=point_or_interval.includes_end,
)
else:
point = coerce_endpoint(point_or_interval)
partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True)
read_part_ir = ir.ReadPartition(partition_interval._ir, reader=ir.PartitionZippedNativeIntervalReader(path, full_row_type))
stream_expr = construct_expr(
read_part_ir,
type=hl.tstream(full_row_type),
indices=partition_interval._indices,
aggregations=partition_interval._aggregations,
)
return stream_expr.map(lambda item: item.rename({entry_id: entries_name})).to_array()

@typecheck(msg=expr_str, result=expr_any)
def _console_log(msg, result):
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
NDArraySVD,
NDArrayWrite,
PartitionNativeIntervalReader,
PartitionZippedNativeIntervalReader,
ProjectedTopLevelReference,
ReadPartition,
Recur,
Expand Down Expand Up @@ -527,6 +528,7 @@
'TableNativeFanoutWriter',
'ReadPartition',
'PartitionNativeIntervalReader',
'PartitionZippedNativeIntervalReader',
'GVCFPartitionReader',
'TableGen',
'Partitioner',
Expand Down
31 changes: 31 additions & 0 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,37 @@ def row_type(self):
return tstruct(**self.table_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class PartitionZippedNativeIntervalReader(PartitionReader):
def __init__(self, path, full_row_type, uid_field=None):
self.path = path
self.full_row_type = full_row_type
self.uid_field = uid_field

def with_uid_field(self, uid_field):
return PartitionZippedNativeIntervalReader(path=self.path, uid_field=uid_field)

def render(self):
return escape_str(
json.dumps({
"name": "PartitionZippedNativeIntervalReader",
"path": self.path,
"uidFieldName": self.uid_field if self.uid_field is not None else '__dummy',
})
)

def _eq(self, other):
return (
isinstance(other, PartitionZippedNativeIntervalReader)
and self.path == other.path
and self.uid_field == other.uid_field
)

def row_type(self):
if self.uid_field is None:
return self.full_row_type
return tstruct(**self.full_row_type, **{self.uid_field: ttuple(tint64, tint64)})


class ReadPartition(IR):
@typecheck_method(context=IR, reader=PartitionReader)
def __init__(self, context, reader):
Expand Down

0 comments on commit ecaa96b

Please sign in to comment.