Skip to content

Commit

Permalink
merge develop branch
Browse files Browse the repository at this point in the history
  • Loading branch information
wj-Mcat committed Apr 11, 2024
2 parents 88e60ef + 157ca7e commit 02f27a1
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 20 deletions.
53 changes: 43 additions & 10 deletions erniebot-agent/applications/erniebot_researcher/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,52 @@ wget https://paddlenlp.bj.bcebos.com/pipelines/fonts/SimSun.ttf

> 第四步:创建索引
下载实例数据
**数据准备**

我们支持docx、pdf、txt等格式的文件,用户可以把这些文件放到同一个文件夹下,然后运行下面的命令创建索引,后续会根据这些文件写报告。

为了方便测试,我们提供了样例数据。
样例数据:

```
wget https://paddlenlp.bj.bcebos.com/pipelines/erniebot_researcher_example.tar.gz
tar xvf erniebot_researcher_example.tar.gz
```

首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量:
url数据:

如果用户有文件对应的url链接,可以传入存储url链接的txt。在txt中,每一行存储url链接和对应文件的路径,例如:
```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
export AISTUDIO_ACCESS_TOKEN=<aistudio-access-token>
https://zhuanlan.zhihu.com/p/659457816 erniebot_researcher_example/Ai_Agent的起源.md
```
如果用户不传入url文件,则默认文件的路径为其url链接

如果用户有url链接,你可以传入存储url链接的txt。
在txt中,每一行存储文件的路径和对应的url链接,例如:
'https://zhuanlan.zhihu.com/p/659457816 erniebot_researcher_example/Ai_Agent的起源.md'
摘要数据:

如果用户不传入url文件,则默认文件的路径为其url链接
用户可以利用path_abstract参数传入自己文件对应摘要的存储路径。
其中摘要需要用json文件存储。其中json文件内存储的是多个字典,每个字典有3组键值对,
- `page_content` : `str`, 文件摘要。
- `url` : `str`, 文件url链接。
- `name` : `str`, 文件名字。

例如:

```
[{"page_content":"文件摘要","url":"https://zhuanlan.zhihu.com/p/659457816","name":Ai_Agent的起源},
...]
```

如果用户没有摘要路径,则无需改变path_abstract的默认值,我们会利用ernie-4.0来自动生成摘要,生成的摘要存储路径为abstract.json。

**创建索引**

首先需要在[AI Studio星河社区](https://aistudio.baidu.com/index)注册并登录账号,然后在AI Studio的[访问令牌页面](https://aistudio.baidu.com/index/accessToken)获取`Access Token`,最后设置环境变量:

**有摘要有url链接**

用户可以自己传入文件摘要的存储路径。其中摘要需要用json文件存储。其中json文件内存储的是多个字典,每个字典有3组键值对,"page_content"存储文件的摘要,"url"是文件的url链接,"name"是文章的名字。例如:
[{"page_content":"文章摘要","url":"https://zhuanlan.zhihu.com/p/659457816","name":Ai_Agent的起源},...]
```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
export AISTUDIO_ACCESS_TOKEN=<aistudio-access-token>
python ./tools/preprocessing.py \
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text> \
Expand All @@ -97,6 +120,16 @@ python ./tools/preprocessing.py \
--path_abstract <the json path of your abstract text>
```

**无摘要无url链接**

```
export EB_AGENT_ACCESS_TOKEN=<aistudio-access-token>
export AISTUDIO_ACCESS_TOKEN=<aistudio-access-token>
python ./tools/preprocessing.py \
--index_name_full_text <the index name of your full text> \
--index_name_abstract <the index name of your abstract text> \
--path_full_text <the folder path of your full text>
```
> 第五步:运行

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ async def run(self, query: str):
for sub_query in sub_queries:
research_result = await self.run_search_summary(sub_query)
paragraphs_item.extend(research_result)

paragraphs = []
for item in paragraphs_item:
if item not in paragraphs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from llama_index.schema import NodeWithScore, TextNode
from llama_index.core.schema import NodeWithScore, TextNode

from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool

Expand Down
1 change: 0 additions & 1 deletion erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def handle_response(cls, resp: EBResponse) -> EBResponse:
if "error_code" in resp and "error_msg" in resp:
ecode = resp["error_code"]
emsg = resp["error_msg"]
print(ecode)
if ecode in (4, 17):
raise errors.RequestLimitError(emsg, ecode=ecode)
elif ecode in (13, 15, 18):
Expand Down
17 changes: 17 additions & 0 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union

import erniebot.utils.logging as logging
from erniebot.api_types import APIType
from erniebot.backends.bce import QianfanLegacyBackend
from erniebot.response import EBResponse
Expand All @@ -29,6 +31,10 @@ class CustomBackend(EBBackend):

def __init__(self, config_dict: Dict[str, Any]) -> None:
super().__init__(config_dict=config_dict)
access_token = self._cfg.get("access_token", None)
if access_token is None:
access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None)
self._access_token = access_token

def request(
self,
Expand Down Expand Up @@ -71,6 +77,8 @@ async def arequest(
supplied_headers=headers,
params=params,
)
if self._access_token is not None:
headers = self._add_aistudio_fields_to_headers(headers)
return await self._client.asend_request(
method,
url,
Expand All @@ -83,3 +91,12 @@ async def arequest(
@classmethod
def handle_response(cls, resp: EBResponse) -> EBResponse:
return QianfanLegacyBackend.handle_response(resp)

def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
if "Authorization" in headers:
logging.warning(
"Key 'Authorization' already exists in `headers`: %r",
headers["Authorization"],
)
headers["Authorization"] = f"{self._access_token}"
return headers
2 changes: 1 addition & 1 deletion erniebot/src/erniebot/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
DEFAULT_REQUEST_TIMEOUT_SECS: Final[float] = 600

POLLING_INTERVAL_SECS: Final[float] = 5
POLLING_TIMEOUT_SECS: Final[float] = 20
POLLING_TIMEOUT_SECS: Final[float] = 600
1 change: 0 additions & 1 deletion erniebot/src/erniebot/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ def _interpret_response_line(
logging.debug("Decoded response body: %r", decoded_rbody)

response = EBResponse(rcode=rcode, rbody=decoded_rbody, rheaders=dict(rheaders))

if rcode != http.HTTPStatus.OK:
raise errors.HTTPRequestError(
f"The status code is not {http.HTTPStatus.OK}.",
Expand Down
21 changes: 18 additions & 3 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
},
},
}
Expand Down Expand Up @@ -502,8 +511,14 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:

# params
params = {}
if model == "ernie-turbo":
for arg in ("functions", "stop", "disable_search", "enable_citation"):
if model in ("ernie-turbo", "ernie-speed"):
for arg in (
"functions",
"stop",
"disable_search",
"enable_citation",
"tool_choice",
):
if arg in kwargs:
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model} model.")
params["messages"] = messages
Expand All @@ -529,7 +544,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:

# headers
headers: HeadersType = {}
if self.api_type is APIType.AISTUDIO:
if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM:
headers["Content-Type"] = "application/json"
if "headers" in kwargs:
headers.update(kwargs["headers"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ class ChatCompletionWithPlugins(EBResource, CreatableWithStreaming):
)
_API_INFO_DICT: ClassVar[Dict[APIType, Dict[str, Any]]] = {
APIType.QIANFAN: {
"path": "/erniebot/plugin",
"path": "/erniebot/plugins",
},
APIType.CUSTOM: {
"path": "/erniebot/plugins_v3",
},
APIType.AISTUDIO: {
"path": "/erniebot/plugin",
"path": "/erniebot/plugins",
},
}

Expand Down

0 comments on commit 02f27a1

Please sign in to comment.