diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index fd8a060..41850e2 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -217,7 +217,8 @@ async def _guild_text_mentions(payload: dict, guild_id: int, # for the users that have a state # in the channel. if mentions_here: - uids = [] + uids = set() + await app.db.execute(""" UPDATE user_read_state SET mention_count = mention_count + 1 @@ -229,7 +230,7 @@ async def _guild_text_mentions(payload: dict, guild_id: int, # that might not have read permissions # to the channel. if mentions_everyone: - uids = [] + uids = set() member_ids = await app.storage.get_member_ids(guild_id) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index f900ef6..5cfff5d 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -18,6 +18,7 @@ along with this program. If not, see . """ import time +from typing import List from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger @@ -262,7 +263,7 @@ async def _update_pos(channel_id, pos: int): """, pos, channel_id) -async def _mass_chan_update(guild_id, channel_ids: int): +async def _mass_chan_update(guild_id, channel_ids: List[int]): for channel_id in channel_ids: chan = await app.storage.get_channel(channel_id) await app.dispatcher.dispatch( @@ -337,7 +338,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict): if 'position' in j: channel_data = await app.storage.get_channel_data(guild_id) - chans = [None * len(channel_data)] + chans = [None] * len(channel_data) for chandata in channel_data: chans.insert(chandata['position'], int(chandata['id'])) diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py index 8dd9204..3558168 100644 --- a/litecord/blueprints/guild/members.py +++ b/litecord/blueprints/guild/members.py @@ -68,7 +68,7 @@ async def get_members(guild_id): async def _update_member_roles(guild_id: int, member_id: int, - wanted_roles: list): + wanted_roles: set): """Update the roles a member has.""" # first, fetch all current roles diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index bd419cb..2a06c73 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ -from typing import List, Dict +from typing import List, Dict, Tuple from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger @@ -184,10 +184,11 @@ async def _role_pairs_update(guild_id: int, pairs: list): await _role_update_dispatch(role_1, guild_id) await _role_update_dispatch(role_2, guild_id) +PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]] def gen_pairs(list_of_changes: List[Dict[str, int]], current_state: Dict[int, int], - blacklist: List[int] = None) -> List[tuple]: + blacklist: List[int] = None) -> PairList: """Generate a list of pairs that, when applied to the database, will generate the desired state given in list_of_changes. @@ -262,7 +263,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]], # if its being swapped to leave space, add it # to the pairs list - if new_pos_2: + if element_2 and new_pos_2: pairs.append( ((element_1, new_pos_1), (element_2, new_pos_2)) ) diff --git a/litecord/blueprints/voice.py b/litecord/blueprints/voice.py index 4d788cd..a06eec1 100644 --- a/litecord/blueprints/voice.py +++ b/litecord/blueprints/voice.py @@ -17,6 +17,7 @@ along with this program. If not, see . """ +from typing import Optional from collections import Counter from random import choice @@ -36,7 +37,7 @@ def _majority_region_count(regions: list) -> str: return region -async def _choose_random_region() -> str: +async def _choose_random_region() -> Optional[str]: """Give a random voice region.""" regions = await app.db.fetch(""" SELECT id @@ -51,7 +52,7 @@ async def _choose_random_region() -> str: return choice(regions) -async def _majority_region_any(user_id) -> str: +async def _majority_region_any(user_id) -> Optional[str]: """Calculate the most likely region to make the user happy, but this is based on the guilds the user is IN, instead of the guilds the user owns.""" @@ -79,7 +80,7 @@ async def _majority_region_any(user_id) -> str: return most_common -async def majority_region(user_id) -> str: +async def majority_region(user_id: int) -> Optional[str]: """Given a user ID, give the most likely region for the user to be happy with.""" regions = await app.db.fetch(""" diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 9a4ddee..e6fa7d4 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -235,7 +235,7 @@ class GatewayWebsocket: 's': None }) - def _check_ratelimit(self, key: str, ratelimit_key: str): + def _check_ratelimit(self, key: str, ratelimit_key): ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}') bucket = ratelimit.get_bucket(ratelimit_key) return bucket.update_rate_limit() @@ -292,7 +292,7 @@ class GatewayWebsocket: await self.send(payload) - async def _make_guild_list(self) -> List[int]: + async def _make_guild_list(self) -> List[Dict[str, Any]]: user_id = self.state.user_id guild_ids = await self._guild_ids() @@ -772,7 +772,7 @@ class GatewayWebsocket: await self._resume(range(seq, state.seq)) - async def _req_guild_members(self, guild_id: str, user_ids: List[int], + async def _req_guild_members(self, guild_id, user_ids: List[int], query: str, limit: int): try: guild_id = int(guild_id) diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index 2b32a1d..b931bec 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ -from typing import Any +from typing import Any, List from logbook import Logger @@ -54,13 +54,13 @@ class ChannelDispatcher(DispatcherWithState): VAL_TYPE = int async def dispatch(self, channel_id, - event: str, data: Any): + event: str, data: Any) -> List[str]: """Dispatch an event to a channel.""" # get everyone who is subscribed # and store the number of states we dispatched the event to user_ids = self.state[channel_id] dispatched = 0 - sessions = [] + sessions: List[str] = [] # making a copy of user_ids since # we'll modify it later on. diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index 4b60019..dd03ef2 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -17,9 +17,7 @@ along with this program. If not, see . """ -""" -litecord.pubsub.dispatcher: main dispatcher class -""" +from typing import List from collections import defaultdict from logbook import Logger @@ -82,7 +80,8 @@ class Dispatcher: """ raise NotImplementedError - async def _dispatch_states(self, states: list, event: str, data) -> int: + async def _dispatch_states(self, states: list, event: str, + data) -> List[str]: """Dispatch an event to a list of states.""" res = [] diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index da87367..93c15aa 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -28,7 +28,7 @@ lazy guilds: import asyncio from collections import defaultdict -from typing import Any, List, Dict, Union +from typing import Any, List, Dict, Union, Optional, Iterable, Tuple from dataclasses import dataclass, asdict, field from logbook import Logger @@ -39,7 +39,7 @@ from litecord.permissions import ( ) from litecord.utils import index_by_func from litecord.utils import mmh3 - +from litecord.gateway.state import GatewayState log = Logger(__name__) @@ -113,7 +113,7 @@ class MemberList: yield group, self.data[group.gid] @property - def iter_non_empty(self) -> tuple: + def iter_non_empty(self) -> Generator[Tuple[GroupInfo, List[int]]]: """Only iterate through non-empty groups. Note that while the offline group can be empty, it is always @@ -359,7 +359,7 @@ class GuildMemberList: # then the final perms for that role if # any overwrite exists in the channel final_perms = overwrite_find_mix( - role_perms, self.list.overwrites, group.gid) + role_perms, self.list.overwrites, int(group.gid)) # update the group's permissions # with the mixed ones @@ -423,7 +423,7 @@ class GuildMemberList: async def _get_group_for_member(self, member_id: int, roles: List[Union[str, int]], - status: str) -> GroupID: + status: str) -> Optional[GroupID]: """Return a fitting group ID for the member.""" member_roles = list(map(int, roles)) @@ -463,15 +463,15 @@ class GuildMemberList: self.list.members[member_id] = member self.list.data[group_id].append(member_id) - def _display_name(self, member_id: int) -> str: + def _display_name(self, member_id: int) -> Optional[str]: """Get the display name for a given member. This is more efficient than the old function (not method) of same name, as we dont need to pass nickname information to it. """ - member = self.list.members.get(member_id) - - if not member_id: + try: + member = self.list.members[member_id] + except KeyError: return None username = member['user']['username'] @@ -578,7 +578,7 @@ class GuildMemberList: if not self.state: self._set_empty_list() - def _get_state(self, session_id: str): + def _get_state(self, session_id: str) -> Optional[GatewayState]: """Get the state for a session id. Wrapper for :meth:`StateManager.fetch_raw` @@ -625,7 +625,8 @@ class GuildMemberList: return dispatched - async def _resync(self, session_ids: int, item_index: int) -> List[str]: + async def _resync(self, session_ids: List[int], + item_index: int) -> List[str]: """Send a SYNC event to all states that are subscribed to an item. Returns @@ -729,7 +730,7 @@ class GuildMemberList: # send SYNCs to the state that requested await self._dispatch_sess([session_id], ops) - def _get_item_index(self, user_id: Union[str, int]) -> int: + def _get_item_index(self, user_id: Union[str, int]) -> Optional[int]: """Get the item index a user is on.""" # NOTE: this is inefficient user_id = int(user_id) @@ -749,7 +750,7 @@ class GuildMemberList: return None - def _get_group_item_index(self, group_id: GroupID) -> int: + def _get_group_item_index(self, group_id: GroupID) -> Optional[int]: """Get the item index a group is on.""" index = 0 @@ -773,7 +774,7 @@ class GuildMemberList: return False - def _get_subs(self, item_index: int) -> filter: + def _get_subs(self, item_index: int) -> Iterable[str]: """Get the list of subscribed states to a given item.""" return filter( lambda sess_id: self._is_subbed(item_index, sess_id), @@ -1141,7 +1142,7 @@ class GuildMemberList: # when bots come along. self.list.data[new_group.gid] = [] - def _get_role_as_group_idx(self, role_id: int) -> int: + def _get_role_as_group_idx(self, role_id: int) -> Optional[int]: """Get a group index representing the given role id. Returns diff --git a/litecord/schemas.py b/litecord/schemas.py index 8b4fd28..3683334 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -147,7 +147,7 @@ class LitecordValidator(Validator): def validate(reqjson: Union[Dict, List], schema: Dict, - raise_err: bool = True) -> Union[Dict, List]: + raise_err: bool = True) -> Dict: """Validate a given document (user-input) and give the correct document as a result. """ diff --git a/litecord/storage.py b/litecord/storage.py index 962bb57..2b63e82 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional from logbook import Logger @@ -77,7 +77,7 @@ class Storage: self.db = app.db self.presence = None - async def fetchrow_with_json(self, query: str, *args): + async def fetchrow_with_json(self, query: str, *args) -> Any: """Fetch a single row with JSON/JSONB support.""" # the pool by itself doesn't have # set_type_codec, so we must set it manually @@ -86,19 +86,19 @@ class Storage: await pg_set_json(con) return await con.fetchrow(query, *args) - async def fetch_with_json(self, query: str, *args): + async def fetch_with_json(self, query: str, *args) -> List[Any]: """Fetch many rows with JSON/JSONB support.""" async with self.db.acquire() as con: await pg_set_json(con) return await con.fetch(query, *args) - async def execute_with_json(self, query: str, *args): + async def execute_with_json(self, query: str, *args) -> str: """Execute a SQL statement with JSON/JSONB support.""" async with self.db.acquire() as con: await pg_set_json(con) return await con.execute(query, *args) - async def get_user(self, user_id, secure=False) -> Dict[str, Any]: + async def get_user(self, user_id, secure=False) -> Optional[Dict[str, Any]]: """Get a single user payload.""" user_id = int(user_id) @@ -115,7 +115,7 @@ class Storage: """, user_id) if not user_row: - return + return None duser = dict(user_row) @@ -141,7 +141,7 @@ class Storage: """Search a user""" if len(discriminator) < 4: # how do we do this in f-strings again..? - discriminator = '%04d' % discriminator + discriminator = '%04d' % int(discriminator) return await self.db.fetchval(""" SELECT id FROM users @@ -219,12 +219,12 @@ class Storage: } async def get_member_data_one(self, guild_id: int, - member_id: int) -> Dict[str, Any]: + member_id: int) -> Optional[Dict[str, Any]]: """Get data about one member in a guild.""" basic = await self._member_basic(guild_id, member_id) if not basic: - return + return None return await self._member_dict(basic, guild_id, member_id) @@ -376,7 +376,7 @@ class Storage: return [r['member_id'] for r in user_ids] async def _gdm_recipients(self, channel_id: int, - reference_id: int = None) -> List[int]: + reference_id: int = None) -> List[Dict]: """Get the list of users that are recipients of the given Group DM.""" recipients = await self.gdm_recipient_ids(channel_id) @@ -392,7 +392,8 @@ class Storage: return res - async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]: + async def get_channel(self, channel_id: int, + **kwargs) -> Optional[Dict[str, Any]]: """Fetch a single channel's information.""" chan_type = await self.get_chan_type(channel_id) ctype = ChannelType(chan_type) @@ -501,7 +502,7 @@ class Storage: return channels async def get_role(self, role_id: int, - guild_id: int = None) -> Dict[str, Any]: + guild_id: int = None) -> Optional[Dict[str, Any]]: """get a single role's information.""" guild_field = 'AND guild_id = $2' if guild_id else '' @@ -519,7 +520,7 @@ class Storage: """, *args) if not row: - return + return None return dict(row) @@ -769,7 +770,8 @@ class Storage: res.pop('author_id') - async def get_message(self, message_id: int, user_id=None) -> Dict: + async def get_message(self, message_id: int, + user_id: Optional[int] = None) -> Optional[Dict]: """Get a single message's payload.""" row = await self.fetchrow_with_json(""" SELECT id::text, channel_id::text, author_id, webhook_id, content, @@ -780,7 +782,7 @@ class Storage: """, message_id) if not row: - return + return None res = dict(row) res['nonce'] = str(res['nonce']) @@ -915,7 +917,8 @@ class Storage: 'approximate_member_count': len(mids), } - async def get_invite_metadata(self, invite_code: str) -> Dict[str, Any]: + async def get_invite_metadata(self, + invite_code: str) -> Optional[Dict[str, Any]]: """Fetch invite metadata (max_age and friends).""" invite = await self.db.fetchrow(""" SELECT code, inviter, created_at, uses, @@ -925,7 +928,7 @@ class Storage: """, invite_code) if invite is None: - return + return None dinv = dict_(invite) inviter = await self.get_user(invite['inviter']) @@ -966,7 +969,7 @@ class Storage: return parties[0] - async def get_emoji(self, emoji_id: int) -> Dict: + async def get_emoji(self, emoji_id: int) -> Optional[Dict[str, Any]]: """Get a single emoji.""" row = await self.db.fetchrow(""" SELECT id::text, name, animated, managed, @@ -976,7 +979,7 @@ class Storage: """, emoji_id) if not row: - return + return None drow = dict(row) diff --git a/litecord/utils.py b/litecord/utils.py index 5647a37..1dca864 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -20,7 +20,7 @@ along with this program. If not, see . import asyncio import json from logbook import Logger -from typing import Any +from typing import Any, Iterable from quart.json import JSONEncoder log = Logger(__name__) @@ -160,7 +160,7 @@ async def pg_set_json(con): ) -def yield_chunks(input_list: list, chunk_size: int): +def yield_chunks(input_list: Iterable, chunk_size: int): """Yield successive n-sized chunks from l. Taken from https://stackoverflow.com/a/312464.