Files
pycord-rest/src/pycord_rest/app.py

322 lines
14 KiB
Python
Raw Normal View History

2025-03-08 18:16:37 +01:00
# Copyright (c) Paillat-dev
# SPDX-License-Identifier: MIT
import functools
2025-03-08 18:16:37 +01:00
import logging
from collections.abc import Callable, Coroutine
from functools import cached_property
from typing import Any, Never, override
import aiohttp
import discord
import uvicorn
from discord import Entitlement, Interaction, InteractionType
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
2025-03-08 18:16:37 +01:00
from nacl.exceptions import BadSignatureError
from nacl.signing import VerifyKey
from .models import EventType, WebhookEventPayload, WebhookType
2025-03-08 18:16:37 +01:00
logger = logging.getLogger("pycord.rest")
class ApplicationAuthorizedEvent:
def __init__(self, user: discord.User, guild: discord.Guild | None, type: discord.IntegrationType) -> None: # noqa: A002
self.type: discord.IntegrationType = type
self.user: discord.User = user
self.guild: discord.Guild | None = guild
@override
def __repr__(self) -> str:
return (
f"<ApplicationAuthorizedEvent type={self.type} user={self.user}"
+ (f" guild={self.guild}" if self.guild else "")
+ ">"
)
class PycordRestError(discord.DiscordException):
pass
class InvalidCredentialsError(PycordRestError):
pass
def not_supported[T, U](func: Callable[[T], U]) -> Callable[[T], U]:
@functools.wraps(func)
def inner(*args: T, **kwargs: T) -> U:
logger.warning(f"{func.__qualname__} is not supported by REST apps.")
return func(*args, **kwargs)
return inner
2025-03-08 18:16:37 +01:00
class App(discord.Bot):
_UvicornConfig: type[uvicorn.Config] = uvicorn.Config
_UvicornServer: type[uvicorn.Server] = uvicorn.Server
_FastAPI: type[FastAPI] = FastAPI
_APIRouter: type[APIRouter] = APIRouter
2025-03-08 18:16:37 +01:00
def __init__(self, *args: Any, **options: Any) -> None: # pyright: ignore [reportExplicitAny]
super().__init__(*args, **options) # pyright: ignore [reportUnknownMemberType]
self._app: FastAPI = self._FastAPI(openapi_url=None, docs_url=None, redoc_url=None)
self.router: APIRouter = self._APIRouter()
self._public_key: str | None = None
2025-03-08 18:16:37 +01:00
@property
@override
@not_supported
def latency(self) -> float:
return 0.0
2025-03-08 18:16:37 +01:00
@cached_property
def _verify_key(self) -> VerifyKey:
if self._public_key is None:
raise InvalidCredentialsError("No public key provided")
return VerifyKey(bytes.fromhex(self._public_key))
2025-03-08 18:16:37 +01:00
async def _dispatch_view(self, component_type: int, custom_id: str, interaction: Interaction) -> None:
# Code taken from ViewStore.dispatch
self._connection._view_store._ViewStore__verify_integrity() # noqa: SLF001 # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue, reportPrivateUsage]
message_id: int | None = interaction.message and interaction.message.id
key = (component_type, message_id, custom_id)
value = self._connection._view_store._views.get(key) or self._connection._view_store._views.get( # pyright: ignore [reportUnknownVariableType, reportUnknownMemberType, reportPrivateUsage] # noqa: SLF001
(component_type, None, custom_id)
)
if value is None:
return
view, item = value # pyright: ignore [reportUnknownVariableType]
item.refresh_state(interaction)
# Code taken from View._dispatch_item
if view._View__stopped.done(): # noqa: SLF001 # pyright: ignore [reportAttributeAccessIssue, reportUnknownMemberType]
return
if interaction.message:
view.message = interaction.message
await view._scheduled_task(item, interaction) # noqa: SLF001 # pyright: ignore [reportPrivateUsage, reportUnknownMemberType]
async def _verify_request(self, request: Request) -> None:
signature = request.headers["X-Signature-Ed25519"]
timestamp = request.headers["X-Signature-Timestamp"]
body = (await request.body()).decode("utf-8")
try:
_ = self._verify_key.verify(f"{timestamp}{body}".encode(), bytes.fromhex(signature))
except BadSignatureError as e:
raise HTTPException(status_code=401, detail="Invalid request signature") from e
async def _process_interaction(self, request: Request) -> dict[str, Any]: # pyright: ignore [reportExplicitAny]
# Code taken from ConnectionState.parse_interaction_create
2025-03-08 18:16:37 +01:00
data = await request.json()
interaction = Interaction(data=data, state=self._connection)
match interaction.type:
case InteractionType.component:
custom_id: str = interaction.data["custom_id"] # pyright: ignore [reportGeneralTypeIssues, reportOptionalSubscript, reportUnknownVariableType]
component_type = interaction.data["component_type"] # pyright: ignore [reportGeneralTypeIssues, reportOptionalSubscript, reportUnknownVariableType]
await self._dispatch_view(component_type, custom_id, interaction) # pyright: ignore [reportUnknownArgumentType]
case InteractionType.modal_submit:
user_id, custom_id = ( # pyright: ignore [reportUnknownVariableType]
interaction.user.id, # pyright: ignore [reportOptionalMemberAccess]
interaction.data["custom_id"], # pyright: ignore [reportGeneralTypeIssues, reportOptionalSubscript]
)
await self._connection._modal_store.dispatch(user_id, custom_id, interaction) # pyright: ignore [reportUnknownArgumentType, reportPrivateUsage] # noqa: SLF001
case InteractionType.ping:
return {"type": 1}
case InteractionType.application_command | InteractionType.auto_complete:
await self.process_application_commands(interaction)
self.dispatch("interaction", interaction)
2025-03-08 18:16:37 +01:00
return {"ok": True}
@override
async def on_interaction(self, *args: Never, **kwargs: Never) -> None:
pass
2025-03-08 18:16:37 +01:00
@override
async def process_application_commands( # noqa: PLR0912
self, interaction: Interaction, auto_sync: bool | None = None
) -> None:
# Code taken from super().process_application_commands
2025-03-08 18:16:37 +01:00
if auto_sync is None:
auto_sync = self._bot.auto_sync_commands # pyright: ignore [reportUnknownVariableType, reportUnknownMemberType]
# TODO: find out why the isinstance check below doesn't stop the type errors below # noqa: FIX002, TD002, TD003
if interaction.type not in (
InteractionType.application_command,
InteractionType.auto_complete,
):
return None
command: discord.ApplicationCommand | None = None # pyright: ignore [reportMissingTypeArgument]
try:
if interaction.data:
command = self._application_commands[interaction.data["id"]] # pyright: ignore [reportUnknownVariableType, reportUnknownMemberType, reportGeneralTypeIssues]
except KeyError:
for cmd in self.application_commands + self.pending_application_commands: # pyright: ignore [reportUnknownMemberType, reportUnknownVariableType]
if interaction.data:
guild_id = interaction.data.get("guild_id")
if guild_id:
guild_id = int(guild_id)
if cmd.name == interaction.data["name"] and ( # pyright: ignore [reportGeneralTypeIssues]
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
):
command = cmd # pyright: ignore [reportUnknownVariableType]
break
else:
if auto_sync and interaction.data:
guild_id = interaction.data.get("guild_id")
if guild_id is None:
await self.sync_commands() # pyright: ignore [reportUnknownMemberType]
else:
await self.sync_commands(check_guilds=[guild_id]) # pyright: ignore [reportUnknownMemberType]
return self._bot.dispatch("unknown_application_command", interaction)
if interaction.type is InteractionType.auto_complete:
self._bot.dispatch("application_command_auto_complete", interaction, command)
2025-03-09 18:22:10 +01:00
await super().on_application_command_auto_complete(interaction, command) # pyright: ignore [reportArgumentType, reportUnknownMemberType]
2025-03-08 18:16:37 +01:00
return None
ctx = await self.get_application_context(interaction)
if command:
ctx.command = command
await self.invoke_application_command(ctx)
return None
@override
async def on_application_command_auto_complete(self, *args: Never, **kwargs: Never) -> None: # pyright: ignore [reportIncompatibleMethodOverride]
pass
def _process_interaction_factory(
self,
) -> Callable[[Request], Coroutine[Any, Any, dict[str, Any]]]: # pyright: ignore [reportExplicitAny]
@self.router.post("/", dependencies=[Depends(self._verify_request)])
async def process_interaction(request: Request) -> dict[str, Any]: # pyright: ignore [reportExplicitAny]
return await self._process_interaction(request)
return process_interaction
async def _health(self) -> dict[str, str]:
return {"status": "ok"}
def _health_factory(
self,
) -> Callable[[Request], Coroutine[Any, Any, dict[str, str]]]: # pyright: ignore [reportExplicitAny]
@self.router.get("/health")
async def health(_: Request) -> dict[str, str]:
return await self._health()
return health
async def _handle_webhook_event(self, data: dict[str, Any] | None, event_type: EventType) -> None: # pyright: ignore [reportExplicitAny]
if not data:
raise HTTPException(status_code=400, detail="Missing event data")
match event_type:
case EventType.APPLICATION_AUTHORIZED:
event = ApplicationAuthorizedEvent(
user=discord.User(state=self._connection, data=data["user"]),
guild=(discord.Guild(state=self._connection, data=data["guild"]) if data.get("guild") else None),
type=discord.IntegrationType.guild_install
if data.get("guild")
else discord.IntegrationType.user_install,
)
logger.debug("Dispatching application_authorized event")
self.dispatch("application_authorized", event)
if event.type == discord.IntegrationType.guild_install:
self.dispatch("guild_join", event.guild)
case EventType.ENTITLEMENT_CREATE:
entitlement = Entitlement(data=data, state=self._connection) # pyright: ignore [reportArgumentType]
logger.debug("Dispatching entitlement_create event")
self.dispatch("entitlement_create", entitlement)
case _:
logger.warning(f"Unsupported webhook event type received: {event_type}")
async def _webhook_event(self, payload: WebhookEventPayload) -> Response | dict[str, Any]: # pyright: ignore [reportExplicitAny]
match payload.type:
case WebhookType.PING:
return Response(status_code=204)
case WebhookType.Event:
if not payload.event:
raise HTTPException(status_code=400, detail="Missing event data")
await self._handle_webhook_event(payload.event.data, payload.event.type)
return {"ok": True}
def _webhook_event_factory(
self,
) -> Callable[[WebhookEventPayload], Coroutine[Any, Any, Response | dict[str, Any]]]: # pyright: ignore [reportExplicitAny]
@self.router.post("/webhook", dependencies=[Depends(self._verify_request)], response_model=None)
async def webhook_event(payload: WebhookEventPayload) -> Response | dict[str, Any]: # pyright: ignore [reportExplicitAny]
return await self._webhook_event(payload)
return webhook_event
2025-03-08 18:16:37 +01:00
@override
async def connect( # pyright: ignore [reportIncompatibleMethodOverride]
self,
token: str,
public_key: str,
uvicorn_options: dict[str, Any] | None = None, # pyright: ignore [reportExplicitAny]
health: bool = True,
) -> None:
self._public_key = public_key
2025-03-08 18:16:37 +01:00
_ = self._process_interaction_factory()
_ = self._webhook_event_factory()
2025-03-08 18:16:37 +01:00
if health:
_ = self._health_factory()
self._app.include_router(self.router)
2025-03-08 18:16:37 +01:00
uvicorn_options = uvicorn_options or {}
uvicorn_options["log_level"] = uvicorn_options.get("log_level", logging.root.level)
config = self._UvicornConfig(self._app, **uvicorn_options)
server = self._UvicornServer(config)
2025-03-08 18:16:37 +01:00
try:
self.dispatch("connect")
await server.serve()
except (TimeoutError, OSError, HTTPException, aiohttp.ClientError):
logger.exception("An error occurred while serving the app.")
self.dispatch("disconnect")
@override
async def close(self) -> None:
pass
@override
async def start( # pyright: ignore [reportIncompatibleMethodOverride]
self,
token: str,
public_key: str,
uvicorn_options: dict[str, Any] | None = None, # pyright: ignore [reportExplicitAny]
2025-03-08 18:16:37 +01:00
health: bool = True,
) -> None:
if not token:
raise InvalidCredentialsError("No token provided")
if not public_key:
raise InvalidCredentialsError("No public key provided")
2025-03-08 18:16:37 +01:00
await self.login(token)
await self.connect(
token=token,
public_key=public_key,
uvicorn_options=uvicorn_options,
health=health,
)
@override
def run(
self,
*args: Any, # pyright: ignore [reportExplicitAny]
token: str,
public_key: str,
uvicorn_options: dict[str, Any] | None = None, # pyright: ignore [reportExplicitAny]
health: bool = True,
**kwargs: Any, # pyright: ignore [reportExplicitAny]
) -> None:
super().run(
*args,
token=token,
public_key=public_key,
uvicorn_options=uvicorn_options,
health=health,
**kwargs,
)