fix types (big)

varied fixes across many, many files.
This commit is contained in:
Luna 2020-02-06 21:10:51 +00:00
parent 020e03bf6d
commit e0e59f8b63
27 changed files with 163 additions and 114 deletions

View File

@ -1,4 +1,4 @@
image: python:3.7-alpine image: python:3.8-alpine
variables: variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"

View File

@ -26,7 +26,7 @@ flake8 = "*"
pyflakes = "*" pyflakes = "*"
[requires] [requires]
python_version = "3.7" python_version = "3.8"
[pipenv] [pipenv]
allow_prereleases = true allow_prereleases = true

25
Pipfile.lock generated
View File

@ -1,11 +1,11 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "bdd693f88f93f57b4a6eca42e0c8f053e500ddc40e344d888af86143c1904751" "sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
"python_version": "3.7" "python_version": "3.8"
}, },
"sources": [ "sources": [
{ {
@ -238,10 +238,10 @@
}, },
"jinja2": { "jinja2": {
"hashes": [ "hashes": [
"sha256:93187ffbc7808079673ef52771baa950426fd664d3aad1d0fa3e95644360e250", "sha256:c10142f819c2d22bdcd17548c46fa9b77cf4fda45097854c689666bf425e7484",
"sha256:b0eaf100007721b5c16c1fc1eecb87409464edc10469ddc9a22a27a99123be49" "sha256:c922560ac46888d47384de1dbdc3daaa2ea993af4b26a436dec31fa2c19ec668"
], ],
"version": "==2.11.1" "version": "==3.0.0a1"
}, },
"logbook": { "logbook": {
"hashes": [ "hashes": [
@ -554,14 +554,6 @@
"index": "pypi", "index": "pypi",
"version": "==3.7.9" "version": "==3.7.9"
}, },
"importlib-metadata": {
"hashes": [
"sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302",
"sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b"
],
"markers": "python_version < '3.8'",
"version": "==1.5.0"
},
"mccabe": { "mccabe": {
"hashes": [ "hashes": [
"sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42",
@ -749,13 +741,6 @@
"sha256:f28b3e8a6483e5d49e7f8949ac1a78314e740333ae305b4ba5defd3e74fb37a8" "sha256:f28b3e8a6483e5d49e7f8949ac1a78314e740333ae305b4ba5defd3e74fb37a8"
], ],
"version": "==0.1.8" "version": "==0.1.8"
},
"zipp": {
"hashes": [
"sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28",
"sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98"
],
"version": "==2.1.0"
} }
} }
} }

View File

@ -66,7 +66,7 @@ or third party libraries (such as [Eris](https://github.com/abalabahaha/eris)).
Requirements: Requirements:
- **Python 3.7+** - **Python 3.8+**
- PostgreSQL (tested using 9.6+), SQL knowledge is recommended. - PostgreSQL (tested using 9.6+), SQL knowledge is recommended.
- gifsicle for GIF emoji and avatar handling - gifsicle for GIF emoji and avatar handling
- [pipenv] - [pipenv]

View File

@ -19,6 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import string import string
from random import choice from random import choice
from typing import Optional
from quart import Blueprint, jsonify, current_app as app, request from quart import Blueprint, jsonify, current_app as app, request
@ -36,7 +37,7 @@ async def _gen_inv() -> str:
return "".join(choice(ALPHABET) for _ in range(6)) return "".join(choice(ALPHABET) for _ in range(6))
async def gen_inv(ctx) -> str: async def gen_inv(ctx) -> Optional[str]:
"""Generate an invite.""" """Generate an invite."""
for _ in range(10): for _ in range(10):
possible_inv = await _gen_inv() possible_inv = await _gen_inv()

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from pathlib import Path from pathlib import Path
from typing import Optional
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
@ -146,7 +147,7 @@ async def _dm_pre_dispatch(channel_id, peer_id):
async def create_message( async def create_message(
channel_id: int, actual_guild_id: int, author_id: int, data: dict channel_id: int, actual_guild_id: Optional[int], author_id: int, data: dict
) -> int: ) -> int:
message_id = get_snowflake() message_id = get_snowflake()
@ -186,7 +187,7 @@ async def _create_message(channel_id):
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
actual_guild_id = None actual_guild_id: Optional[int] = None
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
await channel_perm_check(user_id, channel_id, "send_messages") await channel_perm_check(user_id, channel_id, "send_messages")

View File

@ -114,7 +114,7 @@ async def add_pin(channel_id, message_id):
) )
await send_sys_message( await send_sys_message(
app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
) )
return "", 204 return "", 204

View File

@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from enum import IntEnum from enum import IntEnum
from typing import List, Union, Tuple from typing import List, Union, Tuple, TypedDict, Optional
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
@ -60,7 +60,12 @@ def emoji_info_from_str(emoji: str) -> tuple:
return emoji_type, emoji_id, emoji_name return emoji_type, emoji_id, emoji_name
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: class PartialEmoji(TypedDict):
id: Optional[int]
name: str
def partial_emoji(emoji_type, emoji_id, emoji_name) -> PartialEmoji:
print(emoji_type, emoji_id, emoji_name) print(emoji_type, emoji_id, emoji_name)
return { return {
"id": None if emoji_type == EmojiType.UNICODE else emoji_id, "id": None if emoji_type == EmojiType.UNICODE else emoji_id,

View File

@ -386,7 +386,7 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int):
return "", 204 return "", 204
async def _update_channel_common(channel_id, guild_id: int, j: dict): async def _update_channel_common(channel_id: int, guild_id: int, j: dict):
if "name" in j: if "name" in j:
await app.db.execute( await app.db.execute(
""" """
@ -401,7 +401,9 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
if "position" in j: if "position" in j:
channel_data = await app.storage.get_channel_data(guild_id) channel_data = await app.storage.get_channel_data(guild_id)
chans = [None] * len(channel_data) # get an ordered list of the chans array by position
# TODO bad impl. can break easily. maybe dict?
chans: List[Optional[int]] = [None] * len(channel_data)
for chandata in channel_data: for chandata in channel_data:
chans.insert(chandata["position"], int(chandata["id"])) chans.insert(chandata["position"], int(chandata["id"]))
@ -422,7 +424,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
left_shift = new_pos > current_pos left_shift = new_pos > current_pos
# find all channels that we'll have to shift # find all channels that we'll have to shift
shift_block = ( shift_block: List[Optional[int]] = (
chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos] chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
) )
@ -509,9 +511,7 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int):
channel_id, channel_id,
) )
await send_sys_message( await send_sys_message(channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id)
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
)
if "icon" in j: if "icon" in j:
new_icon = await app.icons.update( new_icon = await app.icons.update(
@ -528,13 +528,11 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int):
channel_id, channel_id,
) )
await send_sys_message( await send_sys_message(channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id)
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
)
@bp.route("/<int:channel_id>", methods=["PUT", "PATCH"]) @bp.route("/<int:channel_id>", methods=["PUT", "PATCH"])
async def update_channel(channel_id): async def update_channel(channel_id: int):
"""Update a channel's information""" """Update a channel's information"""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Union, List from typing import Union, List, Optional
from quart import current_app as app from quart import current_app as app
@ -66,7 +66,7 @@ async def guild_owner_check(user_id: int, guild_id: int):
async def channel_check( async def channel_check(
user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None user_id, channel_id, *, only: Optional[Union[ChannelType, List[ChannelType]]] = None
): ):
"""Check if the current user is authorized """Check if the current user is authorized
to read the channel's information.""" to read the channel's information."""
@ -77,10 +77,10 @@ async def channel_check(
ctype = ChannelType(chan_type) ctype = ChannelType(chan_type)
if only and not isinstance(only, list): if (only is not None) and not isinstance(only, list):
only = [only] only = [only]
if only and ctype not in only: if (only is not None) and ctype not in only:
raise ChannelNotFound("invalid channel type") raise ChannelNotFound("invalid channel type")
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:

View File

@ -113,9 +113,7 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
await app.dispatcher.sub("channel", peer_id) await app.dispatcher.sub("channel", peer_id)
if user_id: if user_id:
await send_sys_message( await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
)
async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None): async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
@ -145,9 +143,7 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
author_id = peer_id if user_id is None else user_id author_id = peer_id if user_id is None else user_id
await send_sys_message( await send_sys_message(channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id)
app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
)
async def gdm_destroy(channel_id): async def gdm_destroy(channel_id):

View File

@ -130,14 +130,11 @@ async def modify_channel_pos(guild_id):
# the same schema and all. # the same schema and all.
raw_j = await request.get_json() raw_j = await request.get_json()
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION) j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
j = j["roles"] roles = j["roles"]
channels = await app.storage.get_channel_data(guild_id) channels = await app.storage.get_channel_data(guild_id)
channel_positions = {chan["position"]: int(chan["id"]) for chan in channels} channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
swap_pairs = gen_pairs(roles, channel_positions)
swap_pairs = gen_pairs(j, channel_positions)
await _do_channel_swaps(guild_id, swap_pairs) await _do_channel_swaps(guild_id, swap_pairs)
return "", 204 return "", 204

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List, Dict, Tuple from typing import List, Dict, Tuple, Optional
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
@ -122,7 +122,7 @@ PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
def gen_pairs( def gen_pairs(
list_of_changes: List[Dict[str, int]], list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int], current_state: Dict[int, int],
blacklist: List[int] = None, blacklist: Optional[List[int]] = None,
) -> PairList: ) -> PairList:
"""Generate a list of pairs that, when applied to the database, """Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes. will generate the desired state given in list_of_changes.
@ -162,7 +162,7 @@ def gen_pairs(
List of swaps to do to achieve the preferred List of swaps to do to achieve the preferred
state given by ``list_of_changes``. state given by ``list_of_changes``.
""" """
pairs = [] pairs: PairList = []
blacklist = blacklist or [] blacklist = blacklist or []
preferred_state = { preferred_state = {
@ -222,9 +222,9 @@ async def update_guild_role_positions(guild_id):
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION) j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
# extract the list out # extract the list out
j = j["roles"] roles = j["roles"]
log.debug("role stuff: {!r}", j) log.debug("role stuff: {!r}", roles)
all_roles = await app.storage.get_role_data(guild_id) all_roles = await app.storage.get_role_data(guild_id)
@ -238,7 +238,7 @@ async def update_guild_role_positions(guild_id):
# NOTE: ^ this is related to the positioning of the roles. # NOTE: ^ this is related to the positioning of the roles.
pairs = gen_pairs( pairs = gen_pairs(
j, roles,
roles_pos, roles_pos,
# always ignore people trying to change # always ignore people trying to change
# the @everyone's role position # the @everyone's role position

View File

@ -96,7 +96,7 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
return f"data:image/jpeg;base64,{icon}" if icon else None return f"data:image/jpeg;base64,{icon}" if icon else None
async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs): async def _general_guild_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
encoded = sanitize_icon(icon) encoded = sanitize_icon(icon)
icon_kwargs = {"always_icon": True} icon_kwargs = {"always_icon": True}

View File

@ -222,14 +222,14 @@ async def get_guild_webhook(guild_id):
async def get_single_webhook(webhook_id): async def get_single_webhook(webhook_id):
"""Get a single webhook's information.""" """Get a single webhook's information."""
await _webhook_check_fw(webhook_id) await _webhook_check_fw(webhook_id)
return await jsonify(await get_webhook(webhook_id)) return jsonify(await get_webhook(webhook_id))
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"]) @bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"])
async def get_tokened_webhook(webhook_id, webhook_token): async def get_tokened_webhook(webhook_id, webhook_token):
"""Get a webhook using its token.""" """Get a webhook using its token."""
await webhook_token_check(webhook_id, webhook_token) await webhook_token_check(webhook_id, webhook_token)
return await jsonify(await get_webhook(webhook_id, secure=False)) return jsonify(await get_webhook(webhook_id, secure=False))
async def _update_webhook(webhook_id: int, j: dict): async def _update_webhook(webhook_id: int, j: dict):
@ -289,6 +289,7 @@ async def modify_webhook(webhook_id: int):
await _update_webhook(webhook_id, j) await _update_webhook(webhook_id, j)
webhook = await get_webhook(webhook_id) webhook = await get_webhook(webhook_id)
assert webhook is not None
# we don't need to cast channel_id to int since that isn't # we don't need to cast channel_id to int since that isn't
# used in the dispatcher call # used in the dispatcher call
@ -313,7 +314,9 @@ async def modify_webhook_tokened(webhook_id, webhook_token):
async def delete_webhook(webhook_id: int): async def delete_webhook(webhook_id: int):
"""Delete a webhook.""" """Delete a webhook."""
webhook = await get_webhook(webhook_id) webhook = await get_webhook(webhook_id)
assert webhook is not None
# TODO use returning?
res = await app.db.execute( res = await app.db.execute(
""" """
DELETE FROM webhooks DELETE FROM webhooks
@ -423,7 +426,10 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
# we still fetch the URL to check its validity, mimetypes, etc # we still fetch the URL to check its validity, mimetypes, etc
# but in the end, we will store it under the webhook_avatars table, # but in the end, we will store it under the webhook_avatars table,
# not IconManager. # not IconManager.
resp, raw = await fetch_mediaproxy_img(avatar_url) res = await fetch_mediaproxy_img(avatar_url)
if res is None:
raise BadRequest("Failed to fetch URL.")
resp, raw = res
# raw_b64 = base64.b64encode(raw).decode() # raw_b64 = base64.b64encode(raw).decode()
mime = resp.headers["content-type"] mime = resp.headers["content-type"]
@ -469,6 +475,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
given_embeds = j.get("embeds", []) given_embeds = j.get("embeds", [])
webhook = await get_webhook(webhook_id) webhook = await get_webhook(webhook_id)
assert webhook is not None
# webhooks have TWO avatars. one is from settings, the other is from # webhooks have TWO avatars. one is from settings, the other is from
# the json's icon_url. one can be handled gracefully by IconManager, # the json's icon_url. one can be handled gracefully by IconManager,

View File

@ -16,6 +16,7 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Optional
from quart import current_app as app from quart import current_app as app
@ -24,12 +25,15 @@ from litecord.enums import RelationshipType
async def channel_ack( async def channel_ack(
user_id: int, guild_id: int, channel_id: int, message_id: int = None user_id: int, guild_id: int, channel_id: int, message_id: Optional[int] = None
): ):
"""ACK a channel.""" """ACK a channel."""
if not message_id: message_id = message_id or await app.storage.chan_last_message(channel_id)
message_id = await app.storage.chan_last_message(channel_id)
# never ack without a message, as that breaks read state.
if message_id is None:
return
await app.db.execute( await app.db.execute(
""" """

View File

@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from random import randint from random import randint
from typing import Tuple, Optional from typing import Tuple, Optional, List
from quart import current_app as app from quart import current_app as app
from asyncpg import UniqueViolationError from asyncpg import UniqueViolationError
@ -33,13 +33,13 @@ from ..utils import rand_hex
log = Logger(__name__) log = Logger(__name__)
async def mass_user_update(user_id): async def mass_user_update(user_id: int):
"""Dispatch USER_UPDATE in a mass way.""" """Dispatch USER_UPDATE in a mass way."""
# by using dispatch_with_filter # by using dispatch_with_filter
# we're guaranteeing all shards will get # we're guaranteeing all shards will get
# a USER_UPDATE once and not any others. # a USER_UPDATE once and not any others.
session_ids = [] session_ids: List[str] = []
public_user = await app.storage.get_user(user_id) public_user = await app.storage.get_user(user_id)
private_user = await app.storage.get_user(user_id, secure=True) private_user = await app.storage.get_user(user_id, secure=True)

View File

@ -21,7 +21,7 @@ import re
import asyncio import asyncio
import urllib.parse import urllib.parse
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Optional
from quart import current_app as app from quart import current_app as app
from logbook import Logger from logbook import Logger
@ -35,16 +35,16 @@ log = Logger(__name__)
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm") MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
async def fetch_mediaproxy_img_meta(url) -> dict: async def fetch_mediaproxy_img_meta(url) -> Optional[dict]:
"""Insert media metadata as an embed.""" """Insert media metadata as an embed."""
img_proxy_url = proxify(url) img_proxy_url = proxify(url)
meta = await fetch_metadata(url) meta = await fetch_metadata(url)
if meta is None: if meta is None:
return return None
if not meta["image"]: if not meta["image"]:
return return None
return { return {
"type": "image", "type": "image",
@ -135,13 +135,15 @@ async def process_url_embed(payload: dict, *, delay=0):
# if it isn't, we forward an /embed/ scope call to mediaproxy # if it isn't, we forward an /embed/ scope call to mediaproxy
# to generate an embed for us out of the url. # to generate an embed for us out of the url.
new_embeds = [] new_embeds: List[dict] = []
for url in urls: for upstream_url in urls:
url: List[dict] = EmbedURL(url) url = EmbedURL(upstream_url)
if is_media_url(url): if is_media_url(url):
embeds = [await fetch_mediaproxy_img_meta(url)] embed = await fetch_mediaproxy_img_meta(url)
if embed is not None:
embeds = [embed]
else: else:
embeds = await fetch_mediaproxy_embed(url) embeds = await fetch_mediaproxy_embed(url)

View File

@ -98,6 +98,9 @@ class Icon:
return get_ext(self.mime) return get_ext(self.mime)
def __bool__(self):
return self.key and self.icon_hash and self.mime
class ImageError(Exception): class ImageError(Exception):
"""Image error class.""" """Image error class."""
@ -197,6 +200,9 @@ def _gen_update_sql(scope: str) -> str:
def _invalid(kwargs: dict) -> Optional[Icon]: def _invalid(kwargs: dict) -> Optional[Icon]:
"""Send an invalid value.""" """Send an invalid value."""
# TODO: remove optinality off this (really badly designed):
# - also remove kwargs off this function
# - also make an Icon.empty() constructor, and remove need for this entirely
if not kwargs.get("always_icon", False): if not kwargs.get("always_icon", False):
return None return None
@ -519,6 +525,7 @@ class IconManager:
key = str(key) key = str(key)
old_icon = await self.generic_get(scope, key, old_icon_hash) old_icon = await self.generic_get(scope, key, old_icon_hash)
if old_icon:
await self.delete(old_icon) await self.delete(old_icon)
return await self.put(scope, key, new_icon_data, **kwargs) return await self.put(scope, key, new_icon_data, **kwargs)

View File

@ -88,6 +88,7 @@ class Permissions(ctypes.Union):
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
EMPTY_PERMISSIONS = Permissions(0)
async def get_role_perms(guild_id, role_id, storage=None) -> Permissions: async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:

View File

@ -89,7 +89,7 @@ class ChannelDispatcher(DispatcherWithFlags):
): ):
continue continue
cur_sess = [] cur_sess: List[str] = []
if ( if (
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE") event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")

View File

@ -17,6 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithState from .dispatcher import DispatcherWithState
@ -38,7 +39,7 @@ class FriendDispatcher(DispatcherWithState):
async def dispatch_filter(self, user_id: int, func, event, data): async def dispatch_filter(self, user_id: int, func, event, data):
"""Dispatch an event to all of a users' friends.""" """Dispatch an event to all of a users' friends."""
peer_ids = self.state[user_id] peer_ids = self.state[user_id]
sessions = [] sessions: List[str] = []
for peer_id in peer_ids: for peer_id in peer_ids:
# dispatch to the user instead of the "shards tied to a guild" # dispatch to the user instead of the "shards tied to a guild"

View File

@ -39,6 +39,7 @@ from litecord.permissions import (
overwrite_find_mix, overwrite_find_mix,
get_permissions, get_permissions,
role_permissions, role_permissions,
EMPTY_PERMISSIONS,
) )
from litecord.utils import index_by_func from litecord.utils import index_by_func
from litecord.utils import mmh3 from litecord.utils import mmh3
@ -264,8 +265,8 @@ class GuildMemberList:
self.list = MemberList() self.list = MemberList()
#: store the states that are subscribed to the list. #: store the states that are subscribed to the list.
# type is {session_id: set[list]} # type is {session_id: set[tuple]}
self.state: Dict[str, Set[List[int, int]]] = defaultdict(set) self.state: Dict[str, Set[Tuple[int, int]]] = defaultdict(set)
self._list_lock = asyncio.Lock() self._list_lock = asyncio.Lock()
@ -414,8 +415,8 @@ class GuildMemberList:
# inject default groups 'online' and 'offline' # inject default groups 'online' and 'offline'
# their position is always going to be the last ones. # their position is always going to be the last ones.
self.list.groups = role_groups + [ self.list.groups = role_groups + [
GroupInfo("online", "online", MAX_ROLES + 1, 0), GroupInfo("online", "online", MAX_ROLES + 1, EMPTY_PERMISSIONS),
GroupInfo("offline", "offline", MAX_ROLES + 2, 0), GroupInfo("offline", "offline", MAX_ROLES + 2, EMPTY_PERMISSIONS),
] ]
async def _get_group_for_member( async def _get_group_for_member(
@ -808,6 +809,8 @@ class GuildMemberList:
ops = [] ops = []
old_user_index = self._get_item_index(user_id) old_user_index = self._get_item_index(user_id)
assert old_user_index is not None
old_group_index = self._get_group_item_index(old_group) old_group_index = self._get_group_item_index(old_group)
ops.append(Operation("DELETE", {"index": old_user_index})) ops.append(Operation("DELETE", {"index": old_user_index}))
@ -819,6 +822,7 @@ class GuildMemberList:
await self._sort_groups() await self._sort_groups()
new_user_index = self._get_item_index(user_id) new_user_index = self._get_item_index(user_id)
assert new_user_index is not None
ops.append( ops.append(
Operation( Operation(
@ -931,6 +935,7 @@ class GuildMemberList:
# if unknown state, remove from the subscriber list # if unknown state, remove from the subscriber list
if state is None: if state is None:
self.state.pop(session_id) self.state.pop(session_id)
continue
# if we aren't talking about the state the user # if we aren't talking about the state the user
# being removed is subscribed to, ignore # being removed is subscribed to, ignore
@ -1365,8 +1370,8 @@ class GuildMemberList:
len(self.state), len(self.state),
) )
self.guild_id = None self.guild_id = 0
self.channel_id = None self.channel_id = 0
self.main = None self.main = None
self._set_empty_list() self._set_empty_list()
self.state = {} self.state = {}
@ -1392,7 +1397,7 @@ class LazyGuildDispatcher(Dispatcher):
#: store which guilds have their #: store which guilds have their
# respective GMLs # respective GMLs
# {guild_id: [chan_id, ...], ...} # {guild_id: [chan_id, ...], ...}
self.guild_map = defaultdict(list) self.guild_map: Dict[int, List[int]] = defaultdict(list)
async def get_gml(self, channel_id: int): async def get_gml(self, channel_id: int):
"""Get a guild list for a channel ID, """Get a guild list for a channel ID,
@ -1414,7 +1419,18 @@ class LazyGuildDispatcher(Dispatcher):
def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]: def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]:
"""Get all member lists for a given guild.""" """Get all member lists for a given guild."""
return list(map(self.state.get, self.guild_map[guild_id])) res: List[GuildMemberList] = []
channel_ids: List[int] = self.guild_map[guild_id]
for channel_id in channel_ids:
guild_list: Optional[GuildMemberList] = self.state.get(channel_id)
if guild_list is None:
self.guild_map[guild_id].remove(channel_id)
continue
res.append(guild_list)
return res
async def unsub(self, chan_id, session_id): async def unsub(self, chan_id, session_id):
"""Unsubscribe a session from the list.""" """Unsubscribe a session from the list."""

View File

@ -17,13 +17,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Union, TypedDict
from logbook import Logger from logbook import Logger
from litecord.enums import ChannelType from litecord.enums import ChannelType
from litecord.schemas import USER_MENTION, ROLE_MENTION from litecord.schemas import USER_MENTION, ROLE_MENTION
from litecord.blueprints.channel.reactions import EmojiType, emoji_sql, partial_emoji from litecord.blueprints.channel.reactions import (
EmojiType,
emoji_sql,
partial_emoji,
PartialEmoji,
)
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
@ -64,6 +69,12 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: str):
return list(filter(lambda recipient: recipient["id"] != user_id, recipients)) return list(filter(lambda recipient: recipient["id"] != user_id, recipients))
class EmojiStats(TypedDict):
count: int
me: bool
emoji: PartialEmoji
class Storage: class Storage:
"""Class for common SQL statements.""" """Class for common SQL statements."""
@ -373,7 +384,7 @@ class Storage:
members = await self.get_member_multi(guild_id, mids) members = await self.get_member_multi(guild_id, mids)
return members return members
async def chan_last_message(self, channel_id: int): async def chan_last_message(self, channel_id: int) -> Optional[int]:
"""Get the last message ID in a channel.""" """Get the last message ID in a channel."""
return await self.db.fetchval( return await self.db.fetchval(
""" """
@ -491,7 +502,7 @@ class Storage:
return [r["member_id"] for r in user_ids] return [r["member_id"] for r in user_ids]
async def _gdm_recipients( async def _gdm_recipients(
self, channel_id: int, reference_id: int = None self, channel_id: int, reference_id: Optional[int] = None
) -> List[Dict]: ) -> List[Dict]:
"""Get the list of users that are recipients of the """Get the list of users that are recipients of the
given Group DM.""" given Group DM."""
@ -576,11 +587,12 @@ class Storage:
drow = dict(gdm_row) drow = dict(gdm_row)
drow["type"] = chan_type drow["type"] = chan_type
drow["recipients"] = await self._gdm_recipients(
channel_id, kwargs.get("user_id")
)
drow["last_message_id"] = await self.chan_last_message_str(channel_id)
user_id: Optional[int] = kwargs.get("user_id")
assert user_id is not None
drow["recipients"] = await self._gdm_recipients(channel_id, user_id)
drow["last_message_id"] = await self.chan_last_message_str(channel_id)
return drow return drow
return None return None
@ -634,7 +646,7 @@ class Storage:
return channels return channels
async def get_role( async def get_role(
self, role_id: int, guild_id: int = None self, role_id: int, guild_id: Optional[int] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""get a single role's information.""" """get a single role's information."""
@ -732,6 +744,8 @@ class Storage:
mids = [int(m["user"]["id"]) for m in members] mids = [int(m["user"]["id"]) for m in members]
assert self.presence is not None
return { return {
**res, **res,
**{ **{
@ -822,10 +836,10 @@ class Storage:
) )
# ordered list of emoji # ordered list of emoji
emoji = [] emoji: List[Union[int, str]] = []
# the current state of emoji info # the current state of emoji info
react_stats = {} react_stats: Dict[Union[str, int], EmojiStats] = {}
# to generate the list, we pass through all # to generate the list, we pass through all
# all reactions and insert them all. # all reactions and insert them all.
@ -1007,6 +1021,7 @@ class Storage:
# calculate user mentions and role mentions by regex # calculate user mentions and role mentions by regex
async def _get_member(user_id): async def _get_member(user_id):
user = await self.get_user(user_id) user = await self.get_user(user_id)
assert user is not None
member = None member = None
if guild_id: if guild_id:
@ -1143,6 +1158,7 @@ class Storage:
return {} return {}
mids = await self.get_member_ids(guild_id) mids = await self.get_member_ids(guild_id)
assert self.presence is not None
pres = await self.presence.guild_presences(mids, guild_id) pres = await self.presence.guild_presences(mids, guild_id)
online_count = sum(1 for p in pres if p["status"] == "online") online_count = sum(1 for p in pres if p["status"] == "online")
@ -1172,7 +1188,7 @@ class Storage:
return dinv return dinv
async def get_dm(self, dm_id: int, user_id: int = None) -> Optional[Dict]: async def get_dm(self, dm_id: int, user_id: Optional[int] = None) -> Optional[Dict]:
"""Get a DM channel.""" """Get a DM channel."""
dm_chan = await self.get_channel(dm_id) dm_chan = await self.get_channel(dm_id)

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from logbook import Logger from logbook import Logger
from quart import current_app as app
from winter import get_snowflake from winter import get_snowflake
from litecord.enums import MessageType from litecord.enums import MessageType
@ -25,7 +26,7 @@ from litecord.enums import MessageType
log = Logger(__name__) log = Logger(__name__)
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id): async def _handle_pin_msg(channel_id, _pinned_id, author_id):
"""Handle a message pin.""" """Handle a message pin."""
new_id = get_snowflake() new_id = get_snowflake()
@ -48,7 +49,7 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
# TODO: decrease repetition between add and remove handlers # TODO: decrease repetition between add and remove handlers
async def _handle_recp_add(app, channel_id, author_id, peer_id): async def _handle_recp_add(channel_id, author_id, peer_id):
new_id = get_snowflake() new_id = get_snowflake()
await app.db.execute( await app.db.execute(
@ -69,7 +70,7 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id):
return new_id return new_id
async def _handle_recp_rmv(app, channel_id, author_id, peer_id): async def _handle_recp_rmv(channel_id, author_id, peer_id):
new_id = get_snowflake() new_id = get_snowflake()
await app.db.execute( await app.db.execute(
@ -90,7 +91,7 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
return new_id return new_id
async def _handle_gdm_name_edit(app, channel_id, author_id): async def _handle_gdm_name_edit(channel_id, author_id):
new_id = get_snowflake() new_id = get_snowflake()
gdm_name = await app.db.fetchval( gdm_name = await app.db.fetchval(
@ -123,7 +124,7 @@ async def _handle_gdm_name_edit(app, channel_id, author_id):
return new_id return new_id
async def _handle_gdm_icon_edit(app, channel_id, author_id): async def _handle_gdm_icon_edit(channel_id, author_id):
new_id = get_snowflake() new_id = get_snowflake()
await app.db.execute( await app.db.execute(
@ -145,7 +146,7 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id):
async def send_sys_message( async def send_sys_message(
app, channel_id: int, m_type: MessageType, *args, **kwargs channel_id: int, m_type: MessageType, *args, **kwargs
) -> int: ) -> int:
"""Send a system message. """Send a system message.
@ -179,7 +180,7 @@ async def send_sys_message(
except KeyError: except KeyError:
raise ValueError("Invalid system message type") raise ValueError("Invalid system message type")
message_id = await handler(app, channel_id, *args, **kwargs) message_id = await handler(channel_id, *args, **kwargs)
message = await app.storage.get_message(message_id) message = await app.storage.get_message(message_id)

11
mypy.ini Normal file
View File

@ -0,0 +1,11 @@
[mypy]
check_untyped_defs = True
no_implicit_optional = True
[mypy-logbook]
ignore_missing_imports = True
[mypy-quart]
ignore_missing_imports = True
[mypy-winter]
ignore_missing_imports = True
[mypy-asyncpg]
ignore_missing_imports = True

View File

@ -1,5 +1,5 @@
[tox] [tox]
envlist = py3.7 envlist = py3.8
[testenv] [testenv]
deps = -rrequirements.txt deps = -rrequirements.txt