diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e1de29426bf3..63338016744a0 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -214,6 +214,13 @@ def input_type(self) -> Optional[DataType]: return None return _parse_datatype_json_string(input_type) + @property + def table_arg_offsets(self) -> Optional[list[int]]: + offsets = self.get("table_arg_offsets", None) + if offsets is None: + return None + return [int(x) for x in offsets.split(",") if x] + def report_times(outfile, boot, init, finish, processing_time_ms): write_int(SpecialLengths.TIMING_DATA, outfile) @@ -1258,15 +1265,14 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): # It expects the UDTF to be in a specific format and performs various checks to # ensure the UDTF is valid. This function also prepares a mapper function for applying # the UDTF logic to input rows. -def read_udtf(pickleSer, infile, eval_type, runner_conf): +def read_udtf(pickleSer, infile, eval_type, runner_conf, eval_conf): if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: - input_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if runner_conf.use_legacy_pandas_udtf_conversion: # NOTE: if timezone is set here, that implies respectSessionTimeZone is True ser = ArrowStreamPandasUDTFSerializer( timezone=runner_conf.timezone, safecheck=runner_conf.safecheck, - input_type=input_type, + input_type=eval_conf.input_type, prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype, int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) @@ -1274,10 +1280,8 @@ def read_udtf(pickleSer, infile, eval_type, runner_conf): ser = ArrowStreamUDTFSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: # Read the table argument offsets - num_table_arg_offsets = read_int(infile) - table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] # Use PyArrow-native serializer for Arrow UDTFs with potential UDT support - ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) + ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=eval_conf.table_arg_offsets) else: # Each row is a group so do not batch but send one by one. ser = BatchedSerializer(CPickleSerializer(), 1) @@ -2185,7 +2189,7 @@ def mapper(_, it): none_on_identity=True, binary_as_bytes=runner_conf.binary_as_bytes, ) - for f in input_type + for f in eval_conf.input_type ] for a in it: pylist = [ @@ -3525,7 +3529,7 @@ def main(infile, outfile): PythonEvalType.SQL_ARROW_UDTF, ): func, profiler, deserializer, serializer = read_udtf( - pickleSer, infile, eval_type, runner_conf + pickleSer, infile, eval_type, runner_conf, eval_conf ) else: func, profiler, deserializer, serializer = read_udfs( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 818a05cbdd667..a3c9f2f459266 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -53,21 +53,22 @@ class ArrowPythonUDTFRunner( override protected def runnerConf: Map[String, String] = super.runnerConf ++ pythonRunnerConf - override protected def writeUDF(dataOut: DataOutputStream): Unit = { - // For arrow-optimized Python UDTFs (@udtf(useArrow=True)), we need to write - // the schema to the worker to support UDT (user-defined type). - // Currently, UDT is not supported in PyArrow native UDTFs (arrow_udf) + override protected def evalConf: Map[String, String] = { if (evalType == PythonEvalType.SQL_ARROW_TABLE_UDF) { - PythonWorkerUtils.writeUTF(schema.json, dataOut) - } - // Write the table argument offsets for Arrow UDTFs. - else if (evalType == PythonEvalType.SQL_ARROW_UDTF) { + super.evalConf ++ Map( + "input_type" -> schema.json + ) + } else { val tableArgOffsets = argMetas.collect { case ArgumentMetadata(offset, _, isTableArg) if isTableArg => offset } - dataOut.writeInt(tableArgOffsets.length) - tableArgOffsets.foreach(dataOut.writeInt(_)) + super.evalConf ++ Map( + "table_arg_offsets" -> tableArgOffsets.mkString(",") + ) } + } + + override protected def writeUDF(dataOut: DataOutputStream): Unit = { PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) }