Skip to content

Commit

Permalink
chore: avoid implicit optional in type annotations of method (langgen…
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 authored and lau-td committed Oct 23, 2024
1 parent 605d655 commit cf1a7f0
Show file tree
Hide file tree
Showing 37 changed files with 91 additions and 71 deletions.
2 changes: 1 addition & 1 deletion api/core/agent/cot_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit])
return message

def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
self, current_session_messages: Optional[list[PromptMessage]] = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
Expand Down
2 changes: 1 addition & 1 deletion api/core/agent/cot_chat_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _organize_system_prompt(self) -> SystemPromptMessage:

return SystemPromptMessage(content=system_prompt)

def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
Expand Down
3 changes: 2 additions & 1 deletion api/core/agent/cot_completion_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Optional

from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
Expand All @@ -21,7 +22,7 @@ def _organize_instruction_prompt(self) -> str:

return system_prompt

def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
def _organize_historic_prompt(self, current_session_messages: Optional[list[PromptMessage]] = None) -> str:
"""
Organize historic prompt
"""
Expand Down
6 changes: 3 additions & 3 deletions api/core/agent/fc_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from typing import Any, Optional, Union

from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
Expand Down Expand Up @@ -370,7 +370,7 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list
return tool_calls

def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
) -> list[PromptMessage]:
"""
Initialize system message
Expand All @@ -385,7 +385,7 @@ def _init_system_message(

return prompt_messages

def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
Expand Down
4 changes: 2 additions & 2 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def indexing_estimate(
tenant_id: str,
extract_settings: list[ExtractSetting],
tmp_processing_rule: dict,
doc_form: str = None,
doc_form: Optional[str] = None,
doc_language: str = "English",
dataset_id: str = None,
dataset_id: Optional[str] = None,
indexing_technique: str = "economy",
) -> dict:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _code_block_mode_wrapper(
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/bedrock/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _code_block_mode_wrapper(
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def _num_tokens_from_messages(
model: str,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
credentials: dict = None,
credentials: Optional[dict] = None,
) -> int:
"""
Approximate num tokens with GPT2 tokenizer.
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/openai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _code_block_mode_wrapper(
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def _num_tokens_from_messages(
model: str,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
credentials: dict = None,
credentials: Optional[dict] = None,
) -> int:
"""
Approximate num tokens with GPT2 tokenizer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""
pass

def _detect_lang_code(self, content: str, map_dict: dict = None):
def _detect_lang_code(self, content: str, map_dict: Optional[dict] = None):
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}

response = self.comprehend_client.detect_dominant_language(Text=content)
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/wenxin/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _code_block_mode_wrapper(
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: list[Callback] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def conversation(
conversation_id: Optional[str] = None,
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
metadata: dict | None = None,
request_id: str = None,
user_id: str = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -72,9 +72,9 @@ def conversation(
def query_support(
self,
*,
assistant_id_list: list[str] = None,
request_id: str = None,
user_id: str = None,
assistant_id_list: Optional[list[str]] = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
Expand All @@ -99,8 +99,8 @@ def query_conversation_usage(
page: int = 1,
page_size: int = 10,
*,
request_id: str = None,
user_id: str = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal, Optional, cast

import httpx

Expand Down Expand Up @@ -34,11 +34,11 @@ def __init__(self, client: ZhipuAI) -> None:
def create(
self,
*,
file: FileTypes = None,
upload_detail: list[UploadDetail] = None,
file: Optional[FileTypes] = None,
upload_detail: Optional[list[UploadDetail]] = None,
purpose: Literal["fine-tune", "retrieval", "batch"],
knowledge_id: str = None,
sentence_size: int = None,
knowledge_id: Optional[str] = None,
sentence_size: Optional[int] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def __init__(self, client: ZhipuAI) -> None:
def create(
self,
*,
file: FileTypes = None,
file: Optional[FileTypes] = None,
custom_separator: Optional[list[str]] = None,
upload_detail: list[UploadDetail] = None,
upload_detail: Optional[list[UploadDetail]] = None,
purpose: Literal["retrieval"],
knowledge_id: str = None,
sentence_size: int = None,
knowledge_id: Optional[str] = None,
sentence_size: Optional[int] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def generations(
self,
model: str,
*,
prompt: str = None,
image_url: str = None,
prompt: Optional[str] = None,
image_url: Optional[str] = None,
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
request_id: str = None,
user_id: str = None,
request_id: Optional[str] = None,
user_id: Optional[str] = None,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/relyt/relyt_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_ids_by_metadata_field(self, key: str, value: str):
else:
return None

def delete_by_uuids(self, ids: list[str] = None):
def delete_by_uuids(self, ids: Optional[list[str]] = None):
"""Delete by vector IDs.
Args:
Expand Down
6 changes: 3 additions & 3 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Optional

from configs import dify_config
from core.embedding.cached_embedding import CacheEmbedding
Expand All @@ -25,7 +25,7 @@ def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> dict


class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
def __init__(self, dataset: Dataset, attributes: Optional[list] = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self._dataset = dataset
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

def create(self, texts: list = None, **kwargs):
def create(self, texts: Optional[list] = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions api/core/rag/extractor/extract_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import tempfile
from pathlib import Path
from typing import Union
from typing import Optional, Union
from urllib.parse import unquote

from configs import dify_config
Expand Down Expand Up @@ -84,7 +84,7 @@ def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Docume

@classmethod
def extract(
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str = None
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: Optional[str] = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE.value:
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Optional

from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
Expand All @@ -17,7 +18,7 @@ class UnstructuredEpubExtractor(BaseExtractor):
def __init__(
self,
file_path: str,
api_url: str = None,
api_url: Optional[str] = None,
):
"""Initialize with file path."""
self._file_path = file_path
Expand Down
2 changes: 1 addition & 1 deletion api/core/tools/entities/tool_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def set_text(self, tool_name: str, name: str, value: str) -> None:

self.pool.append(variable)

def set_file(self, tool_name: str, value: str, name: str = None) -> None:
def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
"""
set an image variable
Expand Down
4 changes: 2 additions & 2 deletions api/core/tools/provider/builtin/aws/tools/sagemaker_tts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from enum import Enum
from typing import Any, Union
from typing import Any, Optional, Union

import boto3

Expand All @@ -21,7 +21,7 @@ class SageMakerTTSTool(BuiltinTool):
s3_client: Any = None
comprehend_client: Any = None

def _detect_lang_code(self, content: str, map_dict: dict = None):
def _detect_lang_code(self, content: str, map_dict: Optional[dict] = None):
map_dict = {"zh": "<|zh|>", "en": "<|en|>", "ja": "<|jp|>", "zh-TW": "<|yue|>", "ko": "<|ko|>"}

response = self.comprehend_client.detect_dominant_language(Text=content)
Expand Down
4 changes: 3 additions & 1 deletion api/core/tools/tool/builtin_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.entities.tool_entities import ToolProviderType
Expand Down Expand Up @@ -124,7 +126,7 @@ def summarize(content: str) -> str:

return result

def get_url(self, url: str, user_agent: str = None) -> str:
def get_url(self, url: str, user_agent: Optional[str] = None) -> str:
"""
get url
"""
Expand Down
2 changes: 1 addition & 1 deletion api/core/tools/tool/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as)

def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = "") -> ToolInvokeMessage:
def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
"""
create a blob message
Expand Down
4 changes: 2 additions & 2 deletions api/core/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Generator
from os import listdir, path
from threading import Lock
from typing import Any, Union
from typing import Any, Optional, Union

from configs import dify_config
from core.agent.entities import AgentToolEntity
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:

@classmethod
def get_tool(
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None
cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: Optional[str] = None
) -> Union[BuiltinTool, ApiTool]:
"""
get the tool
Expand Down
9 changes: 8 additions & 1 deletion api/core/tools/utils/feishu_api_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import httpx

from core.tools.errors import ToolProviderCredentialValidationError
Expand Down Expand Up @@ -32,7 +34,12 @@ def tenant_access_token(self):
return res.get("tenant_access_token")

def _send_request(
self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None
self,
url: str,
method: str = "post",
require_token: bool = True,
payload: Optional[dict] = None,
params: Optional[dict] = None,
):
headers = {
"Content-Type": "application/json",
Expand Down
Loading

0 comments on commit cf1a7f0

Please sign in to comment.