diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 80b6a6f..d6266ee 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -6,47 +6,15 @@ from logbook import Logger from ..auth import token_check from ..snowflake import get_snowflake, snowflake_datetime from ..enums import ChannelType, MessageType, GUILD_CHANS -from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound +from ..errors import Forbidden, ChannelNotFound, MessageNotFound from ..schemas import validate, MESSAGE_CREATE -from .guilds import guild_check +from .checks import channel_check, guild_check log = Logger(__name__) bp = Blueprint('channels', __name__) -async def channel_check(user_id, channel_id): - """Check if the current user is authorized - to read the channel's information.""" - chan_type = await app.storage.get_chan_type(channel_id) - - if chan_type is None: - raise ChannelNotFound(f'channel type not found') - - ctype = ChannelType(chan_type) - - if ctype in GUILD_CHANS: - guild_id = await app.db.fetchval(""" - SELECT guild_id - FROM guild_channels - WHERE guild_channels.id = $1 - """, channel_id) - - await guild_check(user_id, guild_id) - return guild_id - - if ctype == ChannelType.DM: - parties = await app.db.fetchval(""" - SELECT party1_id, party2_id - FROM dm_channels - WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) - """, channel_id, user_id) - - # get the id of the other party - parties.remove(user_id) - return parties[0] - - @bp.route('/', methods=['GET']) async def get_channel(channel_id): """Get a single channel's information""" @@ -263,6 +231,17 @@ async def create_message(channel_id): payload = await app.storage.get_message(message_id) await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload) + # TODO: dispatch the MESSAGE_CREATE to any mentioning user. + + for str_uid in payload['mentions']: + uid = int(str_uid) + + await app.db.execute(""" + UPDATE user_read_state + SET mention_count += 1 + WHERE user_id = $1 AND channel_id = $2 + """, uid, channel_id) + return jsonify(payload) @@ -431,3 +410,55 @@ async def trigger_typing(channel_id): }) return '', 204 + + +async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): + """ACK a channel.""" + + if not message_id: + message_id = await app.storage.chan_last_message(channel_id) + + res = await app.db.execute(""" + UPDATE user_read_state + + SET last_message_id = $1, + mention_count = 0 + + WHERE user_id = $2 AND channel_id = $3 + """, message_id, user_id, channel_id) + + if res == 'UPDATE 0': + await app.db.execute(""" + INSERT INTO user_read_state + (user_id, channel_id, last_message_id, mention_count) + VALUES ($1, $2, $3, $4) + """, user_id, channel_id, message_id, 0) + + if guild_id: + await app.dispatcher.dispatch_user_guild( + user_id, guild_id, 'MESSAGE_ACK', { + 'message_id': str(message_id), + 'channel_id': str(channel_id) + }) + else: + # TODO: use ChannelDispatcher + await app.dispatcher.dispatch_user( + user_id, 'MESSAGE_ACK', { + 'message_id': str(message_id), + 'channel_id': str(channel_id) + }) + + +@bp.route('//messages//ack', methods=['POST']) +async def ack_channel(channel_id, message_id): + user_id = await token_check() + guild_id = await channel_check(user_id, channel_id) + + await channel_ack(user_id, guild_id, channel_id, message_id) + + return jsonify({ + # token seems to be used for + # data collection activities, + # so we never use it. + 'token': None + }) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py new file mode 100644 index 0000000..aa92b7a --- /dev/null +++ b/litecord/blueprints/checks.py @@ -0,0 +1,48 @@ +from quart import current_app as app + +from ..enums import ChannelType, GUILD_CHANS +from ..errors import GuildNotFound, ChannelNotFound + + +async def guild_check(user_id: int, guild_id: int): + """Check if a user is in a guild.""" + joined_at = await app.db.execute(""" + SELECT joined_at + FROM members + WHERE user_id = $1 AND guild_id = $2 + """, user_id, guild_id) + + if not joined_at: + raise GuildNotFound('guild not found') + + +async def channel_check(user_id, channel_id): + """Check if the current user is authorized + to read the channel's information.""" + chan_type = await app.storage.get_chan_type(channel_id) + + if chan_type is None: + raise ChannelNotFound(f'channel type not found') + + ctype = ChannelType(chan_type) + + if ctype in GUILD_CHANS: + guild_id = await app.db.fetchval(""" + SELECT guild_id + FROM guild_channels + WHERE guild_channels.id = $1 + """, channel_id) + + await guild_check(user_id, guild_id) + return guild_id + + if ctype == ChannelType.DM: + parties = await app.db.fetchval(""" + SELECT party1_id, party2_id + FROM dm_channels + WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) + """, channel_id, user_id) + + # get the id of the other party + parties.remove(user_id) + return parties[0] diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 1b818b7..d14f8f2 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -5,22 +5,12 @@ from ..snowflake import get_snowflake from ..enums import ChannelType from ..errors import Forbidden, GuildNotFound, BadRequest from ..schemas import validate, GUILD_UPDATE +from .channels import channel_ack +from .checks import guild_check, channel_check bp = Blueprint('guilds', __name__) -async def guild_check(user_id: int, guild_id: int): - """Check if a user is in a guild.""" - joined_at = await app.db.execute(""" - SELECT joined_at - FROM members - WHERE user_id = $1 AND guild_id = $2 - """, user_id, guild_id) - - if not joined_at: - raise GuildNotFound('guild not found') - - async def guild_owner_check(user_id: int, guild_id: int): """Check if a user is the owner of the guild.""" owner_id = await app.db.fetchval(""" @@ -469,3 +459,16 @@ async def search_messages(guild_id): 'messages': [], 'analytics_id': 'ass', }) + + +@bp.route('//ack', methods=['POST']) +async def ack_guild(guild_id): + user_id = await token_check() + await guild_check(user_id, guild_id) + + chan_ids = await app.storage.get_channel_ids(guild_id) + + for chan_id in chan_ids: + await channel_ack(user_id, guild_id, chan_id) + + return '', 204 diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index cecf0e2..fed011d 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -207,6 +207,8 @@ class GatewayWebsocket: 'user_settings': await self.storage.get_user_settings(user_id), 'notes': await self.storage.fetch_notes(user_id), 'relationships': await self.storage.get_relationships(user_id), + 'read_state': await self.storage.get_read_state(user_id), + 'friend_suggestion_count': 0, # TODO @@ -215,9 +217,6 @@ class GatewayWebsocket: # TODO 'presences': [], - # TODO - 'read_state': [], - # TODO 'connected_accounts': [], @@ -229,7 +228,9 @@ class GatewayWebsocket: async def dispatch_ready(self): """Dispatch the READY packet for a connecting account.""" guilds = await self._make_guild_list() - user = await self.storage.get_user(self.state.user_id, True) + + user_id = self.state.user_id + user = await self.storage.get_user(user_id, True) uready = {} if not self.state.bot: @@ -240,8 +241,8 @@ class GatewayWebsocket: 'v': 6, 'user': user, - # TODO: dms - 'private_channels': [], + 'private_channels': await self.storage.get_dms(user_id), + 'guilds': guilds, 'session_id': self.state.session_id, '_trace': ['transbian'] diff --git a/litecord/storage.py b/litecord/storage.py index 46844c4..9ebc051 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -225,7 +225,7 @@ class Storage: members = await self.get_member_multi(guild_id, mids) return members - async def _chan_last_message(self, channel_id: int): + async def chan_last_message(self, channel_id: int): return await self.db.fetchval(""" SELECT MAX(id) FROM messages @@ -247,7 +247,7 @@ class Storage: return {**row, **{ 'topic': topic, 'last_message_id': str( - await self._chan_last_message(row['id'])) + await self.chan_last_message(row['id'])) }} elif chan_type == ChannelType.GUILD_VOICE: vrow = await self.db.fetchval(""" @@ -329,7 +329,7 @@ class Storage: drow['type'] = chan_type drow['last_message_id'] = str( - await self._chan_last_message(channel_id)) + await self.chan_last_message(channel_id)) # dms have just two recipients. drow['recipients'] = [ @@ -347,6 +347,15 @@ class Storage: return None + async def get_channel_ids(self, guild_id: int) -> List[int]: + rows = await self.db.fetch(""" + SELECT id + FROM guild_channels + WHERE guild_id = $1 + """, guild_id) + + return [r['id'] for r in rows] + async def get_channel_data(self, guild_id) -> List[Dict]: """Get channel information on a guild""" channel_basics = await self.db.fetch(""" @@ -753,7 +762,7 @@ class Storage: which is different than the whole list of DM channels. """ dm_ids = await self.db.fetch(""" - SELECT id + SELECT dm_id FROM dm_channel_state WHERE user_id = $1 """, user_id) @@ -781,3 +790,25 @@ class Storage: res.append(dm_chan) return res + + async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: + """Get the read state for a user.""" + rows = await self.db.fetch(""" + SELECT channel_id, last_message_id, mention_count + FROM user_read_state + WHERE user_id = $1 + """, user_id) + + res = [] + + for row in rows: + drow = dict(row) + + drow['id'] = str(drow['channel_id']) + drow.pop('channel_id') + + drow['last_message_id'] = str(drow['last_message_id']) + + res.append(drow) + + return res diff --git a/schema.sql b/schema.sql index ca86a63..846eaff 100644 --- a/schema.sql +++ b/schema.sql @@ -180,6 +180,20 @@ CREATE TABLE IF NOT EXISTS channels ( channel_type int NOT NULL ); +CREATE TABLE IF NOT EXISTS user_read_state ( + user_id bigint REFERENCES users (id), + channel_id bigint REFERENCES channels (id), + + -- we don't really need to link + -- this column to the messages table + last_message_id bigint, + + -- counts are always positive + mention_count bigint CHECK (mention_count > -1), + + PRIMARY KEY (user_id, channel_id) +); + CREATE TABLE IF NOT EXISTS guilds ( id bigint PRIMARY KEY NOT NULL,