1515
1616from __future__ import annotations
1717
18- import inspect
1918import logging
2019import os
2120import random
2524import tempfile
2625import textwrap
2726import types
28- from typing import Any , cast , Optional , Sequence , TYPE_CHECKING
27+ from typing import Any , cast , Optional , TYPE_CHECKING
2928import warnings
3029
3130import requests
@@ -87,7 +86,6 @@ def __init__(
8786 bq_location ,
8887 bq_dataset ,
8988 bq_client ,
90- bq_connection_id ,
9189 bq_connection_manager ,
9290 cloud_function_region = None ,
9391 cloud_functions_client = None ,
@@ -102,7 +100,6 @@ def __init__(
102100 self ._bq_location = bq_location
103101 self ._bq_dataset = bq_dataset
104102 self ._bq_client = bq_client
105- self ._bq_connection_id = bq_connection_id
106103 self ._bq_connection_manager = bq_connection_manager
107104 self ._session = session
108105
@@ -114,12 +111,12 @@ def __init__(
114111 self ._cloud_function_docker_repository = cloud_function_docker_repository
115112 self ._cloud_build_service_account = cloud_build_service_account
116113
117- def _create_bq_connection (self ) -> None :
114+ def _create_bq_connection (self , connection_id : str ) -> None :
118115 if self ._bq_connection_manager :
119116 self ._bq_connection_manager .create_bq_connection (
120117 self ._gcp_project_id ,
121118 self ._bq_location ,
122- self . _bq_connection_id ,
119+ connection_id ,
123120 "run.invoker" ,
124121 )
125122
@@ -174,7 +171,7 @@ def create_bq_remote_function(
174171 ):
175172 """Create a BigQuery remote function given the artifacts of a user defined
176173 function and the http endpoint of a corresponding cloud function."""
177- self ._create_bq_connection ()
174+ self ._create_bq_connection (udf_def . connection_id )
178175
179176 # Create BQ function
180177 # https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2
@@ -202,7 +199,7 @@ def create_bq_remote_function(
202199 create_function_ddl = f"""
203200 CREATE OR REPLACE FUNCTION `{ self ._gcp_project_id } .{ self ._bq_dataset } `.{ bq_function_name_escaped } ({ udf_def .signature .to_sql_input_signature ()} )
204201 RETURNS { udf_def .signature .with_devirtualize ().output .sql_type }
205- REMOTE WITH CONNECTION `{ self ._gcp_project_id } .{ self ._bq_location } .{ self . _bq_connection_id } `
202+ REMOTE WITH CONNECTION `{ self ._gcp_project_id } .{ self ._bq_location } .{ udf_def . connection_id } `
206203 OPTIONS ({ remote_function_options_str } )"""
207204
208205 logger .info (f"Creating BQ remote function: { create_function_ddl } " )
@@ -212,26 +209,15 @@ def create_bq_remote_function(
212209
213210 def provision_bq_managed_function (
214211 self ,
215- func ,
216- input_types : Sequence [str ],
217- output_type : str ,
218212 name : Optional [str ],
219- packages : Optional [Sequence [str ]],
220- max_batching_rows : Optional [int ],
221- container_cpu : Optional [float ],
222- container_memory : Optional [str ],
223- is_row_processor : bool ,
224- bq_connection_id ,
225- * ,
226- capture_references : bool = False ,
213+ config : udf_def .ManagedFunctionConfig ,
227214 ):
228215 """Create a BigQuery managed function."""
229216
230217 # TODO(b/406283812): Expose the capability to pass down
231218 # capture_references=True in the public udf API.
232- # TODO(b/495508827): Include all config in the value hash.
233219 if (
234- capture_references
220+ config . capture_references
235221 and (python_version := _utils .get_python_version ())
236222 != _MANAGED_FUNC_PYTHON_VERSION
237223 ):
@@ -241,31 +227,26 @@ def provision_bq_managed_function(
241227 )
242228
243229 # Create BQ managed function.
244- bq_function_args = []
245- bq_function_return_type = output_type
246-
247- input_args = inspect .getargs (func .__code__ ).args
248- # We expect the input type annotations to be 1:1 with the input args.
249- for name_ , type_ in zip (input_args , input_types ):
250- bq_function_args .append (f"{ name_ } { type_ } " )
230+ bq_function_args = config .signature .to_sql_input_signature ()
231+ bq_function_return_type = config .signature .with_devirtualize ().output .sql_type
251232
252233 managed_function_options : dict [str , Any ] = {
253234 "runtime_version" : _MANAGED_FUNC_PYTHON_VERSION ,
254235 "entry_point" : "bigframes_handler" ,
255236 }
256- if max_batching_rows :
257- managed_function_options ["max_batching_rows" ] = max_batching_rows
258- if container_cpu :
259- managed_function_options ["container_cpu" ] = container_cpu
260- if container_memory :
261- managed_function_options ["container_memory" ] = container_memory
237+ if config . max_batching_rows :
238+ managed_function_options ["max_batching_rows" ] = config . max_batching_rows
239+ if config . container_cpu :
240+ managed_function_options ["container_cpu" ] = config . container_cpu
241+ if config . container_memory :
242+ managed_function_options ["container_memory" ] = config . container_memory
262243
263244 # Augment user package requirements with any internal package
264245 # requirements.
265246 packages = _utils .get_updated_package_requirements (
266- packages or [],
267- is_row_processor ,
268- capture_references ,
247+ config . code . package_requirements or [],
248+ config . signature . is_row_processor ,
249+ config . capture_references ,
269250 ignore_package_version = True ,
270251 )
271252 if packages :
@@ -276,40 +257,34 @@ def provision_bq_managed_function(
276257
277258 bq_function_name = name
278259 if not bq_function_name :
279- # Compute a unique hash representing the user code.
280- function_hash = _utils .get_hash (func , packages )
281- bq_function_name = _utils .get_managed_function_name (
282- function_hash ,
283- # session-scope in absensce of name from user
284- # name indicates permanent allocation
285- None if name else self ._session .session_id ,
260+ # Compute a unique hash representing the artifact definition.
261+ bq_function_name = get_managed_function_name (
262+ config , self ._session .session_id
286263 )
287264
288265 persistent_func_id = (
289266 f"`{ self ._gcp_project_id } .{ self ._bq_dataset } `.{ bq_function_name } "
290267 )
291268
292- udf_name = func .__name__
293-
294269 with_connection_clause = (
295270 (
296- f"WITH CONNECTION `{ self ._gcp_project_id } .{ self ._bq_location } .{ self . _bq_connection_id } `"
271+ f"WITH CONNECTION `{ self ._gcp_project_id } .{ self ._bq_location } .{ config . bq_connection_id } `"
297272 )
298- if bq_connection_id
273+ if config . bq_connection_id
299274 else ""
300275 )
301276
302277 # Generate the complete Python code block for the managed Python UDF,
303278 # including the user's function, necessary imports, and the BigQuery
304279 # handler wrapper.
305280 python_code_block = bff_template .generate_managed_function_code (
306- func , udf_name , is_row_processor , capture_references
281+ config . code , config . signature , config . capture_references
307282 )
308283
309284 create_function_ddl = (
310285 textwrap .dedent (
311286 f"""
312- CREATE OR REPLACE FUNCTION { persistent_func_id } ({ ',' . join ( bq_function_args ) } )
287+ CREATE OR REPLACE FUNCTION { persistent_func_id } ({ bq_function_args } )
313288 RETURNS { bq_function_return_type }
314289 LANGUAGE python
315290 { with_connection_clause }
@@ -590,6 +565,7 @@ def provision_bq_remote_function(
590565 cloud_function_memory_mib : int | None ,
591566 cloud_function_cpus : float | None ,
592567 cloud_function_ingress_settings : str ,
568+ bq_connection_id : str ,
593569 ):
594570 """Provision a BigQuery remote function."""
595571 # Augment user package requirements with any internal package
@@ -657,7 +633,7 @@ def provision_bq_remote_function(
657633
658634 intended_rf_spec = udf_def .RemoteFunctionConfig (
659635 endpoint = cf_endpoint ,
660- connection_id = self . _bq_connection_id ,
636+ connection_id = bq_connection_id ,
661637 max_batching_rows = max_batching_rows or 1000 ,
662638 signature = func_signature ,
663639 bq_metadata = func_signature .protocol_metadata ,
@@ -731,6 +707,18 @@ def get_bigframes_function_name(
731707 return _BQ_FUNCTION_NAME_SEPERATOR .join (parts )
732708
733709
710+ def get_managed_function_name (
711+ function_def : udf_def .ManagedFunctionConfig ,
712+ session_id : str | None = None ,
713+ ):
714+ """Get a name for the bigframes managed function for the given user defined function."""
715+ parts = [_BIGFRAMES_FUNCTION_PREFIX ]
716+ if session_id :
717+ parts .append (session_id )
718+ parts .append (function_def .stable_hash ().hex ())
719+ return _BQ_FUNCTION_NAME_SEPERATOR .join (parts )
720+
721+
734722def _validate_routine_name (name : str ) -> None :
735723 """Validate that the given name is a valid BigQuery routine name."""
736724 # Routine IDs can contain only letters (a-z, A-Z), numbers (0-9), or underscores (_)
0 commit comments