Skip to content

Commit eaa845a

Browse files
Merge remote-tracking branch 'github/main' into read_orc
2 parents 9af47af + af49ca2 commit eaa845a

File tree

14 files changed

+900
-233
lines changed

14 files changed

+900
-233
lines changed

bigframes/functions/_function_client.py

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from __future__ import annotations
1717

18-
import inspect
1918
import logging
2019
import os
2120
import random
@@ -25,7 +24,7 @@
2524
import tempfile
2625
import textwrap
2726
import types
28-
from typing import Any, cast, Optional, Sequence, TYPE_CHECKING
27+
from typing import Any, cast, Optional, TYPE_CHECKING
2928
import warnings
3029

3130
import requests
@@ -87,7 +86,6 @@ def __init__(
8786
bq_location,
8887
bq_dataset,
8988
bq_client,
90-
bq_connection_id,
9189
bq_connection_manager,
9290
cloud_function_region=None,
9391
cloud_functions_client=None,
@@ -102,7 +100,6 @@ def __init__(
102100
self._bq_location = bq_location
103101
self._bq_dataset = bq_dataset
104102
self._bq_client = bq_client
105-
self._bq_connection_id = bq_connection_id
106103
self._bq_connection_manager = bq_connection_manager
107104
self._session = session
108105

@@ -114,12 +111,12 @@ def __init__(
114111
self._cloud_function_docker_repository = cloud_function_docker_repository
115112
self._cloud_build_service_account = cloud_build_service_account
116113

117-
def _create_bq_connection(self) -> None:
114+
def _create_bq_connection(self, connection_id: str) -> None:
118115
if self._bq_connection_manager:
119116
self._bq_connection_manager.create_bq_connection(
120117
self._gcp_project_id,
121118
self._bq_location,
122-
self._bq_connection_id,
119+
connection_id,
123120
"run.invoker",
124121
)
125122

@@ -174,7 +171,7 @@ def create_bq_remote_function(
174171
):
175172
"""Create a BigQuery remote function given the artifacts of a user defined
176173
function and the http endpoint of a corresponding cloud function."""
177-
self._create_bq_connection()
174+
self._create_bq_connection(udf_def.connection_id)
178175

179176
# Create BQ function
180177
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2
@@ -202,7 +199,7 @@ def create_bq_remote_function(
202199
create_function_ddl = f"""
203200
CREATE OR REPLACE FUNCTION `{self._gcp_project_id}.{self._bq_dataset}`.{bq_function_name_escaped}({udf_def.signature.to_sql_input_signature()})
204201
RETURNS {udf_def.signature.with_devirtualize().output.sql_type}
205-
REMOTE WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`
202+
REMOTE WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{udf_def.connection_id}`
206203
OPTIONS ({remote_function_options_str})"""
207204

208205
logger.info(f"Creating BQ remote function: {create_function_ddl}")
@@ -212,26 +209,15 @@ def create_bq_remote_function(
212209

213210
def provision_bq_managed_function(
214211
self,
215-
func,
216-
input_types: Sequence[str],
217-
output_type: str,
218212
name: Optional[str],
219-
packages: Optional[Sequence[str]],
220-
max_batching_rows: Optional[int],
221-
container_cpu: Optional[float],
222-
container_memory: Optional[str],
223-
is_row_processor: bool,
224-
bq_connection_id,
225-
*,
226-
capture_references: bool = False,
213+
config: udf_def.ManagedFunctionConfig,
227214
):
228215
"""Create a BigQuery managed function."""
229216

230217
# TODO(b/406283812): Expose the capability to pass down
231218
# capture_references=True in the public udf API.
232-
# TODO(b/495508827): Include all config in the value hash.
233219
if (
234-
capture_references
220+
config.capture_references
235221
and (python_version := _utils.get_python_version())
236222
!= _MANAGED_FUNC_PYTHON_VERSION
237223
):
@@ -241,31 +227,26 @@ def provision_bq_managed_function(
241227
)
242228

243229
# Create BQ managed function.
244-
bq_function_args = []
245-
bq_function_return_type = output_type
246-
247-
input_args = inspect.getargs(func.__code__).args
248-
# We expect the input type annotations to be 1:1 with the input args.
249-
for name_, type_ in zip(input_args, input_types):
250-
bq_function_args.append(f"{name_} {type_}")
230+
bq_function_args = config.signature.to_sql_input_signature()
231+
bq_function_return_type = config.signature.with_devirtualize().output.sql_type
251232

252233
managed_function_options: dict[str, Any] = {
253234
"runtime_version": _MANAGED_FUNC_PYTHON_VERSION,
254235
"entry_point": "bigframes_handler",
255236
}
256-
if max_batching_rows:
257-
managed_function_options["max_batching_rows"] = max_batching_rows
258-
if container_cpu:
259-
managed_function_options["container_cpu"] = container_cpu
260-
if container_memory:
261-
managed_function_options["container_memory"] = container_memory
237+
if config.max_batching_rows:
238+
managed_function_options["max_batching_rows"] = config.max_batching_rows
239+
if config.container_cpu:
240+
managed_function_options["container_cpu"] = config.container_cpu
241+
if config.container_memory:
242+
managed_function_options["container_memory"] = config.container_memory
262243

263244
# Augment user package requirements with any internal package
264245
# requirements.
265246
packages = _utils.get_updated_package_requirements(
266-
packages or [],
267-
is_row_processor,
268-
capture_references,
247+
config.code.package_requirements or [],
248+
config.signature.is_row_processor,
249+
config.capture_references,
269250
ignore_package_version=True,
270251
)
271252
if packages:
@@ -276,40 +257,34 @@ def provision_bq_managed_function(
276257

277258
bq_function_name = name
278259
if not bq_function_name:
279-
# Compute a unique hash representing the user code.
280-
function_hash = _utils.get_hash(func, packages)
281-
bq_function_name = _utils.get_managed_function_name(
282-
function_hash,
283-
# session-scope in absensce of name from user
284-
# name indicates permanent allocation
285-
None if name else self._session.session_id,
260+
# Compute a unique hash representing the artifact definition.
261+
bq_function_name = get_managed_function_name(
262+
config, self._session.session_id
286263
)
287264

288265
persistent_func_id = (
289266
f"`{self._gcp_project_id}.{self._bq_dataset}`.{bq_function_name}"
290267
)
291268

292-
udf_name = func.__name__
293-
294269
with_connection_clause = (
295270
(
296-
f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`"
271+
f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{config.bq_connection_id}`"
297272
)
298-
if bq_connection_id
273+
if config.bq_connection_id
299274
else ""
300275
)
301276

302277
# Generate the complete Python code block for the managed Python UDF,
303278
# including the user's function, necessary imports, and the BigQuery
304279
# handler wrapper.
305280
python_code_block = bff_template.generate_managed_function_code(
306-
func, udf_name, is_row_processor, capture_references
281+
config.code, config.signature, config.capture_references
307282
)
308283

309284
create_function_ddl = (
310285
textwrap.dedent(
311286
f"""
312-
CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)})
287+
CREATE OR REPLACE FUNCTION {persistent_func_id}({bq_function_args})
313288
RETURNS {bq_function_return_type}
314289
LANGUAGE python
315290
{with_connection_clause}
@@ -590,6 +565,7 @@ def provision_bq_remote_function(
590565
cloud_function_memory_mib: int | None,
591566
cloud_function_cpus: float | None,
592567
cloud_function_ingress_settings: str,
568+
bq_connection_id: str,
593569
):
594570
"""Provision a BigQuery remote function."""
595571
# Augment user package requirements with any internal package
@@ -657,7 +633,7 @@ def provision_bq_remote_function(
657633

658634
intended_rf_spec = udf_def.RemoteFunctionConfig(
659635
endpoint=cf_endpoint,
660-
connection_id=self._bq_connection_id,
636+
connection_id=bq_connection_id,
661637
max_batching_rows=max_batching_rows or 1000,
662638
signature=func_signature,
663639
bq_metadata=func_signature.protocol_metadata,
@@ -731,6 +707,18 @@ def get_bigframes_function_name(
731707
return _BQ_FUNCTION_NAME_SEPERATOR.join(parts)
732708

733709

710+
def get_managed_function_name(
711+
function_def: udf_def.ManagedFunctionConfig,
712+
session_id: str | None = None,
713+
):
714+
"""Get a name for the bigframes managed function for the given user defined function."""
715+
parts = [_BIGFRAMES_FUNCTION_PREFIX]
716+
if session_id:
717+
parts.append(session_id)
718+
parts.append(function_def.stable_hash().hex())
719+
return _BQ_FUNCTION_NAME_SEPERATOR.join(parts)
720+
721+
734722
def _validate_routine_name(name: str) -> None:
735723
"""Validate that the given name is a valid BigQuery routine name."""
736724
# Routine IDs can contain only letters (a-z, A-Z), numbers (0-9), or underscores (_)

bigframes/functions/_function_session.py

Lines changed: 44 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -556,34 +556,13 @@ def wrapper(func):
556556
func,
557557
**signature_kwargs,
558558
)
559-
if input_types is not None:
560-
if not isinstance(input_types, collections.abc.Sequence):
561-
input_types = [input_types]
562-
if _utils.has_conflict_input_type(py_sig, input_types):
563-
msg = bfe.format_message(
564-
"Conflicting input types detected, using the one from the decorator."
565-
)
566-
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
567-
py_sig = py_sig.replace(
568-
parameters=[
569-
par.replace(annotation=itype)
570-
for par, itype in zip(py_sig.parameters.values(), input_types)
571-
]
572-
)
573-
if output_type:
574-
if _utils.has_conflict_output_type(py_sig, output_type):
575-
msg = bfe.format_message(
576-
"Conflicting return type detected, using the one from the decorator."
577-
)
578-
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
579-
py_sig = py_sig.replace(return_annotation=output_type)
559+
py_sig = _resolve_signature(py_sig, input_types, output_type)
580560

581561
remote_function_client = _function_client.FunctionClient(
582562
dataset_ref.project,
583563
bq_location,
584564
dataset_ref.dataset_id,
585565
bigquery_client,
586-
bq_connection_id,
587566
bq_connection_manager,
588567
cloud_function_region,
589568
cloud_functions_client,
@@ -618,6 +597,7 @@ def wrapper(func):
618597
cloud_function_memory_mib=cloud_function_memory_mib,
619598
cloud_function_cpus=cloud_function_cpus,
620599
cloud_function_ingress_settings=cloud_function_ingress_settings,
600+
bq_connection_id=bq_connection_id,
621601
)
622602

623603
bigframes_cloud_function = (
@@ -840,27 +820,7 @@ def wrapper(func):
840820
func,
841821
**signature_kwargs,
842822
)
843-
if input_types is not None:
844-
if not isinstance(input_types, collections.abc.Sequence):
845-
input_types = [input_types]
846-
if _utils.has_conflict_input_type(py_sig, input_types):
847-
msg = bfe.format_message(
848-
"Conflicting input types detected, using the one from the decorator."
849-
)
850-
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
851-
py_sig = py_sig.replace(
852-
parameters=[
853-
par.replace(annotation=itype)
854-
for par, itype in zip(py_sig.parameters.values(), input_types)
855-
]
856-
)
857-
if output_type:
858-
if _utils.has_conflict_output_type(py_sig, output_type):
859-
msg = bfe.format_message(
860-
"Conflicting return type detected, using the one from the decorator."
861-
)
862-
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
863-
py_sig = py_sig.replace(return_annotation=output_type)
823+
py_sig = _resolve_signature(py_sig, input_types, output_type)
864824

865825
# The function will actually be receiving a pandas Series, but allow
866826
# both BigQuery DataFrames and pandas object types for compatibility.
@@ -872,22 +832,22 @@ def wrapper(func):
872832
bq_location,
873833
dataset_ref.dataset_id,
874834
bigquery_client,
875-
bq_connection_id,
876835
bq_connection_manager,
877836
session=session, # type: ignore
878837
)
879-
880-
bq_function_name = managed_function_client.provision_bq_managed_function(
881-
func=func,
882-
input_types=tuple(arg.sql_type for arg in udf_sig.inputs),
883-
output_type=udf_sig.output.sql_type,
884-
name=name,
885-
packages=packages,
838+
config = udf_def.ManagedFunctionConfig(
839+
code=udf_def.CodeDef.from_func(func),
840+
signature=udf_sig,
886841
max_batching_rows=max_batching_rows,
887842
container_cpu=container_cpu,
888843
container_memory=container_memory,
889-
is_row_processor=udf_sig.is_row_processor,
890844
bq_connection_id=bq_connection_id,
845+
capture_references=False,
846+
)
847+
848+
bq_function_name = managed_function_client.provision_bq_managed_function(
849+
name=name,
850+
config=config,
891851
)
892852
full_rf_name = (
893853
managed_function_client.get_remote_function_fully_qualilfied_name(
@@ -907,12 +867,14 @@ def wrapper(func):
907867
if udf_sig.is_row_processor:
908868
msg = bfe.format_message("input_types=Series is in preview.")
909869
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
870+
assert session is not None # appease mypy
910871
return decorator(
911872
bq_functions.BigqueryCallableRowRoutine(
912873
udf_definition, session, local_func=func, is_managed=True
913874
)
914875
)
915876
else:
877+
assert session is not None # appease mypy
916878
return decorator(
917879
bq_functions.BigqueryCallableRoutine(
918880
udf_definition,
@@ -949,3 +911,33 @@ def deploy_udf(
949911
# TODO(tswast): If we update udf to defer deployment, update this method
950912
# to deploy immediately.
951913
return self.udf(**kwargs)(func)
914+
915+
916+
def _resolve_signature(
917+
py_sig: inspect.Signature,
918+
input_types: Union[None, type, Sequence[type]] = None,
919+
output_type: Optional[type] = None,
920+
) -> inspect.Signature:
921+
if input_types is not None:
922+
if not isinstance(input_types, collections.abc.Sequence):
923+
input_types = [input_types]
924+
if _utils.has_conflict_input_type(py_sig, input_types):
925+
msg = bfe.format_message(
926+
"Conflicting input types detected, using the one from the decorator."
927+
)
928+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
929+
py_sig = py_sig.replace(
930+
parameters=[
931+
par.replace(annotation=itype)
932+
for par, itype in zip(py_sig.parameters.values(), input_types)
933+
]
934+
)
935+
if output_type:
936+
if _utils.has_conflict_output_type(py_sig, output_type):
937+
msg = bfe.format_message(
938+
"Conflicting return type detected, using the one from the decorator."
939+
)
940+
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
941+
py_sig = py_sig.replace(return_annotation=output_type)
942+
943+
return py_sig

bigframes/functions/_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -186,18 +186,6 @@ def routine_ref_to_string_for_query(routine_ref: bigquery.RoutineReference) -> s
186186
return f"`{routine_ref.project}.{routine_ref.dataset_id}`.{routine_ref.routine_id}"
187187

188188

189-
def get_managed_function_name(
190-
function_hash: str,
191-
session_id: str | None = None,
192-
):
193-
"""Get a name for the bigframes managed function for the given user defined function."""
194-
parts = [_BIGFRAMES_FUNCTION_PREFIX]
195-
if session_id:
196-
parts.append(session_id)
197-
parts.append(function_hash)
198-
return _BQ_FUNCTION_NAME_SEPERATOR.join(parts)
199-
200-
201189
# Deprecated: Use CodeDef.stable_hash() instead.
202190
def get_hash(def_, package_requirements=None):
203191
"Get hash (32 digits alphanumeric) of a function."

0 commit comments

Comments
 (0)