typings, episode 1

(i installed mypy and its beautiful)
This commit is contained in:
Luna 2019-03-04 05:09:04 -03:00
parent 506bd8afbe
commit d91030a2c1
12 changed files with 65 additions and 58 deletions

View File

@ -217,7 +217,8 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
# for the users that have a state
# in the channel.
if mentions_here:
uids = []
uids = set()
await app.db.execute("""
UPDATE user_read_state
SET mention_count = mention_count + 1
@ -229,7 +230,7 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
# that might not have read permissions
# to the channel.
if mentions_everyone:
uids = []
uids = set()
member_ids = await app.storage.get_member_ids(guild_id)

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import time
from typing import List
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@ -262,7 +263,7 @@ async def _update_pos(channel_id, pos: int):
""", pos, channel_id)
async def _mass_chan_update(guild_id, channel_ids: int):
async def _mass_chan_update(guild_id, channel_ids: List[int]):
for channel_id in channel_ids:
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch(
@ -337,7 +338,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
if 'position' in j:
channel_data = await app.storage.get_channel_data(guild_id)
chans = [None * len(channel_data)]
chans = [None] * len(channel_data)
for chandata in channel_data:
chans.insert(chandata['position'], int(chandata['id']))

View File

@ -68,7 +68,7 @@ async def get_members(guild_id):
async def _update_member_roles(guild_id: int, member_id: int,
wanted_roles: list):
wanted_roles: set):
"""Update the roles a member has."""
# first, fetch all current roles

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import List, Dict
from typing import List, Dict, Tuple
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@ -184,10 +184,11 @@ async def _role_pairs_update(guild_id: int, pairs: list):
await _role_update_dispatch(role_1, guild_id)
await _role_update_dispatch(role_2, guild_id)
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
def gen_pairs(list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int],
blacklist: List[int] = None) -> List[tuple]:
blacklist: List[int] = None) -> PairList:
"""Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes.
@ -262,7 +263,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
# if its being swapped to leave space, add it
# to the pairs list
if new_pos_2:
if element_2 and new_pos_2:
pairs.append(
((element_1, new_pos_1), (element_2, new_pos_2))
)

View File

@ -17,6 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import Optional
from collections import Counter
from random import choice
@ -36,7 +37,7 @@ def _majority_region_count(regions: list) -> str:
return region
async def _choose_random_region() -> str:
async def _choose_random_region() -> Optional[str]:
"""Give a random voice region."""
regions = await app.db.fetch("""
SELECT id
@ -51,7 +52,7 @@ async def _choose_random_region() -> str:
return choice(regions)
async def _majority_region_any(user_id) -> str:
async def _majority_region_any(user_id) -> Optional[str]:
"""Calculate the most likely region to make the user happy, but
this is based on the guilds the user is IN, instead of the guilds
the user owns."""
@ -79,7 +80,7 @@ async def _majority_region_any(user_id) -> str:
return most_common
async def majority_region(user_id) -> str:
async def majority_region(user_id: int) -> Optional[str]:
"""Given a user ID, give the most likely region for the user to be
happy with."""
regions = await app.db.fetch("""

View File

@ -235,7 +235,7 @@ class GatewayWebsocket:
's': None
})
def _check_ratelimit(self, key: str, ratelimit_key: str):
def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@ -292,7 +292,7 @@ class GatewayWebsocket:
await self.send(payload)
async def _make_guild_list(self) -> List[int]:
async def _make_guild_list(self) -> List[Dict[str, Any]]:
user_id = self.state.user_id
guild_ids = await self._guild_ids()
@ -772,7 +772,7 @@ class GatewayWebsocket:
await self._resume(range(seq, state.seq))
async def _req_guild_members(self, guild_id: str, user_ids: List[int],
async def _req_guild_members(self, guild_id, user_ids: List[int],
query: str, limit: int):
try:
guild_id = int(guild_id)

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import Any
from typing import Any, List
from logbook import Logger
@ -54,13 +54,13 @@ class ChannelDispatcher(DispatcherWithState):
VAL_TYPE = int
async def dispatch(self, channel_id,
event: str, data: Any):
event: str, data: Any) -> List[str]:
"""Dispatch an event to a channel."""
# get everyone who is subscribed
# and store the number of states we dispatched the event to
user_ids = self.state[channel_id]
dispatched = 0
sessions = []
sessions: List[str] = []
# making a copy of user_ids since
# we'll modify it later on.

View File

@ -17,9 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
"""
litecord.pubsub.dispatcher: main dispatcher class
"""
from typing import List
from collections import defaultdict
from logbook import Logger
@ -82,7 +80,8 @@ class Dispatcher:
"""
raise NotImplementedError
async def _dispatch_states(self, states: list, event: str, data) -> int:
async def _dispatch_states(self, states: list, event: str,
data) -> List[str]:
"""Dispatch an event to a list of states."""
res = []

View File

@ -28,7 +28,7 @@ lazy guilds:
import asyncio
from collections import defaultdict
from typing import Any, List, Dict, Union
from typing import Any, List, Dict, Union, Optional, Iterable, Tuple
from dataclasses import dataclass, asdict, field
from logbook import Logger
@ -39,7 +39,7 @@ from litecord.permissions import (
)
from litecord.utils import index_by_func
from litecord.utils import mmh3
from litecord.gateway.state import GatewayState
log = Logger(__name__)
@ -113,7 +113,7 @@ class MemberList:
yield group, self.data[group.gid]
@property
def iter_non_empty(self) -> tuple:
def iter_non_empty(self) -> Generator[Tuple[GroupInfo, List[int]]]:
"""Only iterate through non-empty groups.
Note that while the offline group can be empty, it is always
@ -359,7 +359,7 @@ class GuildMemberList:
# then the final perms for that role if
# any overwrite exists in the channel
final_perms = overwrite_find_mix(
role_perms, self.list.overwrites, group.gid)
role_perms, self.list.overwrites, int(group.gid))
# update the group's permissions
# with the mixed ones
@ -423,7 +423,7 @@ class GuildMemberList:
async def _get_group_for_member(self, member_id: int,
roles: List[Union[str, int]],
status: str) -> GroupID:
status: str) -> Optional[GroupID]:
"""Return a fitting group ID for the member."""
member_roles = list(map(int, roles))
@ -463,15 +463,15 @@ class GuildMemberList:
self.list.members[member_id] = member
self.list.data[group_id].append(member_id)
def _display_name(self, member_id: int) -> str:
def _display_name(self, member_id: int) -> Optional[str]:
"""Get the display name for a given member.
This is more efficient than the old function (not method) of same
name, as we dont need to pass nickname information to it.
"""
member = self.list.members.get(member_id)
if not member_id:
try:
member = self.list.members[member_id]
except KeyError:
return None
username = member['user']['username']
@ -578,7 +578,7 @@ class GuildMemberList:
if not self.state:
self._set_empty_list()
def _get_state(self, session_id: str):
def _get_state(self, session_id: str) -> Optional[GatewayState]:
"""Get the state for a session id.
Wrapper for :meth:`StateManager.fetch_raw`
@ -625,7 +625,8 @@ class GuildMemberList:
return dispatched
async def _resync(self, session_ids: int, item_index: int) -> List[str]:
async def _resync(self, session_ids: List[int],
item_index: int) -> List[str]:
"""Send a SYNC event to all states that are subscribed to an item.
Returns
@ -729,7 +730,7 @@ class GuildMemberList:
# send SYNCs to the state that requested
await self._dispatch_sess([session_id], ops)
def _get_item_index(self, user_id: Union[str, int]) -> int:
def _get_item_index(self, user_id: Union[str, int]) -> Optional[int]:
"""Get the item index a user is on."""
# NOTE: this is inefficient
user_id = int(user_id)
@ -749,7 +750,7 @@ class GuildMemberList:
return None
def _get_group_item_index(self, group_id: GroupID) -> int:
def _get_group_item_index(self, group_id: GroupID) -> Optional[int]:
"""Get the item index a group is on."""
index = 0
@ -773,7 +774,7 @@ class GuildMemberList:
return False
def _get_subs(self, item_index: int) -> filter:
def _get_subs(self, item_index: int) -> Iterable[str]:
"""Get the list of subscribed states to a given item."""
return filter(
lambda sess_id: self._is_subbed(item_index, sess_id),
@ -1141,7 +1142,7 @@ class GuildMemberList:
# when bots come along.
self.list.data[new_group.gid] = []
def _get_role_as_group_idx(self, role_id: int) -> int:
def _get_role_as_group_idx(self, role_id: int) -> Optional[int]:
"""Get a group index representing the given role id.
Returns

View File

@ -147,7 +147,7 @@ class LitecordValidator(Validator):
def validate(reqjson: Union[Dict, List], schema: Dict,
raise_err: bool = True) -> Union[Dict, List]:
raise_err: bool = True) -> Dict:
"""Validate a given document (user-input) and give
the correct document as a result.
"""

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
from logbook import Logger
@ -77,7 +77,7 @@ class Storage:
self.db = app.db
self.presence = None
async def fetchrow_with_json(self, query: str, *args):
async def fetchrow_with_json(self, query: str, *args) -> Any:
"""Fetch a single row with JSON/JSONB support."""
# the pool by itself doesn't have
# set_type_codec, so we must set it manually
@ -86,19 +86,19 @@ class Storage:
await pg_set_json(con)
return await con.fetchrow(query, *args)
async def fetch_with_json(self, query: str, *args):
async def fetch_with_json(self, query: str, *args) -> List[Any]:
"""Fetch many rows with JSON/JSONB support."""
async with self.db.acquire() as con:
await pg_set_json(con)
return await con.fetch(query, *args)
async def execute_with_json(self, query: str, *args):
async def execute_with_json(self, query: str, *args) -> str:
"""Execute a SQL statement with JSON/JSONB support."""
async with self.db.acquire() as con:
await pg_set_json(con)
return await con.execute(query, *args)
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
async def get_user(self, user_id, secure=False) -> Optional[Dict[str, Any]]:
"""Get a single user payload."""
user_id = int(user_id)
@ -115,7 +115,7 @@ class Storage:
""", user_id)
if not user_row:
return
return None
duser = dict(user_row)
@ -141,7 +141,7 @@ class Storage:
"""Search a user"""
if len(discriminator) < 4:
# how do we do this in f-strings again..?
discriminator = '%04d' % discriminator
discriminator = '%04d' % int(discriminator)
return await self.db.fetchval("""
SELECT id FROM users
@ -219,12 +219,12 @@ class Storage:
}
async def get_member_data_one(self, guild_id: int,
member_id: int) -> Dict[str, Any]:
member_id: int) -> Optional[Dict[str, Any]]:
"""Get data about one member in a guild."""
basic = await self._member_basic(guild_id, member_id)
if not basic:
return
return None
return await self._member_dict(basic, guild_id, member_id)
@ -376,7 +376,7 @@ class Storage:
return [r['member_id'] for r in user_ids]
async def _gdm_recipients(self, channel_id: int,
reference_id: int = None) -> List[int]:
reference_id: int = None) -> List[Dict]:
"""Get the list of users that are recipients of the
given Group DM."""
recipients = await self.gdm_recipient_ids(channel_id)
@ -392,7 +392,8 @@ class Storage:
return res
async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]:
async def get_channel(self, channel_id: int,
**kwargs) -> Optional[Dict[str, Any]]:
"""Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id)
ctype = ChannelType(chan_type)
@ -501,7 +502,7 @@ class Storage:
return channels
async def get_role(self, role_id: int,
guild_id: int = None) -> Dict[str, Any]:
guild_id: int = None) -> Optional[Dict[str, Any]]:
"""get a single role's information."""
guild_field = 'AND guild_id = $2' if guild_id else ''
@ -519,7 +520,7 @@ class Storage:
""", *args)
if not row:
return
return None
return dict(row)
@ -769,7 +770,8 @@ class Storage:
res.pop('author_id')
async def get_message(self, message_id: int, user_id=None) -> Dict:
async def get_message(self, message_id: int,
user_id: Optional[int] = None) -> Optional[Dict]:
"""Get a single message's payload."""
row = await self.fetchrow_with_json("""
SELECT id::text, channel_id::text, author_id, webhook_id, content,
@ -780,7 +782,7 @@ class Storage:
""", message_id)
if not row:
return
return None
res = dict(row)
res['nonce'] = str(res['nonce'])
@ -915,7 +917,8 @@ class Storage:
'approximate_member_count': len(mids),
}
async def get_invite_metadata(self, invite_code: str) -> Dict[str, Any]:
async def get_invite_metadata(self,
invite_code: str) -> Optional[Dict[str, Any]]:
"""Fetch invite metadata (max_age and friends)."""
invite = await self.db.fetchrow("""
SELECT code, inviter, created_at, uses,
@ -925,7 +928,7 @@ class Storage:
""", invite_code)
if invite is None:
return
return None
dinv = dict_(invite)
inviter = await self.get_user(invite['inviter'])
@ -966,7 +969,7 @@ class Storage:
return parties[0]
async def get_emoji(self, emoji_id: int) -> Dict:
async def get_emoji(self, emoji_id: int) -> Optional[Dict[str, Any]]:
"""Get a single emoji."""
row = await self.db.fetchrow("""
SELECT id::text, name, animated, managed,
@ -976,7 +979,7 @@ class Storage:
""", emoji_id)
if not row:
return
return None
drow = dict(row)

View File

@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import json
from logbook import Logger
from typing import Any
from typing import Any, Iterable
from quart.json import JSONEncoder
log = Logger(__name__)
@ -160,7 +160,7 @@ async def pg_set_json(con):
)
def yield_chunks(input_list: list, chunk_size: int):
def yield_chunks(input_list: Iterable, chunk_size: int):
"""Yield successive n-sized chunks from l.
Taken from https://stackoverflow.com/a/312464.