diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 94272673..6bfca4d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,6 +41,9 @@ jobs: - name: Run pre-commit run: uv run pre-commit run --all-files + - name: Run mypy + run: uv run mypy src/ + test: name: Test Python ${{ matrix.python-version }} runs-on: ubuntu-latest diff --git a/src/bedrock_agentcore/evaluation/integrations/strands_agents_evals/evaluator.py b/src/bedrock_agentcore/evaluation/integrations/strands_agents_evals/evaluator.py index 8f371f11..3fcfd59e 100644 --- a/src/bedrock_agentcore/evaluation/integrations/strands_agents_evals/evaluator.py +++ b/src/bedrock_agentcore/evaluation/integrations/strands_agents_evals/evaluator.py @@ -31,7 +31,7 @@ def _is_valid_adot_document(item: Any) -> bool: return isinstance(item, dict) and "scope" in item and "traceId" in item and "spanId" in item -def _validate_spans(spans): +def _validate_spans(spans: Any) -> bool: """Validate spans are OpenTelemetry Span objects.""" if not spans: return False @@ -127,14 +127,14 @@ def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> List[Eva ] # Check if spans are already in ADOT format or need conversion - if _is_adot_format(evaluation_case.actual_trajectory): + if _is_adot_format(evaluation_case.actual_trajectory): # type: ignore[arg-type] # Already in ADOT format (fetched from CloudWatch), use as-is spans = evaluation_case.actual_trajectory else: # Raw OTel spans from in-memory exporter, validate and convert if not _validate_spans(evaluation_case.actual_trajectory): return [EvaluationOutput(score=0.0, test_pass=False, reason="Invalid span objects")] - spans = convert_strands_to_adot(evaluation_case.actual_trajectory) + spans = convert_strands_to_adot(evaluation_case.actual_trajectory) # type: ignore[arg-type] request_payload = {"evaluatorId": self.evaluator_id, "evaluationInput": {"sessionSpans": spans}} @@ -165,7 +165,7 @@ async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) return await asyncio.to_thread(self.evaluate, evaluation_case) -def create_strands_evaluator(evaluator_id: str, **kwargs) -> StrandsEvalsAgentCoreEvaluator: +def create_strands_evaluator(evaluator_id: str, **kwargs: Any) -> StrandsEvalsAgentCoreEvaluator: """Create Strands-compatible evaluator backed by AgentCore Evaluation API. Args: diff --git a/src/bedrock_agentcore/evaluation/span_to_adot_serializer/adot_models.py b/src/bedrock_agentcore/evaluation/span_to_adot_serializer/adot_models.py index 7f336ada..9bd8c813 100644 --- a/src/bedrock_agentcore/evaluation/span_to_adot_serializer/adot_models.py +++ b/src/bedrock_agentcore/evaluation/span_to_adot_serializer/adot_models.py @@ -77,7 +77,7 @@ class SpanParser: """ @staticmethod - def extract_metadata(span) -> SpanMetadata: + def extract_metadata(span: Any) -> SpanMetadata: """Extract core span metadata.""" if not hasattr(span, "context") or not span.context: raise ValueError(f"Span '{getattr(span, 'name', 'unknown')}' missing required context") @@ -96,7 +96,7 @@ def extract_metadata(span) -> SpanMetadata: ) @staticmethod - def extract_resource_info(span) -> ResourceInfo: + def extract_resource_info(span: Any) -> ResourceInfo: """Extract resource and scope information.""" resource_attrs = {} if hasattr(span, "resource") and span.resource and hasattr(span.resource, "attributes"): @@ -115,7 +115,7 @@ def extract_resource_info(span) -> ResourceInfo: ) @staticmethod - def get_span_attributes(span) -> Dict[str, Any]: + def get_span_attributes(span: Any) -> Dict[str, Any]: """Safely extract span attributes.""" return dict(span.attributes) if hasattr(span, "attributes") and span.attributes else {} diff --git a/src/bedrock_agentcore/evaluation/span_to_adot_serializer/strands_converter.py b/src/bedrock_agentcore/evaluation/span_to_adot_serializer/strands_converter.py index 7ff3fc07..a918043d 100644 --- a/src/bedrock_agentcore/evaluation/span_to_adot_serializer/strands_converter.py +++ b/src/bedrock_agentcore/evaluation/span_to_adot_serializer/strands_converter.py @@ -120,13 +120,13 @@ def extract_tool_execution(cls, events: List[Any]) -> Optional[ToolExecution]: class StrandsToADOTConverter: """Convert Strands OTel spans to ADOT format.""" - def __init__(self): + def __init__(self) -> None: """Initialize converter with parsers and builder.""" self.span_parser = SpanParser() self.event_parser = StrandsEventParser() self.doc_builder = ADOTDocumentBuilder() - def convert_span(self, span) -> List[Dict[str, Any]]: + def convert_span(self, span: Any) -> List[Dict[str, Any]]: """Convert a single span to ADOT documents.""" documents = [] diff --git a/src/bedrock_agentcore/identity/auth.py b/src/bedrock_agentcore/identity/auth.py index 4bd99b35..944aec62 100644 --- a/src/bedrock_agentcore/identity/auth.py +++ b/src/bedrock_agentcore/identity/auth.py @@ -187,7 +187,7 @@ def _get_iam_jwt_token(region: str) -> str: try: response = sts_client.get_web_identity_token(**params) logger.info("Successfully obtained AWS IAM JWT token") - return response["WebIdentityToken"] + return response["WebIdentityToken"] # type: ignore[no-any-return] except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "") if error_code in ["FeatureDisabledException", "FeatureDisabled"]: @@ -231,7 +231,7 @@ def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable: def decorator(func: Callable) -> Callable: client = IdentityClient(_get_region()) - async def _get_api_key(): + async def _get_api_key() -> str: return await client.get_api_key( provider_name=provider_name, agent_identity_token=await _get_workload_access_token(client), @@ -268,7 +268,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: return decorator -def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]): +def _get_oauth2_callback_url(user_provided_oauth2_callback_url: Optional[str]) -> Optional[str]: if user_provided_oauth2_callback_url: return user_provided_oauth2_callback_url @@ -298,7 +298,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str: config_path = Path(".agentcore.json") workload_identity_name = None - config = {} + config: dict[str, str] = {} if config_path.exists(): try: with open(config_path, "r", encoding="utf-8") as file: @@ -327,7 +327,7 @@ async def _set_up_local_auth(client: IdentityClient) -> str: except Exception: print("Warning: could not write the created workload identity to file") - return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"] + return client.get_workload_access_token(workload_identity_name, user_id=user_id)["workloadAccessToken"] # type: ignore[no-any-return] def _get_region() -> str: diff --git a/src/bedrock_agentcore/memory/client.py b/src/bedrock_agentcore/memory/client.py index 53bb52af..68b6ea5a 100644 --- a/src/bedrock_agentcore/memory/client.py +++ b/src/bedrock_agentcore/memory/client.py @@ -98,7 +98,7 @@ def __init__( self.gmdp_client.meta.region_name, ) - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: """Dynamically forward method calls to the appropriate boto3 client. This method enables access to all boto3 client methods without explicitly @@ -203,7 +203,7 @@ def create_or_get_memory( try: memory = self.create_memory_and_wait( name=name, - strategies=strategies, + strategies=strategies, # type: ignore[arg-type] description=description, event_expiry_days=event_expiry_days, memory_execution_role_arn=memory_execution_role_arn, @@ -213,7 +213,7 @@ def create_or_get_memory( except ClientError as e: if e.response["Error"]["Code"] == "ValidationException" and "already exists" in str(e): memories = self.list_memories() - memory = next((m for m in memories if m["id"].startswith(name)), None) + memory = next((m for m in memories if m["id"].startswith(name)), None) # type: ignore[arg-type] logger.info("Memory already exists. Using existing memory ID: %s", memory["id"]) return memory else: @@ -338,7 +338,7 @@ def retrieve_memories( memoryId=memory_id, namespace=namespace, searchCriteria={"searchQuery": query, "topK": top_k} ) - memories = response.get("memoryRecordSummaries", []) + memories: list[Dict[str, Any]] = response.get("memoryRecordSummaries", []) logger.info("Retrieved %d memories from namespace: %s", len(memories), namespace) return memories @@ -473,7 +473,7 @@ def create_event( response = self.gmdp_client.create_event(**params) - event = response["event"] + event: Dict[str, Any] = response["event"] logger.info("Created event: %s", event["eventId"]) return event @@ -539,7 +539,7 @@ def create_blob_event( response = self.gmdp_client.create_event(**params) - event = response["event"] + event: Dict[str, Any] = response["event"] logger.info("Created blob event: %s", event["eventId"]) return event @@ -635,7 +635,7 @@ def save_conversation( response = self.gmdp_client.create_event(**params) - event = response["event"] + event: Dict[str, Any] = response["event"] logger.info("Created event: %s", event["eventId"]) return event @@ -777,7 +777,7 @@ def list_events( ) """ try: - all_events = [] + all_events: List[Dict[str, Any]] = [] next_token = None while len(all_events) < max_results: @@ -793,7 +793,7 @@ def list_events( params["nextToken"] = next_token # Build filter map - filter_map = {} + filter_map: Dict[str, Any] = {} # Add branch filter if specified (but not for "main") if branch_name and branch_name != "main": @@ -937,7 +937,7 @@ def list_branch_events( params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}} response = self.gmdp_client.list_events(**params) - events = response.get("events", []) + events: list[Dict[str, Any]] = response.get("events", []) # Handle pagination next_token = response.get("nextToken") @@ -992,7 +992,11 @@ def get_conversation_tree(self, memory_id: str, actor_id: str, session_id: str) break # Build tree structure - tree = {"session_id": session_id, "actor_id": actor_id, "main_branch": {"events": [], "branches": {}}} + tree: Dict[str, Any] = { + "session_id": session_id, + "actor_id": actor_id, + "main_branch": {"events": [], "branches": {}}, + } # Group events by branch for event in all_events: @@ -1094,7 +1098,7 @@ def get_last_k_turns( Returns: List of turns, where each turn is a list of message dictionaries """ - base_params = { + base_params: Dict[str, Any] = { "memoryId": memory_id, "actorId": actor_id, "sessionId": session_id, @@ -1222,7 +1226,7 @@ def get_memory_status(self, memory_id: str) -> str: """Get current memory status.""" try: response = self.gmcp_client.get_memory(memoryId=memory_id) # Input uses old field name - return response["memory"]["status"] + return response["memory"]["status"] # type: ignore[no-any-return] except ClientError as e: logger.error("Failed to get memory status: %s", e) raise @@ -1265,7 +1269,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]: def delete_memory(self, memory_id: str) -> Dict[str, Any]: """Delete a memory resource.""" try: - response = self.gmcp_client.delete_memory( + response: Dict[str, Any] = self.gmcp_client.delete_memory( memoryId=memory_id, clientToken=str(uuid.uuid4()) ) # Input uses old field name logger.info("Deleted memory: %s", memory_id) diff --git a/src/bedrock_agentcore/memory/constants.py b/src/bedrock_agentcore/memory/constants.py index 6ab3a444..0b0d0f04 100644 --- a/src/bedrock_agentcore/memory/constants.py +++ b/src/bedrock_agentcore/memory/constants.py @@ -128,7 +128,7 @@ class ConversationalMessage: text: str role: MessageRole - def __post_init__(self): + def __post_init__(self) -> None: """Validate message fields after initialization.""" if not isinstance(self.text, str): raise ValueError("ConversationalMessage.text must be a string") diff --git a/src/bedrock_agentcore/memory/controlplane.py b/src/bedrock_agentcore/memory/controlplane.py index 64fb77ee..a733d795 100644 --- a/src/bedrock_agentcore/memory/controlplane.py +++ b/src/bedrock_agentcore/memory/controlplane.py @@ -92,7 +92,7 @@ def create_memory( try: response = self.client.create_memory(**params) - memory = response["memory"] + memory: Dict[str, Any] = response["memory"] memory_id = memory["id"] logger.info("Created memory: %s", memory_id) @@ -118,7 +118,7 @@ def get_memory(self, memory_id: str, include_strategies: bool = True) -> Dict[st """ try: response = self.client.get_memory(memoryId=memory_id) - memory = response["memory"] + memory: Dict[str, Any] = response["memory"] # Add strategy count strategies = memory.get("strategies", []) @@ -144,7 +144,7 @@ def list_memories(self, max_results: int = 100) -> List[Dict[str, Any]]: List of memory summaries """ try: - memories = [] + memories: List[Dict[str, Any]] = [] next_token = None while len(memories) < max_results: @@ -239,7 +239,7 @@ def update_memory( try: response = self.client.update_memory(**params) - memory = response["memory"] + memory: Dict[str, Any] = response["memory"] logger.info("Updated memory: %s", memory_id) if wait_for_active: @@ -300,7 +300,7 @@ def delete_memory( logger.warning("Error waiting for strategies to become ACTIVE: %s", e) # Now delete the memory - response = self.client.delete_memory(memoryId=memory_id, clientToken=str(uuid.uuid4())) + response: Dict[str, Any] = self.client.delete_memory(memoryId=memory_id, clientToken=str(uuid.uuid4())) logger.info("Initiated deletion of memory: %s", memory_id) @@ -399,7 +399,8 @@ def get_strategy(self, memory_id: str, strategy_id: str) -> Dict[str, Any]: for strategy in strategies: if strategy.get("strategyId") == strategy_id: - return strategy + result: Dict[str, Any] = strategy + return result raise ValueError(f"Strategy {strategy_id} not found in memory {memory_id}") @@ -567,7 +568,7 @@ def _wait_for_status( start_time = time.time() last_memory_status = None - strategy_statuses = {} + strategy_statuses: Dict[str, str] = {} while time.time() - start_time < max_wait: try: diff --git a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py index 90a26617..b6b1b985 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py +++ b/src/bedrock_agentcore/memory/integrations/strands/bedrock_converter.py @@ -79,7 +79,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: if "conversational" in payload_item: conv = payload_item["conversational"] session_msg = SessionMessage.from_dict(json.loads(conv["content"]["text"])) - session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) # type: ignore[assignment, arg-type] if session_msg.message.get("content"): messages.append(session_msg) elif "blob" in payload_item: @@ -88,7 +88,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: if isinstance(blob_data, (tuple, list)) and len(blob_data) == 2: try: session_msg = SessionMessage.from_dict(json.loads(blob_data[0])) - session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) + session_msg.message = AgentCoreMemoryConverter._filter_empty_text(session_msg.message) # type: ignore[assignment, arg-type] if session_msg.message.get("content"): messages.append(session_msg) except (json.JSONDecodeError, ValueError): diff --git a/src/bedrock_agentcore/memory/integrations/strands/converters/openai.py b/src/bedrock_agentcore/memory/integrations/strands/converters/openai.py index 2810f30b..ba7c26b8 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/converters/openai.py +++ b/src/bedrock_agentcore/memory/integrations/strands/converters/openai.py @@ -59,7 +59,7 @@ def _bedrock_to_openai(message: dict) -> dict: } ) - result: dict[str, Any] = {"role": role} + result: dict[str, Any] = {"role": role} # type: ignore[no-redef] if tool_calls: result["content"] = "\n".join(text_parts) if text_parts else None @@ -144,7 +144,7 @@ def message_to_payload(session_message: SessionMessage) -> list[Tuple[str, str]] if not has_non_empty: return [] - openai_msg = _bedrock_to_openai(message) + openai_msg = _bedrock_to_openai(message) # type: ignore[arg-type] role = openai_msg.get("role", "user") return [(json.dumps(openai_msg), role)] @@ -177,7 +177,7 @@ def events_to_messages(events: list[dict[str, Any]]) -> list[SessionMessage]: if openai_msg and isinstance(openai_msg, dict): bedrock_msg = _openai_to_bedrock(openai_msg) if bedrock_msg.get("content"): - session_msg = SessionMessage(message=bedrock_msg, message_id=0) + session_msg = SessionMessage(message=bedrock_msg, message_id=0) # type: ignore[arg-type] messages.append(session_msg) return messages diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 68bc05b9..a24b5c9f 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -221,10 +221,10 @@ def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: # 1. Try new approach (metadata filter) event_metadata = [ - EventMetadataFilter.build_expression( - left_operand=LeftExpression.build(STATE_TYPE_KEY), + EventMetadataFilter.build_expression( # type: ignore[attr-defined] + left_operand=LeftExpression.build(STATE_TYPE_KEY), # type: ignore[attr-defined] operator=OperatorType.EQUALS_TO, - right_operand=RightExpression.build(StateType.SESSION.value), + right_operand=RightExpression.build(StateType.SESSION.value), # type: ignore[attr-defined] ) ] @@ -296,7 +296,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A # Cache the created_at timestamp to avoid re-fetching on updates if session_agent.created_at: - self._agent_created_at_cache[session_agent.agent_id] = session_agent.created_at + self._agent_created_at_cache[session_agent.agent_id] = session_agent.created_at # type: ignore[assignment] if self.config.batch_size > 1: # Buffer the agent state events @@ -355,15 +355,15 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[ try: # 1. Try new approach (metadata filter) event_metadata = [ - EventMetadataFilter.build_expression( - left_operand=LeftExpression.build(STATE_TYPE_KEY), + EventMetadataFilter.build_expression( # type: ignore[attr-defined] + left_operand=LeftExpression.build(STATE_TYPE_KEY), # type: ignore[attr-defined] operator=OperatorType.EQUALS_TO, - right_operand=RightExpression.build(StateType.AGENT.value), + right_operand=RightExpression.build(StateType.AGENT.value), # type: ignore[attr-defined] ), - EventMetadataFilter.build_expression( - left_operand=LeftExpression.build(AGENT_ID_KEY), + EventMetadataFilter.build_expression( # type: ignore[attr-defined] + left_operand=LeftExpression.build(AGENT_ID_KEY), # type: ignore[attr-defined] operator=OperatorType.EQUALS_TO, - right_operand=RightExpression.build(agent_id), + right_operand=RightExpression.build(agent_id), # type: ignore[attr-defined] ), ] @@ -380,7 +380,7 @@ def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[ agent = SessionAgent.from_dict(agent_data) # Cache the created_at timestamp to avoid re-fetching on updates if agent.created_at: - self._agent_created_at_cache[agent_id] = agent.created_at + self._agent_created_at_cache[agent_id] = agent.created_at # type: ignore[assignment] return agent # 2. Fallback: check for legacy event and migrate @@ -431,13 +431,13 @@ def update_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A raise SessionException(f"Agent {agent_id} in session {session_id} does not exist") # Set created_at from cache before creating the update event - session_agent.created_at = self._agent_created_at_cache[agent_id] + session_agent.created_at = self._agent_created_at_cache[agent_id] # type: ignore[assignment] # Create a new agent event (AgentCore Memory is immutable) # create_agent will handle batching and caching appropriately self.create_agent(session_id, session_agent) - def create_message( + def create_message( # type: ignore[override] self, session_id: str, agent_id: str, session_message: SessionMessage, **kwargs: Any ) -> Optional[dict[str, Any]]: """Create a new message in AgentCore Memory. @@ -658,7 +658,7 @@ def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> Non created_message = self.create_message(self.session_id, agent.agent_id, SessionMessage.from_message(message, 0)) if created_message is None: return - session_message = SessionMessage.from_message(message, created_message.get("eventId")) + session_message = SessionMessage.from_message(message, created_message.get("eventId")) # type: ignore[arg-type] self._latest_agent_message[agent.agent_id] = session_message def retrieve_customer_context(self, event: MessageAddedEvent) -> None: @@ -668,7 +668,7 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: event (MessageAddedEvent): The message added event containing the agent and message data. """ messages = event.agent.messages - if not messages or messages[-1].get("role") != "user" or "toolResult" in messages[-1].get("content")[0]: + if not messages or messages[-1].get("role") != "user" or "toolResult" in messages[-1].get("content")[0]: # type: ignore[index] return None if not self.config.retrieval_config: # Only retrieve LTM @@ -676,7 +676,7 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: user_query = messages[-1]["content"][0]["text"] - def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): + def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig) -> list[str]: """Helper function to retrieve memories for a single namespace.""" resolved_namespace = namespace.format( actorId=self.config.actor_id, @@ -739,7 +739,7 @@ def retrieve_for_namespace(namespace: str, retrieval_config: RetrievalConfig): logger.error("Failed to retrieve customer context: %s", e) @override - def register_hooks(self, registry: HookRegistry, **kwargs) -> None: + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register additional hooks. Args: @@ -751,7 +751,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs) -> None: # Only register AfterInvocationEvent hook when batching is enabled if self.config.batch_size > 1: - registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) + registry.add_callback(AfterInvocationEvent, lambda event: self._flush_messages()) # type: ignore[arg-type] @override def initialize(self, agent: "Agent", **kwargs: Any) -> None: @@ -988,7 +988,7 @@ def _start_flush_timer(self) -> None: # Schedule next flush self._flush_timer = threading.Timer( - self.config.flush_interval_seconds, + self.config.flush_interval_seconds, # type: ignore[arg-type] self._interval_flush_callback, ) self._flush_timer.daemon = True diff --git a/src/bedrock_agentcore/memory/models/DictWrapper.py b/src/bedrock_agentcore/memory/models/DictWrapper.py index 75dadbad..06348657 100644 --- a/src/bedrock_agentcore/memory/models/DictWrapper.py +++ b/src/bedrock_agentcore/memory/models/DictWrapper.py @@ -1,6 +1,7 @@ """Dictionary wrapper module for bedrock-agentcore memory models.""" -from typing import Any, Dict +from collections.abc import ItemsView, KeysView, ValuesView +from typing import Any, Dict, List class DictWrapper: @@ -30,26 +31,26 @@ def __contains__(self, key: str) -> bool: """Support 'in' operator for checking if key exists.""" return key in self._data - def keys(self): + def keys(self) -> KeysView[str]: """Return keys from the underlying dictionary.""" return self._data.keys() - def values(self): + def values(self) -> ValuesView[Any]: """Return values from the underlying dictionary.""" return self._data.values() - def items(self): + def items(self) -> ItemsView[str, Any]: """Return items from the underlying dictionary.""" return self._data.items() - def __dir__(self): + def __dir__(self) -> List[str]: """Enable tab completion and introspection of available attributes.""" return list(self._data.keys()) + ["get"] - def __repr__(self): + def __repr__(self) -> str: """Return a JSON-formatted string representation of the data.""" return self._data.__repr__() - def __str__(self): + def __str__(self) -> str: """Return a JSON-formatted string representation of the data.""" return self.__repr__() diff --git a/src/bedrock_agentcore/memory/models/filters.py b/src/bedrock_agentcore/memory/models/filters.py index 9ab25f7d..a354eff9 100644 --- a/src/bedrock_agentcore/memory/models/filters.py +++ b/src/bedrock_agentcore/memory/models/filters.py @@ -9,7 +9,7 @@ class StringValue(TypedDict): stringValue: str - @staticmethod + @staticmethod # type: ignore[misc] def build(value: str) -> "StringValue": """Build a StringValue from a string.""" return {"stringValue": value} @@ -34,7 +34,7 @@ class LeftExpression(TypedDict): metadataKey: MetadataKey - @staticmethod + @staticmethod # type: ignore[misc] def build(key: str) -> "LeftExpression": """Builds the `metadataKey` for `LeftExpression`.""" return {"metadataKey": key} @@ -63,7 +63,7 @@ class RightExpression(TypedDict): metadataValue: MetadataValue - @staticmethod + @staticmethod # type: ignore[misc] def build(value: str) -> "RightExpression": """Builds the `RightExpression` for `stringValue` type.""" return {"metadataValue": StringValue.build(value)} @@ -82,7 +82,7 @@ class EventMetadataFilter(TypedDict): operator: OperatorType right: Optional[RightExpression] - def build_expression( + def build_expression( # type: ignore[misc] left_operand: LeftExpression, operator: OperatorType, right_operand: Optional[RightExpression] = None, diff --git a/src/bedrock_agentcore/memory/session.py b/src/bedrock_agentcore/memory/session.py index bf88f278..8e6c5edc 100644 --- a/src/bedrock_agentcore/memory/session.py +++ b/src/bedrock_agentcore/memory/session.py @@ -192,7 +192,7 @@ def _build_client_config(self, boto_client_config: Optional[BotocoreConfig]) -> new_user_agent = f"{existing_user_agent} {sdk_user_agent}" else: new_user_agent = sdk_user_agent - return boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + return boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) # type: ignore[no-any-return] else: return BotocoreConfig(user_agent_extra=sdk_user_agent) @@ -205,7 +205,7 @@ def _configure_timestamp_serialization(self) -> None: """ original_serialize_timestamp = self._data_plane_client._serializer._serializer._serialize_type_timestamp - def serialize_timestamp_as_float(serialized, value, shape, name): + def serialize_timestamp_as_float(serialized: Any, value: Any, shape: Any, name: Any) -> None: if isinstance(value, datetime): serialized[name] = value.timestamp() # Convert to float (seconds since epoch with fractional seconds) else: @@ -213,7 +213,7 @@ def serialize_timestamp_as_float(serialized, value, shape, name): self._data_plane_client._serializer._serializer._serialize_type_timestamp = serialize_timestamp_as_float - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: """Dynamically forward method calls to the appropriate boto3 client. This method enables access to all data_plane boto3 client methods without explicitly @@ -409,7 +409,7 @@ def _retrieve_memories_for_llm( retrieved_memories.extend(memory_records) logger.info("Retrieved %d memories for LLM context", len(retrieved_memories)) - return retrieved_memories + return retrieved_memories # type: ignore[return-value] def _save_conversation_turn( self, @@ -432,7 +432,7 @@ def _save_conversation_turn( event_timestamp=event_timestamp, ) logger.info("Completed full conversation turn with LLM") - return event + return event # type: ignore[return-value] def add_turns( self, @@ -547,7 +547,7 @@ def fork_conversation( ) logger.info("Created branch '%s' from event %s", branch_name, root_event_id) - return event + return event # type: ignore[return-value] except ClientError as e: logger.error("Failed to fork conversation: %s", e) @@ -662,7 +662,7 @@ def list_events( # Add eventMetadata filter if specified if eventMetadata: - filterMap["eventMetadata"] = eventMetadata + filterMap["eventMetadata"] = eventMetadata # type: ignore[assignment] if filterMap: params["filter"] = filterMap @@ -760,7 +760,7 @@ def list_branches(self, actor_id: str, session_id: str) -> List[Branch]: # Only add main branch if there are actual events if main_branch_events: result.append( - { + { # type: ignore[arg-type] "name": "main", "rootEventId": None, "firstEventId": main_branch_events[0]["eventId"], @@ -770,10 +770,10 @@ def list_branches(self, actor_id: str, session_id: str) -> List[Branch]: ) # Add other branches - result.extend(list(branches.values())) + result.extend(list(branches.values())) # type: ignore[arg-type] logger.info("Found %d branches in session %s", len(result), session_id) - return [Branch(branch) for branch in result] + return [Branch(branch) for branch in result] # type: ignore[arg-type] except ClientError as e: logger.error("Failed to list branches: %s", e) @@ -807,7 +807,7 @@ def get_last_k_turns( } if branch_name and branch_name != "main": - base_params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}} + base_params["filter"] = {"branch": {"name": branch_name, "includeParentBranches": include_parent_branches}} # type: ignore[assignment] try: turns: List[List[EventMessage]] = [] @@ -875,7 +875,7 @@ def get_event(self, actor_id: str, session_id: str, event_id: str) -> Event: logger.error(" ❌ Error retrieving event: %s", e) raise - def delete_event(self, actor_id: str, session_id: str, event_id: str): + def delete_event(self, actor_id: str, session_id: str, event_id: str) -> None: """Deletes a specific event from short-term memory by its ID. Maps to: bedrock-agentcore.delete_event. @@ -895,7 +895,7 @@ def search_long_term_memories( query: str, namespace_prefix: str, top_k: int = 3, - strategy_id: str = None, + strategy_id: Optional[str] = None, max_results: int = 20, ) -> List[MemoryRecord]: """Performs a semantic search against the long-term memory for this actor. @@ -1000,7 +1000,7 @@ def get_memory_record(self, record_id: str) -> MemoryRecord: logger.error(" ❌ Error retrieving record: %s", e) raise - def delete_memory_record(self, record_id: str): + def delete_memory_record(self, record_id: str) -> None: """Deletes a specific long-term memory record by its ID. Maps to: bedrock-agentcore.delete_memory_record. @@ -1077,7 +1077,7 @@ def delete_all_long_term_memories_in_namespace(self, namespace: str) -> Dict[str return {"successfulRecords": all_successful, "failedRecords": all_failed} - def create_memory_session(self, actor_id: str, session_id: str = None) -> "MemorySession": + def create_memory_session(self, actor_id: str, session_id: Optional[str] = None) -> "MemorySession": """Creates a new MemorySession instance.""" session_id = session_id or str(uuid.uuid4()) logger.info("💬 Creating new conversation for actor '%s' in session '%s'...", actor_id, session_id) @@ -1128,7 +1128,7 @@ def fork_conversation( event_timestamp: Optional[datetime] = None, ) -> Event: """Delegates to manager.fork_conversation.""" - return self._manager.fork_conversation( + return self._manager.fork_conversation( # type: ignore[return-value] self._actor_id, self._session_id, root_event_id, branch_name, messages, metadata, event_timestamp ) @@ -1179,14 +1179,19 @@ def get_last_k_turns( ) -> List[List[EventMessage]]: """Delegates to manager.get_last_k_turns.""" return self._manager.get_last_k_turns( - self._actor_id, self._session_id, k, branch_name, include_parent_branches, max_results + self._actor_id, + self._session_id, + k, + branch_name, + include_parent_branches, # type: ignore[arg-type] + max_results, ) def get_event(self, event_id: str) -> Event: """Delegates to manager.get_event.""" return self._manager.get_event(self._actor_id, self._session_id, event_id) - def delete_event(self, event_id: str): + def delete_event(self, event_id: str) -> None: """Delegates to manager.delete_event.""" return self._manager.delete_event(self._actor_id, self._session_id, event_id) @@ -1194,7 +1199,7 @@ def get_memory_record(self, record_id: str) -> MemoryRecord: """Delegates to manager.get_memory_record.""" return self._manager.get_memory_record(record_id) - def delete_memory_record(self, record_id: str): + def delete_memory_record(self, record_id: str) -> None: """Delegates to manager.delete_memory_record.""" return self._manager.delete_memory_record(record_id) diff --git a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py index 782e2222..49d772ff 100644 --- a/src/bedrock_agentcore/runtime/agent_core_runtime_client.py +++ b/src/bedrock_agentcore/runtime/agent_core_runtime_client.py @@ -316,7 +316,7 @@ def generate_presigned_url( raise RuntimeError("Failed to generate presigned URL") # Convert back to wss:// for WebSocket connection - presigned_url = request.url.replace("https://", "wss://") + presigned_url: str = request.url.replace("https://", "wss://") self.logger.info("✓ Presigned URL generated (expires in %s seconds, Session: %s)", expires, session_id) return presigned_url diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index dd675b54..199a9a3d 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -70,7 +70,7 @@ def _restore_context(ctx: contextvars.Context) -> None: class RequestContextFormatter(logging.Formatter): """Formatter including request and session IDs.""" - def format(self, record): + def format(self, record: Any) -> str: """Format log record as AWS Lambda JSON.""" import json from datetime import datetime @@ -154,7 +154,7 @@ def entrypoint(self, func: Callable) -> Callable: The decorated function with added serve method """ self.handlers["main"] = func - func.run = lambda port=8080, host=None: self.run(port, host) + func.run = lambda port=8080, host=None: self.run(port, host) # type: ignore[attr-defined] return func def ping(self, func: Callable) -> Callable: @@ -197,7 +197,7 @@ def async_task(self, func: Callable) -> Callable: if not _is_async_callable(func): raise ValueError("@async_task can only be applied to async functions") - async def wrapper(*args, **kwargs): + async def wrapper(*args: Any, **kwargs: Any) -> Any: task_id = self.add_async_task(func.__name__) try: @@ -237,17 +237,17 @@ def get_current_ping_status(self) -> PingStatus: if current_status is None: current_status = PingStatus.HEALTHY_BUSY if self._active_tasks else PingStatus.HEALTHY - if not hasattr(self, "_last_known_status") or self._last_known_status != current_status: + if not hasattr(self, "_last_known_status") or self._last_known_status != current_status: # type: ignore[has-type] self._last_known_status = current_status self._last_status_update_time = time.time() return current_status - def force_ping_status(self, status: PingStatus): + def force_ping_status(self, status: PingStatus) -> None: """Force ping status to a specific value.""" self._forced_ping_status = status - def clear_forced_ping_status(self): + def clear_forced_ping_status(self) -> None: """Clear forced status and resume automatic.""" self._forced_ping_status = None @@ -327,7 +327,7 @@ def complete_async_task(self, task_id: int) -> bool: self.logger.warning("Attempted to complete unknown task ID: %s", task_id) return False - def _build_request_context(self, request) -> RequestContext: + def _build_request_context(self, request: Any) -> RequestContext: """Build request context and setup all context variables.""" try: headers = request.headers @@ -375,7 +375,7 @@ def _build_request_context(self, request) -> RequestContext: self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) request_id = str(uuid.uuid4()) BedrockAgentCoreContext.set_request_context(request_id, None) - return RequestContext(session_id=None, request=None) + return RequestContext(session_id=None, request=None) # type: ignore[call-arg] def _takes_context(self, handler: Callable) -> bool: try: @@ -384,7 +384,7 @@ def _takes_context(self, handler: Callable) -> bool: except Exception: return False - async def _handle_invocation(self, request): + async def _handle_invocation(self, request: Any) -> Response: request_context = self._build_request_context(request) start_time = time.time() @@ -459,7 +459,7 @@ async def _handle_invocation(self, request): self.logger.exception("Invocation failed (%.3fs)", duration) return JSONResponse({"error": str(e)}, status_code=500) - def _handle_ping(self, request): + def _handle_ping(self, request: Any) -> JSONResponse: try: status = self.get_current_ping_status() self.logger.debug("Ping request - status: %s", status.value) @@ -468,7 +468,7 @@ def _handle_ping(self, request): self.logger.exception("Ping endpoint failed") return JSONResponse({"status": PingStatus.HEALTHY.value, "time_of_last_update": int(time.time())}) - async def _handle_websocket(self, websocket: WebSocket): + async def _handle_websocket(self, websocket: WebSocket) -> None: """Handle WebSocket connections.""" request_context = self._build_request_context(websocket) @@ -491,7 +491,7 @@ async def _handle_websocket(self, websocket: WebSocket): except Exception: pass - def run(self, port: int = 8080, host: Optional[str] = None, **kwargs): + def run(self, port: int = 8080, host: Optional[str] = None, **kwargs: Any) -> None: """Start the Bedrock AgentCore server. Args: @@ -518,7 +518,7 @@ def run(self, port: int = 8080, host: Optional[str] = None, **kwargs): } uvicorn_params.update(kwargs) - uvicorn.run(self, **uvicorn_params) + uvicorn.run(self, **uvicorn_params) # type: ignore[arg-type] def _ensure_worker_loop(self) -> asyncio.AbstractEventLoop: """Lazily create and start a dedicated worker event loop in a background thread. @@ -542,7 +542,7 @@ def _ensure_worker_loop(self) -> asyncio.AbstractEventLoop: def _run_worker_loop(self) -> None: """Entry point for the worker loop background thread.""" asyncio.set_event_loop(self._worker_loop) - self._worker_loop.run_forever() + self._worker_loop.run_forever() # type: ignore[union-attr] @staticmethod async def _run_with_context(coro: Any, ctx: contextvars.Context) -> Any: @@ -629,18 +629,18 @@ def _handle_task_action(self, payload: dict) -> Optional[JSONResponse]: ), TASK_ACTION_JOB_STATUS: lambda: JSONResponse(self.get_async_task_info()), TASK_ACTION_FORCE_HEALTHY: lambda: ( - self.force_ping_status(PingStatus.HEALTHY), - self.logger.info("Ping status forced to Healthy"), + self.force_ping_status(PingStatus.HEALTHY), # type: ignore[func-returns-value] + self.logger.info("Ping status forced to Healthy"), # type: ignore[func-returns-value] JSONResponse({"forced_status": "Healthy"}), )[2], TASK_ACTION_FORCE_BUSY: lambda: ( - self.force_ping_status(PingStatus.HEALTHY_BUSY), - self.logger.info("Ping status forced to HealthyBusy"), + self.force_ping_status(PingStatus.HEALTHY_BUSY), # type: ignore[func-returns-value] + self.logger.info("Ping status forced to HealthyBusy"), # type: ignore[func-returns-value] JSONResponse({"forced_status": "HealthyBusy"}), )[2], TASK_ACTION_CLEAR_FORCED_STATUS: lambda: ( - self.clear_forced_ping_status(), - self.logger.info("Forced ping status cleared"), + self.clear_forced_ping_status(), # type: ignore[func-returns-value] + self.logger.info("Forced ping status cleared"), # type: ignore[func-returns-value] JSONResponse({"forced_status": "Cleared"}), )[2], } @@ -657,7 +657,7 @@ def _handle_task_action(self, payload: dict) -> Optional[JSONResponse]: self.logger.exception("Debug action '%s' failed", action) return JSONResponse({"error": "Debug action failed", "details": str(e)}, status_code=500) - async def _stream_with_error_handling(self, generator): + async def _stream_with_error_handling(self, generator: Any) -> Any: """Wrap async generator to handle errors and convert to SSE format.""" try: async for value in generator: @@ -671,7 +671,7 @@ async def _stream_with_error_handling(self, generator): } yield self._convert_to_sse(error_event) - def _safe_serialize_to_json_string(self, obj): + def _safe_serialize_to_json_string(self, obj: Any) -> str: """Safely serialize object directly to JSON string with progressive fallback handling. This method eliminates double JSON encoding by returning the JSON string directly, @@ -699,7 +699,7 @@ def _safe_serialize_to_json_string(self, obj): error_obj = {"error": "Serialization failed", "original_type": type(obj).__name__} return json.dumps(error_obj, ensure_ascii=False) - def _convert_to_sse(self, obj) -> bytes: + def _convert_to_sse(self, obj: Any) -> bytes: """Convert object to Server-Sent Events format using safe serialization. Args: @@ -712,7 +712,7 @@ def _convert_to_sse(self, obj) -> bytes: sse_data = f"data: {json_string}\n\n" return sse_data.encode("utf-8") - def _sync_stream_with_error_handling(self, generator): + def _sync_stream_with_error_handling(self, generator: Any) -> Any: """Wrap sync generator to handle errors and convert to SSE format.""" try: for value in generator: diff --git a/src/bedrock_agentcore/runtime/context.py b/src/bedrock_agentcore/runtime/context.py index b0c3090d..3a6fffc5 100644 --- a/src/bedrock_agentcore/runtime/context.py +++ b/src/bedrock_agentcore/runtime/context.py @@ -32,7 +32,7 @@ class BedrockAgentCoreContext: _request_headers: ContextVar[Optional[Dict[str, str]]] = ContextVar("request_headers") @classmethod - def set_workload_access_token(cls, token: str): + def set_workload_access_token(cls, token: str) -> None: """Set the workload access token in the context.""" cls._workload_access_token.set(token) @@ -45,7 +45,7 @@ def get_workload_access_token(cls) -> Optional[str]: return None @classmethod - def set_oauth2_callback_url(cls, workload_callback_url: str): + def set_oauth2_callback_url(cls, workload_callback_url: str) -> None: """Set the oauth2 callback url in the context.""" cls._oauth2_callback_url.set(workload_callback_url) @@ -58,7 +58,7 @@ def get_oauth2_callback_url(cls) -> Optional[str]: return None @classmethod - def set_request_context(cls, request_id: str, session_id: Optional[str] = None): + def set_request_context(cls, request_id: str, session_id: Optional[str] = None) -> None: """Set request-scoped identifiers.""" cls._request_id.set(request_id) cls._session_id.set(session_id) @@ -80,7 +80,7 @@ def get_session_id(cls) -> Optional[str]: return None @classmethod - def set_request_headers(cls, headers: Dict[str, str]): + def set_request_headers(cls, headers: Dict[str, str]) -> None: """Set request headers in the context.""" cls._request_headers.set(headers) diff --git a/src/bedrock_agentcore/runtime/utils.py b/src/bedrock_agentcore/runtime/utils.py index 351cdd0d..4660ed3f 100644 --- a/src/bedrock_agentcore/runtime/utils.py +++ b/src/bedrock_agentcore/runtime/utils.py @@ -16,7 +16,7 @@ def convert_complex_objects(obj: Any, _depth: int = 0) -> Any: # Handle dataclasses (like AgentResult) elif is_dataclass(obj): - return asdict(obj) + return asdict(obj) # type: ignore[arg-type] # Handle dictionaries recursively elif isinstance(obj, dict): diff --git a/src/bedrock_agentcore/services/identity.py b/src/bedrock_agentcore/services/identity.py index 5ffc8692..645095a6 100644 --- a/src/bedrock_agentcore/services/identity.py +++ b/src/bedrock_agentcore/services/identity.py @@ -83,19 +83,19 @@ def __init__(self, region: str): ) self.logger = logging.getLogger("bedrock_agentcore.identity_client") - def create_oauth2_credential_provider(self, req): + def create_oauth2_credential_provider(self, req: Dict[str, Any]) -> dict[Any, Any]: """Create an OAuth2 credential provider.""" self.logger.info("Creating OAuth2 credential provider...") - return self.cp_client.create_oauth2_credential_provider(**req) + return self.cp_client.create_oauth2_credential_provider(**req) # type: ignore[no-any-return] - def create_api_key_credential_provider(self, req): + def create_api_key_credential_provider(self, req: Dict[str, Any]) -> dict[Any, Any]: """Create an API key credential provider.""" self.logger.info("Creating API key credential provider...") - return self.cp_client.create_api_key_credential_provider(**req) + return self.cp_client.create_api_key_credential_provider(**req) # type: ignore[no-any-return] def get_workload_access_token( self, workload_name: str, user_token: Optional[str] = None, user_id: Optional[str] = None - ) -> Dict: + ) -> Dict[str, Any]: """Get a workload access token using workload name and optionally user token.""" if user_token: if user_id is not None: @@ -110,36 +110,36 @@ def get_workload_access_token( resp = self.dp_client.get_workload_access_token(workloadName=workload_name) self.logger.info("Successfully retrieved workload access token") - return resp + return resp # type: ignore[no-any-return] def create_workload_identity( self, name: Optional[str] = None, allowed_resource_oauth_2_return_urls: Optional[list[str]] = None - ) -> Dict: + ) -> Dict[str, Any]: """Create workload identity with optional name.""" self.logger.info("Creating workload identity...") if not name: name = f"workload-{uuid.uuid4().hex[:8]}" - return self.cp_client.create_workload_identity( + return self.cp_client.create_workload_identity( # type: ignore[no-any-return] name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls or [] ) - def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_urls: list[str]) -> Dict: + def update_workload_identity(self, name: str, allowed_resource_oauth_2_return_urls: list[str]) -> Dict[str, Any]: """Update an existing workload identity with allowed resource OAuth2 callback urls.""" self.logger.info( "Updating workload identity '%s' with callback urls: %s", name, allowed_resource_oauth_2_return_urls ) - return self.cp_client.update_workload_identity( + return self.cp_client.update_workload_identity( # type: ignore[no-any-return] name=name, allowedResourceOauth2ReturnUrls=allowed_resource_oauth_2_return_urls ) - def get_workload_identity(self, name: str) -> Dict: + def get_workload_identity(self, name: str) -> Dict[str, Any]: """Retrieves information about a workload identity.""" self.logger.info("Fetching workload identity '%s'", name) - return self.cp_client.get_workload_identity(name=name) + return self.cp_client.get_workload_identity(name=name) # type: ignore[no-any-return] def complete_resource_token_auth( self, session_uri: str, user_identifier: Union[UserTokenIdentifier, UserIdIdentifier] - ): + ) -> dict[Any, Any]: """Confirms the user authentication session for obtaining OAuth2.0 tokens for a resource.""" self.logger.info("Completing 3LO OAuth2 flow...") @@ -151,7 +151,7 @@ def complete_resource_token_auth( else: raise ValueError(f"Unexpected UserIdentifier: {user_identifier}") - return self.dp_client.complete_resource_token_auth(userIdentifier=user_identifier_value, sessionUri=session_uri) + return self.dp_client.complete_resource_token_auth(userIdentifier=user_identifier_value, sessionUri=session_uri) # type: ignore[no-any-return] async def get_token( self, @@ -192,7 +192,7 @@ async def get_token( self.logger.info("Getting OAuth2 token...") # Build parameters - req = { + req: Dict[str, Any] = { "resourceCredentialProviderName": provider_name, "scopes": scopes, "oauth2Flow": auth_flow, @@ -213,7 +213,7 @@ async def get_token( # If we got a token directly, return it if "accessToken" in response: - return response["accessToken"] + return response["accessToken"] # type: ignore[no-any-return] # If we got an authorization URL, handle the OAuth flow if "authorizationUrl" in response: @@ -245,4 +245,4 @@ async def get_api_key(self, *, provider_name: str, agent_identity_token: str) -> self.logger.info("Getting API key...") req = {"resourceCredentialProviderName": provider_name, "workloadIdentityToken": agent_identity_token} - return self.dp_client.get_resource_api_key(**req)["apiKey"] + return self.dp_client.get_resource_api_key(**req)["apiKey"] # type: ignore[no-any-return] diff --git a/src/bedrock_agentcore/tools/browser_client.py b/src/bedrock_agentcore/tools/browser_client.py index 0f6f8308..6a39e137 100644 --- a/src/bedrock_agentcore/tools/browser_client.py +++ b/src/bedrock_agentcore/tools/browser_client.py @@ -25,9 +25,9 @@ from .config import BrowserExtension, ProfileConfiguration, ProxyConfiguration, ViewportConfiguration -def _to_dict(value): +def _to_dict(value: Any) -> Dict[str, Any]: """Convert a dataclass or dict to a dict. Passes dicts through unchanged.""" - return value.to_dict() if hasattr(value, "to_dict") else value + return value.to_dict() if hasattr(value, "to_dict") else value # type: ignore[no-any-return] DEFAULT_IDENTIFIER = "aws.browser.v1" @@ -85,8 +85,8 @@ def __init__(self, region: str, integration_source: Optional[str] = None) -> Non config=client_config, ) - self._identifier = None - self._session_id = None + self._identifier: Optional[str] = None + self._session_id: Optional[str] = None @property def identifier(self) -> Optional[str]: @@ -94,7 +94,7 @@ def identifier(self) -> Optional[str]: return self._identifier @identifier.setter - def identifier(self, value: Optional[str]): + def identifier(self, value: Optional[str]) -> None: """Set the browser identifier.""" self._identifier = value @@ -104,7 +104,7 @@ def session_id(self) -> Optional[str]: return self._session_id @session_id.setter - def session_id(self, value: Optional[str]): + def session_id(self, value: Optional[str]) -> None: """Set the session ID.""" self._session_id = value @@ -178,7 +178,7 @@ def create_browser( """ self.logger.info("Creating browser: %s", name) - request_params = { + request_params: Dict[str, Any] = { "name": name, "executionRoleArn": execution_role_arn, "networkConfiguration": network_configuration or {"networkMode": "PUBLIC"}, @@ -201,7 +201,7 @@ def create_browser( request_params["clientToken"] = client_token response = self.control_plane_client.create_browser(**request_params) - return response + return response # type: ignore[no-any-return] def delete_browser(self, browser_id: str, client_token: Optional[str] = None) -> Dict: """Delete a custom browser. @@ -221,12 +221,12 @@ def delete_browser(self, browser_id: str, client_token: Optional[str] = None) -> """ self.logger.info("Deleting browser: %s", browser_id) - request_params = {"browserId": browser_id} + request_params: Dict[str, Any] = {"browserId": browser_id} if client_token: request_params["clientToken"] = client_token response = self.control_plane_client.delete_browser(**request_params) - return response + return response # type: ignore[no-any-return] def get_browser(self, browser_id: str) -> Dict: """Get detailed information about a browser. @@ -253,7 +253,7 @@ def get_browser(self, browser_id: str) -> Dict: """ self.logger.info("Getting browser: %s", browser_id) response = self.control_plane_client.get_browser(browserId=browser_id) - return response + return response # type: ignore[no-any-return] def list_browsers( self, @@ -281,14 +281,14 @@ def list_browsers( """ self.logger.info("Listing browsers (type=%s)", browser_type) - request_params = {"maxResults": max_results} + request_params: Dict[str, Any] = {"maxResults": max_results} if browser_type: request_params["type"] = browser_type if next_token: request_params["nextToken"] = next_token response = self.control_plane_client.list_browsers(**request_params) - return response + return response # type: ignore[no-any-return] def start( self, @@ -355,7 +355,7 @@ def start( """ self.logger.info("Starting browser session...") - request_params = { + request_params: Dict[str, Any] = { "browserIdentifier": identifier, "name": name or f"browser-session-{uuid.uuid4().hex[:8]}", "sessionTimeoutSeconds": session_timeout_seconds, @@ -431,7 +431,7 @@ def get_session(self, browser_id: Optional[str] = None, session_id: Optional[str self.logger.info("Getting session: %s", session_id) response = self.data_plane_client.get_browser_session(browserIdentifier=browser_id, sessionId=session_id) - return response + return response # type: ignore[no-any-return] def list_sessions( self, @@ -465,14 +465,14 @@ def list_sessions( self.logger.info("Listing sessions for browser: %s", browser_id) - request_params = {"browserIdentifier": browser_id, "maxResults": max_results} + request_params: Dict[str, Any] = {"browserIdentifier": browser_id, "maxResults": max_results} if status: request_params["status"] = status if next_token: request_params["nextToken"] = next_token response = self.data_plane_client.list_browser_sessions(**request_params) - return response + return response # type: ignore[no-any-return] def update_stream( self, @@ -599,9 +599,9 @@ def generate_live_view_url(self, expires: int = DEFAULT_LIVE_VIEW_PRESIGNED_URL_ if not request.url: raise RuntimeError("Failed to generate live view url") - return request.url + return request.url # type: ignore[no-any-return] - def take_control(self): + def take_control(self) -> None: """Take control of the browser by disabling automation stream.""" self.logger.info("Taking control of browser session...") @@ -613,7 +613,7 @@ def take_control(self): self.update_stream("DISABLED") - def release_control(self): + def release_control(self) -> None: """Release control by enabling automation stream.""" self.logger.info("Releasing control of browser session...") @@ -674,7 +674,7 @@ def browser_session( ... ws_url, headers = client.generate_ws_headers() """ client = BrowserClient(region) - start_kwargs = {} + start_kwargs: Dict[str, Any] = {} if viewport is not None: start_kwargs["viewport"] = viewport if identifier is not None: diff --git a/src/bedrock_agentcore/tools/code_interpreter_client.py b/src/bedrock_agentcore/tools/code_interpreter_client.py index 636b7eb5..106ce5f2 100644 --- a/src/bedrock_agentcore/tools/code_interpreter_client.py +++ b/src/bedrock_agentcore/tools/code_interpreter_client.py @@ -104,8 +104,8 @@ def __init__( config=data_config, ) - self._identifier = None - self._session_id = None + self._identifier: Optional[str] = None + self._session_id: Optional[str] = None self._file_descriptions: Dict[str, str] = {} @property @@ -114,7 +114,7 @@ def identifier(self) -> Optional[str]: return self._identifier @identifier.setter - def identifier(self, value: Optional[str]): + def identifier(self, value: Optional[str]) -> None: """Set the code interpreter identifier.""" self._identifier = value @@ -124,7 +124,7 @@ def session_id(self) -> Optional[str]: return self._session_id @session_id.setter - def session_id(self, value: Optional[str]): + def session_id(self, value: Optional[str]) -> None: """Set the session ID.""" self._session_id = value @@ -184,7 +184,7 @@ def create_code_interpreter( """ self.logger.info("Creating code interpreter: %s", name) - request_params = { + request_params: Dict[str, Any] = { "name": name, "executionRoleArn": execution_role_arn, "networkConfiguration": network_configuration or {"networkMode": "PUBLIC"}, @@ -200,7 +200,7 @@ def create_code_interpreter( request_params["clientToken"] = client_token response = self.control_plane_client.create_code_interpreter(**request_params) - return response + return response # type: ignore[no-any-return] def delete_code_interpreter(self, interpreter_id: str, client_token: Optional[str] = None) -> Dict: """Delete a custom code interpreter. @@ -220,12 +220,12 @@ def delete_code_interpreter(self, interpreter_id: str, client_token: Optional[st """ self.logger.info("Deleting code interpreter: %s", interpreter_id) - request_params = {"codeInterpreterId": interpreter_id} + request_params: Dict[str, Any] = {"codeInterpreterId": interpreter_id} if client_token: request_params["clientToken"] = client_token response = self.control_plane_client.delete_code_interpreter(**request_params) - return response + return response # type: ignore[no-any-return] def get_code_interpreter(self, interpreter_id: str) -> Dict: """Get detailed information about a code interpreter. @@ -248,7 +248,7 @@ def get_code_interpreter(self, interpreter_id: str) -> Dict: """ self.logger.info("Getting code interpreter: %s", interpreter_id) response = self.control_plane_client.get_code_interpreter(codeInterpreterId=interpreter_id) - return response + return response # type: ignore[no-any-return] def list_code_interpreters( self, @@ -276,14 +276,14 @@ def list_code_interpreters( """ self.logger.info("Listing code interpreters (type=%s)", interpreter_type) - request_params = {"maxResults": max_results} + request_params: Dict[str, Any] = {"maxResults": max_results} if interpreter_type: request_params["type"] = interpreter_type if next_token: request_params["nextToken"] = next_token response = self.control_plane_client.list_code_interpreters(**request_params) - return response + return response # type: ignore[no-any-return] def start( self, @@ -376,7 +376,7 @@ def get_session(self, interpreter_id: Optional[str] = None, session_id: Optional response = self.data_plane_client.get_code_interpreter_session( codeInterpreterIdentifier=interpreter_id, sessionId=session_id ) - return response + return response # type: ignore[no-any-return] def list_sessions( self, @@ -410,16 +410,16 @@ def list_sessions( self.logger.info("Listing sessions for interpreter: %s", interpreter_id) - request_params = {"codeInterpreterIdentifier": interpreter_id, "maxResults": max_results} + request_params: Dict[str, Any] = {"codeInterpreterIdentifier": interpreter_id, "maxResults": max_results} if status: request_params["status"] = status if next_token: request_params["nextToken"] = next_token response = self.data_plane_client.list_code_interpreter_sessions(**request_params) - return response + return response # type: ignore[no-any-return] - def invoke(self, method: str, params: Optional[Dict] = None): + def invoke(self, method: str, params: Optional[Dict] = None) -> Dict[str, Any]: r"""Invoke a method in the code interpreter sandbox. If no session is active, automatically starts a new session. @@ -442,7 +442,7 @@ def invoke(self, method: str, params: Optional[Dict] = None): if not self.session_id or not self.identifier: self.start() - return self.data_plane_client.invoke_code_interpreter( + return self.data_plane_client.invoke_code_interpreter( # type: ignore[no-any-return] codeInterpreterIdentifier=self.identifier, sessionId=self.session_id, name=method, @@ -546,8 +546,8 @@ def upload_files( if path.startswith("/"): raise ValueError(f"Path must be relative, not absolute. Got: {path}") - if isinstance(content, bytes): - file_contents.append({"path": path, "blob": base64.b64encode(content).decode("utf-8")}) + if isinstance(content, bytes): # type: ignore[unreachable] + file_contents.append({"path": path, "blob": base64.b64encode(content).decode("utf-8")}) # type: ignore[unreachable] else: file_contents.append({"path": path, "text": content}) @@ -630,7 +630,7 @@ def download_file( if content_item.get("type") == "resource": resource = content_item.get("resource", {}) if "text" in resource: - return resource["text"] + return resource["text"] # type: ignore[no-any-return] elif "blob" in resource: raw = base64.b64decode(resource["blob"]) try: diff --git a/src/bedrock_agentcore/tools/config.py b/src/bedrock_agentcore/tools/config.py index bbd062e2..a3dcd87a 100644 --- a/src/bedrock_agentcore/tools/config.py +++ b/src/bedrock_agentcore/tools/config.py @@ -5,7 +5,7 @@ """ from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional @dataclass @@ -37,7 +37,7 @@ class NetworkConfiguration: network_mode: str = "PUBLIC" vpc_config: Optional[VpcConfig] = None - def __post_init__(self): + def __post_init__(self) -> None: """Validate configuration.""" if self.network_mode not in ["PUBLIC", "VPC"]: raise ValueError(f"network_mode must be 'PUBLIC' or 'VPC', got '{self.network_mode}'") @@ -47,7 +47,7 @@ def __post_init__(self): def to_dict(self) -> Dict: """Convert to API-compatible dictionary.""" - config = {"networkMode": self.network_mode} + config: Dict[str, Any] = {"networkMode": self.network_mode} if self.vpc_config: config["vpcConfig"] = self.vpc_config.to_dict() return config @@ -105,7 +105,7 @@ class RecordingConfiguration: def to_dict(self) -> Dict: """Convert to API-compatible dictionary.""" - config = {"enabled": self.enabled} + config: Dict[str, Any] = {"enabled": self.enabled} if self.s3_location: config["s3Location"] = self.s3_location.to_dict() return config @@ -275,7 +275,7 @@ class ProxyConfiguration: def to_dict(self) -> Dict: """Convert to API-compatible dictionary.""" - config = {"proxies": [p.to_dict() for p in self.proxies]} + config: Dict[str, Any] = {"proxies": [p.to_dict() for p in self.proxies]} if self.bypass_patterns: config["bypass"] = {"domainPatterns": self.bypass_patterns} return config @@ -356,7 +356,7 @@ class SessionConfiguration: def to_dict(self) -> Dict: """Convert to API-compatible dictionary.""" - config = {} + config: Dict[str, Any] = {} if self.name is not None: config["name"] = self.name if self.viewport: @@ -396,7 +396,7 @@ class BrowserConfiguration: def to_dict(self) -> Dict: """Convert to API-compatible dictionary for create_browser.""" - config = { + config: Dict[str, Any] = { "name": self.name, "executionRoleArn": self.execution_role_arn, "networkConfiguration": self.network_configuration.to_dict(), @@ -437,7 +437,7 @@ class CodeInterpreterConfiguration: def to_dict(self) -> Dict: """Convert to API-compatible dictionary for create_code_interpreter.""" - config = { + config: Dict[str, Any] = { "name": self.name, "executionRoleArn": self.execution_role_arn, "networkConfiguration": self.network_configuration.to_dict(),