2025-03-08 18:16:37 +01:00
|
|
|
# Copyright (c) Paillat-dev
|
|
|
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
2025-03-10 23:31:49 +01:00
|
|
|
import functools
|
2025-03-08 18:16:37 +01:00
|
|
|
import logging
|
2025-03-13 09:22:16 +01:00
|
|
|
import warnings
|
2025-03-08 18:16:37 +01:00
|
|
|
from collections.abc import Callable, Coroutine
|
|
|
|
|
from functools import cached_property
|
|
|
|
|
from typing import Any, Never, override
|
|
|
|
|
|
|
|
|
|
import aiohttp
|
|
|
|
|
import discord
|
|
|
|
|
import uvicorn
|
2025-03-09 16:24:36 +01:00
|
|
|
from discord import Entitlement, Interaction, InteractionType
|
|
|
|
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
|
2025-03-08 18:16:37 +01:00
|
|
|
from nacl.exceptions import BadSignatureError
|
|
|
|
|
from nacl.signing import VerifyKey
|
|
|
|
|
|
2025-03-09 16:24:36 +01:00
|
|
|
from .models import EventType, WebhookEventPayload, WebhookType
|
|
|
|
|
|
2025-03-08 18:16:37 +01:00
|
|
|
logger = logging.getLogger("pycord.rest")
|
|
|
|
|
|
|
|
|
|
|
2025-03-09 16:24:36 +01:00
|
|
|
class ApplicationAuthorizedEvent:
|
|
|
|
|
def __init__(self, user: discord.User, guild: discord.Guild | None, type: discord.IntegrationType) -> None: # noqa: A002
|
|
|
|
|
self.type: discord.IntegrationType = type
|
|
|
|
|
self.user: discord.User = user
|
|
|
|
|
self.guild: discord.Guild | None = guild
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return (
|
|
|
|
|
f"<ApplicationAuthorizedEvent type={self.type} user={self.user}"
|
|
|
|
|
+ (f" guild={self.guild}" if self.guild else "")
|
|
|
|
|
+ ">"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-03-09 18:04:00 +01:00
|
|
|
class PycordRestError(discord.DiscordException):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InvalidCredentialsError(PycordRestError):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
2025-03-10 23:31:49 +01:00
|
|
|
def not_supported[T, U](func: Callable[[T], U]) -> Callable[[T], U]:
|
|
|
|
|
@functools.wraps(func)
|
|
|
|
|
def inner(*args: T, **kwargs: T) -> U:
|
|
|
|
|
logger.warning(f"{func.__qualname__} is not supported by REST apps.")
|
2025-03-13 09:22:16 +01:00
|
|
|
warnings.warn(
|
|
|
|
|
f"{func.__qualname__} is not supported by REST apps.",
|
|
|
|
|
SyntaxWarning,
|
|
|
|
|
stacklevel=2,
|
|
|
|
|
)
|
2025-03-10 23:31:49 +01:00
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
|
2025-03-08 18:16:37 +01:00
|
|
|
class App(discord.Bot):
|
2025-03-13 08:47:21 +01:00
|
|
|
_UvicornConfig: type[uvicorn.Config] = uvicorn.Config
|
|
|
|
|
_UvicornServer: type[uvicorn.Server] = uvicorn.Server
|
|
|
|
|
_FastAPI: type[FastAPI] = FastAPI
|
|
|
|
|
_APIRouter: type[APIRouter] = APIRouter
|
|
|
|
|
|
2025-03-08 18:16:37 +01:00
|
|
|
def __init__(self, *args: Any, **options: Any) -> None: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
super().__init__(*args, **options) # pyright: ignore [reportUnknownMemberType]
|
2025-03-13 08:47:21 +01:00
|
|
|
self._app: FastAPI = self._FastAPI(openapi_url=None, docs_url=None, redoc_url=None)
|
|
|
|
|
self.router: APIRouter = self._APIRouter()
|
2025-03-09 18:04:00 +01:00
|
|
|
self._public_key: str | None = None
|
2025-03-08 18:16:37 +01:00
|
|
|
|
2025-03-10 23:31:49 +01:00
|
|
|
@property
|
|
|
|
|
@override
|
|
|
|
|
@not_supported
|
|
|
|
|
def latency(self) -> float:
|
|
|
|
|
return 0.0
|
|
|
|
|
|
2025-03-08 18:16:37 +01:00
|
|
|
@cached_property
|
|
|
|
|
def _verify_key(self) -> VerifyKey:
|
2025-03-09 18:04:00 +01:00
|
|
|
if self._public_key is None:
|
|
|
|
|
raise InvalidCredentialsError("No public key provided")
|
|
|
|
|
return VerifyKey(bytes.fromhex(self._public_key))
|
2025-03-08 18:16:37 +01:00
|
|
|
|
|
|
|
|
async def _dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None:
|
|
|
|
|
# Code taken from ViewStore.dispatch
|
|
|
|
|
self._connection._view_store._ViewStore__verify_integrity() # noqa: SLF001 # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue, reportPrivateUsage]
|
|
|
|
|
message_id: int | None = interaction.message and interaction.message.id
|
|
|
|
|
key = (component_type, message_id, custom_id)
|
|
|
|
|
value = self._connection._view_store._views.get(key) or self._connection._view_store._views.get( # pyright: ignore [reportUnknownVariableType, reportUnknownMemberType, reportPrivateUsage] # noqa: SLF001
|
|
|
|
|
(component_type, None, custom_id)
|
|
|
|
|
)
|
|
|
|
|
if value is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
view, item = value # pyright: ignore [reportUnknownVariableType]
|
|
|
|
|
item.refresh_state(interaction)
|
|
|
|
|
|
|
|
|
|
# Code taken from View._dispatch_item
|
|
|
|
|
if view._View__stopped.done(): # noqa: SLF001 # pyright: ignore [reportAttributeAccessIssue, reportUnknownMemberType]
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if interaction.message:
|
|
|
|
|
view.message = interaction.message
|
|
|
|
|
|
|
|
|
|
await view._scheduled_task(item, interaction) # noqa: SLF001 # pyright: ignore [reportPrivateUsage, reportUnknownMemberType]
|
|
|
|
|
|
|
|
|
|
async def _verify_request(self, request: Request) -> None:
|
|
|
|
|
signature = request.headers["X-Signature-Ed25519"]
|
|
|
|
|
timestamp = request.headers["X-Signature-Timestamp"]
|
|
|
|
|
body = (await request.body()).decode("utf-8")
|
|
|
|
|
try:
|
|
|
|
|
_ = self._verify_key.verify(f"{timestamp}{body}".encode(), bytes.fromhex(signature))
|
|
|
|
|
except BadSignatureError as e:
|
|
|
|
|
raise HTTPException(status_code=401, detail="Invalid request signature") from e
|
|
|
|
|
|
|
|
|
|
async def _process_interaction(self, request: Request) -> dict[str, Any]: # pyright: ignore [reportExplicitAny]
|
2025-03-08 19:53:04 +01:00
|
|
|
# Code taken from ConnectionState.parse_interaction_create
|
2025-03-08 18:16:37 +01:00
|
|
|
data = await request.json()
|
|
|
|
|
interaction = Interaction(data=data, state=self._connection)
|
2025-03-08 19:53:04 +01:00
|
|
|
match interaction.type:
|
|
|
|
|
case InteractionType.component:
|
|
|
|
|
custom_id: str = interaction.data["custom_id"] # pyright: ignore [reportGeneralTypeIssues, reportOptionalSubscript, reportUnknownVariableType]
|
|
|
|
|
component_type = interaction.data["component_type"] # pyright: ignore [reportGeneralTypeIssues, reportOptionalSubscript, reportUnknownVariableType]
|
|
|
|
|
await self._dispatch_view(component_type, custom_id, interaction) # pyright: ignore [reportUnknownArgumentType]
|
|
|
|
|
case InteractionType.modal_submit:
|
|
|
|
|
user_id, custom_id = ( # pyright: ignore [reportUnknownVariableType]
|
|
|
|
|
interaction.user.id, # pyright: ignore [reportOptionalMemberAccess]
|
|
|
|
|
interaction.data["custom_id"], # pyright: ignore [reportGeneralTypeIssues, reportOptionalSubscript]
|
|
|
|
|
)
|
|
|
|
|
await self._connection._modal_store.dispatch(user_id, custom_id, interaction) # pyright: ignore [reportUnknownArgumentType, reportPrivateUsage] # noqa: SLF001
|
|
|
|
|
case InteractionType.ping:
|
|
|
|
|
return {"type": 1}
|
|
|
|
|
case InteractionType.application_command | InteractionType.auto_complete:
|
|
|
|
|
await self.process_application_commands(interaction)
|
|
|
|
|
self.dispatch("interaction", interaction)
|
2025-03-08 18:16:37 +01:00
|
|
|
return {"ok": True}
|
|
|
|
|
|
2025-03-08 19:53:04 +01:00
|
|
|
@override
|
|
|
|
|
async def on_interaction(self, *args: Never, **kwargs: Never) -> None:
|
|
|
|
|
pass
|
|
|
|
|
|
2025-03-08 18:16:37 +01:00
|
|
|
@override
|
|
|
|
|
async def process_application_commands( # noqa: PLR0912
|
|
|
|
|
self, interaction: Interaction, auto_sync: bool | None = None
|
|
|
|
|
) -> None:
|
2025-03-09 18:04:00 +01:00
|
|
|
# Code taken from super().process_application_commands
|
2025-03-08 18:16:37 +01:00
|
|
|
if auto_sync is None:
|
|
|
|
|
auto_sync = self._bot.auto_sync_commands # pyright: ignore [reportUnknownVariableType, reportUnknownMemberType]
|
|
|
|
|
# TODO: find out why the isinstance check below doesn't stop the type errors below # noqa: FIX002, TD002, TD003
|
|
|
|
|
if interaction.type not in (
|
|
|
|
|
InteractionType.application_command,
|
|
|
|
|
InteractionType.auto_complete,
|
|
|
|
|
):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
command: discord.ApplicationCommand | None = None # pyright: ignore [reportMissingTypeArgument]
|
|
|
|
|
try:
|
|
|
|
|
if interaction.data:
|
|
|
|
|
command = self._application_commands[interaction.data["id"]] # pyright: ignore [reportUnknownVariableType, reportUnknownMemberType, reportGeneralTypeIssues]
|
|
|
|
|
except KeyError:
|
|
|
|
|
for cmd in self.application_commands + self.pending_application_commands: # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
|
|
|
|
|
if interaction.data:
|
|
|
|
|
guild_id = interaction.data.get("guild_id")
|
|
|
|
|
if guild_id:
|
|
|
|
|
guild_id = int(guild_id)
|
|
|
|
|
if cmd.name == interaction.data["name"] and ( # pyright: ignore [reportGeneralTypeIssues]
|
|
|
|
|
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
|
|
|
|
|
):
|
|
|
|
|
command = cmd # pyright: ignore [reportUnknownVariableType]
|
|
|
|
|
break
|
|
|
|
|
else:
|
|
|
|
|
if auto_sync and interaction.data:
|
|
|
|
|
guild_id = interaction.data.get("guild_id")
|
|
|
|
|
if guild_id is None:
|
|
|
|
|
await self.sync_commands() # pyright: ignore [reportUnknownMemberType]
|
|
|
|
|
else:
|
|
|
|
|
await self.sync_commands(check_guilds=[guild_id]) # pyright: ignore [reportUnknownMemberType]
|
|
|
|
|
return self._bot.dispatch("unknown_application_command", interaction)
|
|
|
|
|
|
|
|
|
|
if interaction.type is InteractionType.auto_complete:
|
2025-03-08 19:53:04 +01:00
|
|
|
self._bot.dispatch("application_command_auto_complete", interaction, command)
|
2025-03-09 18:22:10 +01:00
|
|
|
await super().on_application_command_auto_complete(interaction, command) # pyright: ignore [reportArgumentType, reportUnknownMemberType]
|
2025-03-08 18:16:37 +01:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
ctx = await self.get_application_context(interaction)
|
|
|
|
|
if command:
|
|
|
|
|
ctx.command = command
|
|
|
|
|
await self.invoke_application_command(ctx)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
async def on_application_command_auto_complete(self, *args: Never, **kwargs: Never) -> None: # pyright: ignore [reportIncompatibleMethodOverride]
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def _process_interaction_factory(
|
|
|
|
|
self,
|
|
|
|
|
) -> Callable[[Request], Coroutine[Any, Any, dict[str, Any]]]: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
@self.router.post("/", dependencies=[Depends(self._verify_request)])
|
|
|
|
|
async def process_interaction(request: Request) -> dict[str, Any]: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
return await self._process_interaction(request)
|
|
|
|
|
|
|
|
|
|
return process_interaction
|
|
|
|
|
|
|
|
|
|
async def _health(self) -> dict[str, str]:
|
|
|
|
|
return {"status": "ok"}
|
|
|
|
|
|
|
|
|
|
def _health_factory(
|
|
|
|
|
self,
|
|
|
|
|
) -> Callable[[Request], Coroutine[Any, Any, dict[str, str]]]: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
@self.router.get("/health")
|
|
|
|
|
async def health(_: Request) -> dict[str, str]:
|
|
|
|
|
return await self._health()
|
|
|
|
|
|
|
|
|
|
return health
|
|
|
|
|
|
2025-03-09 16:24:36 +01:00
|
|
|
async def _handle_webhook_event(self, data: dict[str, Any] | None, event_type: EventType) -> None: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
if not data:
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing event data")
|
|
|
|
|
|
|
|
|
|
match event_type:
|
|
|
|
|
case EventType.APPLICATION_AUTHORIZED:
|
|
|
|
|
event = ApplicationAuthorizedEvent(
|
|
|
|
|
user=discord.User(state=self._connection, data=data["user"]),
|
|
|
|
|
guild=(discord.Guild(state=self._connection, data=data["guild"]) if data.get("guild") else None),
|
|
|
|
|
type=discord.IntegrationType.guild_install
|
|
|
|
|
if data.get("guild")
|
|
|
|
|
else discord.IntegrationType.user_install,
|
|
|
|
|
)
|
|
|
|
|
logger.debug("Dispatching application_authorized event")
|
|
|
|
|
self.dispatch("application_authorized", event)
|
|
|
|
|
if event.type == discord.IntegrationType.guild_install:
|
|
|
|
|
self.dispatch("guild_join", event.guild)
|
|
|
|
|
case EventType.ENTITLEMENT_CREATE:
|
|
|
|
|
entitlement = Entitlement(data=data, state=self._connection) # pyright: ignore [reportArgumentType]
|
|
|
|
|
logger.debug("Dispatching entitlement_create event")
|
|
|
|
|
self.dispatch("entitlement_create", entitlement)
|
|
|
|
|
case _:
|
|
|
|
|
logger.warning(f"Unsupported webhook event type received: {event_type}")
|
|
|
|
|
|
|
|
|
|
async def _webhook_event(self, payload: WebhookEventPayload) -> Response | dict[str, Any]: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
match payload.type:
|
|
|
|
|
case WebhookType.PING:
|
|
|
|
|
return Response(status_code=204)
|
|
|
|
|
case WebhookType.Event:
|
|
|
|
|
if not payload.event:
|
|
|
|
|
raise HTTPException(status_code=400, detail="Missing event data")
|
|
|
|
|
await self._handle_webhook_event(payload.event.data, payload.event.type)
|
|
|
|
|
|
|
|
|
|
return {"ok": True}
|
|
|
|
|
|
|
|
|
|
def _webhook_event_factory(
|
|
|
|
|
self,
|
|
|
|
|
) -> Callable[[WebhookEventPayload], Coroutine[Any, Any, Response | dict[str, Any]]]: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
@self.router.post("/webhook", dependencies=[Depends(self._verify_request)], response_model=None)
|
|
|
|
|
async def webhook_event(payload: WebhookEventPayload) -> Response | dict[str, Any]: # pyright: ignore [reportExplicitAny]
|
|
|
|
|
return await self._webhook_event(payload)
|
|
|
|
|
|
|
|
|
|
return webhook_event
|
|
|
|
|
|
2025-03-08 18:16:37 +01:00
|
|
|
@override
|
|
|
|
|
async def connect( # pyright: ignore [reportIncompatibleMethodOverride]
|
|
|
|
|
self,
|
|
|
|
|
token: str,
|
|
|
|
|
public_key: str,
|
|
|
|
|
uvicorn_options: dict[str, Any] | None = None, # pyright: ignore [reportExplicitAny]
|
|
|
|
|
health: bool = True,
|
|
|
|
|
) -> None:
|
2025-03-09 18:04:00 +01:00
|
|
|
self._public_key = public_key
|
2025-03-08 18:16:37 +01:00
|
|
|
_ = self._process_interaction_factory()
|
2025-03-09 16:24:36 +01:00
|
|
|
_ = self._webhook_event_factory()
|
2025-03-08 18:16:37 +01:00
|
|
|
if health:
|
|
|
|
|
_ = self._health_factory()
|
2025-03-09 18:04:00 +01:00
|
|
|
self._app.include_router(self.router)
|
2025-03-08 18:16:37 +01:00
|
|
|
uvicorn_options = uvicorn_options or {}
|
|
|
|
|
uvicorn_options["log_level"] = uvicorn_options.get("log_level", logging.root.level)
|
2025-03-13 08:47:21 +01:00
|
|
|
config = self._UvicornConfig(self._app, **uvicorn_options)
|
|
|
|
|
server = self._UvicornServer(config)
|
2025-03-08 18:16:37 +01:00
|
|
|
try:
|
|
|
|
|
self.dispatch("connect")
|
|
|
|
|
await server.serve()
|
|
|
|
|
except (TimeoutError, OSError, HTTPException, aiohttp.ClientError):
|
|
|
|
|
logger.exception("An error occurred while serving the app.")
|
|
|
|
|
self.dispatch("disconnect")
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
async def close(self) -> None:
|
2025-03-13 09:21:01 +01:00
|
|
|
self._closed: bool = True
|
|
|
|
|
|
|
|
|
|
await self.http.close()
|
|
|
|
|
self._ready.clear()
|
2025-03-08 18:16:37 +01:00
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
async def start( # pyright: ignore [reportIncompatibleMethodOverride]
|
|
|
|
|
self,
|
|
|
|
|
token: str,
|
|
|
|
|
public_key: str,
|
2025-03-08 20:26:00 +01:00
|
|
|
uvicorn_options: dict[str, Any] | None = None, # pyright: ignore [reportExplicitAny]
|
2025-03-08 18:16:37 +01:00
|
|
|
health: bool = True,
|
|
|
|
|
) -> None:
|
2025-03-09 18:04:00 +01:00
|
|
|
if not token:
|
|
|
|
|
raise InvalidCredentialsError("No token provided")
|
|
|
|
|
if not public_key:
|
|
|
|
|
raise InvalidCredentialsError("No public key provided")
|
2025-03-08 18:16:37 +01:00
|
|
|
await self.login(token)
|
|
|
|
|
await self.connect(
|
|
|
|
|
token=token,
|
|
|
|
|
public_key=public_key,
|
|
|
|
|
uvicorn_options=uvicorn_options,
|
|
|
|
|
health=health,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@override
|
|
|
|
|
def run(
|
|
|
|
|
self,
|
|
|
|
|
*args: Any, # pyright: ignore [reportExplicitAny]
|
|
|
|
|
token: str,
|
|
|
|
|
public_key: str,
|
|
|
|
|
uvicorn_options: dict[str, Any] | None = None, # pyright: ignore [reportExplicitAny]
|
|
|
|
|
health: bool = True,
|
|
|
|
|
**kwargs: Any, # pyright: ignore [reportExplicitAny]
|
|
|
|
|
) -> None:
|
|
|
|
|
super().run(
|
|
|
|
|
*args,
|
|
|
|
|
token=token,
|
|
|
|
|
public_key=public_key,
|
|
|
|
|
uvicorn_options=uvicorn_options,
|
|
|
|
|
health=health,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|