typings, episode 1

(i installed mypy and its beautiful)
This commit is contained in:
Luna 2019-03-04 05:09:04 -03:00
parent 506bd8afbe
commit d91030a2c1
12 changed files with 65 additions and 58 deletions

View File

@ -217,7 +217,8 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
# for the users that have a state # for the users that have a state
# in the channel. # in the channel.
if mentions_here: if mentions_here:
uids = [] uids = set()
await app.db.execute(""" await app.db.execute("""
UPDATE user_read_state UPDATE user_read_state
SET mention_count = mention_count + 1 SET mention_count = mention_count + 1
@ -229,7 +230,7 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
# that might not have read permissions # that might not have read permissions
# to the channel. # to the channel.
if mentions_everyone: if mentions_everyone:
uids = [] uids = set()
member_ids = await app.storage.get_member_ids(guild_id) member_ids = await app.storage.get_member_ids(guild_id)

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import time import time
from typing import List
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
@ -262,7 +263,7 @@ async def _update_pos(channel_id, pos: int):
""", pos, channel_id) """, pos, channel_id)
async def _mass_chan_update(guild_id, channel_ids: int): async def _mass_chan_update(guild_id, channel_ids: List[int]):
for channel_id in channel_ids: for channel_id in channel_ids:
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch( await app.dispatcher.dispatch(
@ -337,7 +338,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
if 'position' in j: if 'position' in j:
channel_data = await app.storage.get_channel_data(guild_id) channel_data = await app.storage.get_channel_data(guild_id)
chans = [None * len(channel_data)] chans = [None] * len(channel_data)
for chandata in channel_data: for chandata in channel_data:
chans.insert(chandata['position'], int(chandata['id'])) chans.insert(chandata['position'], int(chandata['id']))

View File

@ -68,7 +68,7 @@ async def get_members(guild_id):
async def _update_member_roles(guild_id: int, member_id: int, async def _update_member_roles(guild_id: int, member_id: int,
wanted_roles: list): wanted_roles: set):
"""Update the roles a member has.""" """Update the roles a member has."""
# first, fetch all current roles # first, fetch all current roles

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List, Dict from typing import List, Dict, Tuple
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
@ -184,10 +184,11 @@ async def _role_pairs_update(guild_id: int, pairs: list):
await _role_update_dispatch(role_1, guild_id) await _role_update_dispatch(role_1, guild_id)
await _role_update_dispatch(role_2, guild_id) await _role_update_dispatch(role_2, guild_id)
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
def gen_pairs(list_of_changes: List[Dict[str, int]], def gen_pairs(list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int], current_state: Dict[int, int],
blacklist: List[int] = None) -> List[tuple]: blacklist: List[int] = None) -> PairList:
"""Generate a list of pairs that, when applied to the database, """Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes. will generate the desired state given in list_of_changes.
@ -262,7 +263,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# if its being swapped to leave space, add it # if its being swapped to leave space, add it
# to the pairs list # to the pairs list
if new_pos_2: if element_2 and new_pos_2:
pairs.append( pairs.append(
((element_1, new_pos_1), (element_2, new_pos_2)) ((element_1, new_pos_1), (element_2, new_pos_2))
) )

View File

@ -17,6 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Optional
from collections import Counter from collections import Counter
from random import choice from random import choice
@ -36,7 +37,7 @@ def _majority_region_count(regions: list) -> str:
return region return region
async def _choose_random_region() -> str: async def _choose_random_region() -> Optional[str]:
"""Give a random voice region.""" """Give a random voice region."""
regions = await app.db.fetch(""" regions = await app.db.fetch("""
SELECT id SELECT id
@ -51,7 +52,7 @@ async def _choose_random_region() -> str:
return choice(regions) return choice(regions)
async def _majority_region_any(user_id) -> str: async def _majority_region_any(user_id) -> Optional[str]:
"""Calculate the most likely region to make the user happy, but """Calculate the most likely region to make the user happy, but
this is based on the guilds the user is IN, instead of the guilds this is based on the guilds the user is IN, instead of the guilds
the user owns.""" the user owns."""
@ -79,7 +80,7 @@ async def _majority_region_any(user_id) -> str:
return most_common return most_common
async def majority_region(user_id) -> str: async def majority_region(user_id: int) -> Optional[str]:
"""Given a user ID, give the most likely region for the user to be """Given a user ID, give the most likely region for the user to be
happy with.""" happy with."""
regions = await app.db.fetch(""" regions = await app.db.fetch("""

View File

@ -235,7 +235,7 @@ class GatewayWebsocket:
's': None 's': None
}) })
def _check_ratelimit(self, key: str, ratelimit_key: str): def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}') ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
bucket = ratelimit.get_bucket(ratelimit_key) bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit() return bucket.update_rate_limit()
@ -292,7 +292,7 @@ class GatewayWebsocket:
await self.send(payload) await self.send(payload)
async def _make_guild_list(self) -> List[int]: async def _make_guild_list(self) -> List[Dict[str, Any]]:
user_id = self.state.user_id user_id = self.state.user_id
guild_ids = await self._guild_ids() guild_ids = await self._guild_ids()
@ -772,7 +772,7 @@ class GatewayWebsocket:
await self._resume(range(seq, state.seq)) await self._resume(range(seq, state.seq))
async def _req_guild_members(self, guild_id: str, user_ids: List[int], async def _req_guild_members(self, guild_id, user_ids: List[int],
query: str, limit: int): query: str, limit: int):
try: try:
guild_id = int(guild_id) guild_id = int(guild_id)

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Any from typing import Any, List
from logbook import Logger from logbook import Logger
@ -54,13 +54,13 @@ class ChannelDispatcher(DispatcherWithState):
VAL_TYPE = int VAL_TYPE = int
async def dispatch(self, channel_id, async def dispatch(self, channel_id,
event: str, data: Any): event: str, data: Any) -> List[str]:
"""Dispatch an event to a channel.""" """Dispatch an event to a channel."""
# get everyone who is subscribed # get everyone who is subscribed
# and store the number of states we dispatched the event to # and store the number of states we dispatched the event to
user_ids = self.state[channel_id] user_ids = self.state[channel_id]
dispatched = 0 dispatched = 0
sessions = [] sessions: List[str] = []
# making a copy of user_ids since # making a copy of user_ids since
# we'll modify it later on. # we'll modify it later on.

View File

@ -17,9 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
""" from typing import List
litecord.pubsub.dispatcher: main dispatcher class
"""
from collections import defaultdict from collections import defaultdict
from logbook import Logger from logbook import Logger
@ -82,7 +80,8 @@ class Dispatcher:
""" """
raise NotImplementedError raise NotImplementedError
async def _dispatch_states(self, states: list, event: str, data) -> int: async def _dispatch_states(self, states: list, event: str,
data) -> List[str]:
"""Dispatch an event to a list of states.""" """Dispatch an event to a list of states."""
res = [] res = []

View File

@ -28,7 +28,7 @@ lazy guilds:
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from typing import Any, List, Dict, Union from typing import Any, List, Dict, Union, Optional, Iterable, Tuple
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from logbook import Logger from logbook import Logger
@ -39,7 +39,7 @@ from litecord.permissions import (
) )
from litecord.utils import index_by_func from litecord.utils import index_by_func
from litecord.utils import mmh3 from litecord.utils import mmh3
from litecord.gateway.state import GatewayState
log = Logger(__name__) log = Logger(__name__)
@ -113,7 +113,7 @@ class MemberList:
yield group, self.data[group.gid] yield group, self.data[group.gid]
@property @property
def iter_non_empty(self) -> tuple: def iter_non_empty(self) -> Generator[Tuple[GroupInfo, List[int]]]:
"""Only iterate through non-empty groups. """Only iterate through non-empty groups.
Note that while the offline group can be empty, it is always Note that while the offline group can be empty, it is always
@ -359,7 +359,7 @@ class GuildMemberList:
# then the final perms for that role if # then the final perms for that role if
# any overwrite exists in the channel # any overwrite exists in the channel
final_perms = overwrite_find_mix( final_perms = overwrite_find_mix(
role_perms, self.list.overwrites, group.gid) role_perms, self.list.overwrites, int(group.gid))
# update the group's permissions # update the group's permissions
# with the mixed ones # with the mixed ones
@ -423,7 +423,7 @@ class GuildMemberList:
async def _get_group_for_member(self, member_id: int, async def _get_group_for_member(self, member_id: int,
roles: List[Union[str, int]], roles: List[Union[str, int]],
status: str) -> GroupID: status: str) -> Optional[GroupID]:
"""Return a fitting group ID for the member.""" """Return a fitting group ID for the member."""
member_roles = list(map(int, roles)) member_roles = list(map(int, roles))
@ -463,15 +463,15 @@ class GuildMemberList:
self.list.members[member_id] = member self.list.members[member_id] = member
self.list.data[group_id].append(member_id) self.list.data[group_id].append(member_id)
def _display_name(self, member_id: int) -> str: def _display_name(self, member_id: int) -> Optional[str]:
"""Get the display name for a given member. """Get the display name for a given member.
This is more efficient than the old function (not method) of same This is more efficient than the old function (not method) of same
name, as we dont need to pass nickname information to it. name, as we dont need to pass nickname information to it.
""" """
member = self.list.members.get(member_id) try:
member = self.list.members[member_id]
if not member_id: except KeyError:
return None return None
username = member['user']['username'] username = member['user']['username']
@ -578,7 +578,7 @@ class GuildMemberList:
if not self.state: if not self.state:
self._set_empty_list() self._set_empty_list()
def _get_state(self, session_id: str): def _get_state(self, session_id: str) -> Optional[GatewayState]:
"""Get the state for a session id. """Get the state for a session id.
Wrapper for :meth:`StateManager.fetch_raw` Wrapper for :meth:`StateManager.fetch_raw`
@ -625,7 +625,8 @@ class GuildMemberList:
return dispatched return dispatched
async def _resync(self, session_ids: int, item_index: int) -> List[str]: async def _resync(self, session_ids: List[int],
item_index: int) -> List[str]:
"""Send a SYNC event to all states that are subscribed to an item. """Send a SYNC event to all states that are subscribed to an item.
Returns Returns
@ -729,7 +730,7 @@ class GuildMemberList:
# send SYNCs to the state that requested # send SYNCs to the state that requested
await self._dispatch_sess([session_id], ops) await self._dispatch_sess([session_id], ops)
def _get_item_index(self, user_id: Union[str, int]) -> int: def _get_item_index(self, user_id: Union[str, int]) -> Optional[int]:
"""Get the item index a user is on.""" """Get the item index a user is on."""
# NOTE: this is inefficient # NOTE: this is inefficient
user_id = int(user_id) user_id = int(user_id)
@ -749,7 +750,7 @@ class GuildMemberList:
return None return None
def _get_group_item_index(self, group_id: GroupID) -> int: def _get_group_item_index(self, group_id: GroupID) -> Optional[int]:
"""Get the item index a group is on.""" """Get the item index a group is on."""
index = 0 index = 0
@ -773,7 +774,7 @@ class GuildMemberList:
return False return False
def _get_subs(self, item_index: int) -> filter: def _get_subs(self, item_index: int) -> Iterable[str]:
"""Get the list of subscribed states to a given item.""" """Get the list of subscribed states to a given item."""
return filter( return filter(
lambda sess_id: self._is_subbed(item_index, sess_id), lambda sess_id: self._is_subbed(item_index, sess_id),
@ -1141,7 +1142,7 @@ class GuildMemberList:
# when bots come along. # when bots come along.
self.list.data[new_group.gid] = [] self.list.data[new_group.gid] = []
def _get_role_as_group_idx(self, role_id: int) -> int: def _get_role_as_group_idx(self, role_id: int) -> Optional[int]:
"""Get a group index representing the given role id. """Get a group index representing the given role id.
Returns Returns

View File

@ -147,7 +147,7 @@ class LitecordValidator(Validator):
def validate(reqjson: Union[Dict, List], schema: Dict, def validate(reqjson: Union[Dict, List], schema: Dict,
raise_err: bool = True) -> Union[Dict, List]: raise_err: bool = True) -> Dict:
"""Validate a given document (user-input) and give """Validate a given document (user-input) and give
the correct document as a result. the correct document as a result.
""" """

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List, Dict, Any from typing import List, Dict, Any, Optional
from logbook import Logger from logbook import Logger
@ -77,7 +77,7 @@ class Storage:
self.db = app.db self.db = app.db
self.presence = None self.presence = None
async def fetchrow_with_json(self, query: str, *args): async def fetchrow_with_json(self, query: str, *args) -> Any:
"""Fetch a single row with JSON/JSONB support.""" """Fetch a single row with JSON/JSONB support."""
# the pool by itself doesn't have # the pool by itself doesn't have
# set_type_codec, so we must set it manually # set_type_codec, so we must set it manually
@ -86,19 +86,19 @@ class Storage:
await pg_set_json(con) await pg_set_json(con)
return await con.fetchrow(query, *args) return await con.fetchrow(query, *args)
async def fetch_with_json(self, query: str, *args): async def fetch_with_json(self, query: str, *args) -> List[Any]:
"""Fetch many rows with JSON/JSONB support.""" """Fetch many rows with JSON/JSONB support."""
async with self.db.acquire() as con: async with self.db.acquire() as con:
await pg_set_json(con) await pg_set_json(con)
return await con.fetch(query, *args) return await con.fetch(query, *args)
async def execute_with_json(self, query: str, *args): async def execute_with_json(self, query: str, *args) -> str:
"""Execute a SQL statement with JSON/JSONB support.""" """Execute a SQL statement with JSON/JSONB support."""
async with self.db.acquire() as con: async with self.db.acquire() as con:
await pg_set_json(con) await pg_set_json(con)
return await con.execute(query, *args) return await con.execute(query, *args)
async def get_user(self, user_id, secure=False) -> Dict[str, Any]: async def get_user(self, user_id, secure=False) -> Optional[Dict[str, Any]]:
"""Get a single user payload.""" """Get a single user payload."""
user_id = int(user_id) user_id = int(user_id)
@ -115,7 +115,7 @@ class Storage:
""", user_id) """, user_id)
if not user_row: if not user_row:
return return None
duser = dict(user_row) duser = dict(user_row)
@ -141,7 +141,7 @@ class Storage:
"""Search a user""" """Search a user"""
if len(discriminator) < 4: if len(discriminator) < 4:
# how do we do this in f-strings again..? # how do we do this in f-strings again..?
discriminator = '%04d' % discriminator discriminator = '%04d' % int(discriminator)
return await self.db.fetchval(""" return await self.db.fetchval("""
SELECT id FROM users SELECT id FROM users
@ -219,12 +219,12 @@ class Storage:
} }
async def get_member_data_one(self, guild_id: int, async def get_member_data_one(self, guild_id: int,
member_id: int) -> Dict[str, Any]: member_id: int) -> Optional[Dict[str, Any]]:
"""Get data about one member in a guild.""" """Get data about one member in a guild."""
basic = await self._member_basic(guild_id, member_id) basic = await self._member_basic(guild_id, member_id)
if not basic: if not basic:
return return None
return await self._member_dict(basic, guild_id, member_id) return await self._member_dict(basic, guild_id, member_id)
@ -376,7 +376,7 @@ class Storage:
return [r['member_id'] for r in user_ids] return [r['member_id'] for r in user_ids]
async def _gdm_recipients(self, channel_id: int, async def _gdm_recipients(self, channel_id: int,
reference_id: int = None) -> List[int]: reference_id: int = None) -> List[Dict]:
"""Get the list of users that are recipients of the """Get the list of users that are recipients of the
given Group DM.""" given Group DM."""
recipients = await self.gdm_recipient_ids(channel_id) recipients = await self.gdm_recipient_ids(channel_id)
@ -392,7 +392,8 @@ class Storage:
return res return res
async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]: async def get_channel(self, channel_id: int,
**kwargs) -> Optional[Dict[str, Any]]:
"""Fetch a single channel's information.""" """Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id) chan_type = await self.get_chan_type(channel_id)
ctype = ChannelType(chan_type) ctype = ChannelType(chan_type)
@ -501,7 +502,7 @@ class Storage:
return channels return channels
async def get_role(self, role_id: int, async def get_role(self, role_id: int,
guild_id: int = None) -> Dict[str, Any]: guild_id: int = None) -> Optional[Dict[str, Any]]:
"""get a single role's information.""" """get a single role's information."""
guild_field = 'AND guild_id = $2' if guild_id else '' guild_field = 'AND guild_id = $2' if guild_id else ''
@ -519,7 +520,7 @@ class Storage:
""", *args) """, *args)
if not row: if not row:
return return None
return dict(row) return dict(row)
@ -769,7 +770,8 @@ class Storage:
res.pop('author_id') res.pop('author_id')
async def get_message(self, message_id: int, user_id=None) -> Dict: async def get_message(self, message_id: int,
user_id: Optional[int] = None) -> Optional[Dict]:
"""Get a single message's payload.""" """Get a single message's payload."""
row = await self.fetchrow_with_json(""" row = await self.fetchrow_with_json("""
SELECT id::text, channel_id::text, author_id, webhook_id, content, SELECT id::text, channel_id::text, author_id, webhook_id, content,
@ -780,7 +782,7 @@ class Storage:
""", message_id) """, message_id)
if not row: if not row:
return return None
res = dict(row) res = dict(row)
res['nonce'] = str(res['nonce']) res['nonce'] = str(res['nonce'])
@ -915,7 +917,8 @@ class Storage:
'approximate_member_count': len(mids), 'approximate_member_count': len(mids),
} }
async def get_invite_metadata(self, invite_code: str) -> Dict[str, Any]: async def get_invite_metadata(self,
invite_code: str) -> Optional[Dict[str, Any]]:
"""Fetch invite metadata (max_age and friends).""" """Fetch invite metadata (max_age and friends)."""
invite = await self.db.fetchrow(""" invite = await self.db.fetchrow("""
SELECT code, inviter, created_at, uses, SELECT code, inviter, created_at, uses,
@ -925,7 +928,7 @@ class Storage:
""", invite_code) """, invite_code)
if invite is None: if invite is None:
return return None
dinv = dict_(invite) dinv = dict_(invite)
inviter = await self.get_user(invite['inviter']) inviter = await self.get_user(invite['inviter'])
@ -966,7 +969,7 @@ class Storage:
return parties[0] return parties[0]
async def get_emoji(self, emoji_id: int) -> Dict: async def get_emoji(self, emoji_id: int) -> Optional[Dict[str, Any]]:
"""Get a single emoji.""" """Get a single emoji."""
row = await self.db.fetchrow(""" row = await self.db.fetchrow("""
SELECT id::text, name, animated, managed, SELECT id::text, name, animated, managed,
@ -976,7 +979,7 @@ class Storage:
""", emoji_id) """, emoji_id)
if not row: if not row:
return return None
drow = dict(row) drow = dict(row)

View File

@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import json import json
from logbook import Logger from logbook import Logger
from typing import Any from typing import Any, Iterable
from quart.json import JSONEncoder from quart.json import JSONEncoder
log = Logger(__name__) log = Logger(__name__)
@ -160,7 +160,7 @@ async def pg_set_json(con):
) )
def yield_chunks(input_list: list, chunk_size: int): def yield_chunks(input_list: Iterable, chunk_size: int):
"""Yield successive n-sized chunks from l. """Yield successive n-sized chunks from l.
Taken from https://stackoverflow.com/a/312464. Taken from https://stackoverflow.com/a/312464.