diff --git a/tests/common.py b/tests/common.py index 7f00fc7..636f620 100644 --- a/tests/common.py +++ b/tests/common.py @@ -22,9 +22,12 @@ from typing import Optional from dataclasses import dataclass from litecord.common.users import create_user, delete_user -from litecord.common.guilds import delete_guild +from litecord.common.guilds import delete_guild, create_guild_channel +from litecord.blueprints.channel.messages import create_message from litecord.blueprints.auth import make_token from litecord.storage import int_ +from litecord.enums import ChannelType +from litecord.errors import ChannelNotFound, MessageNotFound def email() -> str: @@ -102,7 +105,9 @@ class WrappedGuild: async def refetch(self) -> "WrappedGuild": async with self.test_cli.app.app_context(): - guild = await self.test_cli.app.storage.get_guild_full(self.id) + guild = await self.test_cli.app.storage.get_guild_full( + self.id, user_id=self.test_cli.user["id"] + ) return WrappedGuild.from_json(self.test_cli, guild) @classmethod @@ -128,6 +133,119 @@ class WrappedGuild: ) +@dataclass +class WrappedGuildChannel: + test_cli: "TestClient" + id: int + type: int + guild_id: int + parent_id: Optional[int] + name: str + position: int + nsfw: bool + topic: str + rate_limit_per_user: int + last_message_id: int + permission_overwrites: list + + async def delete(self): + async with self.test_cli.app.app_context(): + resp = await self.test_cli.delete( + f"/api/v6/channels/{self.id}", + ) + rjson = await resp.json + + if resp.status_code == 404 and rjson["code"] == ChannelNotFound.error_code: + return + + assert resp.status_code == 200 + assert rjson["id"] == str(self.id) + + async def refetch(self) -> dict: + async with self.test_cli.app.app_context(): + channel_data = await self.test_cli.app.storage.get_channel(self.id) + return WrappedGuildChannel.from_json(self.test_cli, channel_data) + + @classmethod + def from_json(cls, test_cli, rjson): + return cls( + test_cli, + **{ + **rjson, + **{ + "id": int(rjson["id"]), + "guild_id": int(rjson["guild_id"]), + "parent_id": int_(rjson["parent_id"]), + }, + }, + ) + + +@dataclass +class WrappedMessage: + test_cli: "TestClient" + + id: int + channel_id: int + author: dict + + type: int + content: str + + timestamp: str + edited_timestamp: str + + tts: bool + mention_everyone: bool + nonce: str + embeds: list + mentions: list + mention_roles: list + reactions: list + attachments: list + pinned: bool + message_reference: Optional[dict] + allowed_mentions: Optional[dict] + member: Optional[dict] = None + flags: Optional[int] = None + guild_id: Optional[int] = None + + async def delete(self): + async with self.test_cli.app.app_context(): + resp = await self.test_cli.delete( + f"/api/v6/channels/{self.channel_id}/messages/{self.id}", + ) + rjson = await resp.json + + if resp.status_code == 404 and rjson["code"] in ( + ChannelNotFound.error_code, + MessageNotFound.error_code, + ): + return + + assert resp.status_code == 200 + assert rjson["id"] == str(self.id) + + async def refetch(self) -> dict: + async with self.test_cli.app.app_context(): + message_data = await self.test_cli.app.storage.get_message(self.id) + return WrappedMessage.from_json(self.test_cli, message_data) + + @classmethod + def from_json(cls, test_cli, rjson): + return cls( + test_cli, + **{ + **rjson, + **{ + "id": int(rjson["id"]), + "channel_id": int(rjson["channel_id"]), + "guild_id": int_(rjson["guild_id"]), + }, + }, + ) + + class TestClient: """Test client wrapper class. Adds Authorization headers to all requests and manages test resource setup and destruction.""" @@ -189,6 +307,56 @@ class TestClient: return self.add_resource(WrappedGuild.from_json(self, rjson)) + async def create_guild_channel( + self, + *, + guild_id: int, + name: Optional[str] = None, + type: ChannelType = ChannelType.GUILD_TEXT, + **kwargs, + ) -> WrappedGuild: + name = name or secrets.token_hex(6) + channel_id = self.app.winter_factory.snowflake() + + async with self.app.app_context(): + await create_guild_channel( + guild_id, channel_id, type, **{**{"name": name}, **kwargs} + ) + channel_data = await self.app.storage.get_channel(channel_id) + + return self.add_resource(WrappedGuildChannel.from_json(self, channel_data)) + + async def create_message( + self, + *, + guild_id: int, + channel_id: int, + content: Optional[str] = None, + author_id: Optional[int] = None, + ) -> WrappedGuild: + content = content or secrets.token_hex(6) + author_id = author_id or self.user["id"] + + async with self.app.app_context(): + message_id = await create_message( + channel_id, + guild_id, + author_id, + { + "content": content, + "tts": False, + "nonce": 0, + "everyone_mention": False, + "embeds": [], + "message_reference": None, + "allowed_mentions": None, + }, + ) + + message_data = await self.app.storage.get_message(message_id) + + return self.add_resource(WrappedMessage.from_json(self, message_data)) + def _inject_auth(self, kwargs: dict) -> list: """Inject the test user's API key into the test request before passing the request on to the underlying TestClient."""