diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 930e60a..45e7d6c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: python:3.7-alpine +image: python:3.8-alpine variables: PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" diff --git a/Pipfile b/Pipfile index 95f5117..d2c026e 100644 --- a/Pipfile +++ b/Pipfile @@ -26,7 +26,7 @@ flake8 = "*" pyflakes = "*" [requires] -python_version = "3.7" +python_version = "3.8" [pipenv] allow_prereleases = true diff --git a/Pipfile.lock b/Pipfile.lock index 256740a..6230024 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,11 +1,11 @@ { "_meta": { "hash": { - "sha256": "bdd693f88f93f57b4a6eca42e0c8f053e500ddc40e344d888af86143c1904751" + "sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3" }, "pipfile-spec": 6, "requires": { - "python_version": "3.7" + "python_version": "3.8" }, "sources": [ { @@ -238,10 +238,10 @@ }, "jinja2": { "hashes": [ - "sha256:93187ffbc7808079673ef52771baa950426fd664d3aad1d0fa3e95644360e250", - "sha256:b0eaf100007721b5c16c1fc1eecb87409464edc10469ddc9a22a27a99123be49" + "sha256:c10142f819c2d22bdcd17548c46fa9b77cf4fda45097854c689666bf425e7484", + "sha256:c922560ac46888d47384de1dbdc3daaa2ea993af4b26a436dec31fa2c19ec668" ], - "version": "==2.11.1" + "version": "==3.0.0a1" }, "logbook": { "hashes": [ @@ -554,14 +554,6 @@ "index": "pypi", "version": "==3.7.9" }, - "importlib-metadata": { - "hashes": [ - "sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302", - "sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b" - ], - "markers": "python_version < '3.8'", - "version": "==1.5.0" - }, "mccabe": { "hashes": [ "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", @@ -749,13 +741,6 @@ "sha256:f28b3e8a6483e5d49e7f8949ac1a78314e740333ae305b4ba5defd3e74fb37a8" ], "version": "==0.1.8" - }, - "zipp": { - "hashes": [ - "sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28", - "sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98" - ], - "version": "==2.1.0" } } } diff --git a/README.md b/README.md index 49a0f57..c6bb2dc 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ or third party libraries (such as [Eris](https://github.com/abalabahaha/eris)). Requirements: -- **Python 3.7+** +- **Python 3.8+** - PostgreSQL (tested using 9.6+), SQL knowledge is recommended. - gifsicle for GIF emoji and avatar handling - [pipenv] diff --git a/litecord/blueprints/admin_api/instance_invites.py b/litecord/blueprints/admin_api/instance_invites.py index c410c47..50bfee4 100644 --- a/litecord/blueprints/admin_api/instance_invites.py +++ b/litecord/blueprints/admin_api/instance_invites.py @@ -19,6 +19,7 @@ along with this program. If not, see . import string from random import choice +from typing import Optional from quart import Blueprint, jsonify, current_app as app, request @@ -36,7 +37,7 @@ async def _gen_inv() -> str: return "".join(choice(ALPHABET) for _ in range(6)) -async def gen_inv(ctx) -> str: +async def gen_inv(ctx) -> Optional[str]: """Generate an invite.""" for _ in range(10): possible_inv = await _gen_inv() diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 4a92811..4a20706 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -18,6 +18,7 @@ along with this program. If not, see . """ from pathlib import Path +from typing import Optional from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger @@ -146,7 +147,7 @@ async def _dm_pre_dispatch(channel_id, peer_id): async def create_message( - channel_id: int, actual_guild_id: int, author_id: int, data: dict + channel_id: int, actual_guild_id: Optional[int], author_id: int, data: dict ) -> int: message_id = get_snowflake() @@ -159,7 +160,7 @@ async def create_message( content, tts, mention_everyone, nonce, message_type, embeds) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - """, + """, message_id, channel_id, actual_guild_id, @@ -186,7 +187,7 @@ async def _create_message(channel_id): user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - actual_guild_id = None + actual_guild_id: Optional[int] = None if ctype in GUILD_CHANS: await channel_perm_check(user_id, channel_id, "send_messages") diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py index 9ba3c88..9153ffc 100644 --- a/litecord/blueprints/channel/pins.py +++ b/litecord/blueprints/channel/pins.py @@ -114,7 +114,7 @@ async def add_pin(channel_id, message_id): ) await send_sys_message( - app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id + channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id ) return "", 204 diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py index 9c1627f..3531b0e 100644 --- a/litecord/blueprints/channel/reactions.py +++ b/litecord/blueprints/channel/reactions.py @@ -18,7 +18,7 @@ along with this program. If not, see . """ from enum import IntEnum -from typing import List, Union, Tuple +from typing import List, Union, Tuple, TypedDict, Optional from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger @@ -60,7 +60,12 @@ def emoji_info_from_str(emoji: str) -> tuple: return emoji_type, emoji_id, emoji_name -def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: +class PartialEmoji(TypedDict): + id: Optional[int] + name: str + + +def partial_emoji(emoji_type, emoji_id, emoji_name) -> PartialEmoji: print(emoji_type, emoji_id, emoji_name) return { "id": None if emoji_type == EmojiType.UNICODE else emoji_id, diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 2ede1cc..386f333 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -386,14 +386,14 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int): return "", 204 -async def _update_channel_common(channel_id, guild_id: int, j: dict): +async def _update_channel_common(channel_id: int, guild_id: int, j: dict): if "name" in j: await app.db.execute( """ - UPDATE guild_channels - SET name = $1 - WHERE id = $2 - """, + UPDATE guild_channels + SET name = $1 + WHERE id = $2 + """, j["name"], channel_id, ) @@ -401,7 +401,9 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict): if "position" in j: channel_data = await app.storage.get_channel_data(guild_id) - chans = [None] * len(channel_data) + # get an ordered list of the chans array by position + # TODO bad impl. can break easily. maybe dict? + chans: List[Optional[int]] = [None] * len(channel_data) for chandata in channel_data: chans.insert(chandata["position"], int(chandata["id"])) @@ -422,7 +424,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict): left_shift = new_pos > current_pos # find all channels that we'll have to shift - shift_block = ( + shift_block: List[Optional[int]] = ( chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos] ) @@ -509,9 +511,7 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int): channel_id, ) - await send_sys_message( - app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id - ) + await send_sys_message(channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id) if "icon" in j: new_icon = await app.icons.update( @@ -528,13 +528,11 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int): channel_id, ) - await send_sys_message( - app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id - ) + await send_sys_message(channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id) @bp.route("/", methods=["PUT", "PATCH"]) -async def update_channel(channel_id): +async def update_channel(channel_id: int): """Update a channel's information""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 5141f5b..52ab5bc 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ -from typing import Union, List +from typing import Union, List, Optional from quart import current_app as app @@ -66,7 +66,7 @@ async def guild_owner_check(user_id: int, guild_id: int): async def channel_check( - user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None + user_id, channel_id, *, only: Optional[Union[ChannelType, List[ChannelType]]] = None ): """Check if the current user is authorized to read the channel's information.""" @@ -77,10 +77,10 @@ async def channel_check( ctype = ChannelType(chan_type) - if only and not isinstance(only, list): + if (only is not None) and not isinstance(only, list): only = [only] - if only and ctype not in only: + if (only is not None) and ctype not in only: raise ChannelNotFound("invalid channel type") if ctype in GUILD_CHANS: diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py index db37229..4530d68 100644 --- a/litecord/blueprints/dm_channels.py +++ b/litecord/blueprints/dm_channels.py @@ -113,9 +113,7 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None): await app.dispatcher.sub("channel", peer_id) if user_id: - await send_sys_message( - app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id - ) + await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id) async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None): @@ -145,9 +143,7 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None): author_id = peer_id if user_id is None else user_id - await send_sys_message( - app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id - ) + await send_sys_message(channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id) async def gdm_destroy(channel_id): diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index 89370ae..4f00e10 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -130,14 +130,11 @@ async def modify_channel_pos(guild_id): # the same schema and all. raw_j = await request.get_json() j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION) - j = j["roles"] + roles = j["roles"] channels = await app.storage.get_channel_data(guild_id) channel_positions = {chan["position"]: int(chan["id"]) for chan in channels} - - swap_pairs = gen_pairs(j, channel_positions) - + swap_pairs = gen_pairs(roles, channel_positions) await _do_channel_swaps(guild_id, swap_pairs) - return "", 204 diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index 98440fa..c790fcb 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ -from typing import List, Dict, Tuple +from typing import List, Dict, Tuple, Optional from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger @@ -122,7 +122,7 @@ PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]] def gen_pairs( list_of_changes: List[Dict[str, int]], current_state: Dict[int, int], - blacklist: List[int] = None, + blacklist: Optional[List[int]] = None, ) -> PairList: """Generate a list of pairs that, when applied to the database, will generate the desired state given in list_of_changes. @@ -162,7 +162,7 @@ def gen_pairs( List of swaps to do to achieve the preferred state given by ``list_of_changes``. """ - pairs = [] + pairs: PairList = [] blacklist = blacklist or [] preferred_state = { @@ -222,9 +222,9 @@ async def update_guild_role_positions(guild_id): j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION) # extract the list out - j = j["roles"] + roles = j["roles"] - log.debug("role stuff: {!r}", j) + log.debug("role stuff: {!r}", roles) all_roles = await app.storage.get_role_data(guild_id) @@ -238,7 +238,7 @@ async def update_guild_role_positions(guild_id): # NOTE: ^ this is related to the positioning of the roles. pairs = gen_pairs( - j, + roles, roles_pos, # always ignore people trying to change # the @everyone's role position diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 1c2f2db..5d69ffb 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -96,7 +96,7 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]: return f"data:image/jpeg;base64,{icon}" if icon else None -async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs): +async def _general_guild_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs): encoded = sanitize_icon(icon) icon_kwargs = {"always_icon": True} diff --git a/litecord/blueprints/webhooks.py b/litecord/blueprints/webhooks.py index 9f5cd30..a495c39 100644 --- a/litecord/blueprints/webhooks.py +++ b/litecord/blueprints/webhooks.py @@ -222,14 +222,14 @@ async def get_guild_webhook(guild_id): async def get_single_webhook(webhook_id): """Get a single webhook's information.""" await _webhook_check_fw(webhook_id) - return await jsonify(await get_webhook(webhook_id)) + return jsonify(await get_webhook(webhook_id)) @bp.route("/webhooks//", methods=["GET"]) async def get_tokened_webhook(webhook_id, webhook_token): """Get a webhook using its token.""" await webhook_token_check(webhook_id, webhook_token) - return await jsonify(await get_webhook(webhook_id, secure=False)) + return jsonify(await get_webhook(webhook_id, secure=False)) async def _update_webhook(webhook_id: int, j: dict): @@ -289,6 +289,7 @@ async def modify_webhook(webhook_id: int): await _update_webhook(webhook_id, j) webhook = await get_webhook(webhook_id) + assert webhook is not None # we don't need to cast channel_id to int since that isn't # used in the dispatcher call @@ -313,7 +314,9 @@ async def modify_webhook_tokened(webhook_id, webhook_token): async def delete_webhook(webhook_id: int): """Delete a webhook.""" webhook = await get_webhook(webhook_id) + assert webhook is not None + # TODO use returning? res = await app.db.execute( """ DELETE FROM webhooks @@ -423,7 +426,10 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str: # we still fetch the URL to check its validity, mimetypes, etc # but in the end, we will store it under the webhook_avatars table, # not IconManager. - resp, raw = await fetch_mediaproxy_img(avatar_url) + res = await fetch_mediaproxy_img(avatar_url) + if res is None: + raise BadRequest("Failed to fetch URL.") + resp, raw = res # raw_b64 = base64.b64encode(raw).decode() mime = resp.headers["content-type"] @@ -469,6 +475,7 @@ async def execute_webhook(webhook_id: int, webhook_token): given_embeds = j.get("embeds", []) webhook = await get_webhook(webhook_id) + assert webhook is not None # webhooks have TWO avatars. one is from settings, the other is from # the json's icon_url. one can be handled gracefully by IconManager, diff --git a/litecord/common/channels.py b/litecord/common/channels.py index dc85ef4..05e2aca 100644 --- a/litecord/common/channels.py +++ b/litecord/common/channels.py @@ -16,6 +16,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from typing import Optional from quart import current_app as app @@ -24,12 +25,15 @@ from litecord.enums import RelationshipType async def channel_ack( - user_id: int, guild_id: int, channel_id: int, message_id: int = None + user_id: int, guild_id: int, channel_id: int, message_id: Optional[int] = None ): """ACK a channel.""" - if not message_id: - message_id = await app.storage.chan_last_message(channel_id) + message_id = message_id or await app.storage.chan_last_message(channel_id) + + # never ack without a message, as that breaks read state. + if message_id is None: + return await app.db.execute( """ diff --git a/litecord/common/users.py b/litecord/common/users.py index 9d8794c..00614d2 100644 --- a/litecord/common/users.py +++ b/litecord/common/users.py @@ -18,7 +18,7 @@ along with this program. If not, see . """ from random import randint -from typing import Tuple, Optional +from typing import Tuple, Optional, List from quart import current_app as app from asyncpg import UniqueViolationError @@ -33,13 +33,13 @@ from ..utils import rand_hex log = Logger(__name__) -async def mass_user_update(user_id): +async def mass_user_update(user_id: int): """Dispatch USER_UPDATE in a mass way.""" # by using dispatch_with_filter # we're guaranteeing all shards will get # a USER_UPDATE once and not any others. - session_ids = [] + session_ids: List[str] = [] public_user = await app.storage.get_user(user_id) private_user = await app.storage.get_user(user_id, secure=True) diff --git a/litecord/embed/messages.py b/litecord/embed/messages.py index 7f60dc8..b61a0f1 100644 --- a/litecord/embed/messages.py +++ b/litecord/embed/messages.py @@ -21,7 +21,7 @@ import re import asyncio import urllib.parse from pathlib import Path -from typing import List +from typing import List, Optional from quart import current_app as app from logbook import Logger @@ -35,16 +35,16 @@ log = Logger(__name__) MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm") -async def fetch_mediaproxy_img_meta(url) -> dict: +async def fetch_mediaproxy_img_meta(url) -> Optional[dict]: """Insert media metadata as an embed.""" img_proxy_url = proxify(url) meta = await fetch_metadata(url) if meta is None: - return + return None if not meta["image"]: - return + return None return { "type": "image", @@ -135,13 +135,15 @@ async def process_url_embed(payload: dict, *, delay=0): # if it isn't, we forward an /embed/ scope call to mediaproxy # to generate an embed for us out of the url. - new_embeds = [] + new_embeds: List[dict] = [] - for url in urls: - url: List[dict] = EmbedURL(url) + for upstream_url in urls: + url = EmbedURL(upstream_url) if is_media_url(url): - embeds = [await fetch_mediaproxy_img_meta(url)] + embed = await fetch_mediaproxy_img_meta(url) + if embed is not None: + embeds = [embed] else: embeds = await fetch_mediaproxy_embed(url) diff --git a/litecord/images.py b/litecord/images.py index 345f2ef..69479af 100644 --- a/litecord/images.py +++ b/litecord/images.py @@ -98,6 +98,9 @@ class Icon: return get_ext(self.mime) + def __bool__(self): + return self.key and self.icon_hash and self.mime + class ImageError(Exception): """Image error class.""" @@ -197,6 +200,9 @@ def _gen_update_sql(scope: str) -> str: def _invalid(kwargs: dict) -> Optional[Icon]: """Send an invalid value.""" + # TODO: remove optinality off this (really badly designed): + # - also remove kwargs off this function + # - also make an Icon.empty() constructor, and remove need for this entirely if not kwargs.get("always_icon", False): return None @@ -519,6 +525,7 @@ class IconManager: key = str(key) old_icon = await self.generic_get(scope, key, old_icon_hash) - await self.delete(old_icon) + if old_icon: + await self.delete(old_icon) return await self.put(scope, key, new_icon_data, **kwargs) diff --git a/litecord/permissions.py b/litecord/permissions.py index 11656a1..5715433 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -88,6 +88,7 @@ class Permissions(ctypes.Union): ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) +EMPTY_PERMISSIONS = Permissions(0) async def get_role_perms(guild_id, role_id, storage=None) -> Permissions: diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index fe3c215..cc0404f 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -89,7 +89,7 @@ class ChannelDispatcher(DispatcherWithFlags): ): continue - cur_sess = [] + cur_sess: List[str] = [] if ( event in ("CHANNEL_CREATE", "CHANNEL_UPDATE") diff --git a/litecord/pubsub/friend.py b/litecord/pubsub/friend.py index 0d0ae6b..4bf727a 100644 --- a/litecord/pubsub/friend.py +++ b/litecord/pubsub/friend.py @@ -17,6 +17,7 @@ along with this program. If not, see . """ +from typing import List from logbook import Logger from .dispatcher import DispatcherWithState @@ -38,7 +39,7 @@ class FriendDispatcher(DispatcherWithState): async def dispatch_filter(self, user_id: int, func, event, data): """Dispatch an event to all of a users' friends.""" peer_ids = self.state[user_id] - sessions = [] + sessions: List[str] = [] for peer_id in peer_ids: # dispatch to the user instead of the "shards tied to a guild" diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index b391186..235b69f 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -39,6 +39,7 @@ from litecord.permissions import ( overwrite_find_mix, get_permissions, role_permissions, + EMPTY_PERMISSIONS, ) from litecord.utils import index_by_func from litecord.utils import mmh3 @@ -264,8 +265,8 @@ class GuildMemberList: self.list = MemberList() #: store the states that are subscribed to the list. - # type is {session_id: set[list]} - self.state: Dict[str, Set[List[int, int]]] = defaultdict(set) + # type is {session_id: set[tuple]} + self.state: Dict[str, Set[Tuple[int, int]]] = defaultdict(set) self._list_lock = asyncio.Lock() @@ -414,8 +415,8 @@ class GuildMemberList: # inject default groups 'online' and 'offline' # their position is always going to be the last ones. self.list.groups = role_groups + [ - GroupInfo("online", "online", MAX_ROLES + 1, 0), - GroupInfo("offline", "offline", MAX_ROLES + 2, 0), + GroupInfo("online", "online", MAX_ROLES + 1, EMPTY_PERMISSIONS), + GroupInfo("offline", "offline", MAX_ROLES + 2, EMPTY_PERMISSIONS), ] async def _get_group_for_member( @@ -808,6 +809,8 @@ class GuildMemberList: ops = [] old_user_index = self._get_item_index(user_id) + assert old_user_index is not None + old_group_index = self._get_group_item_index(old_group) ops.append(Operation("DELETE", {"index": old_user_index})) @@ -819,6 +822,7 @@ class GuildMemberList: await self._sort_groups() new_user_index = self._get_item_index(user_id) + assert new_user_index is not None ops.append( Operation( @@ -931,6 +935,7 @@ class GuildMemberList: # if unknown state, remove from the subscriber list if state is None: self.state.pop(session_id) + continue # if we aren't talking about the state the user # being removed is subscribed to, ignore @@ -1365,8 +1370,8 @@ class GuildMemberList: len(self.state), ) - self.guild_id = None - self.channel_id = None + self.guild_id = 0 + self.channel_id = 0 self.main = None self._set_empty_list() self.state = {} @@ -1392,7 +1397,7 @@ class LazyGuildDispatcher(Dispatcher): #: store which guilds have their # respective GMLs # {guild_id: [chan_id, ...], ...} - self.guild_map = defaultdict(list) + self.guild_map: Dict[int, List[int]] = defaultdict(list) async def get_gml(self, channel_id: int): """Get a guild list for a channel ID, @@ -1414,7 +1419,18 @@ class LazyGuildDispatcher(Dispatcher): def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]: """Get all member lists for a given guild.""" - return list(map(self.state.get, self.guild_map[guild_id])) + res: List[GuildMemberList] = [] + + channel_ids: List[int] = self.guild_map[guild_id] + for channel_id in channel_ids: + guild_list: Optional[GuildMemberList] = self.state.get(channel_id) + if guild_list is None: + self.guild_map[guild_id].remove(channel_id) + continue + + res.append(guild_list) + + return res async def unsub(self, chan_id, session_id): """Unsubscribe a session from the list.""" diff --git a/litecord/storage.py b/litecord/storage.py index 4532376..b01bd04 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -17,13 +17,18 @@ along with this program. If not, see . """ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Union, TypedDict from logbook import Logger from litecord.enums import ChannelType from litecord.schemas import USER_MENTION, ROLE_MENTION -from litecord.blueprints.channel.reactions import EmojiType, emoji_sql, partial_emoji +from litecord.blueprints.channel.reactions import ( + EmojiType, + emoji_sql, + partial_emoji, + PartialEmoji, +) from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE @@ -64,6 +69,12 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: str): return list(filter(lambda recipient: recipient["id"] != user_id, recipients)) +class EmojiStats(TypedDict): + count: int + me: bool + emoji: PartialEmoji + + class Storage: """Class for common SQL statements.""" @@ -373,7 +384,7 @@ class Storage: members = await self.get_member_multi(guild_id, mids) return members - async def chan_last_message(self, channel_id: int): + async def chan_last_message(self, channel_id: int) -> Optional[int]: """Get the last message ID in a channel.""" return await self.db.fetchval( """ @@ -491,7 +502,7 @@ class Storage: return [r["member_id"] for r in user_ids] async def _gdm_recipients( - self, channel_id: int, reference_id: int = None + self, channel_id: int, reference_id: Optional[int] = None ) -> List[Dict]: """Get the list of users that are recipients of the given Group DM.""" @@ -576,11 +587,12 @@ class Storage: drow = dict(gdm_row) drow["type"] = chan_type - drow["recipients"] = await self._gdm_recipients( - channel_id, kwargs.get("user_id") - ) - drow["last_message_id"] = await self.chan_last_message_str(channel_id) + user_id: Optional[int] = kwargs.get("user_id") + assert user_id is not None + drow["recipients"] = await self._gdm_recipients(channel_id, user_id) + + drow["last_message_id"] = await self.chan_last_message_str(channel_id) return drow return None @@ -634,7 +646,7 @@ class Storage: return channels async def get_role( - self, role_id: int, guild_id: int = None + self, role_id: int, guild_id: Optional[int] = None ) -> Optional[Dict[str, Any]]: """get a single role's information.""" @@ -732,6 +744,8 @@ class Storage: mids = [int(m["user"]["id"]) for m in members] + assert self.presence is not None + return { **res, **{ @@ -822,10 +836,10 @@ class Storage: ) # ordered list of emoji - emoji = [] + emoji: List[Union[int, str]] = [] # the current state of emoji info - react_stats = {} + react_stats: Dict[Union[str, int], EmojiStats] = {} # to generate the list, we pass through all # all reactions and insert them all. @@ -1007,6 +1021,7 @@ class Storage: # calculate user mentions and role mentions by regex async def _get_member(user_id): user = await self.get_user(user_id) + assert user is not None member = None if guild_id: @@ -1143,6 +1158,7 @@ class Storage: return {} mids = await self.get_member_ids(guild_id) + assert self.presence is not None pres = await self.presence.guild_presences(mids, guild_id) online_count = sum(1 for p in pres if p["status"] == "online") @@ -1172,7 +1188,7 @@ class Storage: return dinv - async def get_dm(self, dm_id: int, user_id: int = None) -> Optional[Dict]: + async def get_dm(self, dm_id: int, user_id: Optional[int] = None) -> Optional[Dict]: """Get a DM channel.""" dm_chan = await self.get_channel(dm_id) diff --git a/litecord/system_messages.py b/litecord/system_messages.py index c7da62f..2d548db 100644 --- a/litecord/system_messages.py +++ b/litecord/system_messages.py @@ -18,6 +18,7 @@ along with this program. If not, see . """ from logbook import Logger +from quart import current_app as app from winter import get_snowflake from litecord.enums import MessageType @@ -25,7 +26,7 @@ from litecord.enums import MessageType log = Logger(__name__) -async def _handle_pin_msg(app, channel_id, _pinned_id, author_id): +async def _handle_pin_msg(channel_id, _pinned_id, author_id): """Handle a message pin.""" new_id = get_snowflake() @@ -48,7 +49,7 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id): # TODO: decrease repetition between add and remove handlers -async def _handle_recp_add(app, channel_id, author_id, peer_id): +async def _handle_recp_add(channel_id, author_id, peer_id): new_id = get_snowflake() await app.db.execute( @@ -69,7 +70,7 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id): return new_id -async def _handle_recp_rmv(app, channel_id, author_id, peer_id): +async def _handle_recp_rmv(channel_id, author_id, peer_id): new_id = get_snowflake() await app.db.execute( @@ -90,7 +91,7 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id): return new_id -async def _handle_gdm_name_edit(app, channel_id, author_id): +async def _handle_gdm_name_edit(channel_id, author_id): new_id = get_snowflake() gdm_name = await app.db.fetchval( @@ -123,7 +124,7 @@ async def _handle_gdm_name_edit(app, channel_id, author_id): return new_id -async def _handle_gdm_icon_edit(app, channel_id, author_id): +async def _handle_gdm_icon_edit(channel_id, author_id): new_id = get_snowflake() await app.db.execute( @@ -145,7 +146,7 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id): async def send_sys_message( - app, channel_id: int, m_type: MessageType, *args, **kwargs + channel_id: int, m_type: MessageType, *args, **kwargs ) -> int: """Send a system message. @@ -179,7 +180,7 @@ async def send_sys_message( except KeyError: raise ValueError("Invalid system message type") - message_id = await handler(app, channel_id, *args, **kwargs) + message_id = await handler(channel_id, *args, **kwargs) message = await app.storage.get_message(message_id) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..468186c --- /dev/null +++ b/mypy.ini @@ -0,0 +1,11 @@ +[mypy] +check_untyped_defs = True +no_implicit_optional = True +[mypy-logbook] +ignore_missing_imports = True +[mypy-quart] +ignore_missing_imports = True +[mypy-winter] +ignore_missing_imports = True +[mypy-asyncpg] +ignore_missing_imports = True diff --git a/tox.ini b/tox.ini index 5f75c76..b4ad110 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py3.7 +envlist = py3.8 [testenv] deps = -rrequirements.txt