1919from datetime import timezone
2020import json
2121import logging
22+ import pickle
2223from typing import Any
2324from typing import Optional
2425import uuid
2526
27+ from google .genai import types
28+ from sqlalchemy import Boolean
2629from sqlalchemy import delete
2730from sqlalchemy import Dialect
2831from sqlalchemy import event
2932from sqlalchemy import ForeignKeyConstraint
3033from sqlalchemy import func
31- from sqlalchemy import inspect
3234from sqlalchemy import select
3335from sqlalchemy import Text
3436from sqlalchemy .dialects import mysql
3941from sqlalchemy .ext .asyncio import AsyncSession as DatabaseSessionFactory
4042from sqlalchemy .ext .asyncio import create_async_engine
4143from sqlalchemy .ext .mutable import MutableDict
44+ from sqlalchemy .inspection import inspect
4245from sqlalchemy .orm import DeclarativeBase
4346from sqlalchemy .orm import Mapped
4447from sqlalchemy .orm import mapped_column
4548from sqlalchemy .orm import relationship
49+ from sqlalchemy .schema import MetaData
4650from sqlalchemy .types import DateTime
51+ from sqlalchemy .types import PickleType
4752from sqlalchemy .types import String
4853from sqlalchemy .types import TypeDecorator
4954from typing_extensions import override
5257from . import _session_util
5358from ..errors .already_exists_error import AlreadyExistsError
5459from ..events .event import Event
60+ from ..events .event_actions import EventActions
5561from .base_session_service import BaseSessionService
5662from .base_session_service import GetSessionConfig
5763from .base_session_service import ListSessionsResponse
58- from .migration import _schema_check
5964from .session import Session
6065from .state import State
6166
@@ -106,20 +111,39 @@ def load_dialect_impl(self, dialect):
106111 return self .impl
107112
108113
109- class Base ( DeclarativeBase ):
110- """Base class for database tables ."""
114+ class DynamicPickleType ( TypeDecorator ):
115+ """Represents a type that can be pickled ."""
111116
112- pass
117+ impl = PickleType
113118
119+ def load_dialect_impl (self , dialect ):
120+ if dialect .name == "mysql" :
121+ return dialect .type_descriptor (mysql .LONGBLOB )
122+ if dialect .name == "spanner+spanner" :
123+ from google .cloud .sqlalchemy_spanner .sqlalchemy_spanner import SpannerPickleType
114124
115- class StorageMetadata (Base ):
116- """Represents internal metadata stored in the database."""
125+ return dialect .type_descriptor (SpannerPickleType )
126+ return self .impl
127+
128+ def process_bind_param (self , value , dialect ):
129+ """Ensures the pickled value is a bytes object before passing it to the database dialect."""
130+ if value is not None :
131+ if dialect .name in ("spanner+spanner" , "mysql" ):
132+ return pickle .dumps (value )
133+ return value
134+
135+ def process_result_value (self , value , dialect ):
136+ """Ensures the raw bytes from the database are unpickled back into a Python object."""
137+ if value is not None :
138+ if dialect .name in ("spanner+spanner" , "mysql" ):
139+ return pickle .loads (value )
140+ return value
117141
118- __tablename__ = "adk_internal_metadata"
119- key : Mapped [ str ] = mapped_column (
120- String ( DEFAULT_MAX_KEY_LENGTH ), primary_key = True
121- )
122- value : Mapped [ str ] = mapped_column ( String ( DEFAULT_MAX_VARCHAR_LENGTH ))
142+
143+ class Base ( DeclarativeBase ):
144+ """Base class for database tables."""
145+
146+ pass
123147
124148
125149class StorageSession (Base ):
@@ -213,10 +237,46 @@ class StorageEvent(Base):
213237 )
214238
215239 invocation_id : Mapped [str ] = mapped_column (String (DEFAULT_MAX_VARCHAR_LENGTH ))
240+ author : Mapped [str ] = mapped_column (String (DEFAULT_MAX_VARCHAR_LENGTH ))
241+ actions : Mapped [MutableDict [str , Any ]] = mapped_column (DynamicPickleType )
242+ long_running_tool_ids_json : Mapped [Optional [str ]] = mapped_column (
243+ Text , nullable = True
244+ )
245+ branch : Mapped [str ] = mapped_column (
246+ String (DEFAULT_MAX_VARCHAR_LENGTH ), nullable = True
247+ )
216248 timestamp : Mapped [PreciseTimestamp ] = mapped_column (
217249 PreciseTimestamp , default = func .now ()
218250 )
219- event_data : Mapped [dict [str , Any ]] = mapped_column (DynamicJSON )
251+
252+ # === Fields from llm_response.py ===
253+ content : Mapped [dict [str , Any ]] = mapped_column (DynamicJSON , nullable = True )
254+ grounding_metadata : Mapped [dict [str , Any ]] = mapped_column (
255+ DynamicJSON , nullable = True
256+ )
257+ custom_metadata : Mapped [dict [str , Any ]] = mapped_column (
258+ DynamicJSON , nullable = True
259+ )
260+ usage_metadata : Mapped [dict [str , Any ]] = mapped_column (
261+ DynamicJSON , nullable = True
262+ )
263+ citation_metadata : Mapped [dict [str , Any ]] = mapped_column (
264+ DynamicJSON , nullable = True
265+ )
266+
267+ partial : Mapped [bool ] = mapped_column (Boolean , nullable = True )
268+ turn_complete : Mapped [bool ] = mapped_column (Boolean , nullable = True )
269+ error_code : Mapped [str ] = mapped_column (
270+ String (DEFAULT_MAX_VARCHAR_LENGTH ), nullable = True
271+ )
272+ error_message : Mapped [str ] = mapped_column (String (1024 ), nullable = True )
273+ interrupted : Mapped [bool ] = mapped_column (Boolean , nullable = True )
274+ input_transcription : Mapped [dict [str , Any ]] = mapped_column (
275+ DynamicJSON , nullable = True
276+ )
277+ output_transcription : Mapped [dict [str , Any ]] = mapped_column (
278+ DynamicJSON , nullable = True
279+ )
220280
221281 storage_session : Mapped [StorageSession ] = relationship (
222282 "StorageSession" ,
@@ -231,27 +291,102 @@ class StorageEvent(Base):
231291 ),
232292 )
233293
294+ @property
295+ def long_running_tool_ids (self ) -> set [str ]:
296+ return (
297+ set (json .loads (self .long_running_tool_ids_json ))
298+ if self .long_running_tool_ids_json
299+ else set ()
300+ )
301+
302+ @long_running_tool_ids .setter
303+ def long_running_tool_ids (self , value : set [str ]):
304+ if value is None :
305+ self .long_running_tool_ids_json = None
306+ else :
307+ self .long_running_tool_ids_json = json .dumps (list (value ))
308+
234309 @classmethod
235310 def from_event (cls , session : Session , event : Event ) -> StorageEvent :
236- """Creates a StorageEvent from an Event."""
237- return StorageEvent (
311+ storage_event = StorageEvent (
238312 id = event .id ,
239313 invocation_id = event .invocation_id ,
314+ author = event .author ,
315+ branch = event .branch ,
316+ actions = event .actions ,
240317 session_id = session .id ,
241318 app_name = session .app_name ,
242319 user_id = session .user_id ,
243320 timestamp = datetime .fromtimestamp (event .timestamp ),
244- event_data = event .model_dump (exclude_none = True , mode = "json" ),
321+ long_running_tool_ids = event .long_running_tool_ids ,
322+ partial = event .partial ,
323+ turn_complete = event .turn_complete ,
324+ error_code = event .error_code ,
325+ error_message = event .error_message ,
326+ interrupted = event .interrupted ,
245327 )
328+ if event .content :
329+ storage_event .content = event .content .model_dump (
330+ exclude_none = True , mode = "json"
331+ )
332+ if event .grounding_metadata :
333+ storage_event .grounding_metadata = event .grounding_metadata .model_dump (
334+ exclude_none = True , mode = "json"
335+ )
336+ if event .custom_metadata :
337+ storage_event .custom_metadata = event .custom_metadata
338+ if event .usage_metadata :
339+ storage_event .usage_metadata = event .usage_metadata .model_dump (
340+ exclude_none = True , mode = "json"
341+ )
342+ if event .citation_metadata :
343+ storage_event .citation_metadata = event .citation_metadata .model_dump (
344+ exclude_none = True , mode = "json"
345+ )
346+ if event .input_transcription :
347+ storage_event .input_transcription = event .input_transcription .model_dump (
348+ exclude_none = True , mode = "json"
349+ )
350+ if event .output_transcription :
351+ storage_event .output_transcription = (
352+ event .output_transcription .model_dump (exclude_none = True , mode = "json" )
353+ )
354+ return storage_event
246355
247356 def to_event (self ) -> Event :
248- """Converts the StorageEvent to an Event."""
249- return Event .model_validate ({
250- ** self .event_data ,
251- "id" : self .id ,
252- "invocation_id" : self .invocation_id ,
253- "timestamp" : self .timestamp .timestamp (),
254- })
357+ return Event (
358+ id = self .id ,
359+ invocation_id = self .invocation_id ,
360+ author = self .author ,
361+ branch = self .branch ,
362+ # This is needed as previous ADK version pickled actions might not have
363+ # value defined in the current version of the EventActions model.
364+ actions = EventActions ().model_copy (update = self .actions .model_dump ()),
365+ timestamp = self .timestamp .timestamp (),
366+ long_running_tool_ids = self .long_running_tool_ids ,
367+ partial = self .partial ,
368+ turn_complete = self .turn_complete ,
369+ error_code = self .error_code ,
370+ error_message = self .error_message ,
371+ interrupted = self .interrupted ,
372+ custom_metadata = self .custom_metadata ,
373+ content = _session_util .decode_model (self .content , types .Content ),
374+ grounding_metadata = _session_util .decode_model (
375+ self .grounding_metadata , types .GroundingMetadata
376+ ),
377+ usage_metadata = _session_util .decode_model (
378+ self .usage_metadata , types .GenerateContentResponseUsageMetadata
379+ ),
380+ citation_metadata = _session_util .decode_model (
381+ self .citation_metadata , types .CitationMetadata
382+ ),
383+ input_transcription = _session_util .decode_model (
384+ self .input_transcription , types .Transcription
385+ ),
386+ output_transcription = _session_util .decode_model (
387+ self .output_transcription , types .Transcription
388+ ),
389+ )
255390
256391
257392class StorageAppState (Base ):
@@ -328,6 +463,7 @@ def __init__(self, db_url: str, **kwargs: Any):
328463 logger .info ("Local timezone: %s" , local_timezone )
329464
330465 self .db_engine : AsyncEngine = db_engine
466+ self .metadata : MetaData = MetaData ()
331467
332468 # DB session factory method
333469 self .database_session_factory : async_sessionmaker [
@@ -347,46 +483,10 @@ async def _ensure_tables_created(self):
347483 async with self ._table_creation_lock :
348484 # Double-check after acquiring the lock
349485 if not self ._tables_created :
350- # Check schema version BEFORE creating tables.
351- # This prevents creating metadata table on a v0.1 DB.
352- async with self .database_session_factory () as sql_session :
353- version , is_v01 = await sql_session .run_sync (
354- _schema_check .get_version_and_v01_status_sync
355- )
356-
357- if is_v01 :
358- raise RuntimeError (
359- "Database schema appears to be v0.1, but"
360- f" { _schema_check .CURRENT_SCHEMA_VERSION } is required. Please"
361- " migrate the database using 'adk migrate session'."
362- )
363- elif version and version < _schema_check .CURRENT_SCHEMA_VERSION :
364- raise RuntimeError (
365- f"Database schema version is { version } , but current version is"
366- f" { _schema_check .CURRENT_SCHEMA_VERSION } . Please migrate"
367- " the database to the latest version using 'adk migrate"
368- " session'."
369- )
370-
371486 async with self .db_engine .begin () as conn :
372487 # Uncomment to recreate DB every time
373488 # await conn.run_sync(Base.metadata.drop_all)
374489 await conn .run_sync (Base .metadata .create_all )
375-
376- # If we are here, DB is either new or >= current version.
377- # If new or without metadata row, stamp it as current version.
378- async with self .database_session_factory () as sql_session :
379- metadata = await sql_session .get (
380- StorageMetadata , _schema_check .SCHEMA_VERSION_KEY
381- )
382- if not metadata :
383- sql_session .add (
384- StorageMetadata (
385- key = _schema_check .SCHEMA_VERSION_KEY ,
386- value = _schema_check .CURRENT_SCHEMA_VERSION ,
387- )
388- )
389- await sql_session .commit ()
390490 self ._tables_created = True
391491
392492 @override
@@ -623,9 +723,7 @@ async def append_event(self, session: Session, event: Event) -> Event:
623723 storage_session .state = storage_session .state | session_state_delta
624724
625725 if storage_session ._dialect_name == "sqlite" :
626- update_time = datetime .fromtimestamp (
627- event .timestamp , timezone .utc
628- ).replace (tzinfo = None )
726+ update_time = datetime .utcfromtimestamp (event .timestamp )
629727 else :
630728 update_time = datetime .fromtimestamp (event .timestamp )
631729 storage_session .update_time = update_time
0 commit comments