Skip to content

Commit 43fffc3

Browse files
DeanChensjShaharKatz
authored andcommitted
feat!: Rollback the DB migration as it is breaking
Co-authored-by: Shangjie Chen <[email protected]> PiperOrigin-RevId: 839818479
1 parent 06a9b95 commit 43fffc3

File tree

9 files changed

+342
-939
lines changed

9 files changed

+342
-939
lines changed

src/google/adk/cli/cli_tools_click.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from . import cli_deploy
3737
from .. import version
3838
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
39-
from ..sessions.migration import migration_runner
4039
from .cli import run_cli
4140
from .fast_api import get_fast_api_app
4241
from .utils import envs
@@ -1500,41 +1499,6 @@ def cli_deploy_cloud_run(
15001499
click.secho(f"Deploy failed: {e}", fg="red", err=True)
15011500

15021501

1503-
@main.group()
1504-
def migrate():
1505-
"""Migrate ADK database schemas."""
1506-
pass
1507-
1508-
1509-
@migrate.command("session", cls=HelpfulCommand)
1510-
@click.option(
1511-
"--source_db_url",
1512-
required=True,
1513-
help="SQLAlchemy URL of source database.",
1514-
)
1515-
@click.option(
1516-
"--dest_db_url",
1517-
required=True,
1518-
help="SQLAlchemy URL of destination database.",
1519-
)
1520-
@click.option(
1521-
"--log_level",
1522-
type=LOG_LEVELS,
1523-
default="INFO",
1524-
help="Optional. Set the logging level",
1525-
)
1526-
def cli_migrate_session(
1527-
*, source_db_url: str, dest_db_url: str, log_level: str
1528-
):
1529-
"""Migrates a session database to the latest schema version."""
1530-
logs.setup_adk_logger(getattr(logging, log_level.upper()))
1531-
try:
1532-
migration_runner.upgrade(source_db_url, dest_db_url)
1533-
click.secho("Migration check and upgrade process finished.", fg="green")
1534-
except Exception as e:
1535-
click.secho(f"Migration failed: {e}", fg="red", err=True)
1536-
1537-
15381502
@deploy.command("agent_engine")
15391503
@click.option(
15401504
"--api_key",

src/google/adk/sessions/database_session_service.py

Lines changed: 160 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@
1919
from datetime import timezone
2020
import json
2121
import logging
22+
import pickle
2223
from typing import Any
2324
from typing import Optional
2425
import uuid
2526

27+
from google.genai import types
28+
from sqlalchemy import Boolean
2629
from sqlalchemy import delete
2730
from sqlalchemy import Dialect
2831
from sqlalchemy import event
2932
from sqlalchemy import ForeignKeyConstraint
3033
from sqlalchemy import func
31-
from sqlalchemy import inspect
3234
from sqlalchemy import select
3335
from sqlalchemy import Text
3436
from sqlalchemy.dialects import mysql
@@ -39,11 +41,14 @@
3941
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
4042
from sqlalchemy.ext.asyncio import create_async_engine
4143
from sqlalchemy.ext.mutable import MutableDict
44+
from sqlalchemy.inspection import inspect
4245
from sqlalchemy.orm import DeclarativeBase
4346
from sqlalchemy.orm import Mapped
4447
from sqlalchemy.orm import mapped_column
4548
from sqlalchemy.orm import relationship
49+
from sqlalchemy.schema import MetaData
4650
from sqlalchemy.types import DateTime
51+
from sqlalchemy.types import PickleType
4752
from sqlalchemy.types import String
4853
from sqlalchemy.types import TypeDecorator
4954
from typing_extensions import override
@@ -52,10 +57,10 @@
5257
from . import _session_util
5358
from ..errors.already_exists_error import AlreadyExistsError
5459
from ..events.event import Event
60+
from ..events.event_actions import EventActions
5561
from .base_session_service import BaseSessionService
5662
from .base_session_service import GetSessionConfig
5763
from .base_session_service import ListSessionsResponse
58-
from .migration import _schema_check
5964
from .session import Session
6065
from .state import State
6166

@@ -106,20 +111,39 @@ def load_dialect_impl(self, dialect):
106111
return self.impl
107112

108113

109-
class Base(DeclarativeBase):
110-
"""Base class for database tables."""
114+
class DynamicPickleType(TypeDecorator):
115+
"""Represents a type that can be pickled."""
111116

112-
pass
117+
impl = PickleType
113118

119+
def load_dialect_impl(self, dialect):
120+
if dialect.name == "mysql":
121+
return dialect.type_descriptor(mysql.LONGBLOB)
122+
if dialect.name == "spanner+spanner":
123+
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType
114124

115-
class StorageMetadata(Base):
116-
"""Represents internal metadata stored in the database."""
125+
return dialect.type_descriptor(SpannerPickleType)
126+
return self.impl
127+
128+
def process_bind_param(self, value, dialect):
129+
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
130+
if value is not None:
131+
if dialect.name in ("spanner+spanner", "mysql"):
132+
return pickle.dumps(value)
133+
return value
134+
135+
def process_result_value(self, value, dialect):
136+
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
137+
if value is not None:
138+
if dialect.name in ("spanner+spanner", "mysql"):
139+
return pickle.loads(value)
140+
return value
117141

118-
__tablename__ = "adk_internal_metadata"
119-
key: Mapped[str] = mapped_column(
120-
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
121-
)
122-
value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
142+
143+
class Base(DeclarativeBase):
144+
"""Base class for database tables."""
145+
146+
pass
123147

124148

125149
class StorageSession(Base):
@@ -213,10 +237,46 @@ class StorageEvent(Base):
213237
)
214238

215239
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
240+
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
241+
actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType)
242+
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
243+
Text, nullable=True
244+
)
245+
branch: Mapped[str] = mapped_column(
246+
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
247+
)
216248
timestamp: Mapped[PreciseTimestamp] = mapped_column(
217249
PreciseTimestamp, default=func.now()
218250
)
219-
event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
251+
252+
# === Fields from llm_response.py ===
253+
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
254+
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
255+
DynamicJSON, nullable=True
256+
)
257+
custom_metadata: Mapped[dict[str, Any]] = mapped_column(
258+
DynamicJSON, nullable=True
259+
)
260+
usage_metadata: Mapped[dict[str, Any]] = mapped_column(
261+
DynamicJSON, nullable=True
262+
)
263+
citation_metadata: Mapped[dict[str, Any]] = mapped_column(
264+
DynamicJSON, nullable=True
265+
)
266+
267+
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
268+
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
269+
error_code: Mapped[str] = mapped_column(
270+
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
271+
)
272+
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
273+
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
274+
input_transcription: Mapped[dict[str, Any]] = mapped_column(
275+
DynamicJSON, nullable=True
276+
)
277+
output_transcription: Mapped[dict[str, Any]] = mapped_column(
278+
DynamicJSON, nullable=True
279+
)
220280

221281
storage_session: Mapped[StorageSession] = relationship(
222282
"StorageSession",
@@ -231,27 +291,102 @@ class StorageEvent(Base):
231291
),
232292
)
233293

294+
@property
295+
def long_running_tool_ids(self) -> set[str]:
296+
return (
297+
set(json.loads(self.long_running_tool_ids_json))
298+
if self.long_running_tool_ids_json
299+
else set()
300+
)
301+
302+
@long_running_tool_ids.setter
303+
def long_running_tool_ids(self, value: set[str]):
304+
if value is None:
305+
self.long_running_tool_ids_json = None
306+
else:
307+
self.long_running_tool_ids_json = json.dumps(list(value))
308+
234309
@classmethod
235310
def from_event(cls, session: Session, event: Event) -> StorageEvent:
236-
"""Creates a StorageEvent from an Event."""
237-
return StorageEvent(
311+
storage_event = StorageEvent(
238312
id=event.id,
239313
invocation_id=event.invocation_id,
314+
author=event.author,
315+
branch=event.branch,
316+
actions=event.actions,
240317
session_id=session.id,
241318
app_name=session.app_name,
242319
user_id=session.user_id,
243320
timestamp=datetime.fromtimestamp(event.timestamp),
244-
event_data=event.model_dump(exclude_none=True, mode="json"),
321+
long_running_tool_ids=event.long_running_tool_ids,
322+
partial=event.partial,
323+
turn_complete=event.turn_complete,
324+
error_code=event.error_code,
325+
error_message=event.error_message,
326+
interrupted=event.interrupted,
245327
)
328+
if event.content:
329+
storage_event.content = event.content.model_dump(
330+
exclude_none=True, mode="json"
331+
)
332+
if event.grounding_metadata:
333+
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
334+
exclude_none=True, mode="json"
335+
)
336+
if event.custom_metadata:
337+
storage_event.custom_metadata = event.custom_metadata
338+
if event.usage_metadata:
339+
storage_event.usage_metadata = event.usage_metadata.model_dump(
340+
exclude_none=True, mode="json"
341+
)
342+
if event.citation_metadata:
343+
storage_event.citation_metadata = event.citation_metadata.model_dump(
344+
exclude_none=True, mode="json"
345+
)
346+
if event.input_transcription:
347+
storage_event.input_transcription = event.input_transcription.model_dump(
348+
exclude_none=True, mode="json"
349+
)
350+
if event.output_transcription:
351+
storage_event.output_transcription = (
352+
event.output_transcription.model_dump(exclude_none=True, mode="json")
353+
)
354+
return storage_event
246355

247356
def to_event(self) -> Event:
248-
"""Converts the StorageEvent to an Event."""
249-
return Event.model_validate({
250-
**self.event_data,
251-
"id": self.id,
252-
"invocation_id": self.invocation_id,
253-
"timestamp": self.timestamp.timestamp(),
254-
})
357+
return Event(
358+
id=self.id,
359+
invocation_id=self.invocation_id,
360+
author=self.author,
361+
branch=self.branch,
362+
# This is needed as previous ADK version pickled actions might not have
363+
# value defined in the current version of the EventActions model.
364+
actions=EventActions().model_copy(update=self.actions.model_dump()),
365+
timestamp=self.timestamp.timestamp(),
366+
long_running_tool_ids=self.long_running_tool_ids,
367+
partial=self.partial,
368+
turn_complete=self.turn_complete,
369+
error_code=self.error_code,
370+
error_message=self.error_message,
371+
interrupted=self.interrupted,
372+
custom_metadata=self.custom_metadata,
373+
content=_session_util.decode_model(self.content, types.Content),
374+
grounding_metadata=_session_util.decode_model(
375+
self.grounding_metadata, types.GroundingMetadata
376+
),
377+
usage_metadata=_session_util.decode_model(
378+
self.usage_metadata, types.GenerateContentResponseUsageMetadata
379+
),
380+
citation_metadata=_session_util.decode_model(
381+
self.citation_metadata, types.CitationMetadata
382+
),
383+
input_transcription=_session_util.decode_model(
384+
self.input_transcription, types.Transcription
385+
),
386+
output_transcription=_session_util.decode_model(
387+
self.output_transcription, types.Transcription
388+
),
389+
)
255390

256391

257392
class StorageAppState(Base):
@@ -328,6 +463,7 @@ def __init__(self, db_url: str, **kwargs: Any):
328463
logger.info("Local timezone: %s", local_timezone)
329464

330465
self.db_engine: AsyncEngine = db_engine
466+
self.metadata: MetaData = MetaData()
331467

332468
# DB session factory method
333469
self.database_session_factory: async_sessionmaker[
@@ -347,46 +483,10 @@ async def _ensure_tables_created(self):
347483
async with self._table_creation_lock:
348484
# Double-check after acquiring the lock
349485
if not self._tables_created:
350-
# Check schema version BEFORE creating tables.
351-
# This prevents creating metadata table on a v0.1 DB.
352-
async with self.database_session_factory() as sql_session:
353-
version, is_v01 = await sql_session.run_sync(
354-
_schema_check.get_version_and_v01_status_sync
355-
)
356-
357-
if is_v01:
358-
raise RuntimeError(
359-
"Database schema appears to be v0.1, but"
360-
f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please"
361-
" migrate the database using 'adk migrate session'."
362-
)
363-
elif version and version < _schema_check.CURRENT_SCHEMA_VERSION:
364-
raise RuntimeError(
365-
f"Database schema version is {version}, but current version is"
366-
f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate"
367-
" the database to the latest version using 'adk migrate"
368-
" session'."
369-
)
370-
371486
async with self.db_engine.begin() as conn:
372487
# Uncomment to recreate DB every time
373488
# await conn.run_sync(Base.metadata.drop_all)
374489
await conn.run_sync(Base.metadata.create_all)
375-
376-
# If we are here, DB is either new or >= current version.
377-
# If new or without metadata row, stamp it as current version.
378-
async with self.database_session_factory() as sql_session:
379-
metadata = await sql_session.get(
380-
StorageMetadata, _schema_check.SCHEMA_VERSION_KEY
381-
)
382-
if not metadata:
383-
sql_session.add(
384-
StorageMetadata(
385-
key=_schema_check.SCHEMA_VERSION_KEY,
386-
value=_schema_check.CURRENT_SCHEMA_VERSION,
387-
)
388-
)
389-
await sql_session.commit()
390490
self._tables_created = True
391491

392492
@override
@@ -623,9 +723,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
623723
storage_session.state = storage_session.state | session_state_delta
624724

625725
if storage_session._dialect_name == "sqlite":
626-
update_time = datetime.fromtimestamp(
627-
event.timestamp, timezone.utc
628-
).replace(tzinfo=None)
726+
update_time = datetime.utcfromtimestamp(event.timestamp)
629727
else:
630728
update_time = datetime.fromtimestamp(event.timestamp)
631729
storage_session.update_time = update_time

src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py renamed to src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py

File renamed without changes.

0 commit comments

Comments
 (0)