22In-memory implementation of the ADK BaseArtifactService for testing purposes.
33"""
44
5+ import time
56from collections import defaultdict
67from typing import Dict , List , Optional , Tuple , cast
78
89from google .adk .artifacts import BaseArtifactService
10+ from google .adk .artifacts .base_artifact_service import ArtifactVersion
911from google .genai import types as adk_types
1012from typing_extensions import override
1113
@@ -17,12 +19,12 @@ class TestInMemoryArtifactService(BaseArtifactService):
1719 An in-memory artifact service for testing.
1820
1921 Stores artifacts in a nested dictionary structure:
20- _artifacts_data[app_name][user_id][session_id_or_user_namespace_key][filename_key][version] = (content_bytes, mime_type)
22+ _artifacts_data[app_name][user_id][session_id_or_user_namespace_key][filename_key][version] = (content_bytes, mime_type, create_time )
2123 """
2224
2325 def __init__ (self ):
2426 self ._artifacts_data : Dict [
25- str , Dict [str , Dict [str , Dict [str , Dict [int , Tuple [bytes , str ]]]]]
27+ str , Dict [str , Dict [str , Dict [str , Dict [int , Tuple [bytes , str , float ]]]]]
2628 ] = defaultdict (
2729 lambda : defaultdict (lambda : defaultdict (lambda : defaultdict (dict )))
2830 )
@@ -70,8 +72,9 @@ async def save_artifact(
7072
7173 content_bytes = artifact .inline_data .data
7274 mime_type = artifact .inline_data .mime_type or "application/octet-stream"
75+ create_time = time .time ()
7376
74- versions_dict [new_version ] = (content_bytes , mime_type )
77+ versions_dict [new_version ] = (content_bytes , mime_type , create_time )
7578 return new_version
7679
7780 @override
@@ -103,7 +106,7 @@ async def load_artifact(
103106 if target_version not in versions_dict :
104107 return None
105108
106- content_bytes , mime_type = versions_dict [target_version ]
109+ content_bytes , mime_type , _ = versions_dict [target_version ]
107110 return adk_types .Part (
108111 inline_data = adk_types .Blob (mime_type = mime_type , data = content_bytes )
109112 )
@@ -154,6 +157,81 @@ async def list_versions(
154157 return []
155158 return sorted (list (versions_dict .keys ()))
156159
160+ @override
161+ async def list_artifact_versions (
162+ self ,
163+ * ,
164+ app_name : str ,
165+ user_id : str ,
166+ filename : str ,
167+ session_id : str ,
168+ ) -> List [ArtifactVersion ]:
169+ """Lists all versions and their metadata for a specific artifact."""
170+ app_key , user_key , effective_session_key , fn_key = self ._get_path_keys (
171+ app_name , user_id , session_id , filename
172+ )
173+ versions_dict = self ._artifacts_data [app_key ][user_key ][
174+ effective_session_key
175+ ].get (fn_key )
176+ if not versions_dict :
177+ return []
178+
179+ artifact_versions = []
180+ for version_num , (_ , mime_type , create_time ) in versions_dict .items ():
181+ artifact_version = ArtifactVersion (
182+ version = version_num ,
183+ canonical_uri = f"memory://{ app_key } /{ user_key } /{ effective_session_key } /{ fn_key } /{ version_num } " ,
184+ mime_type = mime_type ,
185+ create_time = create_time ,
186+ custom_metadata = {},
187+ )
188+ artifact_versions .append (artifact_version )
189+
190+ # Sort by version number
191+ artifact_versions .sort (key = lambda av : av .version )
192+ return artifact_versions
193+
194+ @override
195+ async def get_artifact_version (
196+ self ,
197+ * ,
198+ app_name : str ,
199+ user_id : str ,
200+ filename : str ,
201+ session_id : str ,
202+ version : Optional [int ] = None ,
203+ ) -> Optional [ArtifactVersion ]:
204+ """Gets the metadata for a specific version of an artifact."""
205+ app_key , user_key , effective_session_key , fn_key = self ._get_path_keys (
206+ app_name , user_id , session_id , filename
207+ )
208+ versions_dict = self ._artifacts_data [app_key ][user_key ][
209+ effective_session_key
210+ ].get (fn_key )
211+ if not versions_dict :
212+ return None
213+
214+ # Determine which version to get
215+ load_version = version
216+ if load_version is None :
217+ if not versions_dict :
218+ return None
219+ load_version = max (versions_dict .keys ())
220+
221+ if load_version not in versions_dict :
222+ return None
223+
224+ _ , mime_type , create_time = versions_dict [load_version ]
225+
226+ artifact_version = ArtifactVersion (
227+ version = load_version ,
228+ canonical_uri = f"memory://{ app_key } /{ user_key } /{ effective_session_key } /{ fn_key } /{ load_version } " ,
229+ mime_type = mime_type ,
230+ create_time = create_time ,
231+ custom_metadata = {},
232+ )
233+ return artifact_version
234+
157235 async def get_artifact_details (
158236 self , app_name : str , user_id : str , session_id : str , filename : str , version : int
159237 ) -> Optional [Tuple [bytes , str ]]:
@@ -169,7 +247,11 @@ async def get_artifact_details(
169247 .get (fn_key , {})
170248 .get (version )
171249 )
172- return cast (Optional [Tuple [bytes , str ]], artifact_data )
250+ if artifact_data is None :
251+ return None
252+ # Return only bytes and mime_type, discarding create_time for backward compatibility
253+ content_bytes , mime_type , _ = artifact_data
254+ return (content_bytes , mime_type )
173255
174256 async def get_all_artifacts_for_session (
175257 self , app_name : str , user_id : str , session_id : str
@@ -180,8 +262,16 @@ async def get_all_artifacts_for_session(
180262 """
181263 app_data = self ._artifacts_data .get (app_name , {})
182264 user_data = app_data .get (user_id , {})
183- session_data = user_data .get (session_id , {})
184- return cast (Dict [str , Dict [int , Tuple [bytes , str ]]], session_data )
265+ session_data_raw = user_data .get (session_id , {})
266+
267+ # Convert 3-tuples to 2-tuples for backward compatibility
268+ session_data = {}
269+ for filename , versions in session_data_raw .items ():
270+ session_data [filename ] = {
271+ version : (content_bytes , mime_type )
272+ for version , (content_bytes , mime_type , _ ) in versions .items ()
273+ }
274+ return session_data
185275
186276 async def get_all_user_artifacts (
187277 self , app_name : str , user_id : str
@@ -192,8 +282,16 @@ async def get_all_user_artifacts(
192282 """
193283 app_data = self ._artifacts_data .get (app_name , {})
194284 user_data = app_data .get (user_id , {})
195- user_namespace_data = user_data .get (_USER_NAMESPACE_KEY , {})
196- return cast (Dict [str , Dict [int , Tuple [bytes , str ]]], user_namespace_data )
285+ user_namespace_data_raw = user_data .get (_USER_NAMESPACE_KEY , {})
286+
287+ # Convert 3-tuples to 2-tuples for backward compatibility
288+ user_namespace_data = {}
289+ for filename , versions in user_namespace_data_raw .items ():
290+ user_namespace_data [filename ] = {
291+ version : (content_bytes , mime_type )
292+ for version , (content_bytes , mime_type , _ ) in versions .items ()
293+ }
294+ return user_namespace_data
197295
198296 async def clear_all_artifacts (self ) -> None :
199297 """Clears all artifacts from the in-memory store."""
0 commit comments