mirror of https://gitlab.com/litecord/litecord.git
blueprints.channels: add channel and guild ack routes
SQL for instances: Rerun `schema.sql` for the new table. - gateway.websocket: add get_read_state to read_state's ready - gateway.websocket: add get_dms on private_channels' ready - storage: fix get_dms() - storage: add Storage.get_channel_ids() - storage: add Storage.get_read_state() - schema.sql: add user_read_state table
This commit is contained in:
parent
5edcc62be4
commit
adc6a58179
|
|
@ -6,47 +6,15 @@ from logbook import Logger
|
|||
from ..auth import token_check
|
||||
from ..snowflake import get_snowflake, snowflake_datetime
|
||||
from ..enums import ChannelType, MessageType, GUILD_CHANS
|
||||
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
|
||||
from ..errors import Forbidden, ChannelNotFound, MessageNotFound
|
||||
from ..schemas import validate, MESSAGE_CREATE
|
||||
|
||||
from .guilds import guild_check
|
||||
from .checks import channel_check, guild_check
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channels', __name__)
|
||||
|
||||
|
||||
async def channel_check(user_id, channel_id):
|
||||
"""Check if the current user is authorized
|
||||
to read the channel's information."""
|
||||
chan_type = await app.storage.get_chan_type(channel_id)
|
||||
|
||||
if chan_type is None:
|
||||
raise ChannelNotFound(f'channel type not found')
|
||||
|
||||
ctype = ChannelType(chan_type)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
guild_id = await app.db.fetchval("""
|
||||
SELECT guild_id
|
||||
FROM guild_channels
|
||||
WHERE guild_channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
return guild_id
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
parties = await app.db.fetchval("""
|
||||
SELECT party1_id, party2_id
|
||||
FROM dm_channels
|
||||
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
||||
""", channel_id, user_id)
|
||||
|
||||
# get the id of the other party
|
||||
parties.remove(user_id)
|
||||
return parties[0]
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>', methods=['GET'])
|
||||
async def get_channel(channel_id):
|
||||
"""Get a single channel's information"""
|
||||
|
|
@ -263,6 +231,17 @@ async def create_message(channel_id):
|
|||
payload = await app.storage.get_message(message_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
|
||||
|
||||
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
||||
|
||||
for str_uid in payload['mentions']:
|
||||
uid = int(str_uid)
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE user_read_state
|
||||
SET mention_count += 1
|
||||
WHERE user_id = $1 AND channel_id = $2
|
||||
""", uid, channel_id)
|
||||
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
|
|
@ -431,3 +410,55 @@ async def trigger_typing(channel_id):
|
|||
})
|
||||
|
||||
return '', 204
|
||||
|
||||
|
||||
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||
"""ACK a channel."""
|
||||
|
||||
if not message_id:
|
||||
message_id = await app.storage.chan_last_message(channel_id)
|
||||
|
||||
res = await app.db.execute("""
|
||||
UPDATE user_read_state
|
||||
|
||||
SET last_message_id = $1,
|
||||
mention_count = 0
|
||||
|
||||
WHERE user_id = $2 AND channel_id = $3
|
||||
""", message_id, user_id, channel_id)
|
||||
|
||||
if res == 'UPDATE 0':
|
||||
await app.db.execute("""
|
||||
INSERT INTO user_read_state
|
||||
(user_id, channel_id, last_message_id, mention_count)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", user_id, channel_id, message_id, 0)
|
||||
|
||||
if guild_id:
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
user_id, guild_id, 'MESSAGE_ACK', {
|
||||
'message_id': str(message_id),
|
||||
'channel_id': str(channel_id)
|
||||
})
|
||||
else:
|
||||
# TODO: use ChannelDispatcher
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, 'MESSAGE_ACK', {
|
||||
'message_id': str(message_id),
|
||||
'channel_id': str(channel_id)
|
||||
})
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
|
||||
async def ack_channel(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await channel_ack(user_id, guild_id, channel_id, message_id)
|
||||
|
||||
return jsonify({
|
||||
# token seems to be used for
|
||||
# data collection activities,
|
||||
# so we never use it.
|
||||
'token': None
|
||||
})
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
from quart import current_app as app
|
||||
|
||||
from ..enums import ChannelType, GUILD_CHANS
|
||||
from ..errors import GuildNotFound, ChannelNotFound
|
||||
|
||||
|
||||
async def guild_check(user_id: int, guild_id: int):
|
||||
"""Check if a user is in a guild."""
|
||||
joined_at = await app.db.execute("""
|
||||
SELECT joined_at
|
||||
FROM members
|
||||
WHERE user_id = $1 AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
|
||||
if not joined_at:
|
||||
raise GuildNotFound('guild not found')
|
||||
|
||||
|
||||
async def channel_check(user_id, channel_id):
|
||||
"""Check if the current user is authorized
|
||||
to read the channel's information."""
|
||||
chan_type = await app.storage.get_chan_type(channel_id)
|
||||
|
||||
if chan_type is None:
|
||||
raise ChannelNotFound(f'channel type not found')
|
||||
|
||||
ctype = ChannelType(chan_type)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
guild_id = await app.db.fetchval("""
|
||||
SELECT guild_id
|
||||
FROM guild_channels
|
||||
WHERE guild_channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
return guild_id
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
parties = await app.db.fetchval("""
|
||||
SELECT party1_id, party2_id
|
||||
FROM dm_channels
|
||||
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
||||
""", channel_id, user_id)
|
||||
|
||||
# get the id of the other party
|
||||
parties.remove(user_id)
|
||||
return parties[0]
|
||||
|
|
@ -5,22 +5,12 @@ from ..snowflake import get_snowflake
|
|||
from ..enums import ChannelType
|
||||
from ..errors import Forbidden, GuildNotFound, BadRequest
|
||||
from ..schemas import validate, GUILD_UPDATE
|
||||
from .channels import channel_ack
|
||||
from .checks import guild_check, channel_check
|
||||
|
||||
bp = Blueprint('guilds', __name__)
|
||||
|
||||
|
||||
async def guild_check(user_id: int, guild_id: int):
|
||||
"""Check if a user is in a guild."""
|
||||
joined_at = await app.db.execute("""
|
||||
SELECT joined_at
|
||||
FROM members
|
||||
WHERE user_id = $1 AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
|
||||
if not joined_at:
|
||||
raise GuildNotFound('guild not found')
|
||||
|
||||
|
||||
async def guild_owner_check(user_id: int, guild_id: int):
|
||||
"""Check if a user is the owner of the guild."""
|
||||
owner_id = await app.db.fetchval("""
|
||||
|
|
@ -469,3 +459,16 @@ async def search_messages(guild_id):
|
|||
'messages': [],
|
||||
'analytics_id': 'ass',
|
||||
})
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/ack', methods=['POST'])
|
||||
async def ack_guild(guild_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
chan_ids = await app.storage.get_channel_ids(guild_id)
|
||||
|
||||
for chan_id in chan_ids:
|
||||
await channel_ack(user_id, guild_id, chan_id)
|
||||
|
||||
return '', 204
|
||||
|
|
|
|||
|
|
@ -207,6 +207,8 @@ class GatewayWebsocket:
|
|||
'user_settings': await self.storage.get_user_settings(user_id),
|
||||
'notes': await self.storage.fetch_notes(user_id),
|
||||
'relationships': await self.storage.get_relationships(user_id),
|
||||
'read_state': await self.storage.get_read_state(user_id),
|
||||
|
||||
'friend_suggestion_count': 0,
|
||||
|
||||
# TODO
|
||||
|
|
@ -215,9 +217,6 @@ class GatewayWebsocket:
|
|||
# TODO
|
||||
'presences': [],
|
||||
|
||||
# TODO
|
||||
'read_state': [],
|
||||
|
||||
# TODO
|
||||
'connected_accounts': [],
|
||||
|
||||
|
|
@ -229,7 +228,9 @@ class GatewayWebsocket:
|
|||
async def dispatch_ready(self):
|
||||
"""Dispatch the READY packet for a connecting account."""
|
||||
guilds = await self._make_guild_list()
|
||||
user = await self.storage.get_user(self.state.user_id, True)
|
||||
|
||||
user_id = self.state.user_id
|
||||
user = await self.storage.get_user(user_id, True)
|
||||
|
||||
uready = {}
|
||||
if not self.state.bot:
|
||||
|
|
@ -240,8 +241,8 @@ class GatewayWebsocket:
|
|||
'v': 6,
|
||||
'user': user,
|
||||
|
||||
# TODO: dms
|
||||
'private_channels': [],
|
||||
'private_channels': await self.storage.get_dms(user_id),
|
||||
|
||||
'guilds': guilds,
|
||||
'session_id': self.state.session_id,
|
||||
'_trace': ['transbian']
|
||||
|
|
|
|||
|
|
@ -225,7 +225,7 @@ class Storage:
|
|||
members = await self.get_member_multi(guild_id, mids)
|
||||
return members
|
||||
|
||||
async def _chan_last_message(self, channel_id: int):
|
||||
async def chan_last_message(self, channel_id: int):
|
||||
return await self.db.fetchval("""
|
||||
SELECT MAX(id)
|
||||
FROM messages
|
||||
|
|
@ -247,7 +247,7 @@ class Storage:
|
|||
return {**row, **{
|
||||
'topic': topic,
|
||||
'last_message_id': str(
|
||||
await self._chan_last_message(row['id']))
|
||||
await self.chan_last_message(row['id']))
|
||||
}}
|
||||
elif chan_type == ChannelType.GUILD_VOICE:
|
||||
vrow = await self.db.fetchval("""
|
||||
|
|
@ -329,7 +329,7 @@ class Storage:
|
|||
drow['type'] = chan_type
|
||||
|
||||
drow['last_message_id'] = str(
|
||||
await self._chan_last_message(channel_id))
|
||||
await self.chan_last_message(channel_id))
|
||||
|
||||
# dms have just two recipients.
|
||||
drow['recipients'] = [
|
||||
|
|
@ -347,6 +347,15 @@ class Storage:
|
|||
|
||||
return None
|
||||
|
||||
async def get_channel_ids(self, guild_id: int) -> List[int]:
|
||||
rows = await self.db.fetch("""
|
||||
SELECT id
|
||||
FROM guild_channels
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
return [r['id'] for r in rows]
|
||||
|
||||
async def get_channel_data(self, guild_id) -> List[Dict]:
|
||||
"""Get channel information on a guild"""
|
||||
channel_basics = await self.db.fetch("""
|
||||
|
|
@ -753,7 +762,7 @@ class Storage:
|
|||
which is different than the whole list of DM channels.
|
||||
"""
|
||||
dm_ids = await self.db.fetch("""
|
||||
SELECT id
|
||||
SELECT dm_id
|
||||
FROM dm_channel_state
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
|
@ -781,3 +790,25 @@ class Storage:
|
|||
res.append(dm_chan)
|
||||
|
||||
return res
|
||||
|
||||
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get the read state for a user."""
|
||||
rows = await self.db.fetch("""
|
||||
SELECT channel_id, last_message_id, mention_count
|
||||
FROM user_read_state
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
res = []
|
||||
|
||||
for row in rows:
|
||||
drow = dict(row)
|
||||
|
||||
drow['id'] = str(drow['channel_id'])
|
||||
drow.pop('channel_id')
|
||||
|
||||
drow['last_message_id'] = str(drow['last_message_id'])
|
||||
|
||||
res.append(drow)
|
||||
|
||||
return res
|
||||
|
|
|
|||
14
schema.sql
14
schema.sql
|
|
@ -180,6 +180,20 @@ CREATE TABLE IF NOT EXISTS channels (
|
|||
channel_type int NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_read_state (
|
||||
user_id bigint REFERENCES users (id),
|
||||
channel_id bigint REFERENCES channels (id),
|
||||
|
||||
-- we don't really need to link
|
||||
-- this column to the messages table
|
||||
last_message_id bigint,
|
||||
|
||||
-- counts are always positive
|
||||
mention_count bigint CHECK (mention_count > -1),
|
||||
|
||||
PRIMARY KEY (user_id, channel_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS guilds (
|
||||
id bigint PRIMARY KEY NOT NULL,
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue