Skip to content

Commit

Permalink
More pyright fixes and rename classes
Browse files Browse the repository at this point in the history
  • Loading branch information
leonghui committed Nov 3, 2024
1 parent 0045b42 commit 450b337
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 45 deletions.
54 changes: 33 additions & 21 deletions amazon_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from datetime import datetime
from urllib.parse import quote_plus, urlencode, urlparse

import nh3
from bs4 import BeautifulSoup
from flask import abort
from nh3 import nh3
from requests.exceptions import JSONDecodeError, RequestException
from requests_cache import AnyResponse

from amazon_feed_data import BOT_PATTERN, AmazonItemQuery, AmazonListingQuery
from amazon_feed_data import (BOT_PATTERN, AmazonAsinQuery, AmazonKeywordQuery,
FilterableQuery)
from json_feed_data import JSONFEED_VERSION_URL, JsonFeedItem, JsonFeedTopLevel

ITEM_QUANTITY = 1
Expand All @@ -28,12 +30,12 @@
}


def reset_query_session(query):
query.config.useragent = None
def reset_query_session(query: FilterableQuery):
query.config.useragent = ""
query.config.session.cookies.clear()


def handle_response(response, query):
def handle_response(response: AnyResponse, query: FilterableQuery):
logger = query.config.logger

try:
Expand All @@ -44,7 +46,7 @@ def handle_response(response, query):
return None


def get_response_dict(url, query):
def get_response_dict(url: str, query: FilterableQuery):
logger = query.config.logger
session = query.config.session

Expand Down Expand Up @@ -80,7 +82,7 @@ def get_response_dict(url, query):
return handle_response(response, query)


def get_search_url(base_url, query, is_xhr=True):
def get_search_url(base_url: str, query: FilterableQuery, is_xhr: bool = True):
search_uri = f"{base_url}/s/query?" if is_xhr else f"{base_url}/s?"

search_dict = {"k": quote_plus(query.query_str)}
Expand All @@ -104,25 +106,27 @@ def get_search_url(base_url, query, is_xhr=True):
return search_uri + urlencode(search_dict)


def get_item_url(base_url, item_id):
def get_item_url(base_url: str, item_id: str):
return base_url + "/gp/product/" + item_id


def get_top_level_feed(base_url, query, feed_items):
def get_top_level_feed(
base_url: str, query: FilterableQuery, feed_items: list[JsonFeedItem]
):
parse_object = urlparse(base_url)
domain = parse_object.netloc

title_strings = [domain, query.query_str]

filters = []

if isinstance(query, AmazonListingQuery):
if isinstance(query, AmazonKeywordQuery):
home_page_url = get_search_url(base_url, query, is_xhr=False)

if query.strict:
filters.append("strict")

elif isinstance(query, AmazonItemQuery):
elif isinstance(query, AmazonAsinQuery):
home_page_url = get_item_url(base_url, query.query_str)

if query.min_price:
Expand All @@ -145,7 +149,13 @@ def get_top_level_feed(base_url, query, feed_items):
return json_feed


def generate_item(base_url, item_id, item_title, item_price_text, item_thumbnail_url):
def generate_item(
base_url: str,
item_id: str,
item_title: str,
item_price_text: str,
item_thumbnail_url: str,
):
item_title_text = item_title.strip() if item_title else item_id

item_thumbnail_html = f'<img src="{item_thumbnail_url}" />'
Expand Down Expand Up @@ -185,7 +195,7 @@ def generate_item(base_url, item_id, item_title, item_price_text, item_thumbnail
return feed_item


def get_search_results(search_query):
def get_keyword_results(search_query: AmazonKeywordQuery):
logger = search_query.config.logger

base_url = "https://" + search_query.locale.domain
Expand All @@ -204,6 +214,8 @@ def get_search_results(search_query):
else {}
)

term_list: list[str] = []

if search_query.strict:
term_list = set([term.lower() for term in search_query.query_str.split()])
logger.debug(
Expand All @@ -215,7 +227,7 @@ def get_search_results(search_query):
generated_items = []

for result in results_dict.values():
item_id = result.get("asin")
item_id: str = result.get("asin")
item_soup = BeautifulSoup(result.get("html"), features="html.parser")

# select product title, use wildcard CSS selector for better international compatibility
Expand Down Expand Up @@ -268,11 +280,11 @@ def get_search_results(search_query):
return json_feed


def get_dimension_url(listing_query, item_id):
def get_dimension_url(query: AmazonAsinQuery, item_id: str):
# Call the "dimension" endpoint which is used on mobile pages
# to display price and optionally availability for product variants

locale_data = listing_query.locale
locale_data = query.locale
base_url = "https://" + locale_data.domain
dimension_endpoint = base_url + "/gp/product/ajax?"

Expand All @@ -286,7 +298,7 @@ def get_dimension_url(listing_query, item_id):
return dimension_endpoint + urlencode(query_dict)


def get_item_listing(query):
def get_item_listing(query: AmazonAsinQuery):
logger = query.config.logger

item_id = query.query_str
Expand All @@ -295,14 +307,14 @@ def get_item_listing(query):

json_dict = get_response_dict(item_dimension_url, query)

item_price = None
item_price: str = ""

if json_dict:
# Assume one item is returned per response
matching_result = (
result = (
json_dict.get("Value", {}).get("content", {}).get("twisterSlotJson", {})
)
item_price = matching_result.get("price")
item_price = result.get("price")

json_feed = get_top_level_feed(base_url, query, [])

Expand All @@ -325,7 +337,7 @@ def get_item_listing(query):

formatted_price = query.locale.currency + item_price

feed_item = generate_item(base_url, item_id, None, formatted_price, None)
feed_item = generate_item(base_url, item_id, "", formatted_price, "")

json_feed = get_top_level_feed(base_url, query, [feed_item])

Expand Down
12 changes: 6 additions & 6 deletions amazon_feed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class _PriceFilter:


@dataclass
class _BaseQueryWithPriceFilter(_PriceFilter, _BaseQuery):
class FilterableQuery(_PriceFilter, _BaseQuery):
def validate_price_filters(self):
if self.max_price and not self.max_price.isnumeric():
self.status.errors.append("Invalid max price")
Expand All @@ -98,7 +98,7 @@ def validate_price_filters(self):


@dataclass
class _AmazonSearchFilter:
class _AmazonKeywordFilter:
strict_str: str = "False"
strict: bool = False

Expand All @@ -108,13 +108,13 @@ def validate_amazon_search_filters(self):


@dataclass
class AmazonListingQuery(_AmazonSearchFilter, _BaseQueryWithPriceFilter):
class AmazonKeywordQuery(_AmazonKeywordFilter, FilterableQuery):
query_str: str = "AMD"

def from_item_query(self):
assert isinstance(self, AmazonItemQuery)
assert isinstance(self, AmazonAsinQuery)

listing_query = AmazonListingQuery(
listing_query = AmazonKeywordQuery(
status=self.status,
query_str=self.query_str,
config=self.config,
Expand All @@ -138,7 +138,7 @@ def __post_init__(self):


@dataclass
class AmazonItemQuery(_BaseQueryWithPriceFilter):
class AmazonAsinQuery(FilterableQuery):
query_str: str = "B08166SLDF" # AMD Ryzen 5 5600X Processor

def __post_init__(self):
Expand Down
10 changes: 5 additions & 5 deletions mozilla_devices.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from amazon_feed_data import FeedConfig

CATALOG_URL = "https://code.cdn.mozilla.net/devices/devices.json"

Expand All @@ -9,14 +10,13 @@ class DeviceType(Enum):
LAPTOPS = "laptops"
TELEVISIONS = "televisions"


def get_useragent_list(device_type, config):
def get_useragent_list(device_type: DeviceType, config: FeedConfig) -> list[str]:
config.logger.debug(f"Querying endpoint: {CATALOG_URL}")
catalog_response = config.session.get(CATALOG_URL)
catalog_json = catalog_response.json() if catalog_response.ok else None
catalog_json: dict = catalog_response.json() if catalog_response.ok else None

if catalog_response.ok:
useragent_list = [
useragent_list: list[str] = [
device["userAgent"] for device in catalog_json[device_type.value]
]
config.logger.info(
Expand All @@ -26,4 +26,4 @@ def get_useragent_list(device_type, config):

else:
config.logger.warning("Unable to get useragent list.")
return None
return []
28 changes: 15 additions & 13 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from flask.logging import create_logger
from requests_cache import CachedSession

from amazon_feed import get_item_listing, get_search_results
from amazon_feed_data import (AmazonItemQuery, AmazonListingQuery, FeedConfig,
QueryStatus)
from amazon_feed import get_item_listing, get_keyword_results
from amazon_feed_data import (AmazonAsinQuery, AmazonKeywordQuery, FeedConfig,
FilterableQuery, QueryStatus)
from mozilla_devices import DeviceType, get_useragent_list

CACHE_EXPIRATION_SEC = 60
Expand Down Expand Up @@ -35,15 +35,17 @@ def set_useragent():
config.logger.debug(f"Using user-agent: {config.useragent}")


def generate_response(query):
def generate_response(query: FilterableQuery):
if not query.status.ok:
abort(400, description="Errors found: " + ", ".join(query.status.errors))

config.logger.debug(query) # log values

if isinstance(query, AmazonListingQuery):
output = get_search_results(query)
elif isinstance(query, AmazonItemQuery):
output = None

if isinstance(query, AmazonKeywordQuery):
output = get_keyword_results(query)
elif isinstance(query, AmazonAsinQuery):
output = get_item_listing(query)
return jsonify(output)

Expand All @@ -52,8 +54,8 @@ def generate_response(query):
@app.route("/search", methods=["GET"])
def process_listing():
list_request_dict = {
"query_str": request.args.get("query") or AmazonListingQuery.query_str,
"country": request.args.get("country") or AmazonListingQuery.country,
"query_str": request.args.get("query") or AmazonKeywordQuery.query_str,
"country": request.args.get("country") or AmazonKeywordQuery.country,
"min_price": request.args.get("min_price"),
"max_price": request.args.get("max_price"),
"strict_str": request.args.get("strict"),
Expand All @@ -62,7 +64,7 @@ def process_listing():
if not config.useragent:
set_useragent()

listing_query = AmazonListingQuery(
listing_query = AmazonKeywordQuery(
status=QueryStatus(), config=config, **list_request_dict
)

Expand All @@ -72,16 +74,16 @@ def process_listing():
@app.route("/item", methods=["GET"])
def process_item():
item_request_dict = {
"query_str": request.args.get("id") or AmazonItemQuery.query_str,
"country": request.args.get("country") or AmazonItemQuery.country,
"query_str": request.args.get("id") or AmazonAsinQuery.query_str,
"country": request.args.get("country") or AmazonAsinQuery.country,
"min_price": None,
"max_price": request.args.get("max_price"),
}

if not config.useragent:
set_useragent()

item_query = AmazonItemQuery(
item_query = AmazonAsinQuery(
status=QueryStatus(), config=config, **item_request_dict
)

Expand Down

0 comments on commit 450b337

Please sign in to comment.