mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'lazy-guilds' into 'master'
Lazy guilds Closes #12, #9, #8, and #2 See merge request luna/litecord!2
This commit is contained in:
commit
6245f08289
19
README.md
19
README.md
|
|
@ -7,8 +7,16 @@ This project is a rewrite of [litecord-reference].
|
|||
|
||||
[litecord-reference]: https://gitlab.com/luna/litecord-reference
|
||||
|
||||
## Notes
|
||||
|
||||
- There are no testing being run on the current codebase. Which means the code is definitely unstable.
|
||||
- No voice is planned to be developed, for now.
|
||||
- You must figure out connecting to the server yourself. Litecord will not distribute
|
||||
Discord's official client code nor provide ways to modify the client.
|
||||
|
||||
## Install
|
||||
|
||||
Requirements:
|
||||
- Python 3.6 or higher
|
||||
- PostgreSQL
|
||||
- [Pipenv]
|
||||
|
|
@ -28,6 +36,10 @@ $ psql -f schema.sql litecord
|
|||
# edit config.py as you wish
|
||||
$ cp config.example.py config.py
|
||||
|
||||
# run database migrations (this is a
|
||||
# required step in setup)
|
||||
$ pipenv run ./manage.py migrate
|
||||
|
||||
# Install all packages:
|
||||
$ pipenv install --dev
|
||||
```
|
||||
|
|
@ -42,3 +54,10 @@ Use `--access-log -` to output access logs to stdout.
|
|||
```sh
|
||||
$ pipenv run hypercorn run:app
|
||||
```
|
||||
|
||||
## Updating
|
||||
|
||||
```sh
|
||||
$ git pull
|
||||
$ pipenv run ./manage.py migrate
|
||||
```
|
||||
|
|
|
|||
|
|
@ -13,7 +13,11 @@ log = Logger(__name__)
|
|||
|
||||
async def raw_token_check(token, db=None):
|
||||
db = db or app.db
|
||||
user_id, _hmac = token.split('.')
|
||||
|
||||
# just try by fragments instead of
|
||||
# unpacking
|
||||
fragments = token.split('.')
|
||||
user_id = fragments[0]
|
||||
|
||||
try:
|
||||
user_id = base64.b64decode(user_id.encode())
|
||||
|
|
@ -35,6 +39,17 @@ async def raw_token_check(token, db=None):
|
|||
try:
|
||||
signer.unsign(token)
|
||||
log.debug('login for uid {} successful', user_id)
|
||||
|
||||
# update the user's last_session field
|
||||
# so that we can keep an exact track of activity,
|
||||
# even on long-lived single sessions (that can happen
|
||||
# with people leaving their clients open forever)
|
||||
await db.execute("""
|
||||
UPDATE users
|
||||
SET last_session = (now() at time zone 'utc')
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
|
||||
return user_id
|
||||
except BadSignature:
|
||||
log.warning('token failed for uid {}', user_id)
|
||||
|
|
@ -43,6 +58,12 @@ async def raw_token_check(token, db=None):
|
|||
|
||||
async def token_check():
|
||||
"""Check token information."""
|
||||
# first, check if the request info already has a uid
|
||||
try:
|
||||
return request.user_id
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
token = request.headers['Authorization']
|
||||
except KeyError:
|
||||
|
|
@ -51,4 +72,6 @@ async def token_check():
|
|||
if token.startswith('Bot '):
|
||||
token = token.replace('Bot ', '')
|
||||
|
||||
return await raw_token_check(token)
|
||||
user_id = await raw_token_check(token)
|
||||
request.user_id = user_id
|
||||
return user_id
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ async def register():
|
|||
|
||||
new_id = get_snowflake()
|
||||
|
||||
new_discrim = str(random.randint(1, 9999))
|
||||
new_discrim = random.randint(1, 9999)
|
||||
new_discrim = '%04d' % new_discrim
|
||||
|
||||
pwd_hash = await hash_data(password)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .messages import bp as channel_messages
|
||||
from .reactions import bp as channel_reactions
|
||||
from .pins import bp as channel_pins
|
||||
|
|
@ -0,0 +1,276 @@
|
|||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||
from litecord.blueprints.dms import try_dm_state
|
||||
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||
from litecord.snowflake import get_snowflake
|
||||
from litecord.schemas import validate, MESSAGE_CREATE
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channel_messages', __name__)
|
||||
|
||||
|
||||
def extract_limit(request, default: int = 50):
|
||||
try:
|
||||
limit = int(request.args.get('limit', default))
|
||||
|
||||
if limit not in range(0, 100):
|
||||
raise ValueError()
|
||||
except (TypeError, ValueError):
|
||||
raise BadRequest('limit not int')
|
||||
|
||||
return limit
|
||||
|
||||
|
||||
def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
||||
before, after = None, None
|
||||
|
||||
if 'around' in request.args:
|
||||
average = int(limit / 2)
|
||||
around = int(request.args['around'])
|
||||
|
||||
after = around - average
|
||||
before = around + average
|
||||
|
||||
elif 'before' in request.args:
|
||||
before = int(request.args['before'])
|
||||
elif 'after' in request.args:
|
||||
before = int(request.args['after'])
|
||||
|
||||
return before, after
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
||||
async def get_messages(channel_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check READ_MESSAGE_HISTORY permission
|
||||
ctype, peer_id = await channel_check(user_id, channel_id)
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
# make sure both parties will be subbed
|
||||
# to a dm
|
||||
await _dm_pre_dispatch(channel_id, user_id)
|
||||
await _dm_pre_dispatch(channel_id, peer_id)
|
||||
|
||||
limit = extract_limit(request, 50)
|
||||
|
||||
where_clause = ''
|
||||
before, after = query_tuple_from_args(request.args, limit)
|
||||
|
||||
if before:
|
||||
where_clause += f'AND id < {before}'
|
||||
|
||||
if after:
|
||||
where_clause += f'AND id > {after}'
|
||||
|
||||
message_ids = await app.db.fetch(f"""
|
||||
SELECT id
|
||||
FROM messages
|
||||
WHERE channel_id = $1 {where_clause}
|
||||
ORDER BY id DESC
|
||||
LIMIT {limit}
|
||||
""", channel_id)
|
||||
|
||||
result = []
|
||||
|
||||
for message_id in message_ids:
|
||||
msg = await app.storage.get_message(message_id['id'], user_id)
|
||||
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
|
||||
log.info('Fetched {} messages', len(result))
|
||||
return jsonify(result)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
|
||||
async def get_single_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
# TODO: check READ_MESSAGE_HISTORY permissions
|
||||
message = await app.storage.get_message(message_id, user_id)
|
||||
|
||||
if not message:
|
||||
raise MessageNotFound()
|
||||
|
||||
return jsonify(message)
|
||||
|
||||
|
||||
async def _dm_pre_dispatch(channel_id, peer_id):
|
||||
"""Do some checks pre-MESSAGE_CREATE so we
|
||||
make sure the receiving party will handle everything."""
|
||||
|
||||
# check the other party's dm_channel_state
|
||||
|
||||
dm_state = await app.db.fetchval("""
|
||||
SELECT dm_id
|
||||
FROM dm_channel_state
|
||||
WHERE user_id = $1 AND dm_id = $2
|
||||
""", peer_id, channel_id)
|
||||
|
||||
if dm_state:
|
||||
# the peer already has the channel
|
||||
# opened, so we don't need to do anything
|
||||
return
|
||||
|
||||
dm_chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
# dispatch CHANNEL_CREATE so the client knows which
|
||||
# channel the future event is about
|
||||
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
|
||||
|
||||
# subscribe the peer to the channel
|
||||
await app.dispatcher.sub('channel', channel_id, peer_id)
|
||||
|
||||
# insert it on dm_channel_state so the client
|
||||
# is subscribed on the future
|
||||
await try_dm_state(peer_id, channel_id)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
||||
async def create_message(channel_id):
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
await channel_perm_check(user_id, channel_id, 'send_messages')
|
||||
|
||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||
message_id = get_snowflake()
|
||||
|
||||
# TODO: check connection to the gateway
|
||||
|
||||
mentions_everyone = ('@everyone' in j['content'] and
|
||||
await channel_perm_check(
|
||||
user_id, channel_id, 'mention_everyone', False
|
||||
)
|
||||
)
|
||||
|
||||
is_tts = (j.get('tts', False) and
|
||||
await channel_perm_check(
|
||||
user_id, channel_id, 'send_tts_messages', False
|
||||
))
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
||||
mention_everyone, nonce, message_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
message_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
j['content'],
|
||||
|
||||
is_tts,
|
||||
mentions_everyone,
|
||||
|
||||
int(j.get('nonce', 0)),
|
||||
MessageType.DEFAULT.value
|
||||
)
|
||||
|
||||
payload = await app.storage.get_message(message_id, user_id)
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
# guild id here is the peer's ID.
|
||||
await _dm_pre_dispatch(channel_id, user_id)
|
||||
await _dm_pre_dispatch(channel_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch('channel', channel_id,
|
||||
'MESSAGE_CREATE', payload)
|
||||
|
||||
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
||||
|
||||
if ctype == ChannelType.GUILD_TEXT:
|
||||
for str_uid in payload['mentions']:
|
||||
uid = int(str_uid)
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE user_read_state
|
||||
SET mention_count += 1
|
||||
WHERE user_id = $1 AND channel_id = $2
|
||||
""", uid, channel_id)
|
||||
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
|
||||
async def edit_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
author_id = await app.db.fetchval("""
|
||||
SELECT author_id FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
|
||||
if not author_id == user_id:
|
||||
raise Forbidden('You can not edit this message')
|
||||
|
||||
j = await request.get_json()
|
||||
updated = 'content' in j or 'embed' in j
|
||||
|
||||
if 'content' in j:
|
||||
await app.db.execute("""
|
||||
UPDATE messages
|
||||
SET content=$1
|
||||
WHERE messages.id = $2
|
||||
""", j['content'], message_id)
|
||||
|
||||
# TODO: update embed
|
||||
|
||||
message = await app.storage.get_message(message_id, user_id)
|
||||
|
||||
# only dispatch MESSAGE_UPDATE if we actually had any update to start with
|
||||
if updated:
|
||||
await app.dispatcher.dispatch('channel', channel_id,
|
||||
'MESSAGE_UPDATE', message)
|
||||
|
||||
return jsonify(message)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
|
||||
async def delete_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
author_id = await app.db.fetchval("""
|
||||
SELECT author_id FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
|
||||
by_perm = await channel_perm_check(
|
||||
user_id, channel_id, 'manage_messages', False
|
||||
)
|
||||
|
||||
by_ownership = author_id == user_id
|
||||
|
||||
if not by_perm and not by_ownership:
|
||||
raise Forbidden('You can not delete this message')
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id,
|
||||
'MESSAGE_DELETE', {
|
||||
'id': str(message_id),
|
||||
'channel_id': str(channel_id),
|
||||
|
||||
# for lazy guilds
|
||||
'guild_id': str(guild_id),
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
from quart import Blueprint, current_app as app, request, jsonify
|
||||
|
||||
from litecord.auth import token_check
|
||||
from litecord.blueprints.checks import channel_check
|
||||
from litecord.snowflake import snowflake_datetime
|
||||
|
||||
bp = Blueprint('channel_pins', __name__)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins', methods=['GET'])
|
||||
async def get_pins(channel_id):
|
||||
"""Get the pins for a channel"""
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
ids = await app.db.fetch("""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
""", channel_id)
|
||||
|
||||
ids = [r['message_id'] for r in ids]
|
||||
res = []
|
||||
|
||||
for message_id in ids:
|
||||
message = await app.storage.get_message(message_id)
|
||||
if message is not None:
|
||||
res.append(message)
|
||||
|
||||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
|
||||
async def add_pin(channel_id, message_id):
|
||||
"""Add a pin to a channel"""
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
# TODO: check MANAGE_MESSAGES permission
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO channel_pins (channel_id, message_id)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, message_id)
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
LIMIT 1
|
||||
""", channel_id)
|
||||
|
||||
timestamp = snowflake_datetime(row['message_id'])
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
|
||||
'channel_id': str(channel_id),
|
||||
'last_pin_timestamp': timestamp.isoformat()
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
|
||||
async def delete_pin(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
# TODO: check MANAGE_MESSAGES permission
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM channel_pins
|
||||
WHERE channel_id = $1 AND message_id = $2
|
||||
""", channel_id, message_id)
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
LIMIT 1
|
||||
""", channel_id)
|
||||
|
||||
timestamp = snowflake_datetime(row['message_id'])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
|
||||
'channel_id': str(channel_id),
|
||||
'last_pin_timestamp': timestamp.isoformat()
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
|
@ -0,0 +1,225 @@
|
|||
from enum import IntEnum
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
||||
|
||||
from litecord.utils import async_map
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import channel_check
|
||||
from litecord.blueprints.channel.messages import (
|
||||
query_tuple_from_args, extract_limit
|
||||
)
|
||||
|
||||
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||
from litecord.enums import GUILD_CHANS
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channel_reactions', __name__)
|
||||
|
||||
BASEPATH = '/<int:channel_id>/messages/<int:message_id>/reactions'
|
||||
|
||||
|
||||
class EmojiType(IntEnum):
|
||||
CUSTOM = 0
|
||||
UNICODE = 1
|
||||
|
||||
|
||||
def emoji_info_from_str(emoji: str) -> tuple:
|
||||
"""Extract emoji information from an emoji string
|
||||
given on the reaction endpoints."""
|
||||
# custom emoji have an emoji of name:id
|
||||
# unicode emoji just have the raw unicode.
|
||||
|
||||
# try checking if the emoji is custom or unicode
|
||||
emoji_type = 0 if ':' in emoji else 1
|
||||
emoji_type = EmojiType(emoji_type)
|
||||
|
||||
# extract the emoji id OR the unicode value of the emoji
|
||||
# depending if it is custom or not
|
||||
emoji_id = (int(emoji.split(':')[1])
|
||||
if emoji_type == EmojiType.CUSTOM
|
||||
else emoji)
|
||||
|
||||
emoji_name = emoji.split(':')[0]
|
||||
|
||||
return emoji_type, emoji_id, emoji_name
|
||||
|
||||
|
||||
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
|
||||
print(emoji_type, emoji_id, emoji_name)
|
||||
return {
|
||||
'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
|
||||
}
|
||||
|
||||
|
||||
def _make_payload(user_id, channel_id, message_id, partial):
|
||||
return {
|
||||
'user_id': str(user_id),
|
||||
'channel_id': str(channel_id),
|
||||
'message_id': str(message_id),
|
||||
'emoji': partial
|
||||
}
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['PUT'])
|
||||
async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
||||
"""Put a reaction."""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check READ_MESSAGE_HISTORY permission
|
||||
# and ADD_REACTIONS. look on route docs.
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO message_reactions (message_id, user_id,
|
||||
emoji_type, emoji_id, emoji_text)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""", message_id, user_id, emoji_type,
|
||||
|
||||
# if it is custom, we put the emoji_id on emoji_id
|
||||
# column, if it isn't, we put it on emoji_text
|
||||
# column.
|
||||
emoji_id if emoji_type == EmojiType.CUSTOM else None,
|
||||
emoji_id if emoji_type == EmojiType.UNICODE else None
|
||||
)
|
||||
|
||||
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
payload['guild_id'] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_REACTION_ADD', payload)
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
|
||||
"""Extract SQL clauses to search for specific emoji
|
||||
in the message_reactions table."""
|
||||
param = f'${param}'
|
||||
|
||||
# know which column to filter with
|
||||
where_ext = (f'AND emoji_id = {param}'
|
||||
if emoji_type == EmojiType.CUSTOM else
|
||||
f'AND emoji_text = {param}')
|
||||
|
||||
# which emoji to remove (custom or unicode)
|
||||
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
|
||||
|
||||
return where_ext, main_emoji
|
||||
|
||||
|
||||
def _emoji_sql_simple(emoji: str, param=4):
|
||||
"""Simpler version of _emoji_sql for functions that
|
||||
don't need the results from emoji_info_from_str."""
|
||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
|
||||
|
||||
|
||||
async def remove_reaction(channel_id: int, message_id: int,
|
||||
user_id: int, emoji: str):
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||
where_ext, main_emoji = emoji_sql(emoji_type, emoji_id, emoji_name)
|
||||
|
||||
await app.db.execute(
|
||||
f"""
|
||||
DELETE FROM message_reactions
|
||||
WHERE message_id = $1
|
||||
AND user_id = $2
|
||||
AND emoji_type = $3
|
||||
{where_ext}
|
||||
""", message_id, user_id, emoji_type, main_emoji)
|
||||
|
||||
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
payload['guild_id'] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload)
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['DELETE'])
|
||||
async def remove_own_reaction(channel_id, message_id, emoji):
|
||||
"""Remove a reaction."""
|
||||
user_id = await token_check()
|
||||
|
||||
await remove_reaction(channel_id, message_id, user_id, emoji)
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>/<int:other_id>', methods=['DELETE'])
|
||||
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
|
||||
"""Remove a reaction made by another user."""
|
||||
await token_check()
|
||||
|
||||
# TODO: check MANAGE_MESSAGES permission (and use user_id
|
||||
# from token_check to do it)
|
||||
await remove_reaction(channel_id, message_id, other_id, emoji)
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>', methods=['GET'])
|
||||
async def list_users_reaction(channel_id, message_id, emoji):
|
||||
"""Get the list of all users who reacted with a certain emoji."""
|
||||
user_id = await token_check()
|
||||
|
||||
# this is not using either ctype or guild_id
|
||||
# that are returned by channel_check
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
limit = extract_limit(request, 25)
|
||||
before, after = query_tuple_from_args(request.args, limit)
|
||||
|
||||
before_clause = 'AND user_id < $2' if before else ''
|
||||
after_clause = 'AND user_id > $3' if after else ''
|
||||
|
||||
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
|
||||
|
||||
rows = await app.db.fetch(f"""
|
||||
SELECT user_id
|
||||
FROM message_reactions
|
||||
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
|
||||
""", message_id, before, after, main_emoji)
|
||||
|
||||
user_ids = [r['user_id'] for r in rows]
|
||||
users = await async_map(app.storage.get_user, user_ids)
|
||||
return jsonify(users)
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}', methods=['DELETE'])
|
||||
async def remove_all_reactions(channel_id, message_id):
|
||||
"""Remove all reactions in a message."""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check MANAGE_MESSAGES permission
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM message_reactions
|
||||
WHERE message_id = $1
|
||||
""", message_id)
|
||||
|
||||
payload = {
|
||||
'channel_id': str(channel_id),
|
||||
'message_id': str(message_id),
|
||||
}
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
payload['guild_id'] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload)
|
||||
|
|
@ -3,14 +3,14 @@ import time
|
|||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
||||
from ..auth import token_check
|
||||
from ..snowflake import get_snowflake, snowflake_datetime
|
||||
from ..enums import ChannelType, MessageType, GUILD_CHANS
|
||||
from ..errors import Forbidden, ChannelNotFound, MessageNotFound
|
||||
from ..schemas import validate, MESSAGE_CREATE
|
||||
from litecord.auth import token_check
|
||||
from litecord.enums import ChannelType, GUILD_CHANS
|
||||
from litecord.errors import ChannelNotFound
|
||||
from litecord.schemas import (
|
||||
validate, CHAN_UPDATE, CHAN_OVERWRITE
|
||||
)
|
||||
|
||||
from .checks import channel_check, guild_check
|
||||
from .dms import try_dm_state
|
||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channels', __name__)
|
||||
|
|
@ -136,6 +136,7 @@ async def guild_cleanup(channel_id):
|
|||
|
||||
@bp.route('/<int:channel_id>', methods=['DELETE'])
|
||||
async def close_channel(channel_id):
|
||||
"""Close or delete a channel."""
|
||||
user_id = await token_check()
|
||||
|
||||
chan_type = await app.storage.get_chan_type(channel_id)
|
||||
|
|
@ -212,287 +213,199 @@ async def close_channel(channel_id):
|
|||
# TODO: group dm
|
||||
pass
|
||||
|
||||
return '', 404
|
||||
raise ChannelNotFound()
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
||||
async def get_messages(channel_id):
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
# TODO: before, after, around keys
|
||||
|
||||
message_ids = await app.db.fetch(f"""
|
||||
SELECT id
|
||||
FROM messages
|
||||
WHERE channel_id = $1
|
||||
ORDER BY id DESC
|
||||
LIMIT 100
|
||||
""", channel_id)
|
||||
|
||||
result = []
|
||||
|
||||
for message_id in message_ids:
|
||||
msg = await app.storage.get_message(message_id['id'])
|
||||
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
|
||||
log.info('Fetched {} messages', len(result))
|
||||
return jsonify(result)
|
||||
async def _update_pos(channel_id, pos: int):
|
||||
await app.db.execute("""
|
||||
UPDATE guild_channels
|
||||
SET position = $1
|
||||
WHERE id = $2
|
||||
""", pos, channel_id)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
|
||||
async def get_single_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
# TODO: check READ_MESSAGE_HISTORY permissions
|
||||
message = await app.storage.get_message(message_id)
|
||||
|
||||
if not message:
|
||||
raise MessageNotFound()
|
||||
|
||||
return jsonify(message)
|
||||
async def _mass_chan_update(guild_id, channel_ids: int):
|
||||
for channel_id in channel_ids:
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch(
|
||||
'guild', guild_id, 'CHANNEL_UPDATE', chan)
|
||||
|
||||
|
||||
async def _dm_pre_dispatch(channel_id, peer_id):
|
||||
"""Do some checks pre-MESSAGE_CREATE so we
|
||||
make sure the receiving party will handle everything."""
|
||||
async def _process_overwrites(channel_id: int, overwrites: list):
|
||||
for overwrite in overwrites:
|
||||
|
||||
# check the other party's dm_channel_state
|
||||
# 0 for user overwrite, 1 for role overwrite
|
||||
target_type = 0 if overwrite['type'] == 'user' else 1
|
||||
target_role = None if target_type == 0 else overwrite['id']
|
||||
target_user = overwrite['id'] if target_type == 0 else None
|
||||
|
||||
dm_state = await app.db.fetchval("""
|
||||
SELECT dm_id
|
||||
FROM dm_channel_state
|
||||
WHERE user_id = $1 AND dm_id = $2
|
||||
""", peer_id, channel_id)
|
||||
|
||||
if dm_state:
|
||||
# the peer already has the channel
|
||||
# opened, so we don't need to do anything
|
||||
return
|
||||
|
||||
dm_chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
# dispatch CHANNEL_CREATE so the client knows which
|
||||
# channel the future event is about
|
||||
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
|
||||
|
||||
# subscribe the peer to the channel
|
||||
await app.dispatcher.sub('channel', channel_id, peer_id)
|
||||
|
||||
# insert it on dm_channel_state so the client
|
||||
# is subscribed on the future
|
||||
await try_dm_state(peer_id, channel_id)
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO channel_overwrites
|
||||
(channel_id, target_type, target_role,
|
||||
target_user, allow, deny)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT ON CONSTRAINT channel_overwrites_uniq
|
||||
DO
|
||||
UPDATE
|
||||
SET allow = $5, deny = $6
|
||||
WHERE channel_overwrites.channel_id = $1
|
||||
AND channel_overwrites.target_type = $2
|
||||
AND channel_overwrites.target_role = $3
|
||||
AND channel_overwrites.target_user = $4
|
||||
""",
|
||||
channel_id, target_type,
|
||||
target_role, target_user,
|
||||
overwrite['allow'], overwrite['deny'])
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
||||
async def create_message(channel_id):
|
||||
@bp.route('/<int:channel_id>/permissions/<int:overwrite_id>', methods=['PUT'])
|
||||
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||
"""Insert or modify a channel overwrite."""
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||
message_id = get_snowflake()
|
||||
if ctype not in GUILD_CHANS:
|
||||
raise ChannelNotFound('Only usable for guild channels.')
|
||||
|
||||
# TODO: check SEND_MESSAGES permission
|
||||
# TODO: check connection to the gateway
|
||||
await channel_perm_check(user_id, guild_id, 'manage_roles')
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
||||
mention_everyone, nonce, message_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
message_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
j['content'],
|
||||
|
||||
# TODO: check SEND_TTS_MESSAGES
|
||||
j.get('tts', False),
|
||||
|
||||
# TODO: check MENTION_EVERYONE permissions
|
||||
'@everyone' in j['content'],
|
||||
int(j.get('nonce', 0)),
|
||||
MessageType.DEFAULT.value
|
||||
j = validate(
|
||||
# inserting a fake id on the payload so validation passes through
|
||||
{**await request.get_json(), **{'id': -1}},
|
||||
CHAN_OVERWRITE
|
||||
)
|
||||
|
||||
payload = await app.storage.get_message(message_id)
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
# guild id here is the peer's ID.
|
||||
await _dm_pre_dispatch(channel_id, guild_id)
|
||||
await _process_overwrites(channel_id, [{
|
||||
'allow': j['allow'],
|
||||
'deny': j['deny'],
|
||||
'type': j['type'],
|
||||
'id': overwrite_id
|
||||
}])
|
||||
|
||||
await app.dispatcher.dispatch('channel', channel_id,
|
||||
'MESSAGE_CREATE', payload)
|
||||
|
||||
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
||||
|
||||
if ctype == ChannelType.GUILD_TEXT:
|
||||
for str_uid in payload['mentions']:
|
||||
uid = int(str_uid)
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE user_read_state
|
||||
SET mention_count += 1
|
||||
WHERE user_id = $1 AND channel_id = $2
|
||||
""", uid, channel_id)
|
||||
|
||||
return jsonify(payload)
|
||||
await _mass_chan_update(guild_id, [channel_id])
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
|
||||
async def edit_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
author_id = await app.db.fetchval("""
|
||||
SELECT author_id FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
|
||||
if not author_id == user_id:
|
||||
raise Forbidden('You can not edit this message')
|
||||
|
||||
j = await request.get_json()
|
||||
updated = 'content' in j or 'embed' in j
|
||||
|
||||
if 'content' in j:
|
||||
async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
||||
if 'name' in j:
|
||||
await app.db.execute("""
|
||||
UPDATE messages
|
||||
SET content=$1
|
||||
WHERE messages.id = $2
|
||||
""", j['content'], message_id)
|
||||
UPDATE guild_channels
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""", j['name'], channel_id)
|
||||
|
||||
# TODO: update embed
|
||||
if 'position' in j:
|
||||
channel_data = await app.storage.get_channel_data(guild_id)
|
||||
|
||||
message = await app.storage.get_message(message_id)
|
||||
chans = [None * len(channel_data)]
|
||||
for chandata in channel_data:
|
||||
chans.insert(chandata['position'], int(chandata['id']))
|
||||
|
||||
# only dispatch MESSAGE_UPDATE if we actually had any update to start with
|
||||
if updated:
|
||||
await app.dispatcher.dispatch('channel', channel_id,
|
||||
'MESSAGE_UPDATE', message)
|
||||
# are we changing to the left or to the right?
|
||||
|
||||
return jsonify(message)
|
||||
# left: [channel1, channel2, ..., channelN-1, channelN]
|
||||
# becomes
|
||||
# [channel1, channelN-1, channel2, ..., channelN]
|
||||
# so we can say that the "main change" is
|
||||
# channelN-1 going to the position channel2
|
||||
# was occupying.
|
||||
current_pos = chans.index(channel_id)
|
||||
new_pos = j['position']
|
||||
|
||||
# if the new position is bigger than the current one,
|
||||
# we're making a left shift of all the channels that are
|
||||
# beyond the current one, to make space
|
||||
left_shift = new_pos > current_pos
|
||||
|
||||
# find all channels that we'll have to shift
|
||||
shift_block = (chans[current_pos:new_pos]
|
||||
if left_shift else
|
||||
chans[new_pos:current_pos]
|
||||
)
|
||||
|
||||
shift = -1 if left_shift else 1
|
||||
|
||||
# do the shift (to the left or to the right)
|
||||
await app.db.executemany("""
|
||||
UPDATE guild_channels
|
||||
SET position = position + $1
|
||||
WHERE id = $2
|
||||
""", [(shift, chan_id) for chan_id in shift_block])
|
||||
|
||||
await _mass_chan_update(guild_id, shift_block)
|
||||
|
||||
# since theres now an empty slot, move current channel to it
|
||||
await _update_pos(channel_id, new_pos)
|
||||
|
||||
if 'channel_overwrites' in j:
|
||||
overwrites = j['channel_overwrites']
|
||||
await _process_overwrites(channel_id, overwrites)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
|
||||
async def delete_message(channel_id, message_id):
|
||||
async def _common_guild_chan(channel_id, j: dict):
|
||||
# common updates to the guild_channels table
|
||||
for field in [field for field in j.keys()
|
||||
if field in ('nsfw', 'parent_id')]:
|
||||
await app.db.execute(f"""
|
||||
UPDATE guild_channels
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], channel_id)
|
||||
|
||||
|
||||
async def _update_text_channel(channel_id: int, j: dict):
|
||||
# first do the specific ones related to guild_text_channels
|
||||
for field in [field for field in j.keys()
|
||||
if field in ('topic', 'rate_limit_per_user')]:
|
||||
await app.db.execute(f"""
|
||||
UPDATE guild_text_channels
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], channel_id)
|
||||
|
||||
await _common_guild_chan(channel_id, j)
|
||||
|
||||
|
||||
async def _update_voice_channel(channel_id: int, j: dict):
|
||||
# first do the specific ones in guild_voice_channels
|
||||
for field in [field for field in j.keys()
|
||||
if field in ('bitrate', 'user_limit')]:
|
||||
await app.db.execute(f"""
|
||||
UPDATE guild_voice_channels
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], channel_id)
|
||||
|
||||
# yes, i'm letting voice channels have nsfw, you cant stop me
|
||||
await _common_guild_chan(channel_id, j)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
|
||||
async def update_channel(channel_id):
|
||||
"""Update a channel's information"""
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
author_id = await app.db.fetchval("""
|
||||
SELECT author_id FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
if ctype not in GUILD_CHANS:
|
||||
raise ChannelNotFound('Can not edit non-guild channels.')
|
||||
|
||||
# TODO: MANAGE_MESSAGES permission check
|
||||
if author_id != user_id:
|
||||
raise Forbidden('You can not delete this message')
|
||||
await channel_perm_check(user_id, channel_id, 'manage_channels')
|
||||
j = validate(await request.get_json(), CHAN_UPDATE)
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
# TODO: categories?
|
||||
update_handler = {
|
||||
ChannelType.GUILD_TEXT: _update_text_channel,
|
||||
ChannelType.GUILD_VOICE: _update_voice_channel,
|
||||
}[ctype]
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id,
|
||||
'MESSAGE_DELETE', {
|
||||
'id': str(message_id),
|
||||
'channel_id': str(channel_id),
|
||||
await _update_channel_common(channel_id, guild_id, j)
|
||||
await update_handler(channel_id, j)
|
||||
|
||||
# for lazy guilds
|
||||
'guild_id': str(guild_id),
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins', methods=['GET'])
|
||||
async def get_pins(channel_id):
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
ids = await app.db.fetch("""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
""", channel_id)
|
||||
|
||||
ids = [r['message_id'] for r in ids]
|
||||
res = []
|
||||
|
||||
for message_id in ids:
|
||||
message = await app.storage.get_message(message_id)
|
||||
if message is not None:
|
||||
res.append(message)
|
||||
|
||||
return jsonify(message)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
|
||||
async def add_pin(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
# TODO: check MANAGE_MESSAGES permission
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO channel_pins (channel_id, message_id)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, message_id)
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
LIMIT 1
|
||||
""", channel_id)
|
||||
|
||||
timestamp = snowflake_datetime(row['message_id'])
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
|
||||
'channel_id': str(channel_id),
|
||||
'last_pin_timestamp': timestamp.isoformat()
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
|
||||
async def delete_pin(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
# TODO: check MANAGE_MESSAGES permission
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM channel_pins
|
||||
WHERE channel_id = $1 AND message_id = $2
|
||||
""", channel_id, message_id)
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
LIMIT 1
|
||||
""", channel_id)
|
||||
|
||||
timestamp = snowflake_datetime(row['message_id'])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
|
||||
'channel_id': str(channel_id),
|
||||
'last_pin_timestamp': timestamp.isoformat()
|
||||
})
|
||||
|
||||
return '', 204
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan)
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/typing', methods=['POST'])
|
||||
|
|
@ -518,21 +431,18 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
|||
if not message_id:
|
||||
message_id = await app.storage.chan_last_message(channel_id)
|
||||
|
||||
res = await app.db.execute("""
|
||||
UPDATE user_read_state
|
||||
|
||||
SET last_message_id = $1,
|
||||
mention_count = 0
|
||||
|
||||
WHERE user_id = $2 AND channel_id = $3
|
||||
""", message_id, user_id, channel_id)
|
||||
|
||||
if res == 'UPDATE 0':
|
||||
await app.db.execute("""
|
||||
INSERT INTO user_read_state
|
||||
(user_id, channel_id, last_message_id, mention_count)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", user_id, channel_id, message_id, 0)
|
||||
await app.db.execute("""
|
||||
INSERT INTO user_read_state
|
||||
(user_id, channel_id, last_message_id, mention_count)
|
||||
VALUES
|
||||
($1, $2, $3, 0)
|
||||
ON CONFLICT ON CONSTRAINT user_read_state_pkey
|
||||
DO
|
||||
UPDATE
|
||||
SET last_message_id = $3, mention_count = 0
|
||||
WHERE user_read_state.user_id = $1
|
||||
AND user_read_state.channel_id = $2
|
||||
""", user_id, channel_id, message_id)
|
||||
|
||||
if guild_id:
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
|
|
@ -551,6 +461,7 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
|||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
|
||||
async def ack_channel(channel_id, message_id):
|
||||
"""Acknowledge a channel."""
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
|
|
@ -569,6 +480,7 @@ async def ack_channel(channel_id, message_id):
|
|||
|
||||
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE'])
|
||||
async def delete_read_state(channel_id):
|
||||
"""Delete the read state of a channel."""
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from quart import current_app as app
|
||||
|
||||
from ..enums import ChannelType, GUILD_CHANS
|
||||
from ..errors import GuildNotFound, ChannelNotFound
|
||||
from litecord.enums import ChannelType, GUILD_CHANS
|
||||
from litecord.errors import (
|
||||
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
|
||||
)
|
||||
from litecord.permissions import base_permissions, get_permissions
|
||||
|
||||
|
||||
async def guild_check(user_id: int, guild_id: int):
|
||||
|
|
@ -16,6 +19,21 @@ async def guild_check(user_id: int, guild_id: int):
|
|||
raise GuildNotFound('guild not found')
|
||||
|
||||
|
||||
async def guild_owner_check(user_id: int, guild_id: int):
|
||||
"""Check if a user is the owner of the guild."""
|
||||
owner_id = await app.db.fetchval("""
|
||||
SELECT owner_id
|
||||
FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
|
||||
if not owner_id:
|
||||
raise GuildNotFound()
|
||||
|
||||
if user_id != owner_id:
|
||||
raise Forbidden('You are not the owner of the guild')
|
||||
|
||||
|
||||
async def channel_check(user_id, channel_id):
|
||||
"""Check if the current user is authorized
|
||||
to read the channel's information."""
|
||||
|
|
@ -39,3 +57,27 @@ async def channel_check(user_id, channel_id):
|
|||
if ctype == ChannelType.DM:
|
||||
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
||||
return ctype, peer_id
|
||||
|
||||
|
||||
async def guild_perm_check(user_id, guild_id, permission: str):
|
||||
"""Check guild permissions for a user."""
|
||||
base_perms = await base_permissions(user_id, guild_id)
|
||||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
||||
if not hasperm:
|
||||
raise MissingPermissions('Missing permissions.')
|
||||
|
||||
|
||||
async def channel_perm_check(user_id, channel_id,
|
||||
permission: str, raise_err=True):
|
||||
"""Check channel permissions for a user."""
|
||||
base_perms = await get_permissions(user_id, channel_id)
|
||||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
||||
print(base_perms)
|
||||
print(base_perms.binary)
|
||||
|
||||
if not hasperm and raise_err:
|
||||
raise MissingPermissions('Missing permissions.')
|
||||
|
||||
return hasperm
|
||||
|
|
|
|||
|
|
@ -38,41 +38,47 @@ async def try_dm_state(user_id: int, dm_id: int):
|
|||
""", user_id, dm_id)
|
||||
|
||||
|
||||
async def jsonify_dm(dm_id: int, user_id: int):
|
||||
dm_chan = await app.storage.get_dm(dm_id, user_id)
|
||||
return jsonify(dm_chan)
|
||||
|
||||
|
||||
async def create_dm(user_id, recipient_id):
|
||||
"""Create a new dm with a user,
|
||||
or get the existing DM id if it already exists."""
|
||||
|
||||
dm_id = await app.db.fetchval("""
|
||||
SELECT id
|
||||
FROM dm_channels
|
||||
WHERE (party1_id = $1 OR party2_id = $1) AND
|
||||
(party1_id = $2 OR party2_id = $2)
|
||||
""", user_id, recipient_id)
|
||||
|
||||
if dm_id:
|
||||
return await jsonify_dm(dm_id, user_id)
|
||||
|
||||
# if no dm was found, create a new one
|
||||
|
||||
dm_id = get_snowflake()
|
||||
await app.db.execute("""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", dm_id, ChannelType.DM.value)
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", dm_id, ChannelType.DM.value)
|
||||
await app.db.execute("""
|
||||
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", dm_id, user_id, recipient_id)
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", dm_id, user_id, recipient_id)
|
||||
# the dm state is something we use
|
||||
# to give the currently "open dms"
|
||||
# on the client.
|
||||
|
||||
# the dm state is something we use
|
||||
# to give the currently "open dms"
|
||||
# on the client.
|
||||
# we don't open a dm for the peer/recipient
|
||||
# until the user sends a message.
|
||||
await try_dm_state(user_id, dm_id)
|
||||
|
||||
# we don't open a dm for the peer/recipient
|
||||
# until the user sends a message.
|
||||
await try_dm_state(user_id, dm_id)
|
||||
|
||||
except UniqueViolationError:
|
||||
# the dm already exists
|
||||
dm_id = await app.db.fetchval("""
|
||||
SELECT id
|
||||
FROM dm_channels
|
||||
WHERE (party1_id = $1 OR party2_id = $1) AND
|
||||
(party2_id = $2 OR party2_id = $2)
|
||||
""", user_id, recipient_id)
|
||||
|
||||
dm = await app.storage.get_dm(dm_id, user_id)
|
||||
return jsonify(dm)
|
||||
return await jsonify_dm(dm_id, user_id)
|
||||
|
||||
|
||||
@bp.route('/@me/channels', methods=['POST'])
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import time
|
||||
|
||||
from quart import Blueprint, jsonify, current_app as app
|
||||
|
||||
from ..auth import token_check
|
||||
|
|
@ -6,12 +8,14 @@ bp = Blueprint('gateway', __name__)
|
|||
|
||||
|
||||
def get_gw():
|
||||
"""Get the gateway's web"""
|
||||
proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
|
||||
return f'{proto}{app.config["WEBSOCKET_URL"]}/ws'
|
||||
|
||||
|
||||
@bp.route('/gateway')
|
||||
def api_gateway():
|
||||
"""Get the raw URL."""
|
||||
return jsonify({
|
||||
'url': get_gw()
|
||||
})
|
||||
|
|
@ -27,9 +31,25 @@ async def api_gateway_bot():
|
|||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
shards = max(int(guild_count / 1200), 1)
|
||||
shards = max(int(guild_count / 1000), 1)
|
||||
|
||||
# get _ws.session ratelimit
|
||||
ratelimit = app.ratelimiter.get_ratelimit('_ws.session')
|
||||
bucket = ratelimit.get_bucket(user_id)
|
||||
|
||||
# timestamp of bucket reset
|
||||
reset_ts = bucket._window + bucket.second
|
||||
|
||||
# how many seconds until bucket reset
|
||||
reset_after_ts = reset_ts - time.time()
|
||||
|
||||
return jsonify({
|
||||
'url': get_gw(),
|
||||
'shards': shards,
|
||||
|
||||
'session_start_limit': {
|
||||
'total': bucket.requests,
|
||||
'remaining': bucket._tokens,
|
||||
'reset_after': int(reset_after_ts * 1000),
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
from .roles import bp as guild_roles
|
||||
from .members import bp as guild_members
|
||||
from .channels import bp as guild_channels
|
||||
from .mod import bp as guild_mod
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import guild_check, guild_owner_check
|
||||
from litecord.snowflake import get_snowflake
|
||||
from litecord.errors import BadRequest
|
||||
from litecord.enums import ChannelType
|
||||
from litecord.schemas import (
|
||||
validate, ROLE_UPDATE_POSITION
|
||||
)
|
||||
from litecord.blueprints.guild.roles import gen_pairs
|
||||
|
||||
|
||||
bp = Blueprint('guild_channels', __name__)
|
||||
|
||||
|
||||
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||
if ctype == ChannelType.GUILD_TEXT:
|
||||
await app.db.execute("""
|
||||
INSERT INTO guild_text_channels (id, topic)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, kwargs.get('topic', ''))
|
||||
elif ctype == ChannelType.GUILD_VOICE:
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
|
||||
VALUES ($1, $2, $3)
|
||||
""",
|
||||
channel_id,
|
||||
kwargs.get('bitrate', 64),
|
||||
kwargs.get('user_limit', 0)
|
||||
)
|
||||
|
||||
|
||||
async def create_guild_channel(guild_id: int, channel_id: int,
|
||||
ctype: ChannelType, **kwargs):
|
||||
"""Create a channel in a guild."""
|
||||
await app.db.execute("""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, ctype.value)
|
||||
|
||||
# calc new pos
|
||||
max_pos = await app.db.fetchval("""
|
||||
SELECT MAX(position)
|
||||
FROM guild_channels
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
# account for the first channel in a guild too
|
||||
max_pos = max_pos or 0
|
||||
|
||||
# all channels go to guild_channels
|
||||
await app.db.execute("""
|
||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", channel_id, guild_id, kwargs['name'], max_pos + 1)
|
||||
|
||||
# the rest of sql magic is dependant on the channel
|
||||
# we're creating (a text or voice or category),
|
||||
# so we use this function.
|
||||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||
|
||||
|
||||
@bp.route('/<int:guild>/channels', methods=['GET'])
|
||||
async def get_guild_channels(guild_id):
|
||||
"""Get the list of channels in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_channel_data(guild_id))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/channels', methods=['POST'])
|
||||
async def create_channel(guild_id):
|
||||
"""Create a channel in a guild."""
|
||||
user_id = await token_check()
|
||||
j = await request.get_json()
|
||||
|
||||
# TODO: check permissions for MANAGE_CHANNELS
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
channel_type = j.get('type', ChannelType.GUILD_TEXT)
|
||||
channel_type = ChannelType(channel_type)
|
||||
|
||||
if channel_type not in (ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE):
|
||||
raise BadRequest('Invalid channel type')
|
||||
|
||||
new_channel_id = get_snowflake()
|
||||
await create_guild_channel(
|
||||
guild_id, new_channel_id, channel_type, **j)
|
||||
|
||||
# TODO: do a better method
|
||||
# subscribe the currently subscribed users to the new channel
|
||||
# by getting all user ids and subscribing each one by one.
|
||||
|
||||
# since GuildDispatcher calls Storage.get_channel_ids,
|
||||
# it will subscribe all users to the newly created channel.
|
||||
guild_pubsub = app.dispatcher.backends['guild']
|
||||
user_ids = guild_pubsub.state[guild_id]
|
||||
for uid in user_ids:
|
||||
await app.dispatcher.sub('guild', guild_id, uid)
|
||||
|
||||
chan = await app.storage.get_channel(new_channel_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'CHANNEL_CREATE', chan)
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
async def _chan_update_dispatch(guild_id: int, channel_id: int):
|
||||
"""Fetch new information about the channel and dispatch
|
||||
a single CHANNEL_UPDATE event to the guild."""
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan)
|
||||
|
||||
|
||||
async def _do_single_swap(guild_id: int, pair: tuple):
|
||||
"""Do a single channel swap, dispatching
|
||||
the CHANNEL_UPDATE events for after the swap"""
|
||||
pair1, pair2 = pair
|
||||
channel_1, new_pos_1 = pair1
|
||||
channel_2, new_pos_2 = pair2
|
||||
|
||||
# do the swap in a transaction.
|
||||
conn = await app.db.acquire()
|
||||
|
||||
async with conn.transaction():
|
||||
await conn.executemany("""
|
||||
UPDATE guild_channels
|
||||
SET position = $1
|
||||
WHERE id = $2 AND guild_id = $3
|
||||
""", [
|
||||
(new_pos_1, channel_1, guild_id),
|
||||
(new_pos_2, channel_2, guild_id)])
|
||||
|
||||
await app.db.release(conn)
|
||||
|
||||
await _chan_update_dispatch(guild_id, channel_1)
|
||||
await _chan_update_dispatch(guild_id, channel_2)
|
||||
|
||||
|
||||
async def _do_channel_swaps(guild_id: int, swap_pairs: list):
|
||||
"""Swap channel pairs' positions, given the list
|
||||
of pairs to do.
|
||||
|
||||
Dispatches CHANNEL_UPDATEs to the guild.
|
||||
"""
|
||||
for pair in swap_pairs:
|
||||
await _do_single_swap(guild_id, pair)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
|
||||
async def modify_channel_pos(guild_id):
|
||||
"""Change positions of channels in a guild."""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check MANAGE_CHANNELS
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
# same thing as guild.roles, so we use
|
||||
# the same schema and all.
|
||||
raw_j = await request.get_json()
|
||||
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
||||
j = j['roles']
|
||||
|
||||
channels = await app.storage.get_channel_data(guild_id)
|
||||
|
||||
channel_positions = {chan['position']: int(chan['id'])
|
||||
for chan in channels}
|
||||
|
||||
swap_pairs = gen_pairs(
|
||||
j,
|
||||
channel_positions
|
||||
)
|
||||
|
||||
await _do_channel_swaps(guild_id, swap_pairs)
|
||||
|
||||
return '', 204
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import guild_check
|
||||
from litecord.errors import BadRequest
|
||||
from litecord.schemas import (
|
||||
validate, MEMBER_UPDATE
|
||||
)
|
||||
from litecord.blueprints.checks import guild_owner_check
|
||||
|
||||
|
||||
bp = Blueprint('guild_members', __name__)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
|
||||
async def get_guild_member(guild_id, member_id):
|
||||
"""Get a member's information in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
member = await app.storage.get_single_member(guild_id, member_id)
|
||||
return jsonify(member)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members', methods=['GET'])
|
||||
async def get_members(guild_id):
|
||||
"""Get members inside a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
limit, after = int(j.get('limit', 1)), j.get('after', 0)
|
||||
|
||||
if limit < 1 or limit > 1000:
|
||||
raise BadRequest('limit not in 1-1000 range')
|
||||
|
||||
user_ids = await app.db.fetch(f"""
|
||||
SELECT user_id
|
||||
WHERE guild_id = $1, user_id > $2
|
||||
LIMIT {limit}
|
||||
ORDER BY user_id ASC
|
||||
""", guild_id, after)
|
||||
|
||||
user_ids = [r[0] for r in user_ids]
|
||||
members = await app.storage.get_member_multi(guild_id, user_ids)
|
||||
return jsonify(members)
|
||||
|
||||
|
||||
async def _update_member_roles(guild_id: int, member_id: int,
|
||||
wanted_roles: list):
|
||||
"""Update the roles a member has."""
|
||||
|
||||
# first, fetch all current roles
|
||||
roles = await app.db.fetch("""
|
||||
SELECT role_id from member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
roles = [r['role_id'] for r in roles]
|
||||
|
||||
roles = set(roles)
|
||||
wanted_roles = set(wanted_roles)
|
||||
|
||||
# first, we need to find all added roles:
|
||||
# roles that are on wanted_roles but
|
||||
# not on roles
|
||||
added_roles = wanted_roles - roles
|
||||
|
||||
# and then the removed roles
|
||||
# which are roles in roles, but not
|
||||
# in wanted_roles
|
||||
removed_roles = roles - wanted_roles
|
||||
|
||||
conn = await app.db.acquire()
|
||||
|
||||
async with conn.transaction():
|
||||
# add roles
|
||||
await app.db.executemany("""
|
||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", [(member_id, guild_id, role_id)
|
||||
for role_id in added_roles])
|
||||
|
||||
# remove roles
|
||||
await app.db.executemany("""
|
||||
DELETE FROM member_roles
|
||||
WHERE
|
||||
user_id = $1
|
||||
AND guild_id = $2
|
||||
AND role_id = $3
|
||||
""", [(member_id, guild_id, role_id)
|
||||
for role_id in removed_roles])
|
||||
|
||||
await app.db.release(conn)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
|
||||
async def modify_guild_member(guild_id, member_id):
|
||||
"""Modify a members' information in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
j = validate(await request.get_json(), MEMBER_UPDATE)
|
||||
|
||||
if 'nick' in j:
|
||||
# TODO: check MANAGE_NICKNAMES
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET nickname = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['nick'], member_id, guild_id)
|
||||
|
||||
if 'mute' in j:
|
||||
# TODO: check MUTE_MEMBERS
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET muted = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['mute'], member_id, guild_id)
|
||||
|
||||
if 'deaf' in j:
|
||||
# TODO: check DEAFEN_MEMBERS
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET deafened = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['deaf'], member_id, guild_id)
|
||||
|
||||
if 'channel_id' in j:
|
||||
# TODO: check MOVE_MEMBERS and CONNECT to the channel
|
||||
# TODO: change the member's voice channel
|
||||
pass
|
||||
|
||||
if 'roles' in j:
|
||||
# TODO: check permissions
|
||||
await _update_member_roles(guild_id, member_id, j['roles'])
|
||||
|
||||
member = await app.storage.get_member_data_one(guild_id, member_id)
|
||||
member.pop('joined_at')
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||
'guild_id': str(guild_id)
|
||||
}, **member})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
|
||||
async def update_nickname(guild_id):
|
||||
"""Update a member's nickname in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET nickname = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['nick'], user_id, guild_id)
|
||||
|
||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||
member.pop('joined_at')
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||
'guild_id': str(guild_id)
|
||||
}, **member})
|
||||
|
||||
return j['nick']
|
||||
|
|
@ -0,0 +1,185 @@
|
|||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import guild_owner_check
|
||||
|
||||
from litecord.schemas import validate, GUILD_PRUNE
|
||||
|
||||
bp = Blueprint('guild_moderation', __name__)
|
||||
|
||||
|
||||
async def remove_member(guild_id: int, member_id: int):
|
||||
"""Do common tasks related to deleting a member from the guild,
|
||||
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', {
|
||||
'guild_id': guild_id,
|
||||
'unavailable': False,
|
||||
})
|
||||
|
||||
await app.dispatcher.unsub('guild', guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
||||
'guild_id': str(guild_id),
|
||||
'user': await app.storage.get_user(member_id),
|
||||
})
|
||||
|
||||
|
||||
async def remove_member_multi(guild_id: int, members: list):
|
||||
"""Remove multiple members."""
|
||||
for member_id in members:
|
||||
await remove_member(guild_id, member_id)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
|
||||
async def kick_guild_member(guild_id, member_id):
|
||||
"""Remove a member from a guild."""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check KICK_MEMBERS permission
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
await remove_member(guild_id, member_id)
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans', methods=['GET'])
|
||||
async def get_bans(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check BAN_MEMBERS permission
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
bans = await app.db.fetch("""
|
||||
SELECT user_id, reason
|
||||
FROM bans
|
||||
WHERE bans.guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
res = []
|
||||
|
||||
for ban in bans:
|
||||
res.append({
|
||||
'reason': ban['reason'],
|
||||
'user': await app.storage.get_user(ban['user_id'])
|
||||
})
|
||||
|
||||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
|
||||
async def create_ban(guild_id, member_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check BAN_MEMBERS permission
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO bans (guild_id, user_id, reason)
|
||||
VALUES ($1, $2, $3)
|
||||
""", guild_id, member_id, j.get('reason', ''))
|
||||
|
||||
await remove_member(guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {
|
||||
'guild_id': str(guild_id),
|
||||
'user': await app.storage.get_user(member_id)
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans/<int:banned_id>', methods=['DELETE'])
|
||||
async def remove_ban(guild_id, banned_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check BAN_MEMBERS permission
|
||||
await guild_owner_check(guild_id, user_id)
|
||||
|
||||
res = await app.db.execute("""
|
||||
DELETE FROM bans
|
||||
WHERE guild_id = $1 AND user_id = $@
|
||||
""", guild_id, banned_id)
|
||||
|
||||
# we don't really need to dispatch GUILD_BAN_REMOVE
|
||||
# when no bans were actually removed.
|
||||
if res == 'DELETE 0':
|
||||
return '', 204
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', {
|
||||
'guild_id': str(guild_id),
|
||||
'user': await app.storage.get_user(banned_id)
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
async def get_prune(guild_id: int, days: int) -> list:
|
||||
"""Get all members in a guild that:
|
||||
|
||||
- did not login in ``days`` days.
|
||||
- don't have any roles.
|
||||
"""
|
||||
# a good solution would be in pure sql.
|
||||
member_ids = await app.db.fetch(f"""
|
||||
SELECT id
|
||||
FROM users
|
||||
JOIN members
|
||||
ON members.guild_id = $1 AND members.user_id = users.id
|
||||
WHERE users.last_session < (now() - (interval '{days} days'))
|
||||
""", guild_id)
|
||||
|
||||
member_ids = [r['id'] for r in member_ids]
|
||||
members = []
|
||||
|
||||
for member_id in member_ids:
|
||||
role_count = await app.db.fetchval("""
|
||||
SELECT COUNT(*)
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
if role_count == 0:
|
||||
members.append(member_id)
|
||||
|
||||
return members
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/prune', methods=['GET'])
|
||||
async def get_guild_prune_count(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check KICK_MEMBERS
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
j = validate(await request.get_json(), GUILD_PRUNE)
|
||||
days = j['days']
|
||||
member_ids = await get_prune(guild_id, days)
|
||||
|
||||
return jsonify({
|
||||
'pruned': len(member_ids),
|
||||
})
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/prune', methods=['POST'])
|
||||
async def begin_guild_prune(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check KICK_MEMBERS
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
j = validate(await request.get_json(), GUILD_PRUNE)
|
||||
days = j['days']
|
||||
member_ids = await get_prune(guild_id, days)
|
||||
|
||||
app.loop.create_task(remove_member_multi(guild_id, member_ids))
|
||||
|
||||
return jsonify({
|
||||
'pruned': len(member_ids)
|
||||
})
|
||||
|
|
@ -0,0 +1,315 @@
|
|||
from typing import List, Dict, Any, Union
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
from litecord.auth import token_check
|
||||
|
||||
from litecord.blueprints.checks import (
|
||||
guild_check, guild_owner_check
|
||||
)
|
||||
from litecord.schemas import (
|
||||
validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
||||
)
|
||||
|
||||
from litecord.snowflake import get_snowflake
|
||||
from litecord.utils import dict_get
|
||||
|
||||
DEFAULT_EVERYONE_PERMS = 104324161
|
||||
bp = Blueprint('guild_roles', __name__)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles', methods=['GET'])
|
||||
async def get_guild_roles(guild_id):
|
||||
"""Get all roles in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_role_data(guild_id)
|
||||
)
|
||||
|
||||
|
||||
async def create_role(guild_id, name: str, **kwargs):
|
||||
"""Create a role in a guild."""
|
||||
new_role_id = get_snowflake()
|
||||
|
||||
# TODO: use @everyone's perm number
|
||||
default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS)
|
||||
|
||||
max_pos = await app.db.fetchval("""
|
||||
SELECT MAX(position)
|
||||
FROM roles
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO roles (id, guild_id, name, color,
|
||||
hoist, position, permissions, managed, mentionable)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
""",
|
||||
new_role_id,
|
||||
guild_id,
|
||||
name,
|
||||
dict_get(kwargs, 'color', 0),
|
||||
dict_get(kwargs, 'hoist', False),
|
||||
|
||||
# set position = 0 when there isn't any
|
||||
# other role (when we're creating the
|
||||
# @everyone role)
|
||||
max_pos + 1 if max_pos is not None else 0,
|
||||
int(dict_get(kwargs, 'permissions', default_perms)),
|
||||
False,
|
||||
dict_get(kwargs, 'mentionable', False)
|
||||
)
|
||||
|
||||
role = await app.storage.get_role(new_role_id, guild_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'GUILD_ROLE_CREATE', {
|
||||
'guild_id': str(guild_id),
|
||||
'role': role,
|
||||
})
|
||||
|
||||
return role
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles', methods=['POST'])
|
||||
async def create_guild_role(guild_id: int):
|
||||
"""Add a role to a guild"""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: use check_guild and MANAGE_ROLES permission
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
# client can just send null
|
||||
j = validate(await request.get_json() or {}, ROLE_CREATE)
|
||||
|
||||
role_name = j['name']
|
||||
j.pop('name')
|
||||
|
||||
role = await create_role(guild_id, role_name, **j)
|
||||
|
||||
return jsonify(role)
|
||||
|
||||
|
||||
async def _role_update_dispatch(role_id: int, guild_id: int):
|
||||
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
|
||||
role = await app.storage.get_role(role_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', {
|
||||
'guild_id': str(guild_id),
|
||||
'role': role,
|
||||
})
|
||||
|
||||
return role
|
||||
|
||||
|
||||
async def _role_pairs_update(guild_id: int, pairs: list):
|
||||
"""Update the roles' positions.
|
||||
|
||||
Dispatches GUILD_ROLE_UPDATE for all roles being updated.
|
||||
"""
|
||||
for pair in pairs:
|
||||
pair_1, pair_2 = pair
|
||||
|
||||
role_1, new_pos_1 = pair_1
|
||||
role_2, new_pos_2 = pair_2
|
||||
|
||||
conn = await app.db.acquire()
|
||||
async with conn.transaction():
|
||||
# update happens in a transaction
|
||||
# so we don't fuck it up
|
||||
await conn.execute("""
|
||||
UPDATE roles
|
||||
SET position = $1
|
||||
WHERE roles.id = $2
|
||||
""", new_pos_1, role_1)
|
||||
|
||||
await conn.execute("""
|
||||
UPDATE roles
|
||||
SET position = $1
|
||||
WHERE roles.id = $2
|
||||
""", new_pos_2, role_2)
|
||||
|
||||
await app.db.release(conn)
|
||||
|
||||
# the route fires multiple Guild Role Update.
|
||||
await _role_update_dispatch(role_1, guild_id)
|
||||
await _role_update_dispatch(role_2, guild_id)
|
||||
|
||||
|
||||
def gen_pairs(list_of_changes: List[Dict[str, int]],
|
||||
current_state: Dict[int, int],
|
||||
blacklist: List[int] = None) -> List[tuple]:
|
||||
"""Generate a list of pairs that, when applied to the database,
|
||||
will generate the desired state given in list_of_changes.
|
||||
|
||||
We must check if the given list_of_changes isn't overwriting an
|
||||
element's (such as a role or a channel) position to an existing one,
|
||||
without there having an already existing change for the other one.
|
||||
|
||||
Here's a pratical explanation with roles:
|
||||
|
||||
R1 (in position RP1) wants to be in the same position
|
||||
as R2 (currently in position RP2).
|
||||
|
||||
So, if we did the simpler approach, list_of_changes
|
||||
would just contain the preferred change: (R1, RP2).
|
||||
|
||||
With gen_pairs, there MUST be a (R2, RP1) in list_of_changes,
|
||||
if there is, the given result in gen_pairs will be a pair
|
||||
((R1, RP2), (R2, RP1)) which is then used to actually
|
||||
update the roles' positions in a transaction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
list_of_changes:
|
||||
A list of dictionaries with ``id`` and ``position``
|
||||
fields, describing the preferred changes.
|
||||
current_state:
|
||||
Dictionary containing the current state of the list
|
||||
of elements (roles or channels). Points position
|
||||
to element ID.
|
||||
blacklist:
|
||||
List of IDs that shouldn't be moved.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
List of swaps to do to achieve the preferred
|
||||
state given by ``list_of_changes``.
|
||||
"""
|
||||
pairs = []
|
||||
blacklist = blacklist or []
|
||||
|
||||
preferred_state = {element['id']: element['position']
|
||||
for element in list_of_changes}
|
||||
|
||||
for blacklisted_id in blacklist:
|
||||
preferred_state.pop(blacklisted_id)
|
||||
|
||||
# for each change, we must find a matching change
|
||||
# in the same list, so we can make a swap pair
|
||||
for change in list_of_changes:
|
||||
element_1, new_pos_1 = change['id'], change['position']
|
||||
|
||||
# check current pairs
|
||||
# so we don't repeat an element
|
||||
flag = False
|
||||
|
||||
for pair in pairs:
|
||||
if (element_1, new_pos_1) in pair:
|
||||
flag = True
|
||||
|
||||
# skip if found
|
||||
if flag:
|
||||
continue
|
||||
|
||||
# search if there is a role/channel in the
|
||||
# position we want to change to
|
||||
element_2 = current_state.get(new_pos_1)
|
||||
|
||||
# if there is, is that existing channel being
|
||||
# swapped to another position?
|
||||
new_pos_2 = preferred_state.get(element_2)
|
||||
|
||||
# if its being swapped to leave space, add it
|
||||
# to the pairs list
|
||||
if new_pos_2:
|
||||
pairs.append(
|
||||
((element_1, new_pos_1), (element_2, new_pos_2))
|
||||
)
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles', methods=['PATCH'])
|
||||
async def update_guild_role_positions(guild_id):
|
||||
"""Update the positions for a bunch of roles."""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check MANAGE_ROLES
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
raw_j = await request.get_json()
|
||||
|
||||
# we need to do this hackiness because thats
|
||||
# cerberus for ya.
|
||||
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
||||
|
||||
# extract the list out
|
||||
j = j['roles']
|
||||
|
||||
all_roles = await app.storage.get_role_data(guild_id)
|
||||
|
||||
# we'll have to calculate pairs of changing roles,
|
||||
# then do the changes, etc.
|
||||
roles_pos = {role['position']: int(role['id']) for role in all_roles}
|
||||
|
||||
# TODO: check if the user can even change the roles in the first place,
|
||||
# preferrably when we have a proper perms system.
|
||||
|
||||
pairs = gen_pairs(
|
||||
j,
|
||||
roles_pos,
|
||||
|
||||
# always ignore people trying to change
|
||||
# the @everyone's role position
|
||||
[guild_id]
|
||||
)
|
||||
|
||||
await _role_pairs_update(guild_id, pairs)
|
||||
|
||||
# return the list of all roles back
|
||||
return jsonify(await app.storage.get_role_data(guild_id))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['PATCH'])
|
||||
async def update_guild_role(guild_id, role_id):
|
||||
"""Update a single role's information."""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check MANAGE_ROLES
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
j = validate(await request.get_json(), ROLE_UPDATE)
|
||||
|
||||
# we only update ints on the db, not Permissions
|
||||
j['permissions'] = int(j['permissions'])
|
||||
|
||||
for field in j:
|
||||
await app.db.execute(f"""
|
||||
UPDATE roles
|
||||
SET {field} = $1
|
||||
WHERE roles.id = $2 AND roles.guild_id = $3
|
||||
""", j[field], role_id, guild_id)
|
||||
|
||||
role = await _role_update_dispatch(role_id, guild_id)
|
||||
return jsonify(role)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['DELETE'])
|
||||
async def delete_guild_role(guild_id, role_id):
|
||||
"""Delete a role.
|
||||
|
||||
Dispatches GUILD_ROLE_DELETE.
|
||||
"""
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check MANAGE_ROLES
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
res = await app.db.execute("""
|
||||
DELETE FROM roles
|
||||
WHERE guild_id = $1 AND id = $2
|
||||
""", guild_id, role_id)
|
||||
|
||||
if res == 'DELETE 0':
|
||||
return '', 204
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', {
|
||||
'guild_id': str(guild_id),
|
||||
'role_id': str(role_id),
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
|
@ -1,31 +1,23 @@
|
|||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
from litecord.blueprints.guild.channels import create_guild_channel
|
||||
from litecord.blueprints.guild.roles import (
|
||||
create_role, DEFAULT_EVERYONE_PERMS
|
||||
)
|
||||
|
||||
from ..auth import token_check
|
||||
from ..snowflake import get_snowflake
|
||||
from ..enums import ChannelType
|
||||
from ..errors import Forbidden, GuildNotFound, BadRequest
|
||||
from ..schemas import validate, GUILD_UPDATE
|
||||
from ..schemas import (
|
||||
validate, GUILD_CREATE, GUILD_UPDATE
|
||||
)
|
||||
from .channels import channel_ack
|
||||
from .checks import guild_check
|
||||
from .checks import guild_check, guild_owner_check
|
||||
|
||||
|
||||
bp = Blueprint('guilds', __name__)
|
||||
|
||||
|
||||
async def guild_owner_check(user_id: int, guild_id: int):
|
||||
"""Check if a user is the owner of the guild."""
|
||||
owner_id = await app.db.fetchval("""
|
||||
SELECT owner_id
|
||||
FROM guilds
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
if not owner_id:
|
||||
raise GuildNotFound()
|
||||
|
||||
if user_id != owner_id:
|
||||
raise Forbidden('You are not the owner of the guild')
|
||||
|
||||
|
||||
async def create_guild_settings(guild_id: int, user_id: int):
|
||||
"""Create guild settings for the user
|
||||
joining the guild."""
|
||||
|
|
@ -48,10 +40,59 @@ async def create_guild_settings(guild_id: int, user_id: int):
|
|||
""", m_notifs, user_id, guild_id)
|
||||
|
||||
|
||||
async def add_member(guild_id: int, user_id: int):
|
||||
"""Add a user to a guild."""
|
||||
await app.db.execute("""
|
||||
INSERT INTO members (user_id, guild_id)
|
||||
VALUES ($1, $2)
|
||||
""", user_id, guild_id)
|
||||
|
||||
await create_guild_settings(guild_id, user_id)
|
||||
|
||||
|
||||
async def guild_create_roles_prep(guild_id: int, roles: list):
|
||||
"""Create roles in preparation in guild create."""
|
||||
# by reaching this point in the code that means
|
||||
# roles is not nullable, which means
|
||||
# roles has at least one element, so we can access safely.
|
||||
|
||||
# the first member in the roles array
|
||||
# are patches to the @everyone role
|
||||
everyone_patches = roles[0]
|
||||
for field in everyone_patches:
|
||||
await app.db.execute(f"""
|
||||
UPDATE roles
|
||||
SET {field}={everyone_patches[field]}
|
||||
WHERE roles.id = $1
|
||||
""", guild_id)
|
||||
|
||||
default_perms = (everyone_patches.get('permissions')
|
||||
or DEFAULT_EVERYONE_PERMS)
|
||||
|
||||
# from the 2nd and forward,
|
||||
# should be treated as new roles
|
||||
for role in roles[1:]:
|
||||
await create_role(
|
||||
guild_id, role['name'], default_perms=default_perms, **role
|
||||
)
|
||||
|
||||
|
||||
async def guild_create_channels_prep(guild_id: int, channels: list):
|
||||
"""Create channels pre-guild create"""
|
||||
for channel_raw in channels:
|
||||
channel_id = get_snowflake()
|
||||
ctype = ChannelType(channel_raw['type'])
|
||||
|
||||
await create_guild_channel(guild_id, channel_id, ctype)
|
||||
|
||||
|
||||
@bp.route('', methods=['POST'])
|
||||
async def create_guild():
|
||||
"""Create a new guild, assigning
|
||||
the user creating it as the owner and
|
||||
making them join."""
|
||||
user_id = await token_check()
|
||||
j = await request.get_json()
|
||||
j = validate(await request.get_json(), GUILD_CREATE)
|
||||
|
||||
guild_id = get_snowflake()
|
||||
|
||||
|
|
@ -66,36 +107,37 @@ async def create_guild():
|
|||
j.get('default_message_notifications', 0),
|
||||
j.get('explicit_content_filter', 0))
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO members (user_id, guild_id)
|
||||
VALUES ($1, $2)
|
||||
""", user_id, guild_id)
|
||||
await add_member(guild_id, user_id)
|
||||
|
||||
await create_guild_settings(guild_id, user_id)
|
||||
# create the default @everyone role (everyone has it by default,
|
||||
# so we don't insert that in the table)
|
||||
|
||||
# we also don't use create_role because the id of the role
|
||||
# is the same as the id of the guild, and create_role
|
||||
# generates a new snowflake.
|
||||
await app.db.execute("""
|
||||
INSERT INTO roles (id, guild_id, name, position, permissions)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""", guild_id, guild_id, '@everyone', 0, 104324161)
|
||||
""", guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS)
|
||||
|
||||
# add the @everyone role to the guild creator
|
||||
await app.db.execute("""
|
||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", user_id, guild_id, guild_id)
|
||||
|
||||
# create a single #general channel.
|
||||
general_id = get_snowflake()
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", general_id, ChannelType.GUILD_TEXT.value)
|
||||
await create_guild_channel(
|
||||
guild_id, general_id, ChannelType.GUILD_TEXT,
|
||||
name='general')
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", general_id, guild_id, 'general', 0)
|
||||
if j.get('roles'):
|
||||
await guild_create_roles_prep(guild_id, j['roles'])
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO guild_text_channels (id)
|
||||
VALUES ($1)
|
||||
""", general_id)
|
||||
|
||||
# TODO: j['roles'] and j['channels']
|
||||
if j.get('channels'):
|
||||
await guild_create_channels_prep(guild_id, j['channels'])
|
||||
|
||||
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
|
||||
|
|
@ -106,21 +148,22 @@ async def create_guild():
|
|||
|
||||
@bp.route('/<int:guild_id>', methods=['GET'])
|
||||
async def get_guild(guild_id):
|
||||
"""Get a single guilds' information."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
gj = await app.storage.get_guild(guild_id, user_id)
|
||||
gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250)
|
||||
|
||||
return jsonify({**gj, **gj_extra})
|
||||
return jsonify(
|
||||
await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['UPDATE'])
|
||||
async def update_guild(guild_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
j = validate(await request.get_json(), GUILD_UPDATE)
|
||||
|
||||
# TODO: check MANAGE_GUILD
|
||||
await guild_check(user_id, guild_id)
|
||||
j = validate(await request.get_json(), GUILD_UPDATE)
|
||||
|
||||
if 'owner_id' in j:
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
|
@ -139,8 +182,6 @@ async def update_guild(guild_id):
|
|||
""", j['name'], guild_id)
|
||||
|
||||
if 'region' in j:
|
||||
# TODO: check region value
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE guilds
|
||||
SET region = $1
|
||||
|
|
@ -167,15 +208,14 @@ async def update_guild(guild_id):
|
|||
WHERE guild_id = $2
|
||||
""", j[field], guild_id)
|
||||
|
||||
# return guild object
|
||||
gj = await app.storage.get_guild(guild_id, user_id)
|
||||
gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250)
|
||||
guild = await app.storage.get_guild_full(
|
||||
guild_id, user_id
|
||||
)
|
||||
|
||||
gj_total = {**gj, **gj_extra}
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'GUILD_UPDATE', guild)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_UPDATE', gj_total)
|
||||
|
||||
return jsonify({**gj, **gj_extra})
|
||||
return jsonify(guild)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['DELETE'])
|
||||
|
|
@ -185,7 +225,7 @@ async def delete_guild(guild_id):
|
|||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM guild
|
||||
DELETE FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
|
||||
|
|
@ -202,264 +242,12 @@ async def delete_guild(guild_id):
|
|||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild>/channels', methods=['GET'])
|
||||
async def get_guild_channels(guild_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
channels = await app.storage.get_channel_data(guild_id)
|
||||
return jsonify(channels)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/channels', methods=['POST'])
|
||||
async def create_channel(guild_id):
|
||||
user_id = await token_check()
|
||||
j = await request.get_json()
|
||||
|
||||
# TODO: check permissions for MANAGE_CHANNELS
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
new_channel_id = get_snowflake()
|
||||
channel_type = j.get('type', ChannelType.GUILD_TEXT)
|
||||
|
||||
channel_type = ChannelType(channel_type)
|
||||
|
||||
if channel_type not in (ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE):
|
||||
raise BadRequest('Invalid channel type')
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", new_channel_id, channel_type.value)
|
||||
|
||||
max_pos = await app.db.fetchval("""
|
||||
SELECT MAX(position)
|
||||
FROM guild_channels
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
if channel_type == ChannelType.GUILD_TEXT:
|
||||
await app.db.execute("""
|
||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", new_channel_id, guild_id, j['name'], max_pos + 1)
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO guild_text_channels (id)
|
||||
VALUES ($1)
|
||||
""", new_channel_id)
|
||||
|
||||
elif channel_type == ChannelType.GUILD_VOICE:
|
||||
raise NotImplementedError()
|
||||
|
||||
chan = await app.storage.get_channel(new_channel_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_CREATE', chan)
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
|
||||
async def modify_channel_pos(guild_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
await request.get_json()
|
||||
|
||||
# TODO: this route
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
|
||||
async def get_guild_member(guild_id, member_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
member = await app.storage.get_single_member(guild_id, member_id)
|
||||
return jsonify(member)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members', methods=['GET'])
|
||||
async def get_members(guild_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
limit, after = int(j.get('limit', 1)), j.get('after', 0)
|
||||
|
||||
if limit < 1 or limit > 1000:
|
||||
raise BadRequest('limit not in 1-1000 range')
|
||||
|
||||
user_ids = await app.db.fetch(f"""
|
||||
SELECT user_id
|
||||
WHERE guild_id = $1, user_id > $2
|
||||
LIMIT {limit}
|
||||
ORDER BY user_id ASC
|
||||
""", guild_id, after)
|
||||
|
||||
user_ids = [r[0] for r in user_ids]
|
||||
members = await app.storage.get_member_multi(guild_id, user_ids)
|
||||
return jsonify(members)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
|
||||
async def modify_guild_member(guild_id, member_id):
|
||||
j = await request.get_json()
|
||||
|
||||
if 'nick' in j:
|
||||
# TODO: check MANAGE_NICKNAMES
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET nickname = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['nick'], member_id, guild_id)
|
||||
|
||||
if 'mute' in j:
|
||||
# TODO: check MUTE_MEMBERS
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET muted = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['mute'], member_id, guild_id)
|
||||
|
||||
if 'deaf' in j:
|
||||
# TODO: check DEAFEN_MEMBERS
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET deafened = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['deaf'], member_id, guild_id)
|
||||
|
||||
if 'channel_id' in j:
|
||||
# TODO: check MOVE_MEMBERS
|
||||
# TODO: change the member's voice channel
|
||||
pass
|
||||
|
||||
member = await app.storage.get_member_data_one(guild_id, member_id)
|
||||
member.pop('joined_at')
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||
'guild_id': str(guild_id)
|
||||
}, **member})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
|
||||
async def update_nickname(guild_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE members
|
||||
SET nickname = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['nick'], user_id, guild_id)
|
||||
|
||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||
member.pop('joined_at')
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||
'guild_id': str(guild_id)
|
||||
}, **member})
|
||||
|
||||
return j['nick']
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
|
||||
async def kick_member(guild_id, member_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check KICK_MEMBERS permission
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_user(user_id, 'GUILD_DELETE', {
|
||||
'guild_id': guild_id,
|
||||
'unavailable': False,
|
||||
})
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
||||
'guild': guild_id,
|
||||
'user': await app.storage.get_user(member_id),
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans', methods=['GET'])
|
||||
async def get_bans(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check BAN_MEMBERS permission
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
bans = await app.db.fetch("""
|
||||
SELECT user_id, reason
|
||||
FROM bans
|
||||
WHERE bans.guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
res = []
|
||||
|
||||
for ban in bans:
|
||||
res.append({
|
||||
'reason': ban['reason'],
|
||||
'user': await app.storage.get_user(ban['user_id'])
|
||||
})
|
||||
|
||||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
|
||||
async def create_ban(guild_id, member_id):
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: check BAN_MEMBERS permission
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO bans (guild_id, user_id, reason)
|
||||
VALUES ($1, $2, $3)
|
||||
""", guild_id, member_id, j.get('reason', ''))
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, user_id)
|
||||
|
||||
await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', {
|
||||
'guild_id': guild_id,
|
||||
'unavailable': False,
|
||||
})
|
||||
|
||||
await app.dispatcher.unsub('guild', guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
||||
'guild': guild_id,
|
||||
'user': await app.storage.get_user(member_id),
|
||||
})
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {**{
|
||||
'guild': guild_id,
|
||||
}, **(await app.storage.get_user(member_id))})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/messages/search')
|
||||
async def search_messages(guild_id):
|
||||
"""Search messages in a guild.
|
||||
|
||||
This is an undocumented route.
|
||||
"""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
|
|
@ -474,6 +262,7 @@ async def search_messages(guild_id):
|
|||
|
||||
@bp.route('/<int:guild_id>/ack', methods=['POST'])
|
||||
async def ack_guild(guild_id):
|
||||
"""ACKnowledge all messages in the guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ async def use_invite(invite_code):
|
|||
})
|
||||
|
||||
# subscribe new member to guild, so they get events n stuff
|
||||
app.dispatcher.sub_guild(guild_id, user_id)
|
||||
await app.dispatcher.sub('guild', guild_id, user_id)
|
||||
|
||||
# tell the new member that theres the guild it just joined.
|
||||
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
||||
|
|
|
|||
|
|
@ -219,9 +219,11 @@ async def get_me_guilds():
|
|||
partial = await app.db.fetchrow("""
|
||||
SELECT id::text, name, icon, owner_id
|
||||
FROM guilds
|
||||
WHERE guild_id = $1
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
|
||||
partial = dict(partial)
|
||||
|
||||
# TODO: partial['permissions']
|
||||
partial['owner'] = partial['owner_id'] == user_id
|
||||
partial.pop('owner_id')
|
||||
|
|
@ -279,10 +281,11 @@ async def put_note(target_id: int):
|
|||
INSERT INTO notes (user_id, target_id, note)
|
||||
VALUES ($1, $2, $3)
|
||||
|
||||
ON CONFLICT DO UPDATE SET
|
||||
ON CONFLICT ON CONSTRAINT notes_pkey
|
||||
DO UPDATE SET
|
||||
note = $3
|
||||
WHERE
|
||||
user_id = $1 AND target_id = $2
|
||||
WHERE notes.user_id = $1
|
||||
AND notes.target_id = $2
|
||||
""", user_id, target_id, note)
|
||||
|
||||
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
|
||||
|
|
@ -315,7 +318,8 @@ async def patch_current_settings():
|
|||
await app.db.execute(f"""
|
||||
UPDATE user_settings
|
||||
SET {key}=$1
|
||||
""", j[key])
|
||||
WHERE id = $2
|
||||
""", j[key], user_id)
|
||||
|
||||
settings = await app.storage.get_user_settings(user_id)
|
||||
await app.dispatcher.dispatch_user(
|
||||
|
|
@ -444,20 +448,20 @@ async def patch_guild_settings(guild_id: int):
|
|||
continue
|
||||
|
||||
for field in chan_overrides:
|
||||
res = await app.db.execute(f"""
|
||||
UPDATE guild_settings_channel_overrides
|
||||
SET {field} = $1
|
||||
WHERE user_id = $2
|
||||
AND guild_id = $3
|
||||
AND channel_id = $4
|
||||
""", chan_overrides[field], user_id, guild_id, chan_id)
|
||||
|
||||
if res == 'UPDATE 0':
|
||||
await app.db.execute(f"""
|
||||
INSERT INTO guild_settings_channel_overrides
|
||||
(user_id, guild_id, channel_id, {field})
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", user_id, guild_id, chan_id, chan_overrides[field])
|
||||
await app.db.execute(f"""
|
||||
INSERT INTO guild_settings_channel_overrides
|
||||
(user_id, guild_id, channel_id, {field})
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
ON CONFLICT
|
||||
ON CONSTRAINT guild_settings_channel_overrides_pkey
|
||||
DO
|
||||
UPDATE
|
||||
SET {field} = $4
|
||||
WHERE guild_settings_channel_overrides.user_id = $1
|
||||
AND guild_settings_channel_overrides.guild_id = $2
|
||||
AND guild_settings_channel_overrides.channel_id = $3
|
||||
""", user_id, guild_id, chan_id, chan_overrides[field])
|
||||
|
||||
settings = await app.storage.get_guild_settings_one(user_id, guild_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from typing import List, Any
|
|||
from logbook import Logger
|
||||
|
||||
from .pubsub import GuildDispatcher, MemberDispatcher, \
|
||||
UserDispatcher, ChannelDispatcher, FriendDispatcher
|
||||
UserDispatcher, ChannelDispatcher, FriendDispatcher, \
|
||||
LazyGuildDispatcher
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -35,6 +36,7 @@ class EventDispatcher:
|
|||
'channel': ChannelDispatcher(self),
|
||||
'user': UserDispatcher(self),
|
||||
'friend': FriendDispatcher(self),
|
||||
'lazy_guild': LazyGuildDispatcher(self),
|
||||
}
|
||||
|
||||
async def action(self, backend_str: str, action: str, key, identifier):
|
||||
|
|
@ -104,6 +106,15 @@ class EventDispatcher:
|
|||
for key in keys:
|
||||
await self.dispatch(backend_str, key, *args, **kwargs)
|
||||
|
||||
async def dispatch_filter(self, backend_str: str,
|
||||
key: Any, func, *args):
|
||||
"""Dispatch to a backend that only accepts
|
||||
(event, data) arguments with an optional filter
|
||||
function."""
|
||||
backend = self.backends[backend_str]
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.dispatch_filter(key, func, *args)
|
||||
|
||||
async def reset(self, backend_str: str, key: Any):
|
||||
"""Reset the bucket in the given backend."""
|
||||
backend = self.backends[backend_str]
|
||||
|
|
|
|||
|
|
@ -29,16 +29,24 @@ class NotFound(LitecordError):
|
|||
status_code = 404
|
||||
|
||||
|
||||
class GuildNotFound(LitecordError):
|
||||
status_code = 404
|
||||
class GuildNotFound(NotFound):
|
||||
error_code = 10004
|
||||
|
||||
|
||||
class ChannelNotFound(LitecordError):
|
||||
status_code = 404
|
||||
class ChannelNotFound(NotFound):
|
||||
error_code = 10003
|
||||
|
||||
|
||||
class MessageNotFound(LitecordError):
|
||||
status_code = 404
|
||||
class MessageNotFound(NotFound):
|
||||
error_code = 10008
|
||||
|
||||
|
||||
class Ratelimited(LitecordError):
|
||||
status_code = 429
|
||||
|
||||
|
||||
class MissingPermissions(Forbidden):
|
||||
error_code = 50013
|
||||
|
||||
|
||||
class WebsocketClose(Exception):
|
||||
|
|
|
|||
|
|
@ -1,18 +1,68 @@
|
|||
import asyncio
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from logbook import Logger
|
||||
|
||||
from .state import GatewayState
|
||||
from litecord.gateway.state import GatewayState
|
||||
from litecord.gateway.opcodes import OP
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
class ManagerClose(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class StateDictWrapper:
|
||||
"""Wrap a mapping so that any kind of access to the mapping while the
|
||||
state manager is closed raises a ManagerClose error"""
|
||||
def __init__(self, state_manager, mapping):
|
||||
self.state_manager = state_manager
|
||||
self._map = mapping
|
||||
|
||||
def _check_closed(self):
|
||||
if self.state_manager.closed:
|
||||
raise ManagerClose()
|
||||
|
||||
def __getitem__(self, key):
|
||||
self._check_closed()
|
||||
return self._map[key]
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._check_closed()
|
||||
del self._map[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if not self.state_manager.accept_new:
|
||||
raise ManagerClose()
|
||||
|
||||
self._check_closed()
|
||||
self._map[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return self._map.__iter__()
|
||||
|
||||
def pop(self, key):
|
||||
return self._map.pop(key)
|
||||
|
||||
def values(self):
|
||||
return self._map.values()
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""Manager for gateway state information."""
|
||||
|
||||
def __init__(self):
|
||||
#: closed flag
|
||||
self.closed = False
|
||||
|
||||
#: accept new states?
|
||||
self.accept_new = True
|
||||
|
||||
# {
|
||||
# user_id: {
|
||||
# session_id: GatewayState,
|
||||
|
|
@ -20,7 +70,10 @@ class StateManager:
|
|||
# },
|
||||
# user_id_2: {}, ...
|
||||
# }
|
||||
self.states = defaultdict(dict)
|
||||
self.states = StateDictWrapper(self, defaultdict(dict))
|
||||
|
||||
#: raw mapping from session ids to GatewayState
|
||||
self.states_raw = StateDictWrapper(self, {})
|
||||
|
||||
def insert(self, state: GatewayState):
|
||||
"""Insert a new state object."""
|
||||
|
|
@ -28,6 +81,7 @@ class StateManager:
|
|||
|
||||
log.debug('inserting state: {!r}', state)
|
||||
user_states[state.session_id] = state
|
||||
self.states_raw[state.session_id] = state
|
||||
|
||||
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
||||
"""Fetch a state object from the manager.
|
||||
|
|
@ -40,11 +94,20 @@ class StateManager:
|
|||
"""
|
||||
return self.states[user_id][session_id]
|
||||
|
||||
def fetch_raw(self, session_id: str) -> GatewayState:
|
||||
"""Fetch a single state given the Session ID."""
|
||||
return self.states_raw[session_id]
|
||||
|
||||
def remove(self, state):
|
||||
"""Remove a state from the registry"""
|
||||
if not state:
|
||||
return
|
||||
|
||||
try:
|
||||
self.states_raw.pop(state.session_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
log.debug('removing state: {!r}', state)
|
||||
self.states[state.user_id].pop(state.session_id)
|
||||
|
|
@ -100,3 +163,54 @@ class StateManager:
|
|||
states.extend(member_states)
|
||||
|
||||
return states
|
||||
|
||||
async def shutdown_single(self, state: GatewayState):
|
||||
"""Send OP Reconnect to a single connection."""
|
||||
websocket = state.ws
|
||||
|
||||
await websocket.send({
|
||||
'op': OP.RECONNECT
|
||||
})
|
||||
|
||||
# wait 200ms
|
||||
# so that the client has time to process
|
||||
# our payload then close the connection
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
try:
|
||||
# try to close the connection ourselves
|
||||
await websocket.ws.close(
|
||||
code=4000,
|
||||
reason='litecord shutting down'
|
||||
)
|
||||
except ConnectionClosed:
|
||||
log.info('client {} already closed', state)
|
||||
|
||||
def gen_close_tasks(self):
|
||||
"""Generate the tasks that will order the clients
|
||||
to reconnect.
|
||||
|
||||
This is required to be ran before :meth:`StateManager.close`,
|
||||
since this function doesn't wait for the tasks to complete.
|
||||
"""
|
||||
|
||||
self.accept_new = False
|
||||
|
||||
#: store the shutdown tasks
|
||||
tasks = []
|
||||
|
||||
for state in self.states_raw.values():
|
||||
if not state.ws:
|
||||
continue
|
||||
|
||||
tasks.append(
|
||||
self.shutdown_single(state)
|
||||
)
|
||||
|
||||
log.info('made {} shutdown tasks', len(tasks))
|
||||
|
||||
return tasks
|
||||
|
||||
def close(self):
|
||||
"""Close the state manager."""
|
||||
self.closed = True
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple(
|
|||
)
|
||||
|
||||
WebsocketObjects = collections.namedtuple(
|
||||
'WebsocketObjects', 'db state_manager storage loop dispatcher presence'
|
||||
'WebsocketObjects', ('db', 'state_manager', 'storage',
|
||||
'loop', 'dispatcher', 'presence', 'ratelimiter')
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -44,8 +45,38 @@ def encode_etf(payload) -> str:
|
|||
return earl.pack(payload)
|
||||
|
||||
|
||||
def _etf_decode_dict(data):
|
||||
# NOTE: this is a very slow implementation to
|
||||
# decode the dictionary.
|
||||
|
||||
if isinstance(data, bytes):
|
||||
return data.decode()
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
_copy = dict(data)
|
||||
result = {}
|
||||
|
||||
for key in _copy.keys():
|
||||
# assuming key is bytes rn.
|
||||
new_k = key.decode()
|
||||
|
||||
# maybe nested dicts, so...
|
||||
result[new_k] = _etf_decode_dict(data[key])
|
||||
|
||||
return result
|
||||
|
||||
def decode_etf(data: bytes):
|
||||
return earl.unpack(data)
|
||||
res = earl.unpack(data)
|
||||
|
||||
if isinstance(res, bytes):
|
||||
return data.decode()
|
||||
|
||||
if isinstance(res, dict):
|
||||
return _etf_decode_dict(res)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class GatewayWebsocket:
|
||||
|
|
@ -108,6 +139,11 @@ class GatewayWebsocket:
|
|||
else:
|
||||
await self.ws.send(encoded.decode())
|
||||
|
||||
def _check_ratelimit(self, key: str, ratelimit_key: str):
|
||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||
return bucket.update_rate_limit()
|
||||
|
||||
async def _hb_wait(self, interval: int):
|
||||
"""Wait heartbeat"""
|
||||
# if the client heartbeats in time,
|
||||
|
|
@ -312,6 +348,14 @@ class GatewayWebsocket:
|
|||
|
||||
async def update_status(self, status: dict):
|
||||
"""Update the status of the current websocket connection."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
if self._check_ratelimit('presence', self.state.session_id):
|
||||
# Presence Updates beyond the ratelimit
|
||||
# are just silently dropped.
|
||||
return
|
||||
|
||||
if status is None:
|
||||
status = {
|
||||
'afk': False,
|
||||
|
|
@ -365,6 +409,15 @@ class GatewayWebsocket:
|
|||
'op': OP.HEARTBEAT_ACK,
|
||||
})
|
||||
|
||||
async def _connect_ratelimit(self, user_id: int):
|
||||
if self._check_ratelimit('connect', user_id):
|
||||
await self.invalidate_session(False)
|
||||
raise WebsocketClose(4009, 'You are being ratelimited.')
|
||||
|
||||
if self._check_ratelimit('session', user_id):
|
||||
await self.invalidate_session(False)
|
||||
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.')
|
||||
|
||||
async def handle_2(self, payload: Dict[str, Any]):
|
||||
"""Handle the OP 2 Identify packet."""
|
||||
try:
|
||||
|
|
@ -384,6 +437,8 @@ class GatewayWebsocket:
|
|||
except (Unauthorized, Forbidden):
|
||||
raise WebsocketClose(4004, 'Authentication failed')
|
||||
|
||||
await self._connect_ratelimit(user_id)
|
||||
|
||||
bot = await self.ext.db.fetchval("""
|
||||
SELECT bot FROM users
|
||||
WHERE id = $1
|
||||
|
|
@ -641,9 +696,11 @@ class GatewayWebsocket:
|
|||
|
||||
This is the known structure of GUILD_MEMBER_LIST_UPDATE:
|
||||
|
||||
group_id = 'online' | 'offline' | role_id (string)
|
||||
|
||||
sync_item = {
|
||||
'group': {
|
||||
'id': string, // 'online' | 'offline' | any role id
|
||||
'id': group_id,
|
||||
'count': num
|
||||
}
|
||||
} | {
|
||||
|
|
@ -653,7 +710,7 @@ class GatewayWebsocket:
|
|||
list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE'
|
||||
|
||||
list_data = {
|
||||
'id': "everyone" // ??
|
||||
'id': channel_id | 'everyone',
|
||||
'guild_id': guild_id,
|
||||
|
||||
'ops': [
|
||||
|
|
@ -666,10 +723,10 @@ class GatewayWebsocket:
|
|||
// exists if op = 'SYNC'
|
||||
'items': sync_item[],
|
||||
|
||||
// exists if op = 'INSERT' or 'DELETE'
|
||||
// exists if op == 'INSERT' | 'DELETE' | 'UPDATE'
|
||||
'index': num,
|
||||
|
||||
// exists if op = 'INSERT'
|
||||
// exists if op == 'INSERT' | 'UPDATE'
|
||||
'item': sync_item,
|
||||
}
|
||||
],
|
||||
|
|
@ -678,31 +735,11 @@ class GatewayWebsocket:
|
|||
// separately from the online list?
|
||||
'groups': [
|
||||
{
|
||||
'id': string // 'online' | 'offline' | any role id
|
||||
'id': group_id
|
||||
'count': num
|
||||
}, ...
|
||||
]
|
||||
}
|
||||
|
||||
# Implementation defails.
|
||||
|
||||
Lazy guilds are complicated to deal with in the backend level
|
||||
as there are a lot of computation to be done for each request.
|
||||
|
||||
The current implementation is rudimentary and does not account
|
||||
for any roles inside the guild.
|
||||
|
||||
A correct implementation would take account of roles and make
|
||||
the correct groups on list_data:
|
||||
|
||||
For each channel in lazy_request['channels']:
|
||||
- get all roles that have Read Messages on the channel:
|
||||
- Also fetch their member counts, as it'll be important
|
||||
- with the role list, order them like you normally would
|
||||
(by their role priority)
|
||||
- based on the channel's range's min and max and the ordered
|
||||
role list, you can get the roles wanted for your list_data reply.
|
||||
- make new groups ONLY when the role is hoisted.
|
||||
"""
|
||||
data = payload['d']
|
||||
|
||||
|
|
@ -713,65 +750,16 @@ class GatewayWebsocket:
|
|||
if guild_id not in gids:
|
||||
return
|
||||
|
||||
member_ids = await self.storage.get_member_ids(guild_id)
|
||||
log.debug('lazy: loading {} members', len(member_ids))
|
||||
# make shard query
|
||||
lazy_guilds = self.ext.dispatcher.backends['lazy_guild']
|
||||
|
||||
# the current implementation is rudimentary and only
|
||||
# generates two groups: online and offline, using
|
||||
# PresenceManager.guild_presences to fill list_data.
|
||||
for chan_id, ranges in data.get('channels', {}).items():
|
||||
chan_id = int(chan_id)
|
||||
member_list = await lazy_guilds.get_gml(chan_id)
|
||||
|
||||
# this also doesn't take account the channels in lazy_request.
|
||||
|
||||
guild_presences = await self.presence.guild_presences(member_ids,
|
||||
guild_id)
|
||||
|
||||
online = [{'member': p}
|
||||
for p in guild_presences
|
||||
if p['status'] == 'online']
|
||||
offline = [{'member': p}
|
||||
for p in guild_presences
|
||||
if p['status'] == 'offline']
|
||||
|
||||
log.debug('lazy: {} presences, online={}, offline={}',
|
||||
len(guild_presences),
|
||||
len(online),
|
||||
len(offline))
|
||||
|
||||
# construct items in the WORST WAY POSSIBLE.
|
||||
items = [{
|
||||
'group': {
|
||||
'id': 'online',
|
||||
'count': len(online),
|
||||
}
|
||||
}] + online + [{
|
||||
'group': {
|
||||
'id': 'offline',
|
||||
'count': len(offline),
|
||||
}
|
||||
}] + offline
|
||||
|
||||
await self.dispatch('GUILD_MEMBER_LIST_UPDATE', {
|
||||
'id': 'everyone',
|
||||
'guild_id': data['guild_id'],
|
||||
'groups': [
|
||||
{
|
||||
'id': 'online',
|
||||
'count': len(online),
|
||||
},
|
||||
{
|
||||
'id': 'offline',
|
||||
'count': len(offline),
|
||||
}
|
||||
],
|
||||
|
||||
'ops': [
|
||||
{
|
||||
'range': [0, 99],
|
||||
'op': 'SYNC',
|
||||
'items': items
|
||||
}
|
||||
]
|
||||
})
|
||||
await member_list.shard_query(
|
||||
self.state.session_id, ranges
|
||||
)
|
||||
|
||||
async def process_message(self, payload):
|
||||
"""Process a single message coming in from the client."""
|
||||
|
|
@ -788,17 +776,36 @@ class GatewayWebsocket:
|
|||
|
||||
await handler(payload)
|
||||
|
||||
async def _msg_ratelimit(self):
|
||||
if self._check_ratelimit('messages', self.state.session_id):
|
||||
raise WebsocketClose(4008, 'You are being ratelimited.')
|
||||
|
||||
async def listen_messages(self):
|
||||
"""Listen for messages coming in from the websocket."""
|
||||
|
||||
# close anyone trying to login while the
|
||||
# server is shutting down
|
||||
if self.ext.state_manager.closed:
|
||||
raise WebsocketClose(4000, 'state manager closed')
|
||||
|
||||
if not self.ext.state_manager.accept_new:
|
||||
raise WebsocketClose(4000, 'state manager closed for new')
|
||||
|
||||
while True:
|
||||
message = await self.ws.recv()
|
||||
if len(message) > 4096:
|
||||
raise DecodeError('Payload length exceeded')
|
||||
|
||||
if self.state:
|
||||
await self._msg_ratelimit()
|
||||
|
||||
payload = self.decoder(message)
|
||||
await self.process_message(payload)
|
||||
|
||||
def _cleanup(self):
|
||||
for task in self.wsp.tasks.values():
|
||||
task.cancel()
|
||||
|
||||
if self.state:
|
||||
self.ext.state_manager.remove(self.state)
|
||||
self.state.ws = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,242 @@
|
|||
import ctypes
|
||||
|
||||
from quart import current_app as app, request
|
||||
|
||||
# so we don't keep repeating the same
|
||||
# type for all the fields
|
||||
_i = ctypes.c_uint8
|
||||
|
||||
class _RawPermsBits(ctypes.LittleEndianStructure):
|
||||
"""raw bitfield for discord's permission number."""
|
||||
_fields_ = [
|
||||
('create_invites', _i, 1),
|
||||
('kick_members', _i, 1),
|
||||
('ban_members', _i, 1),
|
||||
('administrator', _i, 1),
|
||||
('manage_channels', _i, 1),
|
||||
('manage_guild', _i, 1),
|
||||
('add_reactions', _i, 1),
|
||||
('view_audit_log', _i, 1),
|
||||
('priority_speaker', _i, 1),
|
||||
('_unused1', _i, 1),
|
||||
('read_messages', _i, 1),
|
||||
('send_messages', _i, 1),
|
||||
('send_tts', _i, 1),
|
||||
('manage_messages', _i, 1),
|
||||
('embed_links', _i, 1),
|
||||
('attach_files', _i, 1),
|
||||
('read_history', _i, 1),
|
||||
('mention_everyone', _i, 1),
|
||||
('external_emojis', _i, 1),
|
||||
('_unused2', _i, 1),
|
||||
('connect', _i, 1),
|
||||
('speak', _i, 1),
|
||||
('mute_members', _i, 1),
|
||||
('deafen_members', _i, 1),
|
||||
('move_members', _i, 1),
|
||||
('use_voice_activation', _i, 1),
|
||||
('change_nickname', _i, 1),
|
||||
('manage_nicknames', _i, 1),
|
||||
('manage_roles', _i, 1),
|
||||
('manage_webhooks', _i, 1),
|
||||
('manage_emojis', _i, 1),
|
||||
]
|
||||
|
||||
|
||||
class Permissions(ctypes.Union):
|
||||
_fields_ = [
|
||||
('bits', _RawPermsBits),
|
||||
('binary', ctypes.c_uint64),
|
||||
]
|
||||
|
||||
def __init__(self, val: int):
|
||||
self.binary = val
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Permissions binary={self.binary}>'
|
||||
|
||||
def __int__(self):
|
||||
return self.binary
|
||||
|
||||
def numby(self):
|
||||
return self.binary
|
||||
|
||||
|
||||
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
|
||||
|
||||
|
||||
async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
|
||||
"""Get the raw :class:`Permissions` object for a role."""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
perms = await storage.db.fetchval("""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
WHERE guild_id = $1 AND id = $2
|
||||
""", guild_id, role_id)
|
||||
|
||||
return Permissions(perms)
|
||||
|
||||
|
||||
async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
|
||||
"""Compute the base permissions for a given user.
|
||||
|
||||
Base permissions are
|
||||
(permissions from @everyone role) +
|
||||
(permissions from any other role the member has)
|
||||
|
||||
This will give ALL_PERMISSIONS if base permissions
|
||||
has the Administrator bit set.
|
||||
"""
|
||||
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
owner_id = await storage.db.fetchval("""
|
||||
SELECT owner_id
|
||||
FROM guilds
|
||||
WHERE id = $1
|
||||
""", guild_id)
|
||||
|
||||
if owner_id == member_id:
|
||||
return ALL_PERMISSIONS
|
||||
|
||||
# get permissions for @everyone
|
||||
permissions = await get_role_perms(guild_id, guild_id, storage)
|
||||
|
||||
role_ids = await storage.db.fetch("""
|
||||
SELECT role_id
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
role_perms = []
|
||||
|
||||
for row in role_ids:
|
||||
rperm = await storage.db.fetchval("""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
WHERE id = $1
|
||||
""", row['role_id'])
|
||||
|
||||
role_perms.append(rperm)
|
||||
|
||||
for perm_num in role_perms:
|
||||
permissions.binary |= perm_num
|
||||
|
||||
if permissions.bits.administrator:
|
||||
return ALL_PERMISSIONS
|
||||
|
||||
return permissions
|
||||
|
||||
|
||||
def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
|
||||
# we make a copy of the binary representation
|
||||
# so we don't modify the old perms in-place
|
||||
# which could be an unwanted side-effect
|
||||
result = perms.binary
|
||||
|
||||
# negate the permissions that are denied
|
||||
result &= ~overwrite['deny']
|
||||
|
||||
# combine the permissions that are allowed
|
||||
result |= overwrite['allow']
|
||||
|
||||
return Permissions(result)
|
||||
|
||||
|
||||
def overwrite_find_mix(perms: Permissions, overwrites: dict,
|
||||
target_id: int) -> Permissions:
|
||||
overwrite = overwrites.get(target_id)
|
||||
|
||||
if overwrite:
|
||||
# only mix if overwrite found
|
||||
return overwrite_mix(perms, overwrite)
|
||||
|
||||
return perms
|
||||
|
||||
|
||||
async def role_permissions(guild_id: int, role_id: int,
|
||||
channel_id: int, storage=None) -> Permissions:
|
||||
"""Get the permissions for a role, in relation to a channel"""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
perms = await get_role_perms(guild_id, role_id, storage)
|
||||
|
||||
overwrite = await storage.db.fetchrow("""
|
||||
SELECT allow, deny
|
||||
FROM channel_overwrites
|
||||
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
|
||||
""", channel_id, 1, role_id)
|
||||
|
||||
if overwrite:
|
||||
perms = overwrite_mix(perms, overwrite)
|
||||
|
||||
return perms
|
||||
|
||||
|
||||
async def compute_overwrites(base_perms, user_id, channel_id: int,
|
||||
guild_id: int = None, storage=None):
|
||||
"""Compute the permissions in the context of a channel."""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
if base_perms.bits.administrator:
|
||||
return ALL_PERMISSIONS
|
||||
|
||||
perms = base_perms
|
||||
|
||||
# list of overwrites
|
||||
overwrites = await storage.chan_overwrites(channel_id)
|
||||
|
||||
if not guild_id:
|
||||
guild_id = await storage.guild_from_channel(channel_id)
|
||||
|
||||
# make it a map for better usage
|
||||
overwrites = {o['id']: o for o in overwrites}
|
||||
|
||||
perms = overwrite_find_mix(perms, overwrites, guild_id)
|
||||
|
||||
# apply role specific overwrites
|
||||
allow, deny = 0, 0
|
||||
|
||||
# fetch roles from user and convert to int
|
||||
role_ids = await storage.get_member_role_ids(guild_id, user_id)
|
||||
role_ids = map(int, role_ids)
|
||||
|
||||
# make the allow and deny binaries
|
||||
for role_id in role_ids:
|
||||
overwrite = overwrites.get(role_id)
|
||||
if overwrite:
|
||||
allow |= overwrite['allow']
|
||||
deny |= overwrite['deny']
|
||||
|
||||
# final step for roles: mix
|
||||
perms = overwrite_mix(perms, {
|
||||
'allow': allow,
|
||||
'deny': deny
|
||||
})
|
||||
|
||||
# apply member specific overwrites
|
||||
perms = overwrite_find_mix(perms, overwrites, user_id)
|
||||
|
||||
return perms
|
||||
|
||||
|
||||
async def get_permissions(member_id, channel_id, *, storage=None):
|
||||
"""Get all the permissions for a user in a channel."""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
guild_id = await storage.guild_from_channel(channel_id)
|
||||
|
||||
# for non guild channels
|
||||
if not guild_id:
|
||||
return ALL_PERMISSIONS
|
||||
|
||||
base_perms = await base_permissions(member_id, guild_id, storage)
|
||||
|
||||
return await compute_overwrites(base_perms, member_id,
|
||||
channel_id, guild_id, storage)
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
from typing import List, Dict, Any
|
||||
from random import choice
|
||||
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
def status_cmp(status: str, other_status: str) -> bool:
|
||||
"""Compare if `status` is better than the `other_status`
|
||||
|
|
@ -100,20 +103,64 @@ class PresenceManager:
|
|||
|
||||
game = state['game']
|
||||
|
||||
await self.dispatcher.dispatch_guild(
|
||||
guild_id, 'PRESENCE_UPDATE', {
|
||||
'user': member['user'],
|
||||
'roles': member['roles'],
|
||||
'guild_id': guild_id,
|
||||
lazy_guild_store = self.dispatcher.backends['lazy_guild']
|
||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||
|
||||
'status': state['status'],
|
||||
# shards that are in lazy guilds with 'everyone'
|
||||
# enabled
|
||||
in_lazy = []
|
||||
|
||||
# rich presence stuff
|
||||
'game': game,
|
||||
'activities': [game] if game else []
|
||||
}
|
||||
for member_list in lists:
|
||||
session_ids = await member_list.pres_update(
|
||||
int(member['user']['id']),
|
||||
{
|
||||
'roles': member['roles'],
|
||||
'status': state['status'],
|
||||
'game': game
|
||||
}
|
||||
)
|
||||
|
||||
log.debug('Lazy Dispatch to {}',
|
||||
len(session_ids))
|
||||
|
||||
if member_list.channel_id == 'everyone':
|
||||
in_lazy.extend(session_ids)
|
||||
|
||||
pres_update_payload = {
|
||||
'user': member['user'],
|
||||
'roles': member['roles'],
|
||||
'guild_id': str(guild_id),
|
||||
|
||||
'status': state['status'],
|
||||
|
||||
# rich presence stuff
|
||||
'game': game,
|
||||
'activities': [game] if game else []
|
||||
}
|
||||
|
||||
def _sane_session(session_id):
|
||||
state = self.state_manager.fetch_raw(session_id)
|
||||
uid = int(member['user']['id'])
|
||||
|
||||
if not state:
|
||||
return False
|
||||
|
||||
# we don't want to send a presence update
|
||||
# to the same user
|
||||
return (state.user_id != uid and
|
||||
session_id not in in_lazy)
|
||||
|
||||
# everyone not in lazy guild mode
|
||||
# gets a PRESENCE_UPDATE
|
||||
await self.dispatcher.dispatch_filter(
|
||||
'guild', guild_id,
|
||||
_sane_session,
|
||||
|
||||
'PRESENCE_UPDATE', pres_update_payload
|
||||
)
|
||||
|
||||
return in_lazy
|
||||
|
||||
async def dispatch_pres(self, user_id: int, state: dict):
|
||||
"""Dispatch a new presence to all guilds the user is in.
|
||||
|
||||
|
|
@ -122,10 +169,12 @@ class PresenceManager:
|
|||
if state['status'] == 'invisible':
|
||||
state['status'] = 'offline'
|
||||
|
||||
# TODO: shard-aware
|
||||
guild_ids = await self.storage.get_user_guilds(user_id)
|
||||
|
||||
for guild_id in guild_ids:
|
||||
await self.dispatch_guild_pres(guild_id, user_id, state)
|
||||
await self.dispatch_guild_pres(
|
||||
guild_id, user_id, state)
|
||||
|
||||
# dispatch to all friends that are subscribed to them
|
||||
user = await self.storage.get_user(user_id)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from .member import MemberDispatcher
|
|||
from .user import UserDispatcher
|
||||
from .channel import ChannelDispatcher
|
||||
from .friend import FriendDispatcher
|
||||
from .lazy_guild import LazyGuildDispatcher
|
||||
|
||||
__all__ = ['GuildDispatcher', 'MemberDispatcher',
|
||||
'UserDispatcher', 'ChannelDispatcher',
|
||||
'FriendDispatcher']
|
||||
'FriendDispatcher', 'LazyGuildDispatcher']
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Any
|
||||
from collections import defaultdict
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,14 @@ class Dispatcher:
|
|||
"""Unsubscribe an elemtnt from the channel/key."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def dispatch_filter(self, _key, _func, *_args):
|
||||
"""Selectively dispatch to the list of subscribed users.
|
||||
|
||||
The selection logic is completly arbitraty and up to the
|
||||
Pub/Sub backend.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def dispatch(self, _key, *_args):
|
||||
"""Dispatch an event to the given channel/key."""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from logbook import Logger
|
||||
|
|
@ -47,6 +46,8 @@ class GuildDispatcher(DispatcherWithState):
|
|||
# when subbing a user to the guild, we should sub them
|
||||
# to every channel they have access to, in the guild.
|
||||
|
||||
# TODO: check for permissions
|
||||
|
||||
await self._chan_action('sub', guild_id, user_id)
|
||||
|
||||
async def unsub(self, guild_id: int, user_id: int):
|
||||
|
|
@ -56,9 +57,10 @@ class GuildDispatcher(DispatcherWithState):
|
|||
# same thing happening from sub() happens on unsub()
|
||||
await self._chan_action('unsub', guild_id, user_id)
|
||||
|
||||
async def dispatch(self, guild_id: int,
|
||||
event: str, data: Any):
|
||||
"""Dispatch an event to all subscribers of the guild."""
|
||||
async def dispatch_filter(self, guild_id: int, func,
|
||||
event: str, data: Any):
|
||||
"""Selectively dispatch to session ids that have
|
||||
func(session_id) true."""
|
||||
user_ids = self.state[guild_id]
|
||||
dispatched = 0
|
||||
|
||||
|
|
@ -75,8 +77,22 @@ class GuildDispatcher(DispatcherWithState):
|
|||
await self.unsub(guild_id, user_id)
|
||||
continue
|
||||
|
||||
# filter the ones that matter
|
||||
states = list(filter(
|
||||
lambda state: func(state.session_id), states
|
||||
))
|
||||
|
||||
dispatched += await self._dispatch_states(
|
||||
states, event, data)
|
||||
|
||||
log.info('Dispatched {} {!r} to {} states',
|
||||
guild_id, event, dispatched)
|
||||
|
||||
async def dispatch(self, guild_id: int,
|
||||
event: str, data: Any):
|
||||
"""Dispatch an event to all subscribers of the guild."""
|
||||
await self.dispatch_filter(
|
||||
guild_id,
|
||||
lambda sess_id: True,
|
||||
event, data,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,736 @@
|
|||
"""
|
||||
Main code for Lazy Guild implementation in litecord.
|
||||
"""
|
||||
import pprint
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.pubsub.dispatcher import Dispatcher
|
||||
from litecord.permissions import (
|
||||
Permissions, overwrite_find_mix, get_permissions, role_permissions
|
||||
)
|
||||
from litecord.utils import index_by_func
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
GroupID = Union[int, str]
|
||||
Presence = Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
"""Store information about a specific group."""
|
||||
gid: GroupID
|
||||
name: str
|
||||
position: int
|
||||
permissions: Permissions
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemberList:
|
||||
"""Total information on the guild's member list."""
|
||||
groups: List[GroupInfo] = None
|
||||
group_info: Dict[GroupID, GroupInfo] = None
|
||||
data: Dict[GroupID, Presence] = None
|
||||
overwrites: Dict[int, Dict[str, Any]] = None
|
||||
|
||||
def __bool__(self):
|
||||
"""Return if the current member list is fully initialized."""
|
||||
list_dict = asdict(self)
|
||||
return all(v is not None for v in list_dict.values())
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over all groups in the correct order.
|
||||
|
||||
Yields a tuple containing :class:`GroupInfo` and
|
||||
the List[Presence] for the group.
|
||||
"""
|
||||
if not self.groups:
|
||||
return
|
||||
|
||||
for group in self.groups:
|
||||
yield group, self.data[group.gid]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operation:
|
||||
"""Represents a member list operation."""
|
||||
list_op: str
|
||||
params: Dict[str, Any]
|
||||
|
||||
@property
|
||||
def to_dict(self) -> dict:
|
||||
res = {
|
||||
'op': self.list_op
|
||||
}
|
||||
|
||||
if self.list_op == 'SYNC':
|
||||
res['items'] = self.params['items']
|
||||
|
||||
if self.list_op in ('SYNC', 'INVALIDATE'):
|
||||
res['range'] = self.params['range']
|
||||
|
||||
if self.list_op in ('INSERT', 'DELETE', 'UPDATE'):
|
||||
res['index'] = self.params['index']
|
||||
|
||||
if self.list_op in ('INSERT', 'UPDATE'):
|
||||
res['item'] = self.params['item']
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def _to_simple_group(presence: dict) -> str:
|
||||
"""Return a simple group (not a role), given a presence."""
|
||||
return 'offline' if presence['status'] == 'offline' else 'online'
|
||||
|
||||
|
||||
def display_name(member_nicks: Dict[str, str], presence: Presence) -> str:
|
||||
"""Return the display name of a presence.
|
||||
|
||||
Used to sort groups.
|
||||
"""
|
||||
uid = presence['user']['id']
|
||||
|
||||
uname = presence['user']['username']
|
||||
nick = member_nicks.get(uid)
|
||||
|
||||
return nick or uname
|
||||
|
||||
|
||||
class GuildMemberList:
|
||||
"""This class stores the current member list information
|
||||
for a guild (by channel).
|
||||
|
||||
As channels can have different sets of roles that can
|
||||
read them and so, different lists, this is more of a
|
||||
"channel member list" than a guild member list.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
main_lg: LazyGuildDispatcher
|
||||
Main instance of :class:`LazyGuildDispatcher`,
|
||||
so that we're able to use things such as :class:`Storage`.
|
||||
guild_id: int
|
||||
The Guild ID this instance is referring to.
|
||||
channel_id: int
|
||||
The Channel ID this instance is referring to.
|
||||
member_list: List
|
||||
The actual member list information.
|
||||
state: set
|
||||
The set of session IDs that are subscribed to the guild.
|
||||
|
||||
User IDs being used as the identifier in GuildMemberList
|
||||
is a wrong assumption. It is true Discord rolled out
|
||||
lazy guilds to all of the userbase, but users that are bots,
|
||||
for example, can still rely on PRESENCE_UPDATEs.
|
||||
"""
|
||||
def __init__(self, guild_id: int,
|
||||
channel_id: int, main_lg):
|
||||
self.guild_id = guild_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
self.main = main_lg
|
||||
self.list = MemberList(None, None, None, None)
|
||||
|
||||
#: store the states that are subscribed to the list
|
||||
# type is{session_id: set[list]}
|
||||
self.state = defaultdict(set)
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
"""Get the global :class:`Storage` instance."""
|
||||
return self.main.app.storage
|
||||
|
||||
@property
|
||||
def presence(self):
|
||||
"""Get the global :class:`PresenceManager` instance."""
|
||||
return self.main.app.presence
|
||||
|
||||
@property
|
||||
def state_man(self):
|
||||
"""Get the global :class:`StateManager` instance."""
|
||||
return self.main.app.state_manager
|
||||
|
||||
@property
|
||||
def list_id(self):
|
||||
"""get the id of the member list."""
|
||||
return ('everyone'
|
||||
if self.channel_id == self.guild_id
|
||||
else str(self.channel_id))
|
||||
|
||||
def _set_empty_list(self):
|
||||
"""Set the member list as being empty."""
|
||||
self.list = MemberList(None, None, None, None)
|
||||
|
||||
async def _init_check(self):
|
||||
"""Check if the member list is initialized before
|
||||
messing with it."""
|
||||
if not self.list:
|
||||
await self._init_member_list()
|
||||
|
||||
async def _fetch_overwrites(self):
|
||||
overwrites = await self.storage.chan_overwrites(self.channel_id)
|
||||
overwrites = {int(ov['id']): ov for ov in overwrites}
|
||||
self.list.overwrites = overwrites
|
||||
|
||||
def _calc_member_group(self, roles: List[int], status: str):
|
||||
"""Calculate the best fitting group for a member,
|
||||
given their roles and their current status."""
|
||||
try:
|
||||
# the first group in the list
|
||||
# that the member is entitled to is
|
||||
# the selected group for the member.
|
||||
group_id = next(g.gid for g in self.list.groups
|
||||
if g.gid in roles)
|
||||
except StopIteration:
|
||||
# no group was found, so we fallback
|
||||
# to simple group"
|
||||
group_id = _to_simple_group({'status': status})
|
||||
|
||||
return group_id
|
||||
|
||||
async def get_roles(self) -> List[GroupInfo]:
|
||||
"""Get role information, but only:
|
||||
- the ID
|
||||
- the name
|
||||
- the position
|
||||
- the permissions
|
||||
|
||||
of all HOISTED roles AND roles that
|
||||
have permissions to read the channel
|
||||
being referred to this :class:`GuildMemberList`
|
||||
instance.
|
||||
|
||||
The list is sorted by position.
|
||||
"""
|
||||
roledata = await self.storage.db.fetch("""
|
||||
SELECT id, name, hoist, position, permissions
|
||||
FROM roles
|
||||
WHERE guild_id = $1
|
||||
""", self.guild_id)
|
||||
|
||||
hoisted = [
|
||||
GroupInfo(row['id'], row['name'],
|
||||
row['position'], row['permissions'])
|
||||
for row in roledata if row['hoist']
|
||||
]
|
||||
|
||||
# sort role list by position
|
||||
hoisted = sorted(hoisted, key=lambda group: group.position)
|
||||
|
||||
# we need to store them for later on
|
||||
# for members
|
||||
await self._fetch_overwrites()
|
||||
|
||||
def _can_read_chan(group: GroupInfo):
|
||||
# get the base role perms
|
||||
role_perms = group.permissions
|
||||
|
||||
# then the final perms for that role if
|
||||
# any overwrite exists in the channel
|
||||
final_perms = overwrite_find_mix(
|
||||
role_perms, self.list.overwrites, group.gid)
|
||||
|
||||
# update the group's permissions
|
||||
# with the mixed ones
|
||||
group.permissions = final_perms
|
||||
|
||||
# if the role can read messages, then its
|
||||
# part of the group.
|
||||
return final_perms.bits.read_messages
|
||||
|
||||
return list(filter(_can_read_chan, hoisted))
|
||||
|
||||
async def set_groups(self):
|
||||
"""Get the groups for the member list."""
|
||||
role_groups = await self.get_roles()
|
||||
role_ids = [g.gid for g in role_groups]
|
||||
|
||||
self.list.groups = role_ids + ['online', 'offline']
|
||||
|
||||
# inject default groups 'online' and 'offline'
|
||||
self.list.groups = role_ids + [
|
||||
GroupInfo('online', 'online', -1, -1),
|
||||
GroupInfo('offline', 'offline', -1, -1)
|
||||
]
|
||||
self.list.group_info = {g.gid: g for g in role_groups}
|
||||
|
||||
async def get_group(self, member_id: int,
|
||||
roles: List[Union[str, int]],
|
||||
status: str) -> int:
|
||||
"""Return a fitting group ID for the user."""
|
||||
member_roles = list(map(int, roles))
|
||||
|
||||
# get the member's permissions relative to the channel
|
||||
# (accounting for channel overwrites)
|
||||
member_perms = await get_permissions(
|
||||
member_id, self.channel_id, storage=self.storage)
|
||||
|
||||
if not member_perms.bits.read_messages:
|
||||
return None
|
||||
|
||||
# if the member is offline, we
|
||||
# default give them the offline group.
|
||||
group_id = ('offline' if status == 'offline'
|
||||
else self._calc_member_group(member_roles, status))
|
||||
|
||||
return group_id
|
||||
|
||||
async def _pass_1(self, guild_presences: List[Presence]):
|
||||
"""First pass on generating the member list.
|
||||
|
||||
This assigns all presences a single group.
|
||||
"""
|
||||
for presence in guild_presences:
|
||||
member_id = int(presence['user']['id'])
|
||||
|
||||
group_id = await self.get_group(
|
||||
member_id, presence['roles'], presence['status']
|
||||
)
|
||||
|
||||
self.list.data[group_id].append(presence)
|
||||
|
||||
async def get_member_nicks_dict(self) -> dict:
|
||||
"""Get a dictionary with nickname information."""
|
||||
members = await self.storage.get_member_data(self.guild_id)
|
||||
|
||||
# make a dictionary of member ids to nicknames
|
||||
# so we don't need to keep querying the db on
|
||||
# every loop iteration
|
||||
member_nicks = {m['user']['id']: m.get('nick')
|
||||
for m in members}
|
||||
|
||||
return member_nicks
|
||||
|
||||
async def _sort_groups(self):
|
||||
member_nicks = await self.get_member_nicks_dict()
|
||||
|
||||
for group_members in self.list.data.values():
|
||||
|
||||
# this should update the list in-place
|
||||
group_members.sort(
|
||||
key=lambda p: display_name(member_nicks, p))
|
||||
|
||||
async def _init_member_list(self):
|
||||
"""Generate the main member list with groups."""
|
||||
member_ids = await self.storage.get_member_ids(self.guild_id)
|
||||
|
||||
guild_presences = await self.presence.guild_presences(
|
||||
member_ids, self.guild_id)
|
||||
|
||||
await self.set_groups()
|
||||
|
||||
log.debug('{} presences, {} groups',
|
||||
len(guild_presences),
|
||||
len(self.list.groups))
|
||||
|
||||
self.list.data = {group.gid: [] for group in self.list.groups}
|
||||
|
||||
# first pass: set which presences
|
||||
# go to which groups
|
||||
await self._pass_1(guild_presences)
|
||||
|
||||
# second pass: sort each group's members
|
||||
# by the display name
|
||||
await self._sort_groups()
|
||||
|
||||
@property
|
||||
def items(self) -> list:
|
||||
"""Main items list."""
|
||||
|
||||
# TODO: maybe make this stored in the list
|
||||
# so we don't need to keep regenning?
|
||||
|
||||
if not self.list:
|
||||
return []
|
||||
|
||||
res = []
|
||||
|
||||
# NOTE: maybe use map()?
|
||||
for group, presences in self.list:
|
||||
res.append({
|
||||
'group': {
|
||||
'id': group.gid,
|
||||
'count': len(presences),
|
||||
}
|
||||
})
|
||||
|
||||
for presence in presences:
|
||||
res.append({
|
||||
'member': presence
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
async def sub(self, _session_id: str):
|
||||
"""Subscribe a shard to the member list."""
|
||||
await self._init_check()
|
||||
|
||||
def unsub(self, session_id: str):
|
||||
"""Unsubscribe a shard from the member list"""
|
||||
try:
|
||||
self.state.pop(session_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# once we reach 0 subscribers,
|
||||
# we drop the current member list we have (for memory)
|
||||
# but keep the GuildMemberList running (as
|
||||
# uninitialized) for a future subscriber.
|
||||
|
||||
if not self.state:
|
||||
self._set_empty_list()
|
||||
|
||||
def get_state(self, session_id: str):
|
||||
try:
|
||||
state = self.state_man.fetch_raw(session_id)
|
||||
return state
|
||||
except KeyError:
|
||||
self.unsub(session_id)
|
||||
return
|
||||
|
||||
async def _dispatch_sess(self, session_ids: List[str],
|
||||
operations: List[Operation]):
|
||||
|
||||
# construct the payload to dispatch
|
||||
payload = {
|
||||
'id': self.list_id,
|
||||
'guild_id': str(self.guild_id),
|
||||
|
||||
'groups': [
|
||||
{
|
||||
'count': len(presences),
|
||||
'id': group.gid
|
||||
} for group, presences in self.list
|
||||
],
|
||||
|
||||
'ops': [
|
||||
operation.to_dict
|
||||
for operation in operations
|
||||
]
|
||||
}
|
||||
|
||||
states = map(self.get_state, session_ids)
|
||||
states = filter(lambda state: state is not None, states)
|
||||
|
||||
dispatched = []
|
||||
|
||||
for state in states:
|
||||
await state.ws.dispatch(
|
||||
'GUILD_MEMBER_LIST_UPDATE', payload)
|
||||
|
||||
dispatched.append(state.session_id)
|
||||
|
||||
return dispatched
|
||||
|
||||
async def shard_query(self, session_id: str, ranges: list):
|
||||
"""Send a GUILD_MEMBER_LIST_UPDATE event
|
||||
for a shard that is querying about the member list.
|
||||
|
||||
Paramteters
|
||||
-----------
|
||||
session_id: str
|
||||
The Session ID querying information.
|
||||
channel_id: int
|
||||
The Channel ID that we want information on.
|
||||
ranges: List[List[int]]
|
||||
ranges of the list that we want.
|
||||
"""
|
||||
|
||||
# a guild list with a channel id of the guild
|
||||
# represents the 'everyone' global list.
|
||||
list_id = self.list_id
|
||||
|
||||
# if everyone can read the channel,
|
||||
# we direct the request to the 'everyone' gml instance
|
||||
# instead of the current one.
|
||||
everyone_perms = await role_permissions(
|
||||
self.guild_id,
|
||||
self.guild_id,
|
||||
self.channel_id,
|
||||
storage=self.storage
|
||||
)
|
||||
|
||||
if everyone_perms.bits.read_messages and list_id != 'everyone':
|
||||
everyone_gml = await self.main.get_gml(self.guild_id)
|
||||
|
||||
return await everyone_gml.shard_query(
|
||||
session_id, ranges
|
||||
)
|
||||
|
||||
await self._init_check()
|
||||
|
||||
ops = []
|
||||
|
||||
for start, end in ranges:
|
||||
itemcount = end - start
|
||||
|
||||
# ignore incorrect ranges
|
||||
if itemcount < 0:
|
||||
continue
|
||||
|
||||
self.state[session_id].add((start, end))
|
||||
|
||||
ops.append(Operation('SYNC', {
|
||||
'range': [start, end],
|
||||
'items': self.items[start:end]
|
||||
}))
|
||||
|
||||
await self._dispatch_sess([session_id], ops)
|
||||
|
||||
def get_item_index(self, user_id: Union[str, int]):
|
||||
"""Get the item index a user is on."""
|
||||
def _get_id(item):
|
||||
# item can be a group item or a member item
|
||||
return item.get('member', {}).get('user', {}).get('id')
|
||||
|
||||
# get the updated item's index
|
||||
return index_by_func(
|
||||
lambda p: _get_id(p) == str(user_id),
|
||||
self.items
|
||||
)
|
||||
|
||||
def state_is_subbed(self, item_index, session_id: str) -> bool:
|
||||
"""Return if a state's ranges include the given
|
||||
item index."""
|
||||
|
||||
ranges = self.state[session_id]
|
||||
|
||||
for range_start, range_end in ranges:
|
||||
if range_start <= item_index <= range_end:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_subs(self, item_index: int) -> filter:
|
||||
"""Get the list of subscribed states to a given item."""
|
||||
return filter(
|
||||
lambda sess_id: self.state_is_subbed(item_index, sess_id),
|
||||
self.state.keys()
|
||||
)
|
||||
|
||||
async def _pres_update_simple(self, user_id: int):
|
||||
item_index = self.get_item_index(user_id)
|
||||
|
||||
if item_index is None:
|
||||
log.warning('lazy guild got invalid pres update uid={}',
|
||||
user_id)
|
||||
return []
|
||||
|
||||
item = self.items[item_index]
|
||||
session_ids = self.get_subs(item_index)
|
||||
|
||||
# simple update means we just give an UPDATE
|
||||
# operation
|
||||
return await self._dispatch_sess(
|
||||
session_ids,
|
||||
[
|
||||
Operation('UPDATE', {
|
||||
'index': item_index,
|
||||
'item': item,
|
||||
})
|
||||
]
|
||||
)
|
||||
|
||||
async def _pres_update_complex(self, user_id: int,
|
||||
old_group: str, old_index: int,
|
||||
new_group: str):
|
||||
"""Move a member between groups."""
|
||||
log.debug('complex update: uid={} old={} old_idx={} new={}',
|
||||
user_id, old_group, old_index, new_group)
|
||||
old_group_presences = self.list.data[old_group]
|
||||
old_item_index = self.get_item_index(user_id)
|
||||
|
||||
# make a copy of current presence to insert in the new group
|
||||
current_presence = dict(old_group_presences[old_index])
|
||||
|
||||
# step 1: remove the old presence (old_index is relative
|
||||
# to the group, and not the items list)
|
||||
del old_group_presences[old_index]
|
||||
|
||||
# we need to insert current_presence to the new group
|
||||
# but we also need to calculate its index to insert on.
|
||||
presences = self.list.data[new_group]
|
||||
|
||||
best_index = 0
|
||||
member_nicks = await self.get_member_nicks_dict()
|
||||
current_name = display_name(member_nicks, current_presence)
|
||||
|
||||
# go through each one until we find the best placement
|
||||
for presence in presences:
|
||||
name = display_name(member_nicks, presence)
|
||||
|
||||
print(name, current_name, name < current_name)
|
||||
|
||||
# TODO: check if this works
|
||||
if name < current_name:
|
||||
break
|
||||
|
||||
best_index += 1
|
||||
|
||||
# insert the presence at the index
|
||||
presences.insert(best_index + 1, current_presence)
|
||||
|
||||
new_item_index = self.get_item_index(user_id)
|
||||
|
||||
log.debug('assigned new item index {} to uid {}',
|
||||
new_item_index, user_id)
|
||||
|
||||
session_ids_old = self.get_subs(old_item_index)
|
||||
session_ids_new = self.get_subs(new_item_index)
|
||||
|
||||
# dispatch events to both the old states and
|
||||
# new states.
|
||||
return await self._dispatch_sess(
|
||||
# inefficient, but necessary since we
|
||||
# want to merge both session ids.
|
||||
list(session_ids_old) + list(session_ids_new),
|
||||
[
|
||||
Operation('DELETE', {
|
||||
'index': old_item_index,
|
||||
}),
|
||||
|
||||
Operation('INSERT', {
|
||||
'index': new_item_index,
|
||||
'item': {
|
||||
'member': current_presence
|
||||
}
|
||||
})
|
||||
]
|
||||
)
|
||||
|
||||
async def pres_update(self, user_id: int,
|
||||
partial_presence: Dict[str, Any]):
|
||||
"""Update a presence inside the member list.
|
||||
|
||||
There are 4 types of updates that can happen for a user in a group:
|
||||
- from 'offline' to any
|
||||
- from any to 'offline'
|
||||
- from any to any
|
||||
- from G to G (with G being any group)
|
||||
|
||||
any: 'online' | role_id
|
||||
|
||||
All first, second, and third updates are 'complex' updates,
|
||||
which means we'll have to change the group the user is on
|
||||
to account for them.
|
||||
|
||||
The fourth is a 'simple' change, since we're not changing
|
||||
the group a user is on, and so there's less overhead
|
||||
involved.
|
||||
"""
|
||||
await self._init_check()
|
||||
|
||||
old_group, old_index, old_presence = None, None, None
|
||||
|
||||
for group, presences in self.list:
|
||||
p_idx = index_by_func(
|
||||
lambda p: p['user']['id'] == str(user_id),
|
||||
presences)
|
||||
|
||||
log.debug('p_idx for group {!r} = {}',
|
||||
group.gid, p_idx)
|
||||
|
||||
if p_idx is None:
|
||||
log.debug('skipping group {}', group)
|
||||
continue
|
||||
|
||||
# make a copy since we're modifying in-place
|
||||
old_group = group.gid
|
||||
old_index = p_idx
|
||||
old_presence = dict(presences[p_idx])
|
||||
|
||||
# be ready if it is a simple update
|
||||
presences[p_idx].update(partial_presence)
|
||||
break
|
||||
|
||||
if not old_group:
|
||||
log.warning('pres update with unknown old group uid={}',
|
||||
user_id)
|
||||
return []
|
||||
|
||||
roles = partial_presence.get('roles', old_presence['roles'])
|
||||
new_status = partial_presence.get('status', old_presence['status'])
|
||||
|
||||
new_group = await self.get_group(user_id, roles, new_status)
|
||||
|
||||
log.debug('pres update: gid={} cid={} old_g={} new_g={}',
|
||||
self.guild_id, self.channel_id, old_group, new_group)
|
||||
|
||||
# if we're going to the same group,
|
||||
# treat this as a simple update
|
||||
if old_group == new_group:
|
||||
return await self._pres_update_simple(user_id)
|
||||
|
||||
return await self._pres_update_complex(
|
||||
user_id, old_group, old_index, new_group)
|
||||
|
||||
async def dispatch(self, event: str, data: Any):
|
||||
"""Modify the member list and dispatch the respective
|
||||
events to subscribed shards.
|
||||
|
||||
GuildMemberList stores the current guilds' list
|
||||
in its :attr:`GuildMemberList.list` attribute,
|
||||
with that attribute being modified via different
|
||||
calls to :meth:`GuildMemberList.dispatch`
|
||||
"""
|
||||
|
||||
# if no subscribers, drop event
|
||||
if not self.list:
|
||||
return
|
||||
|
||||
|
||||
class LazyGuildDispatcher(Dispatcher):
|
||||
"""Main class holding the member lists for lazy guilds."""
|
||||
# channel ids
|
||||
KEY_TYPE = int
|
||||
|
||||
# the session ids subscribing to channels
|
||||
VAL_TYPE = str
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
|
||||
self.storage = main.app.storage
|
||||
|
||||
# {chan_id: gml, ...}
|
||||
self.state = {}
|
||||
|
||||
#: store which guilds have their
|
||||
# respective GMLs
|
||||
# {guild_id: [chan_id, ...], ...}
|
||||
self.guild_map = defaultdict(list)
|
||||
|
||||
async def get_gml(self, channel_id: int):
|
||||
"""Get a guild list for a channel ID,
|
||||
generating it if it doesn't exist."""
|
||||
try:
|
||||
return self.state[channel_id]
|
||||
except KeyError:
|
||||
guild_id = await self.storage.guild_from_channel(
|
||||
channel_id
|
||||
)
|
||||
|
||||
# if we don't find a guild, we just
|
||||
# set it the same as the channel.
|
||||
if not guild_id:
|
||||
guild_id = channel_id
|
||||
|
||||
gml = GuildMemberList(guild_id, channel_id, self)
|
||||
self.state[channel_id] = gml
|
||||
self.guild_map[guild_id].append(channel_id)
|
||||
return gml
|
||||
|
||||
def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]:
|
||||
"""Get all member lists for a given guild."""
|
||||
return list(map(
|
||||
self.state.get,
|
||||
self.guild_map[guild_id]
|
||||
))
|
||||
|
||||
async def unsub(self, chan_id, session_id):
|
||||
gml = await self.get_gml(chan_id)
|
||||
gml.unsub(session_id)
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
main litecord ratelimiting code
|
||||
|
||||
This code was copied from elixire's ratelimiting,
|
||||
which in turn is a work on top of discord.py's ratelimiting.
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
class RatelimitBucket:
|
||||
"""Main ratelimit bucket class."""
|
||||
def __init__(self, tokens, second):
|
||||
self.requests = tokens
|
||||
self.second = second
|
||||
|
||||
self._window = 0.0
|
||||
self._tokens = self.requests
|
||||
self.retries = 0
|
||||
self._last = 0.0
|
||||
|
||||
def get_tokens(self, current):
|
||||
"""Get the current amount of available tokens."""
|
||||
if not current:
|
||||
current = time.time()
|
||||
|
||||
# by default, use _tokens
|
||||
tokens = self._tokens
|
||||
|
||||
# if current timestamp is above _window + seconds
|
||||
# reset tokens to self.requests (default)
|
||||
if current > self._window + self.second:
|
||||
tokens = self.requests
|
||||
|
||||
return tokens
|
||||
|
||||
def update_rate_limit(self):
|
||||
"""Update current ratelimit state."""
|
||||
current = time.time()
|
||||
self._last = current
|
||||
self._tokens = self.get_tokens(current)
|
||||
|
||||
# we are using the ratelimit for the first time
|
||||
# so set current ratelimit window to right now
|
||||
if self._tokens == self.requests:
|
||||
self._window = current
|
||||
|
||||
# Are we currently ratelimited?
|
||||
if self._tokens == 0:
|
||||
self.retries += 1
|
||||
return self.second - (current - self._window)
|
||||
|
||||
# if not ratelimited, remove a token
|
||||
self.retries = 0
|
||||
self._tokens -= 1
|
||||
|
||||
# if we got ratelimited after that token removal,
|
||||
# set window to now
|
||||
if self._tokens == 0:
|
||||
self._window = current
|
||||
|
||||
def reset(self):
|
||||
"""Reset current ratelimit to default state."""
|
||||
self._tokens = self.requests
|
||||
self._last = 0.0
|
||||
self.retries = 0
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of this ratelimit.
|
||||
|
||||
Used to manage multiple ratelimits to users.
|
||||
"""
|
||||
return RatelimitBucket(self.requests,
|
||||
self.second)
|
||||
|
||||
def __repr__(self):
|
||||
return (f'<RatelimitBucket requests={self.requests} '
|
||||
f'second={self.second} window: {self._window} '
|
||||
f'tokens={self._tokens}>')
|
||||
|
||||
|
||||
class Ratelimit:
|
||||
"""Manages buckets."""
|
||||
def __init__(self, tokens, second, keys=None):
|
||||
self._cache = {}
|
||||
if keys is None:
|
||||
keys = tuple()
|
||||
self.keys = keys
|
||||
self._cooldown = RatelimitBucket(tokens, second)
|
||||
|
||||
def __repr__(self):
|
||||
return (f'<Ratelimit cooldown={self._cooldown}>')
|
||||
|
||||
def _verify_cache(self):
|
||||
current = time.time()
|
||||
dead_keys = [k for k, v in self._cache.items()
|
||||
if current > v._last + v.second]
|
||||
|
||||
for k in dead_keys:
|
||||
del self._cache[k]
|
||||
|
||||
def get_bucket(self, key) -> RatelimitBucket:
|
||||
if not self._cooldown:
|
||||
return None
|
||||
|
||||
self._verify_cache()
|
||||
|
||||
if key not in self._cache:
|
||||
bucket = self._cooldown.copy()
|
||||
self._cache[key] = bucket
|
||||
else:
|
||||
bucket = self._cache[key]
|
||||
|
||||
return bucket
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
from quart import current_app as app, request, g
|
||||
|
||||
from litecord.errors import Ratelimited
|
||||
from litecord.auth import token_check, Unauthorized
|
||||
|
||||
|
||||
async def _check_bucket(bucket):
|
||||
retry_after = bucket.update_rate_limit()
|
||||
|
||||
request.bucket = bucket
|
||||
|
||||
if retry_after:
|
||||
request.retry_after = retry_after
|
||||
|
||||
raise Ratelimited('You are being rate limited.', {
|
||||
'retry_after': int(retry_after * 1000),
|
||||
'global': request.bucket_global,
|
||||
})
|
||||
|
||||
|
||||
async def _handle_global(ratelimit):
|
||||
"""Global ratelimit is per-user."""
|
||||
try:
|
||||
user_id = await token_check()
|
||||
except Unauthorized:
|
||||
user_id = request.remote_addr
|
||||
|
||||
request.bucket_global = True
|
||||
bucket = ratelimit.get_bucket(user_id)
|
||||
await _check_bucket(bucket)
|
||||
|
||||
|
||||
async def _handle_specific(ratelimit):
|
||||
try:
|
||||
user_id = await token_check()
|
||||
except Unauthorized:
|
||||
user_id = request.remote_addr
|
||||
|
||||
# construct the key based on the ratelimit.keys
|
||||
keys = ratelimit.keys
|
||||
|
||||
# base key is the user id
|
||||
key_components = [f'user_id:{user_id}']
|
||||
|
||||
for key in keys:
|
||||
val = request.view_args[key]
|
||||
key_components.append(f'{key}:{val}')
|
||||
|
||||
bucket_key = ':'.join(key_components)
|
||||
bucket = ratelimit.get_bucket(bucket_key)
|
||||
await _check_bucket(bucket)
|
||||
|
||||
|
||||
async def ratelimit_handler():
|
||||
"""Main ratelimit handler.
|
||||
|
||||
Decides on which ratelimit to use.
|
||||
"""
|
||||
rule = request.url_rule
|
||||
|
||||
if rule is None:
|
||||
return await _handle_global(
|
||||
app.ratelimiter.global_bucket
|
||||
)
|
||||
|
||||
# rule.endpoint is composed of '<blueprint>.<function>'
|
||||
# and so we can use that to make routes with different
|
||||
# methods have different ratelimits
|
||||
rule_path = rule.endpoint
|
||||
|
||||
# some request ratelimit context.
|
||||
# TODO: maybe put those in a namedtuple or contextvar of sorts?
|
||||
request.bucket = None
|
||||
request.retry_after = None
|
||||
request.bucket_global = False
|
||||
|
||||
try:
|
||||
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
|
||||
await _handle_specific(ratelimit)
|
||||
except KeyError:
|
||||
await _handle_global(
|
||||
app.ratelimiter.global_bucket
|
||||
)
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
from litecord.ratelimits.bucket import Ratelimit
|
||||
|
||||
"""
|
||||
REST:
|
||||
POST Message | 5/5s | per-channel
|
||||
DELETE Message | 5/1s | per-channel
|
||||
PUT/DELETE Reaction | 1/0.25s | per-channel
|
||||
PATCH Member | 10/10s | per-guild
|
||||
PATCH Member Nick | 1/1s | per-guild
|
||||
PATCH Username | 2/3600s | per-account
|
||||
|All Requests| | 50/1s | per-account
|
||||
WS:
|
||||
Gateway Connect | 1/5s | per-account
|
||||
Presence Update | 5/60s | per-session
|
||||
|All Sent Messages| | 120/60s | per-session
|
||||
"""
|
||||
|
||||
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
|
||||
|
||||
RATELIMITS = {
|
||||
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
|
||||
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
|
||||
|
||||
# all of those share the same bucket.
|
||||
'channel_reactions.add_reaction': REACTION_BUCKET,
|
||||
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
|
||||
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
|
||||
|
||||
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
|
||||
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
|
||||
|
||||
# this only applies to username.
|
||||
# 'users.patch_me': Ratelimit(2, 3600),
|
||||
|
||||
'_ws.connect': Ratelimit(1, 5),
|
||||
'_ws.presence': Ratelimit(5, 60),
|
||||
'_ws.messages': Ratelimit(120, 60),
|
||||
|
||||
# 1000 / 4h for new session issuing
|
||||
'_ws.session': Ratelimit(1000, 14400)
|
||||
}
|
||||
|
||||
class RatelimitManager:
|
||||
"""Manager for the bucket managers"""
|
||||
def __init__(self):
|
||||
self._ratelimiters = {}
|
||||
self.global_bucket = Ratelimit(50, 1)
|
||||
self._fill_rtl()
|
||||
|
||||
def _fill_rtl(self):
|
||||
for path, rtl in RATELIMITS.items():
|
||||
self._ratelimiters[path] = rtl
|
||||
|
||||
def get_ratelimit(self, key: str) -> Ratelimit:
|
||||
"""Get the :class:`Ratelimit` instance for a given path."""
|
||||
return self._ratelimiters.get(key, self.global_bucket)
|
||||
|
|
@ -1,11 +1,16 @@
|
|||
import re
|
||||
from typing import Union, Dict, List, Any
|
||||
|
||||
from cerberus import Validator
|
||||
from logbook import Logger
|
||||
|
||||
from .errors import BadRequest
|
||||
from .enums import ActivityType, StatusType, ExplicitFilter, \
|
||||
RelationshipType, MessageNotifications
|
||||
from .permissions import Permissions
|
||||
from .types import Color
|
||||
from .enums import (
|
||||
ActivityType, StatusType, ExplicitFilter, RelationshipType,
|
||||
MessageNotifications, ChannelType, VerificationLevel
|
||||
)
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -24,13 +29,21 @@ EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
|
|||
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
|
||||
|
||||
|
||||
def _in_enum(enum, value: int):
|
||||
try:
|
||||
enum(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
class LitecordValidator(Validator):
|
||||
def _validate_type_username(self, value: str) -> bool:
|
||||
"""Validate against the username regex."""
|
||||
return bool(USERNAME_REGEX.match(value))
|
||||
|
||||
def _validate_type_email(self, value: str) -> bool:
|
||||
"""Validate against the username regex."""
|
||||
"""Validate against the email regex."""
|
||||
return bool(EMAIL_REGEX.match(value))
|
||||
|
||||
def _validate_type_b64_icon(self, value: str) -> bool:
|
||||
|
|
@ -56,11 +69,17 @@ class LitecordValidator(Validator):
|
|||
|
||||
def _validate_type_voice_region(self, value: str) -> bool:
|
||||
# TODO: complete this list
|
||||
return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
|
||||
return value.lower() in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
|
||||
|
||||
def _validate_type_verification_level(self, value: int) -> bool:
|
||||
return _in_enum(VerificationLevel, value)
|
||||
|
||||
def _validate_type_activity_type(self, value: int) -> bool:
|
||||
return value in ActivityType.values()
|
||||
|
||||
def _validate_type_channel_type(self, value: int) -> bool:
|
||||
return value in ChannelType.values()
|
||||
|
||||
def _validate_type_status_external(self, value: str) -> bool:
|
||||
statuses = StatusType.values()
|
||||
|
||||
|
|
@ -94,11 +113,31 @@ class LitecordValidator(Validator):
|
|||
|
||||
return val in MessageNotifications.values()
|
||||
|
||||
def _validate_type_guild_name(self, value: str) -> bool:
|
||||
return 2 <= len(value) <= 100
|
||||
|
||||
def validate(reqjson, schema, raise_err: bool = True):
|
||||
def _validate_type_role_name(self, value: str) -> bool:
|
||||
return 1 <= len(value) <= 100
|
||||
|
||||
def _validate_type_channel_name(self, value: str) -> bool:
|
||||
# for now, we'll use the same validation for guild_name
|
||||
return self._validate_type_guild_name(value)
|
||||
|
||||
|
||||
def validate(reqjson: Union[Dict, List], schema: Dict,
|
||||
raise_err: bool = True) -> Union[Dict, List]:
|
||||
"""Validate a given document (user-input) and give
|
||||
the correct document as a result.
|
||||
"""
|
||||
validator = LitecordValidator(schema)
|
||||
|
||||
if not validator.validate(reqjson):
|
||||
try:
|
||||
valid = validator.validate(reqjson)
|
||||
except Exception:
|
||||
log.exception('Error while validating')
|
||||
raise Exception(f'Error while validating: {reqjson}')
|
||||
|
||||
if not valid:
|
||||
errs = validator.errors
|
||||
log.warning('Error validating doc {!r}: {!r}', reqjson, errs)
|
||||
|
||||
|
|
@ -146,16 +185,55 @@ USER_UPDATE = {
|
|||
|
||||
}
|
||||
|
||||
PARTIAL_ROLE_GUILD_CREATE = {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'name': {'type': 'role_name'},
|
||||
'color': {'type': 'number', 'default': 0},
|
||||
'hoist': {'type': 'boolean', 'default': False},
|
||||
|
||||
# NOTE: no position on partial role (on guild create)
|
||||
|
||||
'permissions': {'coerce': Permissions, 'required': False},
|
||||
'mentionable': {'type': 'boolean', 'default': False},
|
||||
}
|
||||
}
|
||||
|
||||
PARTIAL_CHANNEL_GUILD_CREATE = {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'name': {'type': 'channel_name'},
|
||||
'type': {'type': 'channel_type'},
|
||||
}
|
||||
}
|
||||
|
||||
GUILD_CREATE = {
|
||||
'name': {'type': 'guild_name'},
|
||||
'region': {'type': 'voice_region'},
|
||||
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
|
||||
|
||||
'verification_level': {
|
||||
'type': 'verification_level', 'default': 0},
|
||||
'default_message_notifications': {
|
||||
'type': 'msg_notifications', 'default': 0},
|
||||
'explicit_content_filter': {
|
||||
'type': 'explicit', 'default': 0},
|
||||
|
||||
'roles': {
|
||||
'type': 'list', 'required': False,
|
||||
'schema': PARTIAL_ROLE_GUILD_CREATE},
|
||||
'channels': {
|
||||
'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE},
|
||||
}
|
||||
|
||||
|
||||
GUILD_UPDATE = {
|
||||
'name': {
|
||||
'type': 'string',
|
||||
'minlength': 2,
|
||||
'maxlength': 100,
|
||||
'type': 'guild_name',
|
||||
'required': False
|
||||
},
|
||||
'region': {'type': 'voice_region', 'required': False},
|
||||
'icon': {'type': 'icon', 'required': False},
|
||||
'icon': {'type': 'b64_icon', 'required': False},
|
||||
|
||||
'verification_level': {'type': 'verification_level', 'required': False},
|
||||
'default_message_notifications': {
|
||||
|
|
@ -173,13 +251,93 @@ GUILD_UPDATE = {
|
|||
}
|
||||
|
||||
|
||||
CHAN_OVERWRITE = {
|
||||
'id': {'coerce': int},
|
||||
'type': {'type': 'string', 'allowed': ['role', 'member']},
|
||||
'allow': {'coerce': Permissions},
|
||||
'deny': {'coerce': Permissions}
|
||||
}
|
||||
|
||||
|
||||
CHAN_UPDATE = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 2,
|
||||
'maxlength': 100, 'required': False},
|
||||
|
||||
'position': {'coerce': int, 'required': False},
|
||||
|
||||
'topic': {
|
||||
'type': 'string', 'minlength': 0,
|
||||
'maxlength': 1024, 'required': False},
|
||||
|
||||
'nsfw': {'type': 'boolean', 'required': False},
|
||||
'rate_limit_per_user': {
|
||||
'coerce': int, 'min': 0,
|
||||
'max': 120, 'required': False},
|
||||
|
||||
'bitrate': {
|
||||
'coerce': int, 'min': 8000,
|
||||
|
||||
# NOTE: 'max' is 96000 for non-vip guilds
|
||||
'max': 128000, 'required': False},
|
||||
|
||||
'user_limit': {
|
||||
# user_limit being 0 means infinite.
|
||||
'coerce': int, 'min': 0,
|
||||
'max': 99, 'required': False
|
||||
},
|
||||
|
||||
'permission_overwrites': {
|
||||
'type': 'list',
|
||||
'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE},
|
||||
'required': False
|
||||
},
|
||||
|
||||
'parent_id': {'coerce': int, 'required': False, 'nullable': True}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
ROLE_CREATE = {
|
||||
'name': {'type': 'string', 'default': 'new role'},
|
||||
'permissions': {'coerce': Permissions, 'nullable': True},
|
||||
'color': {'coerce': Color, 'default': 0},
|
||||
'hoist': {'type': 'boolean', 'default': False},
|
||||
'mentionable': {'type': 'boolean', 'default': False},
|
||||
}
|
||||
|
||||
ROLE_UPDATE = {
|
||||
'name': {'type': 'string', 'required': False},
|
||||
'permissions': {'coerce': Permissions, 'required': False},
|
||||
'color': {'coerce': Color, 'required': False},
|
||||
'hoist': {'type': 'boolean', 'required': False},
|
||||
'mentionable': {'type': 'boolean', 'required': False},
|
||||
}
|
||||
|
||||
|
||||
ROLE_UPDATE_POSITION = {
|
||||
'roles': {
|
||||
'type': 'list',
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'id': {'coerce': int},
|
||||
'position': {'coerce': int},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
MEMBER_UPDATE = {
|
||||
'nick': {
|
||||
'type': 'nickname',
|
||||
'type': 'username',
|
||||
'minlength': 1, 'maxlength': 100,
|
||||
'required': False,
|
||||
},
|
||||
'roles': {'type': 'list', 'required': False},
|
||||
'roles': {'type': 'list', 'required': False,
|
||||
'schema': {'coerce': int}},
|
||||
'mute': {'type': 'boolean', 'required': False},
|
||||
'deaf': {'type': 'boolean', 'required': False},
|
||||
'channel_id': {'type': 'snowflake', 'required': False},
|
||||
|
|
@ -196,57 +354,60 @@ MESSAGE_CREATE = {
|
|||
|
||||
|
||||
GW_ACTIVITY = {
|
||||
'name': {'type': 'string', 'required': True},
|
||||
'type': {'type': 'activity_type', 'required': True},
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'name': {'type': 'string', 'required': True},
|
||||
'type': {'type': 'activity_type', 'required': True},
|
||||
|
||||
'url': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'url': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
'timestamps': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'start': {'type': 'number', 'required': True},
|
||||
'end': {'type': 'number', 'required': True},
|
||||
'timestamps': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'start': {'type': 'number', 'required': True},
|
||||
'end': {'type': 'number', 'required': False},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
'application_id': {'type': 'snowflake', 'required': False,
|
||||
'nullable': False},
|
||||
'details': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'state': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'application_id': {'type': 'snowflake', 'required': False,
|
||||
'nullable': False},
|
||||
'details': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'state': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
'party': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'id': {'type': 'snowflake', 'required': False},
|
||||
'size': {'type': 'list', 'required': False},
|
||||
}
|
||||
},
|
||||
'party': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'id': {'type': 'snowflake', 'required': False},
|
||||
'size': {'type': 'list', 'required': False},
|
||||
}
|
||||
},
|
||||
|
||||
'assets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'large_image': {'type': 'snowflake', 'required': False},
|
||||
'large_text': {'type': 'string', 'required': False},
|
||||
'small_image': {'type': 'snowflake', 'required': False},
|
||||
'small_text': {'type': 'string', 'required': False},
|
||||
}
|
||||
},
|
||||
'assets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'large_image': {'type': 'snowflake', 'required': False},
|
||||
'large_text': {'type': 'string', 'required': False},
|
||||
'small_image': {'type': 'snowflake', 'required': False},
|
||||
'small_text': {'type': 'string', 'required': False},
|
||||
}
|
||||
},
|
||||
|
||||
'secrets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'join': {'type': 'string', 'required': False},
|
||||
'spectate': {'type': 'string', 'required': False},
|
||||
'match': {'type': 'string', 'required': False},
|
||||
}
|
||||
},
|
||||
'secrets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'join': {'type': 'string', 'required': False},
|
||||
'spectate': {'type': 'string', 'required': False},
|
||||
'match': {'type': 'string', 'required': False},
|
||||
}
|
||||
},
|
||||
|
||||
'instance': {'type': 'boolean', 'required': False},
|
||||
'flags': {'type': 'number', 'required': False},
|
||||
'instance': {'type': 'boolean', 'required': False},
|
||||
'flags': {'type': 'number', 'required': False},
|
||||
}
|
||||
}
|
||||
|
||||
GW_STATUS_UPDATE = {
|
||||
|
|
@ -335,6 +496,8 @@ USER_SETTINGS = {
|
|||
'show_current_game': {'type': 'boolean', 'required': False},
|
||||
|
||||
'timezone_offset': {'type': 'number', 'required': False},
|
||||
|
||||
'status': {'type': 'status_external', 'required': False}
|
||||
}
|
||||
|
||||
RELATIONSHIP = {
|
||||
|
|
@ -395,3 +558,7 @@ GUILD_SETTINGS = {
|
|||
'required': False,
|
||||
}
|
||||
}
|
||||
|
||||
GUILD_PRUNE = {
|
||||
'days': {'type': 'number', 'coerce': int, 'min': 1}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@ from logbook import Logger
|
|||
|
||||
from .enums import ChannelType, RelationshipType
|
||||
from .schemas import USER_MENTION, ROLE_MENTION
|
||||
from litecord.blueprints.channel.reactions import (
|
||||
emoji_info_from_str, EmojiType, emoji_sql, partial_emoji
|
||||
)
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -163,17 +166,40 @@ class Storage:
|
|||
WHERE guild_id = $1 and user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
|
||||
members_roles = await self.db.fetch("""
|
||||
async def get_member_role_ids(self, guild_id: int,
|
||||
member_id: int) -> List[int]:
|
||||
"""Get a list of role IDs that are on a member."""
|
||||
roles = await self.db.fetch("""
|
||||
SELECT role_id::text
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
roles = [r['role_id'] for r in roles]
|
||||
|
||||
try:
|
||||
roles.remove(str(guild_id))
|
||||
except ValueError:
|
||||
# if the @everyone role isn't in, we add it
|
||||
# to member_roles automatically (it won't
|
||||
# be shown on the API, though).
|
||||
await self.db.execute("""
|
||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", member_id, guild_id, guild_id)
|
||||
|
||||
return list(map(str, roles))
|
||||
|
||||
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
|
||||
roles = await self.get_member_role_ids(guild_id, member_id)
|
||||
return {
|
||||
'user': await self.get_user(member_id),
|
||||
'nick': row['nickname'],
|
||||
'roles': [row[0] for row in members_roles],
|
||||
|
||||
# we don't send the @everyone role's id to
|
||||
# the user since it is known that everyone has
|
||||
# that role.
|
||||
'roles': roles,
|
||||
'joined_at': row['joined_at'].isoformat(),
|
||||
'deaf': row['deafened'],
|
||||
'mute': row['muted'],
|
||||
|
|
@ -289,7 +315,7 @@ class Storage:
|
|||
WHERE channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
|
||||
async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
|
||||
overwrite_rows = await self.db.fetch("""
|
||||
SELECT target_type, target_role, target_user, allow, deny
|
||||
FROM channel_overwrites
|
||||
|
|
@ -298,18 +324,20 @@ class Storage:
|
|||
|
||||
def _overwrite_convert(row):
|
||||
drow = dict(row)
|
||||
drow['type'] = drow['target_type']
|
||||
|
||||
target_type = drow['target_type']
|
||||
drow['type'] = 'user' if target_type == 0 else 'role'
|
||||
|
||||
# if type is 0, the overwrite is for a user
|
||||
# if type is 1, the overwrite is for a role
|
||||
drow['id'] = {
|
||||
0: drow['target_user'],
|
||||
1: drow['target_role'],
|
||||
}[drow['type']]
|
||||
}[target_type]
|
||||
|
||||
drow['id'] = str(drow['id'])
|
||||
|
||||
drow.pop('overwrite_type')
|
||||
drow.pop('target_type')
|
||||
drow.pop('target_user')
|
||||
drow.pop('target_role')
|
||||
|
||||
|
|
@ -335,8 +363,8 @@ class Storage:
|
|||
dbase['type'] = chan_type
|
||||
|
||||
res = await self._channels_extra(dbase)
|
||||
res['permission_overwrites'] = \
|
||||
list(await self._chan_overwrites(channel_id))
|
||||
res['permission_overwrites'] = await self.chan_overwrites(
|
||||
channel_id)
|
||||
|
||||
res['id'] = str(res['id'])
|
||||
return res
|
||||
|
|
@ -401,8 +429,8 @@ class Storage:
|
|||
|
||||
res = await self._channels_extra(drow)
|
||||
|
||||
res['permission_overwrites'] = \
|
||||
list(await self._chan_overwrites(row['id']))
|
||||
res['permission_overwrites'] = await self.chan_overwrites(
|
||||
row['id'])
|
||||
|
||||
# Making sure.
|
||||
res['id'] = str(res['id'])
|
||||
|
|
@ -440,6 +468,7 @@ class Storage:
|
|||
permissions, managed, mentionable
|
||||
FROM roles
|
||||
WHERE guild_id = $1
|
||||
ORDER BY position ASC
|
||||
""", guild_id)
|
||||
|
||||
return list(map(dict, roledata))
|
||||
|
|
@ -535,7 +564,70 @@ class Storage:
|
|||
|
||||
return res
|
||||
|
||||
async def get_message(self, message_id: int) -> Dict:
|
||||
async def get_reactions(self, message_id: int, user_id=None) -> List:
|
||||
"""Get all reactions in a message."""
|
||||
reactions = await self.db.fetch("""
|
||||
SELECT user_id, emoji_type, emoji_id, emoji_text
|
||||
FROM message_reactions
|
||||
WHERE message_id = $1
|
||||
ORDER BY react_ts
|
||||
""", message_id)
|
||||
|
||||
# ordered list of emoji
|
||||
emoji = []
|
||||
|
||||
# the current state of emoji info
|
||||
react_stats = {}
|
||||
|
||||
# to generate the list, we pass through all
|
||||
# all reactions and insert them all.
|
||||
|
||||
# we can't use a set() because that
|
||||
# doesn't guarantee any order.
|
||||
for row in reactions:
|
||||
etype = EmojiType(row['emoji_type'])
|
||||
eid, etext = row['emoji_id'], row['emoji_text']
|
||||
|
||||
# get the main key to use, given
|
||||
# the emoji information
|
||||
_, main_emoji = emoji_sql(etype, eid, etext)
|
||||
|
||||
if main_emoji in emoji:
|
||||
continue
|
||||
|
||||
# maintain order (first reacted comes first
|
||||
# on the reaction list)
|
||||
emoji.append(main_emoji)
|
||||
|
||||
react_stats[main_emoji] = {
|
||||
'count': 0,
|
||||
'me': False,
|
||||
'emoji': partial_emoji(etype, eid, etext)
|
||||
}
|
||||
|
||||
# then the 2nd pass, where we insert
|
||||
# the info for each reaction in the react_stats
|
||||
# dictionary
|
||||
for row in reactions:
|
||||
etype = EmojiType(row['emoji_type'])
|
||||
eid, etext = row['emoji_id'], row['emoji_text']
|
||||
|
||||
# same thing as the last loop,
|
||||
# extracting main key
|
||||
_, main_emoji = emoji_sql(etype, eid, etext)
|
||||
|
||||
stats = react_stats[main_emoji]
|
||||
stats['count'] += 1
|
||||
|
||||
if row['user_id'] == user_id:
|
||||
stats['me'] = True
|
||||
|
||||
# after processing reaction counts,
|
||||
# we get them in the same order
|
||||
# they were defined in the first loop.
|
||||
return list(map(react_stats.get, emoji))
|
||||
|
||||
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
||||
"""Get a single message's payload."""
|
||||
row = await self.db.fetchrow("""
|
||||
SELECT id::text, channel_id::text, author_id, content,
|
||||
|
|
@ -596,6 +688,8 @@ class Storage:
|
|||
res['mention_roles'] = await self._msg_regex(
|
||||
ROLE_MENTION, _get_role_mention, content)
|
||||
|
||||
res['reactions'] = await self.get_reactions(message_id, user_id)
|
||||
|
||||
# TODO: handle webhook authors
|
||||
res['author'] = await self.get_user(res['author_id'])
|
||||
res.pop('author_id')
|
||||
|
|
@ -606,9 +700,6 @@ class Storage:
|
|||
# TODO: res['embeds']
|
||||
res['embeds'] = []
|
||||
|
||||
# TODO: res['reactions']
|
||||
res['reactions'] = []
|
||||
|
||||
# TODO: res['pinned']
|
||||
res['pinned'] = False
|
||||
|
||||
|
|
@ -966,7 +1057,6 @@ class Storage:
|
|||
""", user_id)
|
||||
|
||||
for row in settings:
|
||||
print(dict(row))
|
||||
gid = int(row['guild_id'])
|
||||
drow = dict(row)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
|
||||
class Color:
|
||||
"""Custom color class"""
|
||||
def __init__(self, val: int):
|
||||
self.blue = val & 255
|
||||
self.green = (val >> 8) & 255
|
||||
self.red = (val >> 16) & 255
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Give the actual RGB integer encoding this color."""
|
||||
return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16)
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
|
@ -22,3 +22,18 @@ async def task_wrapper(name: str, coro):
|
|||
pass
|
||||
except:
|
||||
log.exception('{} task error', name)
|
||||
|
||||
|
||||
def dict_get(mapping, key, default):
|
||||
"""Return `default` even when mapping[key] is None."""
|
||||
return mapping.get(key) or default
|
||||
|
||||
|
||||
def index_by_func(function, indexable: iter) -> int:
|
||||
"""Search in an idexable and return the index number
|
||||
for an iterm that has func(item) = True."""
|
||||
for index, item in enumerate(indexable):
|
||||
if function(item):
|
||||
return index
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from manage.main import main
|
||||
|
||||
import config
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(config))
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .command import setup as migration
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
import inspect
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from collections import namedtuple
|
||||
from typing import Dict
|
||||
|
||||
import asyncpg
|
||||
from logbook import Logger
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
Migration = namedtuple('Migration', 'id name path')
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationContext:
|
||||
"""Hold information about migration."""
|
||||
migration_folder: Path
|
||||
scripts: Dict[int, Migration]
|
||||
|
||||
@property
|
||||
def latest(self):
|
||||
"""Return the latest migration ID."""
|
||||
return max(self.scripts.keys())
|
||||
|
||||
|
||||
def make_migration_ctx() -> MigrationContext:
|
||||
"""Create the MigrationContext instance."""
|
||||
# taken from https://stackoverflow.com/a/6628348
|
||||
script_path = inspect.stack()[0][1]
|
||||
script_folder = '/'.join(script_path.split('/')[:-1])
|
||||
script_folder = Path(script_folder)
|
||||
|
||||
migration_folder = script_folder / 'scripts'
|
||||
|
||||
mctx = MigrationContext(migration_folder, {})
|
||||
|
||||
for mig_path in migration_folder.glob('*.sql'):
|
||||
mig_path_str = str(mig_path)
|
||||
|
||||
# extract migration script id and name
|
||||
mig_filename = mig_path_str.split('/')[-1].split('.')[0]
|
||||
name_fragments = mig_filename.split('_')
|
||||
|
||||
mig_id = int(name_fragments[0])
|
||||
mig_name = '_'.join(name_fragments[1:])
|
||||
|
||||
mctx.scripts[mig_id] = Migration(
|
||||
mig_id, mig_name, mig_path)
|
||||
|
||||
return mctx
|
||||
|
||||
|
||||
async def _ensure_changelog(app, ctx):
|
||||
# make sure we have the migration table up
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
CREATE TABLE migration_log (
|
||||
change_num bigint NOT NULL,
|
||||
|
||||
apply_ts timestamp without time zone default
|
||||
(now() at time zone 'utc'),
|
||||
|
||||
description text,
|
||||
|
||||
PRIMARY KEY (change_num)
|
||||
);
|
||||
""")
|
||||
|
||||
# if we were able to create the
|
||||
# migration_log table, insert that we are
|
||||
# on the latest version.
|
||||
await app.db.execute("""
|
||||
INSERT INTO migration_log (change_num, description)
|
||||
VALUES ($1, $2)
|
||||
""", ctx.latest, 'migration setup')
|
||||
except asyncpg.DuplicateTableError:
|
||||
log.debug('existing migration table')
|
||||
|
||||
|
||||
async def apply_migration(app, migration: Migration):
|
||||
"""Apply a single migration."""
|
||||
migration_sql = migration.path.read_text(encoding='utf-8')
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
INSERT INTO migration_log (change_num, description)
|
||||
VALUES ($1, $2)
|
||||
""", migration.id, f'migration: {migration.name}')
|
||||
except asyncpg.UniqueViolationError:
|
||||
log.warning('already applied {}', migration.id)
|
||||
return
|
||||
|
||||
await app.db.execute(migration_sql)
|
||||
log.info('applied {}', migration.id)
|
||||
|
||||
|
||||
async def migrate_cmd(app, _args):
|
||||
"""Main migration command.
|
||||
|
||||
This makes sure the database
|
||||
is updated.
|
||||
"""
|
||||
|
||||
ctx = make_migration_ctx()
|
||||
|
||||
await _ensure_changelog(app, ctx)
|
||||
|
||||
# local point in the changelog
|
||||
local_change = await app.db.fetchval("""
|
||||
SELECT max(change_num)
|
||||
FROM migration_log
|
||||
""")
|
||||
|
||||
local_change = local_change or 0
|
||||
latest_change = ctx.latest
|
||||
|
||||
log.debug('local: {}, latest: {}', local_change, latest_change)
|
||||
|
||||
if local_change == latest_change:
|
||||
print('no changes to do, exiting')
|
||||
return
|
||||
|
||||
# we do local_change + 1 so we start from the
|
||||
# next migration to do, end in latest_change + 1
|
||||
# because of how range() works.
|
||||
for idx in range(local_change + 1, latest_change + 1):
|
||||
migration = ctx.scripts.get(idx)
|
||||
|
||||
print('applying', migration.id, migration.name)
|
||||
await apply_migration(app, migration)
|
||||
|
||||
|
||||
def setup(subparser):
|
||||
migrate_parser = subparser.add_parser(
|
||||
'migrate',
|
||||
help='Run migration tasks',
|
||||
description=migrate_cmd.__doc__
|
||||
)
|
||||
|
||||
migrate_parser.set_defaults(func=migrate_cmd)
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
-- unused tables
|
||||
DROP TABLE message_embeds;
|
||||
DROP TABLE embeds;
|
||||
|
||||
ALTER TABLE messages
|
||||
ADD COLUMN embeds jsonb DEFAULT '[]'
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
import asyncio
|
||||
import argparse
|
||||
from sys import argv
|
||||
from dataclasses import dataclass
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from run import init_app_managers, init_app_db
|
||||
from manage.cmd.migration import migration
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeApp:
|
||||
"""Fake app instance."""
|
||||
config: dict
|
||||
db = None
|
||||
loop: asyncio.BaseEventLoop = None
|
||||
ratelimiter = None
|
||||
state_manager = None
|
||||
storage = None
|
||||
dispatcher = None
|
||||
presence = None
|
||||
|
||||
|
||||
def init_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
subparser = parser.add_subparsers(help='operations')
|
||||
|
||||
migration(subparser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(config):
|
||||
"""Start the script"""
|
||||
loop = asyncio.get_event_loop()
|
||||
cfg = getattr(config, config.MODE)
|
||||
app = FakeApp(cfg.__dict__)
|
||||
|
||||
loop.run_until_complete(init_app_db(app))
|
||||
init_app_managers(app)
|
||||
|
||||
# initialize argparser
|
||||
parser = init_parser()
|
||||
|
||||
try:
|
||||
if len(argv) < 2:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
args = parser.parse_args()
|
||||
loop.run_until_complete(args.func(app, args))
|
||||
except Exception:
|
||||
log.exception('error while running command')
|
||||
finally:
|
||||
loop.run_until_complete(app.db.close())
|
||||
10
nginx.conf
10
nginx.conf
|
|
@ -5,13 +5,11 @@ server {
|
|||
location / {
|
||||
proxy_pass http://localhost:5000;
|
||||
}
|
||||
}
|
||||
|
||||
# Main litecord websocket proxy.
|
||||
server {
|
||||
server_name websocket.somewhere;
|
||||
|
||||
location / {
|
||||
# if you don't want to keep the gateway
|
||||
# domain as the main domain, you can
|
||||
# keep a separate server block
|
||||
location /ws {
|
||||
proxy_pass http://localhost:5001;
|
||||
|
||||
# those options are required for websockets
|
||||
|
|
|
|||
106
run.py
106
run.py
|
|
@ -9,9 +9,28 @@ from quart import Quart, g, jsonify, request
|
|||
from logbook import StreamHandler, Logger
|
||||
from logbook.compat import redirect_logging
|
||||
|
||||
# import the config set by instance owner
|
||||
import config
|
||||
from litecord.blueprints import gateway, auth, users, guilds, channels, \
|
||||
webhooks, science, voice, invites, relationships, dms
|
||||
|
||||
from litecord.blueprints import (
|
||||
gateway, auth, users, guilds, channels, webhooks, science,
|
||||
voice, invites, relationships, dms
|
||||
)
|
||||
|
||||
# those blueprints are separated from the "main" ones
|
||||
# for code readability if people want to dig through
|
||||
# the codebase.
|
||||
from litecord.blueprints.guild import (
|
||||
guild_roles, guild_members, guild_channels, guild_mod
|
||||
)
|
||||
|
||||
from litecord.blueprints.channel import (
|
||||
channel_messages, channel_reactions, channel_pins
|
||||
)
|
||||
|
||||
from litecord.ratelimits.handler import ratelimit_handler
|
||||
from litecord.ratelimits.main import RatelimitManager
|
||||
|
||||
from litecord.gateway import websocket_handler
|
||||
from litecord.errors import LitecordError
|
||||
from litecord.gateway.state_manager import StateManager
|
||||
|
|
@ -50,8 +69,18 @@ bps = {
|
|||
auth: '/auth',
|
||||
users: '/users',
|
||||
relationships: '/users',
|
||||
|
||||
guilds: '/guilds',
|
||||
guild_roles: '/guilds',
|
||||
guild_members: '/guilds',
|
||||
guild_channels: '/guilds',
|
||||
guild_mod: '/guilds',
|
||||
|
||||
channels: '/channels',
|
||||
channel_messages: '/channels',
|
||||
channel_reactions: '/channels',
|
||||
channel_pins: '/channels',
|
||||
|
||||
webhooks: None,
|
||||
science: None,
|
||||
voice: '/voice',
|
||||
|
|
@ -64,6 +93,11 @@ for bp, suffix in bps.items():
|
|||
app.register_blueprint(bp, url_prefix=f'/api/v6{suffix}')
|
||||
|
||||
|
||||
@app.before_request
|
||||
async def app_before_request():
|
||||
await ratelimit_handler()
|
||||
|
||||
|
||||
@app.after_request
|
||||
async def app_after_request(resp):
|
||||
origin = request.headers.get('Origin', '*')
|
||||
|
|
@ -80,19 +114,44 @@ async def app_after_request(resp):
|
|||
# resp.headers['Access-Control-Allow-Methods'] = '*'
|
||||
resp.headers['Access-Control-Allow-Methods'] = \
|
||||
resp.headers.get('allow', '*')
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def app_before_serving():
|
||||
log.info('opening db')
|
||||
@app.after_request
|
||||
async def app_set_ratelimit_headers(resp):
|
||||
"""Set the specific ratelimit headers."""
|
||||
try:
|
||||
bucket = request.bucket
|
||||
|
||||
if bucket is None:
|
||||
raise AttributeError()
|
||||
|
||||
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
|
||||
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
|
||||
resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second)
|
||||
|
||||
resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower()
|
||||
|
||||
# only add Retry-After if we actually hit a ratelimit
|
||||
retry_after = request.retry_after
|
||||
if request.retry_after:
|
||||
resp.headers['Retry-After'] = str(retry_after)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
async def init_app_db(app):
|
||||
"""Connect to databases"""
|
||||
app.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
||||
|
||||
g.app = app
|
||||
|
||||
def init_app_managers(app):
|
||||
"""Initialize singleton classes."""
|
||||
app.loop = asyncio.get_event_loop()
|
||||
g.loop = asyncio.get_event_loop()
|
||||
|
||||
app.ratelimiter = RatelimitManager()
|
||||
app.state_manager = StateManager()
|
||||
app.storage = Storage(app.db)
|
||||
|
||||
|
|
@ -101,6 +160,17 @@ async def app_before_serving():
|
|||
app.state_manager, app.dispatcher)
|
||||
app.storage.presence = app.presence
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def app_before_serving():
|
||||
log.info('opening db')
|
||||
await init_app_db(app)
|
||||
|
||||
g.app = app
|
||||
g.loop = asyncio.get_event_loop()
|
||||
|
||||
init_app_managers(app)
|
||||
|
||||
# start the websocket, etc
|
||||
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
||||
log.info(f'starting websocket at {host} {port}')
|
||||
|
|
@ -108,8 +178,11 @@ async def app_before_serving():
|
|||
async def _wrapper(ws, url):
|
||||
# We wrap the main websocket_handler
|
||||
# so we can pass quart's app object.
|
||||
|
||||
# TODO: pass just the app object
|
||||
await websocket_handler((app.db, app.state_manager, app.storage,
|
||||
app.loop, app.dispatcher, app.presence),
|
||||
app.loop, app.dispatcher, app.presence,
|
||||
app.ratelimiter),
|
||||
ws, url)
|
||||
|
||||
ws_future = websockets.serve(_wrapper, host, port)
|
||||
|
|
@ -119,6 +192,15 @@ async def app_before_serving():
|
|||
|
||||
@app.after_serving
|
||||
async def app_after_serving():
|
||||
"""Shutdown tasks for the server."""
|
||||
|
||||
# first close all clients, then close db
|
||||
tasks = app.state_manager.gen_close_tasks()
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, loop=app.loop)
|
||||
|
||||
app.state_manager.close()
|
||||
|
||||
log.info('closing db')
|
||||
await app.db.close()
|
||||
|
||||
|
|
@ -130,9 +212,13 @@ async def handle_litecord_err(err):
|
|||
except IndexError:
|
||||
ejson = {}
|
||||
|
||||
try:
|
||||
ejson['code'] = err.error_code
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return jsonify({
|
||||
'error': True,
|
||||
# 'code': err.code,
|
||||
'status': err.status_code,
|
||||
'message': err.message,
|
||||
**ejson
|
||||
|
|
|
|||
44
schema.sql
44
schema.sql
|
|
@ -75,6 +75,9 @@ CREATE TABLE IF NOT EXISTS users (
|
|||
phone varchar(60) DEFAULT '',
|
||||
password_hash text NOT NULL,
|
||||
|
||||
-- store the last time the user logged in via the gateway
|
||||
last_session timestamp without time zone default (now() at time zone 'utc'),
|
||||
|
||||
PRIMARY KEY (id, username, discriminator)
|
||||
);
|
||||
|
||||
|
|
@ -131,6 +134,10 @@ CREATE TABLE IF NOT EXISTS user_settings (
|
|||
|
||||
-- appearance
|
||||
message_display_compact bool DEFAULT false,
|
||||
|
||||
-- for now we store status but don't
|
||||
-- actively use it, since the official client
|
||||
-- sends its own presence on IDENTIFY
|
||||
status text DEFAULT 'online' NOT NULL,
|
||||
theme text DEFAULT 'dark' NOT NULL,
|
||||
developer_mode bool DEFAULT true,
|
||||
|
|
@ -328,7 +335,7 @@ CREATE TABLE IF NOT EXISTS channel_overwrites (
|
|||
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||
|
||||
-- target_type = 0 -> use target_user
|
||||
-- target_type = 1 -> user target_role
|
||||
-- target_type = 1 -> use target_role
|
||||
-- discord already has overwrite.type = 'role' | 'member'
|
||||
-- so this allows us to be more compliant with the API
|
||||
target_type integer default null,
|
||||
|
|
@ -344,11 +351,15 @@ CREATE TABLE IF NOT EXISTS channel_overwrites (
|
|||
-- they're bigints (64bits), discord,
|
||||
-- for now, only needs 53.
|
||||
allow bigint DEFAULT 0,
|
||||
deny bigint DEFAULT 0,
|
||||
|
||||
PRIMARY KEY (channel_id, target_role, target_user)
|
||||
deny bigint DEFAULT 0
|
||||
);
|
||||
|
||||
-- columns in private keys can't have NULL values,
|
||||
-- so instead we use a custom constraint with UNIQUE
|
||||
|
||||
ALTER TABLE channel_overwrites ADD CONSTRAINT channel_overwrites_uniq
|
||||
UNIQUE (channel_id, target_role, target_user);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS features (
|
||||
id serial PRIMARY KEY,
|
||||
|
|
@ -479,11 +490,6 @@ CREATE TABLE IF NOT EXISTS bans (
|
|||
);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS embeds (
|
||||
-- TODO: this table
|
||||
id bigint PRIMARY KEY
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id bigint PRIMARY KEY,
|
||||
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||
|
|
@ -504,6 +510,8 @@ CREATE TABLE IF NOT EXISTS messages (
|
|||
tts bool default false,
|
||||
mention_everyone bool default false,
|
||||
|
||||
embeds jsonb DEFAULT '[]',
|
||||
|
||||
nonce bigint default 0,
|
||||
|
||||
message_type int NOT NULL
|
||||
|
|
@ -515,22 +523,22 @@ CREATE TABLE IF NOT EXISTS message_attachments (
|
|||
PRIMARY KEY (message_id, attachment)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS message_embeds (
|
||||
message_id bigint REFERENCES messages (id) UNIQUE,
|
||||
embed_id bigint REFERENCES embeds (id),
|
||||
PRIMARY KEY (message_id, embed_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS message_reactions (
|
||||
message_id bigint REFERENCES messages (id),
|
||||
user_id bigint REFERENCES users (id),
|
||||
|
||||
-- since it can be a custom emote, or unicode emoji
|
||||
react_ts timestamp without time zone default (now() at time zone 'utc'),
|
||||
|
||||
-- emoji_type = 0 -> custom emoji
|
||||
-- emoji_type = 1 -> unicode emoji
|
||||
emoji_type int DEFAULT 0,
|
||||
emoji_id bigint REFERENCES guild_emoji (id),
|
||||
emoji_text text NOT NULL,
|
||||
PRIMARY KEY (message_id, user_id, emoji_id, emoji_text)
|
||||
emoji_text text
|
||||
);
|
||||
|
||||
ALTER TABLE message_reactions ADD CONSTRAINT message_reactions_main_uniq
|
||||
UNIQUE (message_id, user_id, emoji_id, emoji_text);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS channel_pins (
|
||||
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||
message_id bigint REFERENCES messages (id) ON DELETE CASCADE,
|
||||
|
|
|
|||
Loading…
Reference in New Issue