Skip to content

Commit 9bffdeb

Browse files
committed
2.10.0
1 parent 720ba18 commit 9bffdeb

5 files changed

Lines changed: 359 additions & 62 deletions

File tree

pomice/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class DiscordPyOutdated(Exception):
2020
"using 'pip install discord.py'",
2121
)
2222

23-
__version__ = "2.9.2"
23+
__version__ = "2.10.0"
2424
__title__ = "pomice"
2525
__author__ = "cloudwithax"
2626
__license__ = "GPL-3.0"

pomice/applemusic/client.py

Lines changed: 127 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import base64
45
import logging
56
import re
67
from datetime import datetime
8+
from typing import AsyncGenerator
79
from typing import Dict
810
from typing import List
11+
from typing import Optional
912
from typing import Union
1013

1114
import aiohttp
@@ -17,10 +20,10 @@
1720
__all__ = ("Client",)
1821

1922
AM_URL_REGEX = re.compile(
20-
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
23+
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
2124
)
2225
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
23-
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
26+
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
2427
)
2528

2629
AM_SCRIPT_REGEX = re.compile(r'<script.*?src="(/assets/index-.*?)"')
@@ -35,12 +38,14 @@ class Client:
3538
and translating it to a valid Lavalink track. No client auth is required here.
3639
"""
3740

38-
def __init__(self) -> None:
41+
def __init__(self, *, playlist_concurrency: int = 6) -> None:
3942
self.expiry: datetime = datetime(1970, 1, 1)
4043
self.token: str = ""
4144
self.headers: Dict[str, str] = {}
4245
self.session: aiohttp.ClientSession = None # type: ignore
4346
self._log = logging.getLogger(__name__)
47+
# Concurrency knob for parallel playlist page retrieval
48+
self._playlist_concurrency = max(1, playlist_concurrency)
4449

4550
async def _set_session(self, session: aiohttp.ClientSession) -> None:
4651
self.session = session
@@ -167,25 +172,127 @@ async def search(self, query: str) -> Union[Album, Playlist, Song, Artist]:
167172
"This playlist is empty and therefore cannot be queued.",
168173
)
169174

170-
_next = track_data.get("next")
171-
if _next:
172-
next_page_url = AM_BASE_URL + _next
173-
174-
while next_page_url is not None:
175-
resp = await self.session.get(next_page_url, headers=self.headers)
175+
# Apple Music uses cursor pagination with 'next'. We'll fetch subsequent pages
176+
# concurrently by first collecting cursors in rolling waves.
177+
next_cursor = track_data.get("next")
178+
semaphore = asyncio.Semaphore(self._playlist_concurrency)
176179

180+
async def fetch_page(url: str) -> List[Song]:
181+
async with semaphore:
182+
resp = await self.session.get(url, headers=self.headers)
177183
if resp.status != 200:
178-
raise AppleMusicRequestException(
179-
f"Error while fetching results: {resp.status} {resp.reason}",
180-
)
184+
if self._log:
185+
self._log.warning(
186+
f"Apple Music page fetch failed {resp.status} {resp.reason} for {url}",
187+
)
188+
return []
189+
pj: dict = await resp.json(loads=json.loads)
190+
songs = [Song(track) for track in pj.get("data", [])]
191+
# Return songs; we will look for pj.get('next') in streaming iterator variant
192+
return songs, pj.get("next") # type: ignore
193+
194+
# We'll implement a wave-based approach similar to Spotify but need to follow cursors.
195+
# Because we cannot know all cursors upfront, we'll iteratively fetch waves.
196+
waves: List[List[Song]] = []
197+
cursors: List[str] = []
198+
if next_cursor:
199+
cursors.append(next_cursor)
200+
201+
# Limit total waves to avoid infinite loops in malformed responses
202+
max_waves = 50
203+
wave_size = self._playlist_concurrency * 2
204+
wave_counter = 0
205+
while cursors and wave_counter < max_waves:
206+
current = cursors[:wave_size]
207+
cursors = cursors[wave_size:]
208+
tasks = [
209+
fetch_page(AM_BASE_URL + cursor) for cursor in current # type: ignore[arg-type]
210+
]
211+
results = await asyncio.gather(*tasks, return_exceptions=True)
212+
for res in results:
213+
if isinstance(res, tuple): # (songs, next)
214+
songs, nxt = res
215+
if songs:
216+
waves.append(songs)
217+
if nxt:
218+
cursors.append(nxt)
219+
wave_counter += 1
220+
221+
for w in waves:
222+
album_tracks.extend(w)
223+
224+
return Playlist(data, album_tracks)
181225

182-
next_data: dict = await resp.json(loads=json.loads)
183-
album_tracks.extend(Song(track) for track in next_data["data"])
226+
async def iter_playlist_tracks(
227+
self,
228+
*,
229+
query: str,
230+
batch_size: int = 100,
231+
) -> AsyncGenerator[List[Song], None]:
232+
"""Stream Apple Music playlist tracks in batches.
233+
234+
Parameters
235+
----------
236+
query: str
237+
Apple Music playlist URL.
238+
batch_size: int
239+
Logical grouping size for yielded batches.
240+
"""
241+
if not self.token or datetime.utcnow() > self.expiry:
242+
await self.request_token()
184243

185-
_next = next_data.get("next")
186-
if _next:
187-
next_page_url = AM_BASE_URL + _next
188-
else:
189-
next_page_url = None
244+
result = AM_URL_REGEX.match(query)
245+
if not result or result.group("type") != "playlist":
246+
raise InvalidAppleMusicURL("Provided query is not a valid Apple Music playlist URL.")
190247

191-
return Playlist(data, album_tracks)
248+
country = result.group("country")
249+
playlist_id = result.group("id")
250+
request_url = AM_REQ_URL.format(country=country, type="playlist", id=playlist_id)
251+
resp = await self.session.get(request_url, headers=self.headers)
252+
if resp.status != 200:
253+
raise AppleMusicRequestException(
254+
f"Error while fetching results: {resp.status} {resp.reason}",
255+
)
256+
data: dict = await resp.json(loads=json.loads)
257+
playlist_data = data["data"][0]
258+
track_data: dict = playlist_data["relationships"]["tracks"]
259+
260+
first_page_tracks = [Song(track) for track in track_data["data"]]
261+
for i in range(0, len(first_page_tracks), batch_size):
262+
yield first_page_tracks[i : i + batch_size]
263+
264+
next_cursor = track_data.get("next")
265+
semaphore = asyncio.Semaphore(self._playlist_concurrency)
266+
267+
async def fetch(cursor: str) -> tuple[List[Song], Optional[str]]:
268+
url = AM_BASE_URL + cursor
269+
async with semaphore:
270+
r = await self.session.get(url, headers=self.headers)
271+
if r.status != 200:
272+
if self._log:
273+
self._log.warning(
274+
f"Skipping Apple Music page due to {r.status} {r.reason}",
275+
)
276+
return [], None
277+
pj: dict = await r.json(loads=json.loads)
278+
songs = [Song(track) for track in pj.get("data", [])]
279+
return songs, pj.get("next")
280+
281+
# Rolling waves of fetches following cursor chain
282+
max_waves = 50
283+
wave_size = self._playlist_concurrency * 2
284+
waves = 0
285+
cursors: List[str] = []
286+
if next_cursor:
287+
cursors.append(next_cursor)
288+
while cursors and waves < max_waves:
289+
current = cursors[:wave_size]
290+
cursors = cursors[wave_size:]
291+
results = await asyncio.gather(*[fetch(c) for c in current])
292+
for songs, nxt in results:
293+
if songs:
294+
for j in range(0, len(songs), batch_size):
295+
yield songs[j : j + batch_size]
296+
if nxt:
297+
cursors.append(nxt)
298+
waves += 1

pomice/enums.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ class SearchType(Enum):
3434
ytsearch = "ytsearch"
3535
ytmsearch = "ytmsearch"
3636
scsearch = "scsearch"
37+
other = "other"
38+
39+
@classmethod
40+
def _missing_(cls, value: object) -> "SearchType": # type: ignore[override]
41+
return cls.other
3742

3843
def __str__(self) -> str:
3944
return self.value
@@ -68,7 +73,7 @@ class TrackType(Enum):
6873
OTHER = "other"
6974

7075
@classmethod
71-
def _missing_(cls, _: object) -> "TrackType":
76+
def _missing_(cls, value: object) -> "TrackType": # type: ignore[override]
7277
return cls.OTHER
7378

7479
def __str__(self) -> str:
@@ -98,7 +103,7 @@ class PlaylistType(Enum):
98103
OTHER = "other"
99104

100105
@classmethod
101-
def _missing_(cls, _: object) -> "PlaylistType":
106+
def _missing_(cls, value: object) -> "PlaylistType": # type: ignore[override]
102107
return cls.OTHER
103108

104109
def __str__(self) -> str:
@@ -213,8 +218,12 @@ class URLRegex:
213218
214219
"""
215220

221+
# Spotify share links can include query parameters like ?si=XXXX, a trailing slash,
222+
# or an intl locale segment (e.g. /intl-en/). Broaden the regex so we still capture
223+
# the type and id while ignoring extra parameters. This prevents the URL from being
224+
# treated as a generic Lavalink identifier and ensures internal Spotify handling runs.
216225
SPOTIFY_URL = re.compile(
217-
r"https?://open.spotify.com/(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)",
226+
r"https?://open\.spotify\.com/(?:intl-[a-zA-Z-]+/)?(?P<type>album|playlist|track|artist)/(?P<id>[a-zA-Z0-9]+)(?:/)?(?:\?.*)?$",
218227
)
219228

220229
DISCORD_MP3_URL = re.compile(
@@ -235,14 +244,17 @@ class URLRegex:
235244
r"(?P<video>^.*?)(\?t|&start)=(?P<time>\d+)?.*",
236245
)
237246

247+
# Apple Music links sometimes append additional query parameters (e.g. &l=en, &uo=4).
248+
# Allow arbitrary query parameters so valid links are captured and parsed.
238249
AM_URL = re.compile(
239-
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/"
240-
r"(?P<type>album|playlist|song|artist)/(?P<name>.+)/(?P<id>[^?]+)",
250+
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/"
251+
r"(?P<type>album|playlist|song|artist)/(?P<name>.+?)/(?P<id>[^/?]+?)(?:/)?(?:\?.*)?$",
241252
)
242253

254+
# Single-in-album links may also carry extra query params beyond the ?i=<trackid> token.
243255
AM_SINGLE_IN_ALBUM_REGEX = re.compile(
244-
r"https?://music.apple.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
245-
r"(?P<name>.+)/(?P<id>.+)(\?i=)(?P<id2>.+)",
256+
r"https?://music\.apple\.com/(?P<country>[a-zA-Z]{2})/(?P<type>album|playlist|song|artist)/"
257+
r"(?P<name>.+)/(?P<id>[^/?]+)(\?i=)(?P<id2>[^&]+)(?:&.*)?$",
246258
)
247259

248260
SOUNDCLOUD_URL = re.compile(

pomice/pool.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
from discord import Client
2222
from discord.ext import commands
2323
from discord.utils import MISSING
24-
from websockets import client
24+
25+
try:
26+
from websockets.legacy import client # websockets >= 10.0
27+
except ImportError:
28+
import websockets.client as client # websockets < 10.0 # type: ignore
29+
2530
from websockets import exceptions
2631
from websockets import typing as wstype
2732

@@ -303,7 +308,7 @@ async def _configure_resuming(self) -> None:
303308
if not self._resume_key:
304309
return
305310

306-
data = {"timeout": self._resume_timeout}
311+
data: Dict[str, Union[int, str, bool]] = {"timeout": self._resume_timeout}
307312

308313
if self._version.major == 3:
309314
data["resumingKey"] = self._resume_key
@@ -444,7 +449,17 @@ async def connect(self, *, reconnect: bool = False) -> Node:
444449
start = time.perf_counter()
445450

446451
if not self._session:
447-
self._session = aiohttp.ClientSession()
452+
# Configure connection pooling for optimal concurrent request performance
453+
connector = aiohttp.TCPConnector(
454+
limit=100, # Total connection limit
455+
limit_per_host=30, # Per-host connection limit
456+
ttl_dns_cache=300, # DNS cache TTL in seconds
457+
)
458+
timeout = aiohttp.ClientTimeout(total=30, connect=10)
459+
self._session = aiohttp.ClientSession(
460+
connector=connector,
461+
timeout=timeout,
462+
)
448463

449464
try:
450465
if not reconnect:
@@ -463,7 +478,7 @@ async def connect(self, *, reconnect: bool = False) -> Node:
463478
f"Version check from Node {self._identifier} successful. Returned version {version}",
464479
)
465480

466-
self._websocket = await client.connect(
481+
self._websocket = await client.connect( # type: ignore
467482
f"{self._websocket_uri}/v{self._version.major}/websocket",
468483
extra_headers=self._headers,
469484
ping_interval=self._heartbeat,
@@ -560,7 +575,7 @@ async def get_tracks(
560575
query: str,
561576
*,
562577
ctx: Optional[commands.Context] = None,
563-
search_type: SearchType | None = SearchType.ytsearch,
578+
search_type: Optional[SearchType] = SearchType.ytsearch,
564579
filters: Optional[List[Filter]] = None,
565580
) -> Optional[Union[Playlist, List[Track]]]:
566581
"""Fetches tracks from the node's REST api to parse into Lavalink.
@@ -595,7 +610,7 @@ async def get_tracks(
595610
track_id=apple_music_results.id,
596611
ctx=ctx,
597612
track_type=TrackType.APPLE_MUSIC,
598-
search_type=search_type,
613+
search_type=search_type or SearchType.ytsearch,
599614
filters=filters,
600615
info={
601616
"title": apple_music_results.name,
@@ -617,7 +632,7 @@ async def get_tracks(
617632
track_id=track.id,
618633
ctx=ctx,
619634
track_type=TrackType.APPLE_MUSIC,
620-
search_type=search_type,
635+
search_type=search_type or SearchType.ytsearch,
621636
filters=filters,
622637
info={
623638
"title": track.name,
@@ -655,7 +670,7 @@ async def get_tracks(
655670
track_id=spotify_results.id,
656671
ctx=ctx,
657672
track_type=TrackType.SPOTIFY,
658-
search_type=search_type,
673+
search_type=search_type or SearchType.ytsearch,
659674
filters=filters,
660675
info={
661676
"title": spotify_results.name,
@@ -677,7 +692,7 @@ async def get_tracks(
677692
track_id=track.id,
678693
ctx=ctx,
679694
track_type=TrackType.SPOTIFY,
680-
search_type=search_type,
695+
search_type=search_type or SearchType.ytsearch,
681696
filters=filters,
682697
info={
683698
"title": track.name,

0 commit comments

Comments
 (0)