mirror of https://gitlab.com/litecord/litecord.git
blueprints.channels: add channel and guild ack routes
SQL for instances: Rerun `schema.sql` for the new table. - gateway.websocket: add get_read_state to read_state's ready - gateway.websocket: add get_dms on private_channels' ready - storage: fix get_dms() - storage: add Storage.get_channel_ids() - storage: add Storage.get_read_state() - schema.sql: add user_read_state table
This commit is contained in:
parent
5edcc62be4
commit
adc6a58179
|
|
@ -6,47 +6,15 @@ from logbook import Logger
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
from ..snowflake import get_snowflake, snowflake_datetime
|
from ..snowflake import get_snowflake, snowflake_datetime
|
||||||
from ..enums import ChannelType, MessageType, GUILD_CHANS
|
from ..enums import ChannelType, MessageType, GUILD_CHANS
|
||||||
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
|
from ..errors import Forbidden, ChannelNotFound, MessageNotFound
|
||||||
from ..schemas import validate, MESSAGE_CREATE
|
from ..schemas import validate, MESSAGE_CREATE
|
||||||
|
|
||||||
from .guilds import guild_check
|
from .checks import channel_check, guild_check
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('channels', __name__)
|
bp = Blueprint('channels', __name__)
|
||||||
|
|
||||||
|
|
||||||
async def channel_check(user_id, channel_id):
|
|
||||||
"""Check if the current user is authorized
|
|
||||||
to read the channel's information."""
|
|
||||||
chan_type = await app.storage.get_chan_type(channel_id)
|
|
||||||
|
|
||||||
if chan_type is None:
|
|
||||||
raise ChannelNotFound(f'channel type not found')
|
|
||||||
|
|
||||||
ctype = ChannelType(chan_type)
|
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
|
||||||
guild_id = await app.db.fetchval("""
|
|
||||||
SELECT guild_id
|
|
||||||
FROM guild_channels
|
|
||||||
WHERE guild_channels.id = $1
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
|
||||||
return guild_id
|
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
|
||||||
parties = await app.db.fetchval("""
|
|
||||||
SELECT party1_id, party2_id
|
|
||||||
FROM dm_channels
|
|
||||||
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
|
||||||
""", channel_id, user_id)
|
|
||||||
|
|
||||||
# get the id of the other party
|
|
||||||
parties.remove(user_id)
|
|
||||||
return parties[0]
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>', methods=['GET'])
|
@bp.route('/<int:channel_id>', methods=['GET'])
|
||||||
async def get_channel(channel_id):
|
async def get_channel(channel_id):
|
||||||
"""Get a single channel's information"""
|
"""Get a single channel's information"""
|
||||||
|
|
@ -263,6 +231,17 @@ async def create_message(channel_id):
|
||||||
payload = await app.storage.get_message(message_id)
|
payload = await app.storage.get_message(message_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
|
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
|
||||||
|
|
||||||
|
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
||||||
|
|
||||||
|
for str_uid in payload['mentions']:
|
||||||
|
uid = int(str_uid)
|
||||||
|
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count += 1
|
||||||
|
WHERE user_id = $1 AND channel_id = $2
|
||||||
|
""", uid, channel_id)
|
||||||
|
|
||||||
return jsonify(payload)
|
return jsonify(payload)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -431,3 +410,55 @@ async def trigger_typing(channel_id):
|
||||||
})
|
})
|
||||||
|
|
||||||
return '', 204
|
return '', 204
|
||||||
|
|
||||||
|
|
||||||
|
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||||
|
"""ACK a channel."""
|
||||||
|
|
||||||
|
if not message_id:
|
||||||
|
message_id = await app.storage.chan_last_message(channel_id)
|
||||||
|
|
||||||
|
res = await app.db.execute("""
|
||||||
|
UPDATE user_read_state
|
||||||
|
|
||||||
|
SET last_message_id = $1,
|
||||||
|
mention_count = 0
|
||||||
|
|
||||||
|
WHERE user_id = $2 AND channel_id = $3
|
||||||
|
""", message_id, user_id, channel_id)
|
||||||
|
|
||||||
|
if res == 'UPDATE 0':
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO user_read_state
|
||||||
|
(user_id, channel_id, last_message_id, mention_count)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
""", user_id, channel_id, message_id, 0)
|
||||||
|
|
||||||
|
if guild_id:
|
||||||
|
await app.dispatcher.dispatch_user_guild(
|
||||||
|
user_id, guild_id, 'MESSAGE_ACK', {
|
||||||
|
'message_id': str(message_id),
|
||||||
|
'channel_id': str(channel_id)
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# TODO: use ChannelDispatcher
|
||||||
|
await app.dispatcher.dispatch_user(
|
||||||
|
user_id, 'MESSAGE_ACK', {
|
||||||
|
'message_id': str(message_id),
|
||||||
|
'channel_id': str(channel_id)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
|
||||||
|
async def ack_channel(channel_id, message_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
await channel_ack(user_id, guild_id, channel_id, message_id)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
# token seems to be used for
|
||||||
|
# data collection activities,
|
||||||
|
# so we never use it.
|
||||||
|
'token': None
|
||||||
|
})
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
from quart import current_app as app
|
||||||
|
|
||||||
|
from ..enums import ChannelType, GUILD_CHANS
|
||||||
|
from ..errors import GuildNotFound, ChannelNotFound
|
||||||
|
|
||||||
|
|
||||||
|
async def guild_check(user_id: int, guild_id: int):
|
||||||
|
"""Check if a user is in a guild."""
|
||||||
|
joined_at = await app.db.execute("""
|
||||||
|
SELECT joined_at
|
||||||
|
FROM members
|
||||||
|
WHERE user_id = $1 AND guild_id = $2
|
||||||
|
""", user_id, guild_id)
|
||||||
|
|
||||||
|
if not joined_at:
|
||||||
|
raise GuildNotFound('guild not found')
|
||||||
|
|
||||||
|
|
||||||
|
async def channel_check(user_id, channel_id):
|
||||||
|
"""Check if the current user is authorized
|
||||||
|
to read the channel's information."""
|
||||||
|
chan_type = await app.storage.get_chan_type(channel_id)
|
||||||
|
|
||||||
|
if chan_type is None:
|
||||||
|
raise ChannelNotFound(f'channel type not found')
|
||||||
|
|
||||||
|
ctype = ChannelType(chan_type)
|
||||||
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
|
guild_id = await app.db.fetchval("""
|
||||||
|
SELECT guild_id
|
||||||
|
FROM guild_channels
|
||||||
|
WHERE guild_channels.id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
return guild_id
|
||||||
|
|
||||||
|
if ctype == ChannelType.DM:
|
||||||
|
parties = await app.db.fetchval("""
|
||||||
|
SELECT party1_id, party2_id
|
||||||
|
FROM dm_channels
|
||||||
|
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
||||||
|
""", channel_id, user_id)
|
||||||
|
|
||||||
|
# get the id of the other party
|
||||||
|
parties.remove(user_id)
|
||||||
|
return parties[0]
|
||||||
|
|
@ -5,22 +5,12 @@ from ..snowflake import get_snowflake
|
||||||
from ..enums import ChannelType
|
from ..enums import ChannelType
|
||||||
from ..errors import Forbidden, GuildNotFound, BadRequest
|
from ..errors import Forbidden, GuildNotFound, BadRequest
|
||||||
from ..schemas import validate, GUILD_UPDATE
|
from ..schemas import validate, GUILD_UPDATE
|
||||||
|
from .channels import channel_ack
|
||||||
|
from .checks import guild_check, channel_check
|
||||||
|
|
||||||
bp = Blueprint('guilds', __name__)
|
bp = Blueprint('guilds', __name__)
|
||||||
|
|
||||||
|
|
||||||
async def guild_check(user_id: int, guild_id: int):
|
|
||||||
"""Check if a user is in a guild."""
|
|
||||||
joined_at = await app.db.execute("""
|
|
||||||
SELECT joined_at
|
|
||||||
FROM members
|
|
||||||
WHERE user_id = $1 AND guild_id = $2
|
|
||||||
""", user_id, guild_id)
|
|
||||||
|
|
||||||
if not joined_at:
|
|
||||||
raise GuildNotFound('guild not found')
|
|
||||||
|
|
||||||
|
|
||||||
async def guild_owner_check(user_id: int, guild_id: int):
|
async def guild_owner_check(user_id: int, guild_id: int):
|
||||||
"""Check if a user is the owner of the guild."""
|
"""Check if a user is the owner of the guild."""
|
||||||
owner_id = await app.db.fetchval("""
|
owner_id = await app.db.fetchval("""
|
||||||
|
|
@ -469,3 +459,16 @@ async def search_messages(guild_id):
|
||||||
'messages': [],
|
'messages': [],
|
||||||
'analytics_id': 'ass',
|
'analytics_id': 'ass',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route('/<int:guild_id>/ack', methods=['POST'])
|
||||||
|
async def ack_guild(guild_id):
|
||||||
|
user_id = await token_check()
|
||||||
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
|
chan_ids = await app.storage.get_channel_ids(guild_id)
|
||||||
|
|
||||||
|
for chan_id in chan_ids:
|
||||||
|
await channel_ack(user_id, guild_id, chan_id)
|
||||||
|
|
||||||
|
return '', 204
|
||||||
|
|
|
||||||
|
|
@ -207,6 +207,8 @@ class GatewayWebsocket:
|
||||||
'user_settings': await self.storage.get_user_settings(user_id),
|
'user_settings': await self.storage.get_user_settings(user_id),
|
||||||
'notes': await self.storage.fetch_notes(user_id),
|
'notes': await self.storage.fetch_notes(user_id),
|
||||||
'relationships': await self.storage.get_relationships(user_id),
|
'relationships': await self.storage.get_relationships(user_id),
|
||||||
|
'read_state': await self.storage.get_read_state(user_id),
|
||||||
|
|
||||||
'friend_suggestion_count': 0,
|
'friend_suggestion_count': 0,
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
|
|
@ -215,9 +217,6 @@ class GatewayWebsocket:
|
||||||
# TODO
|
# TODO
|
||||||
'presences': [],
|
'presences': [],
|
||||||
|
|
||||||
# TODO
|
|
||||||
'read_state': [],
|
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
'connected_accounts': [],
|
'connected_accounts': [],
|
||||||
|
|
||||||
|
|
@ -229,7 +228,9 @@ class GatewayWebsocket:
|
||||||
async def dispatch_ready(self):
|
async def dispatch_ready(self):
|
||||||
"""Dispatch the READY packet for a connecting account."""
|
"""Dispatch the READY packet for a connecting account."""
|
||||||
guilds = await self._make_guild_list()
|
guilds = await self._make_guild_list()
|
||||||
user = await self.storage.get_user(self.state.user_id, True)
|
|
||||||
|
user_id = self.state.user_id
|
||||||
|
user = await self.storage.get_user(user_id, True)
|
||||||
|
|
||||||
uready = {}
|
uready = {}
|
||||||
if not self.state.bot:
|
if not self.state.bot:
|
||||||
|
|
@ -240,8 +241,8 @@ class GatewayWebsocket:
|
||||||
'v': 6,
|
'v': 6,
|
||||||
'user': user,
|
'user': user,
|
||||||
|
|
||||||
# TODO: dms
|
'private_channels': await self.storage.get_dms(user_id),
|
||||||
'private_channels': [],
|
|
||||||
'guilds': guilds,
|
'guilds': guilds,
|
||||||
'session_id': self.state.session_id,
|
'session_id': self.state.session_id,
|
||||||
'_trace': ['transbian']
|
'_trace': ['transbian']
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,7 @@ class Storage:
|
||||||
members = await self.get_member_multi(guild_id, mids)
|
members = await self.get_member_multi(guild_id, mids)
|
||||||
return members
|
return members
|
||||||
|
|
||||||
async def _chan_last_message(self, channel_id: int):
|
async def chan_last_message(self, channel_id: int):
|
||||||
return await self.db.fetchval("""
|
return await self.db.fetchval("""
|
||||||
SELECT MAX(id)
|
SELECT MAX(id)
|
||||||
FROM messages
|
FROM messages
|
||||||
|
|
@ -247,7 +247,7 @@ class Storage:
|
||||||
return {**row, **{
|
return {**row, **{
|
||||||
'topic': topic,
|
'topic': topic,
|
||||||
'last_message_id': str(
|
'last_message_id': str(
|
||||||
await self._chan_last_message(row['id']))
|
await self.chan_last_message(row['id']))
|
||||||
}}
|
}}
|
||||||
elif chan_type == ChannelType.GUILD_VOICE:
|
elif chan_type == ChannelType.GUILD_VOICE:
|
||||||
vrow = await self.db.fetchval("""
|
vrow = await self.db.fetchval("""
|
||||||
|
|
@ -329,7 +329,7 @@ class Storage:
|
||||||
drow['type'] = chan_type
|
drow['type'] = chan_type
|
||||||
|
|
||||||
drow['last_message_id'] = str(
|
drow['last_message_id'] = str(
|
||||||
await self._chan_last_message(channel_id))
|
await self.chan_last_message(channel_id))
|
||||||
|
|
||||||
# dms have just two recipients.
|
# dms have just two recipients.
|
||||||
drow['recipients'] = [
|
drow['recipients'] = [
|
||||||
|
|
@ -347,6 +347,15 @@ class Storage:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_channel_ids(self, guild_id: int) -> List[int]:
|
||||||
|
rows = await self.db.fetch("""
|
||||||
|
SELECT id
|
||||||
|
FROM guild_channels
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""", guild_id)
|
||||||
|
|
||||||
|
return [r['id'] for r in rows]
|
||||||
|
|
||||||
async def get_channel_data(self, guild_id) -> List[Dict]:
|
async def get_channel_data(self, guild_id) -> List[Dict]:
|
||||||
"""Get channel information on a guild"""
|
"""Get channel information on a guild"""
|
||||||
channel_basics = await self.db.fetch("""
|
channel_basics = await self.db.fetch("""
|
||||||
|
|
@ -753,7 +762,7 @@ class Storage:
|
||||||
which is different than the whole list of DM channels.
|
which is different than the whole list of DM channels.
|
||||||
"""
|
"""
|
||||||
dm_ids = await self.db.fetch("""
|
dm_ids = await self.db.fetch("""
|
||||||
SELECT id
|
SELECT dm_id
|
||||||
FROM dm_channel_state
|
FROM dm_channel_state
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""", user_id)
|
||||||
|
|
@ -781,3 +790,25 @@ class Storage:
|
||||||
res.append(dm_chan)
|
res.append(dm_chan)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
|
||||||
|
"""Get the read state for a user."""
|
||||||
|
rows = await self.db.fetch("""
|
||||||
|
SELECT channel_id, last_message_id, mention_count
|
||||||
|
FROM user_read_state
|
||||||
|
WHERE user_id = $1
|
||||||
|
""", user_id)
|
||||||
|
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
drow = dict(row)
|
||||||
|
|
||||||
|
drow['id'] = str(drow['channel_id'])
|
||||||
|
drow.pop('channel_id')
|
||||||
|
|
||||||
|
drow['last_message_id'] = str(drow['last_message_id'])
|
||||||
|
|
||||||
|
res.append(drow)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
||||||
14
schema.sql
14
schema.sql
|
|
@ -180,6 +180,20 @@ CREATE TABLE IF NOT EXISTS channels (
|
||||||
channel_type int NOT NULL
|
channel_type int NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_read_state (
|
||||||
|
user_id bigint REFERENCES users (id),
|
||||||
|
channel_id bigint REFERENCES channels (id),
|
||||||
|
|
||||||
|
-- we don't really need to link
|
||||||
|
-- this column to the messages table
|
||||||
|
last_message_id bigint,
|
||||||
|
|
||||||
|
-- counts are always positive
|
||||||
|
mention_count bigint CHECK (mention_count > -1),
|
||||||
|
|
||||||
|
PRIMARY KEY (user_id, channel_id)
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS guilds (
|
CREATE TABLE IF NOT EXISTS guilds (
|
||||||
id bigint PRIMARY KEY NOT NULL,
|
id bigint PRIMARY KEY NOT NULL,
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue