mirror of https://gitlab.com/litecord/litecord.git
blueprints.users, channels: basic dm operations
SQL for instances:
```sql
ALTER TABLE messages
ADD CONSTRAINT messages_channels_fkey
FOREIGN KEY (channel_id)
REFERENCES channels (id)
ON DELETE CASCADE;
ALTER TABLE channel_pins ADD CONSTRAINT pins_channels_fkey
FOREIGN KEY (channel_id)
REFERENCES channels (id)
ON DELETE CASCADE;
ALTER TABLE channel_pins ADD CONSTRAINT pins_messages_fkey
FOREIGN KEY (message_id)
REFERENCES messages (id)
ON DELETE CASCADE;
```
After that, rerun `schema.sql`.
blueprints.channels:
- check dms on channel_check
- add DELETE /api/v6/channels/<int:channel_id>
blueprints.users:
- add event dispatching for leaving guilds
- add GET /api/v6/users/@me/channels, for DM fetching
- add POST /api/v6/users/@me/channels, for DM creation
- add POST /api/v6/users/<int:user_id>/channels for DM / Group DM
creation
- schemas: add CREATE, CREATE_GROUP_DM
- storage: add last_message_id fetching for channels
- storage: add support for DMs in get_channel
- storage: add Storage.get_dm, Storage.get_dms, Storage.get_all_dms
- schema.sql: add dm_channel_state table
- schema.sql: add constriants for messages.channel_id and channel_pins
This commit is contained in:
parent
88bc4ceca8
commit
61e36f244b
|
|
@ -5,7 +5,7 @@ from logbook import Logger
|
||||||
|
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
from ..snowflake import get_snowflake, snowflake_datetime
|
from ..snowflake import get_snowflake, snowflake_datetime
|
||||||
from ..enums import ChannelType, MessageType
|
from ..enums import ChannelType, MessageType, GUILD_CHANS
|
||||||
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
|
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
|
||||||
from ..schemas import validate, MESSAGE_CREATE
|
from ..schemas import validate, MESSAGE_CREATE
|
||||||
|
|
||||||
|
|
@ -18,13 +18,14 @@ bp = Blueprint('channels', __name__)
|
||||||
async def channel_check(user_id, channel_id):
|
async def channel_check(user_id, channel_id):
|
||||||
"""Check if the current user is authorized
|
"""Check if the current user is authorized
|
||||||
to read the channel's information."""
|
to read the channel's information."""
|
||||||
ctype = await app.storage.get_chan_type(channel_id)
|
chan_type = await app.storage.get_chan_type(channel_id)
|
||||||
|
|
||||||
if ctype is None:
|
if chan_type is None:
|
||||||
raise ChannelNotFound(f'channel type not found')
|
raise ChannelNotFound(f'channel type not found')
|
||||||
|
|
||||||
if ChannelType(ctype) in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
ctype = ChannelType(chan_type)
|
||||||
ChannelType.GUILD_CATEGORY):
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
guild_id = await app.db.fetchval("""
|
guild_id = await app.db.fetchval("""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
FROM guild_channels
|
FROM guild_channels
|
||||||
|
|
@ -34,10 +35,25 @@ async def channel_check(user_id, channel_id):
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
return guild_id
|
return guild_id
|
||||||
|
|
||||||
|
if ctype == ChannelType.DM:
|
||||||
|
parties = await app.db.fetchval("""
|
||||||
|
SELECT party1_id, party2_id
|
||||||
|
FROM dm_channels
|
||||||
|
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
||||||
|
""", channel_id, user_id)
|
||||||
|
|
||||||
|
# get the id of the other party
|
||||||
|
parties.remove(user_id)
|
||||||
|
return parties[0]
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>', methods=['GET'])
|
@bp.route('/<int:channel_id>', methods=['GET'])
|
||||||
async def get_channel(channel_id):
|
async def get_channel(channel_id):
|
||||||
|
"""Get a single channel's information"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# channel_check takes care of checking
|
||||||
|
# DMs and group DMs
|
||||||
await channel_check(user_id, channel_id)
|
await channel_check(user_id, channel_id)
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
|
@ -47,6 +63,129 @@ async def get_channel(channel_id):
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
||||||
|
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
|
||||||
|
"""Update a guild's channel id field to NULL,
|
||||||
|
if it was set to the given channel id before."""
|
||||||
|
return await app.db.execute(f"""
|
||||||
|
UPDATE guilds
|
||||||
|
SET {field} = NULL
|
||||||
|
WHERE guilds.id = $1 AND {field} = $2
|
||||||
|
""", guild_id, channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
||||||
|
res_embed = await __guild_chan_sql(
|
||||||
|
guild_id, channel_id, 'embed_channel_id')
|
||||||
|
|
||||||
|
res_widget = await __guild_chan_sql(
|
||||||
|
guild_id, channel_id, 'widget_channel_id')
|
||||||
|
|
||||||
|
res_system = await __guild_chan_sql(
|
||||||
|
guild_id, channel_id, 'system_channel_id')
|
||||||
|
|
||||||
|
# if none of them were actually updated,
|
||||||
|
# ignore and dont dispatch anything
|
||||||
|
if 'UPDATE 1' not in (res_embed, res_widget, res_system):
|
||||||
|
return
|
||||||
|
|
||||||
|
# at least one of the fields were updated,
|
||||||
|
# dispatch GUILD_UPDATE
|
||||||
|
guild = await app.storage.get_guild(guild_id)
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'GUILD_UPDATE', guild)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
||||||
|
res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id')
|
||||||
|
|
||||||
|
# guild didnt update
|
||||||
|
if res == 'UPDATE 0':
|
||||||
|
return
|
||||||
|
|
||||||
|
guild = await app.storage.get_guild(guild_id)
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'GUILD_UPDATE', guild)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||||
|
# get all channels that were childs of the category
|
||||||
|
childs = await app.db.fetch("""
|
||||||
|
SELECT id
|
||||||
|
FROM guild_channels
|
||||||
|
WHERE guild_id = $1 AND parent_id = $2
|
||||||
|
""", guild_id, channel_id)
|
||||||
|
childs = [c['id'] for c in childs]
|
||||||
|
|
||||||
|
# update every child channel to parent_id = NULL
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE guild_channels
|
||||||
|
SET parent_id = NULL
|
||||||
|
WHERE guild_id = $1 AND parent_id = $2
|
||||||
|
""", guild_id, channel_id)
|
||||||
|
|
||||||
|
# tell all people in the guild of the category removal
|
||||||
|
for child_id in childs:
|
||||||
|
child = await app.storage.get_channel(child_id)
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'CHANNEL_UPDATE', child
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>', methods=['DELETE'])
|
||||||
|
async def close_channel(channel_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
|
||||||
|
chan_type = await app.storage.get_chan_type(channel_id)
|
||||||
|
ctype = ChannelType(chan_type)
|
||||||
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
|
guild_id = await channel_check(user_id, channel_id)
|
||||||
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
# the selected function will take care of checking
|
||||||
|
# the sanity of tables once the channel becomes deleted.
|
||||||
|
_update_func = {
|
||||||
|
ChannelType.GUILD_TEXT: _update_guild_chan_text,
|
||||||
|
ChannelType.GUILD_VOICE: _update_guild_chan_voice,
|
||||||
|
ChannelType.GUILD_CATEGORY: _update_guild_chan_cat,
|
||||||
|
}[ctype]
|
||||||
|
|
||||||
|
await _update_func(guild_id, channel_id)
|
||||||
|
|
||||||
|
# this should take care of deleting all messages as well
|
||||||
|
# (if any)
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM guild_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'CHANNEL_DELETE', chan)
|
||||||
|
return jsonify(chan)
|
||||||
|
|
||||||
|
if ctype == ChannelType.DM:
|
||||||
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
# we don't ever actually delete DM channels off the database.
|
||||||
|
# instead, we close the channel for the user that is making
|
||||||
|
# the request via removing the link between them and
|
||||||
|
# the channel on dm_channel_state
|
||||||
|
await app.db.execute("""
|
||||||
|
DELETE FROM dm_channel_state (user_id, dm_id)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", user_id, channel_id)
|
||||||
|
|
||||||
|
# nothing happens to the other party of the dm channel
|
||||||
|
await app.dispacher.dispatch_user(user_id, 'CHANNEL_DELETE', chan)
|
||||||
|
return jsonify(chan)
|
||||||
|
|
||||||
|
if ctype == ChannelType.GROUP_DM:
|
||||||
|
# TODO: group dm
|
||||||
|
pass
|
||||||
|
|
||||||
|
return '', 404
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
||||||
async def get_messages(channel_id):
|
async def get_messages(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -82,7 +221,6 @@ async def get_single_message(channel_id, message_id):
|
||||||
await channel_check(user_id, channel_id)
|
await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
# TODO: check READ_MESSAGE_HISTORY permissions
|
# TODO: check READ_MESSAGE_HISTORY permissions
|
||||||
|
|
||||||
message = await app.storage.get_message(message_id)
|
message = await app.storage.get_message(message_id)
|
||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
|
|
@ -120,6 +258,8 @@ async def create_message(channel_id):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: dispatch_channel
|
# TODO: dispatch_channel
|
||||||
|
# we really need dispatch_channel to make dm messages work,
|
||||||
|
# since they aren't part of any existing guild.
|
||||||
payload = await app.storage.get_message(message_id)
|
payload = await app.storage.get_message(message_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
|
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
|
||||||
|
|
||||||
|
|
@ -266,6 +406,7 @@ async def delete_pin(channel_id, message_id):
|
||||||
|
|
||||||
timestamp = snowflake_datetime(row['message_id'])
|
timestamp = snowflake_datetime(row['message_id'])
|
||||||
|
|
||||||
|
# TODO: dispatch_channel
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
|
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
|
||||||
'channel_id': str(channel_id),
|
'channel_id': str(channel_id),
|
||||||
'last_pin_timestamp': timestamp.isoformat()
|
'last_pin_timestamp': timestamp.isoformat()
|
||||||
|
|
@ -279,6 +420,7 @@ async def trigger_typing(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
guild_id = await channel_check(user_id, channel_id)
|
guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
# TODO: dispatch_channel
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', {
|
await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', {
|
||||||
'channel_id': str(channel_id),
|
'channel_id': str(channel_id),
|
||||||
'user_id': str(user_id),
|
'user_id': str(user_id),
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,11 @@ from quart import Blueprint, jsonify, request, current_app as app
|
||||||
from asyncpg import UniqueViolationError
|
from asyncpg import UniqueViolationError
|
||||||
|
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
|
from ..snowflake import get_snowflake
|
||||||
from ..errors import Forbidden, BadRequest
|
from ..errors import Forbidden, BadRequest
|
||||||
from ..schemas import validate, USER_SETTINGS
|
from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM
|
||||||
|
|
||||||
|
from .guilds import guild_check
|
||||||
|
|
||||||
bp = Blueprint('user', __name__)
|
bp = Blueprint('user', __name__)
|
||||||
|
|
||||||
|
|
@ -84,15 +87,31 @@ async def get_me_guilds():
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
|
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
|
||||||
async def leave_guild(guild_id):
|
async def leave_guild(guild_id: int):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
DELETE FROM members
|
DELETE FROM members
|
||||||
WHERE user_id = $1 AND guild_id = $2
|
WHERE user_id = $1 AND guild_id = $2
|
||||||
""", user_id, guild_id)
|
""", user_id, guild_id)
|
||||||
|
|
||||||
# TODO: something to dispatch events to the users
|
# first dispatch guild delete to the user,
|
||||||
|
# then remove from the guild,
|
||||||
|
# then tell the others that the member was removed
|
||||||
|
await app.dispatcher.dispatch_user_guild(
|
||||||
|
user_id, guild_id, 'GUILD_DELETE', {
|
||||||
|
'id': str(guild_id),
|
||||||
|
'unavailable': False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.dispatcher.unsub_guild(guild_id, user_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild('GUILD_MEMBER_REMOVE', {
|
||||||
|
'guild_id': str(guild_id),
|
||||||
|
'user': await app.storage.get_user(user_id)
|
||||||
|
})
|
||||||
|
|
||||||
return '', 204
|
return '', 204
|
||||||
|
|
||||||
|
|
@ -102,14 +121,75 @@ async def get_connections():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# @bp.route('/@me/channels', methods=['GET'])
|
@bp.route('/@me/channels', methods=['GET'])
|
||||||
async def get_dms():
|
async def get_dms():
|
||||||
pass
|
user_id = await token_check()
|
||||||
|
dms = await app.storage.get_dms(user_id)
|
||||||
|
return jsonify(dms)
|
||||||
|
|
||||||
|
|
||||||
# @bp.route('/@me/channels', methods=['POST'])
|
async def try_dm_state(user_id, dm_id):
|
||||||
|
"""Try insertin the user into the dm state
|
||||||
|
for the given DM."""
|
||||||
|
try:
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO dm_channel_state (id, dm_id)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", user_id, dm_id)
|
||||||
|
except UniqueViolationError:
|
||||||
|
# if already in state, ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def create_dm(user_id, recipient_id):
|
||||||
|
dm_id = get_snowflake()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
""", dm_id, user_id, recipient_id)
|
||||||
|
|
||||||
|
await try_dm_state(user_id, dm_id)
|
||||||
|
|
||||||
|
except UniqueViolationError:
|
||||||
|
# the dm already exists
|
||||||
|
dm_id = await app.db.fetchval("""
|
||||||
|
SELECT id
|
||||||
|
FROM dm_channels
|
||||||
|
WHERE (party1_id = $1 OR party2_id = $1) AND
|
||||||
|
(party2_id = $2 OR party2_id = $2)
|
||||||
|
""", user_id, recipient_id)
|
||||||
|
|
||||||
|
dm = await app.storage.get_dm(dm_id, user_id)
|
||||||
|
return jsonify(dm)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/@me/channels', methods=['POST'])
|
||||||
async def start_dm():
|
async def start_dm():
|
||||||
pass
|
"""Create a DM with a user."""
|
||||||
|
user_id = await token_check()
|
||||||
|
j = validate(await request.get_json(), CREATE_DM)
|
||||||
|
recipient_id = j['recipient_id']
|
||||||
|
|
||||||
|
return await create_dm(user_id, recipient_id)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:user_id>/channels', methods=['POST'])
|
||||||
|
async def create_group_dm(p_user_id: int):
|
||||||
|
"""Create a DM or a Group DM with user(s)."""
|
||||||
|
user_id = await token_check()
|
||||||
|
assert user_id == p_user_id
|
||||||
|
|
||||||
|
j = validate(await request.get_json(), CREATE_GROUP_DM)
|
||||||
|
recipients = j['recipients']
|
||||||
|
|
||||||
|
if list(recipients) == 1:
|
||||||
|
# its a group dm with 1 user... a dm!
|
||||||
|
return await create_dm(user_id, int(recipients[0]))
|
||||||
|
|
||||||
|
# TODO: group dms
|
||||||
|
return '', 500
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
|
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,11 @@ class ChannelType(EasyEnum):
|
||||||
GUILD_CATEGORY = 4
|
GUILD_CATEGORY = 4
|
||||||
|
|
||||||
|
|
||||||
|
GUILD_CHANS = (ChannelType.GUILD_TEXT,
|
||||||
|
ChannelType.GUILD_VOICE,
|
||||||
|
ChannelType.GUILD_CATEGORY)
|
||||||
|
|
||||||
|
|
||||||
class ActivityType(EasyEnum):
|
class ActivityType(EasyEnum):
|
||||||
PLAYING = 0
|
PLAYING = 0
|
||||||
STREAMING = 1
|
STREAMING = 1
|
||||||
|
|
|
||||||
|
|
@ -279,3 +279,16 @@ RELATIONSHIP = {
|
||||||
'default': RelationshipType.FRIEND.value
|
'default': RelationshipType.FRIEND.value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CREATE_DM = {
|
||||||
|
'recipient_id': {
|
||||||
|
'type': 'snowflake',
|
||||||
|
'required': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CREATE_GROUP_DM = {
|
||||||
|
'type': 'list',
|
||||||
|
'required': True,
|
||||||
|
'schema': {'type': 'snowflake'}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,16 @@ async def _set_json(con):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
|
||||||
|
"""Filter recipients in a list of recipients, removing
|
||||||
|
the one that is reundant (ourselves)."""
|
||||||
|
user_id = str(user_id)
|
||||||
|
|
||||||
|
return filter(
|
||||||
|
lambda recipient: recipient['id'] != user_id,
|
||||||
|
recipients)
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
"""Class for common SQL statements."""
|
"""Class for common SQL statements."""
|
||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
|
|
@ -215,12 +225,19 @@ class Storage:
|
||||||
members = await self.get_member_multi(guild_id, mids)
|
members = await self.get_member_multi(guild_id, mids)
|
||||||
return members
|
return members
|
||||||
|
|
||||||
|
async def _chan_last_message(self, channel_id: int):
|
||||||
|
return await self.db.fetch("""
|
||||||
|
SELECT MAX(id)
|
||||||
|
FROM messages
|
||||||
|
WHERE channel_id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
async def _channels_extra(self, row) -> Dict:
|
async def _channels_extra(self, row) -> Dict:
|
||||||
"""Fill in more information about a channel."""
|
"""Fill in more information about a channel."""
|
||||||
channel_type = row['type']
|
channel_type = row['type']
|
||||||
|
|
||||||
# TODO: dm and group dm?
|
|
||||||
chan_type = ChannelType(channel_type)
|
chan_type = ChannelType(channel_type)
|
||||||
|
|
||||||
if chan_type == ChannelType.GUILD_TEXT:
|
if chan_type == ChannelType.GUILD_TEXT:
|
||||||
topic = await self.db.fetchval("""
|
topic = await self.db.fetchval("""
|
||||||
SELECT topic FROM guild_text_channels
|
SELECT topic FROM guild_text_channels
|
||||||
|
|
@ -229,6 +246,8 @@ class Storage:
|
||||||
|
|
||||||
return {**row, **{
|
return {**row, **{
|
||||||
'topic': topic,
|
'topic': topic,
|
||||||
|
'last_message_id': str(
|
||||||
|
await self._chan_last_message(row['id']))
|
||||||
}}
|
}}
|
||||||
elif chan_type == ChannelType.GUILD_VOICE:
|
elif chan_type == ChannelType.GUILD_VOICE:
|
||||||
vrow = await self.db.fetchval("""
|
vrow = await self.db.fetchval("""
|
||||||
|
|
@ -240,7 +259,8 @@ class Storage:
|
||||||
|
|
||||||
log.warning('unknown channel type: {}', chan_type)
|
log.warning('unknown channel type: {}', chan_type)
|
||||||
|
|
||||||
async def get_chan_type(self, channel_id) -> int:
|
async def get_chan_type(self, channel_id: int) -> int:
|
||||||
|
"""Get the channel type integer, given channel ID."""
|
||||||
return await self.db.fetchval("""
|
return await self.db.fetchval("""
|
||||||
SELECT channel_type
|
SELECT channel_type
|
||||||
FROM channels
|
FROM channels
|
||||||
|
|
@ -275,13 +295,14 @@ class Storage:
|
||||||
|
|
||||||
return list(map(_overwrite_convert, overwrite_rows))
|
return list(map(_overwrite_convert, overwrite_rows))
|
||||||
|
|
||||||
async def get_channel(self, channel_id) -> Dict[str, Any]:
|
async def get_channel(self, channel_id: int) -> Dict[str, Any]:
|
||||||
"""Fetch a single channel's information."""
|
"""Fetch a single channel's information."""
|
||||||
chan_type = await self.get_chan_type(channel_id)
|
chan_type = await self.get_chan_type(channel_id)
|
||||||
|
ctype = ChannelType(chan_type)
|
||||||
|
|
||||||
if ChannelType(chan_type) in (ChannelType.GUILD_TEXT,
|
if ctype in (ChannelType.GUILD_TEXT,
|
||||||
ChannelType.GUILD_VOICE,
|
ChannelType.GUILD_VOICE,
|
||||||
ChannelType.GUILD_CATEGORY):
|
ChannelType.GUILD_CATEGORY):
|
||||||
base = await self.db.fetchrow("""
|
base = await self.db.fetchrow("""
|
||||||
SELECT id, guild_id::text, parent_id, name, position, nsfw
|
SELECT id, guild_id::text, parent_id, name, position, nsfw
|
||||||
FROM guild_channels
|
FROM guild_channels
|
||||||
|
|
@ -297,10 +318,35 @@ class Storage:
|
||||||
|
|
||||||
res['id'] = str(res['id'])
|
res['id'] = str(res['id'])
|
||||||
return res
|
return res
|
||||||
else:
|
elif ctype == ChannelType.DM:
|
||||||
# TODO: dms and group dms
|
dm_row = await self.db.fetchrow("""
|
||||||
|
SELECT party1_id, party2_id
|
||||||
|
FROM dm_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
drow = dict(dm_row)
|
||||||
|
drow['type'] = chan_type
|
||||||
|
|
||||||
|
drow['last_message_id'] = str(
|
||||||
|
await self._chan_last_message(channel_id))
|
||||||
|
|
||||||
|
# dms have just two recipients.
|
||||||
|
drow['recipients'] = [
|
||||||
|
await self.get_user(drow['party1_id']),
|
||||||
|
await self.get_user(drow['party2_id'])
|
||||||
|
]
|
||||||
|
|
||||||
|
drow.pop('party1_id')
|
||||||
|
drow.pop('party2_id')
|
||||||
|
|
||||||
|
drow['id'] = str(drow['id'])
|
||||||
|
return drow
|
||||||
|
elif ctype == ChannelType.GROUP_DM:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_channel_data(self, guild_id) -> List[Dict]:
|
async def get_channel_data(self, guild_id) -> List[Dict]:
|
||||||
"""Get channel information on a guild"""
|
"""Get channel information on a guild"""
|
||||||
channel_basics = await self.db.fetch("""
|
channel_basics = await self.db.fetch("""
|
||||||
|
|
@ -574,6 +620,7 @@ class Storage:
|
||||||
return dinv
|
return dinv
|
||||||
|
|
||||||
async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
|
async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
|
||||||
|
"""Get current user settings."""
|
||||||
row = await self._fetchrow_with_json("""
|
row = await self._fetchrow_with_json("""
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM user_settings
|
FROM user_settings
|
||||||
|
|
@ -688,3 +735,49 @@ class Storage:
|
||||||
res.append(drow)
|
res.append(drow)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def get_dm(self, dm_id: int, user_id: int = None):
|
||||||
|
dm_chan = await self.get_channel(dm_id)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
dm_chan['recipients'] = _filter_recipients(
|
||||||
|
dm_chan['recipients'], user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return dm_chan
|
||||||
|
|
||||||
|
async def get_dms(self, user_id: int) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all DM channels for a user, including group DMs.
|
||||||
|
|
||||||
|
This will only fetch channels the user has in their state,
|
||||||
|
which is different than the whole list of DM channels.
|
||||||
|
"""
|
||||||
|
dm_ids = await self.db.fetch("""
|
||||||
|
SELECT id
|
||||||
|
FROM dm_channel_state
|
||||||
|
WHERE user_id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for dm_id in dm_ids:
|
||||||
|
dm_chan = await self.get_dm(dm_id, user_id)
|
||||||
|
res.append(dm_chan)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
async def get_all_dms(self, user_id: int) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all DMs for a user, regardless of the DM state."""
|
||||||
|
dm_ids = await self.db.fetch("""
|
||||||
|
SELECT id
|
||||||
|
FROM dm_channels
|
||||||
|
WHERE party1_id = $1 OR party2_id = $2
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for dm_id in dm_ids:
|
||||||
|
dm_chan = await self.get_dm(dm_id, user_id)
|
||||||
|
res.append(dm_chan)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
||||||
13
schema.sql
13
schema.sql
|
|
@ -261,6 +261,13 @@ CREATE TABLE IF NOT EXISTS dm_channels (
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS dm_channel_state (
|
||||||
|
user_id bigint REFERENCES users (id) ON DELETE CASCADE,
|
||||||
|
dm_id bigint REFERENCES dm_channels (id) ON DELETE CASCADE,
|
||||||
|
PRIMARY KEY (user_id, dm_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS group_dm_channels (
|
CREATE TABLE IF NOT EXISTS group_dm_channels (
|
||||||
id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||||
owner_id bigint REFERENCES users (id),
|
owner_id bigint REFERENCES users (id),
|
||||||
|
|
@ -440,7 +447,7 @@ CREATE TABLE IF NOT EXISTS embeds (
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
id bigint PRIMARY KEY,
|
id bigint PRIMARY KEY,
|
||||||
channel_id bigint REFERENCES channels (id),
|
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||||
|
|
||||||
-- those are mutually exclusive, only one of them
|
-- those are mutually exclusive, only one of them
|
||||||
-- can NOT be NULL at a time.
|
-- can NOT be NULL at a time.
|
||||||
|
|
@ -486,7 +493,7 @@ CREATE TABLE IF NOT EXISTS message_reactions (
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS channel_pins (
|
CREATE TABLE IF NOT EXISTS channel_pins (
|
||||||
channel_id bigint REFERENCES channels (id),
|
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
|
||||||
message_id bigint REFERENCES messages (id),
|
message_id bigint REFERENCES messages (id) ON DELETE CASCADE,
|
||||||
PRIMARY KEY (channel_id, message_id)
|
PRIMARY KEY (channel_id, message_id)
|
||||||
);
|
);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue