Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Fix new services require bot restart #48

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ __pycache__
auslander.py
burgerbot_dev.py
log.txt
.vscode
.vscode
.venv/
84 changes: 45 additions & 39 deletions burgerbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import sys
from dataclasses import dataclass, asdict
from typing import List
from typing import Any, List
from datetime import datetime

from telegram import ParseMode
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(self, chat_id, services=[120686]):
self.chat_id = chat_id
self.services = services if len(services) > 0 else [120686]

def marshall_user(self) -> str:
def marshall_user(self) -> dict[str, Any]:
self.services = list(
set([s for s in self.services if s in list(service_map.keys())])
)
Expand All @@ -75,9 +75,9 @@ def __init__(self) -> None:
self.updater = Updater(os.environ["TELEGRAM_API_KEY"])
self.__init_chats()
self.users = self.__get_chats()
self.services = self.__get_uq_services()
self.parser = Parser(self.services)
self.parser = Parser()
self.dispatcher = self.updater.dispatcher
assert self.dispatcher is not None
self.dispatcher.add_handler(CommandHandler("help", self.__help))
self.dispatcher.add_handler(CommandHandler("start", self.__start))
self.dispatcher.add_handler(CommandHandler("stop", self.__stop))
Expand All @@ -89,39 +89,42 @@ def __init__(self) -> None:
self.dispatcher.add_handler(CommandHandler("services", self.__services))
self.cache: List[Message] = []

def __get_uq_services(self) -> List[int]:
services = []
for u in self.users:
services.extend(u.services)
services = filter(lambda x: x in service_map.keys(), services)
return list(set(services))
def __get_uq_services(self) -> set[int]:
return {
service
for user in self.users.values()
for service in user.services
if service in service_map
}

def __init_chats(self) -> None:
if not os.path.exists(CHATS_FILE):
with open(CHATS_FILE, "w") as f:
f.write("[]")

def __get_chats(self) -> List[User]:
def __get_chats(self) -> dict[int, User]:
with open(CHATS_FILE, "r") as f:
users = [User(u["chat_id"], u["services"]) for u in json.load(f)]
f.close()
print(users)
return users
logging.info(users)
return {u.chat_id: u for u in users}

def __persist_chats(self) -> None:
with open(CHATS_FILE, "w") as f:
json.dump([u.marshall_user() for u in self.users], f)
f.close()
marshalled_users = [u.marshall_user() for u in self.users.values()]
json.dump(marshalled_users, f)

def __add_chat(self, chat_id: int) -> None:
if chat_id not in [u.chat_id for u in self.users]:
logging.info("adding new user")
self.users.append(User(chat_id))
self.__persist_chats()
if chat_id in self.users:
logging.info(f"attempted to add user {chat_id} but it already exists")
return

logging.info(f"adding new user {chat_id}")
self.users[chat_id] = User(chat_id)
self.__persist_chats()

def __remove_chat(self, chat_id: int) -> None:
logging.info("removing the chat " + str(chat_id))
self.users = [u for u in self.users if u.chat_id != chat_id]
self.users.pop(chat_id)
self.__persist_chats()

def __services(self, update: Update, _: CallbackContext) -> None:
Expand Down Expand Up @@ -157,33 +160,33 @@ def __stop(self, update: Update, _: CallbackContext) -> None:
update.message.reply_text("Thanks for using me! Bye!")

def __my_services(self, update: Update, _: CallbackContext) -> None:
chat_id = update.message.chat_id
try:
service_ids = set(
service_id
for u in self.users
for service_id in u.services
if u.chat_id == update.message.chat_id
)
user = self.users[chat_id]
service_ids = set(user.services)
msg = (
"\n".join([f" - {service_id}" for service_id in service_ids])
or " - (none)"
)
update.message.reply_text(
"The following services are on your list:\n" + msg
)
except KeyError:
logging.warning(f"user {chat_id} not found")
except Exception as e:
logging.error(f"error occured when listing user services, {e}")

def __add_service(self, update: Update, _: CallbackContext) -> None:
logging.info(f"adding service {update.message}")
chat_id = update.message.chat_id
try:
user = self.users[chat_id]
service_id = int(update.message.text.split(" ")[1])
for u in self.users:
if u.chat_id == update.message.chat_id:
u.services.append(int(service_id))
self.__persist_chats()
break
user.services.append(service_id)
self.__persist_chats()
update.message.reply_text("Service added")
except KeyError:
logging.warning(f"user {chat_id} not found")
except Exception as e:
update.message.reply_text(
"Failed to add service, have you specified the service id?"
Expand All @@ -192,14 +195,15 @@ def __add_service(self, update: Update, _: CallbackContext) -> None:

def __remove_service(self, update: Update, _: CallbackContext) -> None:
logging.info(f"removing service {update.message}")
chat_id = update.message.chat_id
try:
user = self.users[chat_id]
service_id = int(update.message.text.split(" ")[1])
for u in self.users:
if u.chat_id == update.message.chat_id:
u.services.remove(int(service_id))
self.__persist_chats()
break
user.services.remove(service_id)
self.__persist_chats()
update.message.reply_text("Service removed")
except KeyError:
logging.warning(f"user {chat_id} not found")
except IndexError:
update.message.reply_text(
"Wrong usage. Please type '/remove_service 123456'"
Expand All @@ -215,7 +219,9 @@ def __poll(self) -> None:

def __parse(self) -> None:
while True:
slots = self.parser.parse()
services = self.__get_uq_services()
logging.info(f"services are: {services}")
slots = self.parser.parse(services)
for slot in slots:
self.__send_message(slot)
time.sleep(30)
Expand All @@ -226,7 +232,7 @@ def __send_message(self, slot: Slot) -> None:
return
self.__add_msg_to_cache(slot.msg)
md_msg = f"There are slots on {self.__date_from_msg(slot.msg)} available for booking for {service_map[slot.service_id]}, click [here]({build_url(slot.service_id)}) to check it out"
users = [u for u in self.users if slot.service_id in u.services]
users = [u for u in self.users.values() if slot.service_id in u.services]
for u in users:
logging.debug(f"sending msg to {str(u.chat_id)}")
try:
Expand Down
25 changes: 11 additions & 14 deletions parser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import time
import logging
from dataclasses import dataclass
from typing import List
from re import S
from typing import Iterable, List, Optional

import requests
from bs4 import BeautifulSoup
Expand All @@ -25,10 +24,8 @@ class Slot:


class Parser:
def __init__(self, services: List[int]) -> None:
self.services = services
def __init__(self) -> None:
self.proxy_on: bool = False
self.parse()

def __get_url(self, url) -> requests.Response:
logging.debug(url)
Expand All @@ -47,7 +44,7 @@ def __get_url(self, url) -> requests.Response:
def __toggle_proxy(self) -> None:
self.proxy_on = not self.proxy_on

def __parse_page(self, page, service_id) -> List[str]:
def __parse_page(self, page, service_id) -> Optional[List[Slot]]:
try:
if page.status_code == 428 or page.status_code == 429:
logging.info("exceeded rate limit. Sleeping for a while")
Expand All @@ -68,13 +65,13 @@ def __parse_page(self, page, service_id) -> List[str]:
logging.error(f"error occured during page parsing, {e}")
self.__toggle_proxy()

def add_service(self, service_id: int) -> None:
self.services.append(service_id)

def parse(self) -> List[str]:
slots = []
logging.info("services are: " + str(self.services))
for svc in self.services:
def parse(self, services: Iterable[int]) -> List[Slot]:
slots: list[Slot] = []
for svc in services:
page = self.__get_url(build_url(svc))
slots += self.__parse_page(page, svc)
parsed_slots = self.__parse_page(page, svc)
if parsed_slots is None:
continue

slots += parsed_slots
return slots