Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit e4e5347

Browse files
committed
Merge branch 'main' into shuowei-fix-compiler-syntax-guards
2 parents 05805aa + a000425 commit e4e5347

File tree

12 files changed

+57
-25
lines changed

12 files changed

+57
-25
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def if_(
745745
or pandas Series.
746746
connection_id (str, optional):
747747
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
748-
If not provided, the connection from the current session will be used.
748+
If not provided, the query uses your end-user credential.
749749
750750
Returns:
751751
bigframes.series.Series: A new series of bools.
@@ -756,7 +756,7 @@ def if_(
756756

757757
operator = ai_ops.AIIf(
758758
prompt_context=tuple(prompt_context),
759-
connection_id=_resolve_connection_id(series_list[0], connection_id),
759+
connection_id=connection_id,
760760
)
761761

762762
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -800,7 +800,7 @@ def classify(
800800
Categories to classify the input into.
801801
connection_id (str, optional):
802802
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
803-
If not provided, the connection from the current session will be used.
803+
If not provided, the query uses your end-user credential.
804804
805805
Returns:
806806
bigframes.series.Series: A new series of strings.
@@ -812,7 +812,7 @@ def classify(
812812
operator = ai_ops.AIClassify(
813813
prompt_context=tuple(prompt_context),
814814
categories=tuple(categories),
815-
connection_id=_resolve_connection_id(series_list[0], connection_id),
815+
connection_id=connection_id,
816816
)
817817

818818
return series_list[0]._apply_nary_op(operator, series_list[1:])
@@ -853,7 +853,7 @@ def score(
853853
or pandas Series.
854854
connection_id (str, optional):
855855
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
856-
If not provided, the connection from the current session will be used.
856+
If not provided, the query uses your end-user credential.
857857
858858
Returns:
859859
bigframes.series.Series: A new series of double (float) values.
@@ -864,7 +864,7 @@ def score(
864864

865865
operator = ai_ops.AIScore(
866866
prompt_context=tuple(prompt_context),
867-
connection_id=_resolve_connection_id(series_list[0], connection_id),
867+
connection_id=connection_id,
868868
)
869869

870870
return series_list[0]._apply_nary_op(operator, series_list[1:])

bigframes/core/bigframe_node.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,32 @@ def top_down(
330330
"""
331331
Perform a top-down transformation of the BigFrameNode tree.
332332
"""
333+
results: Dict[BigFrameNode, BigFrameNode] = {}
334+
# Each stack entry is (node, t_node). t_node is None until transform(node) is called.
335+
stack: list[tuple[BigFrameNode, typing.Optional[BigFrameNode]]] = [(self, None)]
333336

334-
@functools.cache
335-
def recursive_transform(node: BigFrameNode) -> BigFrameNode:
336-
return transform(node).transform_children(recursive_transform)
337+
while stack:
338+
node, t_node = stack[-1]
339+
340+
if t_node is None:
341+
if node in results:
342+
stack.pop()
343+
continue
344+
t_node = transform(node)
345+
stack[-1] = (node, t_node)
346+
347+
all_done = True
348+
for child in reversed(t_node.child_nodes):
349+
if child not in results:
350+
stack.append((child, None))
351+
all_done = False
352+
break
353+
354+
if all_done:
355+
results[node] = t_node.transform_children(lambda x: results[x])
356+
stack.pop()
337357

338-
return recursive_transform(self)
358+
return results[self]
339359

340360
def bottom_up(
341361
self: BigFrameNode,

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
113113
)
114114
)
115115

116-
endpoit = op_args.get("endpoint", None)
117-
if endpoit is not None:
118-
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
116+
endpoint = op_args.get("endpoint", None)
117+
if endpoint is not None:
118+
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoint)))
119119

120120
request_type = op_args.get("request_type", None)
121121
if request_type is not None:

bigframes/operations/ai_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class AIIf(base_ops.NaryOp):
123123
name: ClassVar[str] = "ai_if"
124124

125125
prompt_context: Tuple[str | None, ...]
126-
connection_id: str
126+
connection_id: str | None
127127

128128
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
129129
return dtypes.BOOL_DTYPE
@@ -135,7 +135,7 @@ class AIClassify(base_ops.NaryOp):
135135

136136
prompt_context: Tuple[str | None, ...]
137137
categories: tuple[str, ...]
138-
connection_id: str
138+
connection_id: str | None
139139

140140
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
141141
return dtypes.STRING_DTYPE
@@ -146,7 +146,7 @@ class AIScore(base_ops.NaryOp):
146146
name: ClassVar[str] = "ai_score"
147147

148148
prompt_context: Tuple[str | None, ...]
149-
connection_id: str
149+
connection_id: str | None
150150

151151
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
152152
return dtypes.FLOAT_DTYPE
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.CLASSIFY(input => (`string_col`), categories => ['greeting', 'rejection']) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_classify/bigframes-dev.us.bigframes-default-connection/out.sql

File renamed without changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.IF(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_if/bigframes-dev.us.bigframes-default-connection/out.sql

File renamed without changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
AI.SCORE(prompt => (`string_col`, ' is the same as ', `string_col`)) AS `result`
3+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/out.sql renamed to tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_score/bigframes-dev.us.bigframes-default-connection/out.sql

File renamed without changes.

0 commit comments

Comments
 (0)