basic invites

- add invites blueprint
 - add POST /api/v6/channels/<channel_id>/invites
 - add GET /api/v6/invites/<invite_code>
 - schema: add INVITE
 - storage: add get_member_ids, get_invite, get_invite_extra,
 get_invite_metadata
This commit is contained in:
Luna Mendes 2018-09-26 20:29:22 -03:00
parent ea9bb52e4f
commit dc62de37b2
5 changed files with 183 additions and 6 deletions

View File

@ -6,3 +6,4 @@ from .channels import bp as channels
from .webhooks import bp as webhooks
from .science import bp as science
from .voice import bp as voice
from .invites import bp as invites

View File

@ -0,0 +1,56 @@
import hashlib
import os
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from ..auth import token_check
from ..schemas import validate, INVITE
from ..enums import ChannelType
from ..errors import BadRequest
from .channels import channel_check
log = Logger(__name__)
bp = Blueprint('invites', __name__)
@bp.route('/channels/<int:channel_id>/invites', methods=['POST'])
async def create_invite(channel_id):
user_id = await token_check()
j = validate(await request.get_json(), INVITE)
guild_id = await channel_check(user_id, channel_id)
# TODO: check CREATE_INSTANT_INVITE permission
chantype = await app.storage.get_chan_type(channel_id)
if chantype not in (ChannelType.GUILD_TEXT.value,
ChannelType.GUILD_VOICE.value):
raise BadRequest('Invalid channel type')
invite_code = hashlib.md5(os.urandom(64)).hexdigest()[:16]
await app.db.execute(
"""
INSERT INTO invites
(code, guild_id, channel_id, inviter, max_uses,
max_age, temporary)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
invite_code, guild_id, channel_id, user_id,
j['max_uses'], j['max_age'], j['temporary']
)
invite = await app.storage.get_invite(invite_code)
return jsonify(invite)
@bp.route('/invites/<invite_code>', methods=['GET'])
async def get_invite(invite_code: str):
inv = await app.storage.get_invite(invite_code)
if request.args.get('with_counts'):
extra = await app.storage.get_invite_extra(invite_code)
inv.update(extra)
return jsonify(inv)

View File

@ -1,10 +1,14 @@
import re
from cerberus import Validator
from logbook import Logger
from .errors import BadRequest
from .enums import ActivityType, StatusType
log = Logger(__name__)
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
re.A)
@ -48,6 +52,8 @@ class LitecordValidator(Validator):
def validate(reqjson, schema, raise_err: bool = True):
validator = LitecordValidator(schema)
log.debug('Validating {}', reqjson)
if not validator.validate(reqjson):
errs = validator.errors
@ -174,3 +180,31 @@ GW_STATUS_UPDATE = {
'schema': GW_ACTIVITY,
},
}
INVITE = {
# max_age in seconds
# 0 for infinite
'max_age': {
'type': 'number',
'min': 0,
'max': 86400,
# a day
'default': 86400
},
# max invite uses
'max_uses': {
'type': 'number',
'min': 0,
# idk
'max': 1000,
# default infinite
'default': 0
},
'temporary': {'type': 'boolean', 'required': False, 'default': False},
'unique': {'type': 'boolean', 'required': False, 'default': True},
}

View File

@ -13,6 +13,10 @@ async def _dummy(any_id):
return str(any_id)
def dict_(val):
return dict(val) if val else None
class Storage:
"""Class for common SQL statements."""
def __init__(self, db):
@ -73,10 +77,9 @@ class Storage:
drow['system_channel_id'] = str(drow['system_channel_id']) \
if drow['system_channel_id'] else None
return {**drow, **{
# TODO: those
'emojis': [],
}}
# TODO: emojis
drow['emojis'] = []
return drow
async def get_user_guilds(self, user_id: int) -> List[int]:
"""Get all guild IDs a user is on."""
@ -320,6 +323,15 @@ class Storage:
),
}}
async def get_member_ids(self, guild_id: int) -> List[int]:
rows = await self.db.fetch("""
SELECT user_id
FROM members
WHERE guild_id = $1
""", guild_id)
return [r[0] for r in rows]
async def _msg_regex(self, regex, method, content) -> List[Dict]:
res = []
@ -393,3 +405,76 @@ class Storage:
return {str(row['target_id']): row['note']
for row in note_rows}
async def get_invite(self, invite_code: str) -> dict:
"""Fetch invite information given its code."""
invite = await self.db.fetchrow("""
SELECT code, guild_id, channel_id
FROM invites
WHERE code = $1
""", invite_code)
if invite is None:
return None
dinv = dict_(invite)
# fetch some guild info
guild = await self.db.fetchrow("""
SELECT id::text, name, splash, icon
FROM guilds
WHERE id = $1
""", invite['guild_id'])
dinv['guild'] = dict(guild)
chan = await self.get_channel(invite['channel_id'])
dinv['channel'] = {
'id': chan['id'],
'name': chan['name'],
'type': chan['type'],
}
dinv.pop('guild_id')
dinv.pop('channel_id')
return dinv
async def get_invite_extra(self, invite_code: str) -> dict:
"""Extra information about the invite, such as
approximate guild and presence counts."""
guild_id = await self.db.fetchval("""
SELECT guild_id
FROM invites
WHERE code = $1
""", invite_code)
if guild_id is None:
return {}
mids = await self.get_member_ids(guild_id)
pres = await self.presence.guild_presences(mids, guild_id)
online_count = sum(1 for p in pres if p['status'] == 'online')
return {
'approximate_presence_count': online_count,
'approximate_member_count': len(mids),
}
async def get_invite_metadata(self, invite_code: str) -> Dict[str, Any]:
"""Fetch invite metadata (max_age and friends)."""
invite = await self.db.fetchrow("""
SELECT code, inviter, created_at, uses,
max_uses, max_age, temporary, created_at, revoked
FROM invites
WHERE code = $1
""", invite_code)
if invite is None:
return
dinv = dict_(invite)
inviter = await self.get_user(invite['inviter'])
dinv['inviter'] = inviter
return dinv

5
run.py
View File

@ -11,7 +11,7 @@ from logbook.compat import redirect_logging
import config
from litecord.blueprints import gateway, auth, users, guilds, channels, \
webhooks, science, voice
webhooks, science, voice, invites
from litecord.gateway import websocket_handler
from litecord.errors import LitecordError
from litecord.gateway.state_manager import StateManager
@ -53,7 +53,8 @@ bps = {
channels: '/channels',
webhooks: None,
science: None,
voice: '/voice'
voice: '/voice',
invites: None
}
for bp, suffix in bps.items():