Skip to content

Commit

Permalink
Use latest PMAT with db_cache (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Nov 6, 2024
1 parent 8df988f commit e1450b5
Show file tree
Hide file tree
Showing 11 changed files with 920 additions and 844 deletions.
1,705 changes: 905 additions & 800 deletions poetry.lock

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions prediction_prophet/autonolas/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import math
import tenacity
from datetime import timedelta
from sklearn.metrics.pairwise import cosine_similarity
from typing import Any, Dict, Generator, List, Optional, Tuple, TypedDict
from datetime import datetime, timezone
Expand Down Expand Up @@ -31,9 +32,10 @@

from dateutil import parser
from prediction_prophet.functions.utils import check_not_none
from prediction_market_agent_tooling.gtypes import Probability
from prediction_market_agent_tooling.tools.utils import secret_str_from_env
from prediction_prophet.functions.cache import persistent_inmemory_cache
from prediction_market_agent_tooling.gtypes import Probability
from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.tools.caches.db_cache import db_cache
from prediction_prophet.functions.parallelism import par_map
from prediction_market_agent_tooling.config import APIKeys
from pydantic.types import SecretStr
Expand Down Expand Up @@ -359,7 +361,7 @@ def fields_dict_to_bullet_list(fields_dict: Dict[str, str]) -> str:
return bullet_list

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True)
@persistent_inmemory_cache
@db_cache(max_age=timedelta(days=1))
def search_google(query: str, num: int = 3) -> List[str]:
"""Search Google using a custom search engine."""
service = build("customsearch", "v1", developerKey=os.getenv("GOOGLE_SEARCH_API_KEY"))
Expand Down Expand Up @@ -694,7 +696,7 @@ def concatenate_short_sentences(sentences: list[str], len_sentence_threshold: in
return modified_sentences


@persistent_inmemory_cache
@db_cache
def openai_embedding_cached(text: str, model: str = "text-embedding-ada-002") -> list[float]:
emb = OpenAIEmbeddings(model=model)
vector: list[float] = emb.embed_query(text)
Expand Down
3 changes: 0 additions & 3 deletions prediction_prophet/benchmark/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def __init__(
max_results_per_search: int = 5,
min_scraped_sites: int = 5,
max_workers: t.Optional[int] = None,
tavily_storage: TavilyStorage | None = None,
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(),
):
super().__init__(agent_name=agent_name, max_workers=max_workers)
Expand All @@ -183,7 +182,6 @@ def __init__(
self.subqueries_limit = subqueries_limit
self.max_results_per_search = max_results_per_search
self.min_scraped_sites = min_scraped_sites
self.tavily_storage = tavily_storage
self.logger = logger

def is_predictable(self, market_question: str) -> bool:
Expand All @@ -205,7 +203,6 @@ def research(self, market_question: str) -> Research:
subqueries_limit=self.subqueries_limit,
max_results_per_search=self.max_results_per_search,
min_scraped_sites=self.min_scraped_sites,
tavily_storage=self.tavily_storage,
logger=self.logger,
)

Expand Down
19 changes: 0 additions & 19 deletions prediction_prophet/functions/cache.py

This file was deleted.

2 changes: 0 additions & 2 deletions prediction_prophet/functions/prepare_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from prediction_prophet.functions.cache import persistent_inmemory_cache
from prediction_prophet.functions.utils import trim_to_n_tokens
from prediction_market_agent_tooling.config import APIKeys
from pydantic.types import SecretStr
from prediction_market_agent_tooling.gtypes import secretstr_to_v1_secretstr
from prediction_market_agent_tooling.tools.langfuse_ import get_langfuse_langchain_config, observe

@persistent_inmemory_cache
@observe()
def prepare_summary(goal: str, content: str, model: str, api_key: SecretStr | None = None, trim_content_to_tokens: t.Optional[int] = None) -> str:
if api_key == None:
Expand Down
3 changes: 0 additions & 3 deletions prediction_prophet/functions/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from pydantic.types import SecretStr
from pydantic import BaseModel
from prediction_market_agent_tooling.tools.langfuse_ import observe
from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage

if t.TYPE_CHECKING:
from loguru import Logger
Expand Down Expand Up @@ -43,7 +42,6 @@ def research(
openai_api_key: SecretStr | None = None,
tavily_api_key: SecretStr | None = None,
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(),
tavily_storage: TavilyStorage | None = None,
) -> Research:
# Validate args
if min_scraped_sites > max_results_per_search * subqueries_limit:
Expand All @@ -70,7 +68,6 @@ def research(
queries,
lambda result: not result.url.startswith("https://www.youtube"),
tavily_api_key=tavily_api_key,
tavily_storage=tavily_storage,
max_results_per_search=max_results_per_search,
)

Expand Down
7 changes: 2 additions & 5 deletions prediction_prophet/functions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from prediction_prophet.functions.web_search import WebSearchResult, web_search
from concurrent.futures import ThreadPoolExecutor, as_completed
from pydantic.types import SecretStr
from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage

if t.TYPE_CHECKING:
from loguru import Logger
Expand All @@ -13,11 +12,10 @@ def safe_web_search(
query: str,
max_results: int = 5,
tavily_api_key: SecretStr | None = None,
tavily_storage: TavilyStorage | None = None,
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(),
) -> t.Optional[list[WebSearchResult]]:
try:
return web_search(query, max_results, tavily_api_key, tavily_storage)
return web_search(query, max_results, tavily_api_key)
except Exception as e:
logger.warning(f"Error when searching for `{query}` in web_search: {e}")
return None
Expand All @@ -27,7 +25,6 @@ def search(
queries: list[str],
filter: t.Callable[[WebSearchResult], bool] = lambda x: True,
tavily_api_key: SecretStr | None = None,
tavily_storage: TavilyStorage | None = None,
max_results_per_search: int = 5,
logger: t.Union[logging.Logger, "Logger"] = logging.getLogger(),
) -> list[tuple[str, WebSearchResult]]:
Expand All @@ -36,7 +33,7 @@ def search(
# Each result will have a query associated with it
# We only want to keep the results that are unique
with ThreadPoolExecutor(max_workers=5) as executor:
futures = {executor.submit(safe_web_search, query, max_results_per_search, tavily_api_key, tavily_storage) for query in queries}
futures = {executor.submit(safe_web_search, query, max_results_per_search, tavily_api_key) for query in queries}
for future in as_completed(futures):
maybe_results.append(future.result())

Expand Down
4 changes: 2 additions & 2 deletions prediction_prophet/functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from typing import NoReturn, Type, TypeVar, Optional
from googleapiclient.discovery import build
from prediction_prophet.functions.cache import persistent_inmemory_cache
from prediction_market_agent_tooling.tools.caches.db_cache import db_cache

T = TypeVar("T")

Expand Down Expand Up @@ -58,7 +58,7 @@ def trim_to_n_tokens(content: str, n: int, model: str) -> str:
return encoder.decode(encoder.encode(content)[:n])


@persistent_inmemory_cache
@db_cache
def url_is_older_than(url: str, older_than: datetime) -> bool:
service = build("customsearch", "v1", developerKey=os.environ["GOOGLE_SEARCH_API_KEY"])
date_restrict = f"d{(datetime.now().date() - older_than.date()).days}" # {d,w,m,y}N to restrict the search to the last N days, weeks, months or years.
Expand Down
5 changes: 3 additions & 2 deletions prediction_prophet/functions/web_scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
from bs4 import BeautifulSoup
from requests import Response
import tenacity
from prediction_prophet.functions.cache import persistent_inmemory_cache
from datetime import timedelta
from prediction_market_agent_tooling.tools.caches.db_cache import db_cache


@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_fixed(1), reraise=True)
@persistent_inmemory_cache
def fetch_html(url: str, timeout: int) -> Response:
headers = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:107.0) Gecko/20100101 Firefox/107.0"
}
response = requests.get(url, headers=headers, timeout=timeout)
return response

@db_cache(max_age=timedelta(days=1))
def web_scrape_strict(url: str, timeout: int = 10) -> str:
response = fetch_html(url=url, timeout=timeout)

Expand Down
4 changes: 1 addition & 3 deletions prediction_prophet/functions/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@

from prediction_prophet.models.WebSearchResult import WebSearchResult
from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage
from prediction_market_agent_tooling.tools.tavily.tavily_search import tavily_search



def web_search(query: str, max_results: int = 5, tavily_api_key: SecretStr | None = None, tavily_storage: TavilyStorage | None = None) -> list[WebSearchResult]:
def web_search(query: str, max_results: int = 5, tavily_api_key: SecretStr | None = None) -> list[WebSearchResult]:
response = tavily_search(
query=query,
search_depth="advanced",
max_results=max_results,
include_raw_content=True,
tavily_storage=tavily_storage,
api_keys=APIKeys(TAVILY_API_KEY=tavily_api_key) if tavily_api_key else None,
)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ scikit-learn = "^1.4.0"
typer = ">=0.9.0,<1.0.0"
types-requests = "^2.31.0.20240125"
types-python-dateutil = "^2.9.0"
prediction-market-agent-tooling = { version = ">=0.55.0,<1", extras = ["langchain", "google"] }
prediction-market-agent-tooling = { version = ">=0.56.0,<1", extras = ["langchain", "google"] }
langchain-community = "^0.2.6"
memory-profiler = "^0.61.0"
matplotlib = "^3.8.3"
Expand Down

0 comments on commit e1450b5

Please sign in to comment.