Merge branch 'lazy-guilds' into 'master'

Lazy guilds

Closes #12, #9, #8, and #2

See merge request luna/litecord!2
This commit is contained in:
Luna Mendes 2018-11-08 02:12:45 +00:00
commit 6245f08289
46 changed files with 4159 additions and 849 deletions

View File

@ -7,8 +7,16 @@ This project is a rewrite of [litecord-reference].
[litecord-reference]: https://gitlab.com/luna/litecord-reference [litecord-reference]: https://gitlab.com/luna/litecord-reference
## Notes
- There are no testing being run on the current codebase. Which means the code is definitely unstable.
- No voice is planned to be developed, for now.
- You must figure out connecting to the server yourself. Litecord will not distribute
Discord's official client code nor provide ways to modify the client.
## Install ## Install
Requirements:
- Python 3.6 or higher - Python 3.6 or higher
- PostgreSQL - PostgreSQL
- [Pipenv] - [Pipenv]
@ -28,6 +36,10 @@ $ psql -f schema.sql litecord
# edit config.py as you wish # edit config.py as you wish
$ cp config.example.py config.py $ cp config.example.py config.py
# run database migrations (this is a
# required step in setup)
$ pipenv run ./manage.py migrate
# Install all packages: # Install all packages:
$ pipenv install --dev $ pipenv install --dev
``` ```
@ -42,3 +54,10 @@ Use `--access-log -` to output access logs to stdout.
```sh ```sh
$ pipenv run hypercorn run:app $ pipenv run hypercorn run:app
``` ```
## Updating
```sh
$ git pull
$ pipenv run ./manage.py migrate
```

View File

@ -13,7 +13,11 @@ log = Logger(__name__)
async def raw_token_check(token, db=None): async def raw_token_check(token, db=None):
db = db or app.db db = db or app.db
user_id, _hmac = token.split('.')
# just try by fragments instead of
# unpacking
fragments = token.split('.')
user_id = fragments[0]
try: try:
user_id = base64.b64decode(user_id.encode()) user_id = base64.b64decode(user_id.encode())
@ -35,6 +39,17 @@ async def raw_token_check(token, db=None):
try: try:
signer.unsign(token) signer.unsign(token)
log.debug('login for uid {} successful', user_id) log.debug('login for uid {} successful', user_id)
# update the user's last_session field
# so that we can keep an exact track of activity,
# even on long-lived single sessions (that can happen
# with people leaving their clients open forever)
await db.execute("""
UPDATE users
SET last_session = (now() at time zone 'utc')
WHERE id = $1
""", user_id)
return user_id return user_id
except BadSignature: except BadSignature:
log.warning('token failed for uid {}', user_id) log.warning('token failed for uid {}', user_id)
@ -43,6 +58,12 @@ async def raw_token_check(token, db=None):
async def token_check(): async def token_check():
"""Check token information.""" """Check token information."""
# first, check if the request info already has a uid
try:
return request.user_id
except AttributeError:
pass
try: try:
token = request.headers['Authorization'] token = request.headers['Authorization']
except KeyError: except KeyError:
@ -51,4 +72,6 @@ async def token_check():
if token.startswith('Bot '): if token.startswith('Bot '):
token = token.replace('Bot ', '') token = token.replace('Bot ', '')
return await raw_token_check(token) user_id = await raw_token_check(token)
request.user_id = user_id
return user_id

View File

@ -65,7 +65,7 @@ async def register():
new_id = get_snowflake() new_id = get_snowflake()
new_discrim = str(random.randint(1, 9999)) new_discrim = random.randint(1, 9999)
new_discrim = '%04d' % new_discrim new_discrim = '%04d' % new_discrim
pwd_hash = await hash_data(password) pwd_hash = await hash_data(password)

View File

@ -0,0 +1,3 @@
from .messages import bp as channel_messages
from .reactions import bp as channel_reactions
from .pins import bp as channel_pins

View File

@ -0,0 +1,276 @@
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.blueprints.dms import try_dm_state
from litecord.errors import MessageNotFound, Forbidden, BadRequest
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.snowflake import get_snowflake
from litecord.schemas import validate, MESSAGE_CREATE
log = Logger(__name__)
bp = Blueprint('channel_messages', __name__)
def extract_limit(request, default: int = 50):
try:
limit = int(request.args.get('limit', default))
if limit not in range(0, 100):
raise ValueError()
except (TypeError, ValueError):
raise BadRequest('limit not int')
return limit
def query_tuple_from_args(args: dict, limit: int) -> tuple:
before, after = None, None
if 'around' in request.args:
average = int(limit / 2)
around = int(request.args['around'])
after = around - average
before = around + average
elif 'before' in request.args:
before = int(request.args['before'])
elif 'after' in request.args:
before = int(request.args['after'])
return before, after
@bp.route('/<int:channel_id>/messages', methods=['GET'])
async def get_messages(channel_id):
user_id = await token_check()
# TODO: check READ_MESSAGE_HISTORY permission
ctype, peer_id = await channel_check(user_id, channel_id)
if ctype == ChannelType.DM:
# make sure both parties will be subbed
# to a dm
await _dm_pre_dispatch(channel_id, user_id)
await _dm_pre_dispatch(channel_id, peer_id)
limit = extract_limit(request, 50)
where_clause = ''
before, after = query_tuple_from_args(request.args, limit)
if before:
where_clause += f'AND id < {before}'
if after:
where_clause += f'AND id > {after}'
message_ids = await app.db.fetch(f"""
SELECT id
FROM messages
WHERE channel_id = $1 {where_clause}
ORDER BY id DESC
LIMIT {limit}
""", channel_id)
result = []
for message_id in message_ids:
msg = await app.storage.get_message(message_id['id'], user_id)
if msg is None:
continue
result.append(msg)
log.info('Fetched {} messages', len(result))
return jsonify(result)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
async def get_single_message(channel_id, message_id):
user_id = await token_check()
await channel_check(user_id, channel_id)
# TODO: check READ_MESSAGE_HISTORY permissions
message = await app.storage.get_message(message_id, user_id)
if not message:
raise MessageNotFound()
return jsonify(message)
async def _dm_pre_dispatch(channel_id, peer_id):
"""Do some checks pre-MESSAGE_CREATE so we
make sure the receiving party will handle everything."""
# check the other party's dm_channel_state
dm_state = await app.db.fetchval("""
SELECT dm_id
FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2
""", peer_id, channel_id)
if dm_state:
# the peer already has the channel
# opened, so we don't need to do anything
return
dm_chan = await app.storage.get_channel(channel_id)
# dispatch CHANNEL_CREATE so the client knows which
# channel the future event is about
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
# subscribe the peer to the channel
await app.dispatcher.sub('channel', channel_id, peer_id)
# insert it on dm_channel_state so the client
# is subscribed on the future
await try_dm_state(peer_id, channel_id)
@bp.route('/<int:channel_id>/messages', methods=['POST'])
async def create_message(channel_id):
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
if ctype in GUILD_CHANS:
await channel_perm_check(user_id, channel_id, 'send_messages')
j = validate(await request.get_json(), MESSAGE_CREATE)
message_id = get_snowflake()
# TODO: check connection to the gateway
mentions_everyone = ('@everyone' in j['content'] and
await channel_perm_check(
user_id, channel_id, 'mention_everyone', False
)
)
is_tts = (j.get('tts', False) and
await channel_perm_check(
user_id, channel_id, 'send_tts_messages', False
))
await app.db.execute(
"""
INSERT INTO messages (id, channel_id, author_id, content, tts,
mention_everyone, nonce, message_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""",
message_id,
channel_id,
user_id,
j['content'],
is_tts,
mentions_everyone,
int(j.get('nonce', 0)),
MessageType.DEFAULT.value
)
payload = await app.storage.get_message(message_id, user_id)
if ctype == ChannelType.DM:
# guild id here is the peer's ID.
await _dm_pre_dispatch(channel_id, user_id)
await _dm_pre_dispatch(channel_id, guild_id)
await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_CREATE', payload)
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
if ctype == ChannelType.GUILD_TEXT:
for str_uid in payload['mentions']:
uid = int(str_uid)
await app.db.execute("""
UPDATE user_read_state
SET mention_count += 1
WHERE user_id = $1 AND channel_id = $2
""", uid, channel_id)
return jsonify(payload)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
async def edit_message(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval("""
SELECT author_id FROM messages
WHERE messages.id = $1
""", message_id)
if not author_id == user_id:
raise Forbidden('You can not edit this message')
j = await request.get_json()
updated = 'content' in j or 'embed' in j
if 'content' in j:
await app.db.execute("""
UPDATE messages
SET content=$1
WHERE messages.id = $2
""", j['content'], message_id)
# TODO: update embed
message = await app.storage.get_message(message_id, user_id)
# only dispatch MESSAGE_UPDATE if we actually had any update to start with
if updated:
await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_UPDATE', message)
return jsonify(message)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
async def delete_message(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval("""
SELECT author_id FROM messages
WHERE messages.id = $1
""", message_id)
by_perm = await channel_perm_check(
user_id, channel_id, 'manage_messages', False
)
by_ownership = author_id == user_id
if not by_perm and not by_ownership:
raise Forbidden('You can not delete this message')
await app.db.execute("""
DELETE FROM messages
WHERE messages.id = $1
""", message_id)
await app.dispatcher.dispatch(
'channel', channel_id,
'MESSAGE_DELETE', {
'id': str(message_id),
'channel_id': str(channel_id),
# for lazy guilds
'guild_id': str(guild_id),
})
return '', 204

View File

@ -0,0 +1,93 @@
from quart import Blueprint, current_app as app, request, jsonify
from litecord.auth import token_check
from litecord.blueprints.checks import channel_check
from litecord.snowflake import snowflake_datetime
bp = Blueprint('channel_pins', __name__)
@bp.route('/<int:channel_id>/pins', methods=['GET'])
async def get_pins(channel_id):
"""Get the pins for a channel"""
user_id = await token_check()
await channel_check(user_id, channel_id)
ids = await app.db.fetch("""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
""", channel_id)
ids = [r['message_id'] for r in ids]
res = []
for message_id in ids:
message = await app.storage.get_message(message_id)
if message is not None:
res.append(message)
return jsonify(res)
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
async def add_pin(channel_id, message_id):
"""Add a pin to a channel"""
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
# TODO: check MANAGE_MESSAGES permission
await app.db.execute("""
INSERT INTO channel_pins (channel_id, message_id)
VALUES ($1, $2)
""", channel_id, message_id)
row = await app.db.fetchrow("""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
""", channel_id)
timestamp = snowflake_datetime(row['message_id'])
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
'channel_id': str(channel_id),
'last_pin_timestamp': timestamp.isoformat()
})
return '', 204
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
async def delete_pin(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
# TODO: check MANAGE_MESSAGES permission
await app.db.execute("""
DELETE FROM channel_pins
WHERE channel_id = $1 AND message_id = $2
""", channel_id, message_id)
row = await app.db.fetchrow("""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
""", channel_id)
timestamp = snowflake_datetime(row['message_id'])
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
'channel_id': str(channel_id),
'last_pin_timestamp': timestamp.isoformat()
})
return '', 204

View File

@ -0,0 +1,225 @@
from enum import IntEnum
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from litecord.utils import async_map
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check
from litecord.blueprints.channel.messages import (
query_tuple_from_args, extract_limit
)
from litecord.errors import MessageNotFound, Forbidden, BadRequest
from litecord.enums import GUILD_CHANS
log = Logger(__name__)
bp = Blueprint('channel_reactions', __name__)
BASEPATH = '/<int:channel_id>/messages/<int:message_id>/reactions'
class EmojiType(IntEnum):
CUSTOM = 0
UNICODE = 1
def emoji_info_from_str(emoji: str) -> tuple:
"""Extract emoji information from an emoji string
given on the reaction endpoints."""
# custom emoji have an emoji of name:id
# unicode emoji just have the raw unicode.
# try checking if the emoji is custom or unicode
emoji_type = 0 if ':' in emoji else 1
emoji_type = EmojiType(emoji_type)
# extract the emoji id OR the unicode value of the emoji
# depending if it is custom or not
emoji_id = (int(emoji.split(':')[1])
if emoji_type == EmojiType.CUSTOM
else emoji)
emoji_name = emoji.split(':')[0]
return emoji_type, emoji_id, emoji_name
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
print(emoji_type, emoji_id, emoji_name)
return {
'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
}
def _make_payload(user_id, channel_id, message_id, partial):
return {
'user_id': str(user_id),
'channel_id': str(channel_id),
'message_id': str(message_id),
'emoji': partial
}
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['PUT'])
async def add_reaction(channel_id: int, message_id: int, emoji: str):
"""Put a reaction."""
user_id = await token_check()
# TODO: check READ_MESSAGE_HISTORY permission
# and ADD_REACTIONS. look on route docs.
ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
await app.db.execute(
"""
INSERT INTO message_reactions (message_id, user_id,
emoji_type, emoji_id, emoji_text)
VALUES ($1, $2, $3, $4, $5)
""", message_id, user_id, emoji_type,
# if it is custom, we put the emoji_id on emoji_id
# column, if it isn't, we put it on emoji_text
# column.
emoji_id if emoji_type == EmojiType.CUSTOM else None,
emoji_id if emoji_type == EmojiType.UNICODE else None
)
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_ADD', payload)
return '', 204
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
"""Extract SQL clauses to search for specific emoji
in the message_reactions table."""
param = f'${param}'
# know which column to filter with
where_ext = (f'AND emoji_id = {param}'
if emoji_type == EmojiType.CUSTOM else
f'AND emoji_text = {param}')
# which emoji to remove (custom or unicode)
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
return where_ext, main_emoji
def _emoji_sql_simple(emoji: str, param=4):
"""Simpler version of _emoji_sql for functions that
don't need the results from emoji_info_from_str."""
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
async def remove_reaction(channel_id: int, message_id: int,
user_id: int, emoji: str):
ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
where_ext, main_emoji = emoji_sql(emoji_type, emoji_id, emoji_name)
await app.db.execute(
f"""
DELETE FROM message_reactions
WHERE message_id = $1
AND user_id = $2
AND emoji_type = $3
{where_ext}
""", message_id, user_id, emoji_type, main_emoji)
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload)
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['DELETE'])
async def remove_own_reaction(channel_id, message_id, emoji):
"""Remove a reaction."""
user_id = await token_check()
await remove_reaction(channel_id, message_id, user_id, emoji)
return '', 204
@bp.route(f'{BASEPATH}/<emoji>/<int:other_id>', methods=['DELETE'])
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
"""Remove a reaction made by another user."""
await token_check()
# TODO: check MANAGE_MESSAGES permission (and use user_id
# from token_check to do it)
await remove_reaction(channel_id, message_id, other_id, emoji)
return '', 204
@bp.route(f'{BASEPATH}/<emoji>', methods=['GET'])
async def list_users_reaction(channel_id, message_id, emoji):
"""Get the list of all users who reacted with a certain emoji."""
user_id = await token_check()
# this is not using either ctype or guild_id
# that are returned by channel_check
await channel_check(user_id, channel_id)
limit = extract_limit(request, 25)
before, after = query_tuple_from_args(request.args, limit)
before_clause = 'AND user_id < $2' if before else ''
after_clause = 'AND user_id > $3' if after else ''
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
rows = await app.db.fetch(f"""
SELECT user_id
FROM message_reactions
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
""", message_id, before, after, main_emoji)
user_ids = [r['user_id'] for r in rows]
users = await async_map(app.storage.get_user, user_ids)
return jsonify(users)
@bp.route(f'{BASEPATH}', methods=['DELETE'])
async def remove_all_reactions(channel_id, message_id):
"""Remove all reactions in a message."""
user_id = await token_check()
# TODO: check MANAGE_MESSAGES permission
ctype, guild_id = await channel_check(user_id, channel_id)
await app.db.execute("""
DELETE FROM message_reactions
WHERE message_id = $1
""", message_id)
payload = {
'channel_id': str(channel_id),
'message_id': str(message_id),
}
if ctype in GUILD_CHANS:
payload['guild_id'] = str(guild_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload)

View File

@ -3,14 +3,14 @@ import time
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
from ..auth import token_check from litecord.auth import token_check
from ..snowflake import get_snowflake, snowflake_datetime from litecord.enums import ChannelType, GUILD_CHANS
from ..enums import ChannelType, MessageType, GUILD_CHANS from litecord.errors import ChannelNotFound
from ..errors import Forbidden, ChannelNotFound, MessageNotFound from litecord.schemas import (
from ..schemas import validate, MESSAGE_CREATE validate, CHAN_UPDATE, CHAN_OVERWRITE
)
from .checks import channel_check, guild_check from litecord.blueprints.checks import channel_check, channel_perm_check
from .dms import try_dm_state
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('channels', __name__) bp = Blueprint('channels', __name__)
@ -136,6 +136,7 @@ async def guild_cleanup(channel_id):
@bp.route('/<int:channel_id>', methods=['DELETE']) @bp.route('/<int:channel_id>', methods=['DELETE'])
async def close_channel(channel_id): async def close_channel(channel_id):
"""Close or delete a channel."""
user_id = await token_check() user_id = await token_check()
chan_type = await app.storage.get_chan_type(channel_id) chan_type = await app.storage.get_chan_type(channel_id)
@ -212,287 +213,199 @@ async def close_channel(channel_id):
# TODO: group dm # TODO: group dm
pass pass
return '', 404 raise ChannelNotFound()
@bp.route('/<int:channel_id>/messages', methods=['GET']) async def _update_pos(channel_id, pos: int):
async def get_messages(channel_id): await app.db.execute("""
user_id = await token_check() UPDATE guild_channels
await channel_check(user_id, channel_id) SET position = $1
WHERE id = $2
# TODO: before, after, around keys """, pos, channel_id)
message_ids = await app.db.fetch(f"""
SELECT id
FROM messages
WHERE channel_id = $1
ORDER BY id DESC
LIMIT 100
""", channel_id)
result = []
for message_id in message_ids:
msg = await app.storage.get_message(message_id['id'])
if msg is None:
continue
result.append(msg)
log.info('Fetched {} messages', len(result))
return jsonify(result)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET']) async def _mass_chan_update(guild_id, channel_ids: int):
async def get_single_message(channel_id, message_id): for channel_id in channel_ids:
user_id = await token_check() chan = await app.storage.get_channel(channel_id)
await channel_check(user_id, channel_id) await app.dispatcher.dispatch(
'guild', guild_id, 'CHANNEL_UPDATE', chan)
# TODO: check READ_MESSAGE_HISTORY permissions
message = await app.storage.get_message(message_id)
if not message:
raise MessageNotFound()
return jsonify(message)
async def _dm_pre_dispatch(channel_id, peer_id): async def _process_overwrites(channel_id: int, overwrites: list):
"""Do some checks pre-MESSAGE_CREATE so we for overwrite in overwrites:
make sure the receiving party will handle everything."""
# check the other party's dm_channel_state # 0 for user overwrite, 1 for role overwrite
target_type = 0 if overwrite['type'] == 'user' else 1
target_role = None if target_type == 0 else overwrite['id']
target_user = overwrite['id'] if target_type == 0 else None
dm_state = await app.db.fetchval(""" await app.db.execute(
SELECT dm_id """
FROM dm_channel_state INSERT INTO channel_overwrites
WHERE user_id = $1 AND dm_id = $2 (channel_id, target_type, target_role,
""", peer_id, channel_id) target_user, allow, deny)
VALUES
if dm_state: ($1, $2, $3, $4, $5, $6)
# the peer already has the channel ON CONFLICT ON CONSTRAINT channel_overwrites_uniq
# opened, so we don't need to do anything DO
return UPDATE
SET allow = $5, deny = $6
dm_chan = await app.storage.get_channel(channel_id) WHERE channel_overwrites.channel_id = $1
AND channel_overwrites.target_type = $2
# dispatch CHANNEL_CREATE so the client knows which AND channel_overwrites.target_role = $3
# channel the future event is about AND channel_overwrites.target_user = $4
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) """,
channel_id, target_type,
# subscribe the peer to the channel target_role, target_user,
await app.dispatcher.sub('channel', channel_id, peer_id) overwrite['allow'], overwrite['deny'])
# insert it on dm_channel_state so the client
# is subscribed on the future
await try_dm_state(peer_id, channel_id)
@bp.route('/<int:channel_id>/messages', methods=['POST']) @bp.route('/<int:channel_id>/permissions/<int:overwrite_id>', methods=['PUT'])
async def create_message(channel_id): async def put_channel_overwrite(channel_id: int, overwrite_id: int):
"""Insert or modify a channel overwrite."""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
j = validate(await request.get_json(), MESSAGE_CREATE) if ctype not in GUILD_CHANS:
message_id = get_snowflake() raise ChannelNotFound('Only usable for guild channels.')
# TODO: check SEND_MESSAGES permission await channel_perm_check(user_id, guild_id, 'manage_roles')
# TODO: check connection to the gateway
await app.db.execute( j = validate(
""" # inserting a fake id on the payload so validation passes through
INSERT INTO messages (id, channel_id, author_id, content, tts, {**await request.get_json(), **{'id': -1}},
mention_everyone, nonce, message_type) CHAN_OVERWRITE
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""",
message_id,
channel_id,
user_id,
j['content'],
# TODO: check SEND_TTS_MESSAGES
j.get('tts', False),
# TODO: check MENTION_EVERYONE permissions
'@everyone' in j['content'],
int(j.get('nonce', 0)),
MessageType.DEFAULT.value
) )
payload = await app.storage.get_message(message_id) await _process_overwrites(channel_id, [{
'allow': j['allow'],
if ctype == ChannelType.DM: 'deny': j['deny'],
# guild id here is the peer's ID. 'type': j['type'],
await _dm_pre_dispatch(channel_id, guild_id) 'id': overwrite_id
}])
await app.dispatcher.dispatch('channel', channel_id, await _mass_chan_update(guild_id, [channel_id])
'MESSAGE_CREATE', payload) return '', 204
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
if ctype == ChannelType.GUILD_TEXT:
for str_uid in payload['mentions']:
uid = int(str_uid)
await app.db.execute("""
UPDATE user_read_state
SET mention_count += 1
WHERE user_id = $1 AND channel_id = $2
""", uid, channel_id)
return jsonify(payload)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH']) async def _update_channel_common(channel_id, guild_id: int, j: dict):
async def edit_message(channel_id, message_id): if 'name' in j:
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval("""
SELECT author_id FROM messages
WHERE messages.id = $1
""", message_id)
if not author_id == user_id:
raise Forbidden('You can not edit this message')
j = await request.get_json()
updated = 'content' in j or 'embed' in j
if 'content' in j:
await app.db.execute(""" await app.db.execute("""
UPDATE messages UPDATE guild_channels
SET content=$1 SET name = $1
WHERE messages.id = $2 WHERE id = $2
""", j['content'], message_id) """, j['name'], channel_id)
# TODO: update embed if 'position' in j:
channel_data = await app.storage.get_channel_data(guild_id)
message = await app.storage.get_message(message_id) chans = [None * len(channel_data)]
for chandata in channel_data:
chans.insert(chandata['position'], int(chandata['id']))
# only dispatch MESSAGE_UPDATE if we actually had any update to start with # are we changing to the left or to the right?
if updated:
await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_UPDATE', message)
return jsonify(message) # left: [channel1, channel2, ..., channelN-1, channelN]
# becomes
# [channel1, channelN-1, channel2, ..., channelN]
# so we can say that the "main change" is
# channelN-1 going to the position channel2
# was occupying.
current_pos = chans.index(channel_id)
new_pos = j['position']
# if the new position is bigger than the current one,
# we're making a left shift of all the channels that are
# beyond the current one, to make space
left_shift = new_pos > current_pos
# find all channels that we'll have to shift
shift_block = (chans[current_pos:new_pos]
if left_shift else
chans[new_pos:current_pos]
)
shift = -1 if left_shift else 1
# do the shift (to the left or to the right)
await app.db.executemany("""
UPDATE guild_channels
SET position = position + $1
WHERE id = $2
""", [(shift, chan_id) for chan_id in shift_block])
await _mass_chan_update(guild_id, shift_block)
# since theres now an empty slot, move current channel to it
await _update_pos(channel_id, new_pos)
if 'channel_overwrites' in j:
overwrites = j['channel_overwrites']
await _process_overwrites(channel_id, overwrites)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE']) async def _common_guild_chan(channel_id, j: dict):
async def delete_message(channel_id, message_id): # common updates to the guild_channels table
for field in [field for field in j.keys()
if field in ('nsfw', 'parent_id')]:
await app.db.execute(f"""
UPDATE guild_channels
SET {field} = $1
WHERE id = $2
""", j[field], channel_id)
async def _update_text_channel(channel_id: int, j: dict):
# first do the specific ones related to guild_text_channels
for field in [field for field in j.keys()
if field in ('topic', 'rate_limit_per_user')]:
await app.db.execute(f"""
UPDATE guild_text_channels
SET {field} = $1
WHERE id = $2
""", j[field], channel_id)
await _common_guild_chan(channel_id, j)
async def _update_voice_channel(channel_id: int, j: dict):
# first do the specific ones in guild_voice_channels
for field in [field for field in j.keys()
if field in ('bitrate', 'user_limit')]:
await app.db.execute(f"""
UPDATE guild_voice_channels
SET {field} = $1
WHERE id = $2
""", j[field], channel_id)
# yes, i'm letting voice channels have nsfw, you cant stop me
await _common_guild_chan(channel_id, j)
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
async def update_channel(channel_id):
"""Update a channel's information"""
user_id = await token_check() user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
author_id = await app.db.fetchval(""" if ctype not in GUILD_CHANS:
SELECT author_id FROM messages raise ChannelNotFound('Can not edit non-guild channels.')
WHERE messages.id = $1
""", message_id)
# TODO: MANAGE_MESSAGES permission check await channel_perm_check(user_id, channel_id, 'manage_channels')
if author_id != user_id: j = validate(await request.get_json(), CHAN_UPDATE)
raise Forbidden('You can not delete this message')
await app.db.execute(""" # TODO: categories?
DELETE FROM messages update_handler = {
WHERE messages.id = $1 ChannelType.GUILD_TEXT: _update_text_channel,
""", message_id) ChannelType.GUILD_VOICE: _update_voice_channel,
}[ctype]
await app.dispatcher.dispatch( await _update_channel_common(channel_id, guild_id, j)
'channel', channel_id, await update_handler(channel_id, j)
'MESSAGE_DELETE', {
'id': str(message_id),
'channel_id': str(channel_id),
# for lazy guilds chan = await app.storage.get_channel(channel_id)
'guild_id': str(guild_id), await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan)
}) return jsonify(chan)
return '', 204
@bp.route('/<int:channel_id>/pins', methods=['GET'])
async def get_pins(channel_id):
user_id = await token_check()
await channel_check(user_id, channel_id)
ids = await app.db.fetch("""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
""", channel_id)
ids = [r['message_id'] for r in ids]
res = []
for message_id in ids:
message = await app.storage.get_message(message_id)
if message is not None:
res.append(message)
return jsonify(message)
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
async def add_pin(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
# TODO: check MANAGE_MESSAGES permission
await app.db.execute("""
INSERT INTO channel_pins (channel_id, message_id)
VALUES ($1, $2)
""", channel_id, message_id)
row = await app.db.fetchrow("""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
""", channel_id)
timestamp = snowflake_datetime(row['message_id'])
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
'channel_id': str(channel_id),
'last_pin_timestamp': timestamp.isoformat()
})
return '', 204
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
async def delete_pin(channel_id, message_id):
user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id)
# TODO: check MANAGE_MESSAGES permission
await app.db.execute("""
DELETE FROM channel_pins
WHERE channel_id = $1 AND message_id = $2
""", channel_id, message_id)
row = await app.db.fetchrow("""
SELECT message_id
FROM channel_pins
WHERE channel_id = $1
ORDER BY message_id ASC
LIMIT 1
""", channel_id)
timestamp = snowflake_datetime(row['message_id'])
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
'channel_id': str(channel_id),
'last_pin_timestamp': timestamp.isoformat()
})
return '', 204
@bp.route('/<int:channel_id>/typing', methods=['POST']) @bp.route('/<int:channel_id>/typing', methods=['POST'])
@ -518,21 +431,18 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
if not message_id: if not message_id:
message_id = await app.storage.chan_last_message(channel_id) message_id = await app.storage.chan_last_message(channel_id)
res = await app.db.execute(""" await app.db.execute("""
UPDATE user_read_state INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count)
SET last_message_id = $1, VALUES
mention_count = 0 ($1, $2, $3, 0)
ON CONFLICT ON CONSTRAINT user_read_state_pkey
WHERE user_id = $2 AND channel_id = $3 DO
""", message_id, user_id, channel_id) UPDATE
SET last_message_id = $3, mention_count = 0
if res == 'UPDATE 0': WHERE user_read_state.user_id = $1
await app.db.execute(""" AND user_read_state.channel_id = $2
INSERT INTO user_read_state """, user_id, channel_id, message_id)
(user_id, channel_id, last_message_id, mention_count)
VALUES ($1, $2, $3, $4)
""", user_id, channel_id, message_id, 0)
if guild_id: if guild_id:
await app.dispatcher.dispatch_user_guild( await app.dispatcher.dispatch_user_guild(
@ -551,6 +461,7 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST']) @bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
async def ack_channel(channel_id, message_id): async def ack_channel(channel_id, message_id):
"""Acknowledge a channel."""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
@ -569,6 +480,7 @@ async def ack_channel(channel_id, message_id):
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE']) @bp.route('/<int:channel_id>/messages/ack', methods=['DELETE'])
async def delete_read_state(channel_id): async def delete_read_state(channel_id):
"""Delete the read state of a channel."""
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)

View File

@ -1,7 +1,10 @@
from quart import current_app as app from quart import current_app as app
from ..enums import ChannelType, GUILD_CHANS from litecord.enums import ChannelType, GUILD_CHANS
from ..errors import GuildNotFound, ChannelNotFound from litecord.errors import (
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
)
from litecord.permissions import base_permissions, get_permissions
async def guild_check(user_id: int, guild_id: int): async def guild_check(user_id: int, guild_id: int):
@ -16,6 +19,21 @@ async def guild_check(user_id: int, guild_id: int):
raise GuildNotFound('guild not found') raise GuildNotFound('guild not found')
async def guild_owner_check(user_id: int, guild_id: int):
"""Check if a user is the owner of the guild."""
owner_id = await app.db.fetchval("""
SELECT owner_id
FROM guilds
WHERE guilds.id = $1
""", guild_id)
if not owner_id:
raise GuildNotFound()
if user_id != owner_id:
raise Forbidden('You are not the owner of the guild')
async def channel_check(user_id, channel_id): async def channel_check(user_id, channel_id):
"""Check if the current user is authorized """Check if the current user is authorized
to read the channel's information.""" to read the channel's information."""
@ -39,3 +57,27 @@ async def channel_check(user_id, channel_id):
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
peer_id = await app.storage.get_dm_peer(channel_id, user_id) peer_id = await app.storage.get_dm_peer(channel_id, user_id)
return ctype, peer_id return ctype, peer_id
async def guild_perm_check(user_id, guild_id, permission: str):
"""Check guild permissions for a user."""
base_perms = await base_permissions(user_id, guild_id)
hasperm = getattr(base_perms.bits, permission)
if not hasperm:
raise MissingPermissions('Missing permissions.')
async def channel_perm_check(user_id, channel_id,
permission: str, raise_err=True):
"""Check channel permissions for a user."""
base_perms = await get_permissions(user_id, channel_id)
hasperm = getattr(base_perms.bits, permission)
print(base_perms)
print(base_perms.binary)
if not hasperm and raise_err:
raise MissingPermissions('Missing permissions.')
return hasperm

View File

@ -38,41 +38,47 @@ async def try_dm_state(user_id: int, dm_id: int):
""", user_id, dm_id) """, user_id, dm_id)
async def jsonify_dm(dm_id: int, user_id: int):
dm_chan = await app.storage.get_dm(dm_id, user_id)
return jsonify(dm_chan)
async def create_dm(user_id, recipient_id): async def create_dm(user_id, recipient_id):
"""Create a new dm with a user, """Create a new dm with a user,
or get the existing DM id if it already exists.""" or get the existing DM id if it already exists."""
dm_id = await app.db.fetchval("""
SELECT id
FROM dm_channels
WHERE (party1_id = $1 OR party2_id = $1) AND
(party1_id = $2 OR party2_id = $2)
""", user_id, recipient_id)
if dm_id:
return await jsonify_dm(dm_id, user_id)
# if no dm was found, create a new one
dm_id = get_snowflake() dm_id = get_snowflake()
await app.db.execute("""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", dm_id, ChannelType.DM.value)
try: await app.db.execute("""
await app.db.execute(""" INSERT INTO dm_channels (id, party1_id, party2_id)
INSERT INTO channels (id, channel_type) VALUES ($1, $2, $3)
VALUES ($1, $2) """, dm_id, user_id, recipient_id)
""", dm_id, ChannelType.DM.value)
await app.db.execute(""" # the dm state is something we use
INSERT INTO dm_channels (id, party1_id, party2_id) # to give the currently "open dms"
VALUES ($1, $2, $3) # on the client.
""", dm_id, user_id, recipient_id)
# the dm state is something we use # we don't open a dm for the peer/recipient
# to give the currently "open dms" # until the user sends a message.
# on the client. await try_dm_state(user_id, dm_id)
# we don't open a dm for the peer/recipient return await jsonify_dm(dm_id, user_id)
# until the user sends a message.
await try_dm_state(user_id, dm_id)
except UniqueViolationError:
# the dm already exists
dm_id = await app.db.fetchval("""
SELECT id
FROM dm_channels
WHERE (party1_id = $1 OR party2_id = $1) AND
(party2_id = $2 OR party2_id = $2)
""", user_id, recipient_id)
dm = await app.storage.get_dm(dm_id, user_id)
return jsonify(dm)
@bp.route('/@me/channels', methods=['POST']) @bp.route('/@me/channels', methods=['POST'])

View File

@ -1,3 +1,5 @@
import time
from quart import Blueprint, jsonify, current_app as app from quart import Blueprint, jsonify, current_app as app
from ..auth import token_check from ..auth import token_check
@ -6,12 +8,14 @@ bp = Blueprint('gateway', __name__)
def get_gw(): def get_gw():
"""Get the gateway's web"""
proto = 'wss://' if app.config['IS_SSL'] else 'ws://' proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
return f'{proto}{app.config["WEBSOCKET_URL"]}/ws' return f'{proto}{app.config["WEBSOCKET_URL"]}/ws'
@bp.route('/gateway') @bp.route('/gateway')
def api_gateway(): def api_gateway():
"""Get the raw URL."""
return jsonify({ return jsonify({
'url': get_gw() 'url': get_gw()
}) })
@ -27,9 +31,25 @@ async def api_gateway_bot():
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """, user_id)
shards = max(int(guild_count / 1200), 1) shards = max(int(guild_count / 1000), 1)
# get _ws.session ratelimit
ratelimit = app.ratelimiter.get_ratelimit('_ws.session')
bucket = ratelimit.get_bucket(user_id)
# timestamp of bucket reset
reset_ts = bucket._window + bucket.second
# how many seconds until bucket reset
reset_after_ts = reset_ts - time.time()
return jsonify({ return jsonify({
'url': get_gw(), 'url': get_gw(),
'shards': shards, 'shards': shards,
'session_start_limit': {
'total': bucket.requests,
'remaining': bucket._tokens,
'reset_after': int(reset_after_ts * 1000),
}
}) })

View File

@ -0,0 +1,4 @@
from .roles import bp as guild_roles
from .members import bp as guild_members
from .channels import bp as guild_channels
from .mod import bp as guild_mod

View File

@ -0,0 +1,180 @@
from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import guild_check, guild_owner_check
from litecord.snowflake import get_snowflake
from litecord.errors import BadRequest
from litecord.enums import ChannelType
from litecord.schemas import (
validate, ROLE_UPDATE_POSITION
)
from litecord.blueprints.guild.roles import gen_pairs
bp = Blueprint('guild_channels', __name__)
async def _specific_chan_create(channel_id, ctype, **kwargs):
if ctype == ChannelType.GUILD_TEXT:
await app.db.execute("""
INSERT INTO guild_text_channels (id, topic)
VALUES ($1, $2)
""", channel_id, kwargs.get('topic', ''))
elif ctype == ChannelType.GUILD_VOICE:
await app.db.execute(
"""
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
VALUES ($1, $2, $3)
""",
channel_id,
kwargs.get('bitrate', 64),
kwargs.get('user_limit', 0)
)
async def create_guild_channel(guild_id: int, channel_id: int,
ctype: ChannelType, **kwargs):
"""Create a channel in a guild."""
await app.db.execute("""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", channel_id, ctype.value)
# calc new pos
max_pos = await app.db.fetchval("""
SELECT MAX(position)
FROM guild_channels
WHERE guild_id = $1
""", guild_id)
# account for the first channel in a guild too
max_pos = max_pos or 0
# all channels go to guild_channels
await app.db.execute("""
INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4)
""", channel_id, guild_id, kwargs['name'], max_pos + 1)
# the rest of sql magic is dependant on the channel
# we're creating (a text or voice or category),
# so we use this function.
await _specific_chan_create(channel_id, ctype, **kwargs)
@bp.route('/<int:guild>/channels', methods=['GET'])
async def get_guild_channels(guild_id):
"""Get the list of channels in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
return jsonify(
await app.storage.get_channel_data(guild_id))
@bp.route('/<int:guild_id>/channels', methods=['POST'])
async def create_channel(guild_id):
"""Create a channel in a guild."""
user_id = await token_check()
j = await request.get_json()
# TODO: check permissions for MANAGE_CHANNELS
await guild_check(user_id, guild_id)
channel_type = j.get('type', ChannelType.GUILD_TEXT)
channel_type = ChannelType(channel_type)
if channel_type not in (ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE):
raise BadRequest('Invalid channel type')
new_channel_id = get_snowflake()
await create_guild_channel(
guild_id, new_channel_id, channel_type, **j)
# TODO: do a better method
# subscribe the currently subscribed users to the new channel
# by getting all user ids and subscribing each one by one.
# since GuildDispatcher calls Storage.get_channel_ids,
# it will subscribe all users to the newly created channel.
guild_pubsub = app.dispatcher.backends['guild']
user_ids = guild_pubsub.state[guild_id]
for uid in user_ids:
await app.dispatcher.sub('guild', guild_id, uid)
chan = await app.storage.get_channel(new_channel_id)
await app.dispatcher.dispatch_guild(
guild_id, 'CHANNEL_CREATE', chan)
return jsonify(chan)
async def _chan_update_dispatch(guild_id: int, channel_id: int):
"""Fetch new information about the channel and dispatch
a single CHANNEL_UPDATE event to the guild."""
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan)
async def _do_single_swap(guild_id: int, pair: tuple):
"""Do a single channel swap, dispatching
the CHANNEL_UPDATE events for after the swap"""
pair1, pair2 = pair
channel_1, new_pos_1 = pair1
channel_2, new_pos_2 = pair2
# do the swap in a transaction.
conn = await app.db.acquire()
async with conn.transaction():
await conn.executemany("""
UPDATE guild_channels
SET position = $1
WHERE id = $2 AND guild_id = $3
""", [
(new_pos_1, channel_1, guild_id),
(new_pos_2, channel_2, guild_id)])
await app.db.release(conn)
await _chan_update_dispatch(guild_id, channel_1)
await _chan_update_dispatch(guild_id, channel_2)
async def _do_channel_swaps(guild_id: int, swap_pairs: list):
"""Swap channel pairs' positions, given the list
of pairs to do.
Dispatches CHANNEL_UPDATEs to the guild.
"""
for pair in swap_pairs:
await _do_single_swap(guild_id, pair)
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
async def modify_channel_pos(guild_id):
"""Change positions of channels in a guild."""
user_id = await token_check()
# TODO: check MANAGE_CHANNELS
await guild_owner_check(user_id, guild_id)
# same thing as guild.roles, so we use
# the same schema and all.
raw_j = await request.get_json()
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
j = j['roles']
channels = await app.storage.get_channel_data(guild_id)
channel_positions = {chan['position']: int(chan['id'])
for chan in channels}
swap_pairs = gen_pairs(
j,
channel_positions
)
await _do_channel_swaps(guild_id, swap_pairs)
return '', 204

View File

@ -0,0 +1,172 @@
from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import guild_check
from litecord.errors import BadRequest
from litecord.schemas import (
validate, MEMBER_UPDATE
)
from litecord.blueprints.checks import guild_owner_check
bp = Blueprint('guild_members', __name__)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
async def get_guild_member(guild_id, member_id):
"""Get a member's information in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
member = await app.storage.get_single_member(guild_id, member_id)
return jsonify(member)
@bp.route('/<int:guild_id>/members', methods=['GET'])
async def get_members(guild_id):
"""Get members inside a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
j = await request.get_json()
limit, after = int(j.get('limit', 1)), j.get('after', 0)
if limit < 1 or limit > 1000:
raise BadRequest('limit not in 1-1000 range')
user_ids = await app.db.fetch(f"""
SELECT user_id
WHERE guild_id = $1, user_id > $2
LIMIT {limit}
ORDER BY user_id ASC
""", guild_id, after)
user_ids = [r[0] for r in user_ids]
members = await app.storage.get_member_multi(guild_id, user_ids)
return jsonify(members)
async def _update_member_roles(guild_id: int, member_id: int,
wanted_roles: list):
"""Update the roles a member has."""
# first, fetch all current roles
roles = await app.db.fetch("""
SELECT role_id from member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
roles = [r['role_id'] for r in roles]
roles = set(roles)
wanted_roles = set(wanted_roles)
# first, we need to find all added roles:
# roles that are on wanted_roles but
# not on roles
added_roles = wanted_roles - roles
# and then the removed roles
# which are roles in roles, but not
# in wanted_roles
removed_roles = roles - wanted_roles
conn = await app.db.acquire()
async with conn.transaction():
# add roles
await app.db.executemany("""
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
""", [(member_id, guild_id, role_id)
for role_id in added_roles])
# remove roles
await app.db.executemany("""
DELETE FROM member_roles
WHERE
user_id = $1
AND guild_id = $2
AND role_id = $3
""", [(member_id, guild_id, role_id)
for role_id in removed_roles])
await app.db.release(conn)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
async def modify_guild_member(guild_id, member_id):
"""Modify a members' information in a guild."""
user_id = await token_check()
await guild_owner_check(user_id, guild_id)
j = validate(await request.get_json(), MEMBER_UPDATE)
if 'nick' in j:
# TODO: check MANAGE_NICKNAMES
await app.db.execute("""
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
""", j['nick'], member_id, guild_id)
if 'mute' in j:
# TODO: check MUTE_MEMBERS
await app.db.execute("""
UPDATE members
SET muted = $1
WHERE user_id = $2 AND guild_id = $3
""", j['mute'], member_id, guild_id)
if 'deaf' in j:
# TODO: check DEAFEN_MEMBERS
await app.db.execute("""
UPDATE members
SET deafened = $1
WHERE user_id = $2 AND guild_id = $3
""", j['deaf'], member_id, guild_id)
if 'channel_id' in j:
# TODO: check MOVE_MEMBERS and CONNECT to the channel
# TODO: change the member's voice channel
pass
if 'roles' in j:
# TODO: check permissions
await _update_member_roles(guild_id, member_id, j['roles'])
member = await app.storage.get_member_data_one(guild_id, member_id)
member.pop('joined_at')
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
'guild_id': str(guild_id)
}, **member})
return '', 204
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
async def update_nickname(guild_id):
"""Update a member's nickname in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
j = await request.get_json()
await app.db.execute("""
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
""", j['nick'], user_id, guild_id)
member = await app.storage.get_member_data_one(guild_id, user_id)
member.pop('joined_at')
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
'guild_id': str(guild_id)
}, **member})
return j['nick']

View File

@ -0,0 +1,185 @@
from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import guild_owner_check
from litecord.schemas import validate, GUILD_PRUNE
bp = Blueprint('guild_moderation', __name__)
async def remove_member(guild_id: int, member_id: int):
"""Do common tasks related to deleting a member from the guild,
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
await app.db.execute("""
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', {
'guild_id': guild_id,
'unavailable': False,
})
await app.dispatcher.unsub('guild', guild_id, member_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
'guild_id': str(guild_id),
'user': await app.storage.get_user(member_id),
})
async def remove_member_multi(guild_id: int, members: list):
"""Remove multiple members."""
for member_id in members:
await remove_member(guild_id, member_id)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
async def kick_guild_member(guild_id, member_id):
"""Remove a member from a guild."""
user_id = await token_check()
# TODO: check KICK_MEMBERS permission
await guild_owner_check(user_id, guild_id)
await remove_member(guild_id, member_id)
return '', 204
@bp.route('/<int:guild_id>/bans', methods=['GET'])
async def get_bans(guild_id):
user_id = await token_check()
# TODO: check BAN_MEMBERS permission
await guild_owner_check(user_id, guild_id)
bans = await app.db.fetch("""
SELECT user_id, reason
FROM bans
WHERE bans.guild_id = $1
""", guild_id)
res = []
for ban in bans:
res.append({
'reason': ban['reason'],
'user': await app.storage.get_user(ban['user_id'])
})
return jsonify(res)
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
async def create_ban(guild_id, member_id):
user_id = await token_check()
# TODO: check BAN_MEMBERS permission
await guild_owner_check(user_id, guild_id)
j = await request.get_json()
await app.db.execute("""
INSERT INTO bans (guild_id, user_id, reason)
VALUES ($1, $2, $3)
""", guild_id, member_id, j.get('reason', ''))
await remove_member(guild_id, member_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {
'guild_id': str(guild_id),
'user': await app.storage.get_user(member_id)
})
return '', 204
@bp.route('/<int:guild_id>/bans/<int:banned_id>', methods=['DELETE'])
async def remove_ban(guild_id, banned_id):
user_id = await token_check()
# TODO: check BAN_MEMBERS permission
await guild_owner_check(guild_id, user_id)
res = await app.db.execute("""
DELETE FROM bans
WHERE guild_id = $1 AND user_id = $@
""", guild_id, banned_id)
# we don't really need to dispatch GUILD_BAN_REMOVE
# when no bans were actually removed.
if res == 'DELETE 0':
return '', 204
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', {
'guild_id': str(guild_id),
'user': await app.storage.get_user(banned_id)
})
return '', 204
async def get_prune(guild_id: int, days: int) -> list:
"""Get all members in a guild that:
- did not login in ``days`` days.
- don't have any roles.
"""
# a good solution would be in pure sql.
member_ids = await app.db.fetch(f"""
SELECT id
FROM users
JOIN members
ON members.guild_id = $1 AND members.user_id = users.id
WHERE users.last_session < (now() - (interval '{days} days'))
""", guild_id)
member_ids = [r['id'] for r in member_ids]
members = []
for member_id in member_ids:
role_count = await app.db.fetchval("""
SELECT COUNT(*)
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
if role_count == 0:
members.append(member_id)
return members
@bp.route('/<int:guild_id>/prune', methods=['GET'])
async def get_guild_prune_count(guild_id):
user_id = await token_check()
# TODO: check KICK_MEMBERS
await guild_owner_check(user_id, guild_id)
j = validate(await request.get_json(), GUILD_PRUNE)
days = j['days']
member_ids = await get_prune(guild_id, days)
return jsonify({
'pruned': len(member_ids),
})
@bp.route('/<int:guild_id>/prune', methods=['POST'])
async def begin_guild_prune(guild_id):
user_id = await token_check()
# TODO: check KICK_MEMBERS
await guild_owner_check(user_id, guild_id)
j = validate(await request.get_json(), GUILD_PRUNE)
days = j['days']
member_ids = await get_prune(guild_id, days)
app.loop.create_task(remove_member_multi(guild_id, member_ids))
return jsonify({
'pruned': len(member_ids)
})

View File

@ -0,0 +1,315 @@
from typing import List, Dict, Any, Union
from quart import Blueprint, request, current_app as app, jsonify
from litecord.auth import token_check
from litecord.blueprints.checks import (
guild_check, guild_owner_check
)
from litecord.schemas import (
validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
)
from litecord.snowflake import get_snowflake
from litecord.utils import dict_get
DEFAULT_EVERYONE_PERMS = 104324161
bp = Blueprint('guild_roles', __name__)
@bp.route('/<int:guild_id>/roles', methods=['GET'])
async def get_guild_roles(guild_id):
"""Get all roles in a guild."""
user_id = await token_check()
await guild_check(user_id, guild_id)
return jsonify(
await app.storage.get_role_data(guild_id)
)
async def create_role(guild_id, name: str, **kwargs):
"""Create a role in a guild."""
new_role_id = get_snowflake()
# TODO: use @everyone's perm number
default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS)
max_pos = await app.db.fetchval("""
SELECT MAX(position)
FROM roles
WHERE guild_id = $1
""", guild_id)
await app.db.execute(
"""
INSERT INTO roles (id, guild_id, name, color,
hoist, position, permissions, managed, mentionable)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""",
new_role_id,
guild_id,
name,
dict_get(kwargs, 'color', 0),
dict_get(kwargs, 'hoist', False),
# set position = 0 when there isn't any
# other role (when we're creating the
# @everyone role)
max_pos + 1 if max_pos is not None else 0,
int(dict_get(kwargs, 'permissions', default_perms)),
False,
dict_get(kwargs, 'mentionable', False)
)
role = await app.storage.get_role(new_role_id, guild_id)
await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_ROLE_CREATE', {
'guild_id': str(guild_id),
'role': role,
})
return role
@bp.route('/<int:guild_id>/roles', methods=['POST'])
async def create_guild_role(guild_id: int):
"""Add a role to a guild"""
user_id = await token_check()
# TODO: use check_guild and MANAGE_ROLES permission
await guild_owner_check(user_id, guild_id)
# client can just send null
j = validate(await request.get_json() or {}, ROLE_CREATE)
role_name = j['name']
j.pop('name')
role = await create_role(guild_id, role_name, **j)
return jsonify(role)
async def _role_update_dispatch(role_id: int, guild_id: int):
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
role = await app.storage.get_role(role_id, guild_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', {
'guild_id': str(guild_id),
'role': role,
})
return role
async def _role_pairs_update(guild_id: int, pairs: list):
"""Update the roles' positions.
Dispatches GUILD_ROLE_UPDATE for all roles being updated.
"""
for pair in pairs:
pair_1, pair_2 = pair
role_1, new_pos_1 = pair_1
role_2, new_pos_2 = pair_2
conn = await app.db.acquire()
async with conn.transaction():
# update happens in a transaction
# so we don't fuck it up
await conn.execute("""
UPDATE roles
SET position = $1
WHERE roles.id = $2
""", new_pos_1, role_1)
await conn.execute("""
UPDATE roles
SET position = $1
WHERE roles.id = $2
""", new_pos_2, role_2)
await app.db.release(conn)
# the route fires multiple Guild Role Update.
await _role_update_dispatch(role_1, guild_id)
await _role_update_dispatch(role_2, guild_id)
def gen_pairs(list_of_changes: List[Dict[str, int]],
current_state: Dict[int, int],
blacklist: List[int] = None) -> List[tuple]:
"""Generate a list of pairs that, when applied to the database,
will generate the desired state given in list_of_changes.
We must check if the given list_of_changes isn't overwriting an
element's (such as a role or a channel) position to an existing one,
without there having an already existing change for the other one.
Here's a pratical explanation with roles:
R1 (in position RP1) wants to be in the same position
as R2 (currently in position RP2).
So, if we did the simpler approach, list_of_changes
would just contain the preferred change: (R1, RP2).
With gen_pairs, there MUST be a (R2, RP1) in list_of_changes,
if there is, the given result in gen_pairs will be a pair
((R1, RP2), (R2, RP1)) which is then used to actually
update the roles' positions in a transaction.
Parameters
----------
list_of_changes:
A list of dictionaries with ``id`` and ``position``
fields, describing the preferred changes.
current_state:
Dictionary containing the current state of the list
of elements (roles or channels). Points position
to element ID.
blacklist:
List of IDs that shouldn't be moved.
Returns
-------
list
List of swaps to do to achieve the preferred
state given by ``list_of_changes``.
"""
pairs = []
blacklist = blacklist or []
preferred_state = {element['id']: element['position']
for element in list_of_changes}
for blacklisted_id in blacklist:
preferred_state.pop(blacklisted_id)
# for each change, we must find a matching change
# in the same list, so we can make a swap pair
for change in list_of_changes:
element_1, new_pos_1 = change['id'], change['position']
# check current pairs
# so we don't repeat an element
flag = False
for pair in pairs:
if (element_1, new_pos_1) in pair:
flag = True
# skip if found
if flag:
continue
# search if there is a role/channel in the
# position we want to change to
element_2 = current_state.get(new_pos_1)
# if there is, is that existing channel being
# swapped to another position?
new_pos_2 = preferred_state.get(element_2)
# if its being swapped to leave space, add it
# to the pairs list
if new_pos_2:
pairs.append(
((element_1, new_pos_1), (element_2, new_pos_2))
)
return pairs
@bp.route('/<int:guild_id>/roles', methods=['PATCH'])
async def update_guild_role_positions(guild_id):
"""Update the positions for a bunch of roles."""
user_id = await token_check()
# TODO: check MANAGE_ROLES
await guild_owner_check(user_id, guild_id)
raw_j = await request.get_json()
# we need to do this hackiness because thats
# cerberus for ya.
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
# extract the list out
j = j['roles']
all_roles = await app.storage.get_role_data(guild_id)
# we'll have to calculate pairs of changing roles,
# then do the changes, etc.
roles_pos = {role['position']: int(role['id']) for role in all_roles}
# TODO: check if the user can even change the roles in the first place,
# preferrably when we have a proper perms system.
pairs = gen_pairs(
j,
roles_pos,
# always ignore people trying to change
# the @everyone's role position
[guild_id]
)
await _role_pairs_update(guild_id, pairs)
# return the list of all roles back
return jsonify(await app.storage.get_role_data(guild_id))
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['PATCH'])
async def update_guild_role(guild_id, role_id):
"""Update a single role's information."""
user_id = await token_check()
# TODO: check MANAGE_ROLES
await guild_owner_check(user_id, guild_id)
j = validate(await request.get_json(), ROLE_UPDATE)
# we only update ints on the db, not Permissions
j['permissions'] = int(j['permissions'])
for field in j:
await app.db.execute(f"""
UPDATE roles
SET {field} = $1
WHERE roles.id = $2 AND roles.guild_id = $3
""", j[field], role_id, guild_id)
role = await _role_update_dispatch(role_id, guild_id)
return jsonify(role)
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['DELETE'])
async def delete_guild_role(guild_id, role_id):
"""Delete a role.
Dispatches GUILD_ROLE_DELETE.
"""
user_id = await token_check()
# TODO: check MANAGE_ROLES
await guild_owner_check(user_id, guild_id)
res = await app.db.execute("""
DELETE FROM roles
WHERE guild_id = $1 AND id = $2
""", guild_id, role_id)
if res == 'DELETE 0':
return '', 204
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', {
'guild_id': str(guild_id),
'role_id': str(role_id),
})
return '', 204

View File

@ -1,31 +1,23 @@
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from litecord.blueprints.guild.channels import create_guild_channel
from litecord.blueprints.guild.roles import (
create_role, DEFAULT_EVERYONE_PERMS
)
from ..auth import token_check from ..auth import token_check
from ..snowflake import get_snowflake from ..snowflake import get_snowflake
from ..enums import ChannelType from ..enums import ChannelType
from ..errors import Forbidden, GuildNotFound, BadRequest from ..schemas import (
from ..schemas import validate, GUILD_UPDATE validate, GUILD_CREATE, GUILD_UPDATE
)
from .channels import channel_ack from .channels import channel_ack
from .checks import guild_check from .checks import guild_check, guild_owner_check
bp = Blueprint('guilds', __name__) bp = Blueprint('guilds', __name__)
async def guild_owner_check(user_id: int, guild_id: int):
"""Check if a user is the owner of the guild."""
owner_id = await app.db.fetchval("""
SELECT owner_id
FROM guilds
WHERE guild_id = $1
""", guild_id)
if not owner_id:
raise GuildNotFound()
if user_id != owner_id:
raise Forbidden('You are not the owner of the guild')
async def create_guild_settings(guild_id: int, user_id: int): async def create_guild_settings(guild_id: int, user_id: int):
"""Create guild settings for the user """Create guild settings for the user
joining the guild.""" joining the guild."""
@ -48,10 +40,59 @@ async def create_guild_settings(guild_id: int, user_id: int):
""", m_notifs, user_id, guild_id) """, m_notifs, user_id, guild_id)
async def add_member(guild_id: int, user_id: int):
"""Add a user to a guild."""
await app.db.execute("""
INSERT INTO members (user_id, guild_id)
VALUES ($1, $2)
""", user_id, guild_id)
await create_guild_settings(guild_id, user_id)
async def guild_create_roles_prep(guild_id: int, roles: list):
"""Create roles in preparation in guild create."""
# by reaching this point in the code that means
# roles is not nullable, which means
# roles has at least one element, so we can access safely.
# the first member in the roles array
# are patches to the @everyone role
everyone_patches = roles[0]
for field in everyone_patches:
await app.db.execute(f"""
UPDATE roles
SET {field}={everyone_patches[field]}
WHERE roles.id = $1
""", guild_id)
default_perms = (everyone_patches.get('permissions')
or DEFAULT_EVERYONE_PERMS)
# from the 2nd and forward,
# should be treated as new roles
for role in roles[1:]:
await create_role(
guild_id, role['name'], default_perms=default_perms, **role
)
async def guild_create_channels_prep(guild_id: int, channels: list):
"""Create channels pre-guild create"""
for channel_raw in channels:
channel_id = get_snowflake()
ctype = ChannelType(channel_raw['type'])
await create_guild_channel(guild_id, channel_id, ctype)
@bp.route('', methods=['POST']) @bp.route('', methods=['POST'])
async def create_guild(): async def create_guild():
"""Create a new guild, assigning
the user creating it as the owner and
making them join."""
user_id = await token_check() user_id = await token_check()
j = await request.get_json() j = validate(await request.get_json(), GUILD_CREATE)
guild_id = get_snowflake() guild_id = get_snowflake()
@ -66,36 +107,37 @@ async def create_guild():
j.get('default_message_notifications', 0), j.get('default_message_notifications', 0),
j.get('explicit_content_filter', 0)) j.get('explicit_content_filter', 0))
await app.db.execute(""" await add_member(guild_id, user_id)
INSERT INTO members (user_id, guild_id)
VALUES ($1, $2)
""", user_id, guild_id)
await create_guild_settings(guild_id, user_id) # create the default @everyone role (everyone has it by default,
# so we don't insert that in the table)
# we also don't use create_role because the id of the role
# is the same as the id of the guild, and create_role
# generates a new snowflake.
await app.db.execute(""" await app.db.execute("""
INSERT INTO roles (id, guild_id, name, position, permissions) INSERT INTO roles (id, guild_id, name, position, permissions)
VALUES ($1, $2, $3, $4, $5) VALUES ($1, $2, $3, $4, $5)
""", guild_id, guild_id, '@everyone', 0, 104324161) """, guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS)
# add the @everyone role to the guild creator
await app.db.execute("""
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
""", user_id, guild_id, guild_id)
# create a single #general channel.
general_id = get_snowflake() general_id = get_snowflake()
await app.db.execute(""" await create_guild_channel(
INSERT INTO channels (id, channel_type) guild_id, general_id, ChannelType.GUILD_TEXT,
VALUES ($1, $2) name='general')
""", general_id, ChannelType.GUILD_TEXT.value)
await app.db.execute(""" if j.get('roles'):
INSERT INTO guild_channels (id, guild_id, name, position) await guild_create_roles_prep(guild_id, j['roles'])
VALUES ($1, $2, $3, $4)
""", general_id, guild_id, 'general', 0)
await app.db.execute(""" if j.get('channels'):
INSERT INTO guild_text_channels (id) await guild_create_channels_prep(guild_id, j['channels'])
VALUES ($1)
""", general_id)
# TODO: j['roles'] and j['channels']
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250) guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
@ -106,21 +148,22 @@ async def create_guild():
@bp.route('/<int:guild_id>', methods=['GET']) @bp.route('/<int:guild_id>', methods=['GET'])
async def get_guild(guild_id): async def get_guild(guild_id):
"""Get a single guilds' information."""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id)
gj = await app.storage.get_guild(guild_id, user_id) return jsonify(
gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) await app.storage.get_guild_full(guild_id, user_id, 250)
)
return jsonify({**gj, **gj_extra})
@bp.route('/<int:guild_id>', methods=['UPDATE']) @bp.route('/<int:guild_id>', methods=['UPDATE'])
async def update_guild(guild_id): async def update_guild(guild_id):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id)
j = validate(await request.get_json(), GUILD_UPDATE)
# TODO: check MANAGE_GUILD # TODO: check MANAGE_GUILD
await guild_check(user_id, guild_id)
j = validate(await request.get_json(), GUILD_UPDATE)
if 'owner_id' in j: if 'owner_id' in j:
await guild_owner_check(user_id, guild_id) await guild_owner_check(user_id, guild_id)
@ -139,8 +182,6 @@ async def update_guild(guild_id):
""", j['name'], guild_id) """, j['name'], guild_id)
if 'region' in j: if 'region' in j:
# TODO: check region value
await app.db.execute(""" await app.db.execute("""
UPDATE guilds UPDATE guilds
SET region = $1 SET region = $1
@ -167,15 +208,14 @@ async def update_guild(guild_id):
WHERE guild_id = $2 WHERE guild_id = $2
""", j[field], guild_id) """, j[field], guild_id)
# return guild object guild = await app.storage.get_guild_full(
gj = await app.storage.get_guild(guild_id, user_id) guild_id, user_id
gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) )
gj_total = {**gj, **gj_extra} await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_UPDATE', guild)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_UPDATE', gj_total) return jsonify(guild)
return jsonify({**gj, **gj_extra})
@bp.route('/<int:guild_id>', methods=['DELETE']) @bp.route('/<int:guild_id>', methods=['DELETE'])
@ -185,7 +225,7 @@ async def delete_guild(guild_id):
await guild_owner_check(user_id, guild_id) await guild_owner_check(user_id, guild_id)
await app.db.execute(""" await app.db.execute("""
DELETE FROM guild DELETE FROM guilds
WHERE guilds.id = $1 WHERE guilds.id = $1
""", guild_id) """, guild_id)
@ -202,264 +242,12 @@ async def delete_guild(guild_id):
return '', 204 return '', 204
@bp.route('/<int:guild>/channels', methods=['GET'])
async def get_guild_channels(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
channels = await app.storage.get_channel_data(guild_id)
return jsonify(channels)
@bp.route('/<int:guild_id>/channels', methods=['POST'])
async def create_channel(guild_id):
user_id = await token_check()
j = await request.get_json()
# TODO: check permissions for MANAGE_CHANNELS
await guild_check(user_id, guild_id)
new_channel_id = get_snowflake()
channel_type = j.get('type', ChannelType.GUILD_TEXT)
channel_type = ChannelType(channel_type)
if channel_type not in (ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE):
raise BadRequest('Invalid channel type')
await app.db.execute("""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", new_channel_id, channel_type.value)
max_pos = await app.db.fetchval("""
SELECT MAX(position)
FROM guild_channels
WHERE guild_id = $1
""", guild_id)
if channel_type == ChannelType.GUILD_TEXT:
await app.db.execute("""
INSERT INTO guild_channels (id, guild_id, name, position)
VALUES ($1, $2, $3, $4)
""", new_channel_id, guild_id, j['name'], max_pos + 1)
await app.db.execute("""
INSERT INTO guild_text_channels (id)
VALUES ($1)
""", new_channel_id)
elif channel_type == ChannelType.GUILD_VOICE:
raise NotImplementedError()
chan = await app.storage.get_channel(new_channel_id)
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_CREATE', chan)
return jsonify(chan)
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
async def modify_channel_pos(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
await request.get_json()
# TODO: this route
raise NotImplementedError
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
async def get_guild_member(guild_id, member_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
member = await app.storage.get_single_member(guild_id, member_id)
return jsonify(member)
@bp.route('/<int:guild_id>/members', methods=['GET'])
async def get_members(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
j = await request.get_json()
limit, after = int(j.get('limit', 1)), j.get('after', 0)
if limit < 1 or limit > 1000:
raise BadRequest('limit not in 1-1000 range')
user_ids = await app.db.fetch(f"""
SELECT user_id
WHERE guild_id = $1, user_id > $2
LIMIT {limit}
ORDER BY user_id ASC
""", guild_id, after)
user_ids = [r[0] for r in user_ids]
members = await app.storage.get_member_multi(guild_id, user_ids)
return jsonify(members)
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
async def modify_guild_member(guild_id, member_id):
j = await request.get_json()
if 'nick' in j:
# TODO: check MANAGE_NICKNAMES
await app.db.execute("""
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
""", j['nick'], member_id, guild_id)
if 'mute' in j:
# TODO: check MUTE_MEMBERS
await app.db.execute("""
UPDATE members
SET muted = $1
WHERE user_id = $2 AND guild_id = $3
""", j['mute'], member_id, guild_id)
if 'deaf' in j:
# TODO: check DEAFEN_MEMBERS
await app.db.execute("""
UPDATE members
SET deafened = $1
WHERE user_id = $2 AND guild_id = $3
""", j['deaf'], member_id, guild_id)
if 'channel_id' in j:
# TODO: check MOVE_MEMBERS
# TODO: change the member's voice channel
pass
member = await app.storage.get_member_data_one(guild_id, member_id)
member.pop('joined_at')
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
'guild_id': str(guild_id)
}, **member})
return '', 204
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
async def update_nickname(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
j = await request.get_json()
await app.db.execute("""
UPDATE members
SET nickname = $1
WHERE user_id = $2 AND guild_id = $3
""", j['nick'], user_id, guild_id)
member = await app.storage.get_member_data_one(guild_id, user_id)
member.pop('joined_at')
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
'guild_id': str(guild_id)
}, **member})
return j['nick']
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
async def kick_member(guild_id, member_id):
user_id = await token_check()
# TODO: check KICK_MEMBERS permission
await guild_owner_check(user_id, guild_id)
await app.db.execute("""
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
await app.dispatcher.dispatch_user(user_id, 'GUILD_DELETE', {
'guild_id': guild_id,
'unavailable': False,
})
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
'guild': guild_id,
'user': await app.storage.get_user(member_id),
})
return '', 204
@bp.route('/<int:guild_id>/bans', methods=['GET'])
async def get_bans(guild_id):
user_id = await token_check()
# TODO: check BAN_MEMBERS permission
await guild_owner_check(user_id, guild_id)
bans = await app.db.fetch("""
SELECT user_id, reason
FROM bans
WHERE bans.guild_id = $1
""", guild_id)
res = []
for ban in bans:
res.append({
'reason': ban['reason'],
'user': await app.storage.get_user(ban['user_id'])
})
return jsonify(res)
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
async def create_ban(guild_id, member_id):
user_id = await token_check()
# TODO: check BAN_MEMBERS permission
await guild_owner_check(user_id, guild_id)
j = await request.get_json()
await app.db.execute("""
INSERT INTO bans (guild_id, user_id, reason)
VALUES ($1, $2, $3)
""", guild_id, member_id, j.get('reason', ''))
await app.db.execute("""
DELETE FROM members
WHERE guild_id = $1 AND user_id = $2
""", guild_id, user_id)
await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', {
'guild_id': guild_id,
'unavailable': False,
})
await app.dispatcher.unsub('guild', guild_id, member_id)
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
'guild': guild_id,
'user': await app.storage.get_user(member_id),
})
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {**{
'guild': guild_id,
}, **(await app.storage.get_user(member_id))})
return '', 204
@bp.route('/<int:guild_id>/messages/search') @bp.route('/<int:guild_id>/messages/search')
async def search_messages(guild_id): async def search_messages(guild_id):
"""Search messages in a guild.
This is an undocumented route.
"""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
@ -474,6 +262,7 @@ async def search_messages(guild_id):
@bp.route('/<int:guild_id>/ack', methods=['POST']) @bp.route('/<int:guild_id>/ack', methods=['POST'])
async def ack_guild(guild_id): async def ack_guild(guild_id):
"""ACKnowledge all messages in the guild."""
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)

View File

@ -185,7 +185,7 @@ async def use_invite(invite_code):
}) })
# subscribe new member to guild, so they get events n stuff # subscribe new member to guild, so they get events n stuff
app.dispatcher.sub_guild(guild_id, user_id) await app.dispatcher.sub('guild', guild_id, user_id)
# tell the new member that theres the guild it just joined. # tell the new member that theres the guild it just joined.
# we use dispatch_user_guild so that we send the GUILD_CREATE # we use dispatch_user_guild so that we send the GUILD_CREATE

View File

@ -219,9 +219,11 @@ async def get_me_guilds():
partial = await app.db.fetchrow(""" partial = await app.db.fetchrow("""
SELECT id::text, name, icon, owner_id SELECT id::text, name, icon, owner_id
FROM guilds FROM guilds
WHERE guild_id = $1 WHERE guilds.id = $1
""", guild_id) """, guild_id)
partial = dict(partial)
# TODO: partial['permissions'] # TODO: partial['permissions']
partial['owner'] = partial['owner_id'] == user_id partial['owner'] = partial['owner_id'] == user_id
partial.pop('owner_id') partial.pop('owner_id')
@ -279,10 +281,11 @@ async def put_note(target_id: int):
INSERT INTO notes (user_id, target_id, note) INSERT INTO notes (user_id, target_id, note)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
ON CONFLICT DO UPDATE SET ON CONFLICT ON CONSTRAINT notes_pkey
DO UPDATE SET
note = $3 note = $3
WHERE WHERE notes.user_id = $1
user_id = $1 AND target_id = $2 AND notes.target_id = $2
""", user_id, target_id, note) """, user_id, target_id, note)
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', { await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
@ -315,7 +318,8 @@ async def patch_current_settings():
await app.db.execute(f""" await app.db.execute(f"""
UPDATE user_settings UPDATE user_settings
SET {key}=$1 SET {key}=$1
""", j[key]) WHERE id = $2
""", j[key], user_id)
settings = await app.storage.get_user_settings(user_id) settings = await app.storage.get_user_settings(user_id)
await app.dispatcher.dispatch_user( await app.dispatcher.dispatch_user(
@ -444,20 +448,20 @@ async def patch_guild_settings(guild_id: int):
continue continue
for field in chan_overrides: for field in chan_overrides:
res = await app.db.execute(f""" await app.db.execute(f"""
UPDATE guild_settings_channel_overrides INSERT INTO guild_settings_channel_overrides
SET {field} = $1 (user_id, guild_id, channel_id, {field})
WHERE user_id = $2 VALUES
AND guild_id = $3 ($1, $2, $3, $4)
AND channel_id = $4 ON CONFLICT
""", chan_overrides[field], user_id, guild_id, chan_id) ON CONSTRAINT guild_settings_channel_overrides_pkey
DO
if res == 'UPDATE 0': UPDATE
await app.db.execute(f""" SET {field} = $4
INSERT INTO guild_settings_channel_overrides WHERE guild_settings_channel_overrides.user_id = $1
(user_id, guild_id, channel_id, {field}) AND guild_settings_channel_overrides.guild_id = $2
VALUES ($1, $2, $3, $4) AND guild_settings_channel_overrides.channel_id = $3
""", user_id, guild_id, chan_id, chan_overrides[field]) """, user_id, guild_id, chan_id, chan_overrides[field])
settings = await app.storage.get_guild_settings_one(user_id, guild_id) settings = await app.storage.get_guild_settings_one(user_id, guild_id)

View File

@ -4,7 +4,8 @@ from typing import List, Any
from logbook import Logger from logbook import Logger
from .pubsub import GuildDispatcher, MemberDispatcher, \ from .pubsub import GuildDispatcher, MemberDispatcher, \
UserDispatcher, ChannelDispatcher, FriendDispatcher UserDispatcher, ChannelDispatcher, FriendDispatcher, \
LazyGuildDispatcher
log = Logger(__name__) log = Logger(__name__)
@ -35,6 +36,7 @@ class EventDispatcher:
'channel': ChannelDispatcher(self), 'channel': ChannelDispatcher(self),
'user': UserDispatcher(self), 'user': UserDispatcher(self),
'friend': FriendDispatcher(self), 'friend': FriendDispatcher(self),
'lazy_guild': LazyGuildDispatcher(self),
} }
async def action(self, backend_str: str, action: str, key, identifier): async def action(self, backend_str: str, action: str, key, identifier):
@ -104,6 +106,15 @@ class EventDispatcher:
for key in keys: for key in keys:
await self.dispatch(backend_str, key, *args, **kwargs) await self.dispatch(backend_str, key, *args, **kwargs)
async def dispatch_filter(self, backend_str: str,
key: Any, func, *args):
"""Dispatch to a backend that only accepts
(event, data) arguments with an optional filter
function."""
backend = self.backends[backend_str]
key = backend.KEY_TYPE(key)
return await backend.dispatch_filter(key, func, *args)
async def reset(self, backend_str: str, key: Any): async def reset(self, backend_str: str, key: Any):
"""Reset the bucket in the given backend.""" """Reset the bucket in the given backend."""
backend = self.backends[backend_str] backend = self.backends[backend_str]

View File

@ -29,16 +29,24 @@ class NotFound(LitecordError):
status_code = 404 status_code = 404
class GuildNotFound(LitecordError): class GuildNotFound(NotFound):
status_code = 404 error_code = 10004
class ChannelNotFound(LitecordError): class ChannelNotFound(NotFound):
status_code = 404 error_code = 10003
class MessageNotFound(LitecordError): class MessageNotFound(NotFound):
status_code = 404 error_code = 10008
class Ratelimited(LitecordError):
status_code = 429
class MissingPermissions(Forbidden):
error_code = 50013
class WebsocketClose(Exception): class WebsocketClose(Exception):

View File

@ -1,18 +1,68 @@
import asyncio
from typing import List, Dict, Any from typing import List, Dict, Any
from collections import defaultdict from collections import defaultdict
from websockets.exceptions import ConnectionClosed
from logbook import Logger from logbook import Logger
from .state import GatewayState from litecord.gateway.state import GatewayState
from litecord.gateway.opcodes import OP
log = Logger(__name__) log = Logger(__name__)
class ManagerClose(Exception):
pass
class StateDictWrapper:
"""Wrap a mapping so that any kind of access to the mapping while the
state manager is closed raises a ManagerClose error"""
def __init__(self, state_manager, mapping):
self.state_manager = state_manager
self._map = mapping
def _check_closed(self):
if self.state_manager.closed:
raise ManagerClose()
def __getitem__(self, key):
self._check_closed()
return self._map[key]
def __delitem__(self, key):
self._check_closed()
del self._map[key]
def __setitem__(self, key, value):
if not self.state_manager.accept_new:
raise ManagerClose()
self._check_closed()
self._map[key] = value
def __iter__(self):
return self._map.__iter__()
def pop(self, key):
return self._map.pop(key)
def values(self):
return self._map.values()
class StateManager: class StateManager:
"""Manager for gateway state information.""" """Manager for gateway state information."""
def __init__(self): def __init__(self):
#: closed flag
self.closed = False
#: accept new states?
self.accept_new = True
# { # {
# user_id: { # user_id: {
# session_id: GatewayState, # session_id: GatewayState,
@ -20,7 +70,10 @@ class StateManager:
# }, # },
# user_id_2: {}, ... # user_id_2: {}, ...
# } # }
self.states = defaultdict(dict) self.states = StateDictWrapper(self, defaultdict(dict))
#: raw mapping from session ids to GatewayState
self.states_raw = StateDictWrapper(self, {})
def insert(self, state: GatewayState): def insert(self, state: GatewayState):
"""Insert a new state object.""" """Insert a new state object."""
@ -28,6 +81,7 @@ class StateManager:
log.debug('inserting state: {!r}', state) log.debug('inserting state: {!r}', state)
user_states[state.session_id] = state user_states[state.session_id] = state
self.states_raw[state.session_id] = state
def fetch(self, user_id: int, session_id: str) -> GatewayState: def fetch(self, user_id: int, session_id: str) -> GatewayState:
"""Fetch a state object from the manager. """Fetch a state object from the manager.
@ -40,11 +94,20 @@ class StateManager:
""" """
return self.states[user_id][session_id] return self.states[user_id][session_id]
def fetch_raw(self, session_id: str) -> GatewayState:
"""Fetch a single state given the Session ID."""
return self.states_raw[session_id]
def remove(self, state): def remove(self, state):
"""Remove a state from the registry""" """Remove a state from the registry"""
if not state: if not state:
return return
try:
self.states_raw.pop(state.session_id)
except KeyError:
pass
try: try:
log.debug('removing state: {!r}', state) log.debug('removing state: {!r}', state)
self.states[state.user_id].pop(state.session_id) self.states[state.user_id].pop(state.session_id)
@ -100,3 +163,54 @@ class StateManager:
states.extend(member_states) states.extend(member_states)
return states return states
async def shutdown_single(self, state: GatewayState):
"""Send OP Reconnect to a single connection."""
websocket = state.ws
await websocket.send({
'op': OP.RECONNECT
})
# wait 200ms
# so that the client has time to process
# our payload then close the connection
await asyncio.sleep(0.2)
try:
# try to close the connection ourselves
await websocket.ws.close(
code=4000,
reason='litecord shutting down'
)
except ConnectionClosed:
log.info('client {} already closed', state)
def gen_close_tasks(self):
"""Generate the tasks that will order the clients
to reconnect.
This is required to be ran before :meth:`StateManager.close`,
since this function doesn't wait for the tasks to complete.
"""
self.accept_new = False
#: store the shutdown tasks
tasks = []
for state in self.states_raw.values():
if not state.ws:
continue
tasks.append(
self.shutdown_single(state)
)
log.info('made {} shutdown tasks', len(tasks))
return tasks
def close(self):
"""Close the state manager."""
self.closed = True

View File

@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple(
) )
WebsocketObjects = collections.namedtuple( WebsocketObjects = collections.namedtuple(
'WebsocketObjects', 'db state_manager storage loop dispatcher presence' 'WebsocketObjects', ('db', 'state_manager', 'storage',
'loop', 'dispatcher', 'presence', 'ratelimiter')
) )
@ -44,8 +45,38 @@ def encode_etf(payload) -> str:
return earl.pack(payload) return earl.pack(payload)
def _etf_decode_dict(data):
# NOTE: this is a very slow implementation to
# decode the dictionary.
if isinstance(data, bytes):
return data.decode()
if not isinstance(data, dict):
return data
_copy = dict(data)
result = {}
for key in _copy.keys():
# assuming key is bytes rn.
new_k = key.decode()
# maybe nested dicts, so...
result[new_k] = _etf_decode_dict(data[key])
return result
def decode_etf(data: bytes): def decode_etf(data: bytes):
return earl.unpack(data) res = earl.unpack(data)
if isinstance(res, bytes):
return data.decode()
if isinstance(res, dict):
return _etf_decode_dict(res)
return res
class GatewayWebsocket: class GatewayWebsocket:
@ -108,6 +139,11 @@ class GatewayWebsocket:
else: else:
await self.ws.send(encoded.decode()) await self.ws.send(encoded.decode())
def _check_ratelimit(self, key: str, ratelimit_key: str):
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
async def _hb_wait(self, interval: int): async def _hb_wait(self, interval: int):
"""Wait heartbeat""" """Wait heartbeat"""
# if the client heartbeats in time, # if the client heartbeats in time,
@ -312,6 +348,14 @@ class GatewayWebsocket:
async def update_status(self, status: dict): async def update_status(self, status: dict):
"""Update the status of the current websocket connection.""" """Update the status of the current websocket connection."""
if not self.state:
return
if self._check_ratelimit('presence', self.state.session_id):
# Presence Updates beyond the ratelimit
# are just silently dropped.
return
if status is None: if status is None:
status = { status = {
'afk': False, 'afk': False,
@ -365,6 +409,15 @@ class GatewayWebsocket:
'op': OP.HEARTBEAT_ACK, 'op': OP.HEARTBEAT_ACK,
}) })
async def _connect_ratelimit(self, user_id: int):
if self._check_ratelimit('connect', user_id):
await self.invalidate_session(False)
raise WebsocketClose(4009, 'You are being ratelimited.')
if self._check_ratelimit('session', user_id):
await self.invalidate_session(False)
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.')
async def handle_2(self, payload: Dict[str, Any]): async def handle_2(self, payload: Dict[str, Any]):
"""Handle the OP 2 Identify packet.""" """Handle the OP 2 Identify packet."""
try: try:
@ -384,6 +437,8 @@ class GatewayWebsocket:
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Authentication failed') raise WebsocketClose(4004, 'Authentication failed')
await self._connect_ratelimit(user_id)
bot = await self.ext.db.fetchval(""" bot = await self.ext.db.fetchval("""
SELECT bot FROM users SELECT bot FROM users
WHERE id = $1 WHERE id = $1
@ -641,9 +696,11 @@ class GatewayWebsocket:
This is the known structure of GUILD_MEMBER_LIST_UPDATE: This is the known structure of GUILD_MEMBER_LIST_UPDATE:
group_id = 'online' | 'offline' | role_id (string)
sync_item = { sync_item = {
'group': { 'group': {
'id': string, // 'online' | 'offline' | any role id 'id': group_id,
'count': num 'count': num
} }
} | { } | {
@ -653,7 +710,7 @@ class GatewayWebsocket:
list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE' list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE'
list_data = { list_data = {
'id': "everyone" // ?? 'id': channel_id | 'everyone',
'guild_id': guild_id, 'guild_id': guild_id,
'ops': [ 'ops': [
@ -666,10 +723,10 @@ class GatewayWebsocket:
// exists if op = 'SYNC' // exists if op = 'SYNC'
'items': sync_item[], 'items': sync_item[],
// exists if op = 'INSERT' or 'DELETE' // exists if op == 'INSERT' | 'DELETE' | 'UPDATE'
'index': num, 'index': num,
// exists if op = 'INSERT' // exists if op == 'INSERT' | 'UPDATE'
'item': sync_item, 'item': sync_item,
} }
], ],
@ -678,31 +735,11 @@ class GatewayWebsocket:
// separately from the online list? // separately from the online list?
'groups': [ 'groups': [
{ {
'id': string // 'online' | 'offline' | any role id 'id': group_id
'count': num 'count': num
}, ... }, ...
] ]
} }
# Implementation defails.
Lazy guilds are complicated to deal with in the backend level
as there are a lot of computation to be done for each request.
The current implementation is rudimentary and does not account
for any roles inside the guild.
A correct implementation would take account of roles and make
the correct groups on list_data:
For each channel in lazy_request['channels']:
- get all roles that have Read Messages on the channel:
- Also fetch their member counts, as it'll be important
- with the role list, order them like you normally would
(by their role priority)
- based on the channel's range's min and max and the ordered
role list, you can get the roles wanted for your list_data reply.
- make new groups ONLY when the role is hoisted.
""" """
data = payload['d'] data = payload['d']
@ -713,65 +750,16 @@ class GatewayWebsocket:
if guild_id not in gids: if guild_id not in gids:
return return
member_ids = await self.storage.get_member_ids(guild_id) # make shard query
log.debug('lazy: loading {} members', len(member_ids)) lazy_guilds = self.ext.dispatcher.backends['lazy_guild']
# the current implementation is rudimentary and only for chan_id, ranges in data.get('channels', {}).items():
# generates two groups: online and offline, using chan_id = int(chan_id)
# PresenceManager.guild_presences to fill list_data. member_list = await lazy_guilds.get_gml(chan_id)
# this also doesn't take account the channels in lazy_request. await member_list.shard_query(
self.state.session_id, ranges
guild_presences = await self.presence.guild_presences(member_ids, )
guild_id)
online = [{'member': p}
for p in guild_presences
if p['status'] == 'online']
offline = [{'member': p}
for p in guild_presences
if p['status'] == 'offline']
log.debug('lazy: {} presences, online={}, offline={}',
len(guild_presences),
len(online),
len(offline))
# construct items in the WORST WAY POSSIBLE.
items = [{
'group': {
'id': 'online',
'count': len(online),
}
}] + online + [{
'group': {
'id': 'offline',
'count': len(offline),
}
}] + offline
await self.dispatch('GUILD_MEMBER_LIST_UPDATE', {
'id': 'everyone',
'guild_id': data['guild_id'],
'groups': [
{
'id': 'online',
'count': len(online),
},
{
'id': 'offline',
'count': len(offline),
}
],
'ops': [
{
'range': [0, 99],
'op': 'SYNC',
'items': items
}
]
})
async def process_message(self, payload): async def process_message(self, payload):
"""Process a single message coming in from the client.""" """Process a single message coming in from the client."""
@ -788,17 +776,36 @@ class GatewayWebsocket:
await handler(payload) await handler(payload)
async def _msg_ratelimit(self):
if self._check_ratelimit('messages', self.state.session_id):
raise WebsocketClose(4008, 'You are being ratelimited.')
async def listen_messages(self): async def listen_messages(self):
"""Listen for messages coming in from the websocket.""" """Listen for messages coming in from the websocket."""
# close anyone trying to login while the
# server is shutting down
if self.ext.state_manager.closed:
raise WebsocketClose(4000, 'state manager closed')
if not self.ext.state_manager.accept_new:
raise WebsocketClose(4000, 'state manager closed for new')
while True: while True:
message = await self.ws.recv() message = await self.ws.recv()
if len(message) > 4096: if len(message) > 4096:
raise DecodeError('Payload length exceeded') raise DecodeError('Payload length exceeded')
if self.state:
await self._msg_ratelimit()
payload = self.decoder(message) payload = self.decoder(message)
await self.process_message(payload) await self.process_message(payload)
def _cleanup(self): def _cleanup(self):
for task in self.wsp.tasks.values():
task.cancel()
if self.state: if self.state:
self.ext.state_manager.remove(self.state) self.ext.state_manager.remove(self.state)
self.state.ws = None self.state.ws = None

242
litecord/permissions.py Normal file
View File

@ -0,0 +1,242 @@
import ctypes
from quart import current_app as app, request
# so we don't keep repeating the same
# type for all the fields
_i = ctypes.c_uint8
class _RawPermsBits(ctypes.LittleEndianStructure):
"""raw bitfield for discord's permission number."""
_fields_ = [
('create_invites', _i, 1),
('kick_members', _i, 1),
('ban_members', _i, 1),
('administrator', _i, 1),
('manage_channels', _i, 1),
('manage_guild', _i, 1),
('add_reactions', _i, 1),
('view_audit_log', _i, 1),
('priority_speaker', _i, 1),
('_unused1', _i, 1),
('read_messages', _i, 1),
('send_messages', _i, 1),
('send_tts', _i, 1),
('manage_messages', _i, 1),
('embed_links', _i, 1),
('attach_files', _i, 1),
('read_history', _i, 1),
('mention_everyone', _i, 1),
('external_emojis', _i, 1),
('_unused2', _i, 1),
('connect', _i, 1),
('speak', _i, 1),
('mute_members', _i, 1),
('deafen_members', _i, 1),
('move_members', _i, 1),
('use_voice_activation', _i, 1),
('change_nickname', _i, 1),
('manage_nicknames', _i, 1),
('manage_roles', _i, 1),
('manage_webhooks', _i, 1),
('manage_emojis', _i, 1),
]
class Permissions(ctypes.Union):
_fields_ = [
('bits', _RawPermsBits),
('binary', ctypes.c_uint64),
]
def __init__(self, val: int):
self.binary = val
def __repr__(self):
return f'<Permissions binary={self.binary}>'
def __int__(self):
return self.binary
def numby(self):
return self.binary
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
"""Get the raw :class:`Permissions` object for a role."""
if not storage:
storage = app.storage
perms = await storage.db.fetchval("""
SELECT permissions
FROM roles
WHERE guild_id = $1 AND id = $2
""", guild_id, role_id)
return Permissions(perms)
async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
"""Compute the base permissions for a given user.
Base permissions are
(permissions from @everyone role) +
(permissions from any other role the member has)
This will give ALL_PERMISSIONS if base permissions
has the Administrator bit set.
"""
if not storage:
storage = app.storage
owner_id = await storage.db.fetchval("""
SELECT owner_id
FROM guilds
WHERE id = $1
""", guild_id)
if owner_id == member_id:
return ALL_PERMISSIONS
# get permissions for @everyone
permissions = await get_role_perms(guild_id, guild_id, storage)
role_ids = await storage.db.fetch("""
SELECT role_id
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
role_perms = []
for row in role_ids:
rperm = await storage.db.fetchval("""
SELECT permissions
FROM roles
WHERE id = $1
""", row['role_id'])
role_perms.append(rperm)
for perm_num in role_perms:
permissions.binary |= perm_num
if permissions.bits.administrator:
return ALL_PERMISSIONS
return permissions
def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
# we make a copy of the binary representation
# so we don't modify the old perms in-place
# which could be an unwanted side-effect
result = perms.binary
# negate the permissions that are denied
result &= ~overwrite['deny']
# combine the permissions that are allowed
result |= overwrite['allow']
return Permissions(result)
def overwrite_find_mix(perms: Permissions, overwrites: dict,
target_id: int) -> Permissions:
overwrite = overwrites.get(target_id)
if overwrite:
# only mix if overwrite found
return overwrite_mix(perms, overwrite)
return perms
async def role_permissions(guild_id: int, role_id: int,
channel_id: int, storage=None) -> Permissions:
"""Get the permissions for a role, in relation to a channel"""
if not storage:
storage = app.storage
perms = await get_role_perms(guild_id, role_id, storage)
overwrite = await storage.db.fetchrow("""
SELECT allow, deny
FROM channel_overwrites
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
""", channel_id, 1, role_id)
if overwrite:
perms = overwrite_mix(perms, overwrite)
return perms
async def compute_overwrites(base_perms, user_id, channel_id: int,
guild_id: int = None, storage=None):
"""Compute the permissions in the context of a channel."""
if not storage:
storage = app.storage
if base_perms.bits.administrator:
return ALL_PERMISSIONS
perms = base_perms
# list of overwrites
overwrites = await storage.chan_overwrites(channel_id)
if not guild_id:
guild_id = await storage.guild_from_channel(channel_id)
# make it a map for better usage
overwrites = {o['id']: o for o in overwrites}
perms = overwrite_find_mix(perms, overwrites, guild_id)
# apply role specific overwrites
allow, deny = 0, 0
# fetch roles from user and convert to int
role_ids = await storage.get_member_role_ids(guild_id, user_id)
role_ids = map(int, role_ids)
# make the allow and deny binaries
for role_id in role_ids:
overwrite = overwrites.get(role_id)
if overwrite:
allow |= overwrite['allow']
deny |= overwrite['deny']
# final step for roles: mix
perms = overwrite_mix(perms, {
'allow': allow,
'deny': deny
})
# apply member specific overwrites
perms = overwrite_find_mix(perms, overwrites, user_id)
return perms
async def get_permissions(member_id, channel_id, *, storage=None):
"""Get all the permissions for a user in a channel."""
if not storage:
storage = app.storage
guild_id = await storage.guild_from_channel(channel_id)
# for non guild channels
if not guild_id:
return ALL_PERMISSIONS
base_perms = await base_permissions(member_id, guild_id, storage)
return await compute_overwrites(base_perms, member_id,
channel_id, guild_id, storage)

View File

@ -1,8 +1,11 @@
from typing import List, Dict, Any from typing import List, Dict, Any
from random import choice from random import choice
from logbook import Logger
from quart import current_app as app from quart import current_app as app
log = Logger(__name__)
def status_cmp(status: str, other_status: str) -> bool: def status_cmp(status: str, other_status: str) -> bool:
"""Compare if `status` is better than the `other_status` """Compare if `status` is better than the `other_status`
@ -100,20 +103,64 @@ class PresenceManager:
game = state['game'] game = state['game']
await self.dispatcher.dispatch_guild( lazy_guild_store = self.dispatcher.backends['lazy_guild']
guild_id, 'PRESENCE_UPDATE', { lists = lazy_guild_store.get_gml_guild(guild_id)
'user': member['user'],
'roles': member['roles'],
'guild_id': guild_id,
'status': state['status'], # shards that are in lazy guilds with 'everyone'
# enabled
in_lazy = []
# rich presence stuff for member_list in lists:
'game': game, session_ids = await member_list.pres_update(
'activities': [game] if game else [] int(member['user']['id']),
} {
'roles': member['roles'],
'status': state['status'],
'game': game
}
)
log.debug('Lazy Dispatch to {}',
len(session_ids))
if member_list.channel_id == 'everyone':
in_lazy.extend(session_ids)
pres_update_payload = {
'user': member['user'],
'roles': member['roles'],
'guild_id': str(guild_id),
'status': state['status'],
# rich presence stuff
'game': game,
'activities': [game] if game else []
}
def _sane_session(session_id):
state = self.state_manager.fetch_raw(session_id)
uid = int(member['user']['id'])
if not state:
return False
# we don't want to send a presence update
# to the same user
return (state.user_id != uid and
session_id not in in_lazy)
# everyone not in lazy guild mode
# gets a PRESENCE_UPDATE
await self.dispatcher.dispatch_filter(
'guild', guild_id,
_sane_session,
'PRESENCE_UPDATE', pres_update_payload
) )
return in_lazy
async def dispatch_pres(self, user_id: int, state: dict): async def dispatch_pres(self, user_id: int, state: dict):
"""Dispatch a new presence to all guilds the user is in. """Dispatch a new presence to all guilds the user is in.
@ -122,10 +169,12 @@ class PresenceManager:
if state['status'] == 'invisible': if state['status'] == 'invisible':
state['status'] = 'offline' state['status'] = 'offline'
# TODO: shard-aware
guild_ids = await self.storage.get_user_guilds(user_id) guild_ids = await self.storage.get_user_guilds(user_id)
for guild_id in guild_ids: for guild_id in guild_ids:
await self.dispatch_guild_pres(guild_id, user_id, state) await self.dispatch_guild_pres(
guild_id, user_id, state)
# dispatch to all friends that are subscribed to them # dispatch to all friends that are subscribed to them
user = await self.storage.get_user(user_id) user = await self.storage.get_user(user_id)

View File

@ -3,7 +3,8 @@ from .member import MemberDispatcher
from .user import UserDispatcher from .user import UserDispatcher
from .channel import ChannelDispatcher from .channel import ChannelDispatcher
from .friend import FriendDispatcher from .friend import FriendDispatcher
from .lazy_guild import LazyGuildDispatcher
__all__ = ['GuildDispatcher', 'MemberDispatcher', __all__ = ['GuildDispatcher', 'MemberDispatcher',
'UserDispatcher', 'ChannelDispatcher', 'UserDispatcher', 'ChannelDispatcher',
'FriendDispatcher'] 'FriendDispatcher', 'LazyGuildDispatcher']

View File

@ -1,5 +1,4 @@
from typing import Any from typing import Any
from collections import defaultdict
from logbook import Logger from logbook import Logger

View File

@ -37,6 +37,14 @@ class Dispatcher:
"""Unsubscribe an elemtnt from the channel/key.""" """Unsubscribe an elemtnt from the channel/key."""
raise NotImplementedError raise NotImplementedError
async def dispatch_filter(self, _key, _func, *_args):
"""Selectively dispatch to the list of subscribed users.
The selection logic is completly arbitraty and up to the
Pub/Sub backend.
"""
raise NotImplementedError
async def dispatch(self, _key, *_args): async def dispatch(self, _key, *_args):
"""Dispatch an event to the given channel/key.""" """Dispatch an event to the given channel/key."""
raise NotImplementedError raise NotImplementedError

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from typing import Any from typing import Any
from logbook import Logger from logbook import Logger
@ -47,6 +46,8 @@ class GuildDispatcher(DispatcherWithState):
# when subbing a user to the guild, we should sub them # when subbing a user to the guild, we should sub them
# to every channel they have access to, in the guild. # to every channel they have access to, in the guild.
# TODO: check for permissions
await self._chan_action('sub', guild_id, user_id) await self._chan_action('sub', guild_id, user_id)
async def unsub(self, guild_id: int, user_id: int): async def unsub(self, guild_id: int, user_id: int):
@ -56,9 +57,10 @@ class GuildDispatcher(DispatcherWithState):
# same thing happening from sub() happens on unsub() # same thing happening from sub() happens on unsub()
await self._chan_action('unsub', guild_id, user_id) await self._chan_action('unsub', guild_id, user_id)
async def dispatch(self, guild_id: int, async def dispatch_filter(self, guild_id: int, func,
event: str, data: Any): event: str, data: Any):
"""Dispatch an event to all subscribers of the guild.""" """Selectively dispatch to session ids that have
func(session_id) true."""
user_ids = self.state[guild_id] user_ids = self.state[guild_id]
dispatched = 0 dispatched = 0
@ -75,8 +77,22 @@ class GuildDispatcher(DispatcherWithState):
await self.unsub(guild_id, user_id) await self.unsub(guild_id, user_id)
continue continue
# filter the ones that matter
states = list(filter(
lambda state: func(state.session_id), states
))
dispatched += await self._dispatch_states( dispatched += await self._dispatch_states(
states, event, data) states, event, data)
log.info('Dispatched {} {!r} to {} states', log.info('Dispatched {} {!r} to {} states',
guild_id, event, dispatched) guild_id, event, dispatched)
async def dispatch(self, guild_id: int,
event: str, data: Any):
"""Dispatch an event to all subscribers of the guild."""
await self.dispatch_filter(
guild_id,
lambda sess_id: True,
event, data,
)

View File

@ -0,0 +1,736 @@
"""
Main code for Lazy Guild implementation in litecord.
"""
import pprint
from dataclasses import dataclass, asdict
from collections import defaultdict
from typing import Any, List, Dict, Union
from logbook import Logger
from litecord.pubsub.dispatcher import Dispatcher
from litecord.permissions import (
Permissions, overwrite_find_mix, get_permissions, role_permissions
)
from litecord.utils import index_by_func
log = Logger(__name__)
GroupID = Union[int, str]
Presence = Dict[str, Any]
@dataclass
class GroupInfo:
"""Store information about a specific group."""
gid: GroupID
name: str
position: int
permissions: Permissions
@dataclass
class MemberList:
"""Total information on the guild's member list."""
groups: List[GroupInfo] = None
group_info: Dict[GroupID, GroupInfo] = None
data: Dict[GroupID, Presence] = None
overwrites: Dict[int, Dict[str, Any]] = None
def __bool__(self):
"""Return if the current member list is fully initialized."""
list_dict = asdict(self)
return all(v is not None for v in list_dict.values())
def __iter__(self):
"""Iterate over all groups in the correct order.
Yields a tuple containing :class:`GroupInfo` and
the List[Presence] for the group.
"""
if not self.groups:
return
for group in self.groups:
yield group, self.data[group.gid]
@dataclass
class Operation:
"""Represents a member list operation."""
list_op: str
params: Dict[str, Any]
@property
def to_dict(self) -> dict:
res = {
'op': self.list_op
}
if self.list_op == 'SYNC':
res['items'] = self.params['items']
if self.list_op in ('SYNC', 'INVALIDATE'):
res['range'] = self.params['range']
if self.list_op in ('INSERT', 'DELETE', 'UPDATE'):
res['index'] = self.params['index']
if self.list_op in ('INSERT', 'UPDATE'):
res['item'] = self.params['item']
return res
def _to_simple_group(presence: dict) -> str:
"""Return a simple group (not a role), given a presence."""
return 'offline' if presence['status'] == 'offline' else 'online'
def display_name(member_nicks: Dict[str, str], presence: Presence) -> str:
"""Return the display name of a presence.
Used to sort groups.
"""
uid = presence['user']['id']
uname = presence['user']['username']
nick = member_nicks.get(uid)
return nick or uname
class GuildMemberList:
"""This class stores the current member list information
for a guild (by channel).
As channels can have different sets of roles that can
read them and so, different lists, this is more of a
"channel member list" than a guild member list.
Attributes
----------
main_lg: LazyGuildDispatcher
Main instance of :class:`LazyGuildDispatcher`,
so that we're able to use things such as :class:`Storage`.
guild_id: int
The Guild ID this instance is referring to.
channel_id: int
The Channel ID this instance is referring to.
member_list: List
The actual member list information.
state: set
The set of session IDs that are subscribed to the guild.
User IDs being used as the identifier in GuildMemberList
is a wrong assumption. It is true Discord rolled out
lazy guilds to all of the userbase, but users that are bots,
for example, can still rely on PRESENCE_UPDATEs.
"""
def __init__(self, guild_id: int,
channel_id: int, main_lg):
self.guild_id = guild_id
self.channel_id = channel_id
self.main = main_lg
self.list = MemberList(None, None, None, None)
#: store the states that are subscribed to the list
# type is{session_id: set[list]}
self.state = defaultdict(set)
@property
def storage(self):
"""Get the global :class:`Storage` instance."""
return self.main.app.storage
@property
def presence(self):
"""Get the global :class:`PresenceManager` instance."""
return self.main.app.presence
@property
def state_man(self):
"""Get the global :class:`StateManager` instance."""
return self.main.app.state_manager
@property
def list_id(self):
"""get the id of the member list."""
return ('everyone'
if self.channel_id == self.guild_id
else str(self.channel_id))
def _set_empty_list(self):
"""Set the member list as being empty."""
self.list = MemberList(None, None, None, None)
async def _init_check(self):
"""Check if the member list is initialized before
messing with it."""
if not self.list:
await self._init_member_list()
async def _fetch_overwrites(self):
overwrites = await self.storage.chan_overwrites(self.channel_id)
overwrites = {int(ov['id']): ov for ov in overwrites}
self.list.overwrites = overwrites
def _calc_member_group(self, roles: List[int], status: str):
"""Calculate the best fitting group for a member,
given their roles and their current status."""
try:
# the first group in the list
# that the member is entitled to is
# the selected group for the member.
group_id = next(g.gid for g in self.list.groups
if g.gid in roles)
except StopIteration:
# no group was found, so we fallback
# to simple group"
group_id = _to_simple_group({'status': status})
return group_id
async def get_roles(self) -> List[GroupInfo]:
"""Get role information, but only:
- the ID
- the name
- the position
- the permissions
of all HOISTED roles AND roles that
have permissions to read the channel
being referred to this :class:`GuildMemberList`
instance.
The list is sorted by position.
"""
roledata = await self.storage.db.fetch("""
SELECT id, name, hoist, position, permissions
FROM roles
WHERE guild_id = $1
""", self.guild_id)
hoisted = [
GroupInfo(row['id'], row['name'],
row['position'], row['permissions'])
for row in roledata if row['hoist']
]
# sort role list by position
hoisted = sorted(hoisted, key=lambda group: group.position)
# we need to store them for later on
# for members
await self._fetch_overwrites()
def _can_read_chan(group: GroupInfo):
# get the base role perms
role_perms = group.permissions
# then the final perms for that role if
# any overwrite exists in the channel
final_perms = overwrite_find_mix(
role_perms, self.list.overwrites, group.gid)
# update the group's permissions
# with the mixed ones
group.permissions = final_perms
# if the role can read messages, then its
# part of the group.
return final_perms.bits.read_messages
return list(filter(_can_read_chan, hoisted))
async def set_groups(self):
"""Get the groups for the member list."""
role_groups = await self.get_roles()
role_ids = [g.gid for g in role_groups]
self.list.groups = role_ids + ['online', 'offline']
# inject default groups 'online' and 'offline'
self.list.groups = role_ids + [
GroupInfo('online', 'online', -1, -1),
GroupInfo('offline', 'offline', -1, -1)
]
self.list.group_info = {g.gid: g for g in role_groups}
async def get_group(self, member_id: int,
roles: List[Union[str, int]],
status: str) -> int:
"""Return a fitting group ID for the user."""
member_roles = list(map(int, roles))
# get the member's permissions relative to the channel
# (accounting for channel overwrites)
member_perms = await get_permissions(
member_id, self.channel_id, storage=self.storage)
if not member_perms.bits.read_messages:
return None
# if the member is offline, we
# default give them the offline group.
group_id = ('offline' if status == 'offline'
else self._calc_member_group(member_roles, status))
return group_id
async def _pass_1(self, guild_presences: List[Presence]):
"""First pass on generating the member list.
This assigns all presences a single group.
"""
for presence in guild_presences:
member_id = int(presence['user']['id'])
group_id = await self.get_group(
member_id, presence['roles'], presence['status']
)
self.list.data[group_id].append(presence)
async def get_member_nicks_dict(self) -> dict:
"""Get a dictionary with nickname information."""
members = await self.storage.get_member_data(self.guild_id)
# make a dictionary of member ids to nicknames
# so we don't need to keep querying the db on
# every loop iteration
member_nicks = {m['user']['id']: m.get('nick')
for m in members}
return member_nicks
async def _sort_groups(self):
member_nicks = await self.get_member_nicks_dict()
for group_members in self.list.data.values():
# this should update the list in-place
group_members.sort(
key=lambda p: display_name(member_nicks, p))
async def _init_member_list(self):
"""Generate the main member list with groups."""
member_ids = await self.storage.get_member_ids(self.guild_id)
guild_presences = await self.presence.guild_presences(
member_ids, self.guild_id)
await self.set_groups()
log.debug('{} presences, {} groups',
len(guild_presences),
len(self.list.groups))
self.list.data = {group.gid: [] for group in self.list.groups}
# first pass: set which presences
# go to which groups
await self._pass_1(guild_presences)
# second pass: sort each group's members
# by the display name
await self._sort_groups()
@property
def items(self) -> list:
"""Main items list."""
# TODO: maybe make this stored in the list
# so we don't need to keep regenning?
if not self.list:
return []
res = []
# NOTE: maybe use map()?
for group, presences in self.list:
res.append({
'group': {
'id': group.gid,
'count': len(presences),
}
})
for presence in presences:
res.append({
'member': presence
})
return res
async def sub(self, _session_id: str):
"""Subscribe a shard to the member list."""
await self._init_check()
def unsub(self, session_id: str):
"""Unsubscribe a shard from the member list"""
try:
self.state.pop(session_id)
except KeyError:
pass
# once we reach 0 subscribers,
# we drop the current member list we have (for memory)
# but keep the GuildMemberList running (as
# uninitialized) for a future subscriber.
if not self.state:
self._set_empty_list()
def get_state(self, session_id: str):
try:
state = self.state_man.fetch_raw(session_id)
return state
except KeyError:
self.unsub(session_id)
return
async def _dispatch_sess(self, session_ids: List[str],
operations: List[Operation]):
# construct the payload to dispatch
payload = {
'id': self.list_id,
'guild_id': str(self.guild_id),
'groups': [
{
'count': len(presences),
'id': group.gid
} for group, presences in self.list
],
'ops': [
operation.to_dict
for operation in operations
]
}
states = map(self.get_state, session_ids)
states = filter(lambda state: state is not None, states)
dispatched = []
for state in states:
await state.ws.dispatch(
'GUILD_MEMBER_LIST_UPDATE', payload)
dispatched.append(state.session_id)
return dispatched
async def shard_query(self, session_id: str, ranges: list):
"""Send a GUILD_MEMBER_LIST_UPDATE event
for a shard that is querying about the member list.
Paramteters
-----------
session_id: str
The Session ID querying information.
channel_id: int
The Channel ID that we want information on.
ranges: List[List[int]]
ranges of the list that we want.
"""
# a guild list with a channel id of the guild
# represents the 'everyone' global list.
list_id = self.list_id
# if everyone can read the channel,
# we direct the request to the 'everyone' gml instance
# instead of the current one.
everyone_perms = await role_permissions(
self.guild_id,
self.guild_id,
self.channel_id,
storage=self.storage
)
if everyone_perms.bits.read_messages and list_id != 'everyone':
everyone_gml = await self.main.get_gml(self.guild_id)
return await everyone_gml.shard_query(
session_id, ranges
)
await self._init_check()
ops = []
for start, end in ranges:
itemcount = end - start
# ignore incorrect ranges
if itemcount < 0:
continue
self.state[session_id].add((start, end))
ops.append(Operation('SYNC', {
'range': [start, end],
'items': self.items[start:end]
}))
await self._dispatch_sess([session_id], ops)
def get_item_index(self, user_id: Union[str, int]):
"""Get the item index a user is on."""
def _get_id(item):
# item can be a group item or a member item
return item.get('member', {}).get('user', {}).get('id')
# get the updated item's index
return index_by_func(
lambda p: _get_id(p) == str(user_id),
self.items
)
def state_is_subbed(self, item_index, session_id: str) -> bool:
"""Return if a state's ranges include the given
item index."""
ranges = self.state[session_id]
for range_start, range_end in ranges:
if range_start <= item_index <= range_end:
return True
return False
def get_subs(self, item_index: int) -> filter:
"""Get the list of subscribed states to a given item."""
return filter(
lambda sess_id: self.state_is_subbed(item_index, sess_id),
self.state.keys()
)
async def _pres_update_simple(self, user_id: int):
item_index = self.get_item_index(user_id)
if item_index is None:
log.warning('lazy guild got invalid pres update uid={}',
user_id)
return []
item = self.items[item_index]
session_ids = self.get_subs(item_index)
# simple update means we just give an UPDATE
# operation
return await self._dispatch_sess(
session_ids,
[
Operation('UPDATE', {
'index': item_index,
'item': item,
})
]
)
async def _pres_update_complex(self, user_id: int,
old_group: str, old_index: int,
new_group: str):
"""Move a member between groups."""
log.debug('complex update: uid={} old={} old_idx={} new={}',
user_id, old_group, old_index, new_group)
old_group_presences = self.list.data[old_group]
old_item_index = self.get_item_index(user_id)
# make a copy of current presence to insert in the new group
current_presence = dict(old_group_presences[old_index])
# step 1: remove the old presence (old_index is relative
# to the group, and not the items list)
del old_group_presences[old_index]
# we need to insert current_presence to the new group
# but we also need to calculate its index to insert on.
presences = self.list.data[new_group]
best_index = 0
member_nicks = await self.get_member_nicks_dict()
current_name = display_name(member_nicks, current_presence)
# go through each one until we find the best placement
for presence in presences:
name = display_name(member_nicks, presence)
print(name, current_name, name < current_name)
# TODO: check if this works
if name < current_name:
break
best_index += 1
# insert the presence at the index
presences.insert(best_index + 1, current_presence)
new_item_index = self.get_item_index(user_id)
log.debug('assigned new item index {} to uid {}',
new_item_index, user_id)
session_ids_old = self.get_subs(old_item_index)
session_ids_new = self.get_subs(new_item_index)
# dispatch events to both the old states and
# new states.
return await self._dispatch_sess(
# inefficient, but necessary since we
# want to merge both session ids.
list(session_ids_old) + list(session_ids_new),
[
Operation('DELETE', {
'index': old_item_index,
}),
Operation('INSERT', {
'index': new_item_index,
'item': {
'member': current_presence
}
})
]
)
async def pres_update(self, user_id: int,
partial_presence: Dict[str, Any]):
"""Update a presence inside the member list.
There are 4 types of updates that can happen for a user in a group:
- from 'offline' to any
- from any to 'offline'
- from any to any
- from G to G (with G being any group)
any: 'online' | role_id
All first, second, and third updates are 'complex' updates,
which means we'll have to change the group the user is on
to account for them.
The fourth is a 'simple' change, since we're not changing
the group a user is on, and so there's less overhead
involved.
"""
await self._init_check()
old_group, old_index, old_presence = None, None, None
for group, presences in self.list:
p_idx = index_by_func(
lambda p: p['user']['id'] == str(user_id),
presences)
log.debug('p_idx for group {!r} = {}',
group.gid, p_idx)
if p_idx is None:
log.debug('skipping group {}', group)
continue
# make a copy since we're modifying in-place
old_group = group.gid
old_index = p_idx
old_presence = dict(presences[p_idx])
# be ready if it is a simple update
presences[p_idx].update(partial_presence)
break
if not old_group:
log.warning('pres update with unknown old group uid={}',
user_id)
return []
roles = partial_presence.get('roles', old_presence['roles'])
new_status = partial_presence.get('status', old_presence['status'])
new_group = await self.get_group(user_id, roles, new_status)
log.debug('pres update: gid={} cid={} old_g={} new_g={}',
self.guild_id, self.channel_id, old_group, new_group)
# if we're going to the same group,
# treat this as a simple update
if old_group == new_group:
return await self._pres_update_simple(user_id)
return await self._pres_update_complex(
user_id, old_group, old_index, new_group)
async def dispatch(self, event: str, data: Any):
"""Modify the member list and dispatch the respective
events to subscribed shards.
GuildMemberList stores the current guilds' list
in its :attr:`GuildMemberList.list` attribute,
with that attribute being modified via different
calls to :meth:`GuildMemberList.dispatch`
"""
# if no subscribers, drop event
if not self.list:
return
class LazyGuildDispatcher(Dispatcher):
"""Main class holding the member lists for lazy guilds."""
# channel ids
KEY_TYPE = int
# the session ids subscribing to channels
VAL_TYPE = str
def __init__(self, main):
super().__init__(main)
self.storage = main.app.storage
# {chan_id: gml, ...}
self.state = {}
#: store which guilds have their
# respective GMLs
# {guild_id: [chan_id, ...], ...}
self.guild_map = defaultdict(list)
async def get_gml(self, channel_id: int):
"""Get a guild list for a channel ID,
generating it if it doesn't exist."""
try:
return self.state[channel_id]
except KeyError:
guild_id = await self.storage.guild_from_channel(
channel_id
)
# if we don't find a guild, we just
# set it the same as the channel.
if not guild_id:
guild_id = channel_id
gml = GuildMemberList(guild_id, channel_id, self)
self.state[channel_id] = gml
self.guild_map[guild_id].append(channel_id)
return gml
def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]:
"""Get all member lists for a given guild."""
return list(map(
self.state.get,
self.guild_map[guild_id]
))
async def unsub(self, chan_id, session_id):
gml = await self.get_gml(chan_id)
gml.unsub(session_id)

View File

@ -0,0 +1,113 @@
"""
main litecord ratelimiting code
This code was copied from elixire's ratelimiting,
which in turn is a work on top of discord.py's ratelimiting.
"""
import time
class RatelimitBucket:
"""Main ratelimit bucket class."""
def __init__(self, tokens, second):
self.requests = tokens
self.second = second
self._window = 0.0
self._tokens = self.requests
self.retries = 0
self._last = 0.0
def get_tokens(self, current):
"""Get the current amount of available tokens."""
if not current:
current = time.time()
# by default, use _tokens
tokens = self._tokens
# if current timestamp is above _window + seconds
# reset tokens to self.requests (default)
if current > self._window + self.second:
tokens = self.requests
return tokens
def update_rate_limit(self):
"""Update current ratelimit state."""
current = time.time()
self._last = current
self._tokens = self.get_tokens(current)
# we are using the ratelimit for the first time
# so set current ratelimit window to right now
if self._tokens == self.requests:
self._window = current
# Are we currently ratelimited?
if self._tokens == 0:
self.retries += 1
return self.second - (current - self._window)
# if not ratelimited, remove a token
self.retries = 0
self._tokens -= 1
# if we got ratelimited after that token removal,
# set window to now
if self._tokens == 0:
self._window = current
def reset(self):
"""Reset current ratelimit to default state."""
self._tokens = self.requests
self._last = 0.0
self.retries = 0
def copy(self):
"""Create a copy of this ratelimit.
Used to manage multiple ratelimits to users.
"""
return RatelimitBucket(self.requests,
self.second)
def __repr__(self):
return (f'<RatelimitBucket requests={self.requests} '
f'second={self.second} window: {self._window} '
f'tokens={self._tokens}>')
class Ratelimit:
"""Manages buckets."""
def __init__(self, tokens, second, keys=None):
self._cache = {}
if keys is None:
keys = tuple()
self.keys = keys
self._cooldown = RatelimitBucket(tokens, second)
def __repr__(self):
return (f'<Ratelimit cooldown={self._cooldown}>')
def _verify_cache(self):
current = time.time()
dead_keys = [k for k, v in self._cache.items()
if current > v._last + v.second]
for k in dead_keys:
del self._cache[k]
def get_bucket(self, key) -> RatelimitBucket:
if not self._cooldown:
return None
self._verify_cache()
if key not in self._cache:
bucket = self._cooldown.copy()
self._cache[key] = bucket
else:
bucket = self._cache[key]
return bucket

View File

@ -0,0 +1,83 @@
from quart import current_app as app, request, g
from litecord.errors import Ratelimited
from litecord.auth import token_check, Unauthorized
async def _check_bucket(bucket):
retry_after = bucket.update_rate_limit()
request.bucket = bucket
if retry_after:
request.retry_after = retry_after
raise Ratelimited('You are being rate limited.', {
'retry_after': int(retry_after * 1000),
'global': request.bucket_global,
})
async def _handle_global(ratelimit):
"""Global ratelimit is per-user."""
try:
user_id = await token_check()
except Unauthorized:
user_id = request.remote_addr
request.bucket_global = True
bucket = ratelimit.get_bucket(user_id)
await _check_bucket(bucket)
async def _handle_specific(ratelimit):
try:
user_id = await token_check()
except Unauthorized:
user_id = request.remote_addr
# construct the key based on the ratelimit.keys
keys = ratelimit.keys
# base key is the user id
key_components = [f'user_id:{user_id}']
for key in keys:
val = request.view_args[key]
key_components.append(f'{key}:{val}')
bucket_key = ':'.join(key_components)
bucket = ratelimit.get_bucket(bucket_key)
await _check_bucket(bucket)
async def ratelimit_handler():
"""Main ratelimit handler.
Decides on which ratelimit to use.
"""
rule = request.url_rule
if rule is None:
return await _handle_global(
app.ratelimiter.global_bucket
)
# rule.endpoint is composed of '<blueprint>.<function>'
# and so we can use that to make routes with different
# methods have different ratelimits
rule_path = rule.endpoint
# some request ratelimit context.
# TODO: maybe put those in a namedtuple or contextvar of sorts?
request.bucket = None
request.retry_after = None
request.bucket_global = False
try:
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
await _handle_specific(ratelimit)
except KeyError:
await _handle_global(
app.ratelimiter.global_bucket
)

View File

@ -0,0 +1,56 @@
from litecord.ratelimits.bucket import Ratelimit
"""
REST:
POST Message | 5/5s | per-channel
DELETE Message | 5/1s | per-channel
PUT/DELETE Reaction | 1/0.25s | per-channel
PATCH Member | 10/10s | per-guild
PATCH Member Nick | 1/1s | per-guild
PATCH Username | 2/3600s | per-account
|All Requests| | 50/1s | per-account
WS:
Gateway Connect | 1/5s | per-account
Presence Update | 5/60s | per-session
|All Sent Messages| | 120/60s | per-session
"""
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
RATELIMITS = {
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
# all of those share the same bucket.
'channel_reactions.add_reaction': REACTION_BUCKET,
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
# this only applies to username.
# 'users.patch_me': Ratelimit(2, 3600),
'_ws.connect': Ratelimit(1, 5),
'_ws.presence': Ratelimit(5, 60),
'_ws.messages': Ratelimit(120, 60),
# 1000 / 4h for new session issuing
'_ws.session': Ratelimit(1000, 14400)
}
class RatelimitManager:
"""Manager for the bucket managers"""
def __init__(self):
self._ratelimiters = {}
self.global_bucket = Ratelimit(50, 1)
self._fill_rtl()
def _fill_rtl(self):
for path, rtl in RATELIMITS.items():
self._ratelimiters[path] = rtl
def get_ratelimit(self, key: str) -> Ratelimit:
"""Get the :class:`Ratelimit` instance for a given path."""
return self._ratelimiters.get(key, self.global_bucket)

View File

@ -1,11 +1,16 @@
import re import re
from typing import Union, Dict, List, Any
from cerberus import Validator from cerberus import Validator
from logbook import Logger from logbook import Logger
from .errors import BadRequest from .errors import BadRequest
from .enums import ActivityType, StatusType, ExplicitFilter, \ from .permissions import Permissions
RelationshipType, MessageNotifications from .types import Color
from .enums import (
ActivityType, StatusType, ExplicitFilter, RelationshipType,
MessageNotifications, ChannelType, VerificationLevel
)
log = Logger(__name__) log = Logger(__name__)
@ -24,13 +29,21 @@ EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M) ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
def _in_enum(enum, value: int):
try:
enum(value)
return True
except ValueError:
return False
class LitecordValidator(Validator): class LitecordValidator(Validator):
def _validate_type_username(self, value: str) -> bool: def _validate_type_username(self, value: str) -> bool:
"""Validate against the username regex.""" """Validate against the username regex."""
return bool(USERNAME_REGEX.match(value)) return bool(USERNAME_REGEX.match(value))
def _validate_type_email(self, value: str) -> bool: def _validate_type_email(self, value: str) -> bool:
"""Validate against the username regex.""" """Validate against the email regex."""
return bool(EMAIL_REGEX.match(value)) return bool(EMAIL_REGEX.match(value))
def _validate_type_b64_icon(self, value: str) -> bool: def _validate_type_b64_icon(self, value: str) -> bool:
@ -56,11 +69,17 @@ class LitecordValidator(Validator):
def _validate_type_voice_region(self, value: str) -> bool: def _validate_type_voice_region(self, value: str) -> bool:
# TODO: complete this list # TODO: complete this list
return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia') return value.lower() in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
def _validate_type_verification_level(self, value: int) -> bool:
return _in_enum(VerificationLevel, value)
def _validate_type_activity_type(self, value: int) -> bool: def _validate_type_activity_type(self, value: int) -> bool:
return value in ActivityType.values() return value in ActivityType.values()
def _validate_type_channel_type(self, value: int) -> bool:
return value in ChannelType.values()
def _validate_type_status_external(self, value: str) -> bool: def _validate_type_status_external(self, value: str) -> bool:
statuses = StatusType.values() statuses = StatusType.values()
@ -94,11 +113,31 @@ class LitecordValidator(Validator):
return val in MessageNotifications.values() return val in MessageNotifications.values()
def _validate_type_guild_name(self, value: str) -> bool:
return 2 <= len(value) <= 100
def validate(reqjson, schema, raise_err: bool = True): def _validate_type_role_name(self, value: str) -> bool:
return 1 <= len(value) <= 100
def _validate_type_channel_name(self, value: str) -> bool:
# for now, we'll use the same validation for guild_name
return self._validate_type_guild_name(value)
def validate(reqjson: Union[Dict, List], schema: Dict,
raise_err: bool = True) -> Union[Dict, List]:
"""Validate a given document (user-input) and give
the correct document as a result.
"""
validator = LitecordValidator(schema) validator = LitecordValidator(schema)
if not validator.validate(reqjson): try:
valid = validator.validate(reqjson)
except Exception:
log.exception('Error while validating')
raise Exception(f'Error while validating: {reqjson}')
if not valid:
errs = validator.errors errs = validator.errors
log.warning('Error validating doc {!r}: {!r}', reqjson, errs) log.warning('Error validating doc {!r}: {!r}', reqjson, errs)
@ -146,16 +185,55 @@ USER_UPDATE = {
} }
PARTIAL_ROLE_GUILD_CREATE = {
'type': 'dict',
'schema': {
'name': {'type': 'role_name'},
'color': {'type': 'number', 'default': 0},
'hoist': {'type': 'boolean', 'default': False},
# NOTE: no position on partial role (on guild create)
'permissions': {'coerce': Permissions, 'required': False},
'mentionable': {'type': 'boolean', 'default': False},
}
}
PARTIAL_CHANNEL_GUILD_CREATE = {
'type': 'dict',
'schema': {
'name': {'type': 'channel_name'},
'type': {'type': 'channel_type'},
}
}
GUILD_CREATE = {
'name': {'type': 'guild_name'},
'region': {'type': 'voice_region'},
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
'verification_level': {
'type': 'verification_level', 'default': 0},
'default_message_notifications': {
'type': 'msg_notifications', 'default': 0},
'explicit_content_filter': {
'type': 'explicit', 'default': 0},
'roles': {
'type': 'list', 'required': False,
'schema': PARTIAL_ROLE_GUILD_CREATE},
'channels': {
'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE},
}
GUILD_UPDATE = { GUILD_UPDATE = {
'name': { 'name': {
'type': 'string', 'type': 'guild_name',
'minlength': 2,
'maxlength': 100,
'required': False 'required': False
}, },
'region': {'type': 'voice_region', 'required': False}, 'region': {'type': 'voice_region', 'required': False},
'icon': {'type': 'icon', 'required': False}, 'icon': {'type': 'b64_icon', 'required': False},
'verification_level': {'type': 'verification_level', 'required': False}, 'verification_level': {'type': 'verification_level', 'required': False},
'default_message_notifications': { 'default_message_notifications': {
@ -173,13 +251,93 @@ GUILD_UPDATE = {
} }
CHAN_OVERWRITE = {
'id': {'coerce': int},
'type': {'type': 'string', 'allowed': ['role', 'member']},
'allow': {'coerce': Permissions},
'deny': {'coerce': Permissions}
}
CHAN_UPDATE = {
'name': {
'type': 'string', 'minlength': 2,
'maxlength': 100, 'required': False},
'position': {'coerce': int, 'required': False},
'topic': {
'type': 'string', 'minlength': 0,
'maxlength': 1024, 'required': False},
'nsfw': {'type': 'boolean', 'required': False},
'rate_limit_per_user': {
'coerce': int, 'min': 0,
'max': 120, 'required': False},
'bitrate': {
'coerce': int, 'min': 8000,
# NOTE: 'max' is 96000 for non-vip guilds
'max': 128000, 'required': False},
'user_limit': {
# user_limit being 0 means infinite.
'coerce': int, 'min': 0,
'max': 99, 'required': False
},
'permission_overwrites': {
'type': 'list',
'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE},
'required': False
},
'parent_id': {'coerce': int, 'required': False, 'nullable': True}
}
ROLE_CREATE = {
'name': {'type': 'string', 'default': 'new role'},
'permissions': {'coerce': Permissions, 'nullable': True},
'color': {'coerce': Color, 'default': 0},
'hoist': {'type': 'boolean', 'default': False},
'mentionable': {'type': 'boolean', 'default': False},
}
ROLE_UPDATE = {
'name': {'type': 'string', 'required': False},
'permissions': {'coerce': Permissions, 'required': False},
'color': {'coerce': Color, 'required': False},
'hoist': {'type': 'boolean', 'required': False},
'mentionable': {'type': 'boolean', 'required': False},
}
ROLE_UPDATE_POSITION = {
'roles': {
'type': 'list',
'schema': {
'type': 'dict',
'schema': {
'id': {'coerce': int},
'position': {'coerce': int},
},
}
}
}
MEMBER_UPDATE = { MEMBER_UPDATE = {
'nick': { 'nick': {
'type': 'nickname', 'type': 'username',
'minlength': 1, 'maxlength': 100, 'minlength': 1, 'maxlength': 100,
'required': False, 'required': False,
}, },
'roles': {'type': 'list', 'required': False}, 'roles': {'type': 'list', 'required': False,
'schema': {'coerce': int}},
'mute': {'type': 'boolean', 'required': False}, 'mute': {'type': 'boolean', 'required': False},
'deaf': {'type': 'boolean', 'required': False}, 'deaf': {'type': 'boolean', 'required': False},
'channel_id': {'type': 'snowflake', 'required': False}, 'channel_id': {'type': 'snowflake', 'required': False},
@ -196,57 +354,60 @@ MESSAGE_CREATE = {
GW_ACTIVITY = { GW_ACTIVITY = {
'name': {'type': 'string', 'required': True}, 'type': 'dict',
'type': {'type': 'activity_type', 'required': True}, 'schema': {
'name': {'type': 'string', 'required': True},
'type': {'type': 'activity_type', 'required': True},
'url': {'type': 'string', 'required': False, 'nullable': True}, 'url': {'type': 'string', 'required': False, 'nullable': True},
'timestamps': { 'timestamps': {
'type': 'dict', 'type': 'dict',
'required': False, 'required': False,
'schema': { 'schema': {
'start': {'type': 'number', 'required': True}, 'start': {'type': 'number', 'required': True},
'end': {'type': 'number', 'required': True}, 'end': {'type': 'number', 'required': False},
},
}, },
},
'application_id': {'type': 'snowflake', 'required': False, 'application_id': {'type': 'snowflake', 'required': False,
'nullable': False}, 'nullable': False},
'details': {'type': 'string', 'required': False, 'nullable': True}, 'details': {'type': 'string', 'required': False, 'nullable': True},
'state': {'type': 'string', 'required': False, 'nullable': True}, 'state': {'type': 'string', 'required': False, 'nullable': True},
'party': { 'party': {
'type': 'dict', 'type': 'dict',
'required': False, 'required': False,
'schema': { 'schema': {
'id': {'type': 'snowflake', 'required': False}, 'id': {'type': 'snowflake', 'required': False},
'size': {'type': 'list', 'required': False}, 'size': {'type': 'list', 'required': False},
} }
}, },
'assets': { 'assets': {
'type': 'dict', 'type': 'dict',
'required': False, 'required': False,
'schema': { 'schema': {
'large_image': {'type': 'snowflake', 'required': False}, 'large_image': {'type': 'snowflake', 'required': False},
'large_text': {'type': 'string', 'required': False}, 'large_text': {'type': 'string', 'required': False},
'small_image': {'type': 'snowflake', 'required': False}, 'small_image': {'type': 'snowflake', 'required': False},
'small_text': {'type': 'string', 'required': False}, 'small_text': {'type': 'string', 'required': False},
} }
}, },
'secrets': { 'secrets': {
'type': 'dict', 'type': 'dict',
'required': False, 'required': False,
'schema': { 'schema': {
'join': {'type': 'string', 'required': False}, 'join': {'type': 'string', 'required': False},
'spectate': {'type': 'string', 'required': False}, 'spectate': {'type': 'string', 'required': False},
'match': {'type': 'string', 'required': False}, 'match': {'type': 'string', 'required': False},
} }
}, },
'instance': {'type': 'boolean', 'required': False}, 'instance': {'type': 'boolean', 'required': False},
'flags': {'type': 'number', 'required': False}, 'flags': {'type': 'number', 'required': False},
}
} }
GW_STATUS_UPDATE = { GW_STATUS_UPDATE = {
@ -335,6 +496,8 @@ USER_SETTINGS = {
'show_current_game': {'type': 'boolean', 'required': False}, 'show_current_game': {'type': 'boolean', 'required': False},
'timezone_offset': {'type': 'number', 'required': False}, 'timezone_offset': {'type': 'number', 'required': False},
'status': {'type': 'status_external', 'required': False}
} }
RELATIONSHIP = { RELATIONSHIP = {
@ -395,3 +558,7 @@ GUILD_SETTINGS = {
'required': False, 'required': False,
} }
} }
GUILD_PRUNE = {
'days': {'type': 'number', 'coerce': int, 'min': 1}
}

View File

@ -5,6 +5,9 @@ from logbook import Logger
from .enums import ChannelType, RelationshipType from .enums import ChannelType, RelationshipType
from .schemas import USER_MENTION, ROLE_MENTION from .schemas import USER_MENTION, ROLE_MENTION
from litecord.blueprints.channel.reactions import (
emoji_info_from_str, EmojiType, emoji_sql, partial_emoji
)
log = Logger(__name__) log = Logger(__name__)
@ -163,17 +166,40 @@ class Storage:
WHERE guild_id = $1 and user_id = $2 WHERE guild_id = $1 and user_id = $2
""", guild_id, member_id) """, guild_id, member_id)
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: async def get_member_role_ids(self, guild_id: int,
members_roles = await self.db.fetch(""" member_id: int) -> List[int]:
"""Get a list of role IDs that are on a member."""
roles = await self.db.fetch("""
SELECT role_id::text SELECT role_id::text
FROM member_roles FROM member_roles
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id) """, guild_id, member_id)
roles = [r['role_id'] for r in roles]
try:
roles.remove(str(guild_id))
except ValueError:
# if the @everyone role isn't in, we add it
# to member_roles automatically (it won't
# be shown on the API, though).
await self.db.execute("""
INSERT INTO member_roles (user_id, guild_id, role_id)
VALUES ($1, $2, $3)
""", member_id, guild_id, guild_id)
return list(map(str, roles))
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
roles = await self.get_member_role_ids(guild_id, member_id)
return { return {
'user': await self.get_user(member_id), 'user': await self.get_user(member_id),
'nick': row['nickname'], 'nick': row['nickname'],
'roles': [row[0] for row in members_roles],
# we don't send the @everyone role's id to
# the user since it is known that everyone has
# that role.
'roles': roles,
'joined_at': row['joined_at'].isoformat(), 'joined_at': row['joined_at'].isoformat(),
'deaf': row['deafened'], 'deaf': row['deafened'],
'mute': row['muted'], 'mute': row['muted'],
@ -289,7 +315,7 @@ class Storage:
WHERE channels.id = $1 WHERE channels.id = $1
""", channel_id) """, channel_id)
async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
overwrite_rows = await self.db.fetch(""" overwrite_rows = await self.db.fetch("""
SELECT target_type, target_role, target_user, allow, deny SELECT target_type, target_role, target_user, allow, deny
FROM channel_overwrites FROM channel_overwrites
@ -298,18 +324,20 @@ class Storage:
def _overwrite_convert(row): def _overwrite_convert(row):
drow = dict(row) drow = dict(row)
drow['type'] = drow['target_type']
target_type = drow['target_type']
drow['type'] = 'user' if target_type == 0 else 'role'
# if type is 0, the overwrite is for a user # if type is 0, the overwrite is for a user
# if type is 1, the overwrite is for a role # if type is 1, the overwrite is for a role
drow['id'] = { drow['id'] = {
0: drow['target_user'], 0: drow['target_user'],
1: drow['target_role'], 1: drow['target_role'],
}[drow['type']] }[target_type]
drow['id'] = str(drow['id']) drow['id'] = str(drow['id'])
drow.pop('overwrite_type') drow.pop('target_type')
drow.pop('target_user') drow.pop('target_user')
drow.pop('target_role') drow.pop('target_role')
@ -335,8 +363,8 @@ class Storage:
dbase['type'] = chan_type dbase['type'] = chan_type
res = await self._channels_extra(dbase) res = await self._channels_extra(dbase)
res['permission_overwrites'] = \ res['permission_overwrites'] = await self.chan_overwrites(
list(await self._chan_overwrites(channel_id)) channel_id)
res['id'] = str(res['id']) res['id'] = str(res['id'])
return res return res
@ -401,8 +429,8 @@ class Storage:
res = await self._channels_extra(drow) res = await self._channels_extra(drow)
res['permission_overwrites'] = \ res['permission_overwrites'] = await self.chan_overwrites(
list(await self._chan_overwrites(row['id'])) row['id'])
# Making sure. # Making sure.
res['id'] = str(res['id']) res['id'] = str(res['id'])
@ -440,6 +468,7 @@ class Storage:
permissions, managed, mentionable permissions, managed, mentionable
FROM roles FROM roles
WHERE guild_id = $1 WHERE guild_id = $1
ORDER BY position ASC
""", guild_id) """, guild_id)
return list(map(dict, roledata)) return list(map(dict, roledata))
@ -535,7 +564,70 @@ class Storage:
return res return res
async def get_message(self, message_id: int) -> Dict: async def get_reactions(self, message_id: int, user_id=None) -> List:
"""Get all reactions in a message."""
reactions = await self.db.fetch("""
SELECT user_id, emoji_type, emoji_id, emoji_text
FROM message_reactions
WHERE message_id = $1
ORDER BY react_ts
""", message_id)
# ordered list of emoji
emoji = []
# the current state of emoji info
react_stats = {}
# to generate the list, we pass through all
# all reactions and insert them all.
# we can't use a set() because that
# doesn't guarantee any order.
for row in reactions:
etype = EmojiType(row['emoji_type'])
eid, etext = row['emoji_id'], row['emoji_text']
# get the main key to use, given
# the emoji information
_, main_emoji = emoji_sql(etype, eid, etext)
if main_emoji in emoji:
continue
# maintain order (first reacted comes first
# on the reaction list)
emoji.append(main_emoji)
react_stats[main_emoji] = {
'count': 0,
'me': False,
'emoji': partial_emoji(etype, eid, etext)
}
# then the 2nd pass, where we insert
# the info for each reaction in the react_stats
# dictionary
for row in reactions:
etype = EmojiType(row['emoji_type'])
eid, etext = row['emoji_id'], row['emoji_text']
# same thing as the last loop,
# extracting main key
_, main_emoji = emoji_sql(etype, eid, etext)
stats = react_stats[main_emoji]
stats['count'] += 1
if row['user_id'] == user_id:
stats['me'] = True
# after processing reaction counts,
# we get them in the same order
# they were defined in the first loop.
return list(map(react_stats.get, emoji))
async def get_message(self, message_id: int, user_id=None) -> Dict:
"""Get a single message's payload.""" """Get a single message's payload."""
row = await self.db.fetchrow(""" row = await self.db.fetchrow("""
SELECT id::text, channel_id::text, author_id, content, SELECT id::text, channel_id::text, author_id, content,
@ -596,6 +688,8 @@ class Storage:
res['mention_roles'] = await self._msg_regex( res['mention_roles'] = await self._msg_regex(
ROLE_MENTION, _get_role_mention, content) ROLE_MENTION, _get_role_mention, content)
res['reactions'] = await self.get_reactions(message_id, user_id)
# TODO: handle webhook authors # TODO: handle webhook authors
res['author'] = await self.get_user(res['author_id']) res['author'] = await self.get_user(res['author_id'])
res.pop('author_id') res.pop('author_id')
@ -606,9 +700,6 @@ class Storage:
# TODO: res['embeds'] # TODO: res['embeds']
res['embeds'] = [] res['embeds'] = []
# TODO: res['reactions']
res['reactions'] = []
# TODO: res['pinned'] # TODO: res['pinned']
res['pinned'] = False res['pinned'] = False
@ -966,7 +1057,6 @@ class Storage:
""", user_id) """, user_id)
for row in settings: for row in settings:
print(dict(row))
gid = int(row['guild_id']) gid = int(row['guild_id'])
drow = dict(row) drow = dict(row)

15
litecord/types.py Normal file
View File

@ -0,0 +1,15 @@
class Color:
"""Custom color class"""
def __init__(self, val: int):
self.blue = val & 255
self.green = (val >> 8) & 255
self.red = (val >> 16) & 255
@property
def value(self):
"""Give the actual RGB integer encoding this color."""
return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16)
def __int__(self):
return self.value

View File

@ -22,3 +22,18 @@ async def task_wrapper(name: str, coro):
pass pass
except: except:
log.exception('{} task error', name) log.exception('{} task error', name)
def dict_get(mapping, key, default):
"""Return `default` even when mapping[key] is None."""
return mapping.get(key) or default
def index_by_func(function, indexable: iter) -> int:
"""Search in an idexable and return the index number
for an iterm that has func(item) = True."""
for index, item in enumerate(indexable):
if function(item):
return index
return None

12
manage.py Executable file
View File

@ -0,0 +1,12 @@
#!/usr/bin/env python3
import logging
import sys
from manage.main import main
import config
logging.basicConfig(level=logging.DEBUG)
if __name__ == '__main__':
sys.exit(main(config))

0
manage/__init__.py Normal file
View File

View File

@ -0,0 +1 @@
from .command import setup as migration

View File

@ -0,0 +1,143 @@
import inspect
from pathlib import Path
from dataclasses import dataclass
from collections import namedtuple
from typing import Dict
import asyncpg
from logbook import Logger
log = Logger(__name__)
Migration = namedtuple('Migration', 'id name path')
@dataclass
class MigrationContext:
"""Hold information about migration."""
migration_folder: Path
scripts: Dict[int, Migration]
@property
def latest(self):
"""Return the latest migration ID."""
return max(self.scripts.keys())
def make_migration_ctx() -> MigrationContext:
"""Create the MigrationContext instance."""
# taken from https://stackoverflow.com/a/6628348
script_path = inspect.stack()[0][1]
script_folder = '/'.join(script_path.split('/')[:-1])
script_folder = Path(script_folder)
migration_folder = script_folder / 'scripts'
mctx = MigrationContext(migration_folder, {})
for mig_path in migration_folder.glob('*.sql'):
mig_path_str = str(mig_path)
# extract migration script id and name
mig_filename = mig_path_str.split('/')[-1].split('.')[0]
name_fragments = mig_filename.split('_')
mig_id = int(name_fragments[0])
mig_name = '_'.join(name_fragments[1:])
mctx.scripts[mig_id] = Migration(
mig_id, mig_name, mig_path)
return mctx
async def _ensure_changelog(app, ctx):
# make sure we have the migration table up
try:
await app.db.execute("""
CREATE TABLE migration_log (
change_num bigint NOT NULL,
apply_ts timestamp without time zone default
(now() at time zone 'utc'),
description text,
PRIMARY KEY (change_num)
);
""")
# if we were able to create the
# migration_log table, insert that we are
# on the latest version.
await app.db.execute("""
INSERT INTO migration_log (change_num, description)
VALUES ($1, $2)
""", ctx.latest, 'migration setup')
except asyncpg.DuplicateTableError:
log.debug('existing migration table')
async def apply_migration(app, migration: Migration):
"""Apply a single migration."""
migration_sql = migration.path.read_text(encoding='utf-8')
try:
await app.db.execute("""
INSERT INTO migration_log (change_num, description)
VALUES ($1, $2)
""", migration.id, f'migration: {migration.name}')
except asyncpg.UniqueViolationError:
log.warning('already applied {}', migration.id)
return
await app.db.execute(migration_sql)
log.info('applied {}', migration.id)
async def migrate_cmd(app, _args):
"""Main migration command.
This makes sure the database
is updated.
"""
ctx = make_migration_ctx()
await _ensure_changelog(app, ctx)
# local point in the changelog
local_change = await app.db.fetchval("""
SELECT max(change_num)
FROM migration_log
""")
local_change = local_change or 0
latest_change = ctx.latest
log.debug('local: {}, latest: {}', local_change, latest_change)
if local_change == latest_change:
print('no changes to do, exiting')
return
# we do local_change + 1 so we start from the
# next migration to do, end in latest_change + 1
# because of how range() works.
for idx in range(local_change + 1, latest_change + 1):
migration = ctx.scripts.get(idx)
print('applying', migration.id, migration.name)
await apply_migration(app, migration)
def setup(subparser):
migrate_parser = subparser.add_parser(
'migrate',
help='Run migration tasks',
description=migrate_cmd.__doc__
)
migrate_parser.set_defaults(func=migrate_cmd)

View File

@ -0,0 +1,6 @@
-- unused tables
DROP TABLE message_embeds;
DROP TABLE embeds;
ALTER TABLE messages
ADD COLUMN embeds jsonb DEFAULT '[]'

58
manage/main.py Normal file
View File

@ -0,0 +1,58 @@
import asyncio
import argparse
from sys import argv
from dataclasses import dataclass
from logbook import Logger
from run import init_app_managers, init_app_db
from manage.cmd.migration import migration
log = Logger(__name__)
@dataclass
class FakeApp:
"""Fake app instance."""
config: dict
db = None
loop: asyncio.BaseEventLoop = None
ratelimiter = None
state_manager = None
storage = None
dispatcher = None
presence = None
def init_parser():
parser = argparse.ArgumentParser()
subparser = parser.add_subparsers(help='operations')
migration(subparser)
return parser
def main(config):
"""Start the script"""
loop = asyncio.get_event_loop()
cfg = getattr(config, config.MODE)
app = FakeApp(cfg.__dict__)
loop.run_until_complete(init_app_db(app))
init_app_managers(app)
# initialize argparser
parser = init_parser()
try:
if len(argv) < 2:
parser.print_help()
return
args = parser.parse_args()
loop.run_until_complete(args.func(app, args))
except Exception:
log.exception('error while running command')
finally:
loop.run_until_complete(app.db.close())

View File

@ -5,13 +5,11 @@ server {
location / { location / {
proxy_pass http://localhost:5000; proxy_pass http://localhost:5000;
} }
}
# Main litecord websocket proxy. # if you don't want to keep the gateway
server { # domain as the main domain, you can
server_name websocket.somewhere; # keep a separate server block
location /ws {
location / {
proxy_pass http://localhost:5001; proxy_pass http://localhost:5001;
# those options are required for websockets # those options are required for websockets

106
run.py
View File

@ -9,9 +9,28 @@ from quart import Quart, g, jsonify, request
from logbook import StreamHandler, Logger from logbook import StreamHandler, Logger
from logbook.compat import redirect_logging from logbook.compat import redirect_logging
# import the config set by instance owner
import config import config
from litecord.blueprints import gateway, auth, users, guilds, channels, \
webhooks, science, voice, invites, relationships, dms from litecord.blueprints import (
gateway, auth, users, guilds, channels, webhooks, science,
voice, invites, relationships, dms
)
# those blueprints are separated from the "main" ones
# for code readability if people want to dig through
# the codebase.
from litecord.blueprints.guild import (
guild_roles, guild_members, guild_channels, guild_mod
)
from litecord.blueprints.channel import (
channel_messages, channel_reactions, channel_pins
)
from litecord.ratelimits.handler import ratelimit_handler
from litecord.ratelimits.main import RatelimitManager
from litecord.gateway import websocket_handler from litecord.gateway import websocket_handler
from litecord.errors import LitecordError from litecord.errors import LitecordError
from litecord.gateway.state_manager import StateManager from litecord.gateway.state_manager import StateManager
@ -50,8 +69,18 @@ bps = {
auth: '/auth', auth: '/auth',
users: '/users', users: '/users',
relationships: '/users', relationships: '/users',
guilds: '/guilds', guilds: '/guilds',
guild_roles: '/guilds',
guild_members: '/guilds',
guild_channels: '/guilds',
guild_mod: '/guilds',
channels: '/channels', channels: '/channels',
channel_messages: '/channels',
channel_reactions: '/channels',
channel_pins: '/channels',
webhooks: None, webhooks: None,
science: None, science: None,
voice: '/voice', voice: '/voice',
@ -64,6 +93,11 @@ for bp, suffix in bps.items():
app.register_blueprint(bp, url_prefix=f'/api/v6{suffix}') app.register_blueprint(bp, url_prefix=f'/api/v6{suffix}')
@app.before_request
async def app_before_request():
await ratelimit_handler()
@app.after_request @app.after_request
async def app_after_request(resp): async def app_after_request(resp):
origin = request.headers.get('Origin', '*') origin = request.headers.get('Origin', '*')
@ -80,19 +114,44 @@ async def app_after_request(resp):
# resp.headers['Access-Control-Allow-Methods'] = '*' # resp.headers['Access-Control-Allow-Methods'] = '*'
resp.headers['Access-Control-Allow-Methods'] = \ resp.headers['Access-Control-Allow-Methods'] = \
resp.headers.get('allow', '*') resp.headers.get('allow', '*')
return resp return resp
@app.before_serving @app.after_request
async def app_before_serving(): async def app_set_ratelimit_headers(resp):
log.info('opening db') """Set the specific ratelimit headers."""
try:
bucket = request.bucket
if bucket is None:
raise AttributeError()
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second)
resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower()
# only add Retry-After if we actually hit a ratelimit
retry_after = request.retry_after
if request.retry_after:
resp.headers['Retry-After'] = str(retry_after)
except AttributeError:
pass
return resp
async def init_app_db(app):
"""Connect to databases"""
app.db = await asyncpg.create_pool(**app.config['POSTGRES']) app.db = await asyncpg.create_pool(**app.config['POSTGRES'])
g.app = app
def init_app_managers(app):
"""Initialize singleton classes."""
app.loop = asyncio.get_event_loop() app.loop = asyncio.get_event_loop()
g.loop = asyncio.get_event_loop() app.ratelimiter = RatelimitManager()
app.state_manager = StateManager() app.state_manager = StateManager()
app.storage = Storage(app.db) app.storage = Storage(app.db)
@ -101,6 +160,17 @@ async def app_before_serving():
app.state_manager, app.dispatcher) app.state_manager, app.dispatcher)
app.storage.presence = app.presence app.storage.presence = app.presence
@app.before_serving
async def app_before_serving():
log.info('opening db')
await init_app_db(app)
g.app = app
g.loop = asyncio.get_event_loop()
init_app_managers(app)
# start the websocket, etc # start the websocket, etc
host, port = app.config['WS_HOST'], app.config['WS_PORT'] host, port = app.config['WS_HOST'], app.config['WS_PORT']
log.info(f'starting websocket at {host} {port}') log.info(f'starting websocket at {host} {port}')
@ -108,8 +178,11 @@ async def app_before_serving():
async def _wrapper(ws, url): async def _wrapper(ws, url):
# We wrap the main websocket_handler # We wrap the main websocket_handler
# so we can pass quart's app object. # so we can pass quart's app object.
# TODO: pass just the app object
await websocket_handler((app.db, app.state_manager, app.storage, await websocket_handler((app.db, app.state_manager, app.storage,
app.loop, app.dispatcher, app.presence), app.loop, app.dispatcher, app.presence,
app.ratelimiter),
ws, url) ws, url)
ws_future = websockets.serve(_wrapper, host, port) ws_future = websockets.serve(_wrapper, host, port)
@ -119,6 +192,15 @@ async def app_before_serving():
@app.after_serving @app.after_serving
async def app_after_serving(): async def app_after_serving():
"""Shutdown tasks for the server."""
# first close all clients, then close db
tasks = app.state_manager.gen_close_tasks()
if tasks:
await asyncio.wait(tasks, loop=app.loop)
app.state_manager.close()
log.info('closing db') log.info('closing db')
await app.db.close() await app.db.close()
@ -130,9 +212,13 @@ async def handle_litecord_err(err):
except IndexError: except IndexError:
ejson = {} ejson = {}
try:
ejson['code'] = err.error_code
except AttributeError:
pass
return jsonify({ return jsonify({
'error': True, 'error': True,
# 'code': err.code,
'status': err.status_code, 'status': err.status_code,
'message': err.message, 'message': err.message,
**ejson **ejson

View File

@ -75,6 +75,9 @@ CREATE TABLE IF NOT EXISTS users (
phone varchar(60) DEFAULT '', phone varchar(60) DEFAULT '',
password_hash text NOT NULL, password_hash text NOT NULL,
-- store the last time the user logged in via the gateway
last_session timestamp without time zone default (now() at time zone 'utc'),
PRIMARY KEY (id, username, discriminator) PRIMARY KEY (id, username, discriminator)
); );
@ -131,6 +134,10 @@ CREATE TABLE IF NOT EXISTS user_settings (
-- appearance -- appearance
message_display_compact bool DEFAULT false, message_display_compact bool DEFAULT false,
-- for now we store status but don't
-- actively use it, since the official client
-- sends its own presence on IDENTIFY
status text DEFAULT 'online' NOT NULL, status text DEFAULT 'online' NOT NULL,
theme text DEFAULT 'dark' NOT NULL, theme text DEFAULT 'dark' NOT NULL,
developer_mode bool DEFAULT true, developer_mode bool DEFAULT true,
@ -328,7 +335,7 @@ CREATE TABLE IF NOT EXISTS channel_overwrites (
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
-- target_type = 0 -> use target_user -- target_type = 0 -> use target_user
-- target_type = 1 -> user target_role -- target_type = 1 -> use target_role
-- discord already has overwrite.type = 'role' | 'member' -- discord already has overwrite.type = 'role' | 'member'
-- so this allows us to be more compliant with the API -- so this allows us to be more compliant with the API
target_type integer default null, target_type integer default null,
@ -344,11 +351,15 @@ CREATE TABLE IF NOT EXISTS channel_overwrites (
-- they're bigints (64bits), discord, -- they're bigints (64bits), discord,
-- for now, only needs 53. -- for now, only needs 53.
allow bigint DEFAULT 0, allow bigint DEFAULT 0,
deny bigint DEFAULT 0, deny bigint DEFAULT 0
PRIMARY KEY (channel_id, target_role, target_user)
); );
-- columns in private keys can't have NULL values,
-- so instead we use a custom constraint with UNIQUE
ALTER TABLE channel_overwrites ADD CONSTRAINT channel_overwrites_uniq
UNIQUE (channel_id, target_role, target_user);
CREATE TABLE IF NOT EXISTS features ( CREATE TABLE IF NOT EXISTS features (
id serial PRIMARY KEY, id serial PRIMARY KEY,
@ -479,11 +490,6 @@ CREATE TABLE IF NOT EXISTS bans (
); );
CREATE TABLE IF NOT EXISTS embeds (
-- TODO: this table
id bigint PRIMARY KEY
);
CREATE TABLE IF NOT EXISTS messages ( CREATE TABLE IF NOT EXISTS messages (
id bigint PRIMARY KEY, id bigint PRIMARY KEY,
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
@ -504,6 +510,8 @@ CREATE TABLE IF NOT EXISTS messages (
tts bool default false, tts bool default false,
mention_everyone bool default false, mention_everyone bool default false,
embeds jsonb DEFAULT '[]',
nonce bigint default 0, nonce bigint default 0,
message_type int NOT NULL message_type int NOT NULL
@ -515,22 +523,22 @@ CREATE TABLE IF NOT EXISTS message_attachments (
PRIMARY KEY (message_id, attachment) PRIMARY KEY (message_id, attachment)
); );
CREATE TABLE IF NOT EXISTS message_embeds (
message_id bigint REFERENCES messages (id) UNIQUE,
embed_id bigint REFERENCES embeds (id),
PRIMARY KEY (message_id, embed_id)
);
CREATE TABLE IF NOT EXISTS message_reactions ( CREATE TABLE IF NOT EXISTS message_reactions (
message_id bigint REFERENCES messages (id), message_id bigint REFERENCES messages (id),
user_id bigint REFERENCES users (id), user_id bigint REFERENCES users (id),
-- since it can be a custom emote, or unicode emoji react_ts timestamp without time zone default (now() at time zone 'utc'),
-- emoji_type = 0 -> custom emoji
-- emoji_type = 1 -> unicode emoji
emoji_type int DEFAULT 0,
emoji_id bigint REFERENCES guild_emoji (id), emoji_id bigint REFERENCES guild_emoji (id),
emoji_text text NOT NULL, emoji_text text
PRIMARY KEY (message_id, user_id, emoji_id, emoji_text)
); );
ALTER TABLE message_reactions ADD CONSTRAINT message_reactions_main_uniq
UNIQUE (message_id, user_id, emoji_id, emoji_text);
CREATE TABLE IF NOT EXISTS channel_pins ( CREATE TABLE IF NOT EXISTS channel_pins (
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
message_id bigint REFERENCES messages (id) ON DELETE CASCADE, message_id bigint REFERENCES messages (id) ON DELETE CASCADE,