diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py
index fd8a060..41850e2 100644
--- a/litecord/blueprints/channel/messages.py
+++ b/litecord/blueprints/channel/messages.py
@@ -217,7 +217,8 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
# for the users that have a state
# in the channel.
if mentions_here:
- uids = []
+ uids = set()
+
await app.db.execute("""
UPDATE user_read_state
SET mention_count = mention_count + 1
@@ -229,7 +230,7 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
# that might not have read permissions
# to the channel.
if mentions_everyone:
- uids = []
+ uids = set()
member_ids = await app.storage.get_member_ids(guild_id)
diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py
index f900ef6..5cfff5d 100644
--- a/litecord/blueprints/channels.py
+++ b/litecord/blueprints/channels.py
@@ -18,6 +18,7 @@ along with this program. If not, see .
"""
import time
+from typing import List
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@@ -262,7 +263,7 @@ async def _update_pos(channel_id, pos: int):
""", pos, channel_id)
-async def _mass_chan_update(guild_id, channel_ids: int):
+async def _mass_chan_update(guild_id, channel_ids: List[int]):
for channel_id in channel_ids:
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch(
@@ -337,7 +338,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
if 'position' in j:
channel_data = await app.storage.get_channel_data(guild_id)
- chans = [None * len(channel_data)]
+ chans = [None] * len(channel_data)
for chandata in channel_data:
chans.insert(chandata['position'], int(chandata['id']))
diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py
index 8dd9204..3558168 100644
--- a/litecord/blueprints/guild/members.py
+++ b/litecord/blueprints/guild/members.py
@@ -68,7 +68,7 @@ async def get_members(guild_id):
async def _update_member_roles(guild_id: int, member_id: int,
- wanted_roles: list):
+ wanted_roles: set):
"""Update the roles a member has."""
# first, fetch all current roles
diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py
index bd419cb..2a06c73 100644
--- a/litecord/blueprints/guild/roles.py
+++ b/litecord/blueprints/guild/roles.py
@@ -17,7 +17,7 @@ along with this program. If not, see .
"""
-from typing import List, Dict
+from typing import List, Dict, Tuple
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@@ -184,10 +184,11 @@ async def _role_pairs_update(guild_id: int, pairs: list):
await _role_update_dispatch(role_1, guild_id)
await _role_update_dispatch(role_2, guild_id)
+PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
def gen_pairs(list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int],
- blacklist: List[int] = None) -> List[tuple]:
+ blacklist: List[int] = None) -> PairList:
"""Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes.
@@ -262,7 +263,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# if its being swapped to leave space, add it
# to the pairs list
- if new_pos_2:
+ if element_2 and new_pos_2:
pairs.append(
((element_1, new_pos_1), (element_2, new_pos_2))
)
diff --git a/litecord/blueprints/voice.py b/litecord/blueprints/voice.py
index 4d788cd..a06eec1 100644
--- a/litecord/blueprints/voice.py
+++ b/litecord/blueprints/voice.py
@@ -17,6 +17,7 @@ along with this program. If not, see .
"""
+from typing import Optional
from collections import Counter
from random import choice
@@ -36,7 +37,7 @@ def _majority_region_count(regions: list) -> str:
return region
-async def _choose_random_region() -> str:
+async def _choose_random_region() -> Optional[str]:
"""Give a random voice region."""
regions = await app.db.fetch("""
SELECT id
@@ -51,7 +52,7 @@ async def _choose_random_region() -> str:
return choice(regions)
-async def _majority_region_any(user_id) -> str:
+async def _majority_region_any(user_id) -> Optional[str]:
"""Calculate the most likely region to make the user happy, but
this is based on the guilds the user is IN, instead of the guilds
the user owns."""
@@ -79,7 +80,7 @@ async def _majority_region_any(user_id) -> str:
return most_common
-async def majority_region(user_id) -> str:
+async def majority_region(user_id: int) -> Optional[str]:
"""Given a user ID, give the most likely region for the user to be
happy with."""
regions = await app.db.fetch("""
diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py
index 9a4ddee..e6fa7d4 100644
--- a/litecord/gateway/websocket.py
+++ b/litecord/gateway/websocket.py
@@ -235,7 +235,7 @@ class GatewayWebsocket:
's': None
})
- def _check_ratelimit(self, key: str, ratelimit_key: str):
+ def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@@ -292,7 +292,7 @@ class GatewayWebsocket:
await self.send(payload)
- async def _make_guild_list(self) -> List[int]:
+ async def _make_guild_list(self) -> List[Dict[str, Any]]:
user_id = self.state.user_id
guild_ids = await self._guild_ids()
@@ -772,7 +772,7 @@ class GatewayWebsocket:
await self._resume(range(seq, state.seq))
- async def _req_guild_members(self, guild_id: str, user_ids: List[int],
+ async def _req_guild_members(self, guild_id, user_ids: List[int],
query: str, limit: int):
try:
guild_id = int(guild_id)
diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py
index 2b32a1d..b931bec 100644
--- a/litecord/pubsub/channel.py
+++ b/litecord/pubsub/channel.py
@@ -17,7 +17,7 @@ along with this program. If not, see .
"""
-from typing import Any
+from typing import Any, List
from logbook import Logger
@@ -54,13 +54,13 @@ class ChannelDispatcher(DispatcherWithState):
VAL_TYPE = int
async def dispatch(self, channel_id,
- event: str, data: Any):
+ event: str, data: Any) -> List[str]:
"""Dispatch an event to a channel."""
# get everyone who is subscribed
# and store the number of states we dispatched the event to
user_ids = self.state[channel_id]
dispatched = 0
- sessions = []
+ sessions: List[str] = []
# making a copy of user_ids since
# we'll modify it later on.
diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py
index 4b60019..dd03ef2 100644
--- a/litecord/pubsub/dispatcher.py
+++ b/litecord/pubsub/dispatcher.py
@@ -17,9 +17,7 @@ along with this program. If not, see .
"""
-"""
-litecord.pubsub.dispatcher: main dispatcher class
-"""
+from typing import List
from collections import defaultdict
from logbook import Logger
@@ -82,7 +80,8 @@ class Dispatcher:
"""
raise NotImplementedError
- async def _dispatch_states(self, states: list, event: str, data) -> int:
+ async def _dispatch_states(self, states: list, event: str,
+ data) -> List[str]:
"""Dispatch an event to a list of states."""
res = []
diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py
index da87367..93c15aa 100644
--- a/litecord/pubsub/lazy_guild.py
+++ b/litecord/pubsub/lazy_guild.py
@@ -28,7 +28,7 @@ lazy guilds:
import asyncio
from collections import defaultdict
-from typing import Any, List, Dict, Union
+from typing import Any, List, Dict, Union, Optional, Iterable, Tuple
from dataclasses import dataclass, asdict, field
from logbook import Logger
@@ -39,7 +39,7 @@ from litecord.permissions import (
)
from litecord.utils import index_by_func
from litecord.utils import mmh3
-
+from litecord.gateway.state import GatewayState
log = Logger(__name__)
@@ -113,7 +113,7 @@ class MemberList:
yield group, self.data[group.gid]
@property
- def iter_non_empty(self) -> tuple:
+ def iter_non_empty(self) -> Generator[Tuple[GroupInfo, List[int]]]:
"""Only iterate through non-empty groups.
Note that while the offline group can be empty, it is always
@@ -359,7 +359,7 @@ class GuildMemberList:
# then the final perms for that role if
# any overwrite exists in the channel
final_perms = overwrite_find_mix(
- role_perms, self.list.overwrites, group.gid)
+ role_perms, self.list.overwrites, int(group.gid))
# update the group's permissions
# with the mixed ones
@@ -423,7 +423,7 @@ class GuildMemberList:
async def _get_group_for_member(self, member_id: int,
roles: List[Union[str, int]],
- status: str) -> GroupID:
+ status: str) -> Optional[GroupID]:
"""Return a fitting group ID for the member."""
member_roles = list(map(int, roles))
@@ -463,15 +463,15 @@ class GuildMemberList:
self.list.members[member_id] = member
self.list.data[group_id].append(member_id)
- def _display_name(self, member_id: int) -> str:
+ def _display_name(self, member_id: int) -> Optional[str]:
"""Get the display name for a given member.
This is more efficient than the old function (not method) of same
name, as we dont need to pass nickname information to it.
"""
- member = self.list.members.get(member_id)
-
- if not member_id:
+ try:
+ member = self.list.members[member_id]
+ except KeyError:
return None
username = member['user']['username']
@@ -578,7 +578,7 @@ class GuildMemberList:
if not self.state:
self._set_empty_list()
- def _get_state(self, session_id: str):
+ def _get_state(self, session_id: str) -> Optional[GatewayState]:
"""Get the state for a session id.
Wrapper for :meth:`StateManager.fetch_raw`
@@ -625,7 +625,8 @@ class GuildMemberList:
return dispatched
- async def _resync(self, session_ids: int, item_index: int) -> List[str]:
+ async def _resync(self, session_ids: List[int],
+ item_index: int) -> List[str]:
"""Send a SYNC event to all states that are subscribed to an item.
Returns
@@ -729,7 +730,7 @@ class GuildMemberList:
# send SYNCs to the state that requested
await self._dispatch_sess([session_id], ops)
- def _get_item_index(self, user_id: Union[str, int]) -> int:
+ def _get_item_index(self, user_id: Union[str, int]) -> Optional[int]:
"""Get the item index a user is on."""
# NOTE: this is inefficient
user_id = int(user_id)
@@ -749,7 +750,7 @@ class GuildMemberList:
return None
- def _get_group_item_index(self, group_id: GroupID) -> int:
+ def _get_group_item_index(self, group_id: GroupID) -> Optional[int]:
"""Get the item index a group is on."""
index = 0
@@ -773,7 +774,7 @@ class GuildMemberList:
return False
- def _get_subs(self, item_index: int) -> filter:
+ def _get_subs(self, item_index: int) -> Iterable[str]:
"""Get the list of subscribed states to a given item."""
return filter(
lambda sess_id: self._is_subbed(item_index, sess_id),
@@ -1141,7 +1142,7 @@ class GuildMemberList:
# when bots come along.
self.list.data[new_group.gid] = []
- def _get_role_as_group_idx(self, role_id: int) -> int:
+ def _get_role_as_group_idx(self, role_id: int) -> Optional[int]:
"""Get a group index representing the given role id.
Returns
diff --git a/litecord/schemas.py b/litecord/schemas.py
index 8b4fd28..3683334 100644
--- a/litecord/schemas.py
+++ b/litecord/schemas.py
@@ -147,7 +147,7 @@ class LitecordValidator(Validator):
def validate(reqjson: Union[Dict, List], schema: Dict,
- raise_err: bool = True) -> Union[Dict, List]:
+ raise_err: bool = True) -> Dict:
"""Validate a given document (user-input) and give
the correct document as a result.
"""
diff --git a/litecord/storage.py b/litecord/storage.py
index 962bb57..2b63e82 100644
--- a/litecord/storage.py
+++ b/litecord/storage.py
@@ -17,7 +17,7 @@ along with this program. If not, see .
"""
-from typing import List, Dict, Any
+from typing import List, Dict, Any, Optional
from logbook import Logger
@@ -77,7 +77,7 @@ class Storage:
self.db = app.db
self.presence = None
- async def fetchrow_with_json(self, query: str, *args):
+ async def fetchrow_with_json(self, query: str, *args) -> Any:
"""Fetch a single row with JSON/JSONB support."""
# the pool by itself doesn't have
# set_type_codec, so we must set it manually
@@ -86,19 +86,19 @@ class Storage:
await pg_set_json(con)
return await con.fetchrow(query, *args)
- async def fetch_with_json(self, query: str, *args):
+ async def fetch_with_json(self, query: str, *args) -> List[Any]:
"""Fetch many rows with JSON/JSONB support."""
async with self.db.acquire() as con:
await pg_set_json(con)
return await con.fetch(query, *args)
- async def execute_with_json(self, query: str, *args):
+ async def execute_with_json(self, query: str, *args) -> str:
"""Execute a SQL statement with JSON/JSONB support."""
async with self.db.acquire() as con:
await pg_set_json(con)
return await con.execute(query, *args)
- async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
+ async def get_user(self, user_id, secure=False) -> Optional[Dict[str, Any]]:
"""Get a single user payload."""
user_id = int(user_id)
@@ -115,7 +115,7 @@ class Storage:
""", user_id)
if not user_row:
- return
+ return None
duser = dict(user_row)
@@ -141,7 +141,7 @@ class Storage:
"""Search a user"""
if len(discriminator) < 4:
# how do we do this in f-strings again..?
- discriminator = '%04d' % discriminator
+ discriminator = '%04d' % int(discriminator)
return await self.db.fetchval("""
SELECT id FROM users
@@ -219,12 +219,12 @@ class Storage:
}
async def get_member_data_one(self, guild_id: int,
- member_id: int) -> Dict[str, Any]:
+ member_id: int) -> Optional[Dict[str, Any]]:
"""Get data about one member in a guild."""
basic = await self._member_basic(guild_id, member_id)
if not basic:
- return
+ return None
return await self._member_dict(basic, guild_id, member_id)
@@ -376,7 +376,7 @@ class Storage:
return [r['member_id'] for r in user_ids]
async def _gdm_recipients(self, channel_id: int,
- reference_id: int = None) -> List[int]:
+ reference_id: int = None) -> List[Dict]:
"""Get the list of users that are recipients of the
given Group DM."""
recipients = await self.gdm_recipient_ids(channel_id)
@@ -392,7 +392,8 @@ class Storage:
return res
- async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]:
+ async def get_channel(self, channel_id: int,
+ **kwargs) -> Optional[Dict[str, Any]]:
"""Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id)
ctype = ChannelType(chan_type)
@@ -501,7 +502,7 @@ class Storage:
return channels
async def get_role(self, role_id: int,
- guild_id: int = None) -> Dict[str, Any]:
+ guild_id: int = None) -> Optional[Dict[str, Any]]:
"""get a single role's information."""
guild_field = 'AND guild_id = $2' if guild_id else ''
@@ -519,7 +520,7 @@ class Storage:
""", *args)
if not row:
- return
+ return None
return dict(row)
@@ -769,7 +770,8 @@ class Storage:
res.pop('author_id')
- async def get_message(self, message_id: int, user_id=None) -> Dict:
+ async def get_message(self, message_id: int,
+ user_id: Optional[int] = None) -> Optional[Dict]:
"""Get a single message's payload."""
row = await self.fetchrow_with_json("""
SELECT id::text, channel_id::text, author_id, webhook_id, content,
@@ -780,7 +782,7 @@ class Storage:
""", message_id)
if not row:
- return
+ return None
res = dict(row)
res['nonce'] = str(res['nonce'])
@@ -915,7 +917,8 @@ class Storage:
'approximate_member_count': len(mids),
}
- async def get_invite_metadata(self, invite_code: str) -> Dict[str, Any]:
+ async def get_invite_metadata(self,
+ invite_code: str) -> Optional[Dict[str, Any]]:
"""Fetch invite metadata (max_age and friends)."""
invite = await self.db.fetchrow("""
SELECT code, inviter, created_at, uses,
@@ -925,7 +928,7 @@ class Storage:
""", invite_code)
if invite is None:
- return
+ return None
dinv = dict_(invite)
inviter = await self.get_user(invite['inviter'])
@@ -966,7 +969,7 @@ class Storage:
return parties[0]
- async def get_emoji(self, emoji_id: int) -> Dict:
+ async def get_emoji(self, emoji_id: int) -> Optional[Dict[str, Any]]:
"""Get a single emoji."""
row = await self.db.fetchrow("""
SELECT id::text, name, animated, managed,
@@ -976,7 +979,7 @@ class Storage:
""", emoji_id)
if not row:
- return
+ return None
drow = dict(row)
diff --git a/litecord/utils.py b/litecord/utils.py
index 5647a37..1dca864 100644
--- a/litecord/utils.py
+++ b/litecord/utils.py
@@ -20,7 +20,7 @@ along with this program. If not, see .
import asyncio
import json
from logbook import Logger
-from typing import Any
+from typing import Any, Iterable
from quart.json import JSONEncoder
log = Logger(__name__)
@@ -160,7 +160,7 @@ async def pg_set_json(con):
)
-def yield_chunks(input_list: list, chunk_size: int):
+def yield_chunks(input_list: Iterable, chunk_size: int):
"""Yield successive n-sized chunks from l.
Taken from https://stackoverflow.com/a/312464.