mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'lazy-guilds' into 'master'
Lazy guilds Closes #12, #9, #8, and #2 See merge request luna/litecord!2
This commit is contained in:
commit
6245f08289
19
README.md
19
README.md
|
|
@ -7,8 +7,16 @@ This project is a rewrite of [litecord-reference].
|
||||||
|
|
||||||
[litecord-reference]: https://gitlab.com/luna/litecord-reference
|
[litecord-reference]: https://gitlab.com/luna/litecord-reference
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- There are no testing being run on the current codebase. Which means the code is definitely unstable.
|
||||||
|
- No voice is planned to be developed, for now.
|
||||||
|
- You must figure out connecting to the server yourself. Litecord will not distribute
|
||||||
|
Discord's official client code nor provide ways to modify the client.
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
|
Requirements:
|
||||||
- Python 3.6 or higher
|
- Python 3.6 or higher
|
||||||
- PostgreSQL
|
- PostgreSQL
|
||||||
- [Pipenv]
|
- [Pipenv]
|
||||||
|
|
@ -28,6 +36,10 @@ $ psql -f schema.sql litecord
|
||||||
# edit config.py as you wish
|
# edit config.py as you wish
|
||||||
$ cp config.example.py config.py
|
$ cp config.example.py config.py
|
||||||
|
|
||||||
|
# run database migrations (this is a
|
||||||
|
# required step in setup)
|
||||||
|
$ pipenv run ./manage.py migrate
|
||||||
|
|
||||||
# Install all packages:
|
# Install all packages:
|
||||||
$ pipenv install --dev
|
$ pipenv install --dev
|
||||||
```
|
```
|
||||||
|
|
@ -42,3 +54,10 @@ Use `--access-log -` to output access logs to stdout.
|
||||||
```sh
|
```sh
|
||||||
$ pipenv run hypercorn run:app
|
$ pipenv run hypercorn run:app
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Updating
|
||||||
|
|
||||||
|
```sh
|
||||||
|
$ git pull
|
||||||
|
$ pipenv run ./manage.py migrate
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,11 @@ log = Logger(__name__)
|
||||||
|
|
||||||
async def raw_token_check(token, db=None):
|
async def raw_token_check(token, db=None):
|
||||||
db = db or app.db
|
db = db or app.db
|
||||||
user_id, _hmac = token.split('.')
|
|
||||||
|
# just try by fragments instead of
|
||||||
|
# unpacking
|
||||||
|
fragments = token.split('.')
|
||||||
|
user_id = fragments[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = base64.b64decode(user_id.encode())
|
user_id = base64.b64decode(user_id.encode())
|
||||||
|
|
@ -35,6 +39,17 @@ async def raw_token_check(token, db=None):
|
||||||
try:
|
try:
|
||||||
signer.unsign(token)
|
signer.unsign(token)
|
||||||
log.debug('login for uid {} successful', user_id)
|
log.debug('login for uid {} successful', user_id)
|
||||||
|
|
||||||
|
# update the user's last_session field
|
||||||
|
# so that we can keep an exact track of activity,
|
||||||
|
# even on long-lived single sessions (that can happen
|
||||||
|
# with people leaving their clients open forever)
|
||||||
|
await db.execute("""
|
||||||
|
UPDATE users
|
||||||
|
SET last_session = (now() at time zone 'utc')
|
||||||
|
WHERE id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
except BadSignature:
|
except BadSignature:
|
||||||
log.warning('token failed for uid {}', user_id)
|
log.warning('token failed for uid {}', user_id)
|
||||||
|
|
@ -43,6 +58,12 @@ async def raw_token_check(token, db=None):
|
||||||
|
|
||||||
async def token_check():
|
async def token_check():
|
||||||
"""Check token information."""
|
"""Check token information."""
|
||||||
|
# first, check if the request info already has a uid
|
||||||
|
try:
|
||||||
|
return request.user_id
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = request.headers['Authorization']
|
token = request.headers['Authorization']
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
|
@ -51,4 +72,6 @@ async def token_check():
|
||||||
if token.startswith('Bot '):
|
if token.startswith('Bot '):
|
||||||
token = token.replace('Bot ', '')
|
token = token.replace('Bot ', '')
|
||||||
|
|
||||||
return await raw_token_check(token)
|
user_id = await raw_token_check(token)
|
||||||
|
request.user_id = user_id
|
||||||
|
return user_id
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ async def register():
|
||||||
|
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
||||||
new_discrim = str(random.randint(1, 9999))
|
new_discrim = random.randint(1, 9999)
|
||||||
new_discrim = '%04d' % new_discrim
|
new_discrim = '%04d' % new_discrim
|
||||||
|
|
||||||
pwd_hash = await hash_data(password)
|
pwd_hash = await hash_data(password)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .messages import bp as channel_messages
|
||||||
|
from .reactions import bp as channel_reactions
|
||||||
|
from .pins import bp as channel_pins
|
||||||
|
|
@ -0,0 +1,276 @@
|
||||||
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
||||||
|
from litecord.blueprints.auth import token_check
|
||||||
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
|
from litecord.blueprints.dms import try_dm_state
|
||||||
|
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||||
|
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||||
|
from litecord.snowflake import get_snowflake
|
||||||
|
from litecord.schemas import validate, MESSAGE_CREATE
|
||||||
|
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
bp = Blueprint('channel_messages', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_limit(request, default: int = 50):
|
||||||
|
try:
|
||||||
|
limit = int(request.args.get('limit', default))
|
||||||
|
|
||||||
|
if limit not in range(0, 100):
|
||||||
|
raise ValueError()
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise BadRequest('limit not int')
|
||||||
|
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
||||||
|
before, after = None, None
|
||||||
|
|
||||||
|
if 'around' in request.args:
|
||||||
|
average = int(limit / 2)
|
||||||
|
around = int(request.args['around'])
|
||||||
|
|
||||||
|
after = around - average
|
||||||
|
before = around + average
|
||||||
|
|
||||||
|
elif 'before' in request.args:
|
||||||
|
before = int(request.args['before'])
|
||||||
|
elif 'after' in request.args:
|
||||||
|
before = int(request.args['after'])
|
||||||
|
|
||||||
|
return before, after
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
||||||
|
async def get_messages(channel_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check READ_MESSAGE_HISTORY permission
|
||||||
|
ctype, peer_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
if ctype == ChannelType.DM:
|
||||||
|
# make sure both parties will be subbed
|
||||||
|
# to a dm
|
||||||
|
await _dm_pre_dispatch(channel_id, user_id)
|
||||||
|
await _dm_pre_dispatch(channel_id, peer_id)
|
||||||
|
|
||||||
|
limit = extract_limit(request, 50)
|
||||||
|
|
||||||
|
where_clause = ''
|
||||||
|
before, after = query_tuple_from_args(request.args, limit)
|
||||||
|
|
||||||
|
if before:
|
||||||
|
where_clause += f'AND id < {before}'
|
||||||
|
|
||||||
|
if after:
|
||||||
|
where_clause += f'AND id > {after}'
|
||||||
|
|
||||||
|
message_ids = await app.db.fetch(f"""
|
||||||
|
SELECT id
|
||||||
|
FROM messages
|
||||||
|
WHERE channel_id = $1 {where_clause}
|
||||||
|
ORDER BY id DESC
|
||||||
|
LIMIT {limit}
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for message_id in message_ids:
|
||||||
|
msg = await app.storage.get_message(message_id['id'], user_id)
|
||||||
|
|
||||||
|
if msg is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(msg)
|
||||||
|
|
||||||
|
log.info('Fetched {} messages', len(result))
|
||||||
|
return jsonify(result)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
|
||||||
|
async def get_single_message(channel_id, message_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
# TODO: check READ_MESSAGE_HISTORY permissions
|
||||||
|
message = await app.storage.get_message(message_id, user_id)
|
||||||
|
|
||||||
|
if not message:
|
||||||
|
raise MessageNotFound()
|
||||||
|
|
||||||
|
return jsonify(message)
|
||||||
|
|
||||||
|
|
||||||
|
async def _dm_pre_dispatch(channel_id, peer_id):
|
||||||
|
"""Do some checks pre-MESSAGE_CREATE so we
|
||||||
|
make sure the receiving party will handle everything."""
|
||||||
|
|
||||||
|
# check the other party's dm_channel_state
|
||||||
|
|
||||||
|
dm_state = await app.db.fetchval("""
|
||||||
|
SELECT dm_id
|
||||||
|
FROM dm_channel_state
|
||||||
|
WHERE user_id = $1 AND dm_id = $2
|
||||||
|
""", peer_id, channel_id)
|
||||||
|
|
||||||
|
if dm_state:
|
||||||
|
# the peer already has the channel
|
||||||
|
# opened, so we don't need to do anything
|
||||||
|
return
|
||||||
|
|
||||||
|
dm_chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
# dispatch CHANNEL_CREATE so the client knows which
|
||||||
|
# channel the future event is about
|
||||||
|
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
|
||||||
|
|
||||||
|
# subscribe the peer to the channel
|
||||||
|
await app.dispatcher.sub('channel', channel_id, peer_id)
|
||||||
|
|
||||||
|
# insert it on dm_channel_state so the client
|
||||||
|
# is subscribed on the future
|
||||||
|
await try_dm_state(peer_id, channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
||||||
|
async def create_message(channel_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
|
await channel_perm_check(user_id, channel_id, 'send_messages')
|
||||||
|
|
||||||
|
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||||
|
message_id = get_snowflake()
|
||||||
|
|
||||||
|
# TODO: check connection to the gateway
|
||||||
|
|
||||||
|
mentions_everyone = ('@everyone' in j['content'] and
|
||||||
|
await channel_perm_check(
|
||||||
|
user_id, channel_id, 'mention_everyone', False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
is_tts = (j.get('tts', False) and
|
||||||
|
await channel_perm_check(
|
||||||
|
user_id, channel_id, 'send_tts_messages', False
|
||||||
|
))
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
||||||
|
mention_everyone, nonce, message_type)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
""",
|
||||||
|
message_id,
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
j['content'],
|
||||||
|
|
||||||
|
is_tts,
|
||||||
|
mentions_everyone,
|
||||||
|
|
||||||
|
int(j.get('nonce', 0)),
|
||||||
|
MessageType.DEFAULT.value
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = await app.storage.get_message(message_id, user_id)
|
||||||
|
|
||||||
|
if ctype == ChannelType.DM:
|
||||||
|
# guild id here is the peer's ID.
|
||||||
|
await _dm_pre_dispatch(channel_id, user_id)
|
||||||
|
await _dm_pre_dispatch(channel_id, guild_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch('channel', channel_id,
|
||||||
|
'MESSAGE_CREATE', payload)
|
||||||
|
|
||||||
|
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
||||||
|
|
||||||
|
if ctype == ChannelType.GUILD_TEXT:
|
||||||
|
for str_uid in payload['mentions']:
|
||||||
|
uid = int(str_uid)
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count += 1
|
||||||
|
WHERE user_id = $1 AND channel_id = $2
|
||||||
|
""", uid, channel_id)
|
||||||
|
|
||||||
|
return jsonify(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
|
||||||
|
async def edit_message(channel_id, message_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
author_id = await app.db.fetchval("""
|
||||||
|
SELECT author_id FROM messages
|
||||||
|
WHERE messages.id = $1
|
||||||
|
""", message_id)
|
||||||
|
|
||||||
|
if not author_id == user_id:
|
||||||
|
raise Forbidden('You can not edit this message')
|
||||||
|
|
||||||
|
j = await request.get_json()
|
||||||
|
updated = 'content' in j or 'embed' in j
|
||||||
|
|
||||||
|
if 'content' in j:
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE messages
|
||||||
|
SET content=$1
|
||||||
|
WHERE messages.id = $2
|
||||||
|
""", j['content'], message_id)
|
||||||
|
|
||||||
|
# TODO: update embed
|
||||||
|
|
||||||
|
message = await app.storage.get_message(message_id, user_id)
|
||||||
|
|
||||||
|
# only dispatch MESSAGE_UPDATE if we actually had any update to start with
|
||||||
|
if updated:
|
||||||
|
await app.dispatcher.dispatch('channel', channel_id,
|
||||||
|
'MESSAGE_UPDATE', message)
|
||||||
|
|
||||||
|
return jsonify(message)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
|
||||||
|
async def delete_message(channel_id, message_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
author_id = await app.db.fetchval("""
|
||||||
|
SELECT author_id FROM messages
|
||||||
|
WHERE messages.id = $1
|
||||||
|
""", message_id)
|
||||||
|
|
||||||
|
by_perm = await channel_perm_check(
|
||||||
|
user_id, channel_id, 'manage_messages', False
|
||||||
|
)
|
||||||
|
|
||||||
|
by_ownership = author_id == user_id
|
||||||
|
|
||||||
|
if not by_perm and not by_ownership:
|
||||||
|
raise Forbidden('You can not delete this message')
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM messages
|
||||||
|
WHERE messages.id = $1
|
||||||
|
""", message_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id,
|
||||||
|
'MESSAGE_DELETE', {
|
||||||
|
'id': str(message_id),
|
||||||
|
'channel_id': str(channel_id),
|
||||||
|
|
||||||
|
# for lazy guilds
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
})
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
from quart import Blueprint, current_app as app, request, jsonify
|
||||||
|
|
||||||
|
from litecord.auth import token_check
|
||||||
|
from litecord.blueprints.checks import channel_check
|
||||||
|
from litecord.snowflake import snowflake_datetime
|
||||||
|
|
||||||
|
bp = Blueprint('channel_pins', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/pins', methods=['GET'])
|
||||||
|
async def get_pins(channel_id):
|
||||||
|
"""Get the pins for a channel"""
|
||||||
|
user_id = await token_check()
|
||||||
|
await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
ids = await app.db.fetch("""
|
||||||
|
SELECT message_id
|
||||||
|
FROM channel_pins
|
||||||
|
WHERE channel_id = $1
|
||||||
|
ORDER BY message_id ASC
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
ids = [r['message_id'] for r in ids]
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for message_id in ids:
|
||||||
|
message = await app.storage.get_message(message_id)
|
||||||
|
if message is not None:
|
||||||
|
res.append(message)
|
||||||
|
|
||||||
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
|
||||||
|
async def add_pin(channel_id, message_id):
|
||||||
|
"""Add a pin to a channel"""
|
||||||
|
user_id = await token_check()
|
||||||
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
# TODO: check MANAGE_MESSAGES permission
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO channel_pins (channel_id, message_id)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", channel_id, message_id)
|
||||||
|
|
||||||
|
row = await app.db.fetchrow("""
|
||||||
|
SELECT message_id
|
||||||
|
FROM channel_pins
|
||||||
|
WHERE channel_id = $1
|
||||||
|
ORDER BY message_id ASC
|
||||||
|
LIMIT 1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
timestamp = snowflake_datetime(row['message_id'])
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
|
||||||
|
'channel_id': str(channel_id),
|
||||||
|
'last_pin_timestamp': timestamp.isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
|
||||||
|
async def delete_pin(channel_id, message_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
# TODO: check MANAGE_MESSAGES permission
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM channel_pins
|
||||||
|
WHERE channel_id = $1 AND message_id = $2
|
||||||
|
""", channel_id, message_id)
|
||||||
|
|
||||||
|
row = await app.db.fetchrow("""
|
||||||
|
SELECT message_id
|
||||||
|
FROM channel_pins
|
||||||
|
WHERE channel_id = $1
|
||||||
|
ORDER BY message_id ASC
|
||||||
|
LIMIT 1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
timestamp = snowflake_datetime(row['message_id'])
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
|
||||||
|
'channel_id': str(channel_id),
|
||||||
|
'last_pin_timestamp': timestamp.isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
@ -0,0 +1,225 @@
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
||||||
|
from litecord.utils import async_map
|
||||||
|
from litecord.blueprints.auth import token_check
|
||||||
|
from litecord.blueprints.checks import channel_check
|
||||||
|
from litecord.blueprints.channel.messages import (
|
||||||
|
query_tuple_from_args, extract_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||||
|
from litecord.enums import GUILD_CHANS
|
||||||
|
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
bp = Blueprint('channel_reactions', __name__)
|
||||||
|
|
||||||
|
BASEPATH = '/<int:channel_id>/messages/<int:message_id>/reactions'
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiType(IntEnum):
|
||||||
|
CUSTOM = 0
|
||||||
|
UNICODE = 1
|
||||||
|
|
||||||
|
|
||||||
|
def emoji_info_from_str(emoji: str) -> tuple:
|
||||||
|
"""Extract emoji information from an emoji string
|
||||||
|
given on the reaction endpoints."""
|
||||||
|
# custom emoji have an emoji of name:id
|
||||||
|
# unicode emoji just have the raw unicode.
|
||||||
|
|
||||||
|
# try checking if the emoji is custom or unicode
|
||||||
|
emoji_type = 0 if ':' in emoji else 1
|
||||||
|
emoji_type = EmojiType(emoji_type)
|
||||||
|
|
||||||
|
# extract the emoji id OR the unicode value of the emoji
|
||||||
|
# depending if it is custom or not
|
||||||
|
emoji_id = (int(emoji.split(':')[1])
|
||||||
|
if emoji_type == EmojiType.CUSTOM
|
||||||
|
else emoji)
|
||||||
|
|
||||||
|
emoji_name = emoji.split(':')[0]
|
||||||
|
|
||||||
|
return emoji_type, emoji_id, emoji_name
|
||||||
|
|
||||||
|
|
||||||
|
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
|
||||||
|
print(emoji_type, emoji_id, emoji_name)
|
||||||
|
return {
|
||||||
|
'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||||
|
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_payload(user_id, channel_id, message_id, partial):
|
||||||
|
return {
|
||||||
|
'user_id': str(user_id),
|
||||||
|
'channel_id': str(channel_id),
|
||||||
|
'message_id': str(message_id),
|
||||||
|
'emoji': partial
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['PUT'])
|
||||||
|
async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
||||||
|
"""Put a reaction."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check READ_MESSAGE_HISTORY permission
|
||||||
|
# and ADD_REACTIONS. look on route docs.
|
||||||
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO message_reactions (message_id, user_id,
|
||||||
|
emoji_type, emoji_id, emoji_text)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
""", message_id, user_id, emoji_type,
|
||||||
|
|
||||||
|
# if it is custom, we put the emoji_id on emoji_id
|
||||||
|
# column, if it isn't, we put it on emoji_text
|
||||||
|
# column.
|
||||||
|
emoji_id if emoji_type == EmojiType.CUSTOM else None,
|
||||||
|
emoji_id if emoji_type == EmojiType.UNICODE else None
|
||||||
|
)
|
||||||
|
|
||||||
|
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||||
|
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||||
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
|
payload['guild_id'] = str(guild_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id, 'MESSAGE_REACTION_ADD', payload)
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
|
||||||
|
"""Extract SQL clauses to search for specific emoji
|
||||||
|
in the message_reactions table."""
|
||||||
|
param = f'${param}'
|
||||||
|
|
||||||
|
# know which column to filter with
|
||||||
|
where_ext = (f'AND emoji_id = {param}'
|
||||||
|
if emoji_type == EmojiType.CUSTOM else
|
||||||
|
f'AND emoji_text = {param}')
|
||||||
|
|
||||||
|
# which emoji to remove (custom or unicode)
|
||||||
|
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
|
||||||
|
|
||||||
|
return where_ext, main_emoji
|
||||||
|
|
||||||
|
|
||||||
|
def _emoji_sql_simple(emoji: str, param=4):
|
||||||
|
"""Simpler version of _emoji_sql for functions that
|
||||||
|
don't need the results from emoji_info_from_str."""
|
||||||
|
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||||
|
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_reaction(channel_id: int, message_id: int,
|
||||||
|
user_id: int, emoji: str):
|
||||||
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||||
|
where_ext, main_emoji = emoji_sql(emoji_type, emoji_id, emoji_name)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
|
DELETE FROM message_reactions
|
||||||
|
WHERE message_id = $1
|
||||||
|
AND user_id = $2
|
||||||
|
AND emoji_type = $3
|
||||||
|
{where_ext}
|
||||||
|
""", message_id, user_id, emoji_type, main_emoji)
|
||||||
|
|
||||||
|
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||||
|
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||||
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
|
payload['guild_id'] = str(guild_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['DELETE'])
|
||||||
|
async def remove_own_reaction(channel_id, message_id, emoji):
|
||||||
|
"""Remove a reaction."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
await remove_reaction(channel_id, message_id, user_id, emoji)
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route(f'{BASEPATH}/<emoji>/<int:other_id>', methods=['DELETE'])
|
||||||
|
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
|
||||||
|
"""Remove a reaction made by another user."""
|
||||||
|
await token_check()
|
||||||
|
|
||||||
|
# TODO: check MANAGE_MESSAGES permission (and use user_id
|
||||||
|
# from token_check to do it)
|
||||||
|
await remove_reaction(channel_id, message_id, other_id, emoji)
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route(f'{BASEPATH}/<emoji>', methods=['GET'])
|
||||||
|
async def list_users_reaction(channel_id, message_id, emoji):
|
||||||
|
"""Get the list of all users who reacted with a certain emoji."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# this is not using either ctype or guild_id
|
||||||
|
# that are returned by channel_check
|
||||||
|
await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
limit = extract_limit(request, 25)
|
||||||
|
before, after = query_tuple_from_args(request.args, limit)
|
||||||
|
|
||||||
|
before_clause = 'AND user_id < $2' if before else ''
|
||||||
|
after_clause = 'AND user_id > $3' if after else ''
|
||||||
|
|
||||||
|
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
|
||||||
|
|
||||||
|
rows = await app.db.fetch(f"""
|
||||||
|
SELECT user_id
|
||||||
|
FROM message_reactions
|
||||||
|
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
|
||||||
|
""", message_id, before, after, main_emoji)
|
||||||
|
|
||||||
|
user_ids = [r['user_id'] for r in rows]
|
||||||
|
users = await async_map(app.storage.get_user, user_ids)
|
||||||
|
return jsonify(users)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route(f'{BASEPATH}', methods=['DELETE'])
|
||||||
|
async def remove_all_reactions(channel_id, message_id):
|
||||||
|
"""Remove all reactions in a message."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check MANAGE_MESSAGES permission
|
||||||
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM message_reactions
|
||||||
|
WHERE message_id = $1
|
||||||
|
""", message_id)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'channel_id': str(channel_id),
|
||||||
|
'message_id': str(message_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
|
payload['guild_id'] = str(guild_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload)
|
||||||
|
|
@ -3,14 +3,14 @@ import time
|
||||||
from quart import Blueprint, request, current_app as app, jsonify
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from ..auth import token_check
|
from litecord.auth import token_check
|
||||||
from ..snowflake import get_snowflake, snowflake_datetime
|
from litecord.enums import ChannelType, GUILD_CHANS
|
||||||
from ..enums import ChannelType, MessageType, GUILD_CHANS
|
from litecord.errors import ChannelNotFound
|
||||||
from ..errors import Forbidden, ChannelNotFound, MessageNotFound
|
from litecord.schemas import (
|
||||||
from ..schemas import validate, MESSAGE_CREATE
|
validate, CHAN_UPDATE, CHAN_OVERWRITE
|
||||||
|
)
|
||||||
|
|
||||||
from .checks import channel_check, guild_check
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
from .dms import try_dm_state
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('channels', __name__)
|
bp = Blueprint('channels', __name__)
|
||||||
|
|
@ -136,6 +136,7 @@ async def guild_cleanup(channel_id):
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>', methods=['DELETE'])
|
@bp.route('/<int:channel_id>', methods=['DELETE'])
|
||||||
async def close_channel(channel_id):
|
async def close_channel(channel_id):
|
||||||
|
"""Close or delete a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
chan_type = await app.storage.get_chan_type(channel_id)
|
chan_type = await app.storage.get_chan_type(channel_id)
|
||||||
|
|
@ -212,287 +213,199 @@ async def close_channel(channel_id):
|
||||||
# TODO: group dm
|
# TODO: group dm
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return '', 404
|
raise ChannelNotFound()
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
async def _update_pos(channel_id, pos: int):
|
||||||
async def get_messages(channel_id):
|
await app.db.execute("""
|
||||||
user_id = await token_check()
|
UPDATE guild_channels
|
||||||
await channel_check(user_id, channel_id)
|
SET position = $1
|
||||||
|
WHERE id = $2
|
||||||
# TODO: before, after, around keys
|
""", pos, channel_id)
|
||||||
|
|
||||||
message_ids = await app.db.fetch(f"""
|
|
||||||
SELECT id
|
|
||||||
FROM messages
|
|
||||||
WHERE channel_id = $1
|
|
||||||
ORDER BY id DESC
|
|
||||||
LIMIT 100
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
result = []
|
|
||||||
|
|
||||||
for message_id in message_ids:
|
|
||||||
msg = await app.storage.get_message(message_id['id'])
|
|
||||||
|
|
||||||
if msg is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
result.append(msg)
|
|
||||||
|
|
||||||
log.info('Fetched {} messages', len(result))
|
|
||||||
return jsonify(result)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
|
async def _mass_chan_update(guild_id, channel_ids: int):
|
||||||
async def get_single_message(channel_id, message_id):
|
for channel_id in channel_ids:
|
||||||
user_id = await token_check()
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await channel_check(user_id, channel_id)
|
await app.dispatcher.dispatch(
|
||||||
|
'guild', guild_id, 'CHANNEL_UPDATE', chan)
|
||||||
# TODO: check READ_MESSAGE_HISTORY permissions
|
|
||||||
message = await app.storage.get_message(message_id)
|
|
||||||
|
|
||||||
if not message:
|
|
||||||
raise MessageNotFound()
|
|
||||||
|
|
||||||
return jsonify(message)
|
|
||||||
|
|
||||||
|
|
||||||
async def _dm_pre_dispatch(channel_id, peer_id):
|
async def _process_overwrites(channel_id: int, overwrites: list):
|
||||||
"""Do some checks pre-MESSAGE_CREATE so we
|
for overwrite in overwrites:
|
||||||
make sure the receiving party will handle everything."""
|
|
||||||
|
|
||||||
# check the other party's dm_channel_state
|
# 0 for user overwrite, 1 for role overwrite
|
||||||
|
target_type = 0 if overwrite['type'] == 'user' else 1
|
||||||
|
target_role = None if target_type == 0 else overwrite['id']
|
||||||
|
target_user = overwrite['id'] if target_type == 0 else None
|
||||||
|
|
||||||
dm_state = await app.db.fetchval("""
|
await app.db.execute(
|
||||||
SELECT dm_id
|
"""
|
||||||
FROM dm_channel_state
|
INSERT INTO channel_overwrites
|
||||||
WHERE user_id = $1 AND dm_id = $2
|
(channel_id, target_type, target_role,
|
||||||
""", peer_id, channel_id)
|
target_user, allow, deny)
|
||||||
|
VALUES
|
||||||
if dm_state:
|
($1, $2, $3, $4, $5, $6)
|
||||||
# the peer already has the channel
|
ON CONFLICT ON CONSTRAINT channel_overwrites_uniq
|
||||||
# opened, so we don't need to do anything
|
DO
|
||||||
return
|
UPDATE
|
||||||
|
SET allow = $5, deny = $6
|
||||||
dm_chan = await app.storage.get_channel(channel_id)
|
WHERE channel_overwrites.channel_id = $1
|
||||||
|
AND channel_overwrites.target_type = $2
|
||||||
# dispatch CHANNEL_CREATE so the client knows which
|
AND channel_overwrites.target_role = $3
|
||||||
# channel the future event is about
|
AND channel_overwrites.target_user = $4
|
||||||
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
|
""",
|
||||||
|
channel_id, target_type,
|
||||||
# subscribe the peer to the channel
|
target_role, target_user,
|
||||||
await app.dispatcher.sub('channel', channel_id, peer_id)
|
overwrite['allow'], overwrite['deny'])
|
||||||
|
|
||||||
# insert it on dm_channel_state so the client
|
|
||||||
# is subscribed on the future
|
|
||||||
await try_dm_state(peer_id, channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
@bp.route('/<int:channel_id>/permissions/<int:overwrite_id>', methods=['PUT'])
|
||||||
async def create_message(channel_id):
|
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||||
|
"""Insert or modify a channel overwrite."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
if ctype not in GUILD_CHANS:
|
||||||
message_id = get_snowflake()
|
raise ChannelNotFound('Only usable for guild channels.')
|
||||||
|
|
||||||
# TODO: check SEND_MESSAGES permission
|
await channel_perm_check(user_id, guild_id, 'manage_roles')
|
||||||
# TODO: check connection to the gateway
|
|
||||||
|
|
||||||
await app.db.execute(
|
j = validate(
|
||||||
"""
|
# inserting a fake id on the payload so validation passes through
|
||||||
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
{**await request.get_json(), **{'id': -1}},
|
||||||
mention_everyone, nonce, message_type)
|
CHAN_OVERWRITE
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
||||||
""",
|
|
||||||
message_id,
|
|
||||||
channel_id,
|
|
||||||
user_id,
|
|
||||||
j['content'],
|
|
||||||
|
|
||||||
# TODO: check SEND_TTS_MESSAGES
|
|
||||||
j.get('tts', False),
|
|
||||||
|
|
||||||
# TODO: check MENTION_EVERYONE permissions
|
|
||||||
'@everyone' in j['content'],
|
|
||||||
int(j.get('nonce', 0)),
|
|
||||||
MessageType.DEFAULT.value
|
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = await app.storage.get_message(message_id)
|
await _process_overwrites(channel_id, [{
|
||||||
|
'allow': j['allow'],
|
||||||
|
'deny': j['deny'],
|
||||||
|
'type': j['type'],
|
||||||
|
'id': overwrite_id
|
||||||
|
}])
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
await _mass_chan_update(guild_id, [channel_id])
|
||||||
# guild id here is the peer's ID.
|
return '', 204
|
||||||
await _dm_pre_dispatch(channel_id, guild_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch('channel', channel_id,
|
|
||||||
'MESSAGE_CREATE', payload)
|
|
||||||
|
|
||||||
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
|
||||||
|
|
||||||
if ctype == ChannelType.GUILD_TEXT:
|
|
||||||
for str_uid in payload['mentions']:
|
|
||||||
uid = int(str_uid)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE user_read_state
|
|
||||||
SET mention_count += 1
|
|
||||||
WHERE user_id = $1 AND channel_id = $2
|
|
||||||
""", uid, channel_id)
|
|
||||||
|
|
||||||
return jsonify(payload)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
|
async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
||||||
async def edit_message(channel_id, message_id):
|
if 'name' in j:
|
||||||
user_id = await token_check()
|
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
|
||||||
|
|
||||||
author_id = await app.db.fetchval("""
|
|
||||||
SELECT author_id FROM messages
|
|
||||||
WHERE messages.id = $1
|
|
||||||
""", message_id)
|
|
||||||
|
|
||||||
if not author_id == user_id:
|
|
||||||
raise Forbidden('You can not edit this message')
|
|
||||||
|
|
||||||
j = await request.get_json()
|
|
||||||
updated = 'content' in j or 'embed' in j
|
|
||||||
|
|
||||||
if 'content' in j:
|
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
UPDATE messages
|
UPDATE guild_channels
|
||||||
SET content=$1
|
SET name = $1
|
||||||
WHERE messages.id = $2
|
WHERE id = $2
|
||||||
""", j['content'], message_id)
|
""", j['name'], channel_id)
|
||||||
|
|
||||||
# TODO: update embed
|
if 'position' in j:
|
||||||
|
channel_data = await app.storage.get_channel_data(guild_id)
|
||||||
|
|
||||||
message = await app.storage.get_message(message_id)
|
chans = [None * len(channel_data)]
|
||||||
|
for chandata in channel_data:
|
||||||
|
chans.insert(chandata['position'], int(chandata['id']))
|
||||||
|
|
||||||
# only dispatch MESSAGE_UPDATE if we actually had any update to start with
|
# are we changing to the left or to the right?
|
||||||
if updated:
|
|
||||||
await app.dispatcher.dispatch('channel', channel_id,
|
|
||||||
'MESSAGE_UPDATE', message)
|
|
||||||
|
|
||||||
return jsonify(message)
|
# left: [channel1, channel2, ..., channelN-1, channelN]
|
||||||
|
# becomes
|
||||||
|
# [channel1, channelN-1, channel2, ..., channelN]
|
||||||
|
# so we can say that the "main change" is
|
||||||
|
# channelN-1 going to the position channel2
|
||||||
|
# was occupying.
|
||||||
|
current_pos = chans.index(channel_id)
|
||||||
|
new_pos = j['position']
|
||||||
|
|
||||||
|
# if the new position is bigger than the current one,
|
||||||
|
# we're making a left shift of all the channels that are
|
||||||
|
# beyond the current one, to make space
|
||||||
|
left_shift = new_pos > current_pos
|
||||||
|
|
||||||
|
# find all channels that we'll have to shift
|
||||||
|
shift_block = (chans[current_pos:new_pos]
|
||||||
|
if left_shift else
|
||||||
|
chans[new_pos:current_pos]
|
||||||
|
)
|
||||||
|
|
||||||
|
shift = -1 if left_shift else 1
|
||||||
|
|
||||||
|
# do the shift (to the left or to the right)
|
||||||
|
await app.db.executemany("""
|
||||||
|
UPDATE guild_channels
|
||||||
|
SET position = position + $1
|
||||||
|
WHERE id = $2
|
||||||
|
""", [(shift, chan_id) for chan_id in shift_block])
|
||||||
|
|
||||||
|
await _mass_chan_update(guild_id, shift_block)
|
||||||
|
|
||||||
|
# since theres now an empty slot, move current channel to it
|
||||||
|
await _update_pos(channel_id, new_pos)
|
||||||
|
|
||||||
|
if 'channel_overwrites' in j:
|
||||||
|
overwrites = j['channel_overwrites']
|
||||||
|
await _process_overwrites(channel_id, overwrites)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
|
async def _common_guild_chan(channel_id, j: dict):
|
||||||
async def delete_message(channel_id, message_id):
|
# common updates to the guild_channels table
|
||||||
|
for field in [field for field in j.keys()
|
||||||
|
if field in ('nsfw', 'parent_id')]:
|
||||||
|
await app.db.execute(f"""
|
||||||
|
UPDATE guild_channels
|
||||||
|
SET {field} = $1
|
||||||
|
WHERE id = $2
|
||||||
|
""", j[field], channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_text_channel(channel_id: int, j: dict):
|
||||||
|
# first do the specific ones related to guild_text_channels
|
||||||
|
for field in [field for field in j.keys()
|
||||||
|
if field in ('topic', 'rate_limit_per_user')]:
|
||||||
|
await app.db.execute(f"""
|
||||||
|
UPDATE guild_text_channels
|
||||||
|
SET {field} = $1
|
||||||
|
WHERE id = $2
|
||||||
|
""", j[field], channel_id)
|
||||||
|
|
||||||
|
await _common_guild_chan(channel_id, j)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_voice_channel(channel_id: int, j: dict):
|
||||||
|
# first do the specific ones in guild_voice_channels
|
||||||
|
for field in [field for field in j.keys()
|
||||||
|
if field in ('bitrate', 'user_limit')]:
|
||||||
|
await app.db.execute(f"""
|
||||||
|
UPDATE guild_voice_channels
|
||||||
|
SET {field} = $1
|
||||||
|
WHERE id = $2
|
||||||
|
""", j[field], channel_id)
|
||||||
|
|
||||||
|
# yes, i'm letting voice channels have nsfw, you cant stop me
|
||||||
|
await _common_guild_chan(channel_id, j)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
|
||||||
|
async def update_channel(channel_id):
|
||||||
|
"""Update a channel's information"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
author_id = await app.db.fetchval("""
|
if ctype not in GUILD_CHANS:
|
||||||
SELECT author_id FROM messages
|
raise ChannelNotFound('Can not edit non-guild channels.')
|
||||||
WHERE messages.id = $1
|
|
||||||
""", message_id)
|
|
||||||
|
|
||||||
# TODO: MANAGE_MESSAGES permission check
|
await channel_perm_check(user_id, channel_id, 'manage_channels')
|
||||||
if author_id != user_id:
|
j = validate(await request.get_json(), CHAN_UPDATE)
|
||||||
raise Forbidden('You can not delete this message')
|
|
||||||
|
|
||||||
await app.db.execute("""
|
# TODO: categories?
|
||||||
DELETE FROM messages
|
update_handler = {
|
||||||
WHERE messages.id = $1
|
ChannelType.GUILD_TEXT: _update_text_channel,
|
||||||
""", message_id)
|
ChannelType.GUILD_VOICE: _update_voice_channel,
|
||||||
|
}[ctype]
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await _update_channel_common(channel_id, guild_id, j)
|
||||||
'channel', channel_id,
|
await update_handler(channel_id, j)
|
||||||
'MESSAGE_DELETE', {
|
|
||||||
'id': str(message_id),
|
|
||||||
'channel_id': str(channel_id),
|
|
||||||
|
|
||||||
# for lazy guilds
|
chan = await app.storage.get_channel(channel_id)
|
||||||
'guild_id': str(guild_id),
|
await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan)
|
||||||
})
|
return jsonify(chan)
|
||||||
|
|
||||||
return '', 204
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/pins', methods=['GET'])
|
|
||||||
async def get_pins(channel_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
await channel_check(user_id, channel_id)
|
|
||||||
|
|
||||||
ids = await app.db.fetch("""
|
|
||||||
SELECT message_id
|
|
||||||
FROM channel_pins
|
|
||||||
WHERE channel_id = $1
|
|
||||||
ORDER BY message_id ASC
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
ids = [r['message_id'] for r in ids]
|
|
||||||
res = []
|
|
||||||
|
|
||||||
for message_id in ids:
|
|
||||||
message = await app.storage.get_message(message_id)
|
|
||||||
if message is not None:
|
|
||||||
res.append(message)
|
|
||||||
|
|
||||||
return jsonify(message)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
|
|
||||||
async def add_pin(channel_id, message_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
|
||||||
|
|
||||||
# TODO: check MANAGE_MESSAGES permission
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
INSERT INTO channel_pins (channel_id, message_id)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
""", channel_id, message_id)
|
|
||||||
|
|
||||||
row = await app.db.fetchrow("""
|
|
||||||
SELECT message_id
|
|
||||||
FROM channel_pins
|
|
||||||
WHERE channel_id = $1
|
|
||||||
ORDER BY message_id ASC
|
|
||||||
LIMIT 1
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
timestamp = snowflake_datetime(row['message_id'])
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
|
|
||||||
'channel_id': str(channel_id),
|
|
||||||
'last_pin_timestamp': timestamp.isoformat()
|
|
||||||
})
|
|
||||||
|
|
||||||
return '', 204
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
|
|
||||||
async def delete_pin(channel_id, message_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
|
||||||
|
|
||||||
# TODO: check MANAGE_MESSAGES permission
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
DELETE FROM channel_pins
|
|
||||||
WHERE channel_id = $1 AND message_id = $2
|
|
||||||
""", channel_id, message_id)
|
|
||||||
|
|
||||||
row = await app.db.fetchrow("""
|
|
||||||
SELECT message_id
|
|
||||||
FROM channel_pins
|
|
||||||
WHERE channel_id = $1
|
|
||||||
ORDER BY message_id ASC
|
|
||||||
LIMIT 1
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
timestamp = snowflake_datetime(row['message_id'])
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
|
||||||
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
|
|
||||||
'channel_id': str(channel_id),
|
|
||||||
'last_pin_timestamp': timestamp.isoformat()
|
|
||||||
})
|
|
||||||
|
|
||||||
return '', 204
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/typing', methods=['POST'])
|
@bp.route('/<int:channel_id>/typing', methods=['POST'])
|
||||||
|
|
@ -518,21 +431,18 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||||
if not message_id:
|
if not message_id:
|
||||||
message_id = await app.storage.chan_last_message(channel_id)
|
message_id = await app.storage.chan_last_message(channel_id)
|
||||||
|
|
||||||
res = await app.db.execute("""
|
await app.db.execute("""
|
||||||
UPDATE user_read_state
|
INSERT INTO user_read_state
|
||||||
|
(user_id, channel_id, last_message_id, mention_count)
|
||||||
SET last_message_id = $1,
|
VALUES
|
||||||
mention_count = 0
|
($1, $2, $3, 0)
|
||||||
|
ON CONFLICT ON CONSTRAINT user_read_state_pkey
|
||||||
WHERE user_id = $2 AND channel_id = $3
|
DO
|
||||||
""", message_id, user_id, channel_id)
|
UPDATE
|
||||||
|
SET last_message_id = $3, mention_count = 0
|
||||||
if res == 'UPDATE 0':
|
WHERE user_read_state.user_id = $1
|
||||||
await app.db.execute("""
|
AND user_read_state.channel_id = $2
|
||||||
INSERT INTO user_read_state
|
""", user_id, channel_id, message_id)
|
||||||
(user_id, channel_id, last_message_id, mention_count)
|
|
||||||
VALUES ($1, $2, $3, $4)
|
|
||||||
""", user_id, channel_id, message_id, 0)
|
|
||||||
|
|
||||||
if guild_id:
|
if guild_id:
|
||||||
await app.dispatcher.dispatch_user_guild(
|
await app.dispatcher.dispatch_user_guild(
|
||||||
|
|
@ -551,6 +461,7 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
|
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
|
||||||
async def ack_channel(channel_id, message_id):
|
async def ack_channel(channel_id, message_id):
|
||||||
|
"""Acknowledge a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
|
@ -569,6 +480,7 @@ async def ack_channel(channel_id, message_id):
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE'])
|
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE'])
|
||||||
async def delete_read_state(channel_id):
|
async def delete_read_state(channel_id):
|
||||||
|
"""Delete the read state of a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await channel_check(user_id, channel_id)
|
await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
from ..enums import ChannelType, GUILD_CHANS
|
from litecord.enums import ChannelType, GUILD_CHANS
|
||||||
from ..errors import GuildNotFound, ChannelNotFound
|
from litecord.errors import (
|
||||||
|
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
|
||||||
|
)
|
||||||
|
from litecord.permissions import base_permissions, get_permissions
|
||||||
|
|
||||||
|
|
||||||
async def guild_check(user_id: int, guild_id: int):
|
async def guild_check(user_id: int, guild_id: int):
|
||||||
|
|
@ -16,6 +19,21 @@ async def guild_check(user_id: int, guild_id: int):
|
||||||
raise GuildNotFound('guild not found')
|
raise GuildNotFound('guild not found')
|
||||||
|
|
||||||
|
|
||||||
|
async def guild_owner_check(user_id: int, guild_id: int):
|
||||||
|
"""Check if a user is the owner of the guild."""
|
||||||
|
owner_id = await app.db.fetchval("""
|
||||||
|
SELECT owner_id
|
||||||
|
FROM guilds
|
||||||
|
WHERE guilds.id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
if not owner_id:
|
||||||
|
raise GuildNotFound()
|
||||||
|
|
||||||
|
if user_id != owner_id:
|
||||||
|
raise Forbidden('You are not the owner of the guild')
|
||||||
|
|
||||||
|
|
||||||
async def channel_check(user_id, channel_id):
|
async def channel_check(user_id, channel_id):
|
||||||
"""Check if the current user is authorized
|
"""Check if the current user is authorized
|
||||||
to read the channel's information."""
|
to read the channel's information."""
|
||||||
|
|
@ -39,3 +57,27 @@ async def channel_check(user_id, channel_id):
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
||||||
return ctype, peer_id
|
return ctype, peer_id
|
||||||
|
|
||||||
|
|
||||||
|
async def guild_perm_check(user_id, guild_id, permission: str):
|
||||||
|
"""Check guild permissions for a user."""
|
||||||
|
base_perms = await base_permissions(user_id, guild_id)
|
||||||
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
||||||
|
if not hasperm:
|
||||||
|
raise MissingPermissions('Missing permissions.')
|
||||||
|
|
||||||
|
|
||||||
|
async def channel_perm_check(user_id, channel_id,
|
||||||
|
permission: str, raise_err=True):
|
||||||
|
"""Check channel permissions for a user."""
|
||||||
|
base_perms = await get_permissions(user_id, channel_id)
|
||||||
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
||||||
|
print(base_perms)
|
||||||
|
print(base_perms.binary)
|
||||||
|
|
||||||
|
if not hasperm and raise_err:
|
||||||
|
raise MissingPermissions('Missing permissions.')
|
||||||
|
|
||||||
|
return hasperm
|
||||||
|
|
|
||||||
|
|
@ -38,41 +38,47 @@ async def try_dm_state(user_id: int, dm_id: int):
|
||||||
""", user_id, dm_id)
|
""", user_id, dm_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def jsonify_dm(dm_id: int, user_id: int):
|
||||||
|
dm_chan = await app.storage.get_dm(dm_id, user_id)
|
||||||
|
return jsonify(dm_chan)
|
||||||
|
|
||||||
|
|
||||||
async def create_dm(user_id, recipient_id):
|
async def create_dm(user_id, recipient_id):
|
||||||
"""Create a new dm with a user,
|
"""Create a new dm with a user,
|
||||||
or get the existing DM id if it already exists."""
|
or get the existing DM id if it already exists."""
|
||||||
|
|
||||||
|
dm_id = await app.db.fetchval("""
|
||||||
|
SELECT id
|
||||||
|
FROM dm_channels
|
||||||
|
WHERE (party1_id = $1 OR party2_id = $1) AND
|
||||||
|
(party1_id = $2 OR party2_id = $2)
|
||||||
|
""", user_id, recipient_id)
|
||||||
|
|
||||||
|
if dm_id:
|
||||||
|
return await jsonify_dm(dm_id, user_id)
|
||||||
|
|
||||||
|
# if no dm was found, create a new one
|
||||||
|
|
||||||
dm_id = get_snowflake()
|
dm_id = get_snowflake()
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO channels (id, channel_type)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", dm_id, ChannelType.DM.value)
|
||||||
|
|
||||||
try:
|
await app.db.execute("""
|
||||||
await app.db.execute("""
|
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||||
INSERT INTO channels (id, channel_type)
|
VALUES ($1, $2, $3)
|
||||||
VALUES ($1, $2)
|
""", dm_id, user_id, recipient_id)
|
||||||
""", dm_id, ChannelType.DM.value)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
# the dm state is something we use
|
||||||
INSERT INTO dm_channels (id, party1_id, party2_id)
|
# to give the currently "open dms"
|
||||||
VALUES ($1, $2, $3)
|
# on the client.
|
||||||
""", dm_id, user_id, recipient_id)
|
|
||||||
|
|
||||||
# the dm state is something we use
|
# we don't open a dm for the peer/recipient
|
||||||
# to give the currently "open dms"
|
# until the user sends a message.
|
||||||
# on the client.
|
await try_dm_state(user_id, dm_id)
|
||||||
|
|
||||||
# we don't open a dm for the peer/recipient
|
return await jsonify_dm(dm_id, user_id)
|
||||||
# until the user sends a message.
|
|
||||||
await try_dm_state(user_id, dm_id)
|
|
||||||
|
|
||||||
except UniqueViolationError:
|
|
||||||
# the dm already exists
|
|
||||||
dm_id = await app.db.fetchval("""
|
|
||||||
SELECT id
|
|
||||||
FROM dm_channels
|
|
||||||
WHERE (party1_id = $1 OR party2_id = $1) AND
|
|
||||||
(party2_id = $2 OR party2_id = $2)
|
|
||||||
""", user_id, recipient_id)
|
|
||||||
|
|
||||||
dm = await app.storage.get_dm(dm_id, user_id)
|
|
||||||
return jsonify(dm)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/channels', methods=['POST'])
|
@bp.route('/@me/channels', methods=['POST'])
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import time
|
||||||
|
|
||||||
from quart import Blueprint, jsonify, current_app as app
|
from quart import Blueprint, jsonify, current_app as app
|
||||||
|
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
|
|
@ -6,12 +8,14 @@ bp = Blueprint('gateway', __name__)
|
||||||
|
|
||||||
|
|
||||||
def get_gw():
|
def get_gw():
|
||||||
|
"""Get the gateway's web"""
|
||||||
proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
|
proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
|
||||||
return f'{proto}{app.config["WEBSOCKET_URL"]}/ws'
|
return f'{proto}{app.config["WEBSOCKET_URL"]}/ws'
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/gateway')
|
@bp.route('/gateway')
|
||||||
def api_gateway():
|
def api_gateway():
|
||||||
|
"""Get the raw URL."""
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'url': get_gw()
|
'url': get_gw()
|
||||||
})
|
})
|
||||||
|
|
@ -27,9 +31,25 @@ async def api_gateway_bot():
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""", user_id)
|
||||||
|
|
||||||
shards = max(int(guild_count / 1200), 1)
|
shards = max(int(guild_count / 1000), 1)
|
||||||
|
|
||||||
|
# get _ws.session ratelimit
|
||||||
|
ratelimit = app.ratelimiter.get_ratelimit('_ws.session')
|
||||||
|
bucket = ratelimit.get_bucket(user_id)
|
||||||
|
|
||||||
|
# timestamp of bucket reset
|
||||||
|
reset_ts = bucket._window + bucket.second
|
||||||
|
|
||||||
|
# how many seconds until bucket reset
|
||||||
|
reset_after_ts = reset_ts - time.time()
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'url': get_gw(),
|
'url': get_gw(),
|
||||||
'shards': shards,
|
'shards': shards,
|
||||||
|
|
||||||
|
'session_start_limit': {
|
||||||
|
'total': bucket.requests,
|
||||||
|
'remaining': bucket._tokens,
|
||||||
|
'reset_after': int(reset_after_ts * 1000),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .roles import bp as guild_roles
|
||||||
|
from .members import bp as guild_members
|
||||||
|
from .channels import bp as guild_channels
|
||||||
|
from .mod import bp as guild_mod
|
||||||
|
|
@ -0,0 +1,180 @@
|
||||||
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
|
from litecord.blueprints.auth import token_check
|
||||||
|
from litecord.blueprints.checks import guild_check, guild_owner_check
|
||||||
|
from litecord.snowflake import get_snowflake
|
||||||
|
from litecord.errors import BadRequest
|
||||||
|
from litecord.enums import ChannelType
|
||||||
|
from litecord.schemas import (
|
||||||
|
validate, ROLE_UPDATE_POSITION
|
||||||
|
)
|
||||||
|
from litecord.blueprints.guild.roles import gen_pairs
|
||||||
|
|
||||||
|
|
||||||
|
bp = Blueprint('guild_channels', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||||
|
if ctype == ChannelType.GUILD_TEXT:
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO guild_text_channels (id, topic)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", channel_id, kwargs.get('topic', ''))
|
||||||
|
elif ctype == ChannelType.GUILD_VOICE:
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
kwargs.get('bitrate', 64),
|
||||||
|
kwargs.get('user_limit', 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_guild_channel(guild_id: int, channel_id: int,
|
||||||
|
ctype: ChannelType, **kwargs):
|
||||||
|
"""Create a channel in a guild."""
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO channels (id, channel_type)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", channel_id, ctype.value)
|
||||||
|
|
||||||
|
# calc new pos
|
||||||
|
max_pos = await app.db.fetchval("""
|
||||||
|
SELECT MAX(position)
|
||||||
|
FROM guild_channels
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
# account for the first channel in a guild too
|
||||||
|
max_pos = max_pos or 0
|
||||||
|
|
||||||
|
# all channels go to guild_channels
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
""", channel_id, guild_id, kwargs['name'], max_pos + 1)
|
||||||
|
|
||||||
|
# the rest of sql magic is dependant on the channel
|
||||||
|
# we're creating (a text or voice or category),
|
||||||
|
# so we use this function.
|
||||||
|
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild>/channels', methods=['GET'])
|
||||||
|
async def get_guild_channels(guild_id):
|
||||||
|
"""Get the list of channels in a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
await app.storage.get_channel_data(guild_id))
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/channels', methods=['POST'])
|
||||||
|
async def create_channel(guild_id):
|
||||||
|
"""Create a channel in a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
j = await request.get_json()
|
||||||
|
|
||||||
|
# TODO: check permissions for MANAGE_CHANNELS
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
channel_type = j.get('type', ChannelType.GUILD_TEXT)
|
||||||
|
channel_type = ChannelType(channel_type)
|
||||||
|
|
||||||
|
if channel_type not in (ChannelType.GUILD_TEXT,
|
||||||
|
ChannelType.GUILD_VOICE):
|
||||||
|
raise BadRequest('Invalid channel type')
|
||||||
|
|
||||||
|
new_channel_id = get_snowflake()
|
||||||
|
await create_guild_channel(
|
||||||
|
guild_id, new_channel_id, channel_type, **j)
|
||||||
|
|
||||||
|
# TODO: do a better method
|
||||||
|
# subscribe the currently subscribed users to the new channel
|
||||||
|
# by getting all user ids and subscribing each one by one.
|
||||||
|
|
||||||
|
# since GuildDispatcher calls Storage.get_channel_ids,
|
||||||
|
# it will subscribe all users to the newly created channel.
|
||||||
|
guild_pubsub = app.dispatcher.backends['guild']
|
||||||
|
user_ids = guild_pubsub.state[guild_id]
|
||||||
|
for uid in user_ids:
|
||||||
|
await app.dispatcher.sub('guild', guild_id, uid)
|
||||||
|
|
||||||
|
chan = await app.storage.get_channel(new_channel_id)
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'CHANNEL_CREATE', chan)
|
||||||
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
||||||
|
async def _chan_update_dispatch(guild_id: int, channel_id: int):
|
||||||
|
"""Fetch new information about the channel and dispatch
|
||||||
|
a single CHANNEL_UPDATE event to the guild."""
|
||||||
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan)
|
||||||
|
|
||||||
|
|
||||||
|
async def _do_single_swap(guild_id: int, pair: tuple):
|
||||||
|
"""Do a single channel swap, dispatching
|
||||||
|
the CHANNEL_UPDATE events for after the swap"""
|
||||||
|
pair1, pair2 = pair
|
||||||
|
channel_1, new_pos_1 = pair1
|
||||||
|
channel_2, new_pos_2 = pair2
|
||||||
|
|
||||||
|
# do the swap in a transaction.
|
||||||
|
conn = await app.db.acquire()
|
||||||
|
|
||||||
|
async with conn.transaction():
|
||||||
|
await conn.executemany("""
|
||||||
|
UPDATE guild_channels
|
||||||
|
SET position = $1
|
||||||
|
WHERE id = $2 AND guild_id = $3
|
||||||
|
""", [
|
||||||
|
(new_pos_1, channel_1, guild_id),
|
||||||
|
(new_pos_2, channel_2, guild_id)])
|
||||||
|
|
||||||
|
await app.db.release(conn)
|
||||||
|
|
||||||
|
await _chan_update_dispatch(guild_id, channel_1)
|
||||||
|
await _chan_update_dispatch(guild_id, channel_2)
|
||||||
|
|
||||||
|
|
||||||
|
async def _do_channel_swaps(guild_id: int, swap_pairs: list):
|
||||||
|
"""Swap channel pairs' positions, given the list
|
||||||
|
of pairs to do.
|
||||||
|
|
||||||
|
Dispatches CHANNEL_UPDATEs to the guild.
|
||||||
|
"""
|
||||||
|
for pair in swap_pairs:
|
||||||
|
await _do_single_swap(guild_id, pair)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
|
||||||
|
async def modify_channel_pos(guild_id):
|
||||||
|
"""Change positions of channels in a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check MANAGE_CHANNELS
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
# same thing as guild.roles, so we use
|
||||||
|
# the same schema and all.
|
||||||
|
raw_j = await request.get_json()
|
||||||
|
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
||||||
|
j = j['roles']
|
||||||
|
|
||||||
|
channels = await app.storage.get_channel_data(guild_id)
|
||||||
|
|
||||||
|
channel_positions = {chan['position']: int(chan['id'])
|
||||||
|
for chan in channels}
|
||||||
|
|
||||||
|
swap_pairs = gen_pairs(
|
||||||
|
j,
|
||||||
|
channel_positions
|
||||||
|
)
|
||||||
|
|
||||||
|
await _do_channel_swaps(guild_id, swap_pairs)
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
@ -0,0 +1,172 @@
|
||||||
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
|
from litecord.blueprints.auth import token_check
|
||||||
|
from litecord.blueprints.checks import guild_check
|
||||||
|
from litecord.errors import BadRequest
|
||||||
|
from litecord.schemas import (
|
||||||
|
validate, MEMBER_UPDATE
|
||||||
|
)
|
||||||
|
from litecord.blueprints.checks import guild_owner_check
|
||||||
|
|
||||||
|
|
||||||
|
bp = Blueprint('guild_members', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
|
||||||
|
async def get_guild_member(guild_id, member_id):
|
||||||
|
"""Get a member's information in a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
member = await app.storage.get_single_member(guild_id, member_id)
|
||||||
|
return jsonify(member)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/members', methods=['GET'])
|
||||||
|
async def get_members(guild_id):
|
||||||
|
"""Get members inside a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
j = await request.get_json()
|
||||||
|
|
||||||
|
limit, after = int(j.get('limit', 1)), j.get('after', 0)
|
||||||
|
|
||||||
|
if limit < 1 or limit > 1000:
|
||||||
|
raise BadRequest('limit not in 1-1000 range')
|
||||||
|
|
||||||
|
user_ids = await app.db.fetch(f"""
|
||||||
|
SELECT user_id
|
||||||
|
WHERE guild_id = $1, user_id > $2
|
||||||
|
LIMIT {limit}
|
||||||
|
ORDER BY user_id ASC
|
||||||
|
""", guild_id, after)
|
||||||
|
|
||||||
|
user_ids = [r[0] for r in user_ids]
|
||||||
|
members = await app.storage.get_member_multi(guild_id, user_ids)
|
||||||
|
return jsonify(members)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_member_roles(guild_id: int, member_id: int,
|
||||||
|
wanted_roles: list):
|
||||||
|
"""Update the roles a member has."""
|
||||||
|
|
||||||
|
# first, fetch all current roles
|
||||||
|
roles = await app.db.fetch("""
|
||||||
|
SELECT role_id from member_roles
|
||||||
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
|
""", guild_id, member_id)
|
||||||
|
|
||||||
|
roles = [r['role_id'] for r in roles]
|
||||||
|
|
||||||
|
roles = set(roles)
|
||||||
|
wanted_roles = set(wanted_roles)
|
||||||
|
|
||||||
|
# first, we need to find all added roles:
|
||||||
|
# roles that are on wanted_roles but
|
||||||
|
# not on roles
|
||||||
|
added_roles = wanted_roles - roles
|
||||||
|
|
||||||
|
# and then the removed roles
|
||||||
|
# which are roles in roles, but not
|
||||||
|
# in wanted_roles
|
||||||
|
removed_roles = roles - wanted_roles
|
||||||
|
|
||||||
|
conn = await app.db.acquire()
|
||||||
|
|
||||||
|
async with conn.transaction():
|
||||||
|
# add roles
|
||||||
|
await app.db.executemany("""
|
||||||
|
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
""", [(member_id, guild_id, role_id)
|
||||||
|
for role_id in added_roles])
|
||||||
|
|
||||||
|
# remove roles
|
||||||
|
await app.db.executemany("""
|
||||||
|
DELETE FROM member_roles
|
||||||
|
WHERE
|
||||||
|
user_id = $1
|
||||||
|
AND guild_id = $2
|
||||||
|
AND role_id = $3
|
||||||
|
""", [(member_id, guild_id, role_id)
|
||||||
|
for role_id in removed_roles])
|
||||||
|
|
||||||
|
await app.db.release(conn)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
|
||||||
|
async def modify_guild_member(guild_id, member_id):
|
||||||
|
"""Modify a members' information in a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
j = validate(await request.get_json(), MEMBER_UPDATE)
|
||||||
|
|
||||||
|
if 'nick' in j:
|
||||||
|
# TODO: check MANAGE_NICKNAMES
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE members
|
||||||
|
SET nickname = $1
|
||||||
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
|
""", j['nick'], member_id, guild_id)
|
||||||
|
|
||||||
|
if 'mute' in j:
|
||||||
|
# TODO: check MUTE_MEMBERS
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE members
|
||||||
|
SET muted = $1
|
||||||
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
|
""", j['mute'], member_id, guild_id)
|
||||||
|
|
||||||
|
if 'deaf' in j:
|
||||||
|
# TODO: check DEAFEN_MEMBERS
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE members
|
||||||
|
SET deafened = $1
|
||||||
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
|
""", j['deaf'], member_id, guild_id)
|
||||||
|
|
||||||
|
if 'channel_id' in j:
|
||||||
|
# TODO: check MOVE_MEMBERS and CONNECT to the channel
|
||||||
|
# TODO: change the member's voice channel
|
||||||
|
pass
|
||||||
|
|
||||||
|
if 'roles' in j:
|
||||||
|
# TODO: check permissions
|
||||||
|
await _update_member_roles(guild_id, member_id, j['roles'])
|
||||||
|
|
||||||
|
member = await app.storage.get_member_data_one(guild_id, member_id)
|
||||||
|
member.pop('joined_at')
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||||
|
'guild_id': str(guild_id)
|
||||||
|
}, **member})
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
|
||||||
|
async def update_nickname(guild_id):
|
||||||
|
"""Update a member's nickname in a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
j = await request.get_json()
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE members
|
||||||
|
SET nickname = $1
|
||||||
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
|
""", j['nick'], user_id, guild_id)
|
||||||
|
|
||||||
|
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||||
|
member.pop('joined_at')
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||||
|
'guild_id': str(guild_id)
|
||||||
|
}, **member})
|
||||||
|
|
||||||
|
return j['nick']
|
||||||
|
|
@ -0,0 +1,185 @@
|
||||||
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
|
from litecord.blueprints.auth import token_check
|
||||||
|
from litecord.blueprints.checks import guild_owner_check
|
||||||
|
|
||||||
|
from litecord.schemas import validate, GUILD_PRUNE
|
||||||
|
|
||||||
|
bp = Blueprint('guild_moderation', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_member(guild_id: int, member_id: int):
|
||||||
|
"""Do common tasks related to deleting a member from the guild,
|
||||||
|
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM members
|
||||||
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
|
""", guild_id, member_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', {
|
||||||
|
'guild_id': guild_id,
|
||||||
|
'unavailable': False,
|
||||||
|
})
|
||||||
|
|
||||||
|
await app.dispatcher.unsub('guild', guild_id, member_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'user': await app.storage.get_user(member_id),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_member_multi(guild_id: int, members: list):
|
||||||
|
"""Remove multiple members."""
|
||||||
|
for member_id in members:
|
||||||
|
await remove_member(guild_id, member_id)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
|
||||||
|
async def kick_guild_member(guild_id, member_id):
|
||||||
|
"""Remove a member from a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check KICK_MEMBERS permission
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
await remove_member(guild_id, member_id)
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/bans', methods=['GET'])
|
||||||
|
async def get_bans(guild_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check BAN_MEMBERS permission
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
bans = await app.db.fetch("""
|
||||||
|
SELECT user_id, reason
|
||||||
|
FROM bans
|
||||||
|
WHERE bans.guild_id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for ban in bans:
|
||||||
|
res.append({
|
||||||
|
'reason': ban['reason'],
|
||||||
|
'user': await app.storage.get_user(ban['user_id'])
|
||||||
|
})
|
||||||
|
|
||||||
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
|
||||||
|
async def create_ban(guild_id, member_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check BAN_MEMBERS permission
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
j = await request.get_json()
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO bans (guild_id, user_id, reason)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
""", guild_id, member_id, j.get('reason', ''))
|
||||||
|
|
||||||
|
await remove_member(guild_id, member_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'user': await app.storage.get_user(member_id)
|
||||||
|
})
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/bans/<int:banned_id>', methods=['DELETE'])
|
||||||
|
async def remove_ban(guild_id, banned_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check BAN_MEMBERS permission
|
||||||
|
await guild_owner_check(guild_id, user_id)
|
||||||
|
|
||||||
|
res = await app.db.execute("""
|
||||||
|
DELETE FROM bans
|
||||||
|
WHERE guild_id = $1 AND user_id = $@
|
||||||
|
""", guild_id, banned_id)
|
||||||
|
|
||||||
|
# we don't really need to dispatch GUILD_BAN_REMOVE
|
||||||
|
# when no bans were actually removed.
|
||||||
|
if res == 'DELETE 0':
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'user': await app.storage.get_user(banned_id)
|
||||||
|
})
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
async def get_prune(guild_id: int, days: int) -> list:
|
||||||
|
"""Get all members in a guild that:
|
||||||
|
|
||||||
|
- did not login in ``days`` days.
|
||||||
|
- don't have any roles.
|
||||||
|
"""
|
||||||
|
# a good solution would be in pure sql.
|
||||||
|
member_ids = await app.db.fetch(f"""
|
||||||
|
SELECT id
|
||||||
|
FROM users
|
||||||
|
JOIN members
|
||||||
|
ON members.guild_id = $1 AND members.user_id = users.id
|
||||||
|
WHERE users.last_session < (now() - (interval '{days} days'))
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
member_ids = [r['id'] for r in member_ids]
|
||||||
|
members = []
|
||||||
|
|
||||||
|
for member_id in member_ids:
|
||||||
|
role_count = await app.db.fetchval("""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM member_roles
|
||||||
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
|
""", guild_id, member_id)
|
||||||
|
|
||||||
|
if role_count == 0:
|
||||||
|
members.append(member_id)
|
||||||
|
|
||||||
|
return members
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/prune', methods=['GET'])
|
||||||
|
async def get_guild_prune_count(guild_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check KICK_MEMBERS
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
j = validate(await request.get_json(), GUILD_PRUNE)
|
||||||
|
days = j['days']
|
||||||
|
member_ids = await get_prune(guild_id, days)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'pruned': len(member_ids),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/prune', methods=['POST'])
|
||||||
|
async def begin_guild_prune(guild_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check KICK_MEMBERS
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
j = validate(await request.get_json(), GUILD_PRUNE)
|
||||||
|
days = j['days']
|
||||||
|
member_ids = await get_prune(guild_id, days)
|
||||||
|
|
||||||
|
app.loop.create_task(remove_member_multi(guild_id, member_ids))
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'pruned': len(member_ids)
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,315 @@
|
||||||
|
from typing import List, Dict, Any, Union
|
||||||
|
|
||||||
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
|
from litecord.auth import token_check
|
||||||
|
|
||||||
|
from litecord.blueprints.checks import (
|
||||||
|
guild_check, guild_owner_check
|
||||||
|
)
|
||||||
|
from litecord.schemas import (
|
||||||
|
validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
||||||
|
)
|
||||||
|
|
||||||
|
from litecord.snowflake import get_snowflake
|
||||||
|
from litecord.utils import dict_get
|
||||||
|
|
||||||
|
DEFAULT_EVERYONE_PERMS = 104324161
|
||||||
|
bp = Blueprint('guild_roles', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/roles', methods=['GET'])
|
||||||
|
async def get_guild_roles(guild_id):
|
||||||
|
"""Get all roles in a guild."""
|
||||||
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
await app.storage.get_role_data(guild_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_role(guild_id, name: str, **kwargs):
|
||||||
|
"""Create a role in a guild."""
|
||||||
|
new_role_id = get_snowflake()
|
||||||
|
|
||||||
|
# TODO: use @everyone's perm number
|
||||||
|
default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS)
|
||||||
|
|
||||||
|
max_pos = await app.db.fetchval("""
|
||||||
|
SELECT MAX(position)
|
||||||
|
FROM roles
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO roles (id, guild_id, name, color,
|
||||||
|
hoist, position, permissions, managed, mentionable)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
|
""",
|
||||||
|
new_role_id,
|
||||||
|
guild_id,
|
||||||
|
name,
|
||||||
|
dict_get(kwargs, 'color', 0),
|
||||||
|
dict_get(kwargs, 'hoist', False),
|
||||||
|
|
||||||
|
# set position = 0 when there isn't any
|
||||||
|
# other role (when we're creating the
|
||||||
|
# @everyone role)
|
||||||
|
max_pos + 1 if max_pos is not None else 0,
|
||||||
|
int(dict_get(kwargs, 'permissions', default_perms)),
|
||||||
|
False,
|
||||||
|
dict_get(kwargs, 'mentionable', False)
|
||||||
|
)
|
||||||
|
|
||||||
|
role = await app.storage.get_role(new_role_id, guild_id)
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'GUILD_ROLE_CREATE', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'role': role,
|
||||||
|
})
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/roles', methods=['POST'])
|
||||||
|
async def create_guild_role(guild_id: int):
|
||||||
|
"""Add a role to a guild"""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: use check_guild and MANAGE_ROLES permission
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
# client can just send null
|
||||||
|
j = validate(await request.get_json() or {}, ROLE_CREATE)
|
||||||
|
|
||||||
|
role_name = j['name']
|
||||||
|
j.pop('name')
|
||||||
|
|
||||||
|
role = await create_role(guild_id, role_name, **j)
|
||||||
|
|
||||||
|
return jsonify(role)
|
||||||
|
|
||||||
|
|
||||||
|
async def _role_update_dispatch(role_id: int, guild_id: int):
|
||||||
|
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
|
||||||
|
role = await app.storage.get_role(role_id, guild_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'role': role,
|
||||||
|
})
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
async def _role_pairs_update(guild_id: int, pairs: list):
|
||||||
|
"""Update the roles' positions.
|
||||||
|
|
||||||
|
Dispatches GUILD_ROLE_UPDATE for all roles being updated.
|
||||||
|
"""
|
||||||
|
for pair in pairs:
|
||||||
|
pair_1, pair_2 = pair
|
||||||
|
|
||||||
|
role_1, new_pos_1 = pair_1
|
||||||
|
role_2, new_pos_2 = pair_2
|
||||||
|
|
||||||
|
conn = await app.db.acquire()
|
||||||
|
async with conn.transaction():
|
||||||
|
# update happens in a transaction
|
||||||
|
# so we don't fuck it up
|
||||||
|
await conn.execute("""
|
||||||
|
UPDATE roles
|
||||||
|
SET position = $1
|
||||||
|
WHERE roles.id = $2
|
||||||
|
""", new_pos_1, role_1)
|
||||||
|
|
||||||
|
await conn.execute("""
|
||||||
|
UPDATE roles
|
||||||
|
SET position = $1
|
||||||
|
WHERE roles.id = $2
|
||||||
|
""", new_pos_2, role_2)
|
||||||
|
|
||||||
|
await app.db.release(conn)
|
||||||
|
|
||||||
|
# the route fires multiple Guild Role Update.
|
||||||
|
await _role_update_dispatch(role_1, guild_id)
|
||||||
|
await _role_update_dispatch(role_2, guild_id)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_pairs(list_of_changes: List[Dict[str, int]],
|
||||||
|
current_state: Dict[int, int],
|
||||||
|
blacklist: List[int] = None) -> List[tuple]:
|
||||||
|
"""Generate a list of pairs that, when applied to the database,
|
||||||
|
will generate the desired state given in list_of_changes.
|
||||||
|
|
||||||
|
We must check if the given list_of_changes isn't overwriting an
|
||||||
|
element's (such as a role or a channel) position to an existing one,
|
||||||
|
without there having an already existing change for the other one.
|
||||||
|
|
||||||
|
Here's a pratical explanation with roles:
|
||||||
|
|
||||||
|
R1 (in position RP1) wants to be in the same position
|
||||||
|
as R2 (currently in position RP2).
|
||||||
|
|
||||||
|
So, if we did the simpler approach, list_of_changes
|
||||||
|
would just contain the preferred change: (R1, RP2).
|
||||||
|
|
||||||
|
With gen_pairs, there MUST be a (R2, RP1) in list_of_changes,
|
||||||
|
if there is, the given result in gen_pairs will be a pair
|
||||||
|
((R1, RP2), (R2, RP1)) which is then used to actually
|
||||||
|
update the roles' positions in a transaction.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
list_of_changes:
|
||||||
|
A list of dictionaries with ``id`` and ``position``
|
||||||
|
fields, describing the preferred changes.
|
||||||
|
current_state:
|
||||||
|
Dictionary containing the current state of the list
|
||||||
|
of elements (roles or channels). Points position
|
||||||
|
to element ID.
|
||||||
|
blacklist:
|
||||||
|
List of IDs that shouldn't be moved.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
List of swaps to do to achieve the preferred
|
||||||
|
state given by ``list_of_changes``.
|
||||||
|
"""
|
||||||
|
pairs = []
|
||||||
|
blacklist = blacklist or []
|
||||||
|
|
||||||
|
preferred_state = {element['id']: element['position']
|
||||||
|
for element in list_of_changes}
|
||||||
|
|
||||||
|
for blacklisted_id in blacklist:
|
||||||
|
preferred_state.pop(blacklisted_id)
|
||||||
|
|
||||||
|
# for each change, we must find a matching change
|
||||||
|
# in the same list, so we can make a swap pair
|
||||||
|
for change in list_of_changes:
|
||||||
|
element_1, new_pos_1 = change['id'], change['position']
|
||||||
|
|
||||||
|
# check current pairs
|
||||||
|
# so we don't repeat an element
|
||||||
|
flag = False
|
||||||
|
|
||||||
|
for pair in pairs:
|
||||||
|
if (element_1, new_pos_1) in pair:
|
||||||
|
flag = True
|
||||||
|
|
||||||
|
# skip if found
|
||||||
|
if flag:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# search if there is a role/channel in the
|
||||||
|
# position we want to change to
|
||||||
|
element_2 = current_state.get(new_pos_1)
|
||||||
|
|
||||||
|
# if there is, is that existing channel being
|
||||||
|
# swapped to another position?
|
||||||
|
new_pos_2 = preferred_state.get(element_2)
|
||||||
|
|
||||||
|
# if its being swapped to leave space, add it
|
||||||
|
# to the pairs list
|
||||||
|
if new_pos_2:
|
||||||
|
pairs.append(
|
||||||
|
((element_1, new_pos_1), (element_2, new_pos_2))
|
||||||
|
)
|
||||||
|
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/roles', methods=['PATCH'])
|
||||||
|
async def update_guild_role_positions(guild_id):
|
||||||
|
"""Update the positions for a bunch of roles."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check MANAGE_ROLES
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
raw_j = await request.get_json()
|
||||||
|
|
||||||
|
# we need to do this hackiness because thats
|
||||||
|
# cerberus for ya.
|
||||||
|
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
||||||
|
|
||||||
|
# extract the list out
|
||||||
|
j = j['roles']
|
||||||
|
|
||||||
|
all_roles = await app.storage.get_role_data(guild_id)
|
||||||
|
|
||||||
|
# we'll have to calculate pairs of changing roles,
|
||||||
|
# then do the changes, etc.
|
||||||
|
roles_pos = {role['position']: int(role['id']) for role in all_roles}
|
||||||
|
|
||||||
|
# TODO: check if the user can even change the roles in the first place,
|
||||||
|
# preferrably when we have a proper perms system.
|
||||||
|
|
||||||
|
pairs = gen_pairs(
|
||||||
|
j,
|
||||||
|
roles_pos,
|
||||||
|
|
||||||
|
# always ignore people trying to change
|
||||||
|
# the @everyone's role position
|
||||||
|
[guild_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
await _role_pairs_update(guild_id, pairs)
|
||||||
|
|
||||||
|
# return the list of all roles back
|
||||||
|
return jsonify(await app.storage.get_role_data(guild_id))
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['PATCH'])
|
||||||
|
async def update_guild_role(guild_id, role_id):
|
||||||
|
"""Update a single role's information."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check MANAGE_ROLES
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
j = validate(await request.get_json(), ROLE_UPDATE)
|
||||||
|
|
||||||
|
# we only update ints on the db, not Permissions
|
||||||
|
j['permissions'] = int(j['permissions'])
|
||||||
|
|
||||||
|
for field in j:
|
||||||
|
await app.db.execute(f"""
|
||||||
|
UPDATE roles
|
||||||
|
SET {field} = $1
|
||||||
|
WHERE roles.id = $2 AND roles.guild_id = $3
|
||||||
|
""", j[field], role_id, guild_id)
|
||||||
|
|
||||||
|
role = await _role_update_dispatch(role_id, guild_id)
|
||||||
|
return jsonify(role)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['DELETE'])
|
||||||
|
async def delete_guild_role(guild_id, role_id):
|
||||||
|
"""Delete a role.
|
||||||
|
|
||||||
|
Dispatches GUILD_ROLE_DELETE.
|
||||||
|
"""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check MANAGE_ROLES
|
||||||
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
|
res = await app.db.execute("""
|
||||||
|
DELETE FROM roles
|
||||||
|
WHERE guild_id = $1 AND id = $2
|
||||||
|
""", guild_id, role_id)
|
||||||
|
|
||||||
|
if res == 'DELETE 0':
|
||||||
|
return '', 204
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'role_id': str(role_id),
|
||||||
|
})
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
@ -1,31 +1,23 @@
|
||||||
from quart import Blueprint, request, current_app as app, jsonify
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
|
from litecord.blueprints.guild.channels import create_guild_channel
|
||||||
|
from litecord.blueprints.guild.roles import (
|
||||||
|
create_role, DEFAULT_EVERYONE_PERMS
|
||||||
|
)
|
||||||
|
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
from ..snowflake import get_snowflake
|
from ..snowflake import get_snowflake
|
||||||
from ..enums import ChannelType
|
from ..enums import ChannelType
|
||||||
from ..errors import Forbidden, GuildNotFound, BadRequest
|
from ..schemas import (
|
||||||
from ..schemas import validate, GUILD_UPDATE
|
validate, GUILD_CREATE, GUILD_UPDATE
|
||||||
|
)
|
||||||
from .channels import channel_ack
|
from .channels import channel_ack
|
||||||
from .checks import guild_check
|
from .checks import guild_check, guild_owner_check
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint('guilds', __name__)
|
bp = Blueprint('guilds', __name__)
|
||||||
|
|
||||||
|
|
||||||
async def guild_owner_check(user_id: int, guild_id: int):
|
|
||||||
"""Check if a user is the owner of the guild."""
|
|
||||||
owner_id = await app.db.fetchval("""
|
|
||||||
SELECT owner_id
|
|
||||||
FROM guilds
|
|
||||||
WHERE guild_id = $1
|
|
||||||
""", guild_id)
|
|
||||||
|
|
||||||
if not owner_id:
|
|
||||||
raise GuildNotFound()
|
|
||||||
|
|
||||||
if user_id != owner_id:
|
|
||||||
raise Forbidden('You are not the owner of the guild')
|
|
||||||
|
|
||||||
|
|
||||||
async def create_guild_settings(guild_id: int, user_id: int):
|
async def create_guild_settings(guild_id: int, user_id: int):
|
||||||
"""Create guild settings for the user
|
"""Create guild settings for the user
|
||||||
joining the guild."""
|
joining the guild."""
|
||||||
|
|
@ -48,10 +40,59 @@ async def create_guild_settings(guild_id: int, user_id: int):
|
||||||
""", m_notifs, user_id, guild_id)
|
""", m_notifs, user_id, guild_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_member(guild_id: int, user_id: int):
|
||||||
|
"""Add a user to a guild."""
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO members (user_id, guild_id)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", user_id, guild_id)
|
||||||
|
|
||||||
|
await create_guild_settings(guild_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def guild_create_roles_prep(guild_id: int, roles: list):
|
||||||
|
"""Create roles in preparation in guild create."""
|
||||||
|
# by reaching this point in the code that means
|
||||||
|
# roles is not nullable, which means
|
||||||
|
# roles has at least one element, so we can access safely.
|
||||||
|
|
||||||
|
# the first member in the roles array
|
||||||
|
# are patches to the @everyone role
|
||||||
|
everyone_patches = roles[0]
|
||||||
|
for field in everyone_patches:
|
||||||
|
await app.db.execute(f"""
|
||||||
|
UPDATE roles
|
||||||
|
SET {field}={everyone_patches[field]}
|
||||||
|
WHERE roles.id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
default_perms = (everyone_patches.get('permissions')
|
||||||
|
or DEFAULT_EVERYONE_PERMS)
|
||||||
|
|
||||||
|
# from the 2nd and forward,
|
||||||
|
# should be treated as new roles
|
||||||
|
for role in roles[1:]:
|
||||||
|
await create_role(
|
||||||
|
guild_id, role['name'], default_perms=default_perms, **role
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def guild_create_channels_prep(guild_id: int, channels: list):
|
||||||
|
"""Create channels pre-guild create"""
|
||||||
|
for channel_raw in channels:
|
||||||
|
channel_id = get_snowflake()
|
||||||
|
ctype = ChannelType(channel_raw['type'])
|
||||||
|
|
||||||
|
await create_guild_channel(guild_id, channel_id, ctype)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('', methods=['POST'])
|
@bp.route('', methods=['POST'])
|
||||||
async def create_guild():
|
async def create_guild():
|
||||||
|
"""Create a new guild, assigning
|
||||||
|
the user creating it as the owner and
|
||||||
|
making them join."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
j = await request.get_json()
|
j = validate(await request.get_json(), GUILD_CREATE)
|
||||||
|
|
||||||
guild_id = get_snowflake()
|
guild_id = get_snowflake()
|
||||||
|
|
||||||
|
|
@ -66,36 +107,37 @@ async def create_guild():
|
||||||
j.get('default_message_notifications', 0),
|
j.get('default_message_notifications', 0),
|
||||||
j.get('explicit_content_filter', 0))
|
j.get('explicit_content_filter', 0))
|
||||||
|
|
||||||
await app.db.execute("""
|
await add_member(guild_id, user_id)
|
||||||
INSERT INTO members (user_id, guild_id)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
""", user_id, guild_id)
|
|
||||||
|
|
||||||
await create_guild_settings(guild_id, user_id)
|
# create the default @everyone role (everyone has it by default,
|
||||||
|
# so we don't insert that in the table)
|
||||||
|
|
||||||
|
# we also don't use create_role because the id of the role
|
||||||
|
# is the same as the id of the guild, and create_role
|
||||||
|
# generates a new snowflake.
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
INSERT INTO roles (id, guild_id, name, position, permissions)
|
INSERT INTO roles (id, guild_id, name, position, permissions)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
""", guild_id, guild_id, '@everyone', 0, 104324161)
|
""", guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS)
|
||||||
|
|
||||||
|
# add the @everyone role to the guild creator
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
""", user_id, guild_id, guild_id)
|
||||||
|
|
||||||
|
# create a single #general channel.
|
||||||
general_id = get_snowflake()
|
general_id = get_snowflake()
|
||||||
|
|
||||||
await app.db.execute("""
|
await create_guild_channel(
|
||||||
INSERT INTO channels (id, channel_type)
|
guild_id, general_id, ChannelType.GUILD_TEXT,
|
||||||
VALUES ($1, $2)
|
name='general')
|
||||||
""", general_id, ChannelType.GUILD_TEXT.value)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
if j.get('roles'):
|
||||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
await guild_create_roles_prep(guild_id, j['roles'])
|
||||||
VALUES ($1, $2, $3, $4)
|
|
||||||
""", general_id, guild_id, 'general', 0)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
if j.get('channels'):
|
||||||
INSERT INTO guild_text_channels (id)
|
await guild_create_channels_prep(guild_id, j['channels'])
|
||||||
VALUES ($1)
|
|
||||||
""", general_id)
|
|
||||||
|
|
||||||
# TODO: j['roles'] and j['channels']
|
|
||||||
|
|
||||||
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||||
|
|
||||||
|
|
@ -106,21 +148,22 @@ async def create_guild():
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['GET'])
|
@bp.route('/<int:guild_id>', methods=['GET'])
|
||||||
async def get_guild(guild_id):
|
async def get_guild(guild_id):
|
||||||
|
"""Get a single guilds' information."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
gj = await app.storage.get_guild(guild_id, user_id)
|
return jsonify(
|
||||||
gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250)
|
await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||||
|
)
|
||||||
return jsonify({**gj, **gj_extra})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['UPDATE'])
|
@bp.route('/<int:guild_id>', methods=['UPDATE'])
|
||||||
async def update_guild(guild_id):
|
async def update_guild(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
j = validate(await request.get_json(), GUILD_UPDATE)
|
|
||||||
|
|
||||||
# TODO: check MANAGE_GUILD
|
# TODO: check MANAGE_GUILD
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
j = validate(await request.get_json(), GUILD_UPDATE)
|
||||||
|
|
||||||
if 'owner_id' in j:
|
if 'owner_id' in j:
|
||||||
await guild_owner_check(user_id, guild_id)
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
@ -139,8 +182,6 @@ async def update_guild(guild_id):
|
||||||
""", j['name'], guild_id)
|
""", j['name'], guild_id)
|
||||||
|
|
||||||
if 'region' in j:
|
if 'region' in j:
|
||||||
# TODO: check region value
|
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET region = $1
|
SET region = $1
|
||||||
|
|
@ -167,15 +208,14 @@ async def update_guild(guild_id):
|
||||||
WHERE guild_id = $2
|
WHERE guild_id = $2
|
||||||
""", j[field], guild_id)
|
""", j[field], guild_id)
|
||||||
|
|
||||||
# return guild object
|
guild = await app.storage.get_guild_full(
|
||||||
gj = await app.storage.get_guild(guild_id, user_id)
|
guild_id, user_id
|
||||||
gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250)
|
)
|
||||||
|
|
||||||
gj_total = {**gj, **gj_extra}
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'GUILD_UPDATE', guild)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_UPDATE', gj_total)
|
return jsonify(guild)
|
||||||
|
|
||||||
return jsonify({**gj, **gj_extra})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['DELETE'])
|
@bp.route('/<int:guild_id>', methods=['DELETE'])
|
||||||
|
|
@ -185,7 +225,7 @@ async def delete_guild(guild_id):
|
||||||
await guild_owner_check(user_id, guild_id)
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
DELETE FROM guild
|
DELETE FROM guilds
|
||||||
WHERE guilds.id = $1
|
WHERE guilds.id = $1
|
||||||
""", guild_id)
|
""", guild_id)
|
||||||
|
|
||||||
|
|
@ -202,264 +242,12 @@ async def delete_guild(guild_id):
|
||||||
return '', 204
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild>/channels', methods=['GET'])
|
|
||||||
async def get_guild_channels(guild_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
|
|
||||||
channels = await app.storage.get_channel_data(guild_id)
|
|
||||||
return jsonify(channels)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/channels', methods=['POST'])
|
|
||||||
async def create_channel(guild_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
j = await request.get_json()
|
|
||||||
|
|
||||||
# TODO: check permissions for MANAGE_CHANNELS
|
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
|
|
||||||
new_channel_id = get_snowflake()
|
|
||||||
channel_type = j.get('type', ChannelType.GUILD_TEXT)
|
|
||||||
|
|
||||||
channel_type = ChannelType(channel_type)
|
|
||||||
|
|
||||||
if channel_type not in (ChannelType.GUILD_TEXT,
|
|
||||||
ChannelType.GUILD_VOICE):
|
|
||||||
raise BadRequest('Invalid channel type')
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
INSERT INTO channels (id, channel_type)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
""", new_channel_id, channel_type.value)
|
|
||||||
|
|
||||||
max_pos = await app.db.fetchval("""
|
|
||||||
SELECT MAX(position)
|
|
||||||
FROM guild_channels
|
|
||||||
WHERE guild_id = $1
|
|
||||||
""", guild_id)
|
|
||||||
|
|
||||||
if channel_type == ChannelType.GUILD_TEXT:
|
|
||||||
await app.db.execute("""
|
|
||||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
|
||||||
VALUES ($1, $2, $3, $4)
|
|
||||||
""", new_channel_id, guild_id, j['name'], max_pos + 1)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
INSERT INTO guild_text_channels (id)
|
|
||||||
VALUES ($1)
|
|
||||||
""", new_channel_id)
|
|
||||||
|
|
||||||
elif channel_type == ChannelType.GUILD_VOICE:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
chan = await app.storage.get_channel(new_channel_id)
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_CREATE', chan)
|
|
||||||
return jsonify(chan)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
|
|
||||||
async def modify_channel_pos(guild_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
await request.get_json()
|
|
||||||
|
|
||||||
# TODO: this route
|
|
||||||
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
|
|
||||||
async def get_guild_member(guild_id, member_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
|
|
||||||
member = await app.storage.get_single_member(guild_id, member_id)
|
|
||||||
return jsonify(member)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members', methods=['GET'])
|
|
||||||
async def get_members(guild_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
|
|
||||||
j = await request.get_json()
|
|
||||||
|
|
||||||
limit, after = int(j.get('limit', 1)), j.get('after', 0)
|
|
||||||
|
|
||||||
if limit < 1 or limit > 1000:
|
|
||||||
raise BadRequest('limit not in 1-1000 range')
|
|
||||||
|
|
||||||
user_ids = await app.db.fetch(f"""
|
|
||||||
SELECT user_id
|
|
||||||
WHERE guild_id = $1, user_id > $2
|
|
||||||
LIMIT {limit}
|
|
||||||
ORDER BY user_id ASC
|
|
||||||
""", guild_id, after)
|
|
||||||
|
|
||||||
user_ids = [r[0] for r in user_ids]
|
|
||||||
members = await app.storage.get_member_multi(guild_id, user_ids)
|
|
||||||
return jsonify(members)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
|
|
||||||
async def modify_guild_member(guild_id, member_id):
|
|
||||||
j = await request.get_json()
|
|
||||||
|
|
||||||
if 'nick' in j:
|
|
||||||
# TODO: check MANAGE_NICKNAMES
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE members
|
|
||||||
SET nickname = $1
|
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
|
||||||
""", j['nick'], member_id, guild_id)
|
|
||||||
|
|
||||||
if 'mute' in j:
|
|
||||||
# TODO: check MUTE_MEMBERS
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE members
|
|
||||||
SET muted = $1
|
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
|
||||||
""", j['mute'], member_id, guild_id)
|
|
||||||
|
|
||||||
if 'deaf' in j:
|
|
||||||
# TODO: check DEAFEN_MEMBERS
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE members
|
|
||||||
SET deafened = $1
|
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
|
||||||
""", j['deaf'], member_id, guild_id)
|
|
||||||
|
|
||||||
if 'channel_id' in j:
|
|
||||||
# TODO: check MOVE_MEMBERS
|
|
||||||
# TODO: change the member's voice channel
|
|
||||||
pass
|
|
||||||
|
|
||||||
member = await app.storage.get_member_data_one(guild_id, member_id)
|
|
||||||
member.pop('joined_at')
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
|
||||||
'guild_id': str(guild_id)
|
|
||||||
}, **member})
|
|
||||||
|
|
||||||
return '', 204
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
|
|
||||||
async def update_nickname(guild_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
|
|
||||||
j = await request.get_json()
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE members
|
|
||||||
SET nickname = $1
|
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
|
||||||
""", j['nick'], user_id, guild_id)
|
|
||||||
|
|
||||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
|
||||||
member.pop('joined_at')
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
|
||||||
'guild_id': str(guild_id)
|
|
||||||
}, **member})
|
|
||||||
|
|
||||||
return j['nick']
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
|
|
||||||
async def kick_member(guild_id, member_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
|
|
||||||
# TODO: check KICK_MEMBERS permission
|
|
||||||
await guild_owner_check(user_id, guild_id)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
DELETE FROM members
|
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
|
||||||
""", guild_id, member_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(user_id, 'GUILD_DELETE', {
|
|
||||||
'guild_id': guild_id,
|
|
||||||
'unavailable': False,
|
|
||||||
})
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
|
||||||
'guild': guild_id,
|
|
||||||
'user': await app.storage.get_user(member_id),
|
|
||||||
})
|
|
||||||
|
|
||||||
return '', 204
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/bans', methods=['GET'])
|
|
||||||
async def get_bans(guild_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
|
|
||||||
# TODO: check BAN_MEMBERS permission
|
|
||||||
await guild_owner_check(user_id, guild_id)
|
|
||||||
|
|
||||||
bans = await app.db.fetch("""
|
|
||||||
SELECT user_id, reason
|
|
||||||
FROM bans
|
|
||||||
WHERE bans.guild_id = $1
|
|
||||||
""", guild_id)
|
|
||||||
|
|
||||||
res = []
|
|
||||||
|
|
||||||
for ban in bans:
|
|
||||||
res.append({
|
|
||||||
'reason': ban['reason'],
|
|
||||||
'user': await app.storage.get_user(ban['user_id'])
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify(res)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
|
|
||||||
async def create_ban(guild_id, member_id):
|
|
||||||
user_id = await token_check()
|
|
||||||
|
|
||||||
# TODO: check BAN_MEMBERS permission
|
|
||||||
await guild_owner_check(user_id, guild_id)
|
|
||||||
|
|
||||||
j = await request.get_json()
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
INSERT INTO bans (guild_id, user_id, reason)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
""", guild_id, member_id, j.get('reason', ''))
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
DELETE FROM members
|
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
|
||||||
""", guild_id, user_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', {
|
|
||||||
'guild_id': guild_id,
|
|
||||||
'unavailable': False,
|
|
||||||
})
|
|
||||||
|
|
||||||
await app.dispatcher.unsub('guild', guild_id, member_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
|
||||||
'guild': guild_id,
|
|
||||||
'user': await app.storage.get_user(member_id),
|
|
||||||
})
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {**{
|
|
||||||
'guild': guild_id,
|
|
||||||
}, **(await app.storage.get_user(member_id))})
|
|
||||||
|
|
||||||
return '', 204
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/messages/search')
|
@bp.route('/<int:guild_id>/messages/search')
|
||||||
async def search_messages(guild_id):
|
async def search_messages(guild_id):
|
||||||
|
"""Search messages in a guild.
|
||||||
|
|
||||||
|
This is an undocumented route.
|
||||||
|
"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
|
@ -474,6 +262,7 @@ async def search_messages(guild_id):
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/ack', methods=['POST'])
|
@bp.route('/<int:guild_id>/ack', methods=['POST'])
|
||||||
async def ack_guild(guild_id):
|
async def ack_guild(guild_id):
|
||||||
|
"""ACKnowledge all messages in the guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -185,7 +185,7 @@ async def use_invite(invite_code):
|
||||||
})
|
})
|
||||||
|
|
||||||
# subscribe new member to guild, so they get events n stuff
|
# subscribe new member to guild, so they get events n stuff
|
||||||
app.dispatcher.sub_guild(guild_id, user_id)
|
await app.dispatcher.sub('guild', guild_id, user_id)
|
||||||
|
|
||||||
# tell the new member that theres the guild it just joined.
|
# tell the new member that theres the guild it just joined.
|
||||||
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
||||||
|
|
|
||||||
|
|
@ -219,9 +219,11 @@ async def get_me_guilds():
|
||||||
partial = await app.db.fetchrow("""
|
partial = await app.db.fetchrow("""
|
||||||
SELECT id::text, name, icon, owner_id
|
SELECT id::text, name, icon, owner_id
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE guild_id = $1
|
WHERE guilds.id = $1
|
||||||
""", guild_id)
|
""", guild_id)
|
||||||
|
|
||||||
|
partial = dict(partial)
|
||||||
|
|
||||||
# TODO: partial['permissions']
|
# TODO: partial['permissions']
|
||||||
partial['owner'] = partial['owner_id'] == user_id
|
partial['owner'] = partial['owner_id'] == user_id
|
||||||
partial.pop('owner_id')
|
partial.pop('owner_id')
|
||||||
|
|
@ -279,10 +281,11 @@ async def put_note(target_id: int):
|
||||||
INSERT INTO notes (user_id, target_id, note)
|
INSERT INTO notes (user_id, target_id, note)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
|
|
||||||
ON CONFLICT DO UPDATE SET
|
ON CONFLICT ON CONSTRAINT notes_pkey
|
||||||
|
DO UPDATE SET
|
||||||
note = $3
|
note = $3
|
||||||
WHERE
|
WHERE notes.user_id = $1
|
||||||
user_id = $1 AND target_id = $2
|
AND notes.target_id = $2
|
||||||
""", user_id, target_id, note)
|
""", user_id, target_id, note)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
|
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
|
||||||
|
|
@ -315,7 +318,8 @@ async def patch_current_settings():
|
||||||
await app.db.execute(f"""
|
await app.db.execute(f"""
|
||||||
UPDATE user_settings
|
UPDATE user_settings
|
||||||
SET {key}=$1
|
SET {key}=$1
|
||||||
""", j[key])
|
WHERE id = $2
|
||||||
|
""", j[key], user_id)
|
||||||
|
|
||||||
settings = await app.storage.get_user_settings(user_id)
|
settings = await app.storage.get_user_settings(user_id)
|
||||||
await app.dispatcher.dispatch_user(
|
await app.dispatcher.dispatch_user(
|
||||||
|
|
@ -444,20 +448,20 @@ async def patch_guild_settings(guild_id: int):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for field in chan_overrides:
|
for field in chan_overrides:
|
||||||
res = await app.db.execute(f"""
|
await app.db.execute(f"""
|
||||||
UPDATE guild_settings_channel_overrides
|
INSERT INTO guild_settings_channel_overrides
|
||||||
SET {field} = $1
|
(user_id, guild_id, channel_id, {field})
|
||||||
WHERE user_id = $2
|
VALUES
|
||||||
AND guild_id = $3
|
($1, $2, $3, $4)
|
||||||
AND channel_id = $4
|
ON CONFLICT
|
||||||
""", chan_overrides[field], user_id, guild_id, chan_id)
|
ON CONSTRAINT guild_settings_channel_overrides_pkey
|
||||||
|
DO
|
||||||
if res == 'UPDATE 0':
|
UPDATE
|
||||||
await app.db.execute(f"""
|
SET {field} = $4
|
||||||
INSERT INTO guild_settings_channel_overrides
|
WHERE guild_settings_channel_overrides.user_id = $1
|
||||||
(user_id, guild_id, channel_id, {field})
|
AND guild_settings_channel_overrides.guild_id = $2
|
||||||
VALUES ($1, $2, $3, $4)
|
AND guild_settings_channel_overrides.channel_id = $3
|
||||||
""", user_id, guild_id, chan_id, chan_overrides[field])
|
""", user_id, guild_id, chan_id, chan_overrides[field])
|
||||||
|
|
||||||
settings = await app.storage.get_guild_settings_one(user_id, guild_id)
|
settings = await app.storage.get_guild_settings_one(user_id, guild_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@ from typing import List, Any
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .pubsub import GuildDispatcher, MemberDispatcher, \
|
from .pubsub import GuildDispatcher, MemberDispatcher, \
|
||||||
UserDispatcher, ChannelDispatcher, FriendDispatcher
|
UserDispatcher, ChannelDispatcher, FriendDispatcher, \
|
||||||
|
LazyGuildDispatcher
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -35,6 +36,7 @@ class EventDispatcher:
|
||||||
'channel': ChannelDispatcher(self),
|
'channel': ChannelDispatcher(self),
|
||||||
'user': UserDispatcher(self),
|
'user': UserDispatcher(self),
|
||||||
'friend': FriendDispatcher(self),
|
'friend': FriendDispatcher(self),
|
||||||
|
'lazy_guild': LazyGuildDispatcher(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def action(self, backend_str: str, action: str, key, identifier):
|
async def action(self, backend_str: str, action: str, key, identifier):
|
||||||
|
|
@ -104,6 +106,15 @@ class EventDispatcher:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
await self.dispatch(backend_str, key, *args, **kwargs)
|
await self.dispatch(backend_str, key, *args, **kwargs)
|
||||||
|
|
||||||
|
async def dispatch_filter(self, backend_str: str,
|
||||||
|
key: Any, func, *args):
|
||||||
|
"""Dispatch to a backend that only accepts
|
||||||
|
(event, data) arguments with an optional filter
|
||||||
|
function."""
|
||||||
|
backend = self.backends[backend_str]
|
||||||
|
key = backend.KEY_TYPE(key)
|
||||||
|
return await backend.dispatch_filter(key, func, *args)
|
||||||
|
|
||||||
async def reset(self, backend_str: str, key: Any):
|
async def reset(self, backend_str: str, key: Any):
|
||||||
"""Reset the bucket in the given backend."""
|
"""Reset the bucket in the given backend."""
|
||||||
backend = self.backends[backend_str]
|
backend = self.backends[backend_str]
|
||||||
|
|
|
||||||
|
|
@ -29,16 +29,24 @@ class NotFound(LitecordError):
|
||||||
status_code = 404
|
status_code = 404
|
||||||
|
|
||||||
|
|
||||||
class GuildNotFound(LitecordError):
|
class GuildNotFound(NotFound):
|
||||||
status_code = 404
|
error_code = 10004
|
||||||
|
|
||||||
|
|
||||||
class ChannelNotFound(LitecordError):
|
class ChannelNotFound(NotFound):
|
||||||
status_code = 404
|
error_code = 10003
|
||||||
|
|
||||||
|
|
||||||
class MessageNotFound(LitecordError):
|
class MessageNotFound(NotFound):
|
||||||
status_code = 404
|
error_code = 10008
|
||||||
|
|
||||||
|
|
||||||
|
class Ratelimited(LitecordError):
|
||||||
|
status_code = 429
|
||||||
|
|
||||||
|
|
||||||
|
class MissingPermissions(Forbidden):
|
||||||
|
error_code = 50013
|
||||||
|
|
||||||
|
|
||||||
class WebsocketClose(Exception):
|
class WebsocketClose(Exception):
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,68 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from websockets.exceptions import ConnectionClosed
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .state import GatewayState
|
from litecord.gateway.state import GatewayState
|
||||||
|
from litecord.gateway.opcodes import OP
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ManagerClose(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StateDictWrapper:
|
||||||
|
"""Wrap a mapping so that any kind of access to the mapping while the
|
||||||
|
state manager is closed raises a ManagerClose error"""
|
||||||
|
def __init__(self, state_manager, mapping):
|
||||||
|
self.state_manager = state_manager
|
||||||
|
self._map = mapping
|
||||||
|
|
||||||
|
def _check_closed(self):
|
||||||
|
if self.state_manager.closed:
|
||||||
|
raise ManagerClose()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
self._check_closed()
|
||||||
|
return self._map[key]
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
self._check_closed()
|
||||||
|
del self._map[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
if not self.state_manager.accept_new:
|
||||||
|
raise ManagerClose()
|
||||||
|
|
||||||
|
self._check_closed()
|
||||||
|
self._map[key] = value
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self._map.__iter__()
|
||||||
|
|
||||||
|
def pop(self, key):
|
||||||
|
return self._map.pop(key)
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return self._map.values()
|
||||||
|
|
||||||
|
|
||||||
class StateManager:
|
class StateManager:
|
||||||
"""Manager for gateway state information."""
|
"""Manager for gateway state information."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
#: closed flag
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
#: accept new states?
|
||||||
|
self.accept_new = True
|
||||||
|
|
||||||
# {
|
# {
|
||||||
# user_id: {
|
# user_id: {
|
||||||
# session_id: GatewayState,
|
# session_id: GatewayState,
|
||||||
|
|
@ -20,7 +70,10 @@ class StateManager:
|
||||||
# },
|
# },
|
||||||
# user_id_2: {}, ...
|
# user_id_2: {}, ...
|
||||||
# }
|
# }
|
||||||
self.states = defaultdict(dict)
|
self.states = StateDictWrapper(self, defaultdict(dict))
|
||||||
|
|
||||||
|
#: raw mapping from session ids to GatewayState
|
||||||
|
self.states_raw = StateDictWrapper(self, {})
|
||||||
|
|
||||||
def insert(self, state: GatewayState):
|
def insert(self, state: GatewayState):
|
||||||
"""Insert a new state object."""
|
"""Insert a new state object."""
|
||||||
|
|
@ -28,6 +81,7 @@ class StateManager:
|
||||||
|
|
||||||
log.debug('inserting state: {!r}', state)
|
log.debug('inserting state: {!r}', state)
|
||||||
user_states[state.session_id] = state
|
user_states[state.session_id] = state
|
||||||
|
self.states_raw[state.session_id] = state
|
||||||
|
|
||||||
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
||||||
"""Fetch a state object from the manager.
|
"""Fetch a state object from the manager.
|
||||||
|
|
@ -40,11 +94,20 @@ class StateManager:
|
||||||
"""
|
"""
|
||||||
return self.states[user_id][session_id]
|
return self.states[user_id][session_id]
|
||||||
|
|
||||||
|
def fetch_raw(self, session_id: str) -> GatewayState:
|
||||||
|
"""Fetch a single state given the Session ID."""
|
||||||
|
return self.states_raw[session_id]
|
||||||
|
|
||||||
def remove(self, state):
|
def remove(self, state):
|
||||||
"""Remove a state from the registry"""
|
"""Remove a state from the registry"""
|
||||||
if not state:
|
if not state:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.states_raw.pop(state.session_id)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug('removing state: {!r}', state)
|
log.debug('removing state: {!r}', state)
|
||||||
self.states[state.user_id].pop(state.session_id)
|
self.states[state.user_id].pop(state.session_id)
|
||||||
|
|
@ -100,3 +163,54 @@ class StateManager:
|
||||||
states.extend(member_states)
|
states.extend(member_states)
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
async def shutdown_single(self, state: GatewayState):
|
||||||
|
"""Send OP Reconnect to a single connection."""
|
||||||
|
websocket = state.ws
|
||||||
|
|
||||||
|
await websocket.send({
|
||||||
|
'op': OP.RECONNECT
|
||||||
|
})
|
||||||
|
|
||||||
|
# wait 200ms
|
||||||
|
# so that the client has time to process
|
||||||
|
# our payload then close the connection
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# try to close the connection ourselves
|
||||||
|
await websocket.ws.close(
|
||||||
|
code=4000,
|
||||||
|
reason='litecord shutting down'
|
||||||
|
)
|
||||||
|
except ConnectionClosed:
|
||||||
|
log.info('client {} already closed', state)
|
||||||
|
|
||||||
|
def gen_close_tasks(self):
|
||||||
|
"""Generate the tasks that will order the clients
|
||||||
|
to reconnect.
|
||||||
|
|
||||||
|
This is required to be ran before :meth:`StateManager.close`,
|
||||||
|
since this function doesn't wait for the tasks to complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.accept_new = False
|
||||||
|
|
||||||
|
#: store the shutdown tasks
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for state in self.states_raw.values():
|
||||||
|
if not state.ws:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
self.shutdown_single(state)
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info('made {} shutdown tasks', len(tasks))
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the state manager."""
|
||||||
|
self.closed = True
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
WebsocketObjects = collections.namedtuple(
|
WebsocketObjects = collections.namedtuple(
|
||||||
'WebsocketObjects', 'db state_manager storage loop dispatcher presence'
|
'WebsocketObjects', ('db', 'state_manager', 'storage',
|
||||||
|
'loop', 'dispatcher', 'presence', 'ratelimiter')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,8 +45,38 @@ def encode_etf(payload) -> str:
|
||||||
return earl.pack(payload)
|
return earl.pack(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _etf_decode_dict(data):
|
||||||
|
# NOTE: this is a very slow implementation to
|
||||||
|
# decode the dictionary.
|
||||||
|
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data.decode()
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
|
||||||
|
_copy = dict(data)
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for key in _copy.keys():
|
||||||
|
# assuming key is bytes rn.
|
||||||
|
new_k = key.decode()
|
||||||
|
|
||||||
|
# maybe nested dicts, so...
|
||||||
|
result[new_k] = _etf_decode_dict(data[key])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def decode_etf(data: bytes):
|
def decode_etf(data: bytes):
|
||||||
return earl.unpack(data)
|
res = earl.unpack(data)
|
||||||
|
|
||||||
|
if isinstance(res, bytes):
|
||||||
|
return data.decode()
|
||||||
|
|
||||||
|
if isinstance(res, dict):
|
||||||
|
return _etf_decode_dict(res)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
|
|
@ -108,6 +139,11 @@ class GatewayWebsocket:
|
||||||
else:
|
else:
|
||||||
await self.ws.send(encoded.decode())
|
await self.ws.send(encoded.decode())
|
||||||
|
|
||||||
|
def _check_ratelimit(self, key: str, ratelimit_key: str):
|
||||||
|
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
||||||
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||||
|
return bucket.update_rate_limit()
|
||||||
|
|
||||||
async def _hb_wait(self, interval: int):
|
async def _hb_wait(self, interval: int):
|
||||||
"""Wait heartbeat"""
|
"""Wait heartbeat"""
|
||||||
# if the client heartbeats in time,
|
# if the client heartbeats in time,
|
||||||
|
|
@ -312,6 +348,14 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def update_status(self, status: dict):
|
async def update_status(self, status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
|
if not self.state:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._check_ratelimit('presence', self.state.session_id):
|
||||||
|
# Presence Updates beyond the ratelimit
|
||||||
|
# are just silently dropped.
|
||||||
|
return
|
||||||
|
|
||||||
if status is None:
|
if status is None:
|
||||||
status = {
|
status = {
|
||||||
'afk': False,
|
'afk': False,
|
||||||
|
|
@ -365,6 +409,15 @@ class GatewayWebsocket:
|
||||||
'op': OP.HEARTBEAT_ACK,
|
'op': OP.HEARTBEAT_ACK,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
async def _connect_ratelimit(self, user_id: int):
|
||||||
|
if self._check_ratelimit('connect', user_id):
|
||||||
|
await self.invalidate_session(False)
|
||||||
|
raise WebsocketClose(4009, 'You are being ratelimited.')
|
||||||
|
|
||||||
|
if self._check_ratelimit('session', user_id):
|
||||||
|
await self.invalidate_session(False)
|
||||||
|
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.')
|
||||||
|
|
||||||
async def handle_2(self, payload: Dict[str, Any]):
|
async def handle_2(self, payload: Dict[str, Any]):
|
||||||
"""Handle the OP 2 Identify packet."""
|
"""Handle the OP 2 Identify packet."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -384,6 +437,8 @@ class GatewayWebsocket:
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, 'Authentication failed')
|
raise WebsocketClose(4004, 'Authentication failed')
|
||||||
|
|
||||||
|
await self._connect_ratelimit(user_id)
|
||||||
|
|
||||||
bot = await self.ext.db.fetchval("""
|
bot = await self.ext.db.fetchval("""
|
||||||
SELECT bot FROM users
|
SELECT bot FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
@ -641,9 +696,11 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
This is the known structure of GUILD_MEMBER_LIST_UPDATE:
|
This is the known structure of GUILD_MEMBER_LIST_UPDATE:
|
||||||
|
|
||||||
|
group_id = 'online' | 'offline' | role_id (string)
|
||||||
|
|
||||||
sync_item = {
|
sync_item = {
|
||||||
'group': {
|
'group': {
|
||||||
'id': string, // 'online' | 'offline' | any role id
|
'id': group_id,
|
||||||
'count': num
|
'count': num
|
||||||
}
|
}
|
||||||
} | {
|
} | {
|
||||||
|
|
@ -653,7 +710,7 @@ class GatewayWebsocket:
|
||||||
list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE'
|
list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE'
|
||||||
|
|
||||||
list_data = {
|
list_data = {
|
||||||
'id': "everyone" // ??
|
'id': channel_id | 'everyone',
|
||||||
'guild_id': guild_id,
|
'guild_id': guild_id,
|
||||||
|
|
||||||
'ops': [
|
'ops': [
|
||||||
|
|
@ -666,10 +723,10 @@ class GatewayWebsocket:
|
||||||
// exists if op = 'SYNC'
|
// exists if op = 'SYNC'
|
||||||
'items': sync_item[],
|
'items': sync_item[],
|
||||||
|
|
||||||
// exists if op = 'INSERT' or 'DELETE'
|
// exists if op == 'INSERT' | 'DELETE' | 'UPDATE'
|
||||||
'index': num,
|
'index': num,
|
||||||
|
|
||||||
// exists if op = 'INSERT'
|
// exists if op == 'INSERT' | 'UPDATE'
|
||||||
'item': sync_item,
|
'item': sync_item,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
@ -678,31 +735,11 @@ class GatewayWebsocket:
|
||||||
// separately from the online list?
|
// separately from the online list?
|
||||||
'groups': [
|
'groups': [
|
||||||
{
|
{
|
||||||
'id': string // 'online' | 'offline' | any role id
|
'id': group_id
|
||||||
'count': num
|
'count': num
|
||||||
}, ...
|
}, ...
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Implementation defails.
|
|
||||||
|
|
||||||
Lazy guilds are complicated to deal with in the backend level
|
|
||||||
as there are a lot of computation to be done for each request.
|
|
||||||
|
|
||||||
The current implementation is rudimentary and does not account
|
|
||||||
for any roles inside the guild.
|
|
||||||
|
|
||||||
A correct implementation would take account of roles and make
|
|
||||||
the correct groups on list_data:
|
|
||||||
|
|
||||||
For each channel in lazy_request['channels']:
|
|
||||||
- get all roles that have Read Messages on the channel:
|
|
||||||
- Also fetch their member counts, as it'll be important
|
|
||||||
- with the role list, order them like you normally would
|
|
||||||
(by their role priority)
|
|
||||||
- based on the channel's range's min and max and the ordered
|
|
||||||
role list, you can get the roles wanted for your list_data reply.
|
|
||||||
- make new groups ONLY when the role is hoisted.
|
|
||||||
"""
|
"""
|
||||||
data = payload['d']
|
data = payload['d']
|
||||||
|
|
||||||
|
|
@ -713,65 +750,16 @@ class GatewayWebsocket:
|
||||||
if guild_id not in gids:
|
if guild_id not in gids:
|
||||||
return
|
return
|
||||||
|
|
||||||
member_ids = await self.storage.get_member_ids(guild_id)
|
# make shard query
|
||||||
log.debug('lazy: loading {} members', len(member_ids))
|
lazy_guilds = self.ext.dispatcher.backends['lazy_guild']
|
||||||
|
|
||||||
# the current implementation is rudimentary and only
|
for chan_id, ranges in data.get('channels', {}).items():
|
||||||
# generates two groups: online and offline, using
|
chan_id = int(chan_id)
|
||||||
# PresenceManager.guild_presences to fill list_data.
|
member_list = await lazy_guilds.get_gml(chan_id)
|
||||||
|
|
||||||
# this also doesn't take account the channels in lazy_request.
|
await member_list.shard_query(
|
||||||
|
self.state.session_id, ranges
|
||||||
guild_presences = await self.presence.guild_presences(member_ids,
|
)
|
||||||
guild_id)
|
|
||||||
|
|
||||||
online = [{'member': p}
|
|
||||||
for p in guild_presences
|
|
||||||
if p['status'] == 'online']
|
|
||||||
offline = [{'member': p}
|
|
||||||
for p in guild_presences
|
|
||||||
if p['status'] == 'offline']
|
|
||||||
|
|
||||||
log.debug('lazy: {} presences, online={}, offline={}',
|
|
||||||
len(guild_presences),
|
|
||||||
len(online),
|
|
||||||
len(offline))
|
|
||||||
|
|
||||||
# construct items in the WORST WAY POSSIBLE.
|
|
||||||
items = [{
|
|
||||||
'group': {
|
|
||||||
'id': 'online',
|
|
||||||
'count': len(online),
|
|
||||||
}
|
|
||||||
}] + online + [{
|
|
||||||
'group': {
|
|
||||||
'id': 'offline',
|
|
||||||
'count': len(offline),
|
|
||||||
}
|
|
||||||
}] + offline
|
|
||||||
|
|
||||||
await self.dispatch('GUILD_MEMBER_LIST_UPDATE', {
|
|
||||||
'id': 'everyone',
|
|
||||||
'guild_id': data['guild_id'],
|
|
||||||
'groups': [
|
|
||||||
{
|
|
||||||
'id': 'online',
|
|
||||||
'count': len(online),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': 'offline',
|
|
||||||
'count': len(offline),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
|
|
||||||
'ops': [
|
|
||||||
{
|
|
||||||
'range': [0, 99],
|
|
||||||
'op': 'SYNC',
|
|
||||||
'items': items
|
|
||||||
}
|
|
||||||
]
|
|
||||||
})
|
|
||||||
|
|
||||||
async def process_message(self, payload):
|
async def process_message(self, payload):
|
||||||
"""Process a single message coming in from the client."""
|
"""Process a single message coming in from the client."""
|
||||||
|
|
@ -788,17 +776,36 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
await handler(payload)
|
await handler(payload)
|
||||||
|
|
||||||
|
async def _msg_ratelimit(self):
|
||||||
|
if self._check_ratelimit('messages', self.state.session_id):
|
||||||
|
raise WebsocketClose(4008, 'You are being ratelimited.')
|
||||||
|
|
||||||
async def listen_messages(self):
|
async def listen_messages(self):
|
||||||
"""Listen for messages coming in from the websocket."""
|
"""Listen for messages coming in from the websocket."""
|
||||||
|
|
||||||
|
# close anyone trying to login while the
|
||||||
|
# server is shutting down
|
||||||
|
if self.ext.state_manager.closed:
|
||||||
|
raise WebsocketClose(4000, 'state manager closed')
|
||||||
|
|
||||||
|
if not self.ext.state_manager.accept_new:
|
||||||
|
raise WebsocketClose(4000, 'state manager closed for new')
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
message = await self.ws.recv()
|
message = await self.ws.recv()
|
||||||
if len(message) > 4096:
|
if len(message) > 4096:
|
||||||
raise DecodeError('Payload length exceeded')
|
raise DecodeError('Payload length exceeded')
|
||||||
|
|
||||||
|
if self.state:
|
||||||
|
await self._msg_ratelimit()
|
||||||
|
|
||||||
payload = self.decoder(message)
|
payload = self.decoder(message)
|
||||||
await self.process_message(payload)
|
await self.process_message(payload)
|
||||||
|
|
||||||
def _cleanup(self):
|
def _cleanup(self):
|
||||||
|
for task in self.wsp.tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
if self.state:
|
if self.state:
|
||||||
self.ext.state_manager.remove(self.state)
|
self.ext.state_manager.remove(self.state)
|
||||||
self.state.ws = None
|
self.state.ws = None
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,242 @@
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
from quart import current_app as app, request
|
||||||
|
|
||||||
|
# so we don't keep repeating the same
|
||||||
|
# type for all the fields
|
||||||
|
_i = ctypes.c_uint8
|
||||||
|
|
||||||
|
class _RawPermsBits(ctypes.LittleEndianStructure):
|
||||||
|
"""raw bitfield for discord's permission number."""
|
||||||
|
_fields_ = [
|
||||||
|
('create_invites', _i, 1),
|
||||||
|
('kick_members', _i, 1),
|
||||||
|
('ban_members', _i, 1),
|
||||||
|
('administrator', _i, 1),
|
||||||
|
('manage_channels', _i, 1),
|
||||||
|
('manage_guild', _i, 1),
|
||||||
|
('add_reactions', _i, 1),
|
||||||
|
('view_audit_log', _i, 1),
|
||||||
|
('priority_speaker', _i, 1),
|
||||||
|
('_unused1', _i, 1),
|
||||||
|
('read_messages', _i, 1),
|
||||||
|
('send_messages', _i, 1),
|
||||||
|
('send_tts', _i, 1),
|
||||||
|
('manage_messages', _i, 1),
|
||||||
|
('embed_links', _i, 1),
|
||||||
|
('attach_files', _i, 1),
|
||||||
|
('read_history', _i, 1),
|
||||||
|
('mention_everyone', _i, 1),
|
||||||
|
('external_emojis', _i, 1),
|
||||||
|
('_unused2', _i, 1),
|
||||||
|
('connect', _i, 1),
|
||||||
|
('speak', _i, 1),
|
||||||
|
('mute_members', _i, 1),
|
||||||
|
('deafen_members', _i, 1),
|
||||||
|
('move_members', _i, 1),
|
||||||
|
('use_voice_activation', _i, 1),
|
||||||
|
('change_nickname', _i, 1),
|
||||||
|
('manage_nicknames', _i, 1),
|
||||||
|
('manage_roles', _i, 1),
|
||||||
|
('manage_webhooks', _i, 1),
|
||||||
|
('manage_emojis', _i, 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Permissions(ctypes.Union):
|
||||||
|
_fields_ = [
|
||||||
|
('bits', _RawPermsBits),
|
||||||
|
('binary', ctypes.c_uint64),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, val: int):
|
||||||
|
self.binary = val
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'<Permissions binary={self.binary}>'
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return self.binary
|
||||||
|
|
||||||
|
def numby(self):
|
||||||
|
return self.binary
|
||||||
|
|
||||||
|
|
||||||
|
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
|
||||||
|
"""Get the raw :class:`Permissions` object for a role."""
|
||||||
|
if not storage:
|
||||||
|
storage = app.storage
|
||||||
|
|
||||||
|
perms = await storage.db.fetchval("""
|
||||||
|
SELECT permissions
|
||||||
|
FROM roles
|
||||||
|
WHERE guild_id = $1 AND id = $2
|
||||||
|
""", guild_id, role_id)
|
||||||
|
|
||||||
|
return Permissions(perms)
|
||||||
|
|
||||||
|
|
||||||
|
async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
|
||||||
|
"""Compute the base permissions for a given user.
|
||||||
|
|
||||||
|
Base permissions are
|
||||||
|
(permissions from @everyone role) +
|
||||||
|
(permissions from any other role the member has)
|
||||||
|
|
||||||
|
This will give ALL_PERMISSIONS if base permissions
|
||||||
|
has the Administrator bit set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not storage:
|
||||||
|
storage = app.storage
|
||||||
|
|
||||||
|
owner_id = await storage.db.fetchval("""
|
||||||
|
SELECT owner_id
|
||||||
|
FROM guilds
|
||||||
|
WHERE id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
if owner_id == member_id:
|
||||||
|
return ALL_PERMISSIONS
|
||||||
|
|
||||||
|
# get permissions for @everyone
|
||||||
|
permissions = await get_role_perms(guild_id, guild_id, storage)
|
||||||
|
|
||||||
|
role_ids = await storage.db.fetch("""
|
||||||
|
SELECT role_id
|
||||||
|
FROM member_roles
|
||||||
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
|
""", guild_id, member_id)
|
||||||
|
|
||||||
|
role_perms = []
|
||||||
|
|
||||||
|
for row in role_ids:
|
||||||
|
rperm = await storage.db.fetchval("""
|
||||||
|
SELECT permissions
|
||||||
|
FROM roles
|
||||||
|
WHERE id = $1
|
||||||
|
""", row['role_id'])
|
||||||
|
|
||||||
|
role_perms.append(rperm)
|
||||||
|
|
||||||
|
for perm_num in role_perms:
|
||||||
|
permissions.binary |= perm_num
|
||||||
|
|
||||||
|
if permissions.bits.administrator:
|
||||||
|
return ALL_PERMISSIONS
|
||||||
|
|
||||||
|
return permissions
|
||||||
|
|
||||||
|
|
||||||
|
def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
|
||||||
|
# we make a copy of the binary representation
|
||||||
|
# so we don't modify the old perms in-place
|
||||||
|
# which could be an unwanted side-effect
|
||||||
|
result = perms.binary
|
||||||
|
|
||||||
|
# negate the permissions that are denied
|
||||||
|
result &= ~overwrite['deny']
|
||||||
|
|
||||||
|
# combine the permissions that are allowed
|
||||||
|
result |= overwrite['allow']
|
||||||
|
|
||||||
|
return Permissions(result)
|
||||||
|
|
||||||
|
|
||||||
|
def overwrite_find_mix(perms: Permissions, overwrites: dict,
|
||||||
|
target_id: int) -> Permissions:
|
||||||
|
overwrite = overwrites.get(target_id)
|
||||||
|
|
||||||
|
if overwrite:
|
||||||
|
# only mix if overwrite found
|
||||||
|
return overwrite_mix(perms, overwrite)
|
||||||
|
|
||||||
|
return perms
|
||||||
|
|
||||||
|
|
||||||
|
async def role_permissions(guild_id: int, role_id: int,
|
||||||
|
channel_id: int, storage=None) -> Permissions:
|
||||||
|
"""Get the permissions for a role, in relation to a channel"""
|
||||||
|
if not storage:
|
||||||
|
storage = app.storage
|
||||||
|
|
||||||
|
perms = await get_role_perms(guild_id, role_id, storage)
|
||||||
|
|
||||||
|
overwrite = await storage.db.fetchrow("""
|
||||||
|
SELECT allow, deny
|
||||||
|
FROM channel_overwrites
|
||||||
|
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
|
||||||
|
""", channel_id, 1, role_id)
|
||||||
|
|
||||||
|
if overwrite:
|
||||||
|
perms = overwrite_mix(perms, overwrite)
|
||||||
|
|
||||||
|
return perms
|
||||||
|
|
||||||
|
|
||||||
|
async def compute_overwrites(base_perms, user_id, channel_id: int,
|
||||||
|
guild_id: int = None, storage=None):
|
||||||
|
"""Compute the permissions in the context of a channel."""
|
||||||
|
if not storage:
|
||||||
|
storage = app.storage
|
||||||
|
|
||||||
|
if base_perms.bits.administrator:
|
||||||
|
return ALL_PERMISSIONS
|
||||||
|
|
||||||
|
perms = base_perms
|
||||||
|
|
||||||
|
# list of overwrites
|
||||||
|
overwrites = await storage.chan_overwrites(channel_id)
|
||||||
|
|
||||||
|
if not guild_id:
|
||||||
|
guild_id = await storage.guild_from_channel(channel_id)
|
||||||
|
|
||||||
|
# make it a map for better usage
|
||||||
|
overwrites = {o['id']: o for o in overwrites}
|
||||||
|
|
||||||
|
perms = overwrite_find_mix(perms, overwrites, guild_id)
|
||||||
|
|
||||||
|
# apply role specific overwrites
|
||||||
|
allow, deny = 0, 0
|
||||||
|
|
||||||
|
# fetch roles from user and convert to int
|
||||||
|
role_ids = await storage.get_member_role_ids(guild_id, user_id)
|
||||||
|
role_ids = map(int, role_ids)
|
||||||
|
|
||||||
|
# make the allow and deny binaries
|
||||||
|
for role_id in role_ids:
|
||||||
|
overwrite = overwrites.get(role_id)
|
||||||
|
if overwrite:
|
||||||
|
allow |= overwrite['allow']
|
||||||
|
deny |= overwrite['deny']
|
||||||
|
|
||||||
|
# final step for roles: mix
|
||||||
|
perms = overwrite_mix(perms, {
|
||||||
|
'allow': allow,
|
||||||
|
'deny': deny
|
||||||
|
})
|
||||||
|
|
||||||
|
# apply member specific overwrites
|
||||||
|
perms = overwrite_find_mix(perms, overwrites, user_id)
|
||||||
|
|
||||||
|
return perms
|
||||||
|
|
||||||
|
|
||||||
|
async def get_permissions(member_id, channel_id, *, storage=None):
|
||||||
|
"""Get all the permissions for a user in a channel."""
|
||||||
|
if not storage:
|
||||||
|
storage = app.storage
|
||||||
|
|
||||||
|
guild_id = await storage.guild_from_channel(channel_id)
|
||||||
|
|
||||||
|
# for non guild channels
|
||||||
|
if not guild_id:
|
||||||
|
return ALL_PERMISSIONS
|
||||||
|
|
||||||
|
base_perms = await base_permissions(member_id, guild_id, storage)
|
||||||
|
|
||||||
|
return await compute_overwrites(base_perms, member_id,
|
||||||
|
channel_id, guild_id, storage)
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from random import choice
|
from random import choice
|
||||||
|
|
||||||
|
from logbook import Logger
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def status_cmp(status: str, other_status: str) -> bool:
|
def status_cmp(status: str, other_status: str) -> bool:
|
||||||
"""Compare if `status` is better than the `other_status`
|
"""Compare if `status` is better than the `other_status`
|
||||||
|
|
@ -100,20 +103,64 @@ class PresenceManager:
|
||||||
|
|
||||||
game = state['game']
|
game = state['game']
|
||||||
|
|
||||||
await self.dispatcher.dispatch_guild(
|
lazy_guild_store = self.dispatcher.backends['lazy_guild']
|
||||||
guild_id, 'PRESENCE_UPDATE', {
|
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||||
'user': member['user'],
|
|
||||||
'roles': member['roles'],
|
|
||||||
'guild_id': guild_id,
|
|
||||||
|
|
||||||
'status': state['status'],
|
# shards that are in lazy guilds with 'everyone'
|
||||||
|
# enabled
|
||||||
|
in_lazy = []
|
||||||
|
|
||||||
# rich presence stuff
|
for member_list in lists:
|
||||||
'game': game,
|
session_ids = await member_list.pres_update(
|
||||||
'activities': [game] if game else []
|
int(member['user']['id']),
|
||||||
}
|
{
|
||||||
|
'roles': member['roles'],
|
||||||
|
'status': state['status'],
|
||||||
|
'game': game
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
log.debug('Lazy Dispatch to {}',
|
||||||
|
len(session_ids))
|
||||||
|
|
||||||
|
if member_list.channel_id == 'everyone':
|
||||||
|
in_lazy.extend(session_ids)
|
||||||
|
|
||||||
|
pres_update_payload = {
|
||||||
|
'user': member['user'],
|
||||||
|
'roles': member['roles'],
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
|
||||||
|
'status': state['status'],
|
||||||
|
|
||||||
|
# rich presence stuff
|
||||||
|
'game': game,
|
||||||
|
'activities': [game] if game else []
|
||||||
|
}
|
||||||
|
|
||||||
|
def _sane_session(session_id):
|
||||||
|
state = self.state_manager.fetch_raw(session_id)
|
||||||
|
uid = int(member['user']['id'])
|
||||||
|
|
||||||
|
if not state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# we don't want to send a presence update
|
||||||
|
# to the same user
|
||||||
|
return (state.user_id != uid and
|
||||||
|
session_id not in in_lazy)
|
||||||
|
|
||||||
|
# everyone not in lazy guild mode
|
||||||
|
# gets a PRESENCE_UPDATE
|
||||||
|
await self.dispatcher.dispatch_filter(
|
||||||
|
'guild', guild_id,
|
||||||
|
_sane_session,
|
||||||
|
|
||||||
|
'PRESENCE_UPDATE', pres_update_payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return in_lazy
|
||||||
|
|
||||||
async def dispatch_pres(self, user_id: int, state: dict):
|
async def dispatch_pres(self, user_id: int, state: dict):
|
||||||
"""Dispatch a new presence to all guilds the user is in.
|
"""Dispatch a new presence to all guilds the user is in.
|
||||||
|
|
||||||
|
|
@ -122,10 +169,12 @@ class PresenceManager:
|
||||||
if state['status'] == 'invisible':
|
if state['status'] == 'invisible':
|
||||||
state['status'] = 'offline'
|
state['status'] = 'offline'
|
||||||
|
|
||||||
|
# TODO: shard-aware
|
||||||
guild_ids = await self.storage.get_user_guilds(user_id)
|
guild_ids = await self.storage.get_user_guilds(user_id)
|
||||||
|
|
||||||
for guild_id in guild_ids:
|
for guild_id in guild_ids:
|
||||||
await self.dispatch_guild_pres(guild_id, user_id, state)
|
await self.dispatch_guild_pres(
|
||||||
|
guild_id, user_id, state)
|
||||||
|
|
||||||
# dispatch to all friends that are subscribed to them
|
# dispatch to all friends that are subscribed to them
|
||||||
user = await self.storage.get_user(user_id)
|
user = await self.storage.get_user(user_id)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,8 @@ from .member import MemberDispatcher
|
||||||
from .user import UserDispatcher
|
from .user import UserDispatcher
|
||||||
from .channel import ChannelDispatcher
|
from .channel import ChannelDispatcher
|
||||||
from .friend import FriendDispatcher
|
from .friend import FriendDispatcher
|
||||||
|
from .lazy_guild import LazyGuildDispatcher
|
||||||
|
|
||||||
__all__ = ['GuildDispatcher', 'MemberDispatcher',
|
__all__ = ['GuildDispatcher', 'MemberDispatcher',
|
||||||
'UserDispatcher', 'ChannelDispatcher',
|
'UserDispatcher', 'ChannelDispatcher',
|
||||||
'FriendDispatcher']
|
'FriendDispatcher', 'LazyGuildDispatcher']
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,14 @@ class Dispatcher:
|
||||||
"""Unsubscribe an elemtnt from the channel/key."""
|
"""Unsubscribe an elemtnt from the channel/key."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def dispatch_filter(self, _key, _func, *_args):
|
||||||
|
"""Selectively dispatch to the list of subscribed users.
|
||||||
|
|
||||||
|
The selection logic is completly arbitraty and up to the
|
||||||
|
Pub/Sub backend.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def dispatch(self, _key, *_args):
|
async def dispatch(self, _key, *_args):
|
||||||
"""Dispatch an event to the given channel/key."""
|
"""Dispatch an event to the given channel/key."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
@ -47,6 +46,8 @@ class GuildDispatcher(DispatcherWithState):
|
||||||
# when subbing a user to the guild, we should sub them
|
# when subbing a user to the guild, we should sub them
|
||||||
# to every channel they have access to, in the guild.
|
# to every channel they have access to, in the guild.
|
||||||
|
|
||||||
|
# TODO: check for permissions
|
||||||
|
|
||||||
await self._chan_action('sub', guild_id, user_id)
|
await self._chan_action('sub', guild_id, user_id)
|
||||||
|
|
||||||
async def unsub(self, guild_id: int, user_id: int):
|
async def unsub(self, guild_id: int, user_id: int):
|
||||||
|
|
@ -56,9 +57,10 @@ class GuildDispatcher(DispatcherWithState):
|
||||||
# same thing happening from sub() happens on unsub()
|
# same thing happening from sub() happens on unsub()
|
||||||
await self._chan_action('unsub', guild_id, user_id)
|
await self._chan_action('unsub', guild_id, user_id)
|
||||||
|
|
||||||
async def dispatch(self, guild_id: int,
|
async def dispatch_filter(self, guild_id: int, func,
|
||||||
event: str, data: Any):
|
event: str, data: Any):
|
||||||
"""Dispatch an event to all subscribers of the guild."""
|
"""Selectively dispatch to session ids that have
|
||||||
|
func(session_id) true."""
|
||||||
user_ids = self.state[guild_id]
|
user_ids = self.state[guild_id]
|
||||||
dispatched = 0
|
dispatched = 0
|
||||||
|
|
||||||
|
|
@ -75,8 +77,22 @@ class GuildDispatcher(DispatcherWithState):
|
||||||
await self.unsub(guild_id, user_id)
|
await self.unsub(guild_id, user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# filter the ones that matter
|
||||||
|
states = list(filter(
|
||||||
|
lambda state: func(state.session_id), states
|
||||||
|
))
|
||||||
|
|
||||||
dispatched += await self._dispatch_states(
|
dispatched += await self._dispatch_states(
|
||||||
states, event, data)
|
states, event, data)
|
||||||
|
|
||||||
log.info('Dispatched {} {!r} to {} states',
|
log.info('Dispatched {} {!r} to {} states',
|
||||||
guild_id, event, dispatched)
|
guild_id, event, dispatched)
|
||||||
|
|
||||||
|
async def dispatch(self, guild_id: int,
|
||||||
|
event: str, data: Any):
|
||||||
|
"""Dispatch an event to all subscribers of the guild."""
|
||||||
|
await self.dispatch_filter(
|
||||||
|
guild_id,
|
||||||
|
lambda sess_id: True,
|
||||||
|
event, data,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,736 @@
|
||||||
|
"""
|
||||||
|
Main code for Lazy Guild implementation in litecord.
|
||||||
|
"""
|
||||||
|
import pprint
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, List, Dict, Union
|
||||||
|
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
from litecord.pubsub.dispatcher import Dispatcher
|
||||||
|
from litecord.permissions import (
|
||||||
|
Permissions, overwrite_find_mix, get_permissions, role_permissions
|
||||||
|
)
|
||||||
|
from litecord.utils import index_by_func
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
GroupID = Union[int, str]
|
||||||
|
Presence = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GroupInfo:
|
||||||
|
"""Store information about a specific group."""
|
||||||
|
gid: GroupID
|
||||||
|
name: str
|
||||||
|
position: int
|
||||||
|
permissions: Permissions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemberList:
|
||||||
|
"""Total information on the guild's member list."""
|
||||||
|
groups: List[GroupInfo] = None
|
||||||
|
group_info: Dict[GroupID, GroupInfo] = None
|
||||||
|
data: Dict[GroupID, Presence] = None
|
||||||
|
overwrites: Dict[int, Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
"""Return if the current member list is fully initialized."""
|
||||||
|
list_dict = asdict(self)
|
||||||
|
return all(v is not None for v in list_dict.values())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Iterate over all groups in the correct order.
|
||||||
|
|
||||||
|
Yields a tuple containing :class:`GroupInfo` and
|
||||||
|
the List[Presence] for the group.
|
||||||
|
"""
|
||||||
|
if not self.groups:
|
||||||
|
return
|
||||||
|
|
||||||
|
for group in self.groups:
|
||||||
|
yield group, self.data[group.gid]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Operation:
|
||||||
|
"""Represents a member list operation."""
|
||||||
|
list_op: str
|
||||||
|
params: Dict[str, Any]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
res = {
|
||||||
|
'op': self.list_op
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.list_op == 'SYNC':
|
||||||
|
res['items'] = self.params['items']
|
||||||
|
|
||||||
|
if self.list_op in ('SYNC', 'INVALIDATE'):
|
||||||
|
res['range'] = self.params['range']
|
||||||
|
|
||||||
|
if self.list_op in ('INSERT', 'DELETE', 'UPDATE'):
|
||||||
|
res['index'] = self.params['index']
|
||||||
|
|
||||||
|
if self.list_op in ('INSERT', 'UPDATE'):
|
||||||
|
res['item'] = self.params['item']
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _to_simple_group(presence: dict) -> str:
|
||||||
|
"""Return a simple group (not a role), given a presence."""
|
||||||
|
return 'offline' if presence['status'] == 'offline' else 'online'
|
||||||
|
|
||||||
|
|
||||||
|
def display_name(member_nicks: Dict[str, str], presence: Presence) -> str:
|
||||||
|
"""Return the display name of a presence.
|
||||||
|
|
||||||
|
Used to sort groups.
|
||||||
|
"""
|
||||||
|
uid = presence['user']['id']
|
||||||
|
|
||||||
|
uname = presence['user']['username']
|
||||||
|
nick = member_nicks.get(uid)
|
||||||
|
|
||||||
|
return nick or uname
|
||||||
|
|
||||||
|
|
||||||
|
class GuildMemberList:
|
||||||
|
"""This class stores the current member list information
|
||||||
|
for a guild (by channel).
|
||||||
|
|
||||||
|
As channels can have different sets of roles that can
|
||||||
|
read them and so, different lists, this is more of a
|
||||||
|
"channel member list" than a guild member list.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
main_lg: LazyGuildDispatcher
|
||||||
|
Main instance of :class:`LazyGuildDispatcher`,
|
||||||
|
so that we're able to use things such as :class:`Storage`.
|
||||||
|
guild_id: int
|
||||||
|
The Guild ID this instance is referring to.
|
||||||
|
channel_id: int
|
||||||
|
The Channel ID this instance is referring to.
|
||||||
|
member_list: List
|
||||||
|
The actual member list information.
|
||||||
|
state: set
|
||||||
|
The set of session IDs that are subscribed to the guild.
|
||||||
|
|
||||||
|
User IDs being used as the identifier in GuildMemberList
|
||||||
|
is a wrong assumption. It is true Discord rolled out
|
||||||
|
lazy guilds to all of the userbase, but users that are bots,
|
||||||
|
for example, can still rely on PRESENCE_UPDATEs.
|
||||||
|
"""
|
||||||
|
def __init__(self, guild_id: int,
|
||||||
|
channel_id: int, main_lg):
|
||||||
|
self.guild_id = guild_id
|
||||||
|
self.channel_id = channel_id
|
||||||
|
|
||||||
|
self.main = main_lg
|
||||||
|
self.list = MemberList(None, None, None, None)
|
||||||
|
|
||||||
|
#: store the states that are subscribed to the list
|
||||||
|
# type is{session_id: set[list]}
|
||||||
|
self.state = defaultdict(set)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage(self):
|
||||||
|
"""Get the global :class:`Storage` instance."""
|
||||||
|
return self.main.app.storage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def presence(self):
|
||||||
|
"""Get the global :class:`PresenceManager` instance."""
|
||||||
|
return self.main.app.presence
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_man(self):
|
||||||
|
"""Get the global :class:`StateManager` instance."""
|
||||||
|
return self.main.app.state_manager
|
||||||
|
|
||||||
|
@property
|
||||||
|
def list_id(self):
|
||||||
|
"""get the id of the member list."""
|
||||||
|
return ('everyone'
|
||||||
|
if self.channel_id == self.guild_id
|
||||||
|
else str(self.channel_id))
|
||||||
|
|
||||||
|
def _set_empty_list(self):
|
||||||
|
"""Set the member list as being empty."""
|
||||||
|
self.list = MemberList(None, None, None, None)
|
||||||
|
|
||||||
|
async def _init_check(self):
|
||||||
|
"""Check if the member list is initialized before
|
||||||
|
messing with it."""
|
||||||
|
if not self.list:
|
||||||
|
await self._init_member_list()
|
||||||
|
|
||||||
|
async def _fetch_overwrites(self):
|
||||||
|
overwrites = await self.storage.chan_overwrites(self.channel_id)
|
||||||
|
overwrites = {int(ov['id']): ov for ov in overwrites}
|
||||||
|
self.list.overwrites = overwrites
|
||||||
|
|
||||||
|
def _calc_member_group(self, roles: List[int], status: str):
|
||||||
|
"""Calculate the best fitting group for a member,
|
||||||
|
given their roles and their current status."""
|
||||||
|
try:
|
||||||
|
# the first group in the list
|
||||||
|
# that the member is entitled to is
|
||||||
|
# the selected group for the member.
|
||||||
|
group_id = next(g.gid for g in self.list.groups
|
||||||
|
if g.gid in roles)
|
||||||
|
except StopIteration:
|
||||||
|
# no group was found, so we fallback
|
||||||
|
# to simple group"
|
||||||
|
group_id = _to_simple_group({'status': status})
|
||||||
|
|
||||||
|
return group_id
|
||||||
|
|
||||||
|
async def get_roles(self) -> List[GroupInfo]:
|
||||||
|
"""Get role information, but only:
|
||||||
|
- the ID
|
||||||
|
- the name
|
||||||
|
- the position
|
||||||
|
- the permissions
|
||||||
|
|
||||||
|
of all HOISTED roles AND roles that
|
||||||
|
have permissions to read the channel
|
||||||
|
being referred to this :class:`GuildMemberList`
|
||||||
|
instance.
|
||||||
|
|
||||||
|
The list is sorted by position.
|
||||||
|
"""
|
||||||
|
roledata = await self.storage.db.fetch("""
|
||||||
|
SELECT id, name, hoist, position, permissions
|
||||||
|
FROM roles
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""", self.guild_id)
|
||||||
|
|
||||||
|
hoisted = [
|
||||||
|
GroupInfo(row['id'], row['name'],
|
||||||
|
row['position'], row['permissions'])
|
||||||
|
for row in roledata if row['hoist']
|
||||||
|
]
|
||||||
|
|
||||||
|
# sort role list by position
|
||||||
|
hoisted = sorted(hoisted, key=lambda group: group.position)
|
||||||
|
|
||||||
|
# we need to store them for later on
|
||||||
|
# for members
|
||||||
|
await self._fetch_overwrites()
|
||||||
|
|
||||||
|
def _can_read_chan(group: GroupInfo):
|
||||||
|
# get the base role perms
|
||||||
|
role_perms = group.permissions
|
||||||
|
|
||||||
|
# then the final perms for that role if
|
||||||
|
# any overwrite exists in the channel
|
||||||
|
final_perms = overwrite_find_mix(
|
||||||
|
role_perms, self.list.overwrites, group.gid)
|
||||||
|
|
||||||
|
# update the group's permissions
|
||||||
|
# with the mixed ones
|
||||||
|
group.permissions = final_perms
|
||||||
|
|
||||||
|
# if the role can read messages, then its
|
||||||
|
# part of the group.
|
||||||
|
return final_perms.bits.read_messages
|
||||||
|
|
||||||
|
return list(filter(_can_read_chan, hoisted))
|
||||||
|
|
||||||
|
async def set_groups(self):
|
||||||
|
"""Get the groups for the member list."""
|
||||||
|
role_groups = await self.get_roles()
|
||||||
|
role_ids = [g.gid for g in role_groups]
|
||||||
|
|
||||||
|
self.list.groups = role_ids + ['online', 'offline']
|
||||||
|
|
||||||
|
# inject default groups 'online' and 'offline'
|
||||||
|
self.list.groups = role_ids + [
|
||||||
|
GroupInfo('online', 'online', -1, -1),
|
||||||
|
GroupInfo('offline', 'offline', -1, -1)
|
||||||
|
]
|
||||||
|
self.list.group_info = {g.gid: g for g in role_groups}
|
||||||
|
|
||||||
|
async def get_group(self, member_id: int,
|
||||||
|
roles: List[Union[str, int]],
|
||||||
|
status: str) -> int:
|
||||||
|
"""Return a fitting group ID for the user."""
|
||||||
|
member_roles = list(map(int, roles))
|
||||||
|
|
||||||
|
# get the member's permissions relative to the channel
|
||||||
|
# (accounting for channel overwrites)
|
||||||
|
member_perms = await get_permissions(
|
||||||
|
member_id, self.channel_id, storage=self.storage)
|
||||||
|
|
||||||
|
if not member_perms.bits.read_messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# if the member is offline, we
|
||||||
|
# default give them the offline group.
|
||||||
|
group_id = ('offline' if status == 'offline'
|
||||||
|
else self._calc_member_group(member_roles, status))
|
||||||
|
|
||||||
|
return group_id
|
||||||
|
|
||||||
|
async def _pass_1(self, guild_presences: List[Presence]):
|
||||||
|
"""First pass on generating the member list.
|
||||||
|
|
||||||
|
This assigns all presences a single group.
|
||||||
|
"""
|
||||||
|
for presence in guild_presences:
|
||||||
|
member_id = int(presence['user']['id'])
|
||||||
|
|
||||||
|
group_id = await self.get_group(
|
||||||
|
member_id, presence['roles'], presence['status']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.list.data[group_id].append(presence)
|
||||||
|
|
||||||
|
async def get_member_nicks_dict(self) -> dict:
|
||||||
|
"""Get a dictionary with nickname information."""
|
||||||
|
members = await self.storage.get_member_data(self.guild_id)
|
||||||
|
|
||||||
|
# make a dictionary of member ids to nicknames
|
||||||
|
# so we don't need to keep querying the db on
|
||||||
|
# every loop iteration
|
||||||
|
member_nicks = {m['user']['id']: m.get('nick')
|
||||||
|
for m in members}
|
||||||
|
|
||||||
|
return member_nicks
|
||||||
|
|
||||||
|
async def _sort_groups(self):
|
||||||
|
member_nicks = await self.get_member_nicks_dict()
|
||||||
|
|
||||||
|
for group_members in self.list.data.values():
|
||||||
|
|
||||||
|
# this should update the list in-place
|
||||||
|
group_members.sort(
|
||||||
|
key=lambda p: display_name(member_nicks, p))
|
||||||
|
|
||||||
|
async def _init_member_list(self):
|
||||||
|
"""Generate the main member list with groups."""
|
||||||
|
member_ids = await self.storage.get_member_ids(self.guild_id)
|
||||||
|
|
||||||
|
guild_presences = await self.presence.guild_presences(
|
||||||
|
member_ids, self.guild_id)
|
||||||
|
|
||||||
|
await self.set_groups()
|
||||||
|
|
||||||
|
log.debug('{} presences, {} groups',
|
||||||
|
len(guild_presences),
|
||||||
|
len(self.list.groups))
|
||||||
|
|
||||||
|
self.list.data = {group.gid: [] for group in self.list.groups}
|
||||||
|
|
||||||
|
# first pass: set which presences
|
||||||
|
# go to which groups
|
||||||
|
await self._pass_1(guild_presences)
|
||||||
|
|
||||||
|
# second pass: sort each group's members
|
||||||
|
# by the display name
|
||||||
|
await self._sort_groups()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def items(self) -> list:
|
||||||
|
"""Main items list."""
|
||||||
|
|
||||||
|
# TODO: maybe make this stored in the list
|
||||||
|
# so we don't need to keep regenning?
|
||||||
|
|
||||||
|
if not self.list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
# NOTE: maybe use map()?
|
||||||
|
for group, presences in self.list:
|
||||||
|
res.append({
|
||||||
|
'group': {
|
||||||
|
'id': group.gid,
|
||||||
|
'count': len(presences),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
for presence in presences:
|
||||||
|
res.append({
|
||||||
|
'member': presence
|
||||||
|
})
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def sub(self, _session_id: str):
|
||||||
|
"""Subscribe a shard to the member list."""
|
||||||
|
await self._init_check()
|
||||||
|
|
||||||
|
def unsub(self, session_id: str):
|
||||||
|
"""Unsubscribe a shard from the member list"""
|
||||||
|
try:
|
||||||
|
self.state.pop(session_id)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# once we reach 0 subscribers,
|
||||||
|
# we drop the current member list we have (for memory)
|
||||||
|
# but keep the GuildMemberList running (as
|
||||||
|
# uninitialized) for a future subscriber.
|
||||||
|
|
||||||
|
if not self.state:
|
||||||
|
self._set_empty_list()
|
||||||
|
|
||||||
|
def get_state(self, session_id: str):
|
||||||
|
try:
|
||||||
|
state = self.state_man.fetch_raw(session_id)
|
||||||
|
return state
|
||||||
|
except KeyError:
|
||||||
|
self.unsub(session_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _dispatch_sess(self, session_ids: List[str],
|
||||||
|
operations: List[Operation]):
|
||||||
|
|
||||||
|
# construct the payload to dispatch
|
||||||
|
payload = {
|
||||||
|
'id': self.list_id,
|
||||||
|
'guild_id': str(self.guild_id),
|
||||||
|
|
||||||
|
'groups': [
|
||||||
|
{
|
||||||
|
'count': len(presences),
|
||||||
|
'id': group.gid
|
||||||
|
} for group, presences in self.list
|
||||||
|
],
|
||||||
|
|
||||||
|
'ops': [
|
||||||
|
operation.to_dict
|
||||||
|
for operation in operations
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
states = map(self.get_state, session_ids)
|
||||||
|
states = filter(lambda state: state is not None, states)
|
||||||
|
|
||||||
|
dispatched = []
|
||||||
|
|
||||||
|
for state in states:
|
||||||
|
await state.ws.dispatch(
|
||||||
|
'GUILD_MEMBER_LIST_UPDATE', payload)
|
||||||
|
|
||||||
|
dispatched.append(state.session_id)
|
||||||
|
|
||||||
|
return dispatched
|
||||||
|
|
||||||
|
async def shard_query(self, session_id: str, ranges: list):
|
||||||
|
"""Send a GUILD_MEMBER_LIST_UPDATE event
|
||||||
|
for a shard that is querying about the member list.
|
||||||
|
|
||||||
|
Paramteters
|
||||||
|
-----------
|
||||||
|
session_id: str
|
||||||
|
The Session ID querying information.
|
||||||
|
channel_id: int
|
||||||
|
The Channel ID that we want information on.
|
||||||
|
ranges: List[List[int]]
|
||||||
|
ranges of the list that we want.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# a guild list with a channel id of the guild
|
||||||
|
# represents the 'everyone' global list.
|
||||||
|
list_id = self.list_id
|
||||||
|
|
||||||
|
# if everyone can read the channel,
|
||||||
|
# we direct the request to the 'everyone' gml instance
|
||||||
|
# instead of the current one.
|
||||||
|
everyone_perms = await role_permissions(
|
||||||
|
self.guild_id,
|
||||||
|
self.guild_id,
|
||||||
|
self.channel_id,
|
||||||
|
storage=self.storage
|
||||||
|
)
|
||||||
|
|
||||||
|
if everyone_perms.bits.read_messages and list_id != 'everyone':
|
||||||
|
everyone_gml = await self.main.get_gml(self.guild_id)
|
||||||
|
|
||||||
|
return await everyone_gml.shard_query(
|
||||||
|
session_id, ranges
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._init_check()
|
||||||
|
|
||||||
|
ops = []
|
||||||
|
|
||||||
|
for start, end in ranges:
|
||||||
|
itemcount = end - start
|
||||||
|
|
||||||
|
# ignore incorrect ranges
|
||||||
|
if itemcount < 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.state[session_id].add((start, end))
|
||||||
|
|
||||||
|
ops.append(Operation('SYNC', {
|
||||||
|
'range': [start, end],
|
||||||
|
'items': self.items[start:end]
|
||||||
|
}))
|
||||||
|
|
||||||
|
await self._dispatch_sess([session_id], ops)
|
||||||
|
|
||||||
|
def get_item_index(self, user_id: Union[str, int]):
|
||||||
|
"""Get the item index a user is on."""
|
||||||
|
def _get_id(item):
|
||||||
|
# item can be a group item or a member item
|
||||||
|
return item.get('member', {}).get('user', {}).get('id')
|
||||||
|
|
||||||
|
# get the updated item's index
|
||||||
|
return index_by_func(
|
||||||
|
lambda p: _get_id(p) == str(user_id),
|
||||||
|
self.items
|
||||||
|
)
|
||||||
|
|
||||||
|
def state_is_subbed(self, item_index, session_id: str) -> bool:
|
||||||
|
"""Return if a state's ranges include the given
|
||||||
|
item index."""
|
||||||
|
|
||||||
|
ranges = self.state[session_id]
|
||||||
|
|
||||||
|
for range_start, range_end in ranges:
|
||||||
|
if range_start <= item_index <= range_end:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_subs(self, item_index: int) -> filter:
|
||||||
|
"""Get the list of subscribed states to a given item."""
|
||||||
|
return filter(
|
||||||
|
lambda sess_id: self.state_is_subbed(item_index, sess_id),
|
||||||
|
self.state.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _pres_update_simple(self, user_id: int):
|
||||||
|
item_index = self.get_item_index(user_id)
|
||||||
|
|
||||||
|
if item_index is None:
|
||||||
|
log.warning('lazy guild got invalid pres update uid={}',
|
||||||
|
user_id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
item = self.items[item_index]
|
||||||
|
session_ids = self.get_subs(item_index)
|
||||||
|
|
||||||
|
# simple update means we just give an UPDATE
|
||||||
|
# operation
|
||||||
|
return await self._dispatch_sess(
|
||||||
|
session_ids,
|
||||||
|
[
|
||||||
|
Operation('UPDATE', {
|
||||||
|
'index': item_index,
|
||||||
|
'item': item,
|
||||||
|
})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _pres_update_complex(self, user_id: int,
|
||||||
|
old_group: str, old_index: int,
|
||||||
|
new_group: str):
|
||||||
|
"""Move a member between groups."""
|
||||||
|
log.debug('complex update: uid={} old={} old_idx={} new={}',
|
||||||
|
user_id, old_group, old_index, new_group)
|
||||||
|
old_group_presences = self.list.data[old_group]
|
||||||
|
old_item_index = self.get_item_index(user_id)
|
||||||
|
|
||||||
|
# make a copy of current presence to insert in the new group
|
||||||
|
current_presence = dict(old_group_presences[old_index])
|
||||||
|
|
||||||
|
# step 1: remove the old presence (old_index is relative
|
||||||
|
# to the group, and not the items list)
|
||||||
|
del old_group_presences[old_index]
|
||||||
|
|
||||||
|
# we need to insert current_presence to the new group
|
||||||
|
# but we also need to calculate its index to insert on.
|
||||||
|
presences = self.list.data[new_group]
|
||||||
|
|
||||||
|
best_index = 0
|
||||||
|
member_nicks = await self.get_member_nicks_dict()
|
||||||
|
current_name = display_name(member_nicks, current_presence)
|
||||||
|
|
||||||
|
# go through each one until we find the best placement
|
||||||
|
for presence in presences:
|
||||||
|
name = display_name(member_nicks, presence)
|
||||||
|
|
||||||
|
print(name, current_name, name < current_name)
|
||||||
|
|
||||||
|
# TODO: check if this works
|
||||||
|
if name < current_name:
|
||||||
|
break
|
||||||
|
|
||||||
|
best_index += 1
|
||||||
|
|
||||||
|
# insert the presence at the index
|
||||||
|
presences.insert(best_index + 1, current_presence)
|
||||||
|
|
||||||
|
new_item_index = self.get_item_index(user_id)
|
||||||
|
|
||||||
|
log.debug('assigned new item index {} to uid {}',
|
||||||
|
new_item_index, user_id)
|
||||||
|
|
||||||
|
session_ids_old = self.get_subs(old_item_index)
|
||||||
|
session_ids_new = self.get_subs(new_item_index)
|
||||||
|
|
||||||
|
# dispatch events to both the old states and
|
||||||
|
# new states.
|
||||||
|
return await self._dispatch_sess(
|
||||||
|
# inefficient, but necessary since we
|
||||||
|
# want to merge both session ids.
|
||||||
|
list(session_ids_old) + list(session_ids_new),
|
||||||
|
[
|
||||||
|
Operation('DELETE', {
|
||||||
|
'index': old_item_index,
|
||||||
|
}),
|
||||||
|
|
||||||
|
Operation('INSERT', {
|
||||||
|
'index': new_item_index,
|
||||||
|
'item': {
|
||||||
|
'member': current_presence
|
||||||
|
}
|
||||||
|
})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def pres_update(self, user_id: int,
|
||||||
|
partial_presence: Dict[str, Any]):
|
||||||
|
"""Update a presence inside the member list.
|
||||||
|
|
||||||
|
There are 4 types of updates that can happen for a user in a group:
|
||||||
|
- from 'offline' to any
|
||||||
|
- from any to 'offline'
|
||||||
|
- from any to any
|
||||||
|
- from G to G (with G being any group)
|
||||||
|
|
||||||
|
any: 'online' | role_id
|
||||||
|
|
||||||
|
All first, second, and third updates are 'complex' updates,
|
||||||
|
which means we'll have to change the group the user is on
|
||||||
|
to account for them.
|
||||||
|
|
||||||
|
The fourth is a 'simple' change, since we're not changing
|
||||||
|
the group a user is on, and so there's less overhead
|
||||||
|
involved.
|
||||||
|
"""
|
||||||
|
await self._init_check()
|
||||||
|
|
||||||
|
old_group, old_index, old_presence = None, None, None
|
||||||
|
|
||||||
|
for group, presences in self.list:
|
||||||
|
p_idx = index_by_func(
|
||||||
|
lambda p: p['user']['id'] == str(user_id),
|
||||||
|
presences)
|
||||||
|
|
||||||
|
log.debug('p_idx for group {!r} = {}',
|
||||||
|
group.gid, p_idx)
|
||||||
|
|
||||||
|
if p_idx is None:
|
||||||
|
log.debug('skipping group {}', group)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# make a copy since we're modifying in-place
|
||||||
|
old_group = group.gid
|
||||||
|
old_index = p_idx
|
||||||
|
old_presence = dict(presences[p_idx])
|
||||||
|
|
||||||
|
# be ready if it is a simple update
|
||||||
|
presences[p_idx].update(partial_presence)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not old_group:
|
||||||
|
log.warning('pres update with unknown old group uid={}',
|
||||||
|
user_id)
|
||||||
|
return []
|
||||||
|
|
||||||
|
roles = partial_presence.get('roles', old_presence['roles'])
|
||||||
|
new_status = partial_presence.get('status', old_presence['status'])
|
||||||
|
|
||||||
|
new_group = await self.get_group(user_id, roles, new_status)
|
||||||
|
|
||||||
|
log.debug('pres update: gid={} cid={} old_g={} new_g={}',
|
||||||
|
self.guild_id, self.channel_id, old_group, new_group)
|
||||||
|
|
||||||
|
# if we're going to the same group,
|
||||||
|
# treat this as a simple update
|
||||||
|
if old_group == new_group:
|
||||||
|
return await self._pres_update_simple(user_id)
|
||||||
|
|
||||||
|
return await self._pres_update_complex(
|
||||||
|
user_id, old_group, old_index, new_group)
|
||||||
|
|
||||||
|
async def dispatch(self, event: str, data: Any):
|
||||||
|
"""Modify the member list and dispatch the respective
|
||||||
|
events to subscribed shards.
|
||||||
|
|
||||||
|
GuildMemberList stores the current guilds' list
|
||||||
|
in its :attr:`GuildMemberList.list` attribute,
|
||||||
|
with that attribute being modified via different
|
||||||
|
calls to :meth:`GuildMemberList.dispatch`
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if no subscribers, drop event
|
||||||
|
if not self.list:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class LazyGuildDispatcher(Dispatcher):
|
||||||
|
"""Main class holding the member lists for lazy guilds."""
|
||||||
|
# channel ids
|
||||||
|
KEY_TYPE = int
|
||||||
|
|
||||||
|
# the session ids subscribing to channels
|
||||||
|
VAL_TYPE = str
|
||||||
|
|
||||||
|
def __init__(self, main):
|
||||||
|
super().__init__(main)
|
||||||
|
|
||||||
|
self.storage = main.app.storage
|
||||||
|
|
||||||
|
# {chan_id: gml, ...}
|
||||||
|
self.state = {}
|
||||||
|
|
||||||
|
#: store which guilds have their
|
||||||
|
# respective GMLs
|
||||||
|
# {guild_id: [chan_id, ...], ...}
|
||||||
|
self.guild_map = defaultdict(list)
|
||||||
|
|
||||||
|
async def get_gml(self, channel_id: int):
|
||||||
|
"""Get a guild list for a channel ID,
|
||||||
|
generating it if it doesn't exist."""
|
||||||
|
try:
|
||||||
|
return self.state[channel_id]
|
||||||
|
except KeyError:
|
||||||
|
guild_id = await self.storage.guild_from_channel(
|
||||||
|
channel_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# if we don't find a guild, we just
|
||||||
|
# set it the same as the channel.
|
||||||
|
if not guild_id:
|
||||||
|
guild_id = channel_id
|
||||||
|
|
||||||
|
gml = GuildMemberList(guild_id, channel_id, self)
|
||||||
|
self.state[channel_id] = gml
|
||||||
|
self.guild_map[guild_id].append(channel_id)
|
||||||
|
return gml
|
||||||
|
|
||||||
|
def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]:
|
||||||
|
"""Get all member lists for a given guild."""
|
||||||
|
return list(map(
|
||||||
|
self.state.get,
|
||||||
|
self.guild_map[guild_id]
|
||||||
|
))
|
||||||
|
|
||||||
|
async def unsub(self, chan_id, session_id):
|
||||||
|
gml = await self.get_gml(chan_id)
|
||||||
|
gml.unsub(session_id)
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""
|
||||||
|
main litecord ratelimiting code
|
||||||
|
|
||||||
|
This code was copied from elixire's ratelimiting,
|
||||||
|
which in turn is a work on top of discord.py's ratelimiting.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class RatelimitBucket:
|
||||||
|
"""Main ratelimit bucket class."""
|
||||||
|
def __init__(self, tokens, second):
|
||||||
|
self.requests = tokens
|
||||||
|
self.second = second
|
||||||
|
|
||||||
|
self._window = 0.0
|
||||||
|
self._tokens = self.requests
|
||||||
|
self.retries = 0
|
||||||
|
self._last = 0.0
|
||||||
|
|
||||||
|
def get_tokens(self, current):
|
||||||
|
"""Get the current amount of available tokens."""
|
||||||
|
if not current:
|
||||||
|
current = time.time()
|
||||||
|
|
||||||
|
# by default, use _tokens
|
||||||
|
tokens = self._tokens
|
||||||
|
|
||||||
|
# if current timestamp is above _window + seconds
|
||||||
|
# reset tokens to self.requests (default)
|
||||||
|
if current > self._window + self.second:
|
||||||
|
tokens = self.requests
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def update_rate_limit(self):
|
||||||
|
"""Update current ratelimit state."""
|
||||||
|
current = time.time()
|
||||||
|
self._last = current
|
||||||
|
self._tokens = self.get_tokens(current)
|
||||||
|
|
||||||
|
# we are using the ratelimit for the first time
|
||||||
|
# so set current ratelimit window to right now
|
||||||
|
if self._tokens == self.requests:
|
||||||
|
self._window = current
|
||||||
|
|
||||||
|
# Are we currently ratelimited?
|
||||||
|
if self._tokens == 0:
|
||||||
|
self.retries += 1
|
||||||
|
return self.second - (current - self._window)
|
||||||
|
|
||||||
|
# if not ratelimited, remove a token
|
||||||
|
self.retries = 0
|
||||||
|
self._tokens -= 1
|
||||||
|
|
||||||
|
# if we got ratelimited after that token removal,
|
||||||
|
# set window to now
|
||||||
|
if self._tokens == 0:
|
||||||
|
self._window = current
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset current ratelimit to default state."""
|
||||||
|
self._tokens = self.requests
|
||||||
|
self._last = 0.0
|
||||||
|
self.retries = 0
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
"""Create a copy of this ratelimit.
|
||||||
|
|
||||||
|
Used to manage multiple ratelimits to users.
|
||||||
|
"""
|
||||||
|
return RatelimitBucket(self.requests,
|
||||||
|
self.second)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f'<RatelimitBucket requests={self.requests} '
|
||||||
|
f'second={self.second} window: {self._window} '
|
||||||
|
f'tokens={self._tokens}>')
|
||||||
|
|
||||||
|
|
||||||
|
class Ratelimit:
|
||||||
|
"""Manages buckets."""
|
||||||
|
def __init__(self, tokens, second, keys=None):
|
||||||
|
self._cache = {}
|
||||||
|
if keys is None:
|
||||||
|
keys = tuple()
|
||||||
|
self.keys = keys
|
||||||
|
self._cooldown = RatelimitBucket(tokens, second)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f'<Ratelimit cooldown={self._cooldown}>')
|
||||||
|
|
||||||
|
def _verify_cache(self):
|
||||||
|
current = time.time()
|
||||||
|
dead_keys = [k for k, v in self._cache.items()
|
||||||
|
if current > v._last + v.second]
|
||||||
|
|
||||||
|
for k in dead_keys:
|
||||||
|
del self._cache[k]
|
||||||
|
|
||||||
|
def get_bucket(self, key) -> RatelimitBucket:
|
||||||
|
if not self._cooldown:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._verify_cache()
|
||||||
|
|
||||||
|
if key not in self._cache:
|
||||||
|
bucket = self._cooldown.copy()
|
||||||
|
self._cache[key] = bucket
|
||||||
|
else:
|
||||||
|
bucket = self._cache[key]
|
||||||
|
|
||||||
|
return bucket
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
from quart import current_app as app, request, g
|
||||||
|
|
||||||
|
from litecord.errors import Ratelimited
|
||||||
|
from litecord.auth import token_check, Unauthorized
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_bucket(bucket):
|
||||||
|
retry_after = bucket.update_rate_limit()
|
||||||
|
|
||||||
|
request.bucket = bucket
|
||||||
|
|
||||||
|
if retry_after:
|
||||||
|
request.retry_after = retry_after
|
||||||
|
|
||||||
|
raise Ratelimited('You are being rate limited.', {
|
||||||
|
'retry_after': int(retry_after * 1000),
|
||||||
|
'global': request.bucket_global,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_global(ratelimit):
|
||||||
|
"""Global ratelimit is per-user."""
|
||||||
|
try:
|
||||||
|
user_id = await token_check()
|
||||||
|
except Unauthorized:
|
||||||
|
user_id = request.remote_addr
|
||||||
|
|
||||||
|
request.bucket_global = True
|
||||||
|
bucket = ratelimit.get_bucket(user_id)
|
||||||
|
await _check_bucket(bucket)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_specific(ratelimit):
|
||||||
|
try:
|
||||||
|
user_id = await token_check()
|
||||||
|
except Unauthorized:
|
||||||
|
user_id = request.remote_addr
|
||||||
|
|
||||||
|
# construct the key based on the ratelimit.keys
|
||||||
|
keys = ratelimit.keys
|
||||||
|
|
||||||
|
# base key is the user id
|
||||||
|
key_components = [f'user_id:{user_id}']
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
val = request.view_args[key]
|
||||||
|
key_components.append(f'{key}:{val}')
|
||||||
|
|
||||||
|
bucket_key = ':'.join(key_components)
|
||||||
|
bucket = ratelimit.get_bucket(bucket_key)
|
||||||
|
await _check_bucket(bucket)
|
||||||
|
|
||||||
|
|
||||||
|
async def ratelimit_handler():
|
||||||
|
"""Main ratelimit handler.
|
||||||
|
|
||||||
|
Decides on which ratelimit to use.
|
||||||
|
"""
|
||||||
|
rule = request.url_rule
|
||||||
|
|
||||||
|
if rule is None:
|
||||||
|
return await _handle_global(
|
||||||
|
app.ratelimiter.global_bucket
|
||||||
|
)
|
||||||
|
|
||||||
|
# rule.endpoint is composed of '<blueprint>.<function>'
|
||||||
|
# and so we can use that to make routes with different
|
||||||
|
# methods have different ratelimits
|
||||||
|
rule_path = rule.endpoint
|
||||||
|
|
||||||
|
# some request ratelimit context.
|
||||||
|
# TODO: maybe put those in a namedtuple or contextvar of sorts?
|
||||||
|
request.bucket = None
|
||||||
|
request.retry_after = None
|
||||||
|
request.bucket_global = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
|
||||||
|
await _handle_specific(ratelimit)
|
||||||
|
except KeyError:
|
||||||
|
await _handle_global(
|
||||||
|
app.ratelimiter.global_bucket
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
from litecord.ratelimits.bucket import Ratelimit
|
||||||
|
|
||||||
|
"""
|
||||||
|
REST:
|
||||||
|
POST Message | 5/5s | per-channel
|
||||||
|
DELETE Message | 5/1s | per-channel
|
||||||
|
PUT/DELETE Reaction | 1/0.25s | per-channel
|
||||||
|
PATCH Member | 10/10s | per-guild
|
||||||
|
PATCH Member Nick | 1/1s | per-guild
|
||||||
|
PATCH Username | 2/3600s | per-account
|
||||||
|
|All Requests| | 50/1s | per-account
|
||||||
|
WS:
|
||||||
|
Gateway Connect | 1/5s | per-account
|
||||||
|
Presence Update | 5/60s | per-session
|
||||||
|
|All Sent Messages| | 120/60s | per-session
|
||||||
|
"""
|
||||||
|
|
||||||
|
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
|
||||||
|
|
||||||
|
RATELIMITS = {
|
||||||
|
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
|
||||||
|
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
|
||||||
|
|
||||||
|
# all of those share the same bucket.
|
||||||
|
'channel_reactions.add_reaction': REACTION_BUCKET,
|
||||||
|
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
|
||||||
|
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
|
||||||
|
|
||||||
|
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
|
||||||
|
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
|
||||||
|
|
||||||
|
# this only applies to username.
|
||||||
|
# 'users.patch_me': Ratelimit(2, 3600),
|
||||||
|
|
||||||
|
'_ws.connect': Ratelimit(1, 5),
|
||||||
|
'_ws.presence': Ratelimit(5, 60),
|
||||||
|
'_ws.messages': Ratelimit(120, 60),
|
||||||
|
|
||||||
|
# 1000 / 4h for new session issuing
|
||||||
|
'_ws.session': Ratelimit(1000, 14400)
|
||||||
|
}
|
||||||
|
|
||||||
|
class RatelimitManager:
|
||||||
|
"""Manager for the bucket managers"""
|
||||||
|
def __init__(self):
|
||||||
|
self._ratelimiters = {}
|
||||||
|
self.global_bucket = Ratelimit(50, 1)
|
||||||
|
self._fill_rtl()
|
||||||
|
|
||||||
|
def _fill_rtl(self):
|
||||||
|
for path, rtl in RATELIMITS.items():
|
||||||
|
self._ratelimiters[path] = rtl
|
||||||
|
|
||||||
|
def get_ratelimit(self, key: str) -> Ratelimit:
|
||||||
|
"""Get the :class:`Ratelimit` instance for a given path."""
|
||||||
|
return self._ratelimiters.get(key, self.global_bucket)
|
||||||
|
|
@ -1,11 +1,16 @@
|
||||||
import re
|
import re
|
||||||
|
from typing import Union, Dict, List, Any
|
||||||
|
|
||||||
from cerberus import Validator
|
from cerberus import Validator
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .errors import BadRequest
|
from .errors import BadRequest
|
||||||
from .enums import ActivityType, StatusType, ExplicitFilter, \
|
from .permissions import Permissions
|
||||||
RelationshipType, MessageNotifications
|
from .types import Color
|
||||||
|
from .enums import (
|
||||||
|
ActivityType, StatusType, ExplicitFilter, RelationshipType,
|
||||||
|
MessageNotifications, ChannelType, VerificationLevel
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -24,13 +29,21 @@ EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
|
||||||
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
|
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
|
||||||
|
|
||||||
|
|
||||||
|
def _in_enum(enum, value: int):
|
||||||
|
try:
|
||||||
|
enum(value)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LitecordValidator(Validator):
|
class LitecordValidator(Validator):
|
||||||
def _validate_type_username(self, value: str) -> bool:
|
def _validate_type_username(self, value: str) -> bool:
|
||||||
"""Validate against the username regex."""
|
"""Validate against the username regex."""
|
||||||
return bool(USERNAME_REGEX.match(value))
|
return bool(USERNAME_REGEX.match(value))
|
||||||
|
|
||||||
def _validate_type_email(self, value: str) -> bool:
|
def _validate_type_email(self, value: str) -> bool:
|
||||||
"""Validate against the username regex."""
|
"""Validate against the email regex."""
|
||||||
return bool(EMAIL_REGEX.match(value))
|
return bool(EMAIL_REGEX.match(value))
|
||||||
|
|
||||||
def _validate_type_b64_icon(self, value: str) -> bool:
|
def _validate_type_b64_icon(self, value: str) -> bool:
|
||||||
|
|
@ -56,11 +69,17 @@ class LitecordValidator(Validator):
|
||||||
|
|
||||||
def _validate_type_voice_region(self, value: str) -> bool:
|
def _validate_type_voice_region(self, value: str) -> bool:
|
||||||
# TODO: complete this list
|
# TODO: complete this list
|
||||||
return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
|
return value.lower() in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
|
||||||
|
|
||||||
|
def _validate_type_verification_level(self, value: int) -> bool:
|
||||||
|
return _in_enum(VerificationLevel, value)
|
||||||
|
|
||||||
def _validate_type_activity_type(self, value: int) -> bool:
|
def _validate_type_activity_type(self, value: int) -> bool:
|
||||||
return value in ActivityType.values()
|
return value in ActivityType.values()
|
||||||
|
|
||||||
|
def _validate_type_channel_type(self, value: int) -> bool:
|
||||||
|
return value in ChannelType.values()
|
||||||
|
|
||||||
def _validate_type_status_external(self, value: str) -> bool:
|
def _validate_type_status_external(self, value: str) -> bool:
|
||||||
statuses = StatusType.values()
|
statuses = StatusType.values()
|
||||||
|
|
||||||
|
|
@ -94,11 +113,31 @@ class LitecordValidator(Validator):
|
||||||
|
|
||||||
return val in MessageNotifications.values()
|
return val in MessageNotifications.values()
|
||||||
|
|
||||||
|
def _validate_type_guild_name(self, value: str) -> bool:
|
||||||
|
return 2 <= len(value) <= 100
|
||||||
|
|
||||||
def validate(reqjson, schema, raise_err: bool = True):
|
def _validate_type_role_name(self, value: str) -> bool:
|
||||||
|
return 1 <= len(value) <= 100
|
||||||
|
|
||||||
|
def _validate_type_channel_name(self, value: str) -> bool:
|
||||||
|
# for now, we'll use the same validation for guild_name
|
||||||
|
return self._validate_type_guild_name(value)
|
||||||
|
|
||||||
|
|
||||||
|
def validate(reqjson: Union[Dict, List], schema: Dict,
|
||||||
|
raise_err: bool = True) -> Union[Dict, List]:
|
||||||
|
"""Validate a given document (user-input) and give
|
||||||
|
the correct document as a result.
|
||||||
|
"""
|
||||||
validator = LitecordValidator(schema)
|
validator = LitecordValidator(schema)
|
||||||
|
|
||||||
if not validator.validate(reqjson):
|
try:
|
||||||
|
valid = validator.validate(reqjson)
|
||||||
|
except Exception:
|
||||||
|
log.exception('Error while validating')
|
||||||
|
raise Exception(f'Error while validating: {reqjson}')
|
||||||
|
|
||||||
|
if not valid:
|
||||||
errs = validator.errors
|
errs = validator.errors
|
||||||
log.warning('Error validating doc {!r}: {!r}', reqjson, errs)
|
log.warning('Error validating doc {!r}: {!r}', reqjson, errs)
|
||||||
|
|
||||||
|
|
@ -146,16 +185,55 @@ USER_UPDATE = {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PARTIAL_ROLE_GUILD_CREATE = {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': {
|
||||||
|
'name': {'type': 'role_name'},
|
||||||
|
'color': {'type': 'number', 'default': 0},
|
||||||
|
'hoist': {'type': 'boolean', 'default': False},
|
||||||
|
|
||||||
|
# NOTE: no position on partial role (on guild create)
|
||||||
|
|
||||||
|
'permissions': {'coerce': Permissions, 'required': False},
|
||||||
|
'mentionable': {'type': 'boolean', 'default': False},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PARTIAL_CHANNEL_GUILD_CREATE = {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': {
|
||||||
|
'name': {'type': 'channel_name'},
|
||||||
|
'type': {'type': 'channel_type'},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GUILD_CREATE = {
|
||||||
|
'name': {'type': 'guild_name'},
|
||||||
|
'region': {'type': 'voice_region'},
|
||||||
|
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
|
||||||
|
|
||||||
|
'verification_level': {
|
||||||
|
'type': 'verification_level', 'default': 0},
|
||||||
|
'default_message_notifications': {
|
||||||
|
'type': 'msg_notifications', 'default': 0},
|
||||||
|
'explicit_content_filter': {
|
||||||
|
'type': 'explicit', 'default': 0},
|
||||||
|
|
||||||
|
'roles': {
|
||||||
|
'type': 'list', 'required': False,
|
||||||
|
'schema': PARTIAL_ROLE_GUILD_CREATE},
|
||||||
|
'channels': {
|
||||||
|
'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
GUILD_UPDATE = {
|
GUILD_UPDATE = {
|
||||||
'name': {
|
'name': {
|
||||||
'type': 'string',
|
'type': 'guild_name',
|
||||||
'minlength': 2,
|
|
||||||
'maxlength': 100,
|
|
||||||
'required': False
|
'required': False
|
||||||
},
|
},
|
||||||
'region': {'type': 'voice_region', 'required': False},
|
'region': {'type': 'voice_region', 'required': False},
|
||||||
'icon': {'type': 'icon', 'required': False},
|
'icon': {'type': 'b64_icon', 'required': False},
|
||||||
|
|
||||||
'verification_level': {'type': 'verification_level', 'required': False},
|
'verification_level': {'type': 'verification_level', 'required': False},
|
||||||
'default_message_notifications': {
|
'default_message_notifications': {
|
||||||
|
|
@ -173,13 +251,93 @@ GUILD_UPDATE = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CHAN_OVERWRITE = {
|
||||||
|
'id': {'coerce': int},
|
||||||
|
'type': {'type': 'string', 'allowed': ['role', 'member']},
|
||||||
|
'allow': {'coerce': Permissions},
|
||||||
|
'deny': {'coerce': Permissions}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
CHAN_UPDATE = {
|
||||||
|
'name': {
|
||||||
|
'type': 'string', 'minlength': 2,
|
||||||
|
'maxlength': 100, 'required': False},
|
||||||
|
|
||||||
|
'position': {'coerce': int, 'required': False},
|
||||||
|
|
||||||
|
'topic': {
|
||||||
|
'type': 'string', 'minlength': 0,
|
||||||
|
'maxlength': 1024, 'required': False},
|
||||||
|
|
||||||
|
'nsfw': {'type': 'boolean', 'required': False},
|
||||||
|
'rate_limit_per_user': {
|
||||||
|
'coerce': int, 'min': 0,
|
||||||
|
'max': 120, 'required': False},
|
||||||
|
|
||||||
|
'bitrate': {
|
||||||
|
'coerce': int, 'min': 8000,
|
||||||
|
|
||||||
|
# NOTE: 'max' is 96000 for non-vip guilds
|
||||||
|
'max': 128000, 'required': False},
|
||||||
|
|
||||||
|
'user_limit': {
|
||||||
|
# user_limit being 0 means infinite.
|
||||||
|
'coerce': int, 'min': 0,
|
||||||
|
'max': 99, 'required': False
|
||||||
|
},
|
||||||
|
|
||||||
|
'permission_overwrites': {
|
||||||
|
'type': 'list',
|
||||||
|
'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE},
|
||||||
|
'required': False
|
||||||
|
},
|
||||||
|
|
||||||
|
'parent_id': {'coerce': int, 'required': False, 'nullable': True}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ROLE_CREATE = {
|
||||||
|
'name': {'type': 'string', 'default': 'new role'},
|
||||||
|
'permissions': {'coerce': Permissions, 'nullable': True},
|
||||||
|
'color': {'coerce': Color, 'default': 0},
|
||||||
|
'hoist': {'type': 'boolean', 'default': False},
|
||||||
|
'mentionable': {'type': 'boolean', 'default': False},
|
||||||
|
}
|
||||||
|
|
||||||
|
ROLE_UPDATE = {
|
||||||
|
'name': {'type': 'string', 'required': False},
|
||||||
|
'permissions': {'coerce': Permissions, 'required': False},
|
||||||
|
'color': {'coerce': Color, 'required': False},
|
||||||
|
'hoist': {'type': 'boolean', 'required': False},
|
||||||
|
'mentionable': {'type': 'boolean', 'required': False},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ROLE_UPDATE_POSITION = {
|
||||||
|
'roles': {
|
||||||
|
'type': 'list',
|
||||||
|
'schema': {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': {
|
||||||
|
'id': {'coerce': int},
|
||||||
|
'position': {'coerce': int},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
MEMBER_UPDATE = {
|
MEMBER_UPDATE = {
|
||||||
'nick': {
|
'nick': {
|
||||||
'type': 'nickname',
|
'type': 'username',
|
||||||
'minlength': 1, 'maxlength': 100,
|
'minlength': 1, 'maxlength': 100,
|
||||||
'required': False,
|
'required': False,
|
||||||
},
|
},
|
||||||
'roles': {'type': 'list', 'required': False},
|
'roles': {'type': 'list', 'required': False,
|
||||||
|
'schema': {'coerce': int}},
|
||||||
'mute': {'type': 'boolean', 'required': False},
|
'mute': {'type': 'boolean', 'required': False},
|
||||||
'deaf': {'type': 'boolean', 'required': False},
|
'deaf': {'type': 'boolean', 'required': False},
|
||||||
'channel_id': {'type': 'snowflake', 'required': False},
|
'channel_id': {'type': 'snowflake', 'required': False},
|
||||||
|
|
@ -196,57 +354,60 @@ MESSAGE_CREATE = {
|
||||||
|
|
||||||
|
|
||||||
GW_ACTIVITY = {
|
GW_ACTIVITY = {
|
||||||
'name': {'type': 'string', 'required': True},
|
'type': 'dict',
|
||||||
'type': {'type': 'activity_type', 'required': True},
|
'schema': {
|
||||||
|
'name': {'type': 'string', 'required': True},
|
||||||
|
'type': {'type': 'activity_type', 'required': True},
|
||||||
|
|
||||||
'url': {'type': 'string', 'required': False, 'nullable': True},
|
'url': {'type': 'string', 'required': False, 'nullable': True},
|
||||||
|
|
||||||
'timestamps': {
|
'timestamps': {
|
||||||
'type': 'dict',
|
'type': 'dict',
|
||||||
'required': False,
|
'required': False,
|
||||||
'schema': {
|
'schema': {
|
||||||
'start': {'type': 'number', 'required': True},
|
'start': {'type': 'number', 'required': True},
|
||||||
'end': {'type': 'number', 'required': True},
|
'end': {'type': 'number', 'required': False},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
|
||||||
|
|
||||||
'application_id': {'type': 'snowflake', 'required': False,
|
'application_id': {'type': 'snowflake', 'required': False,
|
||||||
'nullable': False},
|
'nullable': False},
|
||||||
'details': {'type': 'string', 'required': False, 'nullable': True},
|
'details': {'type': 'string', 'required': False, 'nullable': True},
|
||||||
'state': {'type': 'string', 'required': False, 'nullable': True},
|
'state': {'type': 'string', 'required': False, 'nullable': True},
|
||||||
|
|
||||||
'party': {
|
'party': {
|
||||||
'type': 'dict',
|
'type': 'dict',
|
||||||
'required': False,
|
'required': False,
|
||||||
'schema': {
|
'schema': {
|
||||||
'id': {'type': 'snowflake', 'required': False},
|
'id': {'type': 'snowflake', 'required': False},
|
||||||
'size': {'type': 'list', 'required': False},
|
'size': {'type': 'list', 'required': False},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
'assets': {
|
'assets': {
|
||||||
'type': 'dict',
|
'type': 'dict',
|
||||||
'required': False,
|
'required': False,
|
||||||
'schema': {
|
'schema': {
|
||||||
'large_image': {'type': 'snowflake', 'required': False},
|
'large_image': {'type': 'snowflake', 'required': False},
|
||||||
'large_text': {'type': 'string', 'required': False},
|
'large_text': {'type': 'string', 'required': False},
|
||||||
'small_image': {'type': 'snowflake', 'required': False},
|
'small_image': {'type': 'snowflake', 'required': False},
|
||||||
'small_text': {'type': 'string', 'required': False},
|
'small_text': {'type': 'string', 'required': False},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
'secrets': {
|
'secrets': {
|
||||||
'type': 'dict',
|
'type': 'dict',
|
||||||
'required': False,
|
'required': False,
|
||||||
'schema': {
|
'schema': {
|
||||||
'join': {'type': 'string', 'required': False},
|
'join': {'type': 'string', 'required': False},
|
||||||
'spectate': {'type': 'string', 'required': False},
|
'spectate': {'type': 'string', 'required': False},
|
||||||
'match': {'type': 'string', 'required': False},
|
'match': {'type': 'string', 'required': False},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
'instance': {'type': 'boolean', 'required': False},
|
'instance': {'type': 'boolean', 'required': False},
|
||||||
'flags': {'type': 'number', 'required': False},
|
'flags': {'type': 'number', 'required': False},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GW_STATUS_UPDATE = {
|
GW_STATUS_UPDATE = {
|
||||||
|
|
@ -335,6 +496,8 @@ USER_SETTINGS = {
|
||||||
'show_current_game': {'type': 'boolean', 'required': False},
|
'show_current_game': {'type': 'boolean', 'required': False},
|
||||||
|
|
||||||
'timezone_offset': {'type': 'number', 'required': False},
|
'timezone_offset': {'type': 'number', 'required': False},
|
||||||
|
|
||||||
|
'status': {'type': 'status_external', 'required': False}
|
||||||
}
|
}
|
||||||
|
|
||||||
RELATIONSHIP = {
|
RELATIONSHIP = {
|
||||||
|
|
@ -395,3 +558,7 @@ GUILD_SETTINGS = {
|
||||||
'required': False,
|
'required': False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GUILD_PRUNE = {
|
||||||
|
'days': {'type': 'number', 'coerce': int, 'min': 1}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,9 @@ from logbook import Logger
|
||||||
|
|
||||||
from .enums import ChannelType, RelationshipType
|
from .enums import ChannelType, RelationshipType
|
||||||
from .schemas import USER_MENTION, ROLE_MENTION
|
from .schemas import USER_MENTION, ROLE_MENTION
|
||||||
|
from litecord.blueprints.channel.reactions import (
|
||||||
|
emoji_info_from_str, EmojiType, emoji_sql, partial_emoji
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -163,17 +166,40 @@ class Storage:
|
||||||
WHERE guild_id = $1 and user_id = $2
|
WHERE guild_id = $1 and user_id = $2
|
||||||
""", guild_id, member_id)
|
""", guild_id, member_id)
|
||||||
|
|
||||||
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
|
async def get_member_role_ids(self, guild_id: int,
|
||||||
members_roles = await self.db.fetch("""
|
member_id: int) -> List[int]:
|
||||||
|
"""Get a list of role IDs that are on a member."""
|
||||||
|
roles = await self.db.fetch("""
|
||||||
SELECT role_id::text
|
SELECT role_id::text
|
||||||
FROM member_roles
|
FROM member_roles
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
""", guild_id, member_id)
|
""", guild_id, member_id)
|
||||||
|
|
||||||
|
roles = [r['role_id'] for r in roles]
|
||||||
|
|
||||||
|
try:
|
||||||
|
roles.remove(str(guild_id))
|
||||||
|
except ValueError:
|
||||||
|
# if the @everyone role isn't in, we add it
|
||||||
|
# to member_roles automatically (it won't
|
||||||
|
# be shown on the API, though).
|
||||||
|
await self.db.execute("""
|
||||||
|
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
""", member_id, guild_id, guild_id)
|
||||||
|
|
||||||
|
return list(map(str, roles))
|
||||||
|
|
||||||
|
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
|
||||||
|
roles = await self.get_member_role_ids(guild_id, member_id)
|
||||||
return {
|
return {
|
||||||
'user': await self.get_user(member_id),
|
'user': await self.get_user(member_id),
|
||||||
'nick': row['nickname'],
|
'nick': row['nickname'],
|
||||||
'roles': [row[0] for row in members_roles],
|
|
||||||
|
# we don't send the @everyone role's id to
|
||||||
|
# the user since it is known that everyone has
|
||||||
|
# that role.
|
||||||
|
'roles': roles,
|
||||||
'joined_at': row['joined_at'].isoformat(),
|
'joined_at': row['joined_at'].isoformat(),
|
||||||
'deaf': row['deafened'],
|
'deaf': row['deafened'],
|
||||||
'mute': row['muted'],
|
'mute': row['muted'],
|
||||||
|
|
@ -289,7 +315,7 @@ class Storage:
|
||||||
WHERE channels.id = $1
|
WHERE channels.id = $1
|
||||||
""", channel_id)
|
""", channel_id)
|
||||||
|
|
||||||
async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
|
async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
|
||||||
overwrite_rows = await self.db.fetch("""
|
overwrite_rows = await self.db.fetch("""
|
||||||
SELECT target_type, target_role, target_user, allow, deny
|
SELECT target_type, target_role, target_user, allow, deny
|
||||||
FROM channel_overwrites
|
FROM channel_overwrites
|
||||||
|
|
@ -298,18 +324,20 @@ class Storage:
|
||||||
|
|
||||||
def _overwrite_convert(row):
|
def _overwrite_convert(row):
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
drow['type'] = drow['target_type']
|
|
||||||
|
target_type = drow['target_type']
|
||||||
|
drow['type'] = 'user' if target_type == 0 else 'role'
|
||||||
|
|
||||||
# if type is 0, the overwrite is for a user
|
# if type is 0, the overwrite is for a user
|
||||||
# if type is 1, the overwrite is for a role
|
# if type is 1, the overwrite is for a role
|
||||||
drow['id'] = {
|
drow['id'] = {
|
||||||
0: drow['target_user'],
|
0: drow['target_user'],
|
||||||
1: drow['target_role'],
|
1: drow['target_role'],
|
||||||
}[drow['type']]
|
}[target_type]
|
||||||
|
|
||||||
drow['id'] = str(drow['id'])
|
drow['id'] = str(drow['id'])
|
||||||
|
|
||||||
drow.pop('overwrite_type')
|
drow.pop('target_type')
|
||||||
drow.pop('target_user')
|
drow.pop('target_user')
|
||||||
drow.pop('target_role')
|
drow.pop('target_role')
|
||||||
|
|
||||||
|
|
@ -335,8 +363,8 @@ class Storage:
|
||||||
dbase['type'] = chan_type
|
dbase['type'] = chan_type
|
||||||
|
|
||||||
res = await self._channels_extra(dbase)
|
res = await self._channels_extra(dbase)
|
||||||
res['permission_overwrites'] = \
|
res['permission_overwrites'] = await self.chan_overwrites(
|
||||||
list(await self._chan_overwrites(channel_id))
|
channel_id)
|
||||||
|
|
||||||
res['id'] = str(res['id'])
|
res['id'] = str(res['id'])
|
||||||
return res
|
return res
|
||||||
|
|
@ -401,8 +429,8 @@ class Storage:
|
||||||
|
|
||||||
res = await self._channels_extra(drow)
|
res = await self._channels_extra(drow)
|
||||||
|
|
||||||
res['permission_overwrites'] = \
|
res['permission_overwrites'] = await self.chan_overwrites(
|
||||||
list(await self._chan_overwrites(row['id']))
|
row['id'])
|
||||||
|
|
||||||
# Making sure.
|
# Making sure.
|
||||||
res['id'] = str(res['id'])
|
res['id'] = str(res['id'])
|
||||||
|
|
@ -440,6 +468,7 @@ class Storage:
|
||||||
permissions, managed, mentionable
|
permissions, managed, mentionable
|
||||||
FROM roles
|
FROM roles
|
||||||
WHERE guild_id = $1
|
WHERE guild_id = $1
|
||||||
|
ORDER BY position ASC
|
||||||
""", guild_id)
|
""", guild_id)
|
||||||
|
|
||||||
return list(map(dict, roledata))
|
return list(map(dict, roledata))
|
||||||
|
|
@ -535,7 +564,70 @@ class Storage:
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_message(self, message_id: int) -> Dict:
|
async def get_reactions(self, message_id: int, user_id=None) -> List:
|
||||||
|
"""Get all reactions in a message."""
|
||||||
|
reactions = await self.db.fetch("""
|
||||||
|
SELECT user_id, emoji_type, emoji_id, emoji_text
|
||||||
|
FROM message_reactions
|
||||||
|
WHERE message_id = $1
|
||||||
|
ORDER BY react_ts
|
||||||
|
""", message_id)
|
||||||
|
|
||||||
|
# ordered list of emoji
|
||||||
|
emoji = []
|
||||||
|
|
||||||
|
# the current state of emoji info
|
||||||
|
react_stats = {}
|
||||||
|
|
||||||
|
# to generate the list, we pass through all
|
||||||
|
# all reactions and insert them all.
|
||||||
|
|
||||||
|
# we can't use a set() because that
|
||||||
|
# doesn't guarantee any order.
|
||||||
|
for row in reactions:
|
||||||
|
etype = EmojiType(row['emoji_type'])
|
||||||
|
eid, etext = row['emoji_id'], row['emoji_text']
|
||||||
|
|
||||||
|
# get the main key to use, given
|
||||||
|
# the emoji information
|
||||||
|
_, main_emoji = emoji_sql(etype, eid, etext)
|
||||||
|
|
||||||
|
if main_emoji in emoji:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# maintain order (first reacted comes first
|
||||||
|
# on the reaction list)
|
||||||
|
emoji.append(main_emoji)
|
||||||
|
|
||||||
|
react_stats[main_emoji] = {
|
||||||
|
'count': 0,
|
||||||
|
'me': False,
|
||||||
|
'emoji': partial_emoji(etype, eid, etext)
|
||||||
|
}
|
||||||
|
|
||||||
|
# then the 2nd pass, where we insert
|
||||||
|
# the info for each reaction in the react_stats
|
||||||
|
# dictionary
|
||||||
|
for row in reactions:
|
||||||
|
etype = EmojiType(row['emoji_type'])
|
||||||
|
eid, etext = row['emoji_id'], row['emoji_text']
|
||||||
|
|
||||||
|
# same thing as the last loop,
|
||||||
|
# extracting main key
|
||||||
|
_, main_emoji = emoji_sql(etype, eid, etext)
|
||||||
|
|
||||||
|
stats = react_stats[main_emoji]
|
||||||
|
stats['count'] += 1
|
||||||
|
|
||||||
|
if row['user_id'] == user_id:
|
||||||
|
stats['me'] = True
|
||||||
|
|
||||||
|
# after processing reaction counts,
|
||||||
|
# we get them in the same order
|
||||||
|
# they were defined in the first loop.
|
||||||
|
return list(map(react_stats.get, emoji))
|
||||||
|
|
||||||
|
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
||||||
"""Get a single message's payload."""
|
"""Get a single message's payload."""
|
||||||
row = await self.db.fetchrow("""
|
row = await self.db.fetchrow("""
|
||||||
SELECT id::text, channel_id::text, author_id, content,
|
SELECT id::text, channel_id::text, author_id, content,
|
||||||
|
|
@ -596,6 +688,8 @@ class Storage:
|
||||||
res['mention_roles'] = await self._msg_regex(
|
res['mention_roles'] = await self._msg_regex(
|
||||||
ROLE_MENTION, _get_role_mention, content)
|
ROLE_MENTION, _get_role_mention, content)
|
||||||
|
|
||||||
|
res['reactions'] = await self.get_reactions(message_id, user_id)
|
||||||
|
|
||||||
# TODO: handle webhook authors
|
# TODO: handle webhook authors
|
||||||
res['author'] = await self.get_user(res['author_id'])
|
res['author'] = await self.get_user(res['author_id'])
|
||||||
res.pop('author_id')
|
res.pop('author_id')
|
||||||
|
|
@ -606,9 +700,6 @@ class Storage:
|
||||||
# TODO: res['embeds']
|
# TODO: res['embeds']
|
||||||
res['embeds'] = []
|
res['embeds'] = []
|
||||||
|
|
||||||
# TODO: res['reactions']
|
|
||||||
res['reactions'] = []
|
|
||||||
|
|
||||||
# TODO: res['pinned']
|
# TODO: res['pinned']
|
||||||
res['pinned'] = False
|
res['pinned'] = False
|
||||||
|
|
||||||
|
|
@ -966,7 +1057,6 @@ class Storage:
|
||||||
""", user_id)
|
""", user_id)
|
||||||
|
|
||||||
for row in settings:
|
for row in settings:
|
||||||
print(dict(row))
|
|
||||||
gid = int(row['guild_id'])
|
gid = int(row['guild_id'])
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
|
||||||
|
class Color:
|
||||||
|
"""Custom color class"""
|
||||||
|
def __init__(self, val: int):
|
||||||
|
self.blue = val & 255
|
||||||
|
self.green = (val >> 8) & 255
|
||||||
|
self.red = (val >> 16) & 255
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
"""Give the actual RGB integer encoding this color."""
|
||||||
|
return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16)
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return self.value
|
||||||
|
|
@ -22,3 +22,18 @@ async def task_wrapper(name: str, coro):
|
||||||
pass
|
pass
|
||||||
except:
|
except:
|
||||||
log.exception('{} task error', name)
|
log.exception('{} task error', name)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_get(mapping, key, default):
|
||||||
|
"""Return `default` even when mapping[key] is None."""
|
||||||
|
return mapping.get(key) or default
|
||||||
|
|
||||||
|
|
||||||
|
def index_by_func(function, indexable: iter) -> int:
|
||||||
|
"""Search in an idexable and return the index number
|
||||||
|
for an iterm that has func(item) = True."""
|
||||||
|
for index, item in enumerate(indexable):
|
||||||
|
if function(item):
|
||||||
|
return index
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from manage.main import main
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.exit(main(config))
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from .command import setup as migration
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import asyncpg
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
Migration = namedtuple('Migration', 'id name path')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MigrationContext:
|
||||||
|
"""Hold information about migration."""
|
||||||
|
migration_folder: Path
|
||||||
|
scripts: Dict[int, Migration]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def latest(self):
|
||||||
|
"""Return the latest migration ID."""
|
||||||
|
return max(self.scripts.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def make_migration_ctx() -> MigrationContext:
|
||||||
|
"""Create the MigrationContext instance."""
|
||||||
|
# taken from https://stackoverflow.com/a/6628348
|
||||||
|
script_path = inspect.stack()[0][1]
|
||||||
|
script_folder = '/'.join(script_path.split('/')[:-1])
|
||||||
|
script_folder = Path(script_folder)
|
||||||
|
|
||||||
|
migration_folder = script_folder / 'scripts'
|
||||||
|
|
||||||
|
mctx = MigrationContext(migration_folder, {})
|
||||||
|
|
||||||
|
for mig_path in migration_folder.glob('*.sql'):
|
||||||
|
mig_path_str = str(mig_path)
|
||||||
|
|
||||||
|
# extract migration script id and name
|
||||||
|
mig_filename = mig_path_str.split('/')[-1].split('.')[0]
|
||||||
|
name_fragments = mig_filename.split('_')
|
||||||
|
|
||||||
|
mig_id = int(name_fragments[0])
|
||||||
|
mig_name = '_'.join(name_fragments[1:])
|
||||||
|
|
||||||
|
mctx.scripts[mig_id] = Migration(
|
||||||
|
mig_id, mig_name, mig_path)
|
||||||
|
|
||||||
|
return mctx
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_changelog(app, ctx):
|
||||||
|
# make sure we have the migration table up
|
||||||
|
|
||||||
|
try:
|
||||||
|
await app.db.execute("""
|
||||||
|
CREATE TABLE migration_log (
|
||||||
|
change_num bigint NOT NULL,
|
||||||
|
|
||||||
|
apply_ts timestamp without time zone default
|
||||||
|
(now() at time zone 'utc'),
|
||||||
|
|
||||||
|
description text,
|
||||||
|
|
||||||
|
PRIMARY KEY (change_num)
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
# if we were able to create the
|
||||||
|
# migration_log table, insert that we are
|
||||||
|
# on the latest version.
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO migration_log (change_num, description)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", ctx.latest, 'migration setup')
|
||||||
|
except asyncpg.DuplicateTableError:
|
||||||
|
log.debug('existing migration table')
|
||||||
|
|
||||||
|
|
||||||
|
async def apply_migration(app, migration: Migration):
|
||||||
|
"""Apply a single migration."""
|
||||||
|
migration_sql = migration.path.read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
try:
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO migration_log (change_num, description)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", migration.id, f'migration: {migration.name}')
|
||||||
|
except asyncpg.UniqueViolationError:
|
||||||
|
log.warning('already applied {}', migration.id)
|
||||||
|
return
|
||||||
|
|
||||||
|
await app.db.execute(migration_sql)
|
||||||
|
log.info('applied {}', migration.id)
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_cmd(app, _args):
|
||||||
|
"""Main migration command.
|
||||||
|
|
||||||
|
This makes sure the database
|
||||||
|
is updated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ctx = make_migration_ctx()
|
||||||
|
|
||||||
|
await _ensure_changelog(app, ctx)
|
||||||
|
|
||||||
|
# local point in the changelog
|
||||||
|
local_change = await app.db.fetchval("""
|
||||||
|
SELECT max(change_num)
|
||||||
|
FROM migration_log
|
||||||
|
""")
|
||||||
|
|
||||||
|
local_change = local_change or 0
|
||||||
|
latest_change = ctx.latest
|
||||||
|
|
||||||
|
log.debug('local: {}, latest: {}', local_change, latest_change)
|
||||||
|
|
||||||
|
if local_change == latest_change:
|
||||||
|
print('no changes to do, exiting')
|
||||||
|
return
|
||||||
|
|
||||||
|
# we do local_change + 1 so we start from the
|
||||||
|
# next migration to do, end in latest_change + 1
|
||||||
|
# because of how range() works.
|
||||||
|
for idx in range(local_change + 1, latest_change + 1):
|
||||||
|
migration = ctx.scripts.get(idx)
|
||||||
|
|
||||||
|
print('applying', migration.id, migration.name)
|
||||||
|
await apply_migration(app, migration)
|
||||||
|
|
||||||
|
|
||||||
|
def setup(subparser):
|
||||||
|
migrate_parser = subparser.add_parser(
|
||||||
|
'migrate',
|
||||||
|
help='Run migration tasks',
|
||||||
|
description=migrate_cmd.__doc__
|
||||||
|
)
|
||||||
|
|
||||||
|
migrate_parser.set_defaults(func=migrate_cmd)
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
-- unused tables
|
||||||
|
DROP TABLE message_embeds;
|
||||||
|
DROP TABLE embeds;
|
||||||
|
|
||||||
|
ALTER TABLE messages
|
||||||
|
ADD COLUMN embeds jsonb DEFAULT '[]'
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
from sys import argv
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
from run import init_app_managers, init_app_db
|
||||||
|
from manage.cmd.migration import migration
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeApp:
|
||||||
|
"""Fake app instance."""
|
||||||
|
config: dict
|
||||||
|
db = None
|
||||||
|
loop: asyncio.BaseEventLoop = None
|
||||||
|
ratelimiter = None
|
||||||
|
state_manager = None
|
||||||
|
storage = None
|
||||||
|
dispatcher = None
|
||||||
|
presence = None
|
||||||
|
|
||||||
|
|
||||||
|
def init_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
subparser = parser.add_subparsers(help='operations')
|
||||||
|
|
||||||
|
migration(subparser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main(config):
|
||||||
|
"""Start the script"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
cfg = getattr(config, config.MODE)
|
||||||
|
app = FakeApp(cfg.__dict__)
|
||||||
|
|
||||||
|
loop.run_until_complete(init_app_db(app))
|
||||||
|
init_app_managers(app)
|
||||||
|
|
||||||
|
# initialize argparser
|
||||||
|
parser = init_parser()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if len(argv) < 2:
|
||||||
|
parser.print_help()
|
||||||
|
return
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
loop.run_until_complete(args.func(app, args))
|
||||||
|
except Exception:
|
||||||
|
log.exception('error while running command')
|
||||||
|
finally:
|
||||||
|
loop.run_until_complete(app.db.close())
|
||||||
10
nginx.conf
10
nginx.conf
|
|
@ -5,13 +5,11 @@ server {
|
||||||
location / {
|
location / {
|
||||||
proxy_pass http://localhost:5000;
|
proxy_pass http://localhost:5000;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
# Main litecord websocket proxy.
|
# if you don't want to keep the gateway
|
||||||
server {
|
# domain as the main domain, you can
|
||||||
server_name websocket.somewhere;
|
# keep a separate server block
|
||||||
|
location /ws {
|
||||||
location / {
|
|
||||||
proxy_pass http://localhost:5001;
|
proxy_pass http://localhost:5001;
|
||||||
|
|
||||||
# those options are required for websockets
|
# those options are required for websockets
|
||||||
|
|
|
||||||
106
run.py
106
run.py
|
|
@ -9,9 +9,28 @@ from quart import Quart, g, jsonify, request
|
||||||
from logbook import StreamHandler, Logger
|
from logbook import StreamHandler, Logger
|
||||||
from logbook.compat import redirect_logging
|
from logbook.compat import redirect_logging
|
||||||
|
|
||||||
|
# import the config set by instance owner
|
||||||
import config
|
import config
|
||||||
from litecord.blueprints import gateway, auth, users, guilds, channels, \
|
|
||||||
webhooks, science, voice, invites, relationships, dms
|
from litecord.blueprints import (
|
||||||
|
gateway, auth, users, guilds, channels, webhooks, science,
|
||||||
|
voice, invites, relationships, dms
|
||||||
|
)
|
||||||
|
|
||||||
|
# those blueprints are separated from the "main" ones
|
||||||
|
# for code readability if people want to dig through
|
||||||
|
# the codebase.
|
||||||
|
from litecord.blueprints.guild import (
|
||||||
|
guild_roles, guild_members, guild_channels, guild_mod
|
||||||
|
)
|
||||||
|
|
||||||
|
from litecord.blueprints.channel import (
|
||||||
|
channel_messages, channel_reactions, channel_pins
|
||||||
|
)
|
||||||
|
|
||||||
|
from litecord.ratelimits.handler import ratelimit_handler
|
||||||
|
from litecord.ratelimits.main import RatelimitManager
|
||||||
|
|
||||||
from litecord.gateway import websocket_handler
|
from litecord.gateway import websocket_handler
|
||||||
from litecord.errors import LitecordError
|
from litecord.errors import LitecordError
|
||||||
from litecord.gateway.state_manager import StateManager
|
from litecord.gateway.state_manager import StateManager
|
||||||
|
|
@ -50,8 +69,18 @@ bps = {
|
||||||
auth: '/auth',
|
auth: '/auth',
|
||||||
users: '/users',
|
users: '/users',
|
||||||
relationships: '/users',
|
relationships: '/users',
|
||||||
|
|
||||||
guilds: '/guilds',
|
guilds: '/guilds',
|
||||||
|
guild_roles: '/guilds',
|
||||||
|
guild_members: '/guilds',
|
||||||
|
guild_channels: '/guilds',
|
||||||
|
guild_mod: '/guilds',
|
||||||
|
|
||||||
channels: '/channels',
|
channels: '/channels',
|
||||||
|
channel_messages: '/channels',
|
||||||
|
channel_reactions: '/channels',
|
||||||
|
channel_pins: '/channels',
|
||||||
|
|
||||||
webhooks: None,
|
webhooks: None,
|
||||||
science: None,
|
science: None,
|
||||||
voice: '/voice',
|
voice: '/voice',
|
||||||
|
|
@ -64,6 +93,11 @@ for bp, suffix in bps.items():
|
||||||
app.register_blueprint(bp, url_prefix=f'/api/v6{suffix}')
|
app.register_blueprint(bp, url_prefix=f'/api/v6{suffix}')
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
async def app_before_request():
|
||||||
|
await ratelimit_handler()
|
||||||
|
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
async def app_after_request(resp):
|
async def app_after_request(resp):
|
||||||
origin = request.headers.get('Origin', '*')
|
origin = request.headers.get('Origin', '*')
|
||||||
|
|
@ -80,19 +114,44 @@ async def app_after_request(resp):
|
||||||
# resp.headers['Access-Control-Allow-Methods'] = '*'
|
# resp.headers['Access-Control-Allow-Methods'] = '*'
|
||||||
resp.headers['Access-Control-Allow-Methods'] = \
|
resp.headers['Access-Control-Allow-Methods'] = \
|
||||||
resp.headers.get('allow', '*')
|
resp.headers.get('allow', '*')
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@app.before_serving
|
@app.after_request
|
||||||
async def app_before_serving():
|
async def app_set_ratelimit_headers(resp):
|
||||||
log.info('opening db')
|
"""Set the specific ratelimit headers."""
|
||||||
|
try:
|
||||||
|
bucket = request.bucket
|
||||||
|
|
||||||
|
if bucket is None:
|
||||||
|
raise AttributeError()
|
||||||
|
|
||||||
|
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
|
||||||
|
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
|
||||||
|
resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second)
|
||||||
|
|
||||||
|
resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower()
|
||||||
|
|
||||||
|
# only add Retry-After if we actually hit a ratelimit
|
||||||
|
retry_after = request.retry_after
|
||||||
|
if request.retry_after:
|
||||||
|
resp.headers['Retry-After'] = str(retry_after)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
async def init_app_db(app):
|
||||||
|
"""Connect to databases"""
|
||||||
app.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
app.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
||||||
|
|
||||||
g.app = app
|
|
||||||
|
|
||||||
|
def init_app_managers(app):
|
||||||
|
"""Initialize singleton classes."""
|
||||||
app.loop = asyncio.get_event_loop()
|
app.loop = asyncio.get_event_loop()
|
||||||
g.loop = asyncio.get_event_loop()
|
app.ratelimiter = RatelimitManager()
|
||||||
|
|
||||||
app.state_manager = StateManager()
|
app.state_manager = StateManager()
|
||||||
app.storage = Storage(app.db)
|
app.storage = Storage(app.db)
|
||||||
|
|
||||||
|
|
@ -101,6 +160,17 @@ async def app_before_serving():
|
||||||
app.state_manager, app.dispatcher)
|
app.state_manager, app.dispatcher)
|
||||||
app.storage.presence = app.presence
|
app.storage.presence = app.presence
|
||||||
|
|
||||||
|
|
||||||
|
@app.before_serving
|
||||||
|
async def app_before_serving():
|
||||||
|
log.info('opening db')
|
||||||
|
await init_app_db(app)
|
||||||
|
|
||||||
|
g.app = app
|
||||||
|
g.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
init_app_managers(app)
|
||||||
|
|
||||||
# start the websocket, etc
|
# start the websocket, etc
|
||||||
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
||||||
log.info(f'starting websocket at {host} {port}')
|
log.info(f'starting websocket at {host} {port}')
|
||||||
|
|
@ -108,8 +178,11 @@ async def app_before_serving():
|
||||||
async def _wrapper(ws, url):
|
async def _wrapper(ws, url):
|
||||||
# We wrap the main websocket_handler
|
# We wrap the main websocket_handler
|
||||||
# so we can pass quart's app object.
|
# so we can pass quart's app object.
|
||||||
|
|
||||||
|
# TODO: pass just the app object
|
||||||
await websocket_handler((app.db, app.state_manager, app.storage,
|
await websocket_handler((app.db, app.state_manager, app.storage,
|
||||||
app.loop, app.dispatcher, app.presence),
|
app.loop, app.dispatcher, app.presence,
|
||||||
|
app.ratelimiter),
|
||||||
ws, url)
|
ws, url)
|
||||||
|
|
||||||
ws_future = websockets.serve(_wrapper, host, port)
|
ws_future = websockets.serve(_wrapper, host, port)
|
||||||
|
|
@ -119,6 +192,15 @@ async def app_before_serving():
|
||||||
|
|
||||||
@app.after_serving
|
@app.after_serving
|
||||||
async def app_after_serving():
|
async def app_after_serving():
|
||||||
|
"""Shutdown tasks for the server."""
|
||||||
|
|
||||||
|
# first close all clients, then close db
|
||||||
|
tasks = app.state_manager.gen_close_tasks()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.wait(tasks, loop=app.loop)
|
||||||
|
|
||||||
|
app.state_manager.close()
|
||||||
|
|
||||||
log.info('closing db')
|
log.info('closing db')
|
||||||
await app.db.close()
|
await app.db.close()
|
||||||
|
|
||||||
|
|
@ -130,9 +212,13 @@ async def handle_litecord_err(err):
|
||||||
except IndexError:
|
except IndexError:
|
||||||
ejson = {}
|
ejson = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
ejson['code'] = err.error_code
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'error': True,
|
'error': True,
|
||||||
# 'code': err.code,
|
|
||||||
'status': err.status_code,
|
'status': err.status_code,
|
||||||
'message': err.message,
|
'message': err.message,
|
||||||
**ejson
|
**ejson
|
||||||
|
|
|
||||||
44
schema.sql
44
schema.sql
|
|
@ -75,6 +75,9 @@ CREATE TABLE IF NOT EXISTS users (
|
||||||
phone varchar(60) DEFAULT '',
|
phone varchar(60) DEFAULT '',
|
||||||
password_hash text NOT NULL,
|
password_hash text NOT NULL,
|
||||||
|
|
||||||
|
-- store the last time the user logged in via the gateway
|
||||||
|
last_session timestamp without time zone default (now() at time zone 'utc'),
|
||||||
|
|
||||||
PRIMARY KEY (id, username, discriminator)
|
PRIMARY KEY (id, username, discriminator)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -131,6 +134,10 @@ CREATE TABLE IF NOT EXISTS user_settings (
|
||||||
|
|
||||||
-- appearance
|
-- appearance
|
||||||
message_display_compact bool DEFAULT false,
|
message_display_compact bool DEFAULT false,
|
||||||
|
|
||||||
|
-- for now we store status but don't
|
||||||
|
-- actively use it, since the official client
|
||||||
|
-- sends its own presence on IDENTIFY
|
||||||
status text DEFAULT 'online' NOT NULL,
|
status text DEFAULT 'online' NOT NULL,
|
||||||
theme text DEFAULT 'dark' NOT NULL,
|
theme text DEFAULT 'dark' NOT NULL,
|
||||||
developer_mode bool DEFAULT true,
|
developer_mode bool DEFAULT true,
|
||||||
|
|
@ -328,7 +335,7 @@ CREATE TABLE IF NOT EXISTS channel_overwrites (
|
||||||
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||||
|
|
||||||
-- target_type = 0 -> use target_user
|
-- target_type = 0 -> use target_user
|
||||||
-- target_type = 1 -> user target_role
|
-- target_type = 1 -> use target_role
|
||||||
-- discord already has overwrite.type = 'role' | 'member'
|
-- discord already has overwrite.type = 'role' | 'member'
|
||||||
-- so this allows us to be more compliant with the API
|
-- so this allows us to be more compliant with the API
|
||||||
target_type integer default null,
|
target_type integer default null,
|
||||||
|
|
@ -344,11 +351,15 @@ CREATE TABLE IF NOT EXISTS channel_overwrites (
|
||||||
-- they're bigints (64bits), discord,
|
-- they're bigints (64bits), discord,
|
||||||
-- for now, only needs 53.
|
-- for now, only needs 53.
|
||||||
allow bigint DEFAULT 0,
|
allow bigint DEFAULT 0,
|
||||||
deny bigint DEFAULT 0,
|
deny bigint DEFAULT 0
|
||||||
|
|
||||||
PRIMARY KEY (channel_id, target_role, target_user)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- columns in private keys can't have NULL values,
|
||||||
|
-- so instead we use a custom constraint with UNIQUE
|
||||||
|
|
||||||
|
ALTER TABLE channel_overwrites ADD CONSTRAINT channel_overwrites_uniq
|
||||||
|
UNIQUE (channel_id, target_role, target_user);
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS features (
|
CREATE TABLE IF NOT EXISTS features (
|
||||||
id serial PRIMARY KEY,
|
id serial PRIMARY KEY,
|
||||||
|
|
@ -479,11 +490,6 @@ CREATE TABLE IF NOT EXISTS bans (
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS embeds (
|
|
||||||
-- TODO: this table
|
|
||||||
id bigint PRIMARY KEY
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
id bigint PRIMARY KEY,
|
id bigint PRIMARY KEY,
|
||||||
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||||
|
|
@ -504,6 +510,8 @@ CREATE TABLE IF NOT EXISTS messages (
|
||||||
tts bool default false,
|
tts bool default false,
|
||||||
mention_everyone bool default false,
|
mention_everyone bool default false,
|
||||||
|
|
||||||
|
embeds jsonb DEFAULT '[]',
|
||||||
|
|
||||||
nonce bigint default 0,
|
nonce bigint default 0,
|
||||||
|
|
||||||
message_type int NOT NULL
|
message_type int NOT NULL
|
||||||
|
|
@ -515,22 +523,22 @@ CREATE TABLE IF NOT EXISTS message_attachments (
|
||||||
PRIMARY KEY (message_id, attachment)
|
PRIMARY KEY (message_id, attachment)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS message_embeds (
|
|
||||||
message_id bigint REFERENCES messages (id) UNIQUE,
|
|
||||||
embed_id bigint REFERENCES embeds (id),
|
|
||||||
PRIMARY KEY (message_id, embed_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS message_reactions (
|
CREATE TABLE IF NOT EXISTS message_reactions (
|
||||||
message_id bigint REFERENCES messages (id),
|
message_id bigint REFERENCES messages (id),
|
||||||
user_id bigint REFERENCES users (id),
|
user_id bigint REFERENCES users (id),
|
||||||
|
|
||||||
-- since it can be a custom emote, or unicode emoji
|
react_ts timestamp without time zone default (now() at time zone 'utc'),
|
||||||
|
|
||||||
|
-- emoji_type = 0 -> custom emoji
|
||||||
|
-- emoji_type = 1 -> unicode emoji
|
||||||
|
emoji_type int DEFAULT 0,
|
||||||
emoji_id bigint REFERENCES guild_emoji (id),
|
emoji_id bigint REFERENCES guild_emoji (id),
|
||||||
emoji_text text NOT NULL,
|
emoji_text text
|
||||||
PRIMARY KEY (message_id, user_id, emoji_id, emoji_text)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
ALTER TABLE message_reactions ADD CONSTRAINT message_reactions_main_uniq
|
||||||
|
UNIQUE (message_id, user_id, emoji_id, emoji_text);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS channel_pins (
|
CREATE TABLE IF NOT EXISTS channel_pins (
|
||||||
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||||
message_id bigint REFERENCES messages (id) ON DELETE CASCADE,
|
message_id bigint REFERENCES messages (id) ON DELETE CASCADE,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue