diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 263e347..44bfdf8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -20,7 +20,8 @@ tests: script: - ls - cp config.ci.py config.py + - pipenv --venv - pipenv run ./manage.py migrate - pipenv run black --check litecord run.py tests manage - - pipenv run pyflakes run.py litecord/ + - pipenv run flake8 litecord run.py tests manage - pipenv run pytest tests diff --git a/litecord/gateway/schemas.py b/litecord/gateway/schemas.py new file mode 100644 index 0000000..4928a37 --- /dev/null +++ b/litecord/gateway/schemas.py @@ -0,0 +1,184 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +from typing import Dict + +from logbook import Logger + + +from litecord.gateway.errors import DecodeError +from litecord.schemas import LitecordValidator + +log = Logger(__name__) + + +def validate( + reqjson: Dict, + schema: Dict, +) -> Dict: + validator = LitecordValidator(schema) + + try: + valid = validator.validate(reqjson) + except Exception: + log.exception("Error while validating") + raise DecodeError(f"Error while validating: {reqjson}") + + if not valid: + errs = validator.errors + log.warning("Error validating doc {!r}: {!r}", reqjson, errs) + raise DecodeError(f"Error validating message : {errs!r}") + + return validator.document + + +BASE = { + "op": {"type": "number", "required": True}, + "s": {"type": "number", "required": False}, +} + +IDENTIFY_SCHEMA = { + **BASE, + **{ + "d": { + "type": "dict", + "schema": { + "token": {"type": "string", "required": True}, + "compress": {"type": "boolean", "required": False}, + "large_threshold": {"type": "number", "required": False}, + "shard": {"type": "list", "required": False}, + "presence": {"type": "dict", "required": False}, + }, + } + }, +} + +RESUME_SCHEMA = { + **BASE, + **{ + "d": { + "type": "dict", + "schema": { + "token": {"type": "string", "required": True}, + "session_id": {"type": "string", "required": True}, + "seq": {"type": "number", "required": True}, + }, + } + }, +} + +REQ_GUILD_SCHEMA = { + **BASE, + **{ + "d": { + "type": "dict", + "schema": { + "guild_id": {"type": "string", "required": True}, + "user_ids": {"type": "list", "required": False}, + "query": {"type": "string", "required": False}, + "limit": {"type": "number", "required": False}, + "presences": {"type": "bool", "required": False}, + }, + } + }, +} + +GUILD_SYNC_SCHEMA = { + **BASE, + **{ + "d": { + "type": "list", + "schema": {"type": "snowflake"}, + } + }, +} + + +GW_ACTIVITY = { + "name": {"type": "string", "required": True}, + "type": {"type": "activity_type", "required": True}, + "url": {"type": "string", "required": False, "nullable": True}, + "timestamps": { + "type": "dict", + "required": False, + "schema": { + "start": {"type": "number", "required": False}, + "end": {"type": "number", "required": False}, + }, + }, + "application_id": {"type": "snowflake", "required": False, "nullable": False}, + "details": {"type": "string", "required": False, "nullable": True}, + "state": {"type": "string", "required": False, "nullable": True}, + "party": { + "type": "dict", + "required": False, + "schema": { + "id": {"type": "snowflake", "required": False}, + "size": {"type": "list", "required": False}, + }, + }, + "assets": { + "type": "dict", + "required": False, + "schema": { + "large_image": {"type": "snowflake", "required": False}, + "large_text": {"type": "string", "required": False}, + "small_image": {"type": "snowflake", "required": False}, + "small_text": {"type": "string", "required": False}, + }, + }, + "secrets": { + "type": "dict", + "required": False, + "schema": { + "join": {"type": "string", "required": False}, + "spectate": {"type": "string", "required": False}, + "match": {"type": "string", "required": False}, + }, + }, + "instance": {"type": "boolean", "required": False}, + "flags": {"type": "number", "required": False}, + "emoji": { + "type": "dict", + "required": False, + "nullable": True, + "schema": { + "animated": {"type": "boolean", "required": False, "default": False}, + "id": {"coerce": int, "nullable": True, "default": None}, + "name": {"type": "string", "required": True}, + }, + }, +} + +GW_STATUS_UPDATE = { + "status": {"type": "status_external", "required": False, "default": "online"}, + "activities": { + "type": "list", + "required": False, + "schema": {"type": "dict", "schema": GW_ACTIVITY}, + }, + "afk": {"type": "boolean", "required": False}, + "since": {"type": "number", "required": False, "nullable": True}, + "game": { + "type": "dict", + "required": False, + "nullable": True, + "schema": GW_ACTIVITY, + }, +} diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 681cc56..8e58c28 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -32,7 +32,6 @@ from quart import current_app as app from litecord.auth import raw_token_check from litecord.enums import RelationshipType, ChannelType, ActivityType -from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.utils import ( task_wrapper, yield_chunks, @@ -59,6 +58,14 @@ from litecord.gateway.encoding import encode_json, decode_json, encode_etf, deco from litecord.gateway.utils import WebsocketFileHandler from litecord.pubsub.guild import GuildFlags from litecord.pubsub.channel import ChannelFlags +from litecord.gateway.schemas import ( + validate, + IDENTIFY_SCHEMA, + GW_STATUS_UPDATE, + RESUME_SCHEMA, + REQ_GUILD_SCHEMA, + GUILD_SYNC_SCHEMA, +) from litecord.storage import int_ @@ -651,13 +658,9 @@ class GatewayWebsocket: async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" - try: - data = payload["d"] - token = data["token"] - except KeyError: - raise DecodeError("Invalid identify parameters") - - # TODO proper validation of this payload + payload = validate(payload, IDENTIFY_SCHEMA) + data = payload["d"] + token = data["token"] compress = data.get("compress", False) large = data.get("large_threshold", 50) @@ -840,12 +843,9 @@ class GatewayWebsocket: async def handle_6(self, payload: Dict[str, Any]): """Handle OP 6 Resume.""" + payload = validate(payload, RESUME_SCHEMA) data = payload["d"] - - try: - token, sess_id, seq = data["token"], data["session_id"], data["seq"] - except KeyError: - raise DecodeError("Invalid resume payload") + token, sess_id, seq = data["token"], data["session_id"], data["seq"] try: user_id = await raw_token_check(token, self.app.db) @@ -915,6 +915,7 @@ class GatewayWebsocket: async def handle_8(self, payload: Dict): """Handle OP 8 Request Guild Members.""" + payload = validate(payload, REQ_GUILD_SCHEMA) data = payload["d"] gids = data["guild_id"] @@ -952,8 +953,8 @@ class GatewayWebsocket: async def handle_12(self, payload: Dict[str, Any]): """Handle OP 12 Guild Sync.""" + payload = validate(payload, GUILD_SYNC_SCHEMA) data = payload["d"] - gids = await self.user_storage.get_user_guilds(self.state.user_id) for guild_id in data: diff --git a/litecord/schemas.py b/litecord/schemas.py index cbb4781..3ba1f4c 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -418,79 +418,6 @@ MESSAGE_CREATE = { } -GW_ACTIVITY = { - "name": {"type": "string", "required": True}, - "type": {"type": "activity_type", "required": True}, - "url": {"type": "string", "required": False, "nullable": True}, - "timestamps": { - "type": "dict", - "required": False, - "schema": { - "start": {"type": "number", "required": False}, - "end": {"type": "number", "required": False}, - }, - }, - "application_id": {"type": "snowflake", "required": False, "nullable": False}, - "details": {"type": "string", "required": False, "nullable": True}, - "state": {"type": "string", "required": False, "nullable": True}, - "party": { - "type": "dict", - "required": False, - "schema": { - "id": {"type": "snowflake", "required": False}, - "size": {"type": "list", "required": False}, - }, - }, - "assets": { - "type": "dict", - "required": False, - "schema": { - "large_image": {"type": "snowflake", "required": False}, - "large_text": {"type": "string", "required": False}, - "small_image": {"type": "snowflake", "required": False}, - "small_text": {"type": "string", "required": False}, - }, - }, - "secrets": { - "type": "dict", - "required": False, - "schema": { - "join": {"type": "string", "required": False}, - "spectate": {"type": "string", "required": False}, - "match": {"type": "string", "required": False}, - }, - }, - "instance": {"type": "boolean", "required": False}, - "flags": {"type": "number", "required": False}, - "emoji": { - "type": "dict", - "required": False, - "nullable": True, - "schema": { - "animated": {"type": "boolean", "required": False, "default": False}, - "id": {"coerce": int, "nullable": True, "default": None}, - "name": {"type": "string", "required": True}, - }, - }, -} - -GW_STATUS_UPDATE = { - "status": {"type": "status_external", "required": False, "default": "online"}, - "activities": { - "type": "list", - "required": False, - "schema": {"type": "dict", "schema": GW_ACTIVITY}, - }, - "afk": {"type": "boolean", "required": False}, - "since": {"type": "number", "required": False, "nullable": True}, - "game": { - "type": "dict", - "required": False, - "nullable": True, - "schema": GW_ACTIVITY, - }, -} - INVITE = { # max_age in seconds # 0 for infinite diff --git a/tests/test_admin_api/test_guilds.py b/tests/test_admin_api/test_guilds.py index 3ff45fc..4d628ca 100644 --- a/tests/test_admin_api/test_guilds.py +++ b/tests/test_admin_api/test_guilds.py @@ -28,9 +28,10 @@ from litecord.errors import GuildNotFound async def _create_guild(test_cli_staff, *, region=None) -> dict: genned_name = secrets.token_hex(6) - resp = await test_cli_staff.post( - "/api/v6/guilds", json={"name": genned_name, "region": region} - ) + async with test_cli_staff.app.app_context(): + resp = await test_cli_staff.post( + "/api/v6/guilds", json={"name": genned_name, "region": region} + ) assert resp.status_code == 200 rjson = await resp.json @@ -62,13 +63,13 @@ async def _delete_guild(test_cli, guild_id: int): @pytest.mark.asyncio async def test_guild_fetch(test_cli_staff): """Test the creation and fetching of a guild via the Admin API.""" - rjson = await _create_guild(test_cli_staff) - guild_id = rjson["id"] - - try: - await _fetch_guild(test_cli_staff, guild_id) - finally: - await _delete_guild(test_cli_staff, int(guild_id)) + async with test_cli_staff.app.app_context(): + rjson = await _create_guild(test_cli_staff) + guild_id = rjson["id"] + try: + await _fetch_guild(test_cli_staff, guild_id) + finally: + await _delete_guild(test_cli_staff, int(guild_id)) @pytest.mark.asyncio diff --git a/tests/test_websocket.py b/tests/test_websocket.py index ed49838..a197cc2 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -103,6 +103,25 @@ async def test_ready(test_cli_user): await _close(conn) +@pytest.mark.asyncio +async def test_broken_identify(test_cli_user): + conn = await gw_start(test_cli_user.cli) + + # get the hello frame but ignore it + await _json(conn) + + await _json_send(conn, {"op": OP.IDENTIFY, "d": {"token": True}}) + + # try to get a ready + try: + await _json(conn) + raise AssertionError("Received a JSON message but expected close") + except websockets.ConnectionClosed as exc: + assert exc.code == 4002 + finally: + await _close(conn) + + @pytest.mark.asyncio async def test_ready_fields(test_cli_user): conn = await gw_start(test_cli_user.cli)