mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'group-dms' into 'master'
Group DMs See merge request litecord/litecord!18
This commit is contained in:
commit
da7ef70458
|
|
@ -32,7 +32,9 @@ from .icons import bp as icons
|
||||||
from .nodeinfo import bp as nodeinfo
|
from .nodeinfo import bp as nodeinfo
|
||||||
from .static import bp as static
|
from .static import bp as static
|
||||||
from .attachments import bp as attachments
|
from .attachments import bp as attachments
|
||||||
|
from .dm_channels import bp as dm_channels
|
||||||
|
|
||||||
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
|
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
|
||||||
'webhooks', 'science', 'voice', 'invites', 'relationships',
|
'webhooks', 'science', 'voice', 'invites', 'relationships',
|
||||||
'dms', 'icons', 'nodeinfo', 'static', 'attachments']
|
'dms', 'icons', 'nodeinfo', 'static', 'attachments',
|
||||||
|
'dm_channels']
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,14 @@ from quart import Blueprint, request, current_app as app, jsonify
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.auth import token_check
|
from litecord.auth import token_check
|
||||||
from litecord.enums import ChannelType, GUILD_CHANS
|
from litecord.enums import ChannelType, GUILD_CHANS, MessageType
|
||||||
from litecord.errors import ChannelNotFound
|
from litecord.errors import ChannelNotFound
|
||||||
from litecord.schemas import (
|
from litecord.schemas import (
|
||||||
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL
|
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
|
from litecord.system_messages import send_sys_message
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('channels', __name__)
|
bp = Blueprint('channels', __name__)
|
||||||
|
|
@ -405,31 +406,73 @@ async def _update_voice_channel(channel_id: int, j: dict):
|
||||||
await _common_guild_chan(channel_id, j)
|
await _common_guild_chan(channel_id, j)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
|
||||||
|
if 'name' in j:
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE group_dm_channels
|
||||||
|
SET name = $1
|
||||||
|
WHERE id = $2
|
||||||
|
""", j['name'], channel_id)
|
||||||
|
|
||||||
|
await send_sys_message(
|
||||||
|
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'icon' in j:
|
||||||
|
new_icon = await app.icons.update(
|
||||||
|
'channel-icons', channel_id, j['icon'], always_icon=True
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE group_dm_channels
|
||||||
|
SET icon = $1
|
||||||
|
WHERE id = $2
|
||||||
|
""", new_icon.icon_hash, channel_id)
|
||||||
|
|
||||||
|
await send_sys_message(
|
||||||
|
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
|
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
|
||||||
async def update_channel(channel_id):
|
async def update_channel(channel_id):
|
||||||
"""Update a channel's information"""
|
"""Update a channel's information"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
if ctype not in GUILD_CHANS:
|
if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
||||||
raise ChannelNotFound('Can not edit non-guild channels.')
|
ChannelType.GROUP_DM):
|
||||||
|
raise ChannelNotFound('unable to edit unsupported chan type')
|
||||||
|
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_channels')
|
is_guild = ctype in GUILD_CHANS
|
||||||
j = validate(await request.get_json(), CHAN_UPDATE)
|
|
||||||
|
|
||||||
# TODO: categories?
|
if is_guild:
|
||||||
|
await channel_perm_check(user_id, channel_id, 'manage_channels')
|
||||||
|
|
||||||
|
j = validate(await request.get_json(),
|
||||||
|
CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
|
||||||
|
|
||||||
|
# TODO: categories
|
||||||
update_handler = {
|
update_handler = {
|
||||||
ChannelType.GUILD_TEXT: _update_text_channel,
|
ChannelType.GUILD_TEXT: _update_text_channel,
|
||||||
ChannelType.GUILD_VOICE: _update_voice_channel,
|
ChannelType.GUILD_VOICE: _update_voice_channel,
|
||||||
|
ChannelType.GROUP_DM: _update_group_dm,
|
||||||
}[ctype]
|
}[ctype]
|
||||||
|
|
||||||
await _update_channel_common(channel_id, guild_id, j)
|
if is_guild:
|
||||||
await update_handler(channel_id, j)
|
await _update_channel_common(channel_id, guild_id, j)
|
||||||
|
|
||||||
|
await update_handler(channel_id, j, user_id)
|
||||||
|
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
if is_guild:
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'guild', guild_id, 'CHANNEL_UPDATE', chan)
|
||||||
|
else:
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id, 'CHANNEL_UPDATE', chan)
|
||||||
|
|
||||||
await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan)
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
from litecord.enums import ChannelType, GUILD_CHANS
|
from litecord.enums import ChannelType, GUILD_CHANS
|
||||||
|
|
@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int):
|
||||||
raise Forbidden('You are not the owner of the guild')
|
raise Forbidden('You are not the owner of the guild')
|
||||||
|
|
||||||
|
|
||||||
async def channel_check(user_id, channel_id):
|
async def channel_check(user_id, channel_id, *,
|
||||||
|
only: Union[ChannelType, List[ChannelType]] = None):
|
||||||
"""Check if the current user is authorized
|
"""Check if the current user is authorized
|
||||||
to read the channel's information."""
|
to read the channel's information."""
|
||||||
chan_type = await app.storage.get_chan_type(channel_id)
|
chan_type = await app.storage.get_chan_type(channel_id)
|
||||||
|
|
||||||
if chan_type is None:
|
if chan_type is None:
|
||||||
raise ChannelNotFound(f'channel type not found')
|
raise ChannelNotFound('channel type not found')
|
||||||
|
|
||||||
ctype = ChannelType(chan_type)
|
ctype = ChannelType(chan_type)
|
||||||
|
|
||||||
|
if only and not isinstance(only, list):
|
||||||
|
only = [only]
|
||||||
|
|
||||||
|
if only and ctype not in only:
|
||||||
|
raise ChannelNotFound('invalid channel type')
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
guild_id = await app.db.fetchval("""
|
guild_id = await app.db.fetchval("""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
|
|
@ -77,6 +86,15 @@ async def channel_check(user_id, channel_id):
|
||||||
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
||||||
return ctype, peer_id
|
return ctype, peer_id
|
||||||
|
|
||||||
|
if ctype == ChannelType.GROUP_DM:
|
||||||
|
owner_id = await app.db.fetchval("""
|
||||||
|
SELECT owner_id
|
||||||
|
FROM group_dm_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
return ctype, owner_id
|
||||||
|
|
||||||
|
|
||||||
async def guild_perm_check(user_id, guild_id, permission: str):
|
async def guild_perm_check(user_id, guild_id, permission: str):
|
||||||
"""Check guild permissions for a user."""
|
"""Check guild permissions for a user."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
Litecord
|
||||||
|
Copyright (C) 2018-2019 Luna Mendes
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, version 3 of the License.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from quart import Blueprint, current_app as app, jsonify
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
from litecord.blueprints.auth import token_check
|
||||||
|
from litecord.blueprints.checks import channel_check
|
||||||
|
from litecord.enums import ChannelType, MessageType
|
||||||
|
from litecord.errors import BadRequest, Forbidden
|
||||||
|
from litecord.snowflake import get_snowflake
|
||||||
|
from litecord.system_messages import send_sys_message
|
||||||
|
from litecord.pubsub.channel import gdm_recipient_view
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
bp = Blueprint('dm_channels', __name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _raw_gdm_add(channel_id, user_id):
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO group_dm_members (id, member_id)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", channel_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _raw_gdm_remove(channel_id, user_id):
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM group_dm_members
|
||||||
|
WHERE id = $1 AND member_id = $2
|
||||||
|
""", channel_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _gdm_create(user_id, peer_id) -> int:
|
||||||
|
"""Create a group dm, given two users.
|
||||||
|
|
||||||
|
Returns the new GDM id.
|
||||||
|
"""
|
||||||
|
channel_id = get_snowflake()
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO channels (id, channel_type)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", channel_id, ChannelType.GROUP_DM.value)
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO group_dm_channels (id, owner_id, name, icon)
|
||||||
|
VALUES ($1, $2, NULL, NULL)
|
||||||
|
""", channel_id, user_id)
|
||||||
|
|
||||||
|
await _raw_gdm_add(channel_id, user_id)
|
||||||
|
await _raw_gdm_add(channel_id, peer_id)
|
||||||
|
|
||||||
|
await app.dispatcher.sub('channel', channel_id, user_id)
|
||||||
|
await app.dispatcher.sub('channel', channel_id, peer_id)
|
||||||
|
|
||||||
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan)
|
||||||
|
|
||||||
|
return channel_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
||||||
|
"""Add a recipient to a Group DM.
|
||||||
|
|
||||||
|
Dispatches:
|
||||||
|
- A system message with the join (depending of user_id)
|
||||||
|
- A CHANNEL_CREATE to the peer.
|
||||||
|
- A CHANNEL_UPDATE to all remaining recipients.
|
||||||
|
"""
|
||||||
|
await _raw_gdm_add(channel_id, peer_id)
|
||||||
|
|
||||||
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
# the reasoning behind gdm_recipient_view is in its docstring.
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id))
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id, 'CHANNEL_UPDATE', chan)
|
||||||
|
|
||||||
|
await app.dispatcher.sub('channel', peer_id)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
await send_sys_message(
|
||||||
|
app, channel_id, MessageType.RECIPIENT_ADD,
|
||||||
|
user_id, peer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
||||||
|
"""Remove a member from a GDM.
|
||||||
|
|
||||||
|
Dispatches:
|
||||||
|
- A system message with the leave or forced removal (depending if user_id)
|
||||||
|
exists or not.
|
||||||
|
- A CHANNEL_DELETE to the peer.
|
||||||
|
- A CHANNEL_UPDATE to all remaining recipients.
|
||||||
|
"""
|
||||||
|
await _raw_gdm_remove(channel_id, peer_id)
|
||||||
|
|
||||||
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id))
|
||||||
|
|
||||||
|
await app.dispatcher.unsub('channel', peer_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', {
|
||||||
|
'channel_id': str(channel_id),
|
||||||
|
'user': await app.storage.get_user(peer_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
await send_sys_message(
|
||||||
|
app, channel_id, MessageType.RECIPIENT_REMOVE,
|
||||||
|
user_id, peer_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['PUT'])
|
||||||
|
async def add_to_group_dm(dm_chan, peer_id):
|
||||||
|
"""Adds a member to a group dm OR creates a group dm."""
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# other_id is the owner of the group dm (gdm) if the
|
||||||
|
# given channel is a gdm
|
||||||
|
|
||||||
|
# other_id is the peer of the dm if the given channel is a dm
|
||||||
|
ctype, other_id = await channel_check(
|
||||||
|
user_id, dm_chan,
|
||||||
|
only=[ChannelType.DM, ChannelType.GROUP_DM]
|
||||||
|
)
|
||||||
|
|
||||||
|
# check relationship with the given user id
|
||||||
|
# and the user id making the request
|
||||||
|
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
||||||
|
|
||||||
|
if not friends:
|
||||||
|
raise BadRequest('Cant insert peer into dm')
|
||||||
|
|
||||||
|
if ctype == ChannelType.DM:
|
||||||
|
dm_chan = await _gdm_create(
|
||||||
|
user_id, other_id
|
||||||
|
)
|
||||||
|
|
||||||
|
await _gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
await app.storage.get_channel(dm_chan)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['DELETE'])
|
||||||
|
async def remove_from_group_dm(dm_chan, peer_id):
|
||||||
|
"""Remove users from group dm."""
|
||||||
|
user_id = await token_check()
|
||||||
|
_ctype, owner_id = await channel_check(
|
||||||
|
user_id, dm_chan, only=ChannelType.GROUP_DM
|
||||||
|
)
|
||||||
|
|
||||||
|
if owner_id != user_id:
|
||||||
|
raise Forbidden('You are now the owner of the group DM')
|
||||||
|
|
||||||
|
await _gdm_remove_recipient(dm_chan, peer_id)
|
||||||
|
return '', 204
|
||||||
|
|
@ -76,3 +76,9 @@ async def _get_user_avatar(user_id, avatar_file):
|
||||||
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
|
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
|
||||||
async def get_app_icon(application_id, icon_hash, ext):
|
async def get_app_icon(application_id, icon_hash, ext):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/channel-icons/<int:channel_id>/<icon_file>', methods=['GET'])
|
||||||
|
async def _get_gdm_icon(guild_id: int, icon_file: str):
|
||||||
|
icon_hash, ext = splitext_(icon_file)
|
||||||
|
return await send_icon('channel-icons', guild_id, icon_hash, ext=ext)
|
||||||
|
|
|
||||||
|
|
@ -384,9 +384,10 @@ async def get_profile(peer_id: int):
|
||||||
return '', 404
|
return '', 404
|
||||||
|
|
||||||
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
||||||
|
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
||||||
|
|
||||||
# don't return a proper card if no guilds are being shared.
|
# don't return a proper card if no guilds are being shared.
|
||||||
if not mutuals:
|
if not mutuals and not friends:
|
||||||
return '', 404
|
return '', 404
|
||||||
|
|
||||||
# actual premium status is determined by that
|
# actual premium status is determined by that
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,15 @@ class EventDispatcher:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
await self.subscribe(backend_str, key, identifier)
|
await self.subscribe(backend_str, key, identifier)
|
||||||
|
|
||||||
|
async def mass_sub(self, identifier: Any,
|
||||||
|
backends: List[tuple]):
|
||||||
|
"""Mass subscribe to many backends at once."""
|
||||||
|
for backend_str, keys in backends:
|
||||||
|
log.debug('subscribing {} to {} keys in backend {}',
|
||||||
|
identifier, len(keys), backend_str)
|
||||||
|
|
||||||
|
await self.sub_many(backend_str, identifier, keys)
|
||||||
|
|
||||||
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
||||||
"""Dispatch an event to the backend.
|
"""Dispatch an event to the backend.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,14 +77,14 @@ class LitecordError(Exception):
|
||||||
@property
|
@property
|
||||||
def message(self) -> str:
|
def message(self) -> str:
|
||||||
"""Get an error's message string."""
|
"""Get an error's message string."""
|
||||||
err_code = getattr(self, 'error_code', None)
|
|
||||||
|
|
||||||
if err_code is not None:
|
|
||||||
return ERR_MSG_MAP.get(err_code) or self.args[0]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self.args[0]
|
return self.args[0]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
|
err_code = getattr(self, 'error_code', None)
|
||||||
|
|
||||||
|
if err_code is not None:
|
||||||
|
return ERR_MSG_MAP.get(err_code) or self.args[0]
|
||||||
|
|
||||||
return repr(self)
|
return repr(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -372,11 +372,16 @@ class GatewayWebsocket:
|
||||||
# user, fetch info
|
# user, fetch info
|
||||||
uready = await self._user_ready()
|
uready = await self._user_ready()
|
||||||
|
|
||||||
|
private_channels = (
|
||||||
|
await self.user_storage.get_dms(user_id) +
|
||||||
|
await self.user_storage.get_gdms(user_id)
|
||||||
|
)
|
||||||
|
|
||||||
await self.dispatch('READY', {**{
|
await self.dispatch('READY', {**{
|
||||||
'v': 6,
|
'v': 6,
|
||||||
'user': user,
|
'user': user,
|
||||||
|
|
||||||
'private_channels': await self.user_storage.get_dms(user_id),
|
'private_channels': private_channels,
|
||||||
|
|
||||||
'guilds': guilds,
|
'guilds': guilds,
|
||||||
'session_id': self.state.session_id,
|
'session_id': self.state.session_id,
|
||||||
|
|
@ -437,17 +442,24 @@ class GatewayWebsocket:
|
||||||
by GuildDispatcher.sub
|
by GuildDispatcher.sub
|
||||||
"""
|
"""
|
||||||
user_id = self.state.user_id
|
user_id = self.state.user_id
|
||||||
|
|
||||||
guild_ids = await self._guild_ids()
|
guild_ids = await self._guild_ids()
|
||||||
log.info('subscribing to {} guilds', len(guild_ids))
|
|
||||||
await self.ext.dispatcher.sub_many('guild', user_id, guild_ids)
|
|
||||||
|
|
||||||
# subscribe the user to all dms they have OPENED.
|
# subscribe the user to all dms they have OPENED.
|
||||||
dms = await self.user_storage.get_dms(user_id)
|
dms = await self.user_storage.get_dms(user_id)
|
||||||
dm_ids = [int(dm['id']) for dm in dms]
|
dm_ids = [int(dm['id']) for dm in dms]
|
||||||
|
|
||||||
|
# fetch all group dms the user is a member of.
|
||||||
|
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
||||||
|
|
||||||
|
log.info('subscribing to {} guilds', len(guild_ids))
|
||||||
log.info('subscribing to {} dms', len(dm_ids))
|
log.info('subscribing to {} dms', len(dm_ids))
|
||||||
await self.ext.dispatcher.sub_many('channel', user_id, dm_ids)
|
log.info('subscribing to {} group dms', len(gdm_ids))
|
||||||
|
|
||||||
|
await self.ext.dispatcher.mass_sub(user_id, [
|
||||||
|
('guild', guild_ids),
|
||||||
|
('channel', dm_ids),
|
||||||
|
('channel', gdm_ids)
|
||||||
|
])
|
||||||
|
|
||||||
if not self.state.bot:
|
if not self.state.bot:
|
||||||
# subscribe to all friends
|
# subscribe to all friends
|
||||||
|
|
|
||||||
|
|
@ -161,21 +161,18 @@ def parse_data_uri(string) -> tuple:
|
||||||
def _gen_update_sql(scope: str) -> str:
|
def _gen_update_sql(scope: str) -> str:
|
||||||
field = {
|
field = {
|
||||||
'user': 'avatar',
|
'user': 'avatar',
|
||||||
'guild': 'icon'
|
'guild': 'icon',
|
||||||
|
'channel-icons': 'icon',
|
||||||
}[scope]
|
}[scope]
|
||||||
|
|
||||||
table = {
|
table = {
|
||||||
'user': 'users',
|
'user': 'users',
|
||||||
'guild': 'guilds'
|
'guild': 'guilds',
|
||||||
}[scope]
|
'channel-icons': 'group_dm_channels'
|
||||||
|
|
||||||
col = {
|
|
||||||
'user': 'id',
|
|
||||||
'guild': 'id'
|
|
||||||
}[scope]
|
}[scope]
|
||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
SELECT {field} FROM {table} WHERE {col} = $1
|
SELECT {field} FROM {table} WHERE id = $1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -277,6 +274,9 @@ class IconManager:
|
||||||
|
|
||||||
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon:
|
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon:
|
||||||
"""Get any icon."""
|
"""Get any icon."""
|
||||||
|
if icon_hash is None:
|
||||||
|
return None
|
||||||
|
|
||||||
log.debug('GET {} {} {}', scope, key, icon_hash)
|
log.debug('GET {} {} {}', scope, key, icon_hash)
|
||||||
key = str(key)
|
key = str(key)
|
||||||
|
|
||||||
|
|
@ -409,6 +409,12 @@ class IconManager:
|
||||||
WHERE icon = $1
|
WHERE icon = $1
|
||||||
""", icon.icon_hash)
|
""", icon.icon_hash)
|
||||||
|
|
||||||
|
await self.storage.db.execute("""
|
||||||
|
UPDATE group_dm_channels
|
||||||
|
SET icon = NULL
|
||||||
|
WHERE icon = $1
|
||||||
|
""", icon.icon_hash)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute("""
|
||||||
DELETE FROM icons
|
DELETE FROM icons
|
||||||
WHERE hash = $1
|
WHERE hash = $1
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,32 @@ from typing import Any
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .dispatcher import DispatcherWithState
|
from .dispatcher import DispatcherWithState
|
||||||
|
from litecord.enums import ChannelType
|
||||||
|
from litecord.utils import index_by_func
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
||||||
|
"""Create a copy of the original channel object that doesn't
|
||||||
|
show the user we are dispatching it to.
|
||||||
|
|
||||||
|
this only applies to group dms and discords' api design that says
|
||||||
|
a group dms' recipients must not show the original user.
|
||||||
|
"""
|
||||||
|
# make a copy or the original channel object
|
||||||
|
data = dict(orig)
|
||||||
|
|
||||||
|
idx = index_by_func(
|
||||||
|
lambda user: user['id'] == str(user_id),
|
||||||
|
data['recipients']
|
||||||
|
)
|
||||||
|
|
||||||
|
data['recipients'].pop(idx)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ChannelDispatcher(DispatcherWithState):
|
class ChannelDispatcher(DispatcherWithState):
|
||||||
"""Main channel Pub/Sub logic."""
|
"""Main channel Pub/Sub logic."""
|
||||||
KEY_TYPE = int
|
KEY_TYPE = int
|
||||||
|
|
@ -62,7 +84,19 @@ class ChannelDispatcher(DispatcherWithState):
|
||||||
await self.unsub(channel_id, user_id)
|
await self.unsub(channel_id, user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cur_sess = await self._dispatch_states(states, event, data)
|
cur_sess = 0
|
||||||
|
|
||||||
|
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
||||||
|
and data.get('type') == ChannelType.GROUP_DM.value:
|
||||||
|
# we edit the channel payload so it doesn't show
|
||||||
|
# the user as a recipient
|
||||||
|
|
||||||
|
new_data = gdm_recipient_view(data, user_id)
|
||||||
|
cur_sess = await self._dispatch_states(
|
||||||
|
states, event, new_data)
|
||||||
|
else:
|
||||||
|
cur_sess = await self._dispatch_states(
|
||||||
|
states, event, data)
|
||||||
|
|
||||||
sessions.extend(cur_sess)
|
sessions.extend(cur_sess)
|
||||||
dispatched += len(cur_sess)
|
dispatched += len(cur_sess)
|
||||||
|
|
|
||||||
|
|
@ -581,6 +581,14 @@ CREATE_GROUP_DM = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GROUP_DM_UPDATE = {
|
||||||
|
'name': {
|
||||||
|
'type': 'guild_name',
|
||||||
|
'required': False
|
||||||
|
},
|
||||||
|
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
|
||||||
|
}
|
||||||
|
|
||||||
SPECIFIC_FRIEND = {
|
SPECIFIC_FRIEND = {
|
||||||
'username': {'type': 'username'},
|
'username': {'type': 'username'},
|
||||||
'discriminator': {'type': 'discriminator'}
|
'discriminator': {'type': 'discriminator'}
|
||||||
|
|
|
||||||
|
|
@ -353,7 +353,38 @@ class Storage:
|
||||||
|
|
||||||
return list(map(_overwrite_convert, overwrite_rows))
|
return list(map(_overwrite_convert, overwrite_rows))
|
||||||
|
|
||||||
async def get_channel(self, channel_id: int) -> Dict[str, Any]:
|
async def gdm_recipient_ids(self, channel_id: int) -> List[int]:
|
||||||
|
"""Get the list of user IDs that are recipients of the
|
||||||
|
given Group DM."""
|
||||||
|
user_ids = await self.db.fetch("""
|
||||||
|
SELECT member_id
|
||||||
|
FROM group_dm_members
|
||||||
|
JOIN users
|
||||||
|
ON member_id = users.id
|
||||||
|
WHERE group_dm_members.id = $1
|
||||||
|
ORDER BY username DESC
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
return [r['member_id'] for r in user_ids]
|
||||||
|
|
||||||
|
async def _gdm_recipients(self, channel_id: int,
|
||||||
|
reference_id: int = None) -> List[int]:
|
||||||
|
"""Get the list of users that are recipients of the
|
||||||
|
given Group DM."""
|
||||||
|
recipients = await self.gdm_recipient_ids(channel_id)
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for user_id in recipients:
|
||||||
|
if user_id == reference_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
res.append(
|
||||||
|
await self.get_user(user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]:
|
||||||
"""Fetch a single channel's information."""
|
"""Fetch a single channel's information."""
|
||||||
chan_type = await self.get_chan_type(channel_id)
|
chan_type = await self.get_chan_type(channel_id)
|
||||||
ctype = ChannelType(chan_type)
|
ctype = ChannelType(chan_type)
|
||||||
|
|
@ -387,7 +418,8 @@ class Storage:
|
||||||
drow['type'] = chan_type
|
drow['type'] = chan_type
|
||||||
|
|
||||||
drow['last_message_id'] = await self.chan_last_message_str(
|
drow['last_message_id'] = await self.chan_last_message_str(
|
||||||
channel_id)
|
channel_id
|
||||||
|
)
|
||||||
|
|
||||||
# dms have just two recipients.
|
# dms have just two recipients.
|
||||||
drow['recipients'] = [
|
drow['recipients'] = [
|
||||||
|
|
@ -401,8 +433,22 @@ class Storage:
|
||||||
drow['id'] = str(drow['id'])
|
drow['id'] = str(drow['id'])
|
||||||
return drow
|
return drow
|
||||||
elif ctype == ChannelType.GROUP_DM:
|
elif ctype == ChannelType.GROUP_DM:
|
||||||
# TODO: group dms
|
gdm_row = await self.db.fetchrow("""
|
||||||
pass
|
SELECT id::text, owner_id::text, name, icon
|
||||||
|
FROM group_dm_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
drow = dict(gdm_row)
|
||||||
|
drow['type'] = chan_type
|
||||||
|
drow['recipients'] = await self._gdm_recipients(
|
||||||
|
channel_id, kwargs.get('user_id')
|
||||||
|
)
|
||||||
|
drow['last_message_id'] = await self.chan_last_message_str(
|
||||||
|
channel_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return drow
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.snowflake import get_snowflake
|
||||||
from litecord.enums import MessageType
|
from litecord.enums import MessageType
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
async def _handle_pin_msg(app, channel_id, pinned_id, author_id):
|
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
||||||
"""Handle a message pin."""
|
"""Handle a message pin."""
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
||||||
|
|
@ -41,12 +44,108 @@ async def _handle_pin_msg(app, channel_id, pinned_id, author_id):
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: decrease repetition between add and remove handlers
|
||||||
|
async def _handle_recp_add(app, channel_id, author_id, peer_id):
|
||||||
|
new_id = get_snowflake()
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO messages
|
||||||
|
(id, channel_id, author_id, webhook_id,
|
||||||
|
content, message_type)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, NULL, $4, $5)
|
||||||
|
""",
|
||||||
|
new_id, channel_id, author_id,
|
||||||
|
f'<@{peer_id}>',
|
||||||
|
MessageType.RECIPIENT_ADD.value
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
||||||
|
new_id = get_snowflake()
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO messages
|
||||||
|
(id, channel_id, author_id, webhook_id,
|
||||||
|
content, message_type)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, NULL, $4, $5)
|
||||||
|
""",
|
||||||
|
new_id, channel_id, author_id,
|
||||||
|
f'<@{peer_id}>',
|
||||||
|
MessageType.RECIPIENT_REMOVE.value
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_gdm_name_edit(app, channel_id, author_id):
|
||||||
|
new_id = get_snowflake()
|
||||||
|
|
||||||
|
gdm_name = await app.db.fetchval("""
|
||||||
|
SELECT name FROM group_dm_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
if not gdm_name:
|
||||||
|
log.warning('no gdm name found for sys message')
|
||||||
|
return
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO messages
|
||||||
|
(id, channel_id, author_id, webhook_id,
|
||||||
|
content, message_type)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, NULL, $4, $5)
|
||||||
|
""",
|
||||||
|
new_id, channel_id, author_id,
|
||||||
|
gdm_name,
|
||||||
|
MessageType.CHANNEL_NAME_CHANGE.value
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_gdm_icon_edit(app, channel_id, author_id):
|
||||||
|
new_id = get_snowflake()
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO messages
|
||||||
|
(id, channel_id, author_id, webhook_id,
|
||||||
|
content, message_type)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, NULL, $4, $5)
|
||||||
|
""",
|
||||||
|
new_id, channel_id, author_id,
|
||||||
|
'',
|
||||||
|
MessageType.CHANNEL_ICON_CHANGE.value
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
async def send_sys_message(app, channel_id: int, m_type: MessageType,
|
async def send_sys_message(app, channel_id: int, m_type: MessageType,
|
||||||
*args, **kwargs) -> int:
|
*args, **kwargs) -> int:
|
||||||
"""Send a system message."""
|
"""Send a system message."""
|
||||||
handler = {
|
try:
|
||||||
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
|
handler = {
|
||||||
}[m_type]
|
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
|
||||||
|
|
||||||
|
# gdm specific
|
||||||
|
MessageType.RECIPIENT_ADD: _handle_recp_add,
|
||||||
|
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
|
||||||
|
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
|
||||||
|
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit
|
||||||
|
}[m_type]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError('Invalid system message type')
|
||||||
|
|
||||||
message_id = await handler(app, channel_id, *args, **kwargs)
|
message_id = await handler(app, channel_id, *args, **kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -338,3 +338,26 @@ class UserStorage:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
""", user_id, peer_id)
|
""", user_id, peer_id)
|
||||||
|
|
||||||
|
async def get_gdms_internal(self, user_id) -> List[int]:
|
||||||
|
"""Return a list of Group DM IDs the user is a member of."""
|
||||||
|
rows = await self.db.fetch("""
|
||||||
|
SELECT id
|
||||||
|
FROM group_dm_members
|
||||||
|
WHERE member_id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
|
return [r['id'] for r in rows]
|
||||||
|
|
||||||
|
async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
|
||||||
|
"""Get list of group DMs a user is in."""
|
||||||
|
gdm_ids = await self.get_gdms_internal(user_id)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for gdm_id in gdm_ids:
|
||||||
|
res.append(
|
||||||
|
await self.storage.get_channel(gdm_id, user_id=user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
||||||
3
run.py
3
run.py
|
|
@ -35,7 +35,7 @@ import config
|
||||||
from litecord.blueprints import (
|
from litecord.blueprints import (
|
||||||
gateway, auth, users, guilds, channels, webhooks, science,
|
gateway, auth, users, guilds, channels, webhooks, science,
|
||||||
voice, invites, relationships, dms, icons, nodeinfo, static,
|
voice, invites, relationships, dms, icons, nodeinfo, static,
|
||||||
attachments
|
attachments, dm_channels
|
||||||
)
|
)
|
||||||
|
|
||||||
# those blueprints are separated from the "main" ones
|
# those blueprints are separated from the "main" ones
|
||||||
|
|
@ -128,6 +128,7 @@ def set_blueprints(app_):
|
||||||
voice: '/voice',
|
voice: '/voice',
|
||||||
invites: None,
|
invites: None,
|
||||||
dms: '/users',
|
dms: '/users',
|
||||||
|
dm_channels: '/channels',
|
||||||
|
|
||||||
fake_store: None,
|
fake_store: None,
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue