mirror of https://gitlab.com/litecord/litecord.git
blueprints.channels: add message sending
It is clunky when sending messages as Atomic, nor we have proper channel management, but it works. - dispatcher: add sub_many - errors: add failsafe on LitecordError.message - errors: add ChannelNotFound - gateway.websocket: add dispatcher to WebsocketObjects - schemas: add regexes for mentions - storage: add get_channel, get_message
This commit is contained in:
parent
4ea3d353b3
commit
59127ad197
|
|
@ -1,39 +1,50 @@
|
|||
import time
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
||||
from ..auth import token_check
|
||||
from ..snowflake import get_snowflake
|
||||
from ..enums import ChannelType
|
||||
from ..errors import Forbidden, BadRequest, MessageNotFound
|
||||
from ..schemas import validate
|
||||
from ..enums import ChannelType, MessageType
|
||||
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
|
||||
from ..schemas import validate, MESSAGE_CREATE
|
||||
|
||||
from .guilds import guild_check
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channels', __name__)
|
||||
|
||||
|
||||
async def channel_check(user_id, channel_id):
|
||||
ctype = await app.db.fetchval("""
|
||||
SELECT channel_type
|
||||
FROM channels
|
||||
WHERE channels.id = $1
|
||||
""", channel_id)
|
||||
"""Check if the current user is authorized
|
||||
to read the channel's information."""
|
||||
ctype = await app.storage.get_chan_type(channel_id)
|
||||
|
||||
if ctype is None:
|
||||
raise ChannelNotFound(f'channel type not found')
|
||||
|
||||
if ctype in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
||||
ChannelType.GUILD_CATEGORY):
|
||||
guild_id = await app.db.fetchval("""
|
||||
SELECT guild_id
|
||||
FROM guild_channels
|
||||
WHERE channel_id = $1
|
||||
WHERE guild_channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
return guild_id
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>', methods=['GET'])
|
||||
async def get_channel(channel_id):
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
return '', 204
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
if not chan:
|
||||
raise ChannelNotFound('single channel not found')
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
||||
|
|
@ -43,13 +54,27 @@ async def get_messages(channel_id):
|
|||
|
||||
# TODO: before, after, around keys
|
||||
|
||||
await app.db.fetch(f"""
|
||||
SELECT *
|
||||
message_ids = await app.db.fetch(f"""
|
||||
SELECT id
|
||||
FROM messages
|
||||
WHERE channel_id = $1
|
||||
ORDER BY id ASC
|
||||
ORDER BY id DESC
|
||||
LIMIT 100
|
||||
""")
|
||||
""", channel_id)
|
||||
|
||||
result = []
|
||||
|
||||
for message_id in message_ids:
|
||||
msg = await app.storage.get_message(message_id['id'])
|
||||
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
|
||||
log.info('Fetched {} messages', len(result))
|
||||
print(result)
|
||||
return jsonify(result)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
|
||||
|
|
@ -59,25 +84,50 @@ async def get_single_message(channel_id, message_id):
|
|||
|
||||
# TODO: check READ_MESSAGE_HISTORY permissions
|
||||
|
||||
message = await app.db.fetchrow("""
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE channel_id = $1 AND messages.id = $2
|
||||
""", channel_id, message_id)
|
||||
message = await app.storage.get_message(message_id)
|
||||
|
||||
if not message:
|
||||
raise MessageNotFound()
|
||||
|
||||
return jsonify(message)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
||||
async def create_message(channel_id):
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||
message_id = get_snowflake()
|
||||
|
||||
# TODO: check SEND_MESSAGES permission
|
||||
# TODO: check SEND_TTS_MESSAGES
|
||||
# TODO: check connection to the gateway
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
||||
mention_everyone, nonce, message_type)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""", message_id, channel_id, user_id, j['content'], j.get('tts', False),
|
||||
'@everyone' in j['content'], j.get('nonce', 0), MessageType.DEFAULT)
|
||||
|
||||
# TODO: parse payload, make schema
|
||||
# TODO: insert and dispatch message
|
||||
# TODO: dispatch_channel
|
||||
payload = await app.storage.get_message(message_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
|
||||
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/typing', methods=['POST'])
|
||||
async def trigger_typing(channel_id):
|
||||
user_id = await token_check()
|
||||
guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', {
|
||||
'channel_id': channel_id,
|
||||
'user_id': user_id,
|
||||
'timestamp': int(time.time()),
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ async def guild_check(user_id: int, guild_id: int):
|
|||
""", user_id, guild_id)
|
||||
|
||||
if not joined_at:
|
||||
raise GuildNotFound()
|
||||
raise GuildNotFound('guild not found')
|
||||
|
||||
|
||||
async def guild_owner_check(user_id: int, guild_id: int):
|
||||
|
|
|
|||
|
|
@ -24,6 +24,11 @@ class EventDispatcher:
|
|||
"""Reset the guild bucket."""
|
||||
self.guild_buckets[guild_id] = set()
|
||||
|
||||
def sub_many(self, user_id: int, guild_ids: list):
|
||||
"""Subscribe to many guilds at a time."""
|
||||
for guild_id in guild_ids:
|
||||
self.sub_guild(guild_id, user_id)
|
||||
|
||||
async def dispatch_guild(self, guild_id: int,
|
||||
event_name: str, event_payload: Any):
|
||||
"""Dispatch an event to a guild"""
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ class LitecordError(Exception):
|
|||
|
||||
@property
|
||||
def message(self):
|
||||
try:
|
||||
return self.args[0]
|
||||
except IndexError:
|
||||
return repr(self)
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
|
|
@ -30,6 +33,10 @@ class GuildNotFound(LitecordError):
|
|||
status_code = 404
|
||||
|
||||
|
||||
class ChannelNotFound(LitecordError):
|
||||
status_code = 404
|
||||
|
||||
|
||||
class MessageNotFound(LitecordError):
|
||||
status_code = 404
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ WebsocketProperties = collections.namedtuple(
|
|||
)
|
||||
|
||||
WebsocketObjects = collections.namedtuple(
|
||||
'WebsocketObjects', 'db state_manager storage loop'
|
||||
'WebsocketObjects', 'db state_manager storage loop dispatcher'
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -81,7 +81,6 @@ class GatewayWebsocket:
|
|||
if not isinstance(encoded, bytes):
|
||||
encoded = encoded.encode()
|
||||
|
||||
print(self.wsp.compress)
|
||||
if self.wsp.compress == 'zlib-stream':
|
||||
data1 = self.wsp.zctx.compress(encoded)
|
||||
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
||||
|
|
@ -246,6 +245,17 @@ class GatewayWebsocket:
|
|||
if current_shard > shard_count:
|
||||
raise InvalidShard('Shard count > Total shards')
|
||||
|
||||
async def subscribe_guilds(self):
|
||||
"""Subscribe to all available guilds"""
|
||||
guild_ids = await self.ext.db.fetch("""
|
||||
SELECT guild_id
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", self.state.user_id)
|
||||
|
||||
guild_ids = [r['guild_id'] for r in guild_ids]
|
||||
self.ext.dispatcher.sub_many(self.state.user_id, guild_ids)
|
||||
|
||||
async def handle_1(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 1 Heartbeat packets."""
|
||||
pass
|
||||
|
|
@ -291,6 +301,7 @@ class GatewayWebsocket:
|
|||
|
||||
self.ext.state_manager.insert(self.state)
|
||||
await self.dispatch_ready()
|
||||
await self.subscribe_guilds()
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 3 Status Update."""
|
||||
|
|
@ -386,6 +397,9 @@ class GatewayWebsocket:
|
|||
log.warning('closed a client, state={} err={}', self.state, err)
|
||||
|
||||
await self.ws.close(code=err.code, reason=err.reason)
|
||||
except Exception as err:
|
||||
log.exception('An exception has occoured. state={}', self.state)
|
||||
await self.ws.close(code=4000, reason=repr(err))
|
||||
finally:
|
||||
if self.state:
|
||||
self.state.ws = None
|
||||
|
|
|
|||
|
|
@ -8,6 +8,15 @@ USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
|
|||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
|
||||
re.A)
|
||||
|
||||
|
||||
# collection of regexes
|
||||
USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M)
|
||||
CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M)
|
||||
ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M)
|
||||
EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
|
||||
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
|
||||
|
||||
|
||||
class LitecordValidator(Validator):
|
||||
def _validate_type_username(self, value: str) -> bool:
|
||||
"""Validate against the username regex."""
|
||||
|
|
@ -61,3 +70,11 @@ MEMBER_UPDATE = {
|
|||
'deaf': {'type': 'bool', 'required': False},
|
||||
'channel_id': {'type': 'snowflake', 'required': False},
|
||||
}
|
||||
|
||||
MESSAGE_CREATE = {
|
||||
'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000},
|
||||
'nonce': {'type': 'number', 'required': False},
|
||||
'tts': {'type': 'boolean', 'required': False},
|
||||
|
||||
# TODO: file, embed, payload_json
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
from typing import List, Dict, Any
|
||||
|
||||
from .enums import ChannelType
|
||||
from .schemas import USER_MENTION, ROLE_MENTION
|
||||
|
||||
|
||||
async def _dummy(any_id):
|
||||
return str(any_id)
|
||||
|
||||
|
||||
class Storage:
|
||||
|
|
@ -26,8 +31,8 @@ class Storage:
|
|||
|
||||
if not secure:
|
||||
duser.pop('email')
|
||||
duser.pop('mfa_enabled')
|
||||
duser.pop('verified')
|
||||
duser.pop('mfa_enabled')
|
||||
|
||||
return duser
|
||||
|
||||
|
|
@ -174,6 +179,51 @@ class Storage:
|
|||
|
||||
return {**row, **dict(vrow)}
|
||||
|
||||
async def get_chan_type(self, channel_id) -> int:
|
||||
return await self.db.fetchval("""
|
||||
SELECT channel_type
|
||||
FROM channels
|
||||
WHERE channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
async def _chan_overwrites(self, channel_id):
|
||||
overwrite_rows = await self.db.fetch("""
|
||||
SELECT target_id::text AS id, overwrite_type, allow, deny
|
||||
FROM channel_overwrites
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
|
||||
def _overwrite_convert(ov_row):
|
||||
drow = dict(ov_row)
|
||||
drow['type'] = drow['overwrite_type']
|
||||
drow.pop('overwrite_type')
|
||||
return drow
|
||||
|
||||
return map(_overwrite_convert, overwrite_rows)
|
||||
|
||||
async def get_channel(self, channel_id) -> Dict[str, Any]:
|
||||
"""Fetch a single channel's information."""
|
||||
chan_type = await self.get_chan_type(channel_id)
|
||||
|
||||
if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
||||
ChannelType.GUILD_CATEGORY):
|
||||
base = await self.db.fetchrow("""
|
||||
SELECT id, guild_id::text, parent_id, name, position, nsfw
|
||||
FROM guild_channels
|
||||
WHERE guild_channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
res = await self._channels_extra(dict(base), chan_type)
|
||||
res['type'] = chan_type
|
||||
res['permission_overwrites'] = \
|
||||
list(await self._chan_overwrites(channel_id))
|
||||
|
||||
res['id'] = str(res['id'])
|
||||
return res
|
||||
else:
|
||||
# TODO: dms and group dms
|
||||
pass
|
||||
|
||||
async def get_channel_data(self, guild_id) -> List[Dict]:
|
||||
"""Get channel information on a guild"""
|
||||
channel_basics = await self.db.fetch("""
|
||||
|
|
@ -193,22 +243,8 @@ class Storage:
|
|||
res = await self._channels_extra(dict(row), ctype)
|
||||
res['type'] = ctype
|
||||
|
||||
# type is a SQL keyword, so we can't do
|
||||
# 'overwrite_type AS type'
|
||||
overwrite_rows = await self.db.fetch("""
|
||||
SELECT target_id::text AS id, overwrite_type, allow, deny
|
||||
FROM channel_overwrites
|
||||
WHERE channel_id = $1
|
||||
""", row['id'])
|
||||
|
||||
def _overwrite_convert(ov_row):
|
||||
drow = dict(ov_row)
|
||||
drow['type'] = drow['overwrite_type']
|
||||
drow.pop('overwrite_type')
|
||||
return drow
|
||||
|
||||
res['permission_overwrites'] = list(map(_overwrite_convert,
|
||||
overwrite_rows))
|
||||
res['permission_overwrites'] = \
|
||||
list(await self._chan_overwrites(row['id']))
|
||||
|
||||
# Making sure.
|
||||
res['id'] = str(res['id'])
|
||||
|
|
@ -262,6 +298,69 @@ class Storage:
|
|||
'voice_states': [],
|
||||
'channels': channels,
|
||||
'roles': roles,
|
||||
# TODO: finish those
|
||||
|
||||
# TODO: finish presences
|
||||
'presences': [],
|
||||
}}
|
||||
|
||||
async def _msg_regex(self, regex, method, content) -> List[Dict]:
|
||||
res = []
|
||||
|
||||
for match in regex.finditer(content):
|
||||
found_id = match.group(1)
|
||||
|
||||
try:
|
||||
found_id = int(found_id)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
obj = await method(found_id)
|
||||
if obj:
|
||||
res.append(obj)
|
||||
|
||||
return res
|
||||
|
||||
async def get_message(self, message_id: int) -> Dict:
|
||||
"""Get a single message's payload."""
|
||||
row = await self.db.fetchrow("""
|
||||
SELECT id::text, channel_id::text, author_id, content,
|
||||
created_at AS timestamp, edited_at AS edited_timestamp,
|
||||
tts, mention_everyone, nonce, message_type
|
||||
FROM messages
|
||||
WHERE id = $1
|
||||
""", message_id)
|
||||
|
||||
if not row:
|
||||
return
|
||||
|
||||
res = dict(row)
|
||||
res['timestamp'] = res['timestamp'].isoformat()
|
||||
res['type'] = res['message_type']
|
||||
res.pop('message_type')
|
||||
|
||||
# calculate user mentions and role mentions by regex
|
||||
res['mentions'] = await self._msg_regex(USER_MENTION, self.get_user,
|
||||
row['content'])
|
||||
|
||||
# _dummy just returns the string of the id, since we don't
|
||||
# actually use the role objects in mention_roles, just their ids.
|
||||
res['mention_roles'] = await self._msg_regex(ROLE_MENTION, _dummy,
|
||||
row['content'])
|
||||
|
||||
# TODO: handle webhook authors
|
||||
res['author'] = await self.get_user(res['author_id'])
|
||||
res.pop('author_id')
|
||||
|
||||
# TODO: res['attachments']
|
||||
res['attachments'] = []
|
||||
|
||||
# TODO: res['embeds']
|
||||
res['embeds'] = []
|
||||
|
||||
# TODO: res['reactions']
|
||||
res['reactions'] = []
|
||||
|
||||
# TODO: res['pinned']
|
||||
res['pinned'] = False
|
||||
|
||||
return res
|
||||
|
|
|
|||
4
run.py
4
run.py
|
|
@ -63,8 +63,8 @@ async def app_before_serving():
|
|||
async def _wrapper(ws, url):
|
||||
# We wrap the main websocket_handler
|
||||
# so we can pass quart's app object.
|
||||
await websocket_handler((app.db, app.state_manager,
|
||||
app.storage, app.loop), ws, url)
|
||||
await websocket_handler((app.db, app.state_manager, app.storage,
|
||||
app.loop, app.dispatcher), ws, url)
|
||||
|
||||
ws_future = websockets.serve(_wrapper, host, port)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue