diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py index 3e5e9d2..477d543 100644 --- a/litecord/blueprints/channel/pins.py +++ b/litecord/blueprints/channel/pins.py @@ -37,6 +37,33 @@ class SysMsgInvalidAction(BadRequest): error_code = 50021 +async def _dispatch_pins_update(channel_id: int) -> None: + message_id = await app.db.fetchval( + """ + SELECT message_id + FROM channel_pins + WHERE channel_id = $1 + ORDER BY message_id ASC + LIMIT 1 + """, + channel_id, + ) + + timestamp = ( + app.winter_factory.to_datetime(message_id) if message_id is not None else None + ) + await app.dispatcher.channel.dispatch( + channel_id, + ( + "CHANNEL_PINS_UPDATE", + { + "channel_id": str(channel_id), + "last_pin_timestamp": timestamp_(timestamp), + }, + ), + ) + + @bp.route("//pins", methods=["GET"]) async def get_pins(channel_id): """Get the pins for a channel""" @@ -82,7 +109,7 @@ async def add_pin(channel_id, message_id): ) if mtype in SYS_MESSAGES: - raise SysMsgInvalidAction("Cannot execute action on a system message") + raise SysMsgInvalidAction("Cannot pin a system message") await app.db.execute( """ @@ -93,29 +120,7 @@ async def add_pin(channel_id, message_id): message_id, ) - row = await app.db.fetchrow( - """ - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - LIMIT 1 - """, - channel_id, - ) - - timestamp = app.winter_factory.to_datetime(row["message_id"]) - - await app.dispatcher.channel.dispatch( - channel_id, - ( - "CHANNEL_PINS_UPDATE", - { - "channel_id": str(channel_id), - "last_pin_timestamp": timestamp_(timestamp), - }, - ), - ) + await _dispatch_pins_update(channel_id) await send_sys_message( channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id @@ -140,28 +145,6 @@ async def delete_pin(channel_id, message_id): message_id, ) - row = await app.db.fetchrow( - """ - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - LIMIT 1 - """, - channel_id, - ) - - timestamp = app.winter_factory.to_datetime(row["message_id"]) - - await app.dispatcher.channel.dispatch( - channel_id, - ( - "CHANNEL_PINS_UPDATE", - { - "channel_id": str(channel_id), - "last_pin_timestamp": timestamp.isoformat(), - }, - ), - ) + await _dispatch_pins_update(channel_id) return "", 204 diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index ae61820..f077a4a 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -194,6 +194,9 @@ async def close_channel(channel_id): user_id = await token_check() chan_type = await app.storage.get_chan_type(channel_id) + if chan_type is None: + raise ChannelNotFound("Channel not found") + ctype = ChannelType(chan_type) if ctype in GUILD_CHANS: @@ -253,8 +256,7 @@ async def close_channel(channel_id): await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_DELETE", chan)) await app.dispatcher.channel.drop(channel_id) return jsonify(chan) - - if ctype == ChannelType.DM: + elif ctype == ChannelType.DM: chan = await app.storage.get_channel(channel_id) # we don't ever actually delete DM channels off the database. @@ -275,8 +277,7 @@ async def close_channel(channel_id): await dispatch_user(user_id, ("CHANNEL_DELETE", chan)) return jsonify(chan) - - if ctype == ChannelType.GROUP_DM: + elif ctype == ChannelType.GROUP_DM: await gdm_remove_recipient(channel_id, user_id) gdm_count = await app.db.fetchval( @@ -291,8 +292,8 @@ async def close_channel(channel_id): if gdm_count == 0: # destroy dm await gdm_destroy(channel_id) - - raise ChannelNotFound() + else: + raise RuntimeError(f"Data inconsistency: Unknown channel type {ctype}") async def _update_pos(channel_id, pos: int): diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 90d67e7..aa126b8 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -92,7 +92,7 @@ async def channel_check( """, channel_id, ) - + assert guild_id is not None await guild_check(user_id, guild_id) return ctype, guild_id diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index c458f6e..a9f3b88 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -232,6 +232,22 @@ async def _del_from_table(table: str, user_id: int): async def delete_guild(guild_id: int): """Delete a single guild.""" await _del_from_table("vanity_invites", guild_id) + + # while most guild channel tables have 'ON DELETE CASCADE', this + # must not be true to the channels table, which is generic for any channel. + # + # the drawback is that this causes breakdown on the data's semantics as + # we get a channel with a type of GUILD_TEXT/GUILD_VOICE but without any + # entry on the guild_channels table, causing errors. + for channel_id in await app.storage.get_channel_ids(guild_id): + await app.db.execute( + """ + DELETE FROM channels + WHERE channels.id = $1 + """, + channel_id, + ) + await app.db.execute( """ DELETE FROM guilds diff --git a/litecord/storage.py b/litecord/storage.py index 1b3b9d8..eca2481 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -231,7 +231,7 @@ class Storage: drow["max_members"] = 100000 # used by guilds with DISCOVERABLE feature - drow["preffered_locale"] = "en-US" + drow["preferred_locale"] = "en-US" # feature won't be impl'd drow["guild_scheduled_events"] = [] @@ -431,8 +431,7 @@ class Storage: drow["last_message_id"] = last_msg return {**row, **drow} - - if chan_type == ChannelType.GUILD_VOICE: + elif chan_type == ChannelType.GUILD_VOICE: vrow = await self.db.fetchrow( """ SELECT bitrate, user_limit @@ -443,11 +442,11 @@ class Storage: ) return {**row, **dict(vrow)} + else: + # this only exists to trick mypy. this codepath is unreachable + raise AssertionError("Unreachable code path.") - # this only exists to trick mypy. this codepath is unreachable - raise RuntimeError("Unreachable code path.") - - async def get_chan_type(self, channel_id: int) -> int: + async def get_chan_type(self, channel_id: int) -> Optional[int]: """Get the channel type integer, given channel ID.""" return await self.db.fetchval( """ @@ -533,6 +532,9 @@ class Storage: async def get_channel(self, channel_id: int, **kwargs) -> Optional[Dict[str, Any]]: """Fetch a single channel's information.""" chan_type = await self.get_chan_type(channel_id) + if chan_type is None: + return None + ctype = ChannelType(chan_type) if ctype in ( @@ -603,7 +605,9 @@ class Storage: drow["last_message_id"] = await self.chan_last_message_str(channel_id) return drow - return None + raise RuntimeError( + f"Data Inconsistency: Channel type {ctype} is not properly handled" + ) async def get_channel_ids(self, guild_id: int) -> List[int]: """Get all channel IDs in a guild.""" diff --git a/litecord/utils.py b/litecord/utils.py index 338a5b6..355bebf 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -279,7 +279,7 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple: if "before" in args: before = int(args["before"]) elif "after" in args: - before = int(args["after"]) + after = int(args["after"]) return before, after diff --git a/tests/common.py b/tests/common.py index 4e573a7..6f13cff 100644 --- a/tests/common.py +++ b/tests/common.py @@ -18,29 +18,411 @@ along with this program. If not, see . """ import secrets +from typing import Optional +from dataclasses import dataclass + +from litecord.common.users import create_user, delete_user +from litecord.common.guilds import delete_guild, create_guild_channel +from litecord.blueprints.channel.messages import create_message +from litecord.blueprints.auth import make_token +from litecord.storage import int_ +from litecord.enums import ChannelType, UserFlags +from litecord.errors import ChannelNotFound, MessageNotFound def email() -> str: return f"{secrets.token_hex(5)}@{secrets.token_hex(5)}.com" +def random_email() -> str: + # TODO: move everyone who uses email() to random_email() + return email() + + +def random_username() -> str: + return secrets.token_hex(10) + + +@dataclass +class WrappedUser: + test_cli: "TestClient" + id: int + name: str + discriminator: str + avatar: Optional[str] + flags: UserFlags + public_flags: UserFlags + bot: bool + premium: bool + bio: str + accent_color: Optional[int] + + # secure fields + email: str + verified: str + + # extra-secure tokens (not here by default) + password: Optional[str] = None + password_hash: Optional[str] = None + token: Optional[str] = None + + # not there by default + premium_type: Optional[str] = None + mobile: Optional[bool] = None + phone: Optional[bool] = None + mfa_enabled: Optional[bool] = None + + async def refetch(self) -> dict: + async with self.test_cli.app.app_context(): + rjson = await self.test_cli.app.storage.get_user(self.id, secure=True) + return WrappedUser.from_json(self.test_cli, rjson) + + async def delete(self): + return await delete_user(self.id) + + @classmethod + def from_json(cls, test_cli, data_not_owned): + data = dict(data_not_owned) # take ownership of data via copy + data["name"] = data.pop("username") + return cls( + test_cli, + **{ + **data, + **{ + "id": int(data["id"]), + }, + }, + ) + + +@dataclass +class WrappedGuild: + test_cli: "TestClient" + id: int + owner: bool # value depends on the user that fetched guild + owner_id: int + name: str + unavailable: bool + icon: Optional[str] + splash: Optional[str] + region: Optional[str] + afk_timeout: int + afk_channel_id: Optional[str] + afk_timeout: int + verification_level: int + default_message_notifications: int + explicit_content_filter: int + mfa_level: int + embed_enabled: bool + embed_channel_id: int + widget_enabled: bool + widget_channel_id: int + system_channel_id: int + rules_channel_id: int + public_updates_channel_id: int + features: str + features: str + banner: Optional[str] + description: Optional[str] + preferred_locale: Optional[str] + discovery_splash: Optional[str] + + vanity_url_code: Optional[str] + max_presences: int + max_members: int + guild_scheduled_events: list + + joined_at: str # value depends on the user that fetched the guild + + member_count: int + members: list + channels: list + roles: list + presences: list + emojis: list + voice_states: list + + large: Optional[bool] = None + + async def delete(self): + await delete_guild(self.id) + + async def refetch(self) -> "WrappedGuild": + async with self.test_cli.app.app_context(): + guild = await self.test_cli.app.storage.get_guild_full( + self.id, user_id=self.test_cli.user["id"] + ) + return WrappedGuild.from_json(self.test_cli, guild) + + @classmethod + def from_json(cls, test_cli, rjson): + return cls( + test_cli, + **{ + **rjson, + **{ + "id": int(rjson["id"]), + "owner_id": int(rjson["owner_id"]), + "afk_channel_id": int_(rjson["afk_channel_id"]), + "embed_channel_id": int_(rjson["embed_channel_id"]), + "widget_enabled": int_(rjson["widget_enabled"]), + "widget_channel_id": int_(rjson["widget_channel_id"]), + "system_channel_id": int_(rjson["system_channel_id"]), + "rules_channel_id": int_(rjson["rules_channel_id"]), + "public_updates_channel_id": int_( + rjson["public_updates_channel_id"] + ), + }, + }, + ) + + +@dataclass +class WrappedGuildChannel: + test_cli: "TestClient" + id: int + type: int + guild_id: int + parent_id: Optional[int] + name: str + position: int + nsfw: bool + topic: str + rate_limit_per_user: int + last_message_id: int + permission_overwrites: list + + async def delete(self): + async with self.test_cli.app.app_context(): + resp = await self.test_cli.delete( + f"/api/v6/channels/{self.id}", + ) + rjson = await resp.json + + if resp.status_code == 404 and rjson["code"] == ChannelNotFound.error_code: + return + + assert resp.status_code == 200 + assert rjson["id"] == str(self.id) + + async def refetch(self) -> dict: + async with self.test_cli.app.app_context(): + channel_data = await self.test_cli.app.storage.get_channel(self.id) + return WrappedGuildChannel.from_json(self.test_cli, channel_data) + + @classmethod + def from_json(cls, test_cli, rjson): + return cls( + test_cli, + **{ + **rjson, + **{ + "id": int(rjson["id"]), + "guild_id": int(rjson["guild_id"]), + "parent_id": int_(rjson["parent_id"]), + "last_message_id": int_(rjson["last_message_id"]), + "rate_limit_per_user": int_(rjson["rate_limit_per_user"]), + }, + }, + ) + + +@dataclass +class WrappedMessage: + test_cli: "TestClient" + + id: int + channel_id: int + author: dict + + type: int + content: str + + timestamp: str + edited_timestamp: str + + tts: bool + mention_everyone: bool + nonce: str + embeds: list + mentions: list + mention_roles: list + reactions: list + attachments: list + pinned: bool + message_reference: Optional[dict] + allowed_mentions: Optional[dict] + member: Optional[dict] = None + flags: Optional[int] = None + guild_id: Optional[int] = None + + async def delete(self): + async with self.test_cli.app.app_context(): + resp = await self.test_cli.delete( + f"/api/v6/channels/{self.channel_id}/messages/{self.id}", + ) + rjson = await resp.json + + if resp.status_code == 404 and rjson["code"] in ( + ChannelNotFound.error_code, + MessageNotFound.error_code, + ): + return + + assert resp.status_code == 200 + assert rjson["id"] == str(self.id) + + async def refetch(self) -> Optional["WrappedMessage"]: + async with self.test_cli.app.app_context(): + message_data = await self.test_cli.app.storage.get_message(self.id) + if message_data is None: + return None + return WrappedMessage.from_json(self.test_cli, message_data) + + @classmethod + def from_json(cls, test_cli, rjson): + return cls( + test_cli, + **{ + **rjson, + **{ + "id": int(rjson["id"]), + "channel_id": int(rjson["channel_id"]), + "guild_id": int_(rjson["guild_id"]), + }, + }, + ) + + class TestClient: - """Test client that wraps pytest-sanic's TestClient and a test - user and adds authorization headers to test requests.""" + """Test client wrapper class. Adds Authorization headers to all requests + and manages test resource setup and destruction.""" def __init__(self, test_cli, test_user): self.cli = test_cli self.app = test_cli.app self.user = test_user + self.resources = [] def __getitem__(self, key): return self.user[key] + def add_resource(self, resource): + self.resources.append(resource) + return resource + + async def cleanup(self): + for resource in self.resources: + async with self.app.app_context(): + await resource.delete() + + async def create_user( + self, + *, + username: Optional[str] = None, + email: Optional[str] = None, + password: Optional[str] = None, + ) -> WrappedUser: + username = username or random_username() + email = email or random_email() + password = password or random_username() + + async with self.app.app_context(): + user_id, password_hash = await create_user(username, email, password) + user_token = make_token(user_id, password_hash) + full_user_object = await self.app.storage.get_user(user_id, secure=True) + + return self.add_resource( + WrappedUser.from_json( + self, + { + **full_user_object, + **{ + "token": user_token, + "password_hash": password_hash, + }, + }, + ) + ) + + async def create_guild( + self, + *, + name: Optional[str] = None, + region: Optional[str] = None, + owner: Optional["WrappedUser"] = None, + ) -> WrappedGuild: + name = name or secrets.token_hex(6) + owner_token = owner.token if owner else self.user["token"] + + async with self.app.app_context(): + # TODO move guild creation logic to litecord.common.guild + # TODO make tests use aiosqlite on memory for db + resp = await self.post( + "/api/v6/guilds", + json={"name": name, "region": region}, + headers={"authorization": owner_token}, + ) + rjson = await resp.json + + return self.add_resource(WrappedGuild.from_json(self, rjson)) + + async def create_guild_channel( + self, + *, + guild_id: int, + name: Optional[str] = None, + type: ChannelType = ChannelType.GUILD_TEXT, + **kwargs, + ) -> WrappedGuild: + name = name or secrets.token_hex(6) + channel_id = self.app.winter_factory.snowflake() + + async with self.app.app_context(): + await create_guild_channel( + guild_id, channel_id, type, **{**{"name": name}, **kwargs} + ) + channel_data = await self.app.storage.get_channel(channel_id) + + return self.add_resource(WrappedGuildChannel.from_json(self, channel_data)) + + async def create_message( + self, + *, + guild_id: int, + channel_id: int, + content: Optional[str] = None, + author_id: Optional[int] = None, + ) -> WrappedGuild: + content = content or secrets.token_hex(6) + author_id = author_id or self.user["id"] + + async with self.app.app_context(): + message_id = await create_message( + channel_id, + guild_id, + author_id, + { + "content": content, + "tts": False, + "nonce": 0, + "everyone_mention": False, + "embeds": [], + "message_reference": None, + "allowed_mentions": None, + }, + ) + + message_data = await self.app.storage.get_message(message_id) + + return self.add_resource(WrappedMessage.from_json(self, message_data)) + def _inject_auth(self, kwargs: dict) -> list: """Inject the test user's API key into the test request before passing the request on to the underlying TestClient.""" headers = kwargs.get("headers", {}) - headers["authorization"] = self.user["token"] + if "authorization" not in headers: + headers["authorization"] = self.user["token"] return headers async def get(self, *args, **kwargs): diff --git a/tests/conftest.py b/tests/conftest.py index f0444c9..e625dba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,7 @@ sys.path.append(os.getcwd()) from tests.common import email, TestClient -from run import app as main_app, set_blueprints +from run import app as main_app from litecord.common.users import create_user, delete_user from litecord.enums import UserFlags @@ -36,8 +36,7 @@ from litecord.blueprints.auth import make_token @pytest.fixture(name="app") -def _test_app(unused_tcp_port, event_loop): - set_blueprints(main_app) +async def _test_app(unused_tcp_port): main_app.config["_testing"] = True # reassign an unused tcp port for websockets @@ -53,13 +52,13 @@ def _test_app(unused_tcp_port, event_loop): main_app.config["REGISTRATIONS"] = True # make sure we're calling the before_serving hooks - event_loop.run_until_complete(main_app.startup()) + await main_app.startup() # https://docs.pytest.org/en/latest/fixture.html#fixture-finalization-executing-teardown-code yield main_app # properly teardown - event_loop.run_until_complete(main_app.shutdown()) + await main_app.shutdown() @pytest.fixture(name="test_cli") @@ -107,7 +106,9 @@ async def test_user_fixture(app): async def test_cli_user(test_cli, test_user): """Yield a TestClient instance that contains a randomly generated user.""" - yield TestClient(test_cli, test_user) + client = TestClient(test_cli, test_user) + yield client + await client.cleanup() @pytest.fixture @@ -138,5 +139,7 @@ async def test_cli_staff(test_cli): user_id, ) - yield TestClient(test_cli, test_user) + client = TestClient(test_cli, test_user) + yield client + await client.cleanup() await _user_fixture_teardown(test_cli.app, test_user) diff --git a/tests/test_admin_api/test_guilds.py b/tests/test_admin_api/test_guilds.py index 125193c..cdb4110 100644 --- a/tests/test_admin_api/test_guilds.py +++ b/tests/test_admin_api/test_guilds.py @@ -21,30 +21,13 @@ import secrets import pytest -from litecord.blueprints.guilds import delete_guild from litecord.errors import GuildNotFound -async def _create_guild(test_cli_staff, *, region=None) -> dict: - genned_name = secrets.token_hex(6) - - async with test_cli_staff.app.app_context(): - resp = await test_cli_staff.post( - "/api/v6/guilds", json={"name": genned_name, "region": region} - ) - - assert resp.status_code == 200 - rjson = await resp.json - assert isinstance(rjson, dict) - assert rjson["name"] == genned_name - - return rjson - - -async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False): +async def _fetch_guild(test_cli_staff, guild_id: str, *, return_early: bool = False): resp = await test_cli_staff.get(f"/api/v6/admin/guilds/{guild_id}") - if ret_early: + if return_early: return resp assert resp.status_code == 200 @@ -55,73 +38,54 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False): return rjson -async def _delete_guild(test_cli, guild_id: int): - async with test_cli.app.app_context(): - await delete_guild(int(guild_id)) - - @pytest.mark.asyncio async def test_guild_fetch(test_cli_staff): """Test the creation and fetching of a guild via the Admin API.""" - async with test_cli_staff.app.app_context(): - rjson = await _create_guild(test_cli_staff) - guild_id = rjson["id"] - try: - await _fetch_guild(test_cli_staff, guild_id) - finally: - await _delete_guild(test_cli_staff, int(guild_id)) + guild = await test_cli_staff.create_guild() + await _fetch_guild(test_cli_staff, str(guild.id)) @pytest.mark.asyncio async def test_guild_update(test_cli_staff): """Test the update of a guild via the Admin API.""" - async with test_cli_staff.app.app_context(): - rjson = await _create_guild(test_cli_staff) - guild_id = rjson["id"] - assert not rjson["unavailable"] + guild = await test_cli_staff.create_guild() + guild_id = str(guild.id) - try: - # I believe setting up an entire gateway client registered to the guild - # would be overkill to test the side-effects, so... I'm not - # testing them. Yes, I know its a bad idea, but if someone has an easier - # way to write that, do send an MR. - resp = await test_cli_staff.patch( - f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True} - ) + # I believe setting up an entire gateway client registered to the guild + # would be overkill to test the side-effects, so... I'm not + # testing them. Yes, I know its a bad idea, but if someone has an easier + # way to write that, do send an MR. + resp = await test_cli_staff.patch( + f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True} + ) - assert resp.status_code == 200 - rjson = await resp.json - assert isinstance(rjson, dict) - assert rjson["id"] == guild_id - assert rjson["unavailable"] + assert resp.status_code == 200 + rjson = await resp.json + assert isinstance(rjson, dict) + assert rjson["id"] == guild_id + assert rjson["unavailable"] - rjson = await _fetch_guild(test_cli_staff, guild_id) - assert rjson["unavailable"] - finally: - await _delete_guild(test_cli_staff, int(guild_id)) + rjson = await _fetch_guild(test_cli_staff, guild_id) + assert rjson["id"] == guild_id + assert rjson["unavailable"] @pytest.mark.asyncio async def test_guild_delete(test_cli_staff): """Test the update of a guild via the Admin API.""" - async with test_cli_staff.app.app_context(): - rjson = await _create_guild(test_cli_staff) - guild_id = rjson["id"] + guild = await test_cli_staff.create_guild() + guild_id = str(guild.id) - try: - resp = await test_cli_staff.delete(f"/api/v6/admin/guilds/{guild_id}") + resp = await test_cli_staff.delete(f"/api/v6/admin/guilds/{guild_id}") + assert resp.status_code == 204 - assert resp.status_code == 204 + resp = await _fetch_guild(test_cli_staff, guild_id, return_early=True) + assert resp.status_code == 404 - resp = await _fetch_guild(test_cli_staff, guild_id, ret_early=True) - - assert resp.status_code == 404 - rjson = await resp.json - assert isinstance(rjson, dict) - assert rjson["error"] - assert rjson["code"] == GuildNotFound.error_code - finally: - await _delete_guild(test_cli_staff, int(guild_id)) + rjson = await resp.json + assert isinstance(rjson, dict) + assert rjson["error"] + assert rjson["code"] == GuildNotFound.error_code @pytest.mark.asyncio @@ -132,17 +96,15 @@ async def test_guild_create_voice(test_cli_staff): "/api/v6/admin/voice/regions", json={"id": region_id, "name": region_name} ) assert resp.status_code == 200 - guild_id = None + rjson = await resp.json + assert isinstance(rjson, list) + assert region_id in [r["id"] for r in rjson] + # This test is basically creating the guild with a self-selected region + # then deleting the guild afterwards on test resource cleanup try: - rjson = await resp.json - assert isinstance(rjson, list) - assert region_id in [r["id"] for r in rjson] - guild_id = await _create_guild(test_cli_staff, region=region_id) + await test_cli_staff.create_guild(region=region_id) finally: - if guild_id: - await _delete_guild(test_cli_staff, int(guild_id["id"])) - await test_cli_staff.app.db.execute( """ DELETE FROM voice_regions diff --git a/tests/test_admin_api/test_users.py b/tests/test_admin_api/test_users.py index 8b2e738..814cd29 100644 --- a/tests/test_admin_api/test_users.py +++ b/tests/test_admin_api/test_users.py @@ -21,6 +21,7 @@ import secrets import pytest +from tests.common import email from litecord.enums import UserFlags @@ -41,6 +42,20 @@ async def test_list_users(test_cli_staff): assert rjson +@pytest.mark.asyncio +async def test_find_single_user(test_cli_staff): + user = await test_cli_staff.create_user( + username="test_user" + secrets.token_hex(2), email=email() + ) + resp = await _search(test_cli_staff, username=user.name) + + assert resp.status_code == 200 + rjson = await resp.json + assert isinstance(rjson, list) + fetched_user = rjson[0] + assert fetched_user["id"] == str(user.id) + + async def _setup_user(test_cli) -> dict: genned = secrets.token_hex(7) @@ -105,24 +120,17 @@ async def test_create_delete(test_cli_staff): @pytest.mark.asyncio async def test_user_update(test_cli_staff): """Test user update.""" - rjson = await _setup_user(test_cli_staff) + user = await test_cli_staff.create_user() - user_id = rjson["id"] + # set them as partner flag + resp = await test_cli_staff.patch( + f"/api/v6/admin/users/{user.id}", json={"flags": UserFlags.partner} + ) - # test update + assert resp.status_code == 200 + rjson = await resp.json + assert rjson["id"] == str(user.id) + assert rjson["flags"] == UserFlags.partner - try: - # set them as partner flag - resp = await test_cli_staff.patch( - f"/api/v6/admin/users/{user_id}", json={"flags": UserFlags.partner} - ) - - assert resp.status_code == 200 - rjson = await resp.json - assert rjson["id"] == user_id - assert rjson["flags"] == UserFlags.partner - - # TODO: maybe we can check for side effects by fetching the - # user manually too... - finally: - await _del_user(test_cli_staff, user_id) + refetched = await user.refetch() + assert refetched.flags == UserFlags.partner diff --git a/tests/test_channels.py b/tests/test_channels.py new file mode 100644 index 0000000..47d33b4 --- /dev/null +++ b/tests/test_channels.py @@ -0,0 +1,146 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +import pytest +from litecord.common.guilds import add_member + +pytestmark = pytest.mark.asyncio + + +async def test_channel_create(test_cli_user): + guild = await test_cli_user.create_guild() + + # guild test object teardown should destroy the channel as well! + resp = await test_cli_user.post( + f"/api/v6/guilds/{guild.id}/channels", + json={ + "name": "hello-world", + }, + ) + assert resp.status_code == 200 + rjson = await resp.json + channel_id: str = rjson["id"] + assert rjson["name"] == "hello-world" + + refetched_guild = await guild.refetch() + assert len(refetched_guild.channels) == 2 + assert channel_id in (channel["id"] for channel in refetched_guild.channels) + + resp = await test_cli_user.get(f"/api/v6/channels/{channel_id}") + assert resp.status_code == 200 + rjson = await resp.json + assert rjson["id"] == channel_id + + +async def test_channel_delete(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + + resp = await test_cli_user.delete( + f"/api/v6/channels/{channel.id}", + ) + assert resp.status_code == 200 + rjson = await resp.json + assert rjson["id"] == str(channel.id) + + +async def test_channel_message_send(test_cli_user): + guild = await test_cli_user.create_guild() + channel = guild.channels[0] + resp = await test_cli_user.post( + f'/api/v6/channels/{channel["id"]}/messages', + json={ + "content": "hello world", + }, + ) + assert resp.status_code == 200 + rjson = await resp.json + assert rjson["content"] == "hello world" + + +async def test_channel_message_send_on_new_channel(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + assert channel.guild_id == guild.id + + refetched_guild = await guild.refetch() + assert len(refetched_guild.channels) == 2 + + resp = await test_cli_user.post( + f"/api/v6/channels/{channel.id}/messages", + json={ + "content": "hello world", + }, + ) + assert resp.status_code == 200 + rjson = await resp.json + assert rjson["content"] == "hello world" + + +async def test_channel_message_delete(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + message = await test_cli_user.create_message( + guild_id=guild.id, channel_id=channel.id + ) + + resp = await test_cli_user.delete( + f"/api/v6/channels/{channel.id}/messages/{message.id}", + ) + assert resp.status_code == 204 + + assert (await message.refetch()) is None + + +async def test_channel_message_delete_different_author(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + user = await test_cli_user.create_user() + async with test_cli_user.app.app_context(): + await add_member(guild.id, user.id) + + message = await test_cli_user.create_message( + guild_id=guild.id, channel_id=channel.id, author_id=user.id + ) + + resp = await test_cli_user.delete( + f"/api/v6/channels/{channel.id}/messages/{message.id}", + headers={"authorization": user.token}, + ) + assert resp.status_code == 204 + + +async def test_channel_message_bulk_delete(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + messages = [] + for _ in range(10): + messages.append( + await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) + ) + + resp = await test_cli_user.post( + f"/api/v6/channels/{channel.id}/messages/bulk-delete", + json={"messages": [message.id for message in messages]}, + ) + assert resp.status_code == 204 + + # assert everyone cant be refetched + for message in messages: + assert (await message.refetch()) is None diff --git a/tests/test_messages.py b/tests/test_messages.py new file mode 100644 index 0000000..a92b58d --- /dev/null +++ b/tests/test_messages.py @@ -0,0 +1,118 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +import pytest + +pytestmark = pytest.mark.asyncio + + +async def test_message_listing(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + messages = [] + for _ in range(10): + messages.append( + await test_cli_user.create_message(guild_id=guild.id, channel_id=channel.id) + ) + + # assert all messages we just created can be refetched if we give the + # middle message to the 'around' parameter + middle_message_id = messages[5].id + + resp = await test_cli_user.get( + f"/api/v6/channels/{channel.id}/messages", + query_string={"around": middle_message_id}, + ) + assert resp.status_code == 200 + rjson = await resp.json + + fetched_ids = [m["id"] for m in rjson] + for message in messages: + assert str(message.id) in fetched_ids + + # assert all messages are below given id if its on 'before' param + + resp = await test_cli_user.get( + f"/api/v6/channels/{channel.id}/messages", + query_string={"before": middle_message_id}, + ) + assert resp.status_code == 200 + rjson = await resp.json + + for message_json in rjson: + assert int(message_json["id"]) <= middle_message_id + + # assert all message are above given id if its on 'after' param + resp = await test_cli_user.get( + f"/api/v6/channels/{channel.id}/messages", + query_string={"after": middle_message_id}, + ) + assert resp.status_code == 200 + rjson = await resp.json + + for message_json in rjson: + assert int(message_json["id"]) >= middle_message_id + + +async def test_message_update(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + message = await test_cli_user.create_message( + guild_id=guild.id, channel_id=channel.id + ) + + resp = await test_cli_user.patch( + f"/api/v6/channels/{channel.id}/messages/{message.id}", + json={"content": "awooga"}, + ) + assert resp.status_code == 200 + rjson = await resp.json + + assert rjson["id"] == str(message.id) + assert rjson["content"] == "awooga" + + refetched = await message.refetch() + assert refetched.content == "awooga" + + +async def test_message_pinning(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + message = await test_cli_user.create_message( + guild_id=guild.id, channel_id=channel.id + ) + + resp = await test_cli_user.put(f"/api/v6/channels/{channel.id}/pins/{message.id}") + assert resp.status_code == 204 + + resp = await test_cli_user.get(f"/api/v6/channels/{channel.id}/pins") + assert resp.status_code == 200 + rjson = await resp.json + assert len(rjson) == 1 + assert rjson[0]["id"] == str(message.id) + + resp = await test_cli_user.delete( + f"/api/v6/channels/{channel.id}/pins/{message.id}" + ) + assert resp.status_code == 204 + + resp = await test_cli_user.get(f"/api/v6/channels/{channel.id}/pins") + assert resp.status_code == 200 + rjson = await resp.json + assert len(rjson) == 0 diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py new file mode 100644 index 0000000..e73c615 --- /dev/null +++ b/tests/test_webhooks.py @@ -0,0 +1,52 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +import pytest + +pytestmark = pytest.mark.asyncio + + +async def test_webhook_flow(test_cli_user): + guild = await test_cli_user.create_guild() + channel = await test_cli_user.create_guild_channel(guild_id=guild.id) + + resp = await test_cli_user.post( + f"/api/v6/channels/{channel.id}/webhooks", json={"name": "awooga"} + ) + assert resp.status_code == 200 + rjson = await resp.json + assert rjson["channel_id"] == str(channel.id) + assert rjson["guild_id"] == str(guild.id) + assert rjson["name"] == "awooga" + + webhook_id = rjson["id"] + webhook_token = rjson["token"] + + resp = await test_cli_user.post( + f"/api/v6/webhooks/{webhook_id}/{webhook_token}", + json={"content": "test_message"}, + headers={"authorization": ""}, + ) + assert resp.status_code == 204 + + refetched_channel = await channel.refetch() + message = await test_cli_user.app.storage.get_message( + refetched_channel.last_message_id + ) + assert message["author"]["id"] == webhook_id