diff --git a/Pipfile b/Pipfile index c0aebbe..86915c0 100644 --- a/Pipfile +++ b/Pipfile @@ -5,7 +5,9 @@ name = "pypi" [packages] orjson = "*" +pydantic = ">=2" "discord.py" = {extras = ["voice"], version = "*"} +websockets = "*" [dev-packages] mypy = "*" diff --git a/examples/advanced.py b/examples/advanced.py index b8d2d3a..ef14061 100644 --- a/examples/advanced.py +++ b/examples/advanced.py @@ -125,7 +125,7 @@ def is_privileged(self, ctx: commands.Context): return player.dj == ctx.author or ctx.author.guild_permissions.kick_members - # The following are events from pomice.events + # The following are events from pomice.models.events # We are using these so that if the track either stops or errors, # we can just skip to the next track diff --git a/pomice/__init__.py b/pomice/__init__.py index 84b4718..32b2a97 100644 --- a/pomice/__init__.py +++ b/pomice/__init__.py @@ -27,7 +27,7 @@ class DiscordPyOutdated(Exception): __copyright__ = "Copyright (c) 2023, cloudwithax" from .enums import * -from .events import * +from .models import * from .exceptions import * from .filters import * from .objects import * diff --git a/pomice/applemusic/client.py b/pomice/applemusic/client.py index 1a82a52..2e38083 100644 --- a/pomice/applemusic/client.py +++ b/pomice/applemusic/client.py @@ -11,17 +11,12 @@ import aiohttp import orjson as json -from .exceptions import * -from .objects import * +from pomice.applemusic.exceptions import * +from pomice.applemusic.objects import * +from pomice.enums import URLRegex __all__ = ("Client",) -AM_URL_REGEX = re.compile( - r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P[^?]+)", -) -AM_SINGLE_IN_ALBUM_REGEX = re.compile( - r"https?://music.apple.com/(?P[a-zA-Z]{2})/(?Palbum|playlist|song|artist)/(?P.+)/(?P.+)(\?i=)(?P.+)", -) AM_SCRIPT_REGEX = re.compile(r' Union[Album, Playlist, Song, Artist]: if not self.token or datetime.utcnow() > self.expiry: await self.request_token() - result = AM_URL_REGEX.match(query) + result = URLRegex.AM_URL.match(query) if not result: raise InvalidAppleMusicURL( "The Apple Music link provided is not valid.", @@ -113,7 +108,7 @@ async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]: type = result.group("type") id = result.group("id") - if type == "album" and (sia_result := AM_SINGLE_IN_ALBUM_REGEX.match(query)): + if type == "album" and (sia_result := URLRegex.AM_SINGLE_IN_ALBUM_REGEX.match(query)): # apple music likes to generate links for singles off an album # by adding a param at the end of the url # so we're gonna scan for that and correct it diff --git a/pomice/enums.py b/pomice/enums.py index 1e963de..89689fc 100644 --- a/pomice/enums.py +++ b/pomice/enums.py @@ -1,6 +1,7 @@ import re from enum import Enum from enum import IntEnum +from enum import unique __all__ = ( "SearchType", @@ -15,7 +16,13 @@ ) -class SearchType(Enum): +class BaseStrEnum(str, Enum): + def __str__(self): + return self.value + + +@unique +class SearchType(BaseStrEnum): """ The enum for the different search types for Pomice. This feature is exclusively for the Spotify search feature of Pomice. @@ -31,15 +38,13 @@ class SearchType(Enum): which is an alternative to YouTube or YouTube Music. """ - ytsearch = "ytsearch" - ytmsearch = "ytmsearch" - scsearch = "scsearch" - - def __str__(self) -> str: - return self.value + YTSEARCH = "ytsearch" + YTMSEARCH = "ytmsearch" + SCSEARCH = "scsearch" -class TrackType(Enum): +@unique +class TrackType(BaseStrEnum): """ The enum for the different track types for Pomice. @@ -64,11 +69,9 @@ class TrackType(Enum): HTTP = "http" LOCAL = "local" - def __str__(self) -> str: - return self.value - -class PlaylistType(Enum): +@unique +class PlaylistType(BaseStrEnum): """ The enum for the different playlist types for Pomice. @@ -87,11 +90,9 @@ class PlaylistType(Enum): SPOTIFY = "spotify" APPLE_MUSIC = "apple_music" - def __str__(self) -> str: - return self.value - -class NodeAlgorithm(Enum): +@unique +class NodeAlgorithm(BaseStrEnum): """ The enum for the different node algorithms in Pomice. @@ -111,11 +112,9 @@ class NodeAlgorithm(Enum): by_ping = "BY_PING" by_players = "BY_PLAYERS" - def __str__(self) -> str: - return self.value - -class LoopMode(Enum): +@unique +class LoopMode(BaseStrEnum): """ The enum for the different loop modes. This feature is exclusively for the queue utility of pomice. @@ -124,18 +123,15 @@ class LoopMode(Enum): LoopMode.TRACK sets the queue loop to the current track. LoopMode.QUEUE sets the queue loop to the whole queue. - """ # We don't have to define anything special for these, since these just serve as flags TRACK = "track" QUEUE = "queue" - def __str__(self) -> str: - return self.value - -class RouteStrategy(Enum): +@unique +class RouteStrategy(BaseStrEnum): """ The enum for specifying the route planner strategy for Lavalink. This feature is exclusively for the RoutePlanner class. @@ -153,7 +149,6 @@ class RouteStrategy(Enum): RouteStrategy.ROTATING_NANO_SWITCH specifies that the node is switching between IPs every CPU clock cycle and is rotating between IP blocks on ban. - """ ROTATE_ON_BAN = "RotatingIpRoutePlanner" @@ -162,7 +157,8 @@ class RouteStrategy(Enum): ROTATING_NANO_SWITCH = "RotatingNanoIpRoutePlanner" -class RouteIPType(Enum): +@unique +class RouteIPType(BaseStrEnum): """ The enum for specifying the route planner IP block type for Lavalink. This feature is exclusively for the RoutePlanner class. @@ -177,9 +173,43 @@ class RouteIPType(Enum): IPV6 = "Inet6Address" +@unique +class LogLevel(IntEnum): + """ + The enum for specifying the logging level within Pomice. + This class serves as shorthand for logging. + This enum is exclusively for the logging feature in Pomice. + If you are not using this feature, this class is not necessary. + + + LogLevel.DEBUG sets the logging level to "debug". + + LogLevel.INFO sets the logging level to "info". + + LogLevel.WARN sets the logging level to "warn". + + LogLevel.ERROR sets the logging level to "error". + + LogLevel.CRITICAL sets the logging level to "CRITICAL". + """ + + DEBUG = 10 + INFO = 20 + WARN = 30 + ERROR = 40 + CRITICAL = 50 + + @classmethod + def from_str(cls, level_str): + try: + return cls[level_str.upper()] + except KeyError: + raise ValueError(f"No such log level: {level_str}") + + class URLRegex: """ - The enum for all the URL Regexes in use by Pomice. + The class for all the URL Regexes in use by Pomice. URLRegex.SPOTIFY_URL returns the Spotify URL Regex. @@ -196,7 +226,6 @@ class URLRegex: URLRegex.SOUNDCLOUD_URL returns the SoundCloud URL Regex. URLRegex.BASE_URL returns the standard URL Regex. - """ SPOTIFY_URL = re.compile( @@ -246,37 +275,3 @@ class URLRegex: LAVALINK_SEARCH = re.compile(r"(?Pytm?|sc)search:") BASE_URL = re.compile(r"https?://(?:www\.)?.+") - - -class LogLevel(IntEnum): - """ - The enum for specifying the logging level within Pomice. - This class serves as shorthand for logging. - This enum is exclusively for the logging feature in Pomice. - If you are not using this feature, this class is not necessary. - - - LogLevel.DEBUG sets the logging level to "debug". - - LogLevel.INFO sets the logging level to "info". - - LogLevel.WARN sets the logging level to "warn". - - LogLevel.ERROR sets the logging level to "error". - - LogLevel.CRITICAL sets the logging level to "CRITICAL". - - """ - - DEBUG = 10 - INFO = 20 - WARN = 30 - ERROR = 40 - CRITICAL = 50 - - @classmethod - def from_str(cls, level_str): - try: - return cls[level_str.upper()] - except KeyError: - raise ValueError(f"No such log level: {level_str}") diff --git a/pomice/events.py b/pomice/events.py deleted file mode 100644 index 062c2e4..0000000 --- a/pomice/events.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from typing import Any -from typing import Optional -from typing import Tuple -from typing import TYPE_CHECKING - -from discord import Client -from discord import Guild -from discord.ext import commands - -from .objects import Track -from .pool import NodePool - -if TYPE_CHECKING: - from .player import Player - -__all__ = ( - "PomiceEvent", - "TrackStartEvent", - "TrackEndEvent", - "TrackStuckEvent", - "TrackExceptionEvent", - "WebSocketClosedPayload", - "WebSocketClosedEvent", - "WebSocketOpenEvent", -) - - -class PomiceEvent(ABC): - """The base class for all events dispatched by a node. - Every event must be formatted within your bot's code as a listener. - i.e: If you want to listen for when a track starts, the event would be: - ```py - @bot.listen - async def on_pomice_track_start(self, event): - ``` - """ - - name = "event" - handler_args: Tuple - - def dispatch(self, bot: Client) -> None: - bot.dispatch(f"pomice_{self.name}", *self.handler_args) - - -class TrackStartEvent(PomiceEvent): - """Fired when a track has successfully started. - Returns the player associated with the event and the pomice.Track object. - """ - - name = "track_start" - - __slots__ = ( - "player", - "track", - ) - - def __init__(self, data: dict, player: Player): - self.player: Player = player - self.track: Optional[Track] = self.player._current - - # on_pomice_track_start(player, track) - self.handler_args = self.player, self.track - - def __repr__(self) -> str: - return f"" - - -class TrackEndEvent(PomiceEvent): - """Fired when a track has successfully ended. - Returns the player associated with the event along with the pomice.Track object and reason. - """ - - name = "track_end" - - __slots__ = ("player", "track", "reason") - - def __init__(self, data: dict, player: Player): - self.player: Player = player - self.track: Optional[Track] = self.player._ending_track - self.reason: str = data["reason"] - - # on_pomice_track_end(player, track, reason) - self.handler_args = self.player, self.track, self.reason - - def __repr__(self) -> str: - return ( - f"" - ) - - -class TrackStuckEvent(PomiceEvent): - """Fired when a track is stuck and cannot be played. Returns the player - associated with the event along with the pomice.Track object - to be further parsed by the end user. - """ - - name = "track_stuck" - - __slots__ = ("player", "track", "threshold") - - def __init__(self, data: dict, player: Player): - self.player: Player = player - self.track: Optional[Track] = self.player._ending_track - self.threshold: float = data["thresholdMs"] - - # on_pomice_track_stuck(player, track, threshold) - self.handler_args = self.player, self.track, self.threshold - - def __repr__(self) -> str: - return ( - f"" - ) - - -class TrackExceptionEvent(PomiceEvent): - """Fired when a track error has occured. - Returns the player associated with the event along with the error code and exception. - """ - - name = "track_exception" - - __slots__ = ("player", "track", "exception") - - def __init__(self, data: dict, player: Player): - self.player: Player = player - self.track: Optional[Track] = self.player._ending_track - # Error is for Lavalink <= 3.3 - self.exception: str = data.get( - "error", - "", - ) or data.get("exception", "") - - # on_pomice_track_exception(player, track, error) - self.handler_args = self.player, self.track, self.exception - - def __repr__(self) -> str: - return f"" - - -class WebSocketClosedPayload: - __slots__ = ("guild", "code", "reason", "by_remote") - - def __init__(self, data: dict): - self.guild: Optional[Guild] = NodePool.get_node().bot.get_guild(int(data["guildId"])) - self.code: int = data["code"] - self.reason: str = data["code"] - self.by_remote: bool = data["byRemote"] - - def __repr__(self) -> str: - return ( - f"" - ) - - -class WebSocketClosedEvent(PomiceEvent): - """Fired when a websocket connection to a node has been closed. - Returns the reason and the error code. - """ - - name = "websocket_closed" - - __slots__ = ("payload",) - - def __init__(self, data: dict, _: Any) -> None: - self.payload: WebSocketClosedPayload = WebSocketClosedPayload(data) - - # on_pomice_websocket_closed(payload) - self.handler_args = (self.payload,) - - def __repr__(self) -> str: - return f"" - - -class WebSocketOpenEvent(PomiceEvent): - """Fired when a websocket connection to a node has been initiated. - Returns the target and the session SSRC. - """ - - name = "websocket_open" - - __slots__ = ("target", "ssrc") - - def __init__(self, data: dict, _: Any) -> None: - self.target: str = data["target"] - self.ssrc: int = data["ssrc"] - - # on_pomice_websocket_open(target, ssrc) - self.handler_args = self.target, self.ssrc - - def __repr__(self) -> str: - return f"" diff --git a/pomice/exceptions.py b/pomice/exceptions.py index 4019e3b..6ec69c5 100644 --- a/pomice/exceptions.py +++ b/pomice/exceptions.py @@ -61,7 +61,7 @@ class NoNodesAvailable(PomiceException): pass -class TrackInvalidPosition(PomiceException): +class TrackInvalidPosition(PomiceException, ValueError): """An invalid position was chosen for a track.""" pass @@ -73,19 +73,19 @@ class TrackLoadError(PomiceException): pass -class FilterInvalidArgument(PomiceException): +class FilterInvalidArgument(PomiceException, ValueError): """An invalid argument was passed to a filter.""" pass -class FilterTagInvalid(PomiceException): +class FilterTagInvalid(PomiceException, ValueError): """An invalid tag was passed or Pomice was unable to find a filter tag""" pass -class FilterTagAlreadyInUse(PomiceException): +class FilterTagAlreadyInUse(PomiceException, ValueError): """A filter with a tag is already in use by another filter""" pass @@ -97,7 +97,7 @@ class InvalidSpotifyClientAuthorization(PomiceException): pass -class AppleMusicNotEnabled(PomiceException): +class AppleMusicNotEnabled(PomiceException, ValueError): """An Apple Music Link was passed in when Apple Music functionality was not enabled.""" pass diff --git a/pomice/filters.py b/pomice/filters.py index f0df953..e4266d4 100644 --- a/pomice/filters.py +++ b/pomice/filters.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections from typing import Any from typing import Dict @@ -5,7 +7,7 @@ from typing import Optional from typing import Tuple -from .exceptions import FilterInvalidArgument +from pomice.exceptions import FilterInvalidArgument __all__ = ( "Filter", @@ -84,7 +86,7 @@ def __eq__(self, __value: object) -> bool: return self.raw == __value.raw @classmethod - def flat(cls) -> "Equalizer": + def flat(cls) -> Equalizer: """Equalizer preset which represents a flat EQ board, with all levels set to their default values. """ @@ -109,7 +111,7 @@ def flat(cls) -> "Equalizer": return cls(tag="flat", levels=levels) @classmethod - def boost(cls) -> "Equalizer": + def boost(cls) -> Equalizer: """Equalizer preset which boosts the sound of a track, making it sound fun and energetic by increasing the bass and the highs. @@ -135,7 +137,7 @@ def boost(cls) -> "Equalizer": return cls(tag="boost", levels=levels) @classmethod - def metal(cls) -> "Equalizer": + def metal(cls) -> Equalizer: """Equalizer preset which increases the mids of a track, preferably one of the metal genre, to make it sound more full and concert-like. @@ -162,7 +164,7 @@ def metal(cls) -> "Equalizer": return cls(tag="metal", levels=levels) @classmethod - def piano(cls) -> "Equalizer": + def piano(cls) -> Equalizer: """Equalizer preset which increases the mids and highs of a track, preferably a piano based one, to make it stand out. @@ -215,7 +217,7 @@ def __init__(self, *, tag: str, speed: float = 1.0, pitch: float = 1.0, rate: fl } @classmethod - def vaporwave(cls) -> "Timescale": + def vaporwave(cls) -> Timescale: """Timescale preset which slows down the currently playing track, giving it the effect of a half-speed record/casette playing. @@ -225,7 +227,7 @@ def vaporwave(cls) -> "Timescale": return cls(tag="vaporwave", speed=0.8, pitch=0.8) @classmethod - def nightcore(cls) -> "Timescale": + def nightcore(cls) -> Timescale: """Timescale preset which speeds up the currently playing track, which matches up to nightcore, a genre of sped-up music diff --git a/pomice/models/__init__.py b/pomice/models/__init__.py new file mode 100644 index 0000000..3959143 --- /dev/null +++ b/pomice/models/__init__.py @@ -0,0 +1,23 @@ +import pydantic +from pydantic import ConfigDict + +from .events import * +from .music import * +from .payloads import * +from .version import * + + +class BaseModel(pydantic.BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + + def model_dump(self, *args, **kwargs) -> dict: + by_alias = kwargs.pop("by_alias", True) + mode = kwargs.pop("mode", "json") + return super().model_dump(*args, **kwargs, by_alias=by_alias, mode=mode) + + +class VersionedModel(BaseModel): + version: LavalinkVersionType + + def model_dump(self, *args, **kwargs) -> dict: + return super().model_dump(*args, **kwargs, exclude={"version"}) diff --git a/pomice/models/events.py b/pomice/models/events.py new file mode 100644 index 0000000..b206fb5 --- /dev/null +++ b/pomice/models/events.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import abc +from enum import Enum +from enum import unique +from typing import Literal +from typing import TYPE_CHECKING + +from discord import Guild +from pydantic import computed_field +from pydantic import Field + +from pomice.models import BaseModel +from pomice.objects import Track +from pomice.player import Player +from pomice.pool import NodePool + +if TYPE_CHECKING: + from discord import Client + +__all__ = ( + "PomiceEvent", + "TrackStartEvent", + "TrackEndEvent", + "TrackStuckEvent", + "TrackExceptionEvent", + "WebSocketClosedPayload", + "WebSocketClosedEvent", + "WebSocketOpenEvent", +) + + +class PomiceEvent(BaseModel, abc.ABC): + """The base class for all events dispatched by a node. + Every event must be formatted within your bot's code as a listener. + i.e: If you want to listen for when a track starts, the event would be: + ```py + @bot.listen + async def on_pomice_track_start(self, event): + ``` + """ + + name: str + + @abc.abstractmethod + def dispatch(self, bot: Client) -> None: + ... + + +class TrackStartEvent(PomiceEvent): + """Fired when a track has successfully started. + Returns the player associated with the event and the pomice.Track object. + """ + + name: Literal["track_start"] + player: Player + track: Track + + def dispatch(self, bot: Client) -> None: + bot.dispatch(f"pomice_{self.name}", self.player, self.track) + + def __repr__(self) -> str: + return f"" + + +@unique +class TrackEndEventReason(str, Enum): + FINISHED = "finished" + LOAD_FAILED = "loadfailed" + STOPPED = "stopped" + REPLACED = "replaced" + CLEANUP = "cleanup" + + @classmethod + def _missing_(cls, value: object) -> TrackEndEventReason: + if isinstance(value, str): + return TrackEndEventReason(value.casefold()) + + +class TrackEndEvent(PomiceEvent): + """Fired when a track has successfully ended. + Returns the player associated with the event along with the pomice.Track object and reason. + """ + + name: Literal["track_end"] + player: Player + track: Track + reason: TrackEndEventReason + + def dispatch(self, bot: Client) -> None: + bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.reason) + + def __repr__(self) -> str: + return f"" + + +class TrackStuckEvent(PomiceEvent): + """Fired when a track has been stuck for a while. + Returns the player associated with the event along with the pomice.Track object and threshold. + """ + + name: Literal["track_stuck"] + player: Player + track: Track + threshold: float = Field(alias="thresholdMs") + + def dispatch(self, bot: Client) -> None: + bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.threshold) + + def __repr__(self) -> str: + return f"" + + +class TrackExceptionEvent(PomiceEvent): + """Fired when there is an exception while playing a track. + Returns the player associated with the event along with the pomice.Track object and exception. + """ + + name: Literal["track_exception"] + player: Player + track: Track + exception: str = Field(alias="error") + + def dispatch(self, bot: Client) -> None: + bot.dispatch(f"pomice_{self.name}", self.player, self.track, self.exception) + + def __repr__(self) -> str: + return f"" + + +class WebSocketClosedPayload(BaseModel): + """The payload for the WebSocketClosedEvent.""" + + guild_id: int = Field(alias="guildId") + code: int + reason: str + by_remote: bool = Field(alias="byRemote") + + @computed_field + @property + def guild(self) -> Guild: + return NodePool.get_node().bot.get_guild(self.guild_id) + + def __repr__(self) -> str: + return ( + f"" + ) + + +class WebSocketClosedEvent(PomiceEvent): + """Fired when the websocket connection to the node is closed. + Returns the player associated with the event and the code and reason for the closure. + """ + + name: Literal["websocket_closed"] + payload: WebSocketClosedPayload + + def dispatch(self, bot: Client) -> None: + bot.dispatch(f"pomice_{self.name}", self.payload) + + def __repr__(self) -> str: + return f"" + + +class WebSocketOpenEvent(PomiceEvent): + """Fired when the websocket connection to the node is opened. + Returns the player associated with the event. + """ + + name: Literal["websocket_open"] + target: str + ssrc: str + + def dispatch(self, bot: Client) -> None: + bot.dispatch(f"pomice_{self.name}", self.target, self.ssrc) + + def __repr__(self) -> str: + return f"" diff --git a/pomice/models/music.py b/pomice/models/music.py new file mode 100644 index 0000000..51bb225 --- /dev/null +++ b/pomice/models/music.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from typing import List +from typing import Literal +from typing import Optional +from typing import Union + +from discord.ext.commands import Context +from discord.user import _UserTag +from pydantic import Field +from pydantic import model_validator +from pydantic import TypeAdapter + +from pomice.enums import PlaylistType +from pomice.enums import SearchType +from pomice.enums import TrackType +from pomice.filters import Filter +from pomice.models import BaseModel + +__all__ = ( + "Track", + "TrackInfo", + "Playlist", + "PlaylistInfo", + "PlaylistExtended", + "PlaylistModelAdapter", +) + + +class TrackInfo(BaseModel): + identifier: str + title: str + author: str + length: int + position: int = 0 + is_stream: bool = Field(default=False, alias="isStream") + is_seekable: bool = Field(default=False, alias="isSeekable") + uri: Optional[str] = None + isrc: Optional[str] = None + source_name: Optional[str] = Field(default=None, alias="sourceName") + artwork_url: Optional[str] = Field(default=None, alias="artworkUrl") + + +class Track(BaseModel): + """The base track object. Returns critical track information needed for parsing by Lavalink. + You can also pass in commands.Context to get a discord.py Context object in your track. + """ + + track_id: str = Field(alias="encoded") + track_type: TrackType + info: TrackInfo + search_type: SearchType = SearchType.YTSEARCH + filters: List[Filter] = Field(default_factory=list) + timestamp: Optional[float] = None + playlist: Optional[Playlist] = None + original: Optional[Track] = None + ctx: Optional[Context] = None + requester: Optional[_UserTag] = None + + @property + def title(self) -> str: + return self.info.title + + @property + def author(self) -> str: + return self.info.author + + @property + def uri(self) -> Optional[str]: + return self.info.uri + + @property + def identifier(self) -> str: + return self.info.identifier + + @property + def isrc(self) -> Optional[str]: + return self.info.isrc + + @property + def thumbnail(self) -> Optional[str]: + return self.info.artwork_url + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Track): + return False + + return self.track_id == other.track_id + + def __str__(self) -> str: + return self.info.title + + def __repr__(self) -> str: + return f" length={self.info.length}>" + + @model_validator(mode="after") + def _set_thumbnail_url(self) -> Track: + if self.track_type is TrackType.YOUTUBE and not self.info.artwork_url: + self.info.artwork_url = ( + f"https://img.youtube.com/vi/{self.info.identifier}/mqdefault.jpg" + ) + return self + + +class PlaylistInfo(BaseModel): + name: str + selected_track: int = Field(default=0, alias="selectedTrack") + + +class Playlist(BaseModel): + """The base playlist object. + Returns critical playlist information needed for parsing by Lavalink. + """ + + info: PlaylistInfo + tracks: List[Track] + playlist_type: PlaylistType + + @property + def name(self) -> str: + return self.info.name + + @property + def selected_track(self) -> Optional[Track]: + if self.track_count <= 0: + return None + + return self.tracks[self.info.selected_track] + + @property + def track_count(self) -> int: + return len(self.tracks) + + def __str__(self) -> str: + return self.info.name + + def __repr__(self) -> str: + return f"" + + @model_validator(mode="after") + def _set_playlist(self) -> Playlist: + for track in self.tracks: + track.playlist = self + return self + + +class PlaylistExtended(Playlist): + """Playlist object with additional information for external services.""" + + playlist_type: Union[Literal[PlaylistType.APPLE_MUSIC, PlaylistType.SPOTIFY]] + uri: str + artwork_url: str + + @property + def thumbnail(self) -> Optional[str]: + return self.artwork_url + + +PlaylistModelType = Union[Playlist, PlaylistExtended] +PlaylistModelAdapter = lambda **kwargs: TypeAdapter(PlaylistModelType).validate_python(kwargs) diff --git a/pomice/models/payloads.py b/pomice/models/payloads.py new file mode 100644 index 0000000..96d97ef --- /dev/null +++ b/pomice/models/payloads.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Optional +from typing import Union + +from pydantic import AliasPath +from pydantic import Field +from pydantic import field_validator +from pydantic import model_validator +from pydantic import TypeAdapter + +from pomice.models import BaseModel +from pomice.models import LavalinkVersion3Type +from pomice.models import LavalinkVersion4Type +from pomice.models import VersionedModel + +__all__ = ( + "VoiceUpdatePayload", + "TrackStartPayload", + "TrackUpdatePayload", + "ResumePayloadType", + "ResumePayloadTypeAdapter", +) + + +class VoiceUpdatePayload(BaseModel): + token: str = Field(validation_alias=AliasPath("event", "token")) + endpoint: str = Field(validation_alias=AliasPath("event", "endpoint")) + session_id: str = Field(alias="sessionId") + + +class TrackUpdatePayload(BaseModel): + encoded_track: Optional[str] = Field(default=None, alias="encodedTrack") + position: float + + +class TrackStartPayload(VersionedModel): + encoded_track: Optional[str] = Field(default=None, alias="encodedTrack") + position: float + end_time: str = Field(default="0", alias="endTime") + + @field_validator("end_time", mode="before") + @classmethod + def cast_end_time(cls, value: object) -> str: + return str(value) + + @model_validator(mode="after") + def adjust_end_time(self) -> TrackStartPayload: + if self.version >= LavalinkVersion3Type(3, 7, 5): + self.end_time = None + + +class ResumePayload(VersionedModel): + timeout: int + + +class ResumePayloadV3(ResumePayload): + version: LavalinkVersion3Type + resuming_key: str = Field(alias="resumingKey") + + +class ResumePayloadV4(ResumePayload): + version: LavalinkVersion4Type + resuming: bool = True + + +ResumePayloadType = Union[ResumePayloadV3, ResumePayloadV4] +ResumePayloadTypeAdapter = lambda **kwargs: TypeAdapter(ResumePayloadType).validate_python(kwargs) diff --git a/pomice/models/version.py b/pomice/models/version.py new file mode 100644 index 0000000..8ca1ded --- /dev/null +++ b/pomice/models/version.py @@ -0,0 +1,51 @@ +from typing import Literal +from typing import NamedTuple +from typing import Union + +__all__ = ( + "LavalinkVersion", + "LavalinkVersion3Type", + "LavalinkVersion4Type", + "LavalinkVersionType", +) + + +class LavalinkVersion(NamedTuple): + major: int + minor: int + fix: int + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LavalinkVersion): + return False + + return ( + (self.major == other.major) and (self.minor == other.minor) and (self.fix == other.fix) + ) + + def __lt__(self, other: object) -> bool: + if not isinstance(other, LavalinkVersion): + return False + + if self.major > other.major: + return False + if self.minor > other.minor: + return False + if self.fix > other.fix: + return False + return True + + +class LavalinkVersion3Type(LavalinkVersion): + major: Literal[3] + minor: int + fix: int + + +class LavalinkVersion4Type(LavalinkVersion): + major: Literal[4] + minor: int + fix: int + + +LavalinkVersionType = Union[LavalinkVersion3Type, LavalinkVersion4Type, LavalinkVersion] diff --git a/pomice/objects.py b/pomice/objects.py deleted file mode 100644 index b251906..0000000 --- a/pomice/objects.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -from typing import List -from typing import Optional -from typing import Union - -from discord import ClientUser -from discord import Member -from discord import User -from discord.ext import commands - -from .enums import PlaylistType -from .enums import SearchType -from .enums import TrackType -from .filters import Filter - -__all__ = ( - "Track", - "Playlist", -) - - -class Track: - """The base track object. Returns critical track information needed for parsing by Lavalink. - You can also pass in commands.Context to get a discord.py Context object in your track. - """ - - __slots__ = ( - "track_id", - "info", - "track_type", - "filters", - "timestamp", - "original", - "_search_type", - "playlist", - "title", - "author", - "uri", - "identifier", - "isrc", - "thumbnail", - "length", - "ctx", - "requester", - "is_stream", - "is_seekable", - "position", - ) - - def __init__( - self, - *, - track_id: str, - info: dict, - ctx: Optional[commands.Context] = None, - track_type: TrackType, - search_type: SearchType = SearchType.ytsearch, - filters: Optional[List[Filter]] = None, - timestamp: Optional[float] = None, - requester: Optional[Union[Member, User, ClientUser]] = None, - ): - self.track_id: str = track_id - self.info: dict = info - self.track_type: TrackType = track_type - self.filters: Optional[List[Filter]] = filters - self.timestamp: Optional[float] = timestamp - - if self.track_type == TrackType.SPOTIFY or self.track_type == TrackType.APPLE_MUSIC: - self.original: Optional[Track] = None - else: - self.original = self - self._search_type: SearchType = search_type - - self.playlist: Optional[Playlist] = None - - self.title: str = info.get("title", "Unknown Title") - self.author: str = info.get("author", "Unknown Author") - self.uri: str = info.get("uri", "") - self.identifier: str = info.get("identifier", "") - self.isrc: Optional[str] = info.get("isrc", None) - self.thumbnail: Optional[str] = info.get("thumbnail") - - if self.uri and self.track_type is TrackType.YOUTUBE: - self.thumbnail = f"https://img.youtube.com/vi/{self.identifier}/mqdefault.jpg" - - self.length: int = info.get("length", 0) - self.is_stream: bool = info.get("isStream", False) - self.is_seekable: bool = info.get("isSeekable", False) - self.position: int = info.get("position", 0) - - self.ctx: Optional[commands.Context] = ctx - self.requester: Optional[Union[Member, User, ClientUser]] = requester - if not self.requester and self.ctx: - self.requester = self.ctx.author - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Track): - return False - - return other.track_id == self.track_id - - def __str__(self) -> str: - return self.title - - def __repr__(self) -> str: - return f" length={self.length}>" - - -class Playlist: - """The base playlist object. - Returns critical playlist information needed for parsing by Lavalink. - You can also pass in commands.Context to get a discord.py Context object in your tracks. - """ - - __slots__ = ( - "playlist_info", - "tracks", - "name", - "playlist_type", - "_thumbnail", - "_uri", - "selected_track", - "track_count", - ) - - def __init__( - self, - *, - playlist_info: dict, - tracks: list, - playlist_type: PlaylistType, - thumbnail: Optional[str] = None, - uri: Optional[str] = None, - ): - self.playlist_info: dict = playlist_info - self.tracks: List[Track] = tracks - self.name: str = playlist_info.get("name", "Unknown Playlist") - self.playlist_type: PlaylistType = playlist_type - - self._thumbnail: Optional[str] = thumbnail - self._uri: Optional[str] = uri - - for track in self.tracks: - track.playlist = self - - self.selected_track: Optional[Track] = None - if (index := playlist_info.get("selectedTrack", -1)) != -1: - self.selected_track = self.tracks[index] - - self.track_count: int = len(self.tracks) - - def __str__(self) -> str: - return self.name - - def __repr__(self) -> str: - return f"" - - @property - def uri(self) -> Optional[str]: - """Returns either an Apple Music/Spotify URL/URI, or None if its neither of those.""" - return self._uri - - @property - def thumbnail(self) -> Optional[str]: - """Returns either an Apple Music/Spotify album/playlist thumbnail, or None if its neither of those.""" - return self._thumbnail diff --git a/pomice/player.py b/pomice/player.py index e062456..d15db08 100644 --- a/pomice/player.py +++ b/pomice/player.py @@ -14,23 +14,24 @@ from discord import VoiceProtocol from discord.ext import commands -from . import events -from .enums import SearchType -from .events import PomiceEvent -from .events import TrackEndEvent -from .events import TrackStartEvent -from .exceptions import FilterInvalidArgument -from .exceptions import FilterTagAlreadyInUse -from .exceptions import FilterTagInvalid -from .exceptions import TrackInvalidPosition -from .exceptions import TrackLoadError -from .filters import Filter -from .filters import Timescale -from .objects import Playlist -from .objects import Track -from .pool import Node -from .pool import NodePool -from pomice.utils import LavalinkVersion +from pomice import events +from pomice.enums import SearchType +from pomice.exceptions import FilterInvalidArgument +from pomice.exceptions import FilterTagAlreadyInUse +from pomice.exceptions import FilterTagInvalid +from pomice.exceptions import TrackInvalidPosition +from pomice.exceptions import TrackLoadError +from pomice.filters import Filter +from pomice.filters import Timescale +from pomice.models.events import PomiceEvent +from pomice.models.events import TrackEndEvent +from pomice.models.events import TrackStartEvent +from pomice.models.music import Playlist +from pomice.models.music import Track +from pomice.models.payloads import TrackUpdatePayload +from pomice.models.payloads import VoiceUpdatePayload +from pomice.pool import Node +from pomice.pool import NodePool if TYPE_CHECKING: from discord.types.voice import VoiceServerUpdate @@ -200,10 +201,10 @@ def __repr__(self) -> str: @property def position(self) -> float: """Property which returns the player's position in a track in milliseconds""" - if not self.is_playing: + if not self.is_playing or not self._current: return 0 - current: Track = self._current # type: ignore + current: Track = self._current if current.original: current = current.original @@ -230,10 +231,10 @@ def adjusted_position(self) -> float: @property def adjusted_length(self) -> float: """Property which returns the player's track length in milliseconds adjusted for rate""" - if not self.is_playing: + if not self.is_playing or not self._current: return 0 - return self.current.length / self.rate # type: ignore + return self.current.length / self.rate @property def is_playing(self) -> bool: @@ -287,12 +288,6 @@ def is_dead(self) -> bool: """ return self.guild.id not in self._node._players - def _adjust_end_time(self) -> Optional[str]: - if self._node._version >= LavalinkVersion(3, 7, 5): - return None - - return "0" - async def _update_state(self, data: dict) -> None: state: dict = data.get("state", {}) self._last_update = int(state.get("time", 0)) @@ -301,23 +296,18 @@ async def _update_state(self, data: dict) -> None: if self._log: self._log.debug(f"Got player update state with data {state}") - async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = None) -> None: - if {"sessionId", "event"} != self._voice_state.keys(): - return - + async def _dispatch_voice_update(self, voice_data: Dict[str, Union[str, int]]) -> None: state = voice_data or self._voice_state + if {"sessionId", "event"} != state.keys(): + return - data = { - "token": state["event"]["token"], - "endpoint": state["event"]["endpoint"], - "sessionId": state["sessionId"], - } + data = VoiceUpdatePayload.model_validate(state) await self._node.send( method="PATCH", path=self._player_endpoint_uri, guild_id=self._guild.id, - data={"voice": data}, + data={"voice": data.model_dump()}, ) if self._log: @@ -327,44 +317,39 @@ async def _dispatch_voice_update(self, voice_data: Optional[Dict[str, Any]] = No async def on_voice_server_update(self, data: VoiceServerUpdate) -> None: self._voice_state.update({"event": data}) - await self._dispatch_voice_update(self._voice_state) + await self._dispatch_voice_update() async def on_voice_state_update(self, data: GuildVoiceState) -> None: - self._voice_state.update({"sessionId": data.get("session_id")}) - - channel_id = data.get("channel_id") - if not channel_id: - await self.disconnect() - self._voice_state.clear() - return + self._voice_state.update({"sessionId": data["session_id"]}) + channel_id = data["session_id"] channel = self.guild.get_channel(int(channel_id)) - if self.channel != channel: - self.channel = channel - if not channel: await self.disconnect() self._voice_state.clear() return + if self.channel != channel: + self.channel = channel + if not data.get("token"): return + self._voice_state.update({"event": data}) await self._dispatch_voice_update({**self._voice_state, "event": data}) async def _dispatch_event(self, data: dict) -> None: event_type: str = data["type"] - event: PomiceEvent = getattr(events, event_type)(data, self) + event: PomiceEvent = getattr(events, event_type)(player=self, **data) - if isinstance(event, TrackEndEvent) and event.reason not in ("REPLACED", "replaced"): + if isinstance(event, TrackEndEvent) and event.reason != "replaced": self._current = None - - event.dispatch(self._bot) - if isinstance(event, TrackStartEvent): self._ending_track = self._current + event.dispatch(self._bot) + if self._log: self._log.debug(f"Dispatched event {data['type']} to player.") @@ -373,7 +358,10 @@ async def _refresh_endpoint_uri(self, session_id: Optional[str]) -> None: async def _swap_node(self, *, new_node: Node) -> None: if self.current: - data: dict = {"position": self.position, "encodedTrack": self.current.track_id} + data: dict = TrackUpdatePayload( + encoded_track=self.current.track_id, + position=self.position, + ).model_dump() del self._node._players[self._guild.id] self._node = new_node @@ -396,7 +384,7 @@ async def get_tracks( query: str, *, ctx: Optional[commands.Context] = None, - search_type: SearchType = SearchType.ytsearch, + search_type: SearchType = SearchType.YTSEARCH, filters: Optional[List[Filter]] = None, ) -> Optional[Union[List[Track], Playlist]]: """Fetches tracks from the node's REST api to parse into Lavalink. @@ -629,6 +617,9 @@ async def set_pause(self, pause: bool) -> bool: async def set_volume(self, volume: int) -> int: """Sets the volume of the player as an integer. Lavalink accepts values from 0 to 500.""" + if volume < 0 or volume > 500: + raise ValueError("Volume must be between 0 and 500") + await self._node.send( method="PATCH", path=self._player_endpoint_uri, diff --git a/pomice/pool.py b/pomice/pool.py index 64564af..a0696e7 100644 --- a/pomice/pool.py +++ b/pomice/pool.py @@ -20,35 +20,34 @@ import orjson as json from discord import Client from discord.ext import commands -from discord.utils import MISSING from websockets import client from websockets import exceptions -from websockets import typing as wstype - -from . import __version__ -from . import applemusic -from . import spotify -from .enums import * -from .enums import LogLevel -from .exceptions import InvalidSpotifyClientAuthorization -from .exceptions import LavalinkVersionIncompatible -from .exceptions import NodeConnectionFailure -from .exceptions import NodeCreationError -from .exceptions import NodeNotAvailable -from .exceptions import NodeRestException -from .exceptions import NoNodesAvailable -from .exceptions import TrackLoadError -from .filters import Filter -from .objects import Playlist -from .objects import Track -from .routeplanner import RoutePlanner -from .utils import ExponentialBackoff -from .utils import LavalinkVersion -from .utils import NodeStats -from .utils import Ping + +from pomice import __version__ +from pomice import applemusic +from pomice import spotify +from pomice.enums import * +from pomice.exceptions import InvalidSpotifyClientAuthorization +from pomice.exceptions import LavalinkVersionIncompatible +from pomice.exceptions import NodeConnectionFailure +from pomice.exceptions import NodeCreationError +from pomice.exceptions import NodeNotAvailable +from pomice.exceptions import NodeRestException +from pomice.exceptions import NoNodesAvailable +from pomice.exceptions import TrackLoadError +from pomice.filters import Filter +from pomice.models.music import Playlist +from pomice.models.music import Track +from pomice.models.payloads import ResumePayloadTypeAdapter +from pomice.models.payloads import ResumePayloadV4 +from pomice.models.version import LavalinkVersion +from pomice.routeplanner import RoutePlanner +from pomice.utils import ExponentialBackoff +from pomice.utils import NodeStats +from pomice.utils import Ping if TYPE_CHECKING: - from .player import Player + from pomice.player import Player __all__ = ( "Node", @@ -167,20 +166,14 @@ def __init__( self._spotify_client: Optional[spotify.Client] = None self._apple_music_client: Optional[applemusic.Client] = None - self._spotify_client_id: Optional[str] = spotify_client_id - self._spotify_client_secret: Optional[str] = spotify_client_secret - - if self._spotify_client_id and self._spotify_client_secret: + if spotify_client_id and spotify_client_secret: self._spotify_client = spotify.Client( - self._spotify_client_id, - self._spotify_client_secret, + spotify_client_id, + spotify_client_secret, ) - if apple_music: self._apple_music_client = applemusic.Client() - self._bot.add_listener(self._update_handler, "on_socket_response") - def __repr__(self) -> str: return ( f" None: if self._apple_music_client: await self._apple_music_client._set_session(session=session) - async def _update_handler(self, data: dict) -> None: - await self._bot.wait_until_ready() - - if not data: - return - - if data["t"] == "VOICE_SERVER_UPDATE": - guild_id = int(data["d"]["guild_id"]) - try: - player = self._players[guild_id] - await player.on_voice_server_update(data["d"]) - except KeyError: - return - - elif data["t"] == "VOICE_STATE_UPDATE": - if int(data["d"]["user_id"]) != self._bot_user.id: - return - - guild_id = int(data["d"]["guild_id"]) - try: - player = self._players[guild_id] - await player.on_voice_state_update(data["d"]) - except KeyError: - return - async def _handle_node_switch(self) -> None: nodes = [node for node in self.pool._nodes.copy().values() if node.is_connected] new_node = random.choice(nodes) @@ -303,14 +271,15 @@ async def _configure_resuming(self) -> None: if not self._resume_key: return - data = {"timeout": self._resume_timeout} + data = ResumePayloadTypeAdapter( + version=self._version, + timeout=self._resume_timeout, + resuming_key=self._resume_key, + ).model_dump() - if self._version.major == 3: - data["resumingKey"] = self._resume_key - elif self._version.major == 4: + if isinstance(data, ResumePayloadV4): if self._log: self._log.warning("Using a resume key with Lavalink v4 is deprecated.") - data["resuming"] = True await self.send( method="PATCH", @@ -557,7 +526,7 @@ async def get_tracks( query: str, *, ctx: Optional[commands.Context] = None, - search_type: SearchType = SearchType.ytsearch, + search_type: SearchType = SearchType.YTSEARCH, filters: Optional[List[Filter]] = None, ) -> Optional[Union[Playlist, List[Track]]]: """Fetches tracks from the node's REST api to parse into Lavalink. diff --git a/pomice/queue.py b/pomice/queue.py index 3ea6e8b..0a54edf 100644 --- a/pomice/queue.py +++ b/pomice/queue.py @@ -8,11 +8,11 @@ from typing import Optional from typing import Union -from .enums import LoopMode -from .exceptions import QueueEmpty -from .exceptions import QueueException -from .exceptions import QueueFull -from .objects import Track +from pomice.enums import LoopMode +from pomice.exceptions import QueueEmpty +from pomice.exceptions import QueueException +from pomice.exceptions import QueueFull +from pomice.models.music import Track __all__ = ("Queue",) diff --git a/pomice/routeplanner.py b/pomice/routeplanner.py index 9a3d06e..8cf17ed 100644 --- a/pomice/routeplanner.py +++ b/pomice/routeplanner.py @@ -3,9 +3,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .pool import Node + from pomice.pool import Node -from .utils import RouteStats +from pomice.utils import RouteStats __all__ = ("RoutePlanner",) diff --git a/pomice/spotify/__init__.py b/pomice/spotify/__init__.py index e28be28..84da421 100644 --- a/pomice/spotify/__init__.py +++ b/pomice/spotify/__init__.py @@ -1,4 +1,5 @@ """Spotify module for Pomice, made possible by cloudwithax 2023""" -from .client import Client +from .client import * from .exceptions import * +from .models import * from .objects import * diff --git a/pomice/spotify/client.py b/pomice/spotify/client.py index 58b8647..3ed7cc7 100644 --- a/pomice/spotify/client.py +++ b/pomice/spotify/client.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import re import time from base64 import b64encode from typing import Dict @@ -13,18 +12,15 @@ import aiohttp import orjson as json -from .exceptions import InvalidSpotifyURL -from .exceptions import SpotifyRequestException -from .objects import * +from pomice.enums import URLRegex +from pomice.spotify.exceptions import * +from pomice.spotify.models import * __all__ = ("Client",) GRANT_URL = "https://accounts.spotify.com/api/token" REQUEST_URL = "https://api.spotify.com/v1/{type}s/{id}" -SPOTIFY_URL_REGEX = re.compile( - r"https?://open.spotify.com/(?Palbum|playlist|track|artist)/(?P[a-zA-Z0-9]+)", -) class Client: @@ -34,15 +30,12 @@ class Client: """ def __init__(self, client_id: str, client_secret: str) -> None: - self._client_id: str = client_id - self._client_secret: str = client_secret - self.session: aiohttp.ClientSession = None # type: ignore self._bearer_token: Optional[str] = None self._expiry: float = 0.0 self._auth_token = b64encode( - f"{self._client_id}:{self._client_secret}".encode(), + f"{client_id}:{client_secret}".encode(), ) self._grant_headers = { "Authorization": f"Basic {self._auth_token.decode()}", @@ -77,7 +70,7 @@ async def search(self, *, query: str) -> Union[Track, Album, Artist, Playlist]: if not self._bearer_token or time.time() >= self._expiry: await self._fetch_bearer_token() - result = SPOTIFY_URL_REGEX.match(query) + result = URLRegex.SPOTIFY_URL.match(query) if not result: raise InvalidSpotifyURL("The Spotify link provided is not valid.") @@ -151,7 +144,7 @@ async def get_recommendations(self, *, query: str) -> List[Track]: if not self._bearer_token or time.time() >= self._expiry: await self._fetch_bearer_token() - result = SPOTIFY_URL_REGEX.match(query) + result = URLRegex.SPOTIFY_URL.match(query) if not result: raise InvalidSpotifyURL("The Spotify link provided is not valid.") diff --git a/pomice/spotify/models.py b/pomice/spotify/models.py new file mode 100644 index 0000000..5bcc437 --- /dev/null +++ b/pomice/spotify/models.py @@ -0,0 +1,53 @@ +from typing import Dict +from typing import List +from typing import Optional + +from discord.ext.commands import Context +from discord.user import _UserTag +from pydantic import Field + +from pomice.enums import SearchType +from pomice.enums import TrackType +from pomice.filters import Filter +from pomice.models import BaseModel +from pomice.models.music import Track +from pomice.models.music import TrackInfo + + +class SpotifyTrackRaw(BaseModel): + id: str + name: str + artists: List[Dict[str, str]] + duration_ms: float + external_ids: Dict[str, str] = Field(default_factory=dict) + external_urls: Dict[str, str] = Field(default_factory=dict) + album: Dict[str, List[Dict[str, str]]] = Field(default_factory=dict) + + def build_track( + self, + image: Optional[str] = None, + filters: Optional[List[Filter]] = None, + ctx: Optional[Context] = None, + requester: Optional[_UserTag] = None, + ) -> Track: + if self.album: + image = self.album["images"][0]["url"] + + return Track( + track_id=self.id, + track_type=TrackType.SPOTIFY, + search_type=SearchType.YTMSEARCH, + filters=filters, + ctx=ctx, + requester=requester, + info=TrackInfo( + identifier=self.id, + title=self.name, + author=", ".join(artist["name"] for artist in self.artists), + length=self.duration_ms, + is_seekable=True, + uri=self.external_urls.get("spotify", ""), + artwork_url=image, + isrc=self.external_ids.get("isrc"), + ), + ) diff --git a/pomice/utils.py b/pomice/utils.py index e18b937..aa1a623 100644 --- a/pomice/utils.py +++ b/pomice/utils.py @@ -8,11 +8,10 @@ from typing import Callable from typing import Dict from typing import Iterable -from typing import NamedTuple from typing import Optional -from .enums import RouteIPType -from .enums import RouteStrategy +from pomice.enums import RouteIPType +from pomice.enums import RouteStrategy __all__ = ( "ExponentialBackoff", @@ -20,7 +19,6 @@ "FailingIPBlock", "RouteStats", "Ping", - "LavalinkVersion", ) @@ -226,53 +224,3 @@ def get_ping(self) -> float: s_runtime = 1000 * (cost_time) return s_runtime - - -class LavalinkVersion(NamedTuple): - major: int - minor: int - fix: int - - def __eq__(self, other: object) -> bool: - if not isinstance(other, LavalinkVersion): - return False - - return ( - (self.major == other.major) and (self.minor == other.minor) and (self.fix == other.fix) - ) - - def __ne__(self, other: object) -> bool: - if not isinstance(other, LavalinkVersion): - return False - - return not (self == other) - - def __lt__(self, other: object) -> bool: - if not isinstance(other, LavalinkVersion): - return False - - if self.major > other.major: - return False - if self.minor > other.minor: - return False - if self.fix > other.fix: - return False - return True - - def __gt__(self, other: object) -> bool: - if not isinstance(other, LavalinkVersion): - return False - - return not (self < other) - - def __le__(self, other: object) -> bool: - if not isinstance(other, LavalinkVersion): - return False - - return (self < other) or (self == other) - - def __ge__(self, other: object) -> bool: - if not isinstance(other, LavalinkVersion): - return False - - return (self > other) or (self == other) diff --git a/setup.py b/setup.py index 81fa4f5..bb7d5ef 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import setuptools version = "" -requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets"] +requirements = ["aiohttp>=3.7.4,<4", "orjson", "websockets", "pydantic>=2"] with open("pomice/__init__.py") as f: version = re.search( r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',