From ecaa96bc4c869fe3b2b4474e7f04a12958e78c0b Mon Sep 17 00:00:00 2001 From: Christopher Vittal Date: Thu, 30 Jan 2025 15:43:19 -0500 Subject: [PATCH] [query] Add query_matrix_table an analogue to query_table CHANGELOG: Add query_matrix_table an analogue to query_table --- hail/hail/src/is/hail/expr/ir/IR.scala | 10 ++++ hail/hail/src/is/hail/expr/ir/TableIR.scala | 34 +++++++++++ hail/python/hail/expr/__init__.py | 2 + hail/python/hail/expr/functions.py | 66 +++++++++++++++++++++ hail/python/hail/ir/__init__.py | 2 + hail/python/hail/ir/ir.py | 31 ++++++++++ 6 files changed, 145 insertions(+) diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index 32dcf26f3f6..470c338f521 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -1186,6 +1186,7 @@ package defs { classOf[PartitionNativeReader], classOf[PartitionNativeReaderIndexed], classOf[PartitionNativeIntervalReader], + classOf[PartitionZippedNativeIntervalReader], classOf[PartitionZippedNativeReader], classOf[PartitionZippedIndexedNativeReader], classOf[BgenPartitionReader], @@ -1216,6 +1217,15 @@ package defs { spec, (jv \ "uidFieldName").extract[String], ) + case "PartitionZippedNativeIntervalReader" => + val path = (jv \ "path").extract[String] + val spec = RelationalSpec.read(ctx.fs, path).asInstanceOf[AbstractMatrixTableSpec] + PartitionZippedNativeIntervalReader( + ctx.stateManager, + path, + spec, + (jv \ "uidFieldName").extract[String], + ) case "GVCFPartitionReader" => val header = VCFHeaderInfo.fromJSON((jv \ "header")) val callFields = (jv \ "callFields").extract[Set[String]] diff --git a/hail/hail/src/is/hail/expr/ir/TableIR.scala b/hail/hail/src/is/hail/expr/ir/TableIR.scala index 27404c71c01..4672b232abd 100644 --- a/hail/hail/src/is/hail/expr/ir/TableIR.scala +++ b/hail/hail/src/is/hail/expr/ir/TableIR.scala @@ -1509,6 +1509,40 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe } } +case class PartitionZippedNativeIntervalReader( + sm: HailStateManager, + mtPath: String, + mtSpec: AbstractMatrixTableSpec, + uidFieldName: String, +) extends PartitionReader { + require(mtSpec.indexed) + + // XXX: rows and entries paths are hardcoded, see MatrixTableSpec + private val rowsReader = PartitionNativeIntervalReader(sm, mtPath + "/rows", mtSpec.rowsSpec, "__dummy") + private val entriesReader = PartitionNativeIntervalReader(sm, mtPath + "/entries", mtSpec.entriesSpec, uidFieldName) + private val zippedReader = PartitionZippedNativeReader(rowsReader, entriesReader) + + def contextType = rowsReader.contextType + def fullRowType = zippedReader.fullRowType + def rowRequiredness(requestedType: TStruct): RStruct = zippedReader.rowRequiredness(requestedType) + def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats) + + def emitStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + mb: EmitMethodBuilder[_], + _context: EmitCode, + requestedType: TStruct, + ): IEmitCode = { + val zipContextType: TBaseStruct = tcoerce(zippedReader.contextType) + val contexts = FastSeq(_context, _context) + val st = SStackStruct(zipContextType, contexts.map(_.emitType)) + val context = EmitCode.present(mb, st.fromEmitCodes(cb, contexts)) + + zippedReader.emitStream(ctx, cb, mb, context, requestedType) + } +} + case class PartitionZippedIndexedNativeReader( specLeft: AbstractTypedCodecSpec, specRight: AbstractTypedCodecSpec, diff --git a/hail/python/hail/expr/__init__.py b/hail/python/hail/expr/__init__.py index bff07932fd9..f3c8b09e658 100644 --- a/hail/python/hail/expr/__init__.py +++ b/hail/python/hail/expr/__init__.py @@ -210,6 +210,7 @@ qchisqtail, qnorm, qpois, + query_matrix_table, query_table, rand_beta, rand_bool, @@ -554,6 +555,7 @@ '_console_log', 'dnorm', 'dchisq', + 'query_matrix_table', 'query_table', 'keyed_union', 'keyed_intersection', diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 0558be5d0f5..dfb271907e2 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -7069,6 +7069,72 @@ def coerce_endpoint(point): aggregations=partition_interval._aggregations, ) +@typecheck(path=builtins.str, point_or_interval=expr_any, entries_name=builtins.str) +def query_matrix_table(path, point_or_interval, entries_name='entries_array'): + matrix_table = hl.read_matrix_table(path) + if entries_name in matrix_table.row: + raise ValueError(f'field "{entries_name}" is present in matrix table row fields, pick a different `entries_name`') + + entries_table = hl.read_table(path + '/entries') + [entry_id] = list(entries_table.row) + + full_row_type = tstruct(**matrix_table.row.dtype, **entries_table.row.dtype) + key_typ = matrix_table.row_key.dtype + key_names = list(key_typ) + len = builtins.len + + def coerce_endpoint(point): + if point.dtype == key_typ[0]: + point = hl.struct(**{key_names[0]: point}) + ts = point.dtype + if isinstance(ts, tstruct): + i = 0 + while i < len(ts): + if i >= len(key_typ): + raise ValueError( + f"query_table: queried with {len(ts)} key field(s), but table only has {len(key_typ)} key field(s)" + ) + if key_typ[i] != ts[i]: + raise ValueError( + f"query_table: key mismatch at key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, table key type is {key_typ[i]}" + ) + i += 1 + + if i == 0: + raise ValueError("query_table: cannot query with empty key") + + point_size = builtins.len(point.dtype) + return hl.tuple([ + hl.struct(**{ + key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i])) + for i in builtins.range(builtins.len(key_typ)) + }), + hl.int32(point_size), + ]) + else: + raise ValueError( + f"query_table: key mismatch: cannot query a table with key " + f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}" + ) + + if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval): + partition_interval = hl.interval( + start=coerce_endpoint(point_or_interval.start), + end=coerce_endpoint(point_or_interval.end), + includes_start=point_or_interval.includes_start, + includes_end=point_or_interval.includes_end, + ) + else: + point = coerce_endpoint(point_or_interval) + partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True) + read_part_ir = ir.ReadPartition(partition_interval._ir, reader=ir.PartitionZippedNativeIntervalReader(path, full_row_type)) + stream_expr = construct_expr( + read_part_ir, + type=hl.tstream(full_row_type), + indices=partition_interval._indices, + aggregations=partition_interval._aggregations, + ) + return stream_expr.map(lambda item: item.rename({entry_id: entries_name})).to_array() @typecheck(msg=expr_str, result=expr_any) def _console_log(msg, result): diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index 2325598a9f0..c4dcdff3ae5 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -113,6 +113,7 @@ NDArraySVD, NDArrayWrite, PartitionNativeIntervalReader, + PartitionZippedNativeIntervalReader, ProjectedTopLevelReference, ReadPartition, Recur, @@ -527,6 +528,7 @@ 'TableNativeFanoutWriter', 'ReadPartition', 'PartitionNativeIntervalReader', + 'PartitionZippedNativeIntervalReader', 'GVCFPartitionReader', 'TableGen', 'Partitioner', diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 2bef587fc1d..851a86cd5e3 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -3510,6 +3510,37 @@ def row_type(self): return tstruct(**self.table_row_type, **{self.uid_field: ttuple(tint64, tint64)}) +class PartitionZippedNativeIntervalReader(PartitionReader): + def __init__(self, path, full_row_type, uid_field=None): + self.path = path + self.full_row_type = full_row_type + self.uid_field = uid_field + + def with_uid_field(self, uid_field): + return PartitionZippedNativeIntervalReader(path=self.path, uid_field=uid_field) + + def render(self): + return escape_str( + json.dumps({ + "name": "PartitionZippedNativeIntervalReader", + "path": self.path, + "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy', + }) + ) + + def _eq(self, other): + return ( + isinstance(other, PartitionZippedNativeIntervalReader) + and self.path == other.path + and self.uid_field == other.uid_field + ) + + def row_type(self): + if self.uid_field is None: + return self.full_row_type + return tstruct(**self.full_row_type, **{self.uid_field: ttuple(tint64, tint64)}) + + class ReadPartition(IR): @typecheck_method(context=IR, reader=PartitionReader) def __init__(self, context, reader):