diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index df479ee..b5cdaea 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -191,15 +191,38 @@ async def create_message(channel_id): 'MESSAGE_CREATE', payload) if ctype == ChannelType.GUILD_TEXT: - for mention in payload['mentions']: - uid = int(mention['id']) + # calculate the user ids we'll bump the mention count for + uids = set() + # first is extracting user mentions + for mention in payload['mentions']: + uids.add(int(mention['id'])) + + # then role mentions + for role_mention in payload['mention_roles']: + role_id = int(role_mention) + member_ids = await app.storage.get_role_members(role_id) + + for member_id in member_ids: + uids.add(member_id) + + # if we're on an at-everyone / at-here, just update + # the read state for everyone. + if mentions_everyone: + uids = [] + await app.db.execute(""" + UPDATE user_read_state + SET mention_count = mention_count + 1 + WHERE channel_id = $1 + """, channel_id) + + for user_id in uids: await app.db.execute(""" UPDATE user_read_state SET mention_count = mention_count + 1 WHERE user_id = $1 AND channel_id = $2 - """, uid, channel_id) + """, user_id, channel_id) return jsonify(payload) diff --git a/litecord/storage.py b/litecord/storage.py index ea94a0b..abf4b59 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -852,3 +852,13 @@ class Storage: res.append(emoji) return res + + async def get_role_members(self, role_id: int) -> List[int]: + """Get all members with a role.""" + rows = await self.db.fetch(""" + SELECT user_id + FROM member_roles + WHERE role_id = $1 + """, role_id) + + return [r['id'] for r in rows]