mirror of https://gitlab.com/litecord/litecord.git
parent
020e03bf6d
commit
e0e59f8b63
|
|
@ -1,4 +1,4 @@
|
|||
image: python:3.7-alpine
|
||||
image: python:3.8-alpine
|
||||
|
||||
variables:
|
||||
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
|
||||
|
|
|
|||
2
Pipfile
2
Pipfile
|
|
@ -26,7 +26,7 @@ flake8 = "*"
|
|||
pyflakes = "*"
|
||||
|
||||
[requires]
|
||||
python_version = "3.7"
|
||||
python_version = "3.8"
|
||||
|
||||
[pipenv]
|
||||
allow_prereleases = true
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "bdd693f88f93f57b4a6eca42e0c8f053e500ddc40e344d888af86143c1904751"
|
||||
"sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
"python_version": "3.7"
|
||||
"python_version": "3.8"
|
||||
},
|
||||
"sources": [
|
||||
{
|
||||
|
|
@ -238,10 +238,10 @@
|
|||
},
|
||||
"jinja2": {
|
||||
"hashes": [
|
||||
"sha256:93187ffbc7808079673ef52771baa950426fd664d3aad1d0fa3e95644360e250",
|
||||
"sha256:b0eaf100007721b5c16c1fc1eecb87409464edc10469ddc9a22a27a99123be49"
|
||||
"sha256:c10142f819c2d22bdcd17548c46fa9b77cf4fda45097854c689666bf425e7484",
|
||||
"sha256:c922560ac46888d47384de1dbdc3daaa2ea993af4b26a436dec31fa2c19ec668"
|
||||
],
|
||||
"version": "==2.11.1"
|
||||
"version": "==3.0.0a1"
|
||||
},
|
||||
"logbook": {
|
||||
"hashes": [
|
||||
|
|
@ -554,14 +554,6 @@
|
|||
"index": "pypi",
|
||||
"version": "==3.7.9"
|
||||
},
|
||||
"importlib-metadata": {
|
||||
"hashes": [
|
||||
"sha256:06f5b3a99029c7134207dd882428a66992a9de2bef7c2b699b5641f9886c3302",
|
||||
"sha256:b97607a1a18a5100839aec1dc26a1ea17ee0d93b20b0f008d80a5a050afb200b"
|
||||
],
|
||||
"markers": "python_version < '3.8'",
|
||||
"version": "==1.5.0"
|
||||
},
|
||||
"mccabe": {
|
||||
"hashes": [
|
||||
"sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42",
|
||||
|
|
@ -749,13 +741,6 @@
|
|||
"sha256:f28b3e8a6483e5d49e7f8949ac1a78314e740333ae305b4ba5defd3e74fb37a8"
|
||||
],
|
||||
"version": "==0.1.8"
|
||||
},
|
||||
"zipp": {
|
||||
"hashes": [
|
||||
"sha256:ccc94ed0909b58ffe34430ea5451f07bc0c76467d7081619a454bf5c98b89e28",
|
||||
"sha256:feae2f18633c32fc71f2de629bfb3bd3c9325cd4419642b1f1da42ee488d9b98"
|
||||
],
|
||||
"version": "==2.1.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ or third party libraries (such as [Eris](https://github.com/abalabahaha/eris)).
|
|||
|
||||
Requirements:
|
||||
|
||||
- **Python 3.7+**
|
||||
- **Python 3.8+**
|
||||
- PostgreSQL (tested using 9.6+), SQL knowledge is recommended.
|
||||
- gifsicle for GIF emoji and avatar handling
|
||||
- [pipenv]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import string
|
||||
from random import choice
|
||||
from typing import Optional
|
||||
|
||||
from quart import Blueprint, jsonify, current_app as app, request
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ async def _gen_inv() -> str:
|
|||
return "".join(choice(ALPHABET) for _ in range(6))
|
||||
|
||||
|
||||
async def gen_inv(ctx) -> str:
|
||||
async def gen_inv(ctx) -> Optional[str]:
|
||||
"""Generate an invite."""
|
||||
for _ in range(10):
|
||||
possible_inv = await _gen_inv()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
|
@ -146,7 +147,7 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
|||
|
||||
|
||||
async def create_message(
|
||||
channel_id: int, actual_guild_id: int, author_id: int, data: dict
|
||||
channel_id: int, actual_guild_id: Optional[int], author_id: int, data: dict
|
||||
) -> int:
|
||||
message_id = get_snowflake()
|
||||
|
||||
|
|
@ -159,7 +160,7 @@ async def create_message(
|
|||
content, tts, mention_everyone, nonce, message_type,
|
||||
embeds)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
""",
|
||||
""",
|
||||
message_id,
|
||||
channel_id,
|
||||
actual_guild_id,
|
||||
|
|
@ -186,7 +187,7 @@ async def _create_message(channel_id):
|
|||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
actual_guild_id = None
|
||||
actual_guild_id: Optional[int] = None
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
await channel_perm_check(user_id, channel_id, "send_messages")
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ async def add_pin(channel_id, message_id):
|
|||
)
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
|
||||
channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import List, Union, Tuple
|
||||
from typing import List, Union, Tuple, TypedDict, Optional
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
|
@ -60,7 +60,12 @@ def emoji_info_from_str(emoji: str) -> tuple:
|
|||
return emoji_type, emoji_id, emoji_name
|
||||
|
||||
|
||||
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
|
||||
class PartialEmoji(TypedDict):
|
||||
id: Optional[int]
|
||||
name: str
|
||||
|
||||
|
||||
def partial_emoji(emoji_type, emoji_id, emoji_name) -> PartialEmoji:
|
||||
print(emoji_type, emoji_id, emoji_name)
|
||||
return {
|
||||
"id": None if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||
|
|
|
|||
|
|
@ -386,14 +386,14 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
|||
return "", 204
|
||||
|
||||
|
||||
async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
||||
async def _update_channel_common(channel_id: int, guild_id: int, j: dict):
|
||||
if "name" in j:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guild_channels
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""",
|
||||
UPDATE guild_channels
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""",
|
||||
j["name"],
|
||||
channel_id,
|
||||
)
|
||||
|
|
@ -401,7 +401,9 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
|||
if "position" in j:
|
||||
channel_data = await app.storage.get_channel_data(guild_id)
|
||||
|
||||
chans = [None] * len(channel_data)
|
||||
# get an ordered list of the chans array by position
|
||||
# TODO bad impl. can break easily. maybe dict?
|
||||
chans: List[Optional[int]] = [None] * len(channel_data)
|
||||
for chandata in channel_data:
|
||||
chans.insert(chandata["position"], int(chandata["id"]))
|
||||
|
||||
|
|
@ -422,7 +424,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
|||
left_shift = new_pos > current_pos
|
||||
|
||||
# find all channels that we'll have to shift
|
||||
shift_block = (
|
||||
shift_block: List[Optional[int]] = (
|
||||
chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
|
||||
)
|
||||
|
||||
|
|
@ -509,9 +511,7 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int):
|
|||
channel_id,
|
||||
)
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
|
||||
)
|
||||
await send_sys_message(channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id)
|
||||
|
||||
if "icon" in j:
|
||||
new_icon = await app.icons.update(
|
||||
|
|
@ -528,13 +528,11 @@ async def _update_group_dm(channel_id: int, j: dict, author_id: int):
|
|||
channel_id,
|
||||
)
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
|
||||
)
|
||||
await send_sys_message(channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id)
|
||||
|
||||
|
||||
@bp.route("/<int:channel_id>", methods=["PUT", "PATCH"])
|
||||
async def update_channel(channel_id):
|
||||
async def update_channel(channel_id: int):
|
||||
"""Update a channel's information"""
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Union, List
|
||||
from typing import Union, List, Optional
|
||||
|
||||
from quart import current_app as app
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ async def guild_owner_check(user_id: int, guild_id: int):
|
|||
|
||||
|
||||
async def channel_check(
|
||||
user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None
|
||||
user_id, channel_id, *, only: Optional[Union[ChannelType, List[ChannelType]]] = None
|
||||
):
|
||||
"""Check if the current user is authorized
|
||||
to read the channel's information."""
|
||||
|
|
@ -77,10 +77,10 @@ async def channel_check(
|
|||
|
||||
ctype = ChannelType(chan_type)
|
||||
|
||||
if only and not isinstance(only, list):
|
||||
if (only is not None) and not isinstance(only, list):
|
||||
only = [only]
|
||||
|
||||
if only and ctype not in only:
|
||||
if (only is not None) and ctype not in only:
|
||||
raise ChannelNotFound("invalid channel type")
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
|
|
|
|||
|
|
@ -113,9 +113,7 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
await app.dispatcher.sub("channel", peer_id)
|
||||
|
||||
if user_id:
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
|
||||
)
|
||||
await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
|
||||
|
||||
|
||||
async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
||||
|
|
@ -145,9 +143,7 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
|
||||
author_id = peer_id if user_id is None else user_id
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
|
||||
)
|
||||
await send_sys_message(channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id)
|
||||
|
||||
|
||||
async def gdm_destroy(channel_id):
|
||||
|
|
|
|||
|
|
@ -130,14 +130,11 @@ async def modify_channel_pos(guild_id):
|
|||
# the same schema and all.
|
||||
raw_j = await request.get_json()
|
||||
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
|
||||
j = j["roles"]
|
||||
roles = j["roles"]
|
||||
|
||||
channels = await app.storage.get_channel_data(guild_id)
|
||||
|
||||
channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
|
||||
|
||||
swap_pairs = gen_pairs(j, channel_positions)
|
||||
|
||||
swap_pairs = gen_pairs(roles, channel_positions)
|
||||
await _do_channel_swaps(guild_id, swap_pairs)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
|
@ -122,7 +122,7 @@ PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
|
|||
def gen_pairs(
|
||||
list_of_changes: List[Dict[str, int]],
|
||||
current_state: Dict[int, int],
|
||||
blacklist: List[int] = None,
|
||||
blacklist: Optional[List[int]] = None,
|
||||
) -> PairList:
|
||||
"""Generate a list of pairs that, when applied to the database,
|
||||
will generate the desired state given in list_of_changes.
|
||||
|
|
@ -162,7 +162,7 @@ def gen_pairs(
|
|||
List of swaps to do to achieve the preferred
|
||||
state given by ``list_of_changes``.
|
||||
"""
|
||||
pairs = []
|
||||
pairs: PairList = []
|
||||
blacklist = blacklist or []
|
||||
|
||||
preferred_state = {
|
||||
|
|
@ -222,9 +222,9 @@ async def update_guild_role_positions(guild_id):
|
|||
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
|
||||
|
||||
# extract the list out
|
||||
j = j["roles"]
|
||||
roles = j["roles"]
|
||||
|
||||
log.debug("role stuff: {!r}", j)
|
||||
log.debug("role stuff: {!r}", roles)
|
||||
|
||||
all_roles = await app.storage.get_role_data(guild_id)
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ async def update_guild_role_positions(guild_id):
|
|||
# NOTE: ^ this is related to the positioning of the roles.
|
||||
|
||||
pairs = gen_pairs(
|
||||
j,
|
||||
roles,
|
||||
roles_pos,
|
||||
# always ignore people trying to change
|
||||
# the @everyone's role position
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
|
|||
return f"data:image/jpeg;base64,{icon}" if icon else None
|
||||
|
||||
|
||||
async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs):
|
||||
async def _general_guild_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
|
||||
encoded = sanitize_icon(icon)
|
||||
|
||||
icon_kwargs = {"always_icon": True}
|
||||
|
|
|
|||
|
|
@ -222,14 +222,14 @@ async def get_guild_webhook(guild_id):
|
|||
async def get_single_webhook(webhook_id):
|
||||
"""Get a single webhook's information."""
|
||||
await _webhook_check_fw(webhook_id)
|
||||
return await jsonify(await get_webhook(webhook_id))
|
||||
return jsonify(await get_webhook(webhook_id))
|
||||
|
||||
|
||||
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"])
|
||||
async def get_tokened_webhook(webhook_id, webhook_token):
|
||||
"""Get a webhook using its token."""
|
||||
await webhook_token_check(webhook_id, webhook_token)
|
||||
return await jsonify(await get_webhook(webhook_id, secure=False))
|
||||
return jsonify(await get_webhook(webhook_id, secure=False))
|
||||
|
||||
|
||||
async def _update_webhook(webhook_id: int, j: dict):
|
||||
|
|
@ -289,6 +289,7 @@ async def modify_webhook(webhook_id: int):
|
|||
await _update_webhook(webhook_id, j)
|
||||
|
||||
webhook = await get_webhook(webhook_id)
|
||||
assert webhook is not None
|
||||
|
||||
# we don't need to cast channel_id to int since that isn't
|
||||
# used in the dispatcher call
|
||||
|
|
@ -313,7 +314,9 @@ async def modify_webhook_tokened(webhook_id, webhook_token):
|
|||
async def delete_webhook(webhook_id: int):
|
||||
"""Delete a webhook."""
|
||||
webhook = await get_webhook(webhook_id)
|
||||
assert webhook is not None
|
||||
|
||||
# TODO use returning?
|
||||
res = await app.db.execute(
|
||||
"""
|
||||
DELETE FROM webhooks
|
||||
|
|
@ -423,7 +426,10 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
|
|||
# we still fetch the URL to check its validity, mimetypes, etc
|
||||
# but in the end, we will store it under the webhook_avatars table,
|
||||
# not IconManager.
|
||||
resp, raw = await fetch_mediaproxy_img(avatar_url)
|
||||
res = await fetch_mediaproxy_img(avatar_url)
|
||||
if res is None:
|
||||
raise BadRequest("Failed to fetch URL.")
|
||||
resp, raw = res
|
||||
# raw_b64 = base64.b64encode(raw).decode()
|
||||
|
||||
mime = resp.headers["content-type"]
|
||||
|
|
@ -469,6 +475,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
|||
given_embeds = j.get("embeds", [])
|
||||
|
||||
webhook = await get_webhook(webhook_id)
|
||||
assert webhook is not None
|
||||
|
||||
# webhooks have TWO avatars. one is from settings, the other is from
|
||||
# the json's icon_url. one can be handled gracefully by IconManager,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ You should have received a copy of the GNU General Public License
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
from typing import Optional
|
||||
from quart import current_app as app
|
||||
|
||||
|
||||
|
|
@ -24,12 +25,15 @@ from litecord.enums import RelationshipType
|
|||
|
||||
|
||||
async def channel_ack(
|
||||
user_id: int, guild_id: int, channel_id: int, message_id: int = None
|
||||
user_id: int, guild_id: int, channel_id: int, message_id: Optional[int] = None
|
||||
):
|
||||
"""ACK a channel."""
|
||||
|
||||
if not message_id:
|
||||
message_id = await app.storage.chan_last_message(channel_id)
|
||||
message_id = message_id or await app.storage.chan_last_message(channel_id)
|
||||
|
||||
# never ack without a message, as that breaks read state.
|
||||
if message_id is None:
|
||||
return
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from random import randint
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, List
|
||||
|
||||
from quart import current_app as app
|
||||
from asyncpg import UniqueViolationError
|
||||
|
|
@ -33,13 +33,13 @@ from ..utils import rand_hex
|
|||
log = Logger(__name__)
|
||||
|
||||
|
||||
async def mass_user_update(user_id):
|
||||
async def mass_user_update(user_id: int):
|
||||
"""Dispatch USER_UPDATE in a mass way."""
|
||||
# by using dispatch_with_filter
|
||||
# we're guaranteeing all shards will get
|
||||
# a USER_UPDATE once and not any others.
|
||||
|
||||
session_ids = []
|
||||
session_ids: List[str] = []
|
||||
|
||||
public_user = await app.storage.get_user(user_id)
|
||||
private_user = await app.storage.get_user(user_id, secure=True)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import re
|
|||
import asyncio
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from quart import current_app as app
|
||||
from logbook import Logger
|
||||
|
|
@ -35,16 +35,16 @@ log = Logger(__name__)
|
|||
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
|
||||
|
||||
|
||||
async def fetch_mediaproxy_img_meta(url) -> dict:
|
||||
async def fetch_mediaproxy_img_meta(url) -> Optional[dict]:
|
||||
"""Insert media metadata as an embed."""
|
||||
img_proxy_url = proxify(url)
|
||||
meta = await fetch_metadata(url)
|
||||
|
||||
if meta is None:
|
||||
return
|
||||
return None
|
||||
|
||||
if not meta["image"]:
|
||||
return
|
||||
return None
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
|
|
@ -135,13 +135,15 @@ async def process_url_embed(payload: dict, *, delay=0):
|
|||
# if it isn't, we forward an /embed/ scope call to mediaproxy
|
||||
# to generate an embed for us out of the url.
|
||||
|
||||
new_embeds = []
|
||||
new_embeds: List[dict] = []
|
||||
|
||||
for url in urls:
|
||||
url: List[dict] = EmbedURL(url)
|
||||
for upstream_url in urls:
|
||||
url = EmbedURL(upstream_url)
|
||||
|
||||
if is_media_url(url):
|
||||
embeds = [await fetch_mediaproxy_img_meta(url)]
|
||||
embed = await fetch_mediaproxy_img_meta(url)
|
||||
if embed is not None:
|
||||
embeds = [embed]
|
||||
else:
|
||||
embeds = await fetch_mediaproxy_embed(url)
|
||||
|
||||
|
|
|
|||
|
|
@ -98,6 +98,9 @@ class Icon:
|
|||
|
||||
return get_ext(self.mime)
|
||||
|
||||
def __bool__(self):
|
||||
return self.key and self.icon_hash and self.mime
|
||||
|
||||
|
||||
class ImageError(Exception):
|
||||
"""Image error class."""
|
||||
|
|
@ -197,6 +200,9 @@ def _gen_update_sql(scope: str) -> str:
|
|||
|
||||
def _invalid(kwargs: dict) -> Optional[Icon]:
|
||||
"""Send an invalid value."""
|
||||
# TODO: remove optinality off this (really badly designed):
|
||||
# - also remove kwargs off this function
|
||||
# - also make an Icon.empty() constructor, and remove need for this entirely
|
||||
if not kwargs.get("always_icon", False):
|
||||
return None
|
||||
|
||||
|
|
@ -519,6 +525,7 @@ class IconManager:
|
|||
key = str(key)
|
||||
|
||||
old_icon = await self.generic_get(scope, key, old_icon_hash)
|
||||
await self.delete(old_icon)
|
||||
if old_icon:
|
||||
await self.delete(old_icon)
|
||||
|
||||
return await self.put(scope, key, new_icon_data, **kwargs)
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ class Permissions(ctypes.Union):
|
|||
|
||||
|
||||
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
|
||||
EMPTY_PERMISSIONS = Permissions(0)
|
||||
|
||||
|
||||
async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class ChannelDispatcher(DispatcherWithFlags):
|
|||
):
|
||||
continue
|
||||
|
||||
cur_sess = []
|
||||
cur_sess: List[str] = []
|
||||
|
||||
if (
|
||||
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithState
|
||||
|
|
@ -38,7 +39,7 @@ class FriendDispatcher(DispatcherWithState):
|
|||
async def dispatch_filter(self, user_id: int, func, event, data):
|
||||
"""Dispatch an event to all of a users' friends."""
|
||||
peer_ids = self.state[user_id]
|
||||
sessions = []
|
||||
sessions: List[str] = []
|
||||
|
||||
for peer_id in peer_ids:
|
||||
# dispatch to the user instead of the "shards tied to a guild"
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from litecord.permissions import (
|
|||
overwrite_find_mix,
|
||||
get_permissions,
|
||||
role_permissions,
|
||||
EMPTY_PERMISSIONS,
|
||||
)
|
||||
from litecord.utils import index_by_func
|
||||
from litecord.utils import mmh3
|
||||
|
|
@ -264,8 +265,8 @@ class GuildMemberList:
|
|||
self.list = MemberList()
|
||||
|
||||
#: store the states that are subscribed to the list.
|
||||
# type is {session_id: set[list]}
|
||||
self.state: Dict[str, Set[List[int, int]]] = defaultdict(set)
|
||||
# type is {session_id: set[tuple]}
|
||||
self.state: Dict[str, Set[Tuple[int, int]]] = defaultdict(set)
|
||||
|
||||
self._list_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -414,8 +415,8 @@ class GuildMemberList:
|
|||
# inject default groups 'online' and 'offline'
|
||||
# their position is always going to be the last ones.
|
||||
self.list.groups = role_groups + [
|
||||
GroupInfo("online", "online", MAX_ROLES + 1, 0),
|
||||
GroupInfo("offline", "offline", MAX_ROLES + 2, 0),
|
||||
GroupInfo("online", "online", MAX_ROLES + 1, EMPTY_PERMISSIONS),
|
||||
GroupInfo("offline", "offline", MAX_ROLES + 2, EMPTY_PERMISSIONS),
|
||||
]
|
||||
|
||||
async def _get_group_for_member(
|
||||
|
|
@ -808,6 +809,8 @@ class GuildMemberList:
|
|||
ops = []
|
||||
|
||||
old_user_index = self._get_item_index(user_id)
|
||||
assert old_user_index is not None
|
||||
|
||||
old_group_index = self._get_group_item_index(old_group)
|
||||
|
||||
ops.append(Operation("DELETE", {"index": old_user_index}))
|
||||
|
|
@ -819,6 +822,7 @@ class GuildMemberList:
|
|||
await self._sort_groups()
|
||||
|
||||
new_user_index = self._get_item_index(user_id)
|
||||
assert new_user_index is not None
|
||||
|
||||
ops.append(
|
||||
Operation(
|
||||
|
|
@ -931,6 +935,7 @@ class GuildMemberList:
|
|||
# if unknown state, remove from the subscriber list
|
||||
if state is None:
|
||||
self.state.pop(session_id)
|
||||
continue
|
||||
|
||||
# if we aren't talking about the state the user
|
||||
# being removed is subscribed to, ignore
|
||||
|
|
@ -1365,8 +1370,8 @@ class GuildMemberList:
|
|||
len(self.state),
|
||||
)
|
||||
|
||||
self.guild_id = None
|
||||
self.channel_id = None
|
||||
self.guild_id = 0
|
||||
self.channel_id = 0
|
||||
self.main = None
|
||||
self._set_empty_list()
|
||||
self.state = {}
|
||||
|
|
@ -1392,7 +1397,7 @@ class LazyGuildDispatcher(Dispatcher):
|
|||
#: store which guilds have their
|
||||
# respective GMLs
|
||||
# {guild_id: [chan_id, ...], ...}
|
||||
self.guild_map = defaultdict(list)
|
||||
self.guild_map: Dict[int, List[int]] = defaultdict(list)
|
||||
|
||||
async def get_gml(self, channel_id: int):
|
||||
"""Get a guild list for a channel ID,
|
||||
|
|
@ -1414,7 +1419,18 @@ class LazyGuildDispatcher(Dispatcher):
|
|||
|
||||
def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]:
|
||||
"""Get all member lists for a given guild."""
|
||||
return list(map(self.state.get, self.guild_map[guild_id]))
|
||||
res: List[GuildMemberList] = []
|
||||
|
||||
channel_ids: List[int] = self.guild_map[guild_id]
|
||||
for channel_id in channel_ids:
|
||||
guild_list: Optional[GuildMemberList] = self.state.get(channel_id)
|
||||
if guild_list is None:
|
||||
self.guild_map[guild_id].remove(channel_id)
|
||||
continue
|
||||
|
||||
res.append(guild_list)
|
||||
|
||||
return res
|
||||
|
||||
async def unsub(self, chan_id, session_id):
|
||||
"""Unsubscribe a session from the list."""
|
||||
|
|
|
|||
|
|
@ -17,13 +17,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Union, TypedDict
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.enums import ChannelType
|
||||
from litecord.schemas import USER_MENTION, ROLE_MENTION
|
||||
from litecord.blueprints.channel.reactions import EmojiType, emoji_sql, partial_emoji
|
||||
from litecord.blueprints.channel.reactions import (
|
||||
EmojiType,
|
||||
emoji_sql,
|
||||
partial_emoji,
|
||||
PartialEmoji,
|
||||
)
|
||||
|
||||
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
|
||||
|
||||
|
|
@ -64,6 +69,12 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: str):
|
|||
return list(filter(lambda recipient: recipient["id"] != user_id, recipients))
|
||||
|
||||
|
||||
class EmojiStats(TypedDict):
|
||||
count: int
|
||||
me: bool
|
||||
emoji: PartialEmoji
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Class for common SQL statements."""
|
||||
|
||||
|
|
@ -373,7 +384,7 @@ class Storage:
|
|||
members = await self.get_member_multi(guild_id, mids)
|
||||
return members
|
||||
|
||||
async def chan_last_message(self, channel_id: int):
|
||||
async def chan_last_message(self, channel_id: int) -> Optional[int]:
|
||||
"""Get the last message ID in a channel."""
|
||||
return await self.db.fetchval(
|
||||
"""
|
||||
|
|
@ -491,7 +502,7 @@ class Storage:
|
|||
return [r["member_id"] for r in user_ids]
|
||||
|
||||
async def _gdm_recipients(
|
||||
self, channel_id: int, reference_id: int = None
|
||||
self, channel_id: int, reference_id: Optional[int] = None
|
||||
) -> List[Dict]:
|
||||
"""Get the list of users that are recipients of the
|
||||
given Group DM."""
|
||||
|
|
@ -576,11 +587,12 @@ class Storage:
|
|||
|
||||
drow = dict(gdm_row)
|
||||
drow["type"] = chan_type
|
||||
drow["recipients"] = await self._gdm_recipients(
|
||||
channel_id, kwargs.get("user_id")
|
||||
)
|
||||
drow["last_message_id"] = await self.chan_last_message_str(channel_id)
|
||||
|
||||
user_id: Optional[int] = kwargs.get("user_id")
|
||||
assert user_id is not None
|
||||
drow["recipients"] = await self._gdm_recipients(channel_id, user_id)
|
||||
|
||||
drow["last_message_id"] = await self.chan_last_message_str(channel_id)
|
||||
return drow
|
||||
|
||||
return None
|
||||
|
|
@ -634,7 +646,7 @@ class Storage:
|
|||
return channels
|
||||
|
||||
async def get_role(
|
||||
self, role_id: int, guild_id: int = None
|
||||
self, role_id: int, guild_id: Optional[int] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""get a single role's information."""
|
||||
|
||||
|
|
@ -732,6 +744,8 @@ class Storage:
|
|||
|
||||
mids = [int(m["user"]["id"]) for m in members]
|
||||
|
||||
assert self.presence is not None
|
||||
|
||||
return {
|
||||
**res,
|
||||
**{
|
||||
|
|
@ -822,10 +836,10 @@ class Storage:
|
|||
)
|
||||
|
||||
# ordered list of emoji
|
||||
emoji = []
|
||||
emoji: List[Union[int, str]] = []
|
||||
|
||||
# the current state of emoji info
|
||||
react_stats = {}
|
||||
react_stats: Dict[Union[str, int], EmojiStats] = {}
|
||||
|
||||
# to generate the list, we pass through all
|
||||
# all reactions and insert them all.
|
||||
|
|
@ -1007,6 +1021,7 @@ class Storage:
|
|||
# calculate user mentions and role mentions by regex
|
||||
async def _get_member(user_id):
|
||||
user = await self.get_user(user_id)
|
||||
assert user is not None
|
||||
member = None
|
||||
|
||||
if guild_id:
|
||||
|
|
@ -1143,6 +1158,7 @@ class Storage:
|
|||
return {}
|
||||
|
||||
mids = await self.get_member_ids(guild_id)
|
||||
assert self.presence is not None
|
||||
pres = await self.presence.guild_presences(mids, guild_id)
|
||||
online_count = sum(1 for p in pres if p["status"] == "online")
|
||||
|
||||
|
|
@ -1172,7 +1188,7 @@ class Storage:
|
|||
|
||||
return dinv
|
||||
|
||||
async def get_dm(self, dm_id: int, user_id: int = None) -> Optional[Dict]:
|
||||
async def get_dm(self, dm_id: int, user_id: Optional[int] = None) -> Optional[Dict]:
|
||||
"""Get a DM channel."""
|
||||
dm_chan = await self.get_channel(dm_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
from winter import get_snowflake
|
||||
from litecord.enums import MessageType
|
||||
|
|
@ -25,7 +26,7 @@ from litecord.enums import MessageType
|
|||
log = Logger(__name__)
|
||||
|
||||
|
||||
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
||||
async def _handle_pin_msg(channel_id, _pinned_id, author_id):
|
||||
"""Handle a message pin."""
|
||||
new_id = get_snowflake()
|
||||
|
||||
|
|
@ -48,7 +49,7 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
|||
|
||||
|
||||
# TODO: decrease repetition between add and remove handlers
|
||||
async def _handle_recp_add(app, channel_id, author_id, peer_id):
|
||||
async def _handle_recp_add(channel_id, author_id, peer_id):
|
||||
new_id = get_snowflake()
|
||||
|
||||
await app.db.execute(
|
||||
|
|
@ -69,7 +70,7 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id):
|
|||
return new_id
|
||||
|
||||
|
||||
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
||||
async def _handle_recp_rmv(channel_id, author_id, peer_id):
|
||||
new_id = get_snowflake()
|
||||
|
||||
await app.db.execute(
|
||||
|
|
@ -90,7 +91,7 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
|||
return new_id
|
||||
|
||||
|
||||
async def _handle_gdm_name_edit(app, channel_id, author_id):
|
||||
async def _handle_gdm_name_edit(channel_id, author_id):
|
||||
new_id = get_snowflake()
|
||||
|
||||
gdm_name = await app.db.fetchval(
|
||||
|
|
@ -123,7 +124,7 @@ async def _handle_gdm_name_edit(app, channel_id, author_id):
|
|||
return new_id
|
||||
|
||||
|
||||
async def _handle_gdm_icon_edit(app, channel_id, author_id):
|
||||
async def _handle_gdm_icon_edit(channel_id, author_id):
|
||||
new_id = get_snowflake()
|
||||
|
||||
await app.db.execute(
|
||||
|
|
@ -145,7 +146,7 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id):
|
|||
|
||||
|
||||
async def send_sys_message(
|
||||
app, channel_id: int, m_type: MessageType, *args, **kwargs
|
||||
channel_id: int, m_type: MessageType, *args, **kwargs
|
||||
) -> int:
|
||||
"""Send a system message.
|
||||
|
||||
|
|
@ -179,7 +180,7 @@ async def send_sys_message(
|
|||
except KeyError:
|
||||
raise ValueError("Invalid system message type")
|
||||
|
||||
message_id = await handler(app, channel_id, *args, **kwargs)
|
||||
message_id = await handler(channel_id, *args, **kwargs)
|
||||
|
||||
message = await app.storage.get_message(message_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
[mypy]
|
||||
check_untyped_defs = True
|
||||
no_implicit_optional = True
|
||||
[mypy-logbook]
|
||||
ignore_missing_imports = True
|
||||
[mypy-quart]
|
||||
ignore_missing_imports = True
|
||||
[mypy-winter]
|
||||
ignore_missing_imports = True
|
||||
[mypy-asyncpg]
|
||||
ignore_missing_imports = True
|
||||
Loading…
Reference in New Issue