Merge branch 'feature/improve-typing' into 'master'

improve typing (episode 1)

Closes #85

See merge request litecord/litecord!59
This commit is contained in:
Luna 2020-02-06 21:10:52 +00:00
commit c3b281c940
27 changed files with 163 additions and 114 deletions

View File

@ -1,4 +1,4 @@
image: python:3.7-alpine
image: python:3.8-alpine
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"

View File

@ -26,7 +26,7 @@ flake8 = "*"
pyflakes = "*"
[requires]
python_version = "3.7"
python_version = "3.8"
[pipenv]
allow_prereleases = true

25
Pipfile.lock generated
View File

@ -1,11 +1,11 @@
{
"_meta": {
"hash": {
"sha256": "bdd693f88f93f57b4a6eca42e0c8f053e500ddc40e344d888af86143c1904751"
"sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3"
},
"pipfile-spec": 6,
"requires": {
"python_version": "3.7"
"python_version": "3.8"
},
"sources": [
{
@ -238,10 +238,10 @@
},
"jinja2": {
"hashes": [
"sha256:93187ffbc7808079673ef52771baa950426fd664d3aad1d0fa3e95644360e250",
"sha256:b0eaf100007721b5c16c1fc1eecb87409464edc10469ddc9a22a27a99123be49"
"sha256:c10142f819c2d22bdcd17548c46fa9b77cf4fda45097854c689666bf425e7484",
"sha256:c922560ac46888d47384de1dbdc3daaa2ea993af4b26a436dec31fa2c19ec668"
],
"version": "==2.11.1"
"version": "==3.0.0a1"
},
"logbook": {
"hashes": [
@ -554,14 +554,6 @@
"index": "pypi",
"version": "==3.7.9"
},
"importlib-metadata": {
"hashes": [
"sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302",
"sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b"
],
"markers": "python_version < '3.8'",
"version": "==1.5.0"
},
"mccabe": {
"hashes": [
"sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42",
@ -749,13 +741,6 @@
"sha256:f28b3e8a6483e5d49e7f8949ac1a78314e740333ae305b4ba5defd3e74fb37a8"
],
"version": "==0.1.8"
},
"zipp": {
"hashes": [
"sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28",
"sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98"
],
"version": "==2.1.0"
}
}
}

View File

@ -66,7 +66,7 @@ or third party libraries (such as [Eris](https://github.com/abalabahaha/eris)).
Requirements:
- **Python 3.7+**
- **Python 3.8+**
- PostgreSQL (tested using 9.6+), SQL knowledge is recommended.
- gifsicle for GIF emoji and avatar handling
- [pipenv]

View File

@ -19,6 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import string
from random import choice
from typing import Optional
from quart import Blueprint, jsonify, current_app as app, request
@ -36,7 +37,7 @@ async def _gen_inv() -> str:
return "".join(choice(ALPHABET) for _ in range(6))
async def gen_inv(ctx) -> str:
async def gen_inv(ctx) -> Optional[str]:
"""Generate an invite."""
for _ in range(10):
possible_inv = await _gen_inv()

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from pathlib import Path
from typing import Optional
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@ -146,7 +147,7 @@ async def _dm_pre_dispatch(channel_id, peer_id):
async def create_message(
channel_id: int, actual_guild_id: int, author_id: int, data: dict
channel_id: int, actual_guild_id: Optional[int], author_id: int, data: dict
) -> int:
message_id = get_snowflake()
@ -159,7 +160,7 @@ async def create_message(
content, tts, mention_everyone, nonce, message_type,
embeds)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
""",
""",
message_id,
channel_id,
actual_guild_id,
@ -186,7 +187,7 @@ async def _create_message(channel_id):
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
actual_guild_id = None
actual_guild_id: Optional[int] = None
if ctype in GUILD_CHANS:
await channel_perm_check(user_id, channel_id, "send_messages")

View File

@ -114,7 +114,7 @@ async def add_pin(channel_id, message_id):
)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
)
return "", 204

View File

@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from enum import IntEnum
from typing import List, Union, Tuple
from typing import List, Union, Tuple, TypedDict, Optional
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@ -60,7 +60,12 @@ def emoji_info_from_str(emoji: str) -> tuple:
return emoji_type, emoji_id, emoji_name
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
class PartialEmoji(TypedDict):
id: Optional[int]
name: str
def partial_emoji(emoji_type, emoji_id, emoji_name) -> PartialEmoji:
print(emoji_type, emoji_id, emoji_name)
return {
"id": None if emoji_type == EmojiType.UNICODE else emoji_id,

View File

@ -386,14 +386,14 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int):
return "", 204
async def _update_channel_common(channel_id, guild_id: int, j: dict):
async def _update_channel_common(channel_id: int, guild_id: int, j: dict):
if "name" in j:
await app.db.execute(
"""
UPDATE guild_channels
SET name = $1
WHERE id = $2
""",
UPDATE guild_channels
SET name = $1
WHERE id = $2
""",
j["name"],
channel_id,
)
@ -401,7 +401,9 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
if "position" in j:
channel_data = await app.storage.get_channel_data(guild_id)
chans = [None] * len(channel_data)
# get an ordered list of the chans array by position
# TODO bad impl. can break easily. maybe dict?
chans: List[Optional[int]] = [None] * len(channel_data)
for chandata in channel_data:
chans.insert(chandata["position"], int(chandata["id"]))
@ -422,7 +424,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
left_shift = new_pos > current_pos
# find all channels that we'll have to shift
shift_block = (
shift_block: List[Optional[int]] = (
chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
)
@ -509,9 +511,7 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int):
channel_id,
)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
)
await send_sys_message(channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id)
if "icon" in j:
new_icon = await app.icons.update(
@ -528,13 +528,11 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int):
channel_id,
)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
)
await send_sys_message(channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id)
@bp.route("/<int:channel_id>", methods=["PUT", "PATCH"])
async def update_channel(channel_id):
async def update_channel(channel_id: int):
"""Update a channel's information"""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import Union, List
from typing import Union, List, Optional
from quart import current_app as app
@ -66,7 +66,7 @@ async def guild_owner_check(user_id: int, guild_id: int):
async def channel_check(
user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None
user_id, channel_id, *, only: Optional[Union[ChannelType, List[ChannelType]]] = None
):
"""Check if the current user is authorized
to read the channel's information."""
@ -77,10 +77,10 @@ async def channel_check(
ctype = ChannelType(chan_type)
if only and not isinstance(only, list):
if (only is not None) and not isinstance(only, list):
only = [only]
if only and ctype not in only:
if (only is not None) and ctype not in only:
raise ChannelNotFound("invalid channel type")
if ctype in GUILD_CHANS:

View File

@ -113,9 +113,7 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
await app.dispatcher.sub("channel", peer_id)
if user_id:
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
)
await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
@ -145,9 +143,7 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
author_id = peer_id if user_id is None else user_id
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
)
await send_sys_message(channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id)
async def gdm_destroy(channel_id):

View File

@ -130,14 +130,11 @@ async def modify_channel_pos(guild_id):
# the same schema and all.
raw_j = await request.get_json()
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
j = j["roles"]
roles = j["roles"]
channels = await app.storage.get_channel_data(guild_id)
channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
swap_pairs = gen_pairs(j, channel_positions)
swap_pairs = gen_pairs(roles, channel_positions)
await _do_channel_swaps(guild_id, swap_pairs)
return "", 204

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Optional
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
@ -122,7 +122,7 @@ PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
def gen_pairs(
list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int],
blacklist: List[int] = None,
blacklist: Optional[List[int]] = None,
) -> PairList:
"""Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes.
@ -162,7 +162,7 @@ def gen_pairs(
List of swaps to do to achieve the preferred
state given by ``list_of_changes``.
"""
pairs = []
pairs: PairList = []
blacklist = blacklist or []
preferred_state = {
@ -222,9 +222,9 @@ async def update_guild_role_positions(guild_id):
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
# extract the list out
j = j["roles"]
roles = j["roles"]
log.debug("role stuff: {!r}", j)
log.debug("role stuff: {!r}", roles)
all_roles = await app.storage.get_role_data(guild_id)
@ -238,7 +238,7 @@ async def update_guild_role_positions(guild_id):
# NOTE: ^ this is related to the positioning of the roles.
pairs = gen_pairs(
j,
roles,
roles_pos,
# always ignore people trying to change
# the @everyone's role position

View File

@ -96,7 +96,7 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
return f"data:image/jpeg;base64,{icon}" if icon else None
async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs):
async def _general_guild_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
encoded = sanitize_icon(icon)
icon_kwargs = {"always_icon": True}

View File

@ -222,14 +222,14 @@ async def get_guild_webhook(guild_id):
async def get_single_webhook(webhook_id):
"""Get a single webhook's information."""
await _webhook_check_fw(webhook_id)
return await jsonify(await get_webhook(webhook_id))
return jsonify(await get_webhook(webhook_id))
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"])
async def get_tokened_webhook(webhook_id, webhook_token):
"""Get a webhook using its token."""
await webhook_token_check(webhook_id, webhook_token)
return await jsonify(await get_webhook(webhook_id, secure=False))
return jsonify(await get_webhook(webhook_id, secure=False))
async def _update_webhook(webhook_id: int, j: dict):
@ -289,6 +289,7 @@ async def modify_webhook(webhook_id: int):
await _update_webhook(webhook_id, j)
webhook = await get_webhook(webhook_id)
assert webhook is not None
# we don't need to cast channel_id to int since that isn't
# used in the dispatcher call
@ -313,7 +314,9 @@ async def modify_webhook_tokened(webhook_id, webhook_token):
async def delete_webhook(webhook_id: int):
"""Delete a webhook."""
webhook = await get_webhook(webhook_id)
assert webhook is not None
# TODO use returning?
res = await app.db.execute(
"""
DELETE FROM webhooks
@ -423,7 +426,10 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
# we still fetch the URL to check its validity, mimetypes, etc
# but in the end, we will store it under the webhook_avatars table,
# not IconManager.
resp, raw = await fetch_mediaproxy_img(avatar_url)
res = await fetch_mediaproxy_img(avatar_url)
if res is None:
raise BadRequest("Failed to fetch URL.")
resp, raw = res
# raw_b64 = base64.b64encode(raw).decode()
mime = resp.headers["content-type"]
@ -469,6 +475,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
given_embeds = j.get("embeds", [])
webhook = await get_webhook(webhook_id)
assert webhook is not None
# webhooks have TWO avatars. one is from settings, the other is from
# the json's icon_url. one can be handled gracefully by IconManager,

View File

@ -16,6 +16,7 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import Optional
from quart import current_app as app
@ -24,12 +25,15 @@ from litecord.enums import RelationshipType
async def channel_ack(
user_id: int, guild_id: int, channel_id: int, message_id: int = None
user_id: int, guild_id: int, channel_id: int, message_id: Optional[int] = None
):
"""ACK a channel."""
if not message_id:
message_id = await app.storage.chan_last_message(channel_id)
message_id = message_id or await app.storage.chan_last_message(channel_id)
# never ack without a message, as that breaks read state.
if message_id is None:
return
await app.db.execute(
"""

View File

@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from random import randint
from typing import Tuple, Optional
from typing import Tuple, Optional, List
from quart import current_app as app
from asyncpg import UniqueViolationError
@ -33,13 +33,13 @@ from ..utils import rand_hex
log = Logger(__name__)
async def mass_user_update(user_id):
async def mass_user_update(user_id: int):
"""Dispatch USER_UPDATE in a mass way."""
# by using dispatch_with_filter
# we're guaranteeing all shards will get
# a USER_UPDATE once and not any others.
session_ids = []
session_ids: List[str] = []
public_user = await app.storage.get_user(user_id)
private_user = await app.storage.get_user(user_id, secure=True)

View File

@ -21,7 +21,7 @@ import re
import asyncio
import urllib.parse
from pathlib import Path
from typing import List
from typing import List, Optional
from quart import current_app as app
from logbook import Logger
@ -35,16 +35,16 @@ log = Logger(__name__)
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
async def fetch_mediaproxy_img_meta(url) -> dict:
async def fetch_mediaproxy_img_meta(url) -> Optional[dict]:
"""Insert media metadata as an embed."""
img_proxy_url = proxify(url)
meta = await fetch_metadata(url)
if meta is None:
return
return None
if not meta["image"]:
return
return None
return {
"type": "image",
@ -135,13 +135,15 @@ async def process_url_embed(payload: dict, *, delay=0):
# if it isn't, we forward an /embed/ scope call to mediaproxy
# to generate an embed for us out of the url.
new_embeds = []
new_embeds: List[dict] = []
for url in urls:
url: List[dict] = EmbedURL(url)
for upstream_url in urls:
url = EmbedURL(upstream_url)
if is_media_url(url):
embeds = [await fetch_mediaproxy_img_meta(url)]
embed = await fetch_mediaproxy_img_meta(url)
if embed is not None:
embeds = [embed]
else:
embeds = await fetch_mediaproxy_embed(url)

View File

@ -98,6 +98,9 @@ class Icon:
return get_ext(self.mime)
def __bool__(self):
return self.key and self.icon_hash and self.mime
class ImageError(Exception):
"""Image error class."""
@ -197,6 +200,9 @@ def _gen_update_sql(scope: str) -> str:
def _invalid(kwargs: dict) -> Optional[Icon]:
"""Send an invalid value."""
# TODO: remove optinality off this (really badly designed):
# - also remove kwargs off this function
# - also make an Icon.empty() constructor, and remove need for this entirely
if not kwargs.get("always_icon", False):
return None
@ -519,6 +525,7 @@ class IconManager:
key = str(key)
old_icon = await self.generic_get(scope, key, old_icon_hash)
await self.delete(old_icon)
if old_icon:
await self.delete(old_icon)
return await self.put(scope, key, new_icon_data, **kwargs)

View File

@ -88,6 +88,7 @@ class Permissions(ctypes.Union):
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
EMPTY_PERMISSIONS = Permissions(0)
async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:

View File

@ -89,7 +89,7 @@ class ChannelDispatcher(DispatcherWithFlags):
):
continue
cur_sess = []
cur_sess: List[str] = []
if (
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")

View File

@ -17,6 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import List
from logbook import Logger
from .dispatcher import DispatcherWithState
@ -38,7 +39,7 @@ class FriendDispatcher(DispatcherWithState):
async def dispatch_filter(self, user_id: int, func, event, data):
"""Dispatch an event to all of a users' friends."""
peer_ids = self.state[user_id]
sessions = []
sessions: List[str] = []
for peer_id in peer_ids:
# dispatch to the user instead of the "shards tied to a guild"

View File

@ -39,6 +39,7 @@ from litecord.permissions import (
overwrite_find_mix,
get_permissions,
role_permissions,
EMPTY_PERMISSIONS,
)
from litecord.utils import index_by_func
from litecord.utils import mmh3
@ -264,8 +265,8 @@ class GuildMemberList:
self.list = MemberList()
#: store the states that are subscribed to the list.
# type is {session_id: set[list]}
self.state: Dict[str, Set[List[int, int]]] = defaultdict(set)
# type is {session_id: set[tuple]}
self.state: Dict[str, Set[Tuple[int, int]]] = defaultdict(set)
self._list_lock = asyncio.Lock()
@ -414,8 +415,8 @@ class GuildMemberList:
# inject default groups 'online' and 'offline'
# their position is always going to be the last ones.
self.list.groups = role_groups + [
GroupInfo("online", "online", MAX_ROLES + 1, 0),
GroupInfo("offline", "offline", MAX_ROLES + 2, 0),
GroupInfo("online", "online", MAX_ROLES + 1, EMPTY_PERMISSIONS),
GroupInfo("offline", "offline", MAX_ROLES + 2, EMPTY_PERMISSIONS),
]
async def _get_group_for_member(
@ -808,6 +809,8 @@ class GuildMemberList:
ops = []
old_user_index = self._get_item_index(user_id)
assert old_user_index is not None
old_group_index = self._get_group_item_index(old_group)
ops.append(Operation("DELETE", {"index": old_user_index}))
@ -819,6 +822,7 @@ class GuildMemberList:
await self._sort_groups()
new_user_index = self._get_item_index(user_id)
assert new_user_index is not None
ops.append(
Operation(
@ -931,6 +935,7 @@ class GuildMemberList:
# if unknown state, remove from the subscriber list
if state is None:
self.state.pop(session_id)
continue
# if we aren't talking about the state the user
# being removed is subscribed to, ignore
@ -1365,8 +1370,8 @@ class GuildMemberList:
len(self.state),
)
self.guild_id = None
self.channel_id = None
self.guild_id = 0
self.channel_id = 0
self.main = None
self._set_empty_list()
self.state = {}
@ -1392,7 +1397,7 @@ class LazyGuildDispatcher(Dispatcher):
#: store which guilds have their
# respective GMLs
# {guild_id: [chan_id, ...], ...}
self.guild_map = defaultdict(list)
self.guild_map: Dict[int, List[int]] = defaultdict(list)
async def get_gml(self, channel_id: int):
"""Get a guild list for a channel ID,
@ -1414,7 +1419,18 @@ class LazyGuildDispatcher(Dispatcher):
def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]:
"""Get all member lists for a given guild."""
return list(map(self.state.get, self.guild_map[guild_id]))
res: List[GuildMemberList] = []
channel_ids: List[int] = self.guild_map[guild_id]
for channel_id in channel_ids:
guild_list: Optional[GuildMemberList] = self.state.get(channel_id)
if guild_list is None:
self.guild_map[guild_id].remove(channel_id)
continue
res.append(guild_list)
return res
async def unsub(self, chan_id, session_id):
"""Unsubscribe a session from the list."""

View File

@ -17,13 +17,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Union, TypedDict
from logbook import Logger
from litecord.enums import ChannelType
from litecord.schemas import USER_MENTION, ROLE_MENTION
from litecord.blueprints.channel.reactions import EmojiType, emoji_sql, partial_emoji
from litecord.blueprints.channel.reactions import (
EmojiType,
emoji_sql,
partial_emoji,
PartialEmoji,
)
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
@ -64,6 +69,12 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: str):
return list(filter(lambda recipient: recipient["id"] != user_id, recipients))
class EmojiStats(TypedDict):
count: int
me: bool
emoji: PartialEmoji
class Storage:
"""Class for common SQL statements."""
@ -373,7 +384,7 @@ class Storage:
members = await self.get_member_multi(guild_id, mids)
return members
async def chan_last_message(self, channel_id: int):
async def chan_last_message(self, channel_id: int) -> Optional[int]:
"""Get the last message ID in a channel."""
return await self.db.fetchval(
"""
@ -491,7 +502,7 @@ class Storage:
return [r["member_id"] for r in user_ids]
async def _gdm_recipients(
self, channel_id: int, reference_id: int = None
self, channel_id: int, reference_id: Optional[int] = None
) -> List[Dict]:
"""Get the list of users that are recipients of the
given Group DM."""
@ -576,11 +587,12 @@ class Storage:
drow = dict(gdm_row)
drow["type"] = chan_type
drow["recipients"] = await self._gdm_recipients(
channel_id, kwargs.get("user_id")
)
drow["last_message_id"] = await self.chan_last_message_str(channel_id)
user_id: Optional[int] = kwargs.get("user_id")
assert user_id is not None
drow["recipients"] = await self._gdm_recipients(channel_id, user_id)
drow["last_message_id"] = await self.chan_last_message_str(channel_id)
return drow
return None
@ -634,7 +646,7 @@ class Storage:
return channels
async def get_role(
self, role_id: int, guild_id: int = None
self, role_id: int, guild_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""get a single role's information."""
@ -732,6 +744,8 @@ class Storage:
mids = [int(m["user"]["id"]) for m in members]
assert self.presence is not None
return {
**res,
**{
@ -822,10 +836,10 @@ class Storage:
)
# ordered list of emoji
emoji = []
emoji: List[Union[int, str]] = []
# the current state of emoji info
react_stats = {}
react_stats: Dict[Union[str, int], EmojiStats] = {}
# to generate the list, we pass through all
# all reactions and insert them all.
@ -1007,6 +1021,7 @@ class Storage:
# calculate user mentions and role mentions by regex
async def _get_member(user_id):
user = await self.get_user(user_id)
assert user is not None
member = None
if guild_id:
@ -1143,6 +1158,7 @@ class Storage:
return {}
mids = await self.get_member_ids(guild_id)
assert self.presence is not None
pres = await self.presence.guild_presences(mids, guild_id)
online_count = sum(1 for p in pres if p["status"] == "online")
@ -1172,7 +1188,7 @@ class Storage:
return dinv
async def get_dm(self, dm_id: int, user_id: int = None) -> Optional[Dict]:
async def get_dm(self, dm_id: int, user_id: Optional[int] = None) -> Optional[Dict]:
"""Get a DM channel."""
dm_chan = await self.get_channel(dm_id)

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from logbook import Logger
from quart import current_app as app
from winter import get_snowflake
from litecord.enums import MessageType
@ -25,7 +26,7 @@ from litecord.enums import MessageType
log = Logger(__name__)
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
async def _handle_pin_msg(channel_id, _pinned_id, author_id):
"""Handle a message pin."""
new_id = get_snowflake()
@ -48,7 +49,7 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
# TODO: decrease repetition between add and remove handlers
async def _handle_recp_add(app, channel_id, author_id, peer_id):
async def _handle_recp_add(channel_id, author_id, peer_id):
new_id = get_snowflake()
await app.db.execute(
@ -69,7 +70,7 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id):
return new_id
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
async def _handle_recp_rmv(channel_id, author_id, peer_id):
new_id = get_snowflake()
await app.db.execute(
@ -90,7 +91,7 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
return new_id
async def _handle_gdm_name_edit(app, channel_id, author_id):
async def _handle_gdm_name_edit(channel_id, author_id):
new_id = get_snowflake()
gdm_name = await app.db.fetchval(
@ -123,7 +124,7 @@ async def _handle_gdm_name_edit(app, channel_id, author_id):
return new_id
async def _handle_gdm_icon_edit(app, channel_id, author_id):
async def _handle_gdm_icon_edit(channel_id, author_id):
new_id = get_snowflake()
await app.db.execute(
@ -145,7 +146,7 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id):
async def send_sys_message(
app, channel_id: int, m_type: MessageType, *args, **kwargs
channel_id: int, m_type: MessageType, *args, **kwargs
) -> int:
"""Send a system message.
@ -179,7 +180,7 @@ async def send_sys_message(
except KeyError:
raise ValueError("Invalid system message type")
message_id = await handler(app, channel_id, *args, **kwargs)
message_id = await handler(channel_id, *args, **kwargs)
message = await app.storage.get_message(message_id)

11
mypy.ini Normal file
View File

@ -0,0 +1,11 @@
[mypy]
check_untyped_defs = True
no_implicit_optional = True
[mypy-logbook]
ignore_missing_imports = True
[mypy-quart]
ignore_missing_imports = True
[mypy-winter]
ignore_missing_imports = True
[mypy-asyncpg]
ignore_missing_imports = True

View File

@ -1,5 +1,5 @@
[tox]
envlist = py3.7
envlist = py3.8
[testenv]
deps = -rrequirements.txt