From 1d33e46fd80486fd322bab50953ae2012f0c262e Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 19 Oct 2018 19:16:29 -0300 Subject: [PATCH 01/69] pubsub: add lazy_guild --- README.md | 8 +++ litecord/pubsub/channel.py | 1 - litecord/pubsub/guild.py | 1 - litecord/pubsub/lazy_guild.py | 94 +++++++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 litecord/pubsub/lazy_guild.py diff --git a/README.md b/README.md index 4658055..5f5c9c4 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,16 @@ This project is a rewrite of [litecord-reference]. [litecord-reference]: https://gitlab.com/luna/litecord-reference +## Notes + + - There are no testing being run on the current codebase. Which means the code is definitely unstable. + - No voice is planned to be developed, for now. + - You must figure out connecting to the server yourself. Litecord will not distribute + Discord's official client code nor provide ways to modify the client. + ## Install +Requirements: - Python 3.6 or higher - PostgreSQL - [Pipenv] diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index 3c3cb2c..621eeff 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -1,5 +1,4 @@ from typing import Any -from collections import defaultdict from logbook import Logger diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index a05373a..6896d93 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -1,4 +1,3 @@ -from collections import defaultdict from typing import Any from logbook import Logger diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py new file mode 100644 index 0000000..ba8a451 --- /dev/null +++ b/litecord/pubsub/lazy_guild.py @@ -0,0 +1,94 @@ +from collections import defaultdict +from typing import Any + +from logbook import Logger + +from .dispatcher import Dispatcher + +log = Logger(__name__) + + +class GuildMemberList(): + def __init__(self, guild_id: int): + self.guild_id = guild_id + + # TODO: initialize list with actual member info + self._uninitialized = True + self.member_list = [] + + #: holds the state of subscribed users + self.state = set() + + async def _init_check(self): + """Check if the member list is initialized before + messing with it.""" + if self._uninitialized: + await self._init_member_list() + + async def _init_member_list(self): + """Fill in :attr:`GuildMemberList.member_list` + with information about the guilds' members.""" + pass + + async def sub(self, user_id: int): + """Subscribe a user to the member list.""" + await self._init_check() + self.state.add(user_id) + + async def unsub(self, user_id: int): + """Unsubscribe a user from the member list""" + self.state.discard(user_id) + + # once we reach 0 subscribers, + # we drop the current member list we have (for memory) + # but keep the GuildMemberList running (as + # uninitialized) for a future subscriber. + + if not self.state: + self.member_list = [] + self._uninitialized = True + + async def dispatch(self, event: str, data: Any): + """The dispatch() method here, instead of being + about dispatching a single event to the subscribed + users and forgetting about it, is about storing + the actual member list information so that we + can generate the respective events to the users. + + GuildMemberList stores the current guilds' list + in its :attr:`GuildMemberList.member_list` attribute, + with that attribute being modified via different + calls to :meth:`GuildMemberList.dispatch` + """ + + if self._uninitialized: + # if the list is currently uninitialized, + # no subscribers actually happened, so + # we can safely drop the incoming event. + return + + +class LazyGuildDispatcher(Dispatcher): + """Main class holding the member lists for lazy guilds.""" + KEY_TYPE = int + VAL_TYPE = int + + def __init__(self, main): + super().__init__(main) + self.state = defaultdict(GuildMemberList) + + async def sub(self, guild_id, user_id): + await self.state[guild_id].sub(user_id) + + async def unsub(self, guild_id, user_id): + await self.state[guild_id].unsub(user_id) + + async def dispatch(self, guild_id: int, event: str, data): + """Dispatch an event to the member list. + + GuildMemberList will make sure of converting it to + GUILD_MEMBER_LIST_UPDATE events. + """ + member_list = self.state[guild_id] + await member_list.dispatch(event, data) + From cd4181c32719d1a8dc2fb53c1e7b6550c435c100 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 24 Oct 2018 16:36:24 -0300 Subject: [PATCH 02/69] litecord.pubsub: add more functionality to GuildMemberList GuildMemberList, as of this commit, can generate a correct list and handle (some of) the data given in OP 14. The implementation is still rudimentary and there's a lot of work to finish. - dispatcher: add LazyGuildDispatcher - gateway.state_manager: add states_raw to fetch a single state without uid - gateway.websocket: remove rudimentary implementation (moved it to GuildMemberList in litecord.pubsub.lazy_guild) --- litecord/dispatcher.py | 4 +- litecord/gateway/state_manager.py | 13 ++ litecord/gateway/websocket.py | 71 ++----- litecord/pubsub/__init__.py | 3 +- litecord/pubsub/lazy_guild.py | 317 +++++++++++++++++++++++++++--- litecord/storage.py | 2 +- 6 files changed, 317 insertions(+), 93 deletions(-) diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index 009ac5e..e7e392b 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -4,7 +4,8 @@ from typing import List, Any from logbook import Logger from .pubsub import GuildDispatcher, MemberDispatcher, \ - UserDispatcher, ChannelDispatcher, FriendDispatcher + UserDispatcher, ChannelDispatcher, FriendDispatcher, \ + LazyGuildDispatcher log = Logger(__name__) @@ -35,6 +36,7 @@ class EventDispatcher: 'channel': ChannelDispatcher(self), 'user': UserDispatcher(self), 'friend': FriendDispatcher(self), + 'lazy_guild': LazyGuildDispatcher(self), } async def action(self, backend_str: str, action: str, key, identifier): diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index 56d00bc..e393a79 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -22,12 +22,16 @@ class StateManager: # } self.states = defaultdict(dict) + #: raw mapping from session ids to GatewayState + self.states_raw = {} + def insert(self, state: GatewayState): """Insert a new state object.""" user_states = self.states[state.user_id] log.debug('inserting state: {!r}', state) user_states[state.session_id] = state + self.states_raw[state.session_id] = state def fetch(self, user_id: int, session_id: str) -> GatewayState: """Fetch a state object from the manager. @@ -40,11 +44,20 @@ class StateManager: """ return self.states[user_id][session_id] + def fetch_raw(self, session_id: str) -> GatewayState: + """Fetch a single state given the Session ID.""" + return self.states_raw[session_id] + def remove(self, state): """Remove a state from the registry""" if not state: return + try: + self.states_raw.pop(state.session_id) + except KeyError: + pass + try: log.debug('removing state: {!r}', state) self.states[state.user_id].pop(state.session_id) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 9c6bd4d..7fce369 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -641,9 +641,11 @@ class GatewayWebsocket: This is the known structure of GUILD_MEMBER_LIST_UPDATE: + group_id = 'online' | 'offline' | role_id (string) + sync_item = { 'group': { - 'id': string, // 'online' | 'offline' | any role id + 'id': group_id, 'count': num } } | { @@ -678,7 +680,7 @@ class GatewayWebsocket: // separately from the online list? 'groups': [ { - 'id': string // 'online' | 'offline' | any role id + 'id': group_id 'count': num }, ... ] @@ -713,65 +715,16 @@ class GatewayWebsocket: if guild_id not in gids: return - member_ids = await self.storage.get_member_ids(guild_id) - log.debug('lazy: loading {} members', len(member_ids)) + # make shard query + lazy_guilds = self.ext.dispatcher.backends['lazy_guild'] - # the current implementation is rudimentary and only - # generates two groups: online and offline, using - # PresenceManager.guild_presences to fill list_data. + for chan_id, ranges in data['channels'].items(): + chan_id = int(chan_id) + member_list = await lazy_guilds.get_gml(chan_id) - # this also doesn't take account the channels in lazy_request. - - guild_presences = await self.presence.guild_presences(member_ids, - guild_id) - - online = [{'member': p} - for p in guild_presences - if p['status'] == 'online'] - offline = [{'member': p} - for p in guild_presences - if p['status'] == 'offline'] - - log.debug('lazy: {} presences, online={}, offline={}', - len(guild_presences), - len(online), - len(offline)) - - # construct items in the WORST WAY POSSIBLE. - items = [{ - 'group': { - 'id': 'online', - 'count': len(online), - } - }] + online + [{ - 'group': { - 'id': 'offline', - 'count': len(offline), - } - }] + offline - - await self.dispatch('GUILD_MEMBER_LIST_UPDATE', { - 'id': 'everyone', - 'guild_id': data['guild_id'], - 'groups': [ - { - 'id': 'online', - 'count': len(online), - }, - { - 'id': 'offline', - 'count': len(offline), - } - ], - - 'ops': [ - { - 'range': [0, 99], - 'op': 'SYNC', - 'items': items - } - ] - }) + await member_list.shard_query( + self.state.session_id, ranges + ) async def process_message(self, payload): """Process a single message coming in from the client.""" diff --git a/litecord/pubsub/__init__.py b/litecord/pubsub/__init__.py index 7320867..31388de 100644 --- a/litecord/pubsub/__init__.py +++ b/litecord/pubsub/__init__.py @@ -3,7 +3,8 @@ from .member import MemberDispatcher from .user import UserDispatcher from .channel import ChannelDispatcher from .friend import FriendDispatcher +from .lazy_guild import LazyGuildDispatcher __all__ = ['GuildDispatcher', 'MemberDispatcher', 'UserDispatcher', 'ChannelDispatcher', - 'FriendDispatcher'] + 'FriendDispatcher', 'LazyGuildDispatcher'] diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index ba8a451..ef7e0b2 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -1,5 +1,9 @@ +""" +Main code for Lazy Guild implementation in litecord. +""" +import pprint from collections import defaultdict -from typing import Any +from typing import Any, List, Dict from logbook import Logger @@ -8,36 +12,214 @@ from .dispatcher import Dispatcher log = Logger(__name__) -class GuildMemberList(): - def __init__(self, guild_id: int): +class GuildMemberList: + """This class stores the current member list information + for a guild (by channel). + + As channels can have different sets of roles that can + read them and so, different lists, this is more of a + "channel member list" than a guild member list. + + Attributes + ---------- + main_lg: LazyGuildDispatcher + Main instance of :class:`LazyGuildDispatcher`, + so that we're able to use things such as :class:`Storage`. + guild_id: int + The Guild ID this instance is referring to. + channel_id: int + The Channel ID this instance is referring to. + member_list: List + The actual member list information. + state: set + The set of session IDs that are subscribed to the guild. + + User IDs being used as the identifier in GuildMemberList + is a wrong assumption. It is true Discord rolled out + lazy guilds to all of the userbase, but users that are bots, + for example, can still rely on PRESENCE_UPDATEs. + """ + def __init__(self, guild_id: int, + channel_id: int, main_lg): + self.main_lg = main_lg self.guild_id = guild_id + self.channel_id = channel_id - # TODO: initialize list with actual member info - self._uninitialized = True - self.member_list = [] + # a really long chain of classes to get + # to the storage instance... + main = main_lg.main_dispatcher + self.storage = main.app.storage + self.presence = main.app.presence + self.state_man = main.app.state_manager - #: holds the state of subscribed users + self.member_list = None + self.items = None + + #: holds the state of subscribed shards + # to this channels' member list self.state = set() async def _init_check(self): """Check if the member list is initialized before messing with it.""" - if self._uninitialized: + if self.member_list is None: await self._init_member_list() + async def get_roles(self) -> List[Dict[str, Any]]: + """Get role information, but only: + - the ID + - the name + - the position + + of all HOISTED roles.""" + # TODO: write own query for this + # TODO: calculate channel overrides + roles = await self.storage.get_role_data(self.guild_id) + + return [{ + 'id': role['id'], + 'name': role['name'], + 'position': role['position'] + } for role in roles if role['hoist']] + async def _init_member_list(self): """Fill in :attr:`GuildMemberList.member_list` with information about the guilds' members.""" - pass + member_ids = await self.storage.get_member_ids(self.guild_id) - async def sub(self, user_id: int): - """Subscribe a user to the member list.""" + guild_presences = await self.presence.guild_presences( + member_ids, self.guild_id) + + guild_roles = await self.get_roles() + + # sort by position + guild_roles.sort(key=lambda role: role['position']) + roleids = [r['id'] for r in guild_roles] + + # groups are: + # - roles that are hoisted + # - "online" and "offline", with "online" + # being for people without any roles. + + groups = roleids + ['online', 'offline'] + + log.debug('{} presences, {} groups', + len(guild_presences), len(groups)) + + group_data = {group: [] for group in groups} + + print('group data', group_data) + + def _try_hier(role_id: str, roleids: list): + """Try to fetch a role's position in the hierarchy""" + try: + return roleids.index(role_id) + except ValueError: + # the given role isn't on a group + # so it doesn't count for our purposes. + return 0 + + for presence in guild_presences: + # simple group (online or offline) + # we'll decide on the best group for the presence later on + simple_group = ('offline' + if presence['status'] == 'offline' + else 'online') + + # get the best possible role + roles = sorted( + presence['roles'], + key=lambda role_id: _try_hier(role_id, roleids) + ) + + try: + best_role_id = roles[0] + except IndexError: + # no hoisted roles exist in the guild, assign + # the @everyone role + best_role_id = str(self.guild_id) + + print('best role', best_role_id, str(self.guild_id)) + print('simple group assign', simple_group) + + # if the best role is literally the @everyone role, + # this user has no hoisted roles + if best_role_id == str(self.guild_id): + # this user has no roles, put it on online/offline + group_data[simple_group].append(presence) + continue + + # this user has a best_role that isn't the + # @everyone role, so we'll put them in the respective group + group_data[best_role_id].append(presence) + + # go through each group and sort the resulting members by display name + + members = await self.storage.get_member_data(self.guild_id) + member_nicks = {member['user']['id']: member.get('nick') + for member in members} + + # now we'll sort each group by their display name + # (can be their current nickname OR their username + # if no nickname is set) + print('pre-sorted group data') + pprint.pprint(group_data) + + for _, group_list in group_data.items(): + def display_name(presence: dict) -> str: + uid = presence['user']['id'] + + uname = presence['user']['username'] + nick = member_nicks[uid] + + return nick or uname + + group_list.sort(key=display_name) + + pprint.pprint(group_data) + + self.member_list = { + 'groups': groups, + 'data': group_data + } + + def get_items(self) -> list: + """Generate the main items list,""" + if self.member_list is None: + return [] + + if self.items: + return self.items + + groups = self.member_list['groups'] + + res = [] + for group in groups: + members = self.member_list['data'][group] + + res.append({ + 'group': { + 'id': group, + 'count': len(members), + } + }) + + for member in members: + res.append({ + 'member': member + }) + + self.items = res + return res + + async def sub(self, session_id: str): + """Subscribe a shard to the member list.""" await self._init_check() - self.state.add(user_id) + self.state.add(session_id) - async def unsub(self, user_id: int): - """Unsubscribe a user from the member list""" - self.state.discard(user_id) + async def unsub(self, session_id: str): + """Unsubscribe a shard from the member list""" + self.state.discard(session_id) # once we reach 0 subscribers, # we drop the current member list we have (for memory) @@ -45,8 +227,70 @@ class GuildMemberList(): # uninitialized) for a future subscriber. if not self.state: - self.member_list = [] - self._uninitialized = True + self.member_list = None + + async def shard_query(self, session_id: str, ranges: list): + """Send a GUILD_MEMBER_LIST_UPDATE event + for a shard that is querying about the member list. + + Paramteters + ----------- + session_id: str + The Session ID querying information. + channel_id: int + The Channel ID that we want information on. + ranges: List[List[int]] + ranges of the list that we want. + """ + + await self._init_check() + + # make sure this is a sane state + state = self.state_man.fetch_raw(session_id) + if not state: + await self.unsub(session_id) + return + + # since this is a sane state AND + # trying to query, we automatically + # subscribe the state to this list + await self.sub(session_id) + + reply = { + 'guild_id': str(self.guild_id), + + # TODO: everyone for channels without overrides + # channel_id for channels WITH overrides. + + 'id': 'everyone', + # 'id': str(self.channel_id), + + 'groups': [ + { + 'count': len(self.member_list['data'][group]), + 'id': group + } for group in self.member_list['groups'] + ], + + 'ops': [], + } + + for start, end in ranges: + itemcount = end - start + + # ignore incorrect ranges + if itemcount < 0: + continue + + items = self.get_items() + + reply['ops'].append({ + 'op': 'SYNC', + 'range': [start, end], + 'items': items[start:end], + }) + + await state.ws.dispatch('GUILD_MEMBER_LIST_UPDATE', reply) async def dispatch(self, event: str, data: Any): """The dispatch() method here, instead of being @@ -61,7 +305,7 @@ class GuildMemberList(): calls to :meth:`GuildMemberList.dispatch` """ - if self._uninitialized: + if self.member_list is None: # if the list is currently uninitialized, # no subscribers actually happened, so # we can safely drop the incoming event. @@ -70,25 +314,36 @@ class GuildMemberList(): class LazyGuildDispatcher(Dispatcher): """Main class holding the member lists for lazy guilds.""" + # channel ids KEY_TYPE = int - VAL_TYPE = int + + # the session ids subscribing to channels + VAL_TYPE = str def __init__(self, main): super().__init__(main) - self.state = defaultdict(GuildMemberList) - async def sub(self, guild_id, user_id): - await self.state[guild_id].sub(user_id) + self.storage = main.app.storage - async def unsub(self, guild_id, user_id): - await self.state[guild_id].unsub(user_id) + # {chan_id: gml, ...} + self.state = {} - async def dispatch(self, guild_id: int, event: str, data): - """Dispatch an event to the member list. + async def get_gml(self, channel_id: int): + try: + return self.state[channel_id] + except KeyError: + guild_id = await self.storage.guild_from_channel( + channel_id + ) - GuildMemberList will make sure of converting it to - GUILD_MEMBER_LIST_UPDATE events. - """ - member_list = self.state[guild_id] - await member_list.dispatch(event, data) + gml = GuildMemberList(guild_id, channel_id, self) + self.state[channel_id] = gml + return gml + async def sub(self, chan_id, session_id): + gml = await self.get_gml(chan_id) + await gml.sub(session_id) + + async def unsub(self, chan_id, session_id): + gml = await self.get_gml(chan_id) + await gml.unsub(session_id) diff --git a/litecord/storage.py b/litecord/storage.py index a7e37eb..7337e54 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -440,6 +440,7 @@ class Storage: permissions, managed, mentionable FROM roles WHERE guild_id = $1 + ORDER BY position ASC """, guild_id) return list(map(dict, roledata)) @@ -966,7 +967,6 @@ class Storage: """, user_id) for row in settings: - print(dict(row)) gid = int(row['guild_id']) drow = dict(row) From 26058367908952e591b5077b53f5eccc5bb65f0b Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 24 Oct 2018 19:09:05 -0300 Subject: [PATCH 03/69] dispatcher: add dispatch_filter - presence: (basic) handle member lists when presence update. Note that the respective GUILD_UPDATE_MEMBER_LIST doesn't happen yet. we'll need roles beforehand. --- litecord/dispatcher.py | 9 ++++++ litecord/gateway/websocket.py | 20 ------------ litecord/presence.py | 57 ++++++++++++++++++++++++++++------- litecord/pubsub/dispatcher.py | 8 +++++ litecord/pubsub/guild.py | 21 +++++++++++-- litecord/pubsub/lazy_guild.py | 24 +++++++++++++++ 6 files changed, 105 insertions(+), 34 deletions(-) diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index e7e392b..10d2f7a 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -106,6 +106,15 @@ class EventDispatcher: for key in keys: await self.dispatch(backend_str, key, *args, **kwargs) + async def dispatch_filter(self, backend_str: str, + key: Any, func, *args): + """Dispatch to a backend that only accepts + (event, data) arguments with an optional filter + function.""" + backend = self.backends[backend_str] + key = backend.KEY_TYPE(key) + return await backend.dispatch_filter(key, func, *args) + async def reset(self, backend_str: str, key: Any): """Reset the bucket in the given backend.""" backend = self.backends[backend_str] diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 7fce369..6ad0b22 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -685,26 +685,6 @@ class GatewayWebsocket: }, ... ] } - - # Implementation defails. - - Lazy guilds are complicated to deal with in the backend level - as there are a lot of computation to be done for each request. - - The current implementation is rudimentary and does not account - for any roles inside the guild. - - A correct implementation would take account of roles and make - the correct groups on list_data: - - For each channel in lazy_request['channels']: - - get all roles that have Read Messages on the channel: - - Also fetch their member counts, as it'll be important - - with the role list, order them like you normally would - (by their role priority) - - based on the channel's range's min and max and the ordered - role list, you can get the roles wanted for your list_data reply. - - make new groups ONLY when the role is hoisted. """ data = payload['d'] diff --git a/litecord/presence.py b/litecord/presence.py index 3666f3e..8e107c9 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -1,8 +1,11 @@ from typing import List, Dict, Any from random import choice +from logbook import Logger from quart import current_app as app +log = Logger(__name__) + def status_cmp(status: str, other_status: str) -> bool: """Compare if `status` is better than the `other_status` @@ -100,20 +103,50 @@ class PresenceManager: game = state['game'] - await self.dispatcher.dispatch_guild( - guild_id, 'PRESENCE_UPDATE', { - 'user': member['user'], - 'roles': member['roles'], - 'guild_id': guild_id, + lazy_guild_store = self.dispatcher.backends['lazy_guild'] + lists = lazy_guild_store.get_gml_guild(guild_id) - 'status': state['status'], + # shards that are in lazy guilds with 'everyone' + # enabled + in_lazy = [] - # rich presence stuff - 'game': game, - 'activities': [game] if game else [] - } + for member_list in lists: + session_ids = await member_list.pres_update( + int(member['user']['id']), + member['roles'], + state['status'], + game + ) + + log.debug('Lazy Dispatch to {}', + len(session_ids)) + + if member_list.channel_id == 'everyone': + in_lazy.extend(session_ids) + + pres_update_payload = { + 'user': member['user'], + 'roles': member['roles'], + 'guild_id': str(guild_id), + + 'status': state['status'], + + # rich presence stuff + 'game': game, + 'activities': [game] if game else [] + } + + # everyone not in lazy guild mode + # gets a PRESENCE_UPDATE + await self.dispatcher.dispatch_filter( + 'guild', guild_id, + lambda session_id: session_id not in in_lazy, + + 'PRESENCE_UPDATE', pres_update_payload ) + return in_lazy + async def dispatch_pres(self, user_id: int, state: dict): """Dispatch a new presence to all guilds the user is in. @@ -122,10 +155,12 @@ class PresenceManager: if state['status'] == 'invisible': state['status'] = 'offline' + # TODO: shard-aware guild_ids = await self.storage.get_user_guilds(user_id) for guild_id in guild_ids: - await self.dispatch_guild_pres(guild_id, user_id, state) + await self.dispatch_guild_pres( + guild_id, user_id, state) # dispatch to all friends that are subscribed to them user = await self.storage.get_user(user_id) diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index f65104d..c9da2e7 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -37,6 +37,14 @@ class Dispatcher: """Unsubscribe an elemtnt from the channel/key.""" raise NotImplementedError + async def dispatch_filter(self, _key, _func, *_args): + """Selectively dispatch to the list of subscribed users. + + The selection logic is completly arbitraty and up to the + Pub/Sub backend. + """ + raise NotImplementedError + async def dispatch(self, _key, *_args): """Dispatch an event to the given channel/key.""" raise NotImplementedError diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 6896d93..6613fcb 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -55,9 +55,10 @@ class GuildDispatcher(DispatcherWithState): # same thing happening from sub() happens on unsub() await self._chan_action('unsub', guild_id, user_id) - async def dispatch(self, guild_id: int, - event: str, data: Any): - """Dispatch an event to all subscribers of the guild.""" + async def dispatch_filter(self, guild_id: int, func, + event: str, data: Any): + """Selectively dispatch to session ids that have + func(session_id) true.""" user_ids = self.state[guild_id] dispatched = 0 @@ -74,8 +75,22 @@ class GuildDispatcher(DispatcherWithState): await self.unsub(guild_id, user_id) continue + # filter the ones that matter + states = list(filter( + lambda state: func(state.session_id), states + )) + dispatched += await self._dispatch_states( states, event, data) log.info('Dispatched {} {!r} to {} states', guild_id, event, dispatched) + + async def dispatch(self, guild_id: int, + event: str, data: Any): + """Dispatch an event to all subscribers of the guild.""" + await self.dispatch_filter( + guild_id, + lambda sess_id: True, + event, data, + ) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index ef7e0b2..0bcf1f5 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -256,6 +256,9 @@ class GuildMemberList: # subscribe the state to this list await self.sub(session_id) + # TODO: subscribe shard to 'everyone' + # and forward the query to that list + reply = { 'guild_id': str(self.guild_id), @@ -290,8 +293,14 @@ class GuildMemberList: 'items': items[start:end], }) + # the first GUILD_MEMBER_LIST_UPDATE for a shard + # is dispatched here. await state.ws.dispatch('GUILD_MEMBER_LIST_UPDATE', reply) + async def pres_update(self, user_id: int, roles: List[str], + status: str, game: dict) -> List[str]: + return list(self.state) + async def dispatch(self, event: str, data: Any): """The dispatch() method here, instead of being about dispatching a single event to the subscribed @@ -328,7 +337,14 @@ class LazyGuildDispatcher(Dispatcher): # {chan_id: gml, ...} self.state = {} + #: store which guilds have their + # respective GMLs + # {guild_id: [chan_id, ...], ...} + self.guild_map = defaultdict(list) + async def get_gml(self, channel_id: int): + """Get a guild list for a channel ID, + generating it if it doesn't exist.""" try: return self.state[channel_id] except KeyError: @@ -338,8 +354,16 @@ class LazyGuildDispatcher(Dispatcher): gml = GuildMemberList(guild_id, channel_id, self) self.state[channel_id] = gml + self.guild_map[guild_id].append(channel_id) return gml + def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]: + """Get all member lists for a given guild.""" + return list(map( + self.state.get, + self.guild_map[guild_id] + )) + async def sub(self, chan_id, session_id): gml = await self.get_gml(chan_id) await gml.sub(session_id) From 75a8e77a21b6c8ba2e9a3716c138859028116b13 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 24 Oct 2018 19:25:44 -0300 Subject: [PATCH 04/69] blueprints.dms: make sure no double dms happen by first checking existance, then inserting if none was found. --- litecord/blueprints/dms.py | 60 +++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py index 7d625a2..2a83c54 100644 --- a/litecord/blueprints/dms.py +++ b/litecord/blueprints/dms.py @@ -38,41 +38,47 @@ async def try_dm_state(user_id: int, dm_id: int): """, user_id, dm_id) +async def jsonify_dm(dm_id: int, user_id: int): + dm_chan = await app.storage.get_dm(dm_id, user_id) + return jsonify(dm_chan) + + async def create_dm(user_id, recipient_id): """Create a new dm with a user, or get the existing DM id if it already exists.""" + + dm_id = await app.db.fetchval(""" + SELECT id + FROM dm_channels + WHERE (party1_id = $1 OR party2_id = $1) AND + (party1_id = $2 OR party2_id = $2) + """, user_id, recipient_id) + + if dm_id: + return await jsonify_dm(dm_id, user_id) + + # if no dm was found, create a new one + dm_id = get_snowflake() + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, dm_id, ChannelType.DM.value) - try: - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, dm_id, ChannelType.DM.value) + await app.db.execute(""" + INSERT INTO dm_channels (id, party1_id, party2_id) + VALUES ($1, $2, $3) + """, dm_id, user_id, recipient_id) - await app.db.execute(""" - INSERT INTO dm_channels (id, party1_id, party2_id) - VALUES ($1, $2, $3) - """, dm_id, user_id, recipient_id) + # the dm state is something we use + # to give the currently "open dms" + # on the client. - # the dm state is something we use - # to give the currently "open dms" - # on the client. + # we don't open a dm for the peer/recipient + # until the user sends a message. + await try_dm_state(user_id, dm_id) - # we don't open a dm for the peer/recipient - # until the user sends a message. - await try_dm_state(user_id, dm_id) - - except UniqueViolationError: - # the dm already exists - dm_id = await app.db.fetchval(""" - SELECT id - FROM dm_channels - WHERE (party1_id = $1 OR party2_id = $1) AND - (party2_id = $2 OR party2_id = $2) - """, user_id, recipient_id) - - dm = await app.storage.get_dm(dm_id, user_id) - return jsonify(dm) + return await jsonify_dm(dm_id, user_id) @bp.route('/@me/channels', methods=['POST']) From aaa11be2587790d2d876a2c493803aef522a1d36 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 02:34:17 -0300 Subject: [PATCH 05/69] blueprints.guilds: add auto-role and auto-channel creation also simplify a lot of repeated code on the blueprint. - litecord: add permissions module for future role code - schemas: add channel_type, guild_name, channel_name types - schemas: add GUILD_CREATE schema --- litecord/blueprints/guilds.py | 248 ++++++++++++++++++++++------------ litecord/permissions.py | 54 ++++++++ litecord/schemas.py | 58 +++++++- 3 files changed, 269 insertions(+), 91 deletions(-) create mode 100644 litecord/permissions.py diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index d43b033..8c4bbb6 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -9,6 +9,7 @@ from .channels import channel_ack from .checks import guild_check bp = Blueprint('guilds', __name__) +DEFAULT_EVERYONE_PERMS = 104324161 async def guild_owner_check(user_id: int, guild_id: int): @@ -48,8 +49,116 @@ async def create_guild_settings(guild_id: int, user_id: int): """, m_notifs, user_id, guild_id) +async def add_member(guild_id: int, user_id: int): + """Add a user to a guild.""" + await app.db.execute(""" + INSERT INTO members (user_id, guild_id) + VALUES ($1, $2) + """, user_id, guild_id) + + await create_guild_settings(guild_id, user_id) + + +async def guild_create_roles_prep(guild_id: int, roles: list): + """Create roles in preparation in guild create.""" + # by reaching this point in the code that means + # roles is not nullable, which means + # roles has at least one element, so we can access safely. + + # the first member in the roles array + # are patches to the @everyone role + everyone_patches = roles[0] + for field in everyone_patches: + await app.db.execute(f""" + UPDATE roles + SET {field}={everyone_patches[field]} + WHERE roles.id = $1 + """, guild_id) + + default_perms = (everyone_patches.get('permissions') + or DEFAULT_EVERYONE_PERMS) + + # from the 2nd and forward, + # should be treated as new roles + for role in roles[1:]: + new_role_id = get_snowflake() + + await app.db.execute( + """ + INSERT INTO roles (id, guild_id, name, color, + hoist, position, permissions, managed, mentionable) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + """, + new_role_id, + guild_id, + role['name'], + role.get('color', 0), + role.get('hoist', False), + role.get('permissions', default_perms), + False, + role.get('mentionable', False) + ) + + +async def _specific_chan_create(channel_id, ctype, **kwargs): + if ctype == ChannelType.GUILD_TEXT: + await app.db.execute(""" + INSERT INTO guild_text_channels (id, topic) + VALUES ($1) + """, channel_id, kwargs.get('topic', '')) + elif ctype == ChannelType.GUILD_VOICE: + await app.db.execute( + """ + INSERT INTO guild_voice_channels (id, bitrate, user_limit) + VALUES ($1, $2, $3) + """, + channel_id, + kwargs.get('bitrate', 64), + kwargs.get('user_limit', 0) + ) + + +async def create_guild_channel(guild_id: int, channel_id: int, + ctype: ChannelType, **kwargs): + """Create a channel in a guild.""" + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, channel_id, ctype.value) + + # calc new pos + max_pos = await app.db.fetchval(""" + SELECT MAX(position) + FROM guild_channels + WHERE guild_id = $1 + """, guild_id) + + # all channels go to guild_channels + await app.db.execute(""" + INSERT INTO guild_channels (id, guild_id, name, position) + VALUES ($1, $2, $3, $4) + """, channel_id, guild_id, kwargs['name'], max_pos + 1) + + # the rest of sql magic is dependant on the channel + # we're creating (a text or voice or category), + # so we use this function. + await _specific_chan_create(channel_id, ctype, **kwargs) + + +async def guild_create_channels_prep(guild_id: int, channels: list): + """Create channels pre-guild create""" + for channel_raw in channels: + channel_id = get_snowflake() + ctype = ChannelType(channel_raw['type']) + + await create_guild_channel(guild_id, channel_id, ctype) + + @bp.route('', methods=['POST']) async def create_guild(): + """Create a new guild, assigning + the user creating it as the owner and + making them join.""" user_id = await token_check() j = await request.get_json() @@ -66,36 +175,27 @@ async def create_guild(): j.get('default_message_notifications', 0), j.get('explicit_content_filter', 0)) - await app.db.execute(""" - INSERT INTO members (user_id, guild_id) - VALUES ($1, $2) - """, user_id, guild_id) - - await create_guild_settings(guild_id, user_id) + await add_member(guild_id, user_id) + # create the default @everyone role (everyone has it by default, + # so we don't insert that in the table) await app.db.execute(""" INSERT INTO roles (id, guild_id, name, position, permissions) VALUES ($1, $2, $3, $4, $5) - """, guild_id, guild_id, '@everyone', 0, 104324161) + """, guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS) + # create a single #general channel. general_id = get_snowflake() - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, general_id, ChannelType.GUILD_TEXT.value) + await create_guild_channel( + guild_id, general_id, ChannelType.GUILD_TEXT, + name='general') - await app.db.execute(""" - INSERT INTO guild_channels (id, guild_id, name, position) - VALUES ($1, $2, $3, $4) - """, general_id, guild_id, 'general', 0) + if j.get('roles'): + await guild_create_roles_prep(guild_id, j['roles']) - await app.db.execute(""" - INSERT INTO guild_text_channels (id) - VALUES ($1) - """, general_id) - - # TODO: j['roles'] and j['channels'] + if j.get('channels'): + await guild_create_channels_prep(guild_id, j['channels']) guild_total = await app.storage.get_guild_full(guild_id, user_id, 250) @@ -106,12 +206,13 @@ async def create_guild(): @bp.route('/', methods=['GET']) async def get_guild(guild_id): + """Get a single guilds' information.""" user_id = await token_check() + await guild_check(user_id, guild_id) - gj = await app.storage.get_guild(guild_id, user_id) - gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) - - return jsonify({**gj, **gj_extra}) + return jsonify( + await app.storage.get_guild_full(guild_id, user_id, 250) + ) @bp.route('/', methods=['UPDATE']) @@ -139,8 +240,6 @@ async def update_guild(guild_id): """, j['name'], guild_id) if 'region' in j: - # TODO: check region value - await app.db.execute(""" UPDATE guilds SET region = $1 @@ -167,15 +266,14 @@ async def update_guild(guild_id): WHERE guild_id = $2 """, j[field], guild_id) - # return guild object - gj = await app.storage.get_guild(guild_id, user_id) - gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) + guild = await app.storage.get_guild_full( + guild_id, user_id + ) - gj_total = {**gj, **gj_extra} + await app.dispatcher.dispatch_guild( + guild_id, 'GUILD_UPDATE', guild) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_UPDATE', gj_total) - - return jsonify({**gj, **gj_extra}) + return jsonify(guild) @bp.route('/', methods=['DELETE']) @@ -185,7 +283,7 @@ async def delete_guild(guild_id): await guild_owner_check(user_id, guild_id) await app.db.execute(""" - DELETE FROM guild + DELETE FROM guilds WHERE guilds.id = $1 """, guild_id) @@ -219,42 +317,19 @@ async def create_channel(guild_id): # TODO: check permissions for MANAGE_CHANNELS await guild_check(user_id, guild_id) - new_channel_id = get_snowflake() channel_type = j.get('type', ChannelType.GUILD_TEXT) - channel_type = ChannelType(channel_type) if channel_type not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE): raise BadRequest('Invalid channel type') - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, new_channel_id, channel_type.value) - - max_pos = await app.db.fetchval(""" - SELECT MAX(position) - FROM guild_channels - WHERE guild_id = $1 - """, guild_id) - - if channel_type == ChannelType.GUILD_TEXT: - await app.db.execute(""" - INSERT INTO guild_channels (id, guild_id, name, position) - VALUES ($1, $2, $3, $4) - """, new_channel_id, guild_id, j['name'], max_pos + 1) - - await app.db.execute(""" - INSERT INTO guild_text_channels (id) - VALUES ($1) - """, new_channel_id) - - elif channel_type == ChannelType.GUILD_VOICE: - raise NotImplementedError() + new_channel_id = get_snowflake() + await create_guild_channel(guild_id, new_channel_id, channel_type,) chan = await app.storage.get_channel(new_channel_id) - await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_CREATE', chan) + await app.dispatcher.dispatch_guild( + guild_id, 'CHANNEL_CREATE', chan) return jsonify(chan) @@ -271,15 +346,16 @@ async def modify_channel_pos(guild_id): @bp.route('//members/', methods=['GET']) async def get_guild_member(guild_id, member_id): + """Get a member's information in a guild.""" user_id = await token_check() await guild_check(user_id, guild_id) - member = await app.storage.get_single_member(guild_id, member_id) return jsonify(member) @bp.route('//members', methods=['GET']) async def get_members(guild_id): + """Get members inside a guild.""" user_id = await token_check() await guild_check(user_id, guild_id) @@ -304,6 +380,7 @@ async def get_members(guild_id): @bp.route('//members/', methods=['PATCH']) async def modify_guild_member(guild_id, member_id): + """Modify a members' information in a guild.""" j = await request.get_json() if 'nick' in j: @@ -350,6 +427,7 @@ async def modify_guild_member(guild_id, member_id): @bp.route('//members/@me/nick', methods=['PATCH']) async def update_nickname(guild_id): + """Update a member's nickname in a guild.""" user_id = await token_check() await guild_check(user_id, guild_id) @@ -371,28 +449,36 @@ async def update_nickname(guild_id): return j['nick'] -@bp.route('//members/', methods=['DELETE']) -async def kick_member(guild_id, member_id): - user_id = await token_check() - - # TODO: check KICK_MEMBERS permission - await guild_owner_check(user_id, guild_id) +async def remove_member(guild_id: int, member_id: int): + """Do common tasks related to deleting a member from the guild, + such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" await app.db.execute(""" DELETE FROM members WHERE guild_id = $1 AND user_id = $2 """, guild_id, member_id) - await app.dispatcher.dispatch_user(user_id, 'GUILD_DELETE', { + await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', { 'guild_id': guild_id, 'unavailable': False, }) + await app.dispatcher.unsub('guild', guild_id, member_id) + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { 'guild': guild_id, 'user': await app.storage.get_user(member_id), }) + +@bp.route('//members/', methods=['DELETE']) +async def kick_member(guild_id, member_id): + """Remove a member from a guild.""" + user_id = await token_check() + + # TODO: check KICK_MEMBERS permission + await guild_owner_check(user_id, guild_id) + await remove_member(guild_id, member_id) return '', 204 @@ -434,22 +520,7 @@ async def create_ban(guild_id, member_id): VALUES ($1, $2, $3) """, guild_id, member_id, j.get('reason', '')) - await app.db.execute(""" - DELETE FROM members - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, user_id) - - await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', { - 'guild_id': guild_id, - 'unavailable': False, - }) - - await app.dispatcher.unsub('guild', guild_id, member_id) - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { - 'guild': guild_id, - 'user': await app.storage.get_user(member_id), - }) + await remove_member(guild_id, member_id) await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {**{ 'guild': guild_id, @@ -460,6 +531,10 @@ async def create_ban(guild_id, member_id): @bp.route('//messages/search') async def search_messages(guild_id): + """Search messages in a guild. + + This is an undocumented route. + """ user_id = await token_check() await guild_check(user_id, guild_id) @@ -474,6 +549,7 @@ async def search_messages(guild_id): @bp.route('//ack', methods=['POST']) async def ack_guild(guild_id): + """ACKnowledge all messages in the guild.""" user_id = await token_check() await guild_check(user_id, guild_id) diff --git a/litecord/permissions.py b/litecord/permissions.py new file mode 100644 index 0000000..1f7a020 --- /dev/null +++ b/litecord/permissions.py @@ -0,0 +1,54 @@ +import ctypes + +# so we don't keep repeating the same +# type for all the fields +_i = ctypes.c_uint8 + +class _RawPermsBits(ctypes.LittleEndianStructure): + """raw bitfield for discord's permission number.""" + _fields_ = [ + ('create_invites', _i, 1), + ('kick_members', _i, 1), + ('ban_members', _i, 1), + ('administrator', _i, 1), + ('manage_channels', _i, 1), + ('manage_guild', _i, 1), + ('add_reactions', _i, 1), + ('view_audit_log', _i, 1), + ('priority_speaker', _i, 1), + ('_unused1', _i, 1), + ('read_messages', _i, 1), + ('send_messages', _i, 1), + ('send_tts', _i, 1), + ('manage_messages', _i, 1), + ('embed_links', _i, 1), + ('attach_files', _i, 1), + ('read_history', _i, 1), + ('mention_everyone', _i, 1), + ('external_emojis', _i, 1), + ('_unused2', _i, 1), + ('connect', _i, 1), + ('speak', _i, 1), + ('mute_members', _i, 1), + ('deafen_members', _i, 1), + ('move_members', _i, 1), + ('use_voice_activation', _i, 1), + ('change_nickname', _i, 1), + ('manage_nicknames', _i, 1), + ('manage_roles', _i, 1), + ('manage_webhooks', _i, 1), + ('manage_emojis', _i, 1), + ] + + +class Permissions(ctypes.Union): + _fields_ = [ + ('bits', _RawPermsBits), + ('binary', ctypes.c_uint64), + ] + + def __init__(self, val: int): + self.binary = val + + def numby(self): + return self.binary diff --git a/litecord/schemas.py b/litecord/schemas.py index 1c3eda3..98f9134 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -1,11 +1,13 @@ import re +from typing import Union, Dict, List, Any from cerberus import Validator from logbook import Logger from .errors import BadRequest +from .permissions import Permissions from .enums import ActivityType, StatusType, ExplicitFilter, \ - RelationshipType, MessageNotifications + RelationshipType, MessageNotifications, ChannelType log = Logger(__name__) @@ -61,6 +63,9 @@ class LitecordValidator(Validator): def _validate_type_activity_type(self, value: int) -> bool: return value in ActivityType.values() + def _validate_type_channel_type(self, value: int) -> bool: + return value in ChannelType.values() + def _validate_type_status_external(self, value: str) -> bool: statuses = StatusType.values() @@ -94,8 +99,19 @@ class LitecordValidator(Validator): return val in MessageNotifications.values() + def _validate_type_guild_name(self, value: str) -> bool: + return 2 <= len(value) <= 100 -def validate(reqjson, schema, raise_err: bool = True): + def _validate_type_channel_name(self, value: str) -> bool: + # for now, we'll use the same validation for guild_name + return self._validate_type_guild_name(value) + + +def validate(reqjson: Union[Dict, List], schema: Dict, + raise_err: bool = True) -> Union[Dict, List]: + """Validate a given document (user-input) and give + the correct document as a result. + """ validator = LitecordValidator(schema) if not validator.validate(reqjson): @@ -146,12 +162,44 @@ USER_UPDATE = { } +PARTIAL_ROLE_GUILD_CREATE = { + 'name': {'type': 'role_name'}, + 'color': {'type': 'number', 'default': 0}, + 'hoist': {'type': 'boolean', 'default': False}, + + # NOTE: no position on partial role (on guild create) + + 'permissions': {'coerce': Permissions, 'required': False}, + 'mentionable': {'type': 'boolean', 'default': False}, +} + +PARTIAL_CHANNEL_GUILD_CREATE = { + 'name': {'type': 'channel_name'}, + 'type': {'type': 'channel_type'} +} + +GUILD_CREATE = { + 'name': {'type': 'guild_name'}, + 'region': {'type': 'voice_region'}, + 'icon': {'type': 'icon', 'required': False, 'nullable': True}, + + 'verification_level': { + 'type': 'verification_level', 'default': 0}, + 'default_message_notifications': { + 'type': 'msg_notifications', 'default': 0}, + 'explicit_content_filter': { + 'type': 'explicit', 'default': 0}, + + 'roles': { + 'type': 'list', 'required': False, + 'schema': PARTIAL_ROLE_GUILD_CREATE}, + 'channels': { + 'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE}, +} GUILD_UPDATE = { 'name': { - 'type': 'string', - 'minlength': 2, - 'maxlength': 100, + 'type': 'guild_name', 'required': False }, 'region': {'type': 'voice_region', 'required': False}, From e8ebfe6eeb191ee4440d20348c9e598c32756010 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 03:39:11 -0300 Subject: [PATCH 06/69] schemas: fix GW_ACTIVITY schema - blueprints.guilds: use GUILD_CREATE schema --- litecord/blueprints/guilds.py | 4 +- litecord/schemas.py | 121 +++++++++++++++++++--------------- 2 files changed, 70 insertions(+), 55 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 8c4bbb6..295167d 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -4,7 +4,7 @@ from ..auth import token_check from ..snowflake import get_snowflake from ..enums import ChannelType from ..errors import Forbidden, GuildNotFound, BadRequest -from ..schemas import validate, GUILD_UPDATE +from ..schemas import validate, GUILD_CREATE, GUILD_UPDATE from .channels import channel_ack from .checks import guild_check @@ -160,7 +160,7 @@ async def create_guild(): the user creating it as the owner and making them join.""" user_id = await token_check() - j = await request.get_json() + j = validate(await request.get_json(), GUILD_CREATE) guild_id = get_snowflake() diff --git a/litecord/schemas.py b/litecord/schemas.py index 98f9134..342c87d 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -32,7 +32,7 @@ class LitecordValidator(Validator): return bool(USERNAME_REGEX.match(value)) def _validate_type_email(self, value: str) -> bool: - """Validate against the username regex.""" + """Validate against the email regex.""" return bool(EMAIL_REGEX.match(value)) def _validate_type_b64_icon(self, value: str) -> bool: @@ -114,7 +114,13 @@ def validate(reqjson: Union[Dict, List], schema: Dict, """ validator = LitecordValidator(schema) - if not validator.validate(reqjson): + try: + valid = validator.validate(reqjson) + except Exception: + log.exception('Error while validating') + raise Exception(f'Error while validating: {reqjson}') + + if not valid: errs = validator.errors log.warning('Error validating doc {!r}: {!r}', reqjson, errs) @@ -163,19 +169,25 @@ USER_UPDATE = { } PARTIAL_ROLE_GUILD_CREATE = { - 'name': {'type': 'role_name'}, - 'color': {'type': 'number', 'default': 0}, - 'hoist': {'type': 'boolean', 'default': False}, + 'type': 'dict', + 'schema': { + 'name': {'type': 'role_name'}, + 'color': {'type': 'number', 'default': 0}, + 'hoist': {'type': 'boolean', 'default': False}, - # NOTE: no position on partial role (on guild create) + # NOTE: no position on partial role (on guild create) - 'permissions': {'coerce': Permissions, 'required': False}, - 'mentionable': {'type': 'boolean', 'default': False}, + 'permissions': {'coerce': Permissions, 'required': False}, + 'mentionable': {'type': 'boolean', 'default': False}, + } } PARTIAL_CHANNEL_GUILD_CREATE = { - 'name': {'type': 'channel_name'}, - 'type': {'type': 'channel_type'} + 'type': 'dict', + 'schema': { + 'name': {'type': 'channel_name'}, + 'type': {'type': 'channel_type'}, + } } GUILD_CREATE = { @@ -244,57 +256,60 @@ MESSAGE_CREATE = { GW_ACTIVITY = { - 'name': {'type': 'string', 'required': True}, - 'type': {'type': 'activity_type', 'required': True}, + 'type': 'dict', + 'schema': { + 'name': {'type': 'string', 'required': True}, + 'type': {'type': 'activity_type', 'required': True}, - 'url': {'type': 'string', 'required': False, 'nullable': True}, + 'url': {'type': 'string', 'required': False, 'nullable': True}, - 'timestamps': { - 'type': 'dict', - 'required': False, - 'schema': { - 'start': {'type': 'number', 'required': True}, - 'end': {'type': 'number', 'required': True}, + 'timestamps': { + 'type': 'dict', + 'required': False, + 'schema': { + 'start': {'type': 'number', 'required': True}, + 'end': {'type': 'number', 'required': False}, + }, }, - }, - 'application_id': {'type': 'snowflake', 'required': False, - 'nullable': False}, - 'details': {'type': 'string', 'required': False, 'nullable': True}, - 'state': {'type': 'string', 'required': False, 'nullable': True}, + 'application_id': {'type': 'snowflake', 'required': False, + 'nullable': False}, + 'details': {'type': 'string', 'required': False, 'nullable': True}, + 'state': {'type': 'string', 'required': False, 'nullable': True}, - 'party': { - 'type': 'dict', - 'required': False, - 'schema': { - 'id': {'type': 'snowflake', 'required': False}, - 'size': {'type': 'list', 'required': False}, - } - }, + 'party': { + 'type': 'dict', + 'required': False, + 'schema': { + 'id': {'type': 'snowflake', 'required': False}, + 'size': {'type': 'list', 'required': False}, + } + }, - 'assets': { - 'type': 'dict', - 'required': False, - 'schema': { - 'large_image': {'type': 'snowflake', 'required': False}, - 'large_text': {'type': 'string', 'required': False}, - 'small_image': {'type': 'snowflake', 'required': False}, - 'small_text': {'type': 'string', 'required': False}, - } - }, + 'assets': { + 'type': 'dict', + 'required': False, + 'schema': { + 'large_image': {'type': 'snowflake', 'required': False}, + 'large_text': {'type': 'string', 'required': False}, + 'small_image': {'type': 'snowflake', 'required': False}, + 'small_text': {'type': 'string', 'required': False}, + } + }, - 'secrets': { - 'type': 'dict', - 'required': False, - 'schema': { - 'join': {'type': 'string', 'required': False}, - 'spectate': {'type': 'string', 'required': False}, - 'match': {'type': 'string', 'required': False}, - } - }, + 'secrets': { + 'type': 'dict', + 'required': False, + 'schema': { + 'join': {'type': 'string', 'required': False}, + 'spectate': {'type': 'string', 'required': False}, + 'match': {'type': 'string', 'required': False}, + } + }, - 'instance': {'type': 'boolean', 'required': False}, - 'flags': {'type': 'number', 'required': False}, + 'instance': {'type': 'boolean', 'required': False}, + 'flags': {'type': 'number', 'required': False}, + } } GW_STATUS_UPDATE = { From 856839d9e7c6cb6ecef7e8824085543f706bbf58 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 03:50:29 -0300 Subject: [PATCH 07/69] presence: don't send PRESENCE_UPDATEs about the same user --- litecord/presence.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/litecord/presence.py b/litecord/presence.py index 8e107c9..48ad80f 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -136,11 +136,23 @@ class PresenceManager: 'activities': [game] if game else [] } + def _sane_session(session_id): + state = self.state_manager.fetch_raw(session_id) + uid = int(member['user']['id']) + + if not state: + return False + + # we don't want to send a presence update + # to the same user + return (state.user_id != uid and + session_id not in in_lazy) + # everyone not in lazy guild mode # gets a PRESENCE_UPDATE await self.dispatcher.dispatch_filter( 'guild', guild_id, - lambda session_id: session_id not in in_lazy, + _sane_session, 'PRESENCE_UPDATE', pres_update_payload ) From d2562d3262d3ef751e649e460901db01f04cb2ad Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 04:15:51 -0300 Subject: [PATCH 08/69] schemas: add role_name and verification_level types - schemas: also fix GUILD_UPDATE.icon and GUILD_CREATE.icon --- litecord/schemas.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/litecord/schemas.py b/litecord/schemas.py index 342c87d..6526314 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -6,8 +6,10 @@ from logbook import Logger from .errors import BadRequest from .permissions import Permissions -from .enums import ActivityType, StatusType, ExplicitFilter, \ - RelationshipType, MessageNotifications, ChannelType +from .enums import ( + ActivityType, StatusType, ExplicitFilter, RelationshipType, + MessageNotifications, ChannelType, VerificationLevel +) log = Logger(__name__) @@ -26,6 +28,14 @@ EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M) ANIMOJI_MENTION = re.compile(r'', re.A | re.M) +def _in_enum(enum, value: int): + try: + enum(value) + return True + except ValueError: + return False + + class LitecordValidator(Validator): def _validate_type_username(self, value: str) -> bool: """Validate against the username regex.""" @@ -58,7 +68,10 @@ class LitecordValidator(Validator): def _validate_type_voice_region(self, value: str) -> bool: # TODO: complete this list - return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia') + return value.lower() in ('brazil', 'us-east', 'us-west', 'us-south', 'russia') + + def _validate_type_verification_level(self, value: int) -> bool: + return _in_enum(VerificationLevel, value) def _validate_type_activity_type(self, value: int) -> bool: return value in ActivityType.values() @@ -102,6 +115,9 @@ class LitecordValidator(Validator): def _validate_type_guild_name(self, value: str) -> bool: return 2 <= len(value) <= 100 + def _validate_type_role_name(self, value: str) -> bool: + return 1 <= len(value) <= 100 + def _validate_type_channel_name(self, value: str) -> bool: # for now, we'll use the same validation for guild_name return self._validate_type_guild_name(value) @@ -193,7 +209,7 @@ PARTIAL_CHANNEL_GUILD_CREATE = { GUILD_CREATE = { 'name': {'type': 'guild_name'}, 'region': {'type': 'voice_region'}, - 'icon': {'type': 'icon', 'required': False, 'nullable': True}, + 'icon': {'type': 'b64_icon', 'required': False, 'nullable': True}, 'verification_level': { 'type': 'verification_level', 'default': 0}, @@ -215,7 +231,7 @@ GUILD_UPDATE = { 'required': False }, 'region': {'type': 'voice_region', 'required': False}, - 'icon': {'type': 'icon', 'required': False}, + 'icon': {'type': 'b64_icon', 'required': False}, 'verification_level': {'type': 'verification_level', 'required': False}, 'default_message_notifications': { From 888853458047cf3e815f8e7769a929ec9bca7252 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 04:27:48 -0300 Subject: [PATCH 09/69] blueprints.guilds: misc fixes to channel creation Instances should run this SQL to maintain consistency with `schema.sql` ```sql ALTER TABLE guild_channels DROP CONSTRAINT guild_channels_guild_id_fkey; ALTER TABLE guild_channels ADD CONSTRAINT guild_id_fkey FOREIGN KEY (guild_id) REFERENCES guilds (id) ON DELETE CASCADE; ``` --- litecord/blueprints/guilds.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 295167d..414edef 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -104,7 +104,7 @@ async def _specific_chan_create(channel_id, ctype, **kwargs): if ctype == ChannelType.GUILD_TEXT: await app.db.execute(""" INSERT INTO guild_text_channels (id, topic) - VALUES ($1) + VALUES ($1, $2) """, channel_id, kwargs.get('topic', '')) elif ctype == ChannelType.GUILD_VOICE: await app.db.execute( @@ -133,6 +133,9 @@ async def create_guild_channel(guild_id: int, channel_id: int, WHERE guild_id = $1 """, guild_id) + # account for the first channel in a guild too + max_pos = max_pos or 0 + # all channels go to guild_channels await app.db.execute(""" INSERT INTO guild_channels (id, guild_id, name, position) From f8e44d62bd167ed246be7aebb6ebee26738795e6 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 04:30:56 -0300 Subject: [PATCH 10/69] gateway.websocket: handle missing 'channels' on op 14 --- litecord/gateway/websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 6ad0b22..515eac6 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -698,7 +698,7 @@ class GatewayWebsocket: # make shard query lazy_guilds = self.ext.dispatcher.backends['lazy_guild'] - for chan_id, ranges in data['channels'].items(): + for chan_id, ranges in data.get('channels', {}).items(): chan_id = int(chan_id) member_list = await lazy_guilds.get_gml(chan_id) From dca1adc6f8410a03511fcf6770bb8e9fab357ba9 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 15:57:24 -0300 Subject: [PATCH 11/69] blueprints.guilds: add create_role function --- litecord/blueprints/guilds.py | 50 ++++++++++++++++++++++++----------- litecord/utils.py | 5 ++++ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 414edef..abab226 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -5,6 +5,7 @@ from ..snowflake import get_snowflake from ..enums import ChannelType from ..errors import Forbidden, GuildNotFound, BadRequest from ..schemas import validate, GUILD_CREATE, GUILD_UPDATE +from ..utils import dict_get from .channels import channel_ack from .checks import guild_check @@ -59,6 +60,37 @@ async def add_member(guild_id: int, user_id: int): await create_guild_settings(guild_id, user_id) +async def create_role(guild_id, name: str, **kwargs): + """Create a role in a guild.""" + new_role_id = get_snowflake() + + # TODO: use @everyone's perm number + default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS) + + await app.db.execute( + """ + INSERT INTO roles (id, guild_id, name, color, + hoist, position, permissions, managed, mentionable) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + """, + new_role_id, + guild_id, + name, + dict_get(kwargs, 'color', 0), + dict_get(kwargs, 'hoist', False), + dict_get(kwargs, 'permissions', default_perms), + False, + dict_get(kwargs, 'mentionable', False) + ) + + role = await app.storage.get_role(new_role_id, guild_id) + await app.dispatcher.dispatch_guild( + guild_id, 'GUILD_ROLE_CREATE', { + 'guild_id': guild_id, + 'role': role, + }) + + async def guild_create_roles_prep(guild_id: int, roles: list): """Create roles in preparation in guild create.""" # by reaching this point in the code that means @@ -81,22 +113,8 @@ async def guild_create_roles_prep(guild_id: int, roles: list): # from the 2nd and forward, # should be treated as new roles for role in roles[1:]: - new_role_id = get_snowflake() - - await app.db.execute( - """ - INSERT INTO roles (id, guild_id, name, color, - hoist, position, permissions, managed, mentionable) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, - new_role_id, - guild_id, - role['name'], - role.get('color', 0), - role.get('hoist', False), - role.get('permissions', default_perms), - False, - role.get('mentionable', False) + await create_role( + guild_id, role['name'], default_perms=default_perms, **role ) diff --git a/litecord/utils.py b/litecord/utils.py index a350dad..3949194 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -22,3 +22,8 @@ async def task_wrapper(name: str, coro): pass except: log.exception('{} task error', name) + + +def dict_get(mapping, key, default): + """Return `default` even when mapping[key] is None.""" + return mapping.get(key) or default From a08eb0d0689784ab1b5325189ccc6cd40b671902 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 16:21:27 -0300 Subject: [PATCH 12/69] blueprints.guilds: calculate new role position - blueprints.guilds: add basic POST /api/v6/guilds/:id/roles - schemas: add ROLE_CREATE - litecord: add types --- litecord/blueprints/guilds.py | 28 +++++++++++++++++++++++++++- litecord/schemas.py | 14 +++++++++++++- litecord/types.py | 15 +++++++++++++++ 3 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 litecord/types.py diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index abab226..691598d 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -4,7 +4,7 @@ from ..auth import token_check from ..snowflake import get_snowflake from ..enums import ChannelType from ..errors import Forbidden, GuildNotFound, BadRequest -from ..schemas import validate, GUILD_CREATE, GUILD_UPDATE +from ..schemas import validate, GUILD_CREATE, GUILD_UPDATE, ROLE_CREATE from ..utils import dict_get from .channels import channel_ack from .checks import guild_check @@ -67,6 +67,14 @@ async def create_role(guild_id, name: str, **kwargs): # TODO: use @everyone's perm number default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS) + max_pos = await app.db.fetchval(""" + SELECT MAX(position) + FROM roles + WHERE guild_id = $1 + """, guild_id) + + max_pos = max_pos or 0 + await app.db.execute( """ INSERT INTO roles (id, guild_id, name, color, @@ -78,6 +86,7 @@ async def create_role(guild_id, name: str, **kwargs): name, dict_get(kwargs, 'color', 0), dict_get(kwargs, 'hoist', False), + max_pos + 1, dict_get(kwargs, 'permissions', default_perms), False, dict_get(kwargs, 'mentionable', False) @@ -90,6 +99,8 @@ async def create_role(guild_id, name: str, **kwargs): 'role': role, }) + return role + async def guild_create_roles_prep(guild_id: int, roles: list): """Create roles in preparation in guild create.""" @@ -365,6 +376,21 @@ async def modify_channel_pos(guild_id): raise NotImplementedError +@bp.route('//roles', methods=['POST']) +async def create_guild_role(guild_id: int): + """Add a role to a guild""" + user_id = await token_check() + + # TODO: use check_guild and MANAGE_ROLES permission + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), ROLE_CREATE) + + role = await create_role(guild_id, j.get('name', 'new role'), **j) + + return jsonify(role) + + @bp.route('//members/', methods=['GET']) async def get_guild_member(guild_id, member_id): """Get a member's information in a guild.""" diff --git a/litecord/schemas.py b/litecord/schemas.py index 6526314..8eb4bfc 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -6,6 +6,7 @@ from logbook import Logger from .errors import BadRequest from .permissions import Permissions +from .types import Color from .enums import ( ActivityType, StatusType, ExplicitFilter, RelationshipType, MessageNotifications, ChannelType, VerificationLevel @@ -225,6 +226,7 @@ GUILD_CREATE = { 'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE}, } + GUILD_UPDATE = { 'name': { 'type': 'guild_name', @@ -249,13 +251,23 @@ GUILD_UPDATE = { } +ROLE_CREATE = { + 'name': {'type': 'string', 'default': 'new role'}, + 'permissions': {'coerce': Permissions, 'nullable': True}, + 'color': {'coerce': Color, 'default': 0}, + 'hoist': {'type': 'boolean', 'default': False}, + 'mentionable': {'type': 'boolean', 'default': False}, +} + + MEMBER_UPDATE = { 'nick': { 'type': 'nickname', 'minlength': 1, 'maxlength': 100, 'required': False, }, - 'roles': {'type': 'list', 'required': False}, + 'roles': {'type': 'list', 'required': False, + 'schema': {'coerce': int}}, 'mute': {'type': 'boolean', 'required': False}, 'deaf': {'type': 'boolean', 'required': False}, 'channel_id': {'type': 'snowflake', 'required': False}, diff --git a/litecord/types.py b/litecord/types.py new file mode 100644 index 0000000..4fddfff --- /dev/null +++ b/litecord/types.py @@ -0,0 +1,15 @@ + +class Color: + """Custom color class""" + def __init__(self, val: int): + self.blue = val & 255 + self.green = (val >> 8) & 255 + self.red = (val >> 16) & 255 + + @property + def value(self): + """Give the actual RGB integer encoding this color.""" + return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16) + + def __int__(self): + return self.value From 956498ac65c7585f4fa251c1aac18919412de121 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 16:35:00 -0300 Subject: [PATCH 13/69] blueprints.guilds: fix guild_owner_check - blueprints.guilds: fix create_guild_role - blueprints.guilds: fix giving ints in the place of snowflakes in some events --- litecord/blueprints/guilds.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 691598d..5410551 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -18,7 +18,7 @@ async def guild_owner_check(user_id: int, guild_id: int): owner_id = await app.db.fetchval(""" SELECT owner_id FROM guilds - WHERE guild_id = $1 + WHERE guilds.id = $1 """, guild_id) if not owner_id: @@ -73,8 +73,6 @@ async def create_role(guild_id, name: str, **kwargs): WHERE guild_id = $1 """, guild_id) - max_pos = max_pos or 0 - await app.db.execute( """ INSERT INTO roles (id, guild_id, name, color, @@ -86,7 +84,11 @@ async def create_role(guild_id, name: str, **kwargs): name, dict_get(kwargs, 'color', 0), dict_get(kwargs, 'hoist', False), - max_pos + 1, + + # set position = 0 when there isn't any + # other role (when we're creating the + # @everyone role) + max_pos + 1 if max_pos else 0, dict_get(kwargs, 'permissions', default_perms), False, dict_get(kwargs, 'mentionable', False) @@ -95,7 +97,7 @@ async def create_role(guild_id, name: str, **kwargs): role = await app.storage.get_role(new_role_id, guild_id) await app.dispatcher.dispatch_guild( guild_id, 'GUILD_ROLE_CREATE', { - 'guild_id': guild_id, + 'guild_id': str(guild_id), 'role': role, }) @@ -384,9 +386,13 @@ async def create_guild_role(guild_id: int): # TODO: use check_guild and MANAGE_ROLES permission await guild_owner_check(user_id, guild_id) - j = validate(await request.get_json(), ROLE_CREATE) + # client can just send null + j = validate(await request.get_json() or {}, ROLE_CREATE) - role = await create_role(guild_id, j.get('name', 'new role'), **j) + role_name = j['name'] + j.pop('name') + + role = await create_role(guild_id, role_name, **j) return jsonify(role) @@ -513,7 +519,7 @@ async def remove_member(guild_id: int, member_id: int): await app.dispatcher.unsub('guild', guild_id, member_id) await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { - 'guild': guild_id, + 'guild_id': str(guild_id), 'user': await app.storage.get_user(member_id), }) @@ -569,9 +575,10 @@ async def create_ban(guild_id, member_id): await remove_member(guild_id, member_id) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {**{ - 'guild': guild_id, - }, **(await app.storage.get_user(member_id))}) + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(member_id) + }) return '', 204 From 86705b0645b9449bcfdce5681b2d3f72efafdd04 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 26 Oct 2018 21:03:44 -0300 Subject: [PATCH 14/69] blueprints.guilds: add 2 role endpoints Add PATCH /api/v6/guilds/:id/roles for multiple position changes for roles and PATCH /api/v6/guilds/:id/roles/:id for single guild role changes - permissions: add int maigc method - schemas: add ROLE_UPDATE and ROLE_UPDATE_POSITION --- litecord/blueprints/guilds.py | 153 +++++++++++++++++++++++++++++++++- litecord/permissions.py | 3 + litecord/schemas.py | 22 +++++ 3 files changed, 175 insertions(+), 3 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 5410551..90b359b 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -4,7 +4,10 @@ from ..auth import token_check from ..snowflake import get_snowflake from ..enums import ChannelType from ..errors import Forbidden, GuildNotFound, BadRequest -from ..schemas import validate, GUILD_CREATE, GUILD_UPDATE, ROLE_CREATE +from ..schemas import ( + validate, GUILD_CREATE, GUILD_UPDATE, ROLE_CREATE, ROLE_UPDATE, + ROLE_UPDATE_POSITION +) from ..utils import dict_get from .channels import channel_ack from .checks import guild_check @@ -88,8 +91,8 @@ async def create_role(guild_id, name: str, **kwargs): # set position = 0 when there isn't any # other role (when we're creating the # @everyone role) - max_pos + 1 if max_pos else 0, - dict_get(kwargs, 'permissions', default_perms), + max_pos + 1 if max_pos is not None else 0, + int(dict_get(kwargs, 'permissions', default_perms)), False, dict_get(kwargs, 'mentionable', False) ) @@ -397,6 +400,150 @@ async def create_guild_role(guild_id: int): return jsonify(role) +async def _role_update_dispatch(role_id: int, guild_id: int): + """Dispatch a GUILD_ROLE_UPDATE with updated information on a role.""" + role = await app.storage.get_role(role_id, guild_id) + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', { + 'guild_id': str(guild_id), + 'role': role, + }) + + return role + + +async def _role_pairs_update(guild_id: int, pairs: list): + """Update the roles' positions. + + Dispatches GUILD_ROLE_UPDATE for all roles being updated. + """ + for pair in pairs: + pair_1, pair_2 = pair + + role_1, new_pos_1 = pair_1 + role_2, new_pos_2 = pair_2 + + conn = await app.db.acquire() + async with conn.transaction(): + # update happens in a transaction + # so we don't fuck it up + await conn.execute(""" + UPDATE roles + SET position = $1 + WHERE roles.id = $2 + """, new_pos_1, role_1) + + await conn.execute(""" + UPDATE roles + SET position = $1 + WHERE roles.id = $2 + """, new_pos_2, role_2) + + await app.db.release(conn) + + # the route fires multiple Guild Role Update. + await _role_update_dispatch(role_1, guild_id) + await _role_update_dispatch(role_2, guild_id) + + +@bp.route('//roles', methods=['PATCH']) +async def update_guild_role_positions(guild_id): + """Update the positions for a bunch of roles.""" + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + raw_j = await request.get_json() + + # we need to do this hackiness because thats + # cerberus for ya. + j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) + + # extract the list out + j = j['roles'] + print(j) + + all_roles = await app.storage.get_role_data(guild_id) + + # we'll have to calculate pairs of changing roles, + # then do the changes, etc. + roles_pos = {role['position']: int(role['id']) for role in all_roles} + new_positions = {role['id']: role['position'] for role in j} + + # always ignore people trying to change the @everyone role + # TODO: check if the user can even change the roles in the first place, + # preferrably when we have a proper perms system. + try: + new_positions.pop(guild_id) + except KeyError: + pass + + pairs = [] + + # we want to find pairs of (role_1, new_position_1) + # where new_position_1 is actually pointing to position_2 (for a role 2) + # AND we have (role_2, new_position_2) in the list of new_positions. + + # I hope the explanation went through. + + for change in j: + role_1, new_pos_1 = change['id'], change['position'] + + # check current pairs + # so we don't repeat a role + flag = False + + for pair in pairs: + if (role_1, new_pos_1) in pair: + flag = True + + # skip if found + if flag: + continue + + # find a role that is in that new position + role_2 = roles_pos.get(new_pos_1) + + # search role_2 in the new_positions list + new_pos_2 = new_positions.get(role_2) + + # if we found it, add it to the pairs array. + if new_pos_2: + pairs.append( + ((role_1, new_pos_1), (role_2, new_pos_2)) + ) + + await _role_pairs_update(guild_id, pairs) + + # return the list of all roles back + return jsonify(await app.storage.get_role_data(guild_id)) + + +@bp.route('//roles/', methods=['PATCH']) +async def update_guild_role(guild_id, role_id): + """Update a single role's information.""" + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), ROLE_UPDATE) + + # we only update ints on the db, not Permissions + j['permissions'] = int(j['permissions']) + + for field in j: + await app.db.execute(f""" + UPDATE roles + SET {field} = $1 + WHERE roles.id = $2 AND roles.guild_id = $3 + """, j[field], role_id, guild_id) + + role = await _role_update_dispatch(role_id, guild_id) + return jsonify(role) + + @bp.route('//members/', methods=['GET']) async def get_guild_member(guild_id, member_id): """Get a member's information in a guild.""" diff --git a/litecord/permissions.py b/litecord/permissions.py index 1f7a020..c5c5966 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -50,5 +50,8 @@ class Permissions(ctypes.Union): def __init__(self, val: int): self.binary = val + def __int__(self): + return self.binary + def numby(self): return self.binary diff --git a/litecord/schemas.py b/litecord/schemas.py index 8eb4bfc..77147ab 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -259,6 +259,28 @@ ROLE_CREATE = { 'mentionable': {'type': 'boolean', 'default': False}, } +ROLE_UPDATE = { + 'name': {'type': 'string', 'required': False}, + 'permissions': {'coerce': Permissions, 'required': False}, + 'color': {'coerce': Color, 'required': False}, + 'hoist': {'type': 'boolean', 'required': False}, + 'mentionable': {'type': 'boolean', 'required': False}, +} + + +ROLE_UPDATE_POSITION = { + 'roles': { + 'type': 'list', + 'schema': { + 'type': 'dict', + 'schema': { + 'id': {'coerce': int}, + 'position': {'coerce': int}, + }, + } + } +} + MEMBER_UPDATE = { 'nick': { From 43b0482581d2a22016155e00b797217df3bded2a Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 02:04:47 -0300 Subject: [PATCH 15/69] blueprints.guild: split blueprint into channels, members, roles --- litecord/blueprints/checks.py | 17 +- litecord/blueprints/guild/__init__.py | 3 + litecord/blueprints/guild/channels.py | 111 +++++++ litecord/blueprints/guild/members.py | 113 +++++++ litecord/blueprints/guild/roles.py | 222 +++++++++++++ litecord/blueprints/guilds.py | 436 +------------------------- run.py | 19 +- 7 files changed, 492 insertions(+), 429 deletions(-) create mode 100644 litecord/blueprints/guild/__init__.py create mode 100644 litecord/blueprints/guild/channels.py create mode 100644 litecord/blueprints/guild/members.py create mode 100644 litecord/blueprints/guild/roles.py diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 5cfc225..17bd337 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -1,7 +1,7 @@ from quart import current_app as app from ..enums import ChannelType, GUILD_CHANS -from ..errors import GuildNotFound, ChannelNotFound +from ..errors import GuildNotFound, ChannelNotFound, Forbidden async def guild_check(user_id: int, guild_id: int): @@ -16,6 +16,21 @@ async def guild_check(user_id: int, guild_id: int): raise GuildNotFound('guild not found') +async def guild_owner_check(user_id: int, guild_id: int): + """Check if a user is the owner of the guild.""" + owner_id = await app.db.fetchval(""" + SELECT owner_id + FROM guilds + WHERE guilds.id = $1 + """, guild_id) + + if not owner_id: + raise GuildNotFound() + + if user_id != owner_id: + raise Forbidden('You are not the owner of the guild') + + async def channel_check(user_id, channel_id): """Check if the current user is authorized to read the channel's information.""" diff --git a/litecord/blueprints/guild/__init__.py b/litecord/blueprints/guild/__init__.py new file mode 100644 index 0000000..f4f8356 --- /dev/null +++ b/litecord/blueprints/guild/__init__.py @@ -0,0 +1,3 @@ +from .roles import bp as guild_roles +from .members import bp as guild_members +from .channels import bp as guild_channels diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py new file mode 100644 index 0000000..f8c5132 --- /dev/null +++ b/litecord/blueprints/guild/channels.py @@ -0,0 +1,111 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import guild_check, guild_owner_check +from litecord.snowflake import get_snowflake +from litecord.errors import BadRequest +from litecord.enums import ChannelType +# from litecord.schemas import ( +# validate, CHAN_UPDATE_POSITION +# ) + + +bp = Blueprint('guild_channels', __name__) + + +async def _specific_chan_create(channel_id, ctype, **kwargs): + if ctype == ChannelType.GUILD_TEXT: + await app.db.execute(""" + INSERT INTO guild_text_channels (id, topic) + VALUES ($1, $2) + """, channel_id, kwargs.get('topic', '')) + elif ctype == ChannelType.GUILD_VOICE: + await app.db.execute( + """ + INSERT INTO guild_voice_channels (id, bitrate, user_limit) + VALUES ($1, $2, $3) + """, + channel_id, + kwargs.get('bitrate', 64), + kwargs.get('user_limit', 0) + ) + + +async def create_guild_channel(guild_id: int, channel_id: int, + ctype: ChannelType, **kwargs): + """Create a channel in a guild.""" + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, channel_id, ctype.value) + + # calc new pos + max_pos = await app.db.fetchval(""" + SELECT MAX(position) + FROM guild_channels + WHERE guild_id = $1 + """, guild_id) + + # account for the first channel in a guild too + max_pos = max_pos or 0 + + # all channels go to guild_channels + await app.db.execute(""" + INSERT INTO guild_channels (id, guild_id, name, position) + VALUES ($1, $2, $3, $4) + """, channel_id, guild_id, kwargs['name'], max_pos + 1) + + # the rest of sql magic is dependant on the channel + # we're creating (a text or voice or category), + # so we use this function. + await _specific_chan_create(channel_id, ctype, **kwargs) + + +@bp.route('//channels', methods=['GET']) +async def get_guild_channels(guild_id): + """Get the list of channels in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + return jsonify( + await app.storage.get_channel_data(guild_id)) + + +@bp.route('//channels', methods=['POST']) +async def create_channel(guild_id): + """Create a channel in a guild.""" + user_id = await token_check() + j = await request.get_json() + + # TODO: check permissions for MANAGE_CHANNELS + await guild_check(user_id, guild_id) + + channel_type = j.get('type', ChannelType.GUILD_TEXT) + channel_type = ChannelType(channel_type) + + if channel_type not in (ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE): + raise BadRequest('Invalid channel type') + + new_channel_id = get_snowflake() + await create_guild_channel( + guild_id, new_channel_id, channel_type, **j) + + chan = await app.storage.get_channel(new_channel_id) + await app.dispatcher.dispatch_guild( + guild_id, 'CHANNEL_CREATE', chan) + return jsonify(chan) + + +@bp.route('//channels', methods=['PATCH']) +async def modify_channel_pos(guild_id): + user_id = await token_check() + + # TODO: check MANAGE_CHANNELS + await guild_owner_check(user_id, guild_id) + + # TODO: this route + # raw_j = await request.get_json() + # j = validate({'channels': raw_j}, CHAN_UPDATE_POSITION) + + raise NotImplementedError diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py new file mode 100644 index 0000000..6fd2ad2 --- /dev/null +++ b/litecord/blueprints/guild/members.py @@ -0,0 +1,113 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import guild_check +from litecord.errors import BadRequest + + +bp = Blueprint('guild_members', __name__) + + +@bp.route('//members/', methods=['GET']) +async def get_guild_member(guild_id, member_id): + """Get a member's information in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + member = await app.storage.get_single_member(guild_id, member_id) + return jsonify(member) + + +@bp.route('//members', methods=['GET']) +async def get_members(guild_id): + """Get members inside a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + j = await request.get_json() + + limit, after = int(j.get('limit', 1)), j.get('after', 0) + + if limit < 1 or limit > 1000: + raise BadRequest('limit not in 1-1000 range') + + user_ids = await app.db.fetch(f""" + SELECT user_id + WHERE guild_id = $1, user_id > $2 + LIMIT {limit} + ORDER BY user_id ASC + """, guild_id, after) + + user_ids = [r[0] for r in user_ids] + members = await app.storage.get_member_multi(guild_id, user_ids) + return jsonify(members) + + +@bp.route('//members/', methods=['PATCH']) +async def modify_guild_member(guild_id, member_id): + """Modify a members' information in a guild.""" + j = await request.get_json() + + if 'nick' in j: + # TODO: check MANAGE_NICKNAMES + + await app.db.execute(""" + UPDATE members + SET nickname = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['nick'], member_id, guild_id) + + if 'mute' in j: + # TODO: check MUTE_MEMBERS + + await app.db.execute(""" + UPDATE members + SET muted = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['mute'], member_id, guild_id) + + if 'deaf' in j: + # TODO: check DEAFEN_MEMBERS + + await app.db.execute(""" + UPDATE members + SET deafened = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['deaf'], member_id, guild_id) + + if 'channel_id' in j: + # TODO: check MOVE_MEMBERS + # TODO: change the member's voice channel + pass + + member = await app.storage.get_member_data_one(guild_id, member_id) + member.pop('joined_at') + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ + 'guild_id': str(guild_id) + }, **member}) + + return '', 204 + + +@bp.route('//members/@me/nick', methods=['PATCH']) +async def update_nickname(guild_id): + """Update a member's nickname in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + j = await request.get_json() + + await app.db.execute(""" + UPDATE members + SET nickname = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['nick'], user_id, guild_id) + + member = await app.storage.get_member_data_one(guild_id, user_id) + member.pop('joined_at') + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ + 'guild_id': str(guild_id) + }, **member}) + + return j['nick'] diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py new file mode 100644 index 0000000..d83b33f --- /dev/null +++ b/litecord/blueprints/guild/roles.py @@ -0,0 +1,222 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.auth import token_check + +# from litecord.blueprints.checks import guild_check +from litecord.blueprints.checks import guild_owner_check +from litecord.snowflake import get_snowflake +from litecord.utils import dict_get + +from litecord.schemas import ( + validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION +) + +DEFAULT_EVERYONE_PERMS = 104324161 +bp = Blueprint('guild_roles', __name__) + + +async def create_role(guild_id, name: str, **kwargs): + """Create a role in a guild.""" + new_role_id = get_snowflake() + + # TODO: use @everyone's perm number + default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS) + + max_pos = await app.db.fetchval(""" + SELECT MAX(position) + FROM roles + WHERE guild_id = $1 + """, guild_id) + + await app.db.execute( + """ + INSERT INTO roles (id, guild_id, name, color, + hoist, position, permissions, managed, mentionable) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + """, + new_role_id, + guild_id, + name, + dict_get(kwargs, 'color', 0), + dict_get(kwargs, 'hoist', False), + + # set position = 0 when there isn't any + # other role (when we're creating the + # @everyone role) + max_pos + 1 if max_pos is not None else 0, + int(dict_get(kwargs, 'permissions', default_perms)), + False, + dict_get(kwargs, 'mentionable', False) + ) + + role = await app.storage.get_role(new_role_id, guild_id) + await app.dispatcher.dispatch_guild( + guild_id, 'GUILD_ROLE_CREATE', { + 'guild_id': str(guild_id), + 'role': role, + }) + + return role + + +@bp.route('//roles', methods=['POST']) +async def create_guild_role(guild_id: int): + """Add a role to a guild""" + user_id = await token_check() + + # TODO: use check_guild and MANAGE_ROLES permission + await guild_owner_check(user_id, guild_id) + + # client can just send null + j = validate(await request.get_json() or {}, ROLE_CREATE) + + role_name = j['name'] + j.pop('name') + + role = await create_role(guild_id, role_name, **j) + + return jsonify(role) + + +async def _role_update_dispatch(role_id: int, guild_id: int): + """Dispatch a GUILD_ROLE_UPDATE with updated information on a role.""" + role = await app.storage.get_role(role_id, guild_id) + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', { + 'guild_id': str(guild_id), + 'role': role, + }) + + return role + + +async def _role_pairs_update(guild_id: int, pairs: list): + """Update the roles' positions. + + Dispatches GUILD_ROLE_UPDATE for all roles being updated. + """ + for pair in pairs: + pair_1, pair_2 = pair + + role_1, new_pos_1 = pair_1 + role_2, new_pos_2 = pair_2 + + conn = await app.db.acquire() + async with conn.transaction(): + # update happens in a transaction + # so we don't fuck it up + await conn.execute(""" + UPDATE roles + SET position = $1 + WHERE roles.id = $2 + """, new_pos_1, role_1) + + await conn.execute(""" + UPDATE roles + SET position = $1 + WHERE roles.id = $2 + """, new_pos_2, role_2) + + await app.db.release(conn) + + # the route fires multiple Guild Role Update. + await _role_update_dispatch(role_1, guild_id) + await _role_update_dispatch(role_2, guild_id) + + +@bp.route('//roles', methods=['PATCH']) +async def update_guild_role_positions(guild_id): + """Update the positions for a bunch of roles.""" + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + raw_j = await request.get_json() + + # we need to do this hackiness because thats + # cerberus for ya. + j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) + + # extract the list out + j = j['roles'] + print(j) + + all_roles = await app.storage.get_role_data(guild_id) + + # we'll have to calculate pairs of changing roles, + # then do the changes, etc. + roles_pos = {role['position']: int(role['id']) for role in all_roles} + new_positions = {role['id']: role['position'] for role in j} + + # always ignore people trying to change the @everyone role + # TODO: check if the user can even change the roles in the first place, + # preferrably when we have a proper perms system. + try: + new_positions.pop(guild_id) + except KeyError: + pass + + pairs = [] + + # we want to find pairs of (role_1, new_position_1) + # where new_position_1 is actually pointing to position_2 (for a role 2) + # AND we have (role_2, new_position_2) in the list of new_positions. + + # I hope the explanation went through. + + for change in j: + role_1, new_pos_1 = change['id'], change['position'] + + # check current pairs + # so we don't repeat a role + flag = False + + for pair in pairs: + if (role_1, new_pos_1) in pair: + flag = True + + # skip if found + if flag: + continue + + # find a role that is in that new position + role_2 = roles_pos.get(new_pos_1) + + # search role_2 in the new_positions list + new_pos_2 = new_positions.get(role_2) + + # if we found it, add it to the pairs array. + if new_pos_2: + pairs.append( + ((role_1, new_pos_1), (role_2, new_pos_2)) + ) + + await _role_pairs_update(guild_id, pairs) + + # return the list of all roles back + return jsonify(await app.storage.get_role_data(guild_id)) + + +@bp.route('//roles/', methods=['PATCH']) +async def update_guild_role(guild_id, role_id): + """Update a single role's information.""" + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), ROLE_UPDATE) + + # we only update ints on the db, not Permissions + j['permissions'] = int(j['permissions']) + + for field in j: + await app.db.execute(f""" + UPDATE roles + SET {field} = $1 + WHERE roles.id = $2 AND roles.guild_id = $3 + """, j[field], role_id, guild_id) + + role = await _role_update_dispatch(role_id, guild_id) + return jsonify(role) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 90b359b..23ee91c 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -1,34 +1,22 @@ from quart import Blueprint, request, current_app as app, jsonify +from litecord.blueprints.guild.channels import create_guild_channel +from litecord.blueprints.guild.roles import ( + create_role, DEFAULT_EVERYONE_PERMS +) + from ..auth import token_check from ..snowflake import get_snowflake from ..enums import ChannelType -from ..errors import Forbidden, GuildNotFound, BadRequest from ..schemas import ( - validate, GUILD_CREATE, GUILD_UPDATE, ROLE_CREATE, ROLE_UPDATE, - ROLE_UPDATE_POSITION + validate, GUILD_CREATE, GUILD_UPDATE ) from ..utils import dict_get from .channels import channel_ack -from .checks import guild_check +from .checks import guild_check, guild_owner_check + bp = Blueprint('guilds', __name__) -DEFAULT_EVERYONE_PERMS = 104324161 - - -async def guild_owner_check(user_id: int, guild_id: int): - """Check if a user is the owner of the guild.""" - owner_id = await app.db.fetchval(""" - SELECT owner_id - FROM guilds - WHERE guilds.id = $1 - """, guild_id) - - if not owner_id: - raise GuildNotFound() - - if user_id != owner_id: - raise Forbidden('You are not the owner of the guild') async def create_guild_settings(guild_id: int, user_id: int): @@ -63,50 +51,6 @@ async def add_member(guild_id: int, user_id: int): await create_guild_settings(guild_id, user_id) -async def create_role(guild_id, name: str, **kwargs): - """Create a role in a guild.""" - new_role_id = get_snowflake() - - # TODO: use @everyone's perm number - default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS) - - max_pos = await app.db.fetchval(""" - SELECT MAX(position) - FROM roles - WHERE guild_id = $1 - """, guild_id) - - await app.db.execute( - """ - INSERT INTO roles (id, guild_id, name, color, - hoist, position, permissions, managed, mentionable) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, - new_role_id, - guild_id, - name, - dict_get(kwargs, 'color', 0), - dict_get(kwargs, 'hoist', False), - - # set position = 0 when there isn't any - # other role (when we're creating the - # @everyone role) - max_pos + 1 if max_pos is not None else 0, - int(dict_get(kwargs, 'permissions', default_perms)), - False, - dict_get(kwargs, 'mentionable', False) - ) - - role = await app.storage.get_role(new_role_id, guild_id) - await app.dispatcher.dispatch_guild( - guild_id, 'GUILD_ROLE_CREATE', { - 'guild_id': str(guild_id), - 'role': role, - }) - - return role - - async def guild_create_roles_prep(guild_id: int, roles: list): """Create roles in preparation in guild create.""" # by reaching this point in the code that means @@ -134,54 +78,6 @@ async def guild_create_roles_prep(guild_id: int, roles: list): ) -async def _specific_chan_create(channel_id, ctype, **kwargs): - if ctype == ChannelType.GUILD_TEXT: - await app.db.execute(""" - INSERT INTO guild_text_channels (id, topic) - VALUES ($1, $2) - """, channel_id, kwargs.get('topic', '')) - elif ctype == ChannelType.GUILD_VOICE: - await app.db.execute( - """ - INSERT INTO guild_voice_channels (id, bitrate, user_limit) - VALUES ($1, $2, $3) - """, - channel_id, - kwargs.get('bitrate', 64), - kwargs.get('user_limit', 0) - ) - - -async def create_guild_channel(guild_id: int, channel_id: int, - ctype: ChannelType, **kwargs): - """Create a channel in a guild.""" - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, channel_id, ctype.value) - - # calc new pos - max_pos = await app.db.fetchval(""" - SELECT MAX(position) - FROM guild_channels - WHERE guild_id = $1 - """, guild_id) - - # account for the first channel in a guild too - max_pos = max_pos or 0 - - # all channels go to guild_channels - await app.db.execute(""" - INSERT INTO guild_channels (id, guild_id, name, position) - VALUES ($1, $2, $3, $4) - """, channel_id, guild_id, kwargs['name'], max_pos + 1) - - # the rest of sql magic is dependant on the channel - # we're creating (a text or voice or category), - # so we use this function. - await _specific_chan_create(channel_id, ctype, **kwargs) - - async def guild_create_channels_prep(guild_id: int, channels: list): """Create channels pre-guild create""" for channel_raw in channels: @@ -255,10 +151,10 @@ async def get_guild(guild_id): @bp.route('/', methods=['UPDATE']) async def update_guild(guild_id): user_id = await token_check() - await guild_check(user_id, guild_id) - j = validate(await request.get_json(), GUILD_UPDATE) # TODO: check MANAGE_GUILD + await guild_check(user_id, guild_id) + j = validate(await request.get_json(), GUILD_UPDATE) if 'owner_id' in j: await guild_owner_check(user_id, guild_id) @@ -337,318 +233,6 @@ async def delete_guild(guild_id): return '', 204 -@bp.route('//channels', methods=['GET']) -async def get_guild_channels(guild_id): - user_id = await token_check() - await guild_check(user_id, guild_id) - - channels = await app.storage.get_channel_data(guild_id) - return jsonify(channels) - - -@bp.route('//channels', methods=['POST']) -async def create_channel(guild_id): - user_id = await token_check() - j = await request.get_json() - - # TODO: check permissions for MANAGE_CHANNELS - await guild_check(user_id, guild_id) - - channel_type = j.get('type', ChannelType.GUILD_TEXT) - channel_type = ChannelType(channel_type) - - if channel_type not in (ChannelType.GUILD_TEXT, - ChannelType.GUILD_VOICE): - raise BadRequest('Invalid channel type') - - new_channel_id = get_snowflake() - await create_guild_channel(guild_id, new_channel_id, channel_type,) - - chan = await app.storage.get_channel(new_channel_id) - await app.dispatcher.dispatch_guild( - guild_id, 'CHANNEL_CREATE', chan) - return jsonify(chan) - - -@bp.route('//channels', methods=['PATCH']) -async def modify_channel_pos(guild_id): - user_id = await token_check() - await guild_check(user_id, guild_id) - await request.get_json() - - # TODO: this route - - raise NotImplementedError - - -@bp.route('//roles', methods=['POST']) -async def create_guild_role(guild_id: int): - """Add a role to a guild""" - user_id = await token_check() - - # TODO: use check_guild and MANAGE_ROLES permission - await guild_owner_check(user_id, guild_id) - - # client can just send null - j = validate(await request.get_json() or {}, ROLE_CREATE) - - role_name = j['name'] - j.pop('name') - - role = await create_role(guild_id, role_name, **j) - - return jsonify(role) - - -async def _role_update_dispatch(role_id: int, guild_id: int): - """Dispatch a GUILD_ROLE_UPDATE with updated information on a role.""" - role = await app.storage.get_role(role_id, guild_id) - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', { - 'guild_id': str(guild_id), - 'role': role, - }) - - return role - - -async def _role_pairs_update(guild_id: int, pairs: list): - """Update the roles' positions. - - Dispatches GUILD_ROLE_UPDATE for all roles being updated. - """ - for pair in pairs: - pair_1, pair_2 = pair - - role_1, new_pos_1 = pair_1 - role_2, new_pos_2 = pair_2 - - conn = await app.db.acquire() - async with conn.transaction(): - # update happens in a transaction - # so we don't fuck it up - await conn.execute(""" - UPDATE roles - SET position = $1 - WHERE roles.id = $2 - """, new_pos_1, role_1) - - await conn.execute(""" - UPDATE roles - SET position = $1 - WHERE roles.id = $2 - """, new_pos_2, role_2) - - await app.db.release(conn) - - # the route fires multiple Guild Role Update. - await _role_update_dispatch(role_1, guild_id) - await _role_update_dispatch(role_2, guild_id) - - -@bp.route('//roles', methods=['PATCH']) -async def update_guild_role_positions(guild_id): - """Update the positions for a bunch of roles.""" - user_id = await token_check() - - # TODO: check MANAGE_ROLES - await guild_owner_check(user_id, guild_id) - - raw_j = await request.get_json() - - # we need to do this hackiness because thats - # cerberus for ya. - j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) - - # extract the list out - j = j['roles'] - print(j) - - all_roles = await app.storage.get_role_data(guild_id) - - # we'll have to calculate pairs of changing roles, - # then do the changes, etc. - roles_pos = {role['position']: int(role['id']) for role in all_roles} - new_positions = {role['id']: role['position'] for role in j} - - # always ignore people trying to change the @everyone role - # TODO: check if the user can even change the roles in the first place, - # preferrably when we have a proper perms system. - try: - new_positions.pop(guild_id) - except KeyError: - pass - - pairs = [] - - # we want to find pairs of (role_1, new_position_1) - # where new_position_1 is actually pointing to position_2 (for a role 2) - # AND we have (role_2, new_position_2) in the list of new_positions. - - # I hope the explanation went through. - - for change in j: - role_1, new_pos_1 = change['id'], change['position'] - - # check current pairs - # so we don't repeat a role - flag = False - - for pair in pairs: - if (role_1, new_pos_1) in pair: - flag = True - - # skip if found - if flag: - continue - - # find a role that is in that new position - role_2 = roles_pos.get(new_pos_1) - - # search role_2 in the new_positions list - new_pos_2 = new_positions.get(role_2) - - # if we found it, add it to the pairs array. - if new_pos_2: - pairs.append( - ((role_1, new_pos_1), (role_2, new_pos_2)) - ) - - await _role_pairs_update(guild_id, pairs) - - # return the list of all roles back - return jsonify(await app.storage.get_role_data(guild_id)) - - -@bp.route('//roles/', methods=['PATCH']) -async def update_guild_role(guild_id, role_id): - """Update a single role's information.""" - user_id = await token_check() - - # TODO: check MANAGE_ROLES - await guild_owner_check(user_id, guild_id) - - j = validate(await request.get_json(), ROLE_UPDATE) - - # we only update ints on the db, not Permissions - j['permissions'] = int(j['permissions']) - - for field in j: - await app.db.execute(f""" - UPDATE roles - SET {field} = $1 - WHERE roles.id = $2 AND roles.guild_id = $3 - """, j[field], role_id, guild_id) - - role = await _role_update_dispatch(role_id, guild_id) - return jsonify(role) - - -@bp.route('//members/', methods=['GET']) -async def get_guild_member(guild_id, member_id): - """Get a member's information in a guild.""" - user_id = await token_check() - await guild_check(user_id, guild_id) - member = await app.storage.get_single_member(guild_id, member_id) - return jsonify(member) - - -@bp.route('//members', methods=['GET']) -async def get_members(guild_id): - """Get members inside a guild.""" - user_id = await token_check() - await guild_check(user_id, guild_id) - - j = await request.get_json() - - limit, after = int(j.get('limit', 1)), j.get('after', 0) - - if limit < 1 or limit > 1000: - raise BadRequest('limit not in 1-1000 range') - - user_ids = await app.db.fetch(f""" - SELECT user_id - WHERE guild_id = $1, user_id > $2 - LIMIT {limit} - ORDER BY user_id ASC - """, guild_id, after) - - user_ids = [r[0] for r in user_ids] - members = await app.storage.get_member_multi(guild_id, user_ids) - return jsonify(members) - - -@bp.route('//members/', methods=['PATCH']) -async def modify_guild_member(guild_id, member_id): - """Modify a members' information in a guild.""" - j = await request.get_json() - - if 'nick' in j: - # TODO: check MANAGE_NICKNAMES - - await app.db.execute(""" - UPDATE members - SET nickname = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['nick'], member_id, guild_id) - - if 'mute' in j: - # TODO: check MUTE_MEMBERS - - await app.db.execute(""" - UPDATE members - SET muted = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['mute'], member_id, guild_id) - - if 'deaf' in j: - # TODO: check DEAFEN_MEMBERS - - await app.db.execute(""" - UPDATE members - SET deafened = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['deaf'], member_id, guild_id) - - if 'channel_id' in j: - # TODO: check MOVE_MEMBERS - # TODO: change the member's voice channel - pass - - member = await app.storage.get_member_data_one(guild_id, member_id) - member.pop('joined_at') - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ - 'guild_id': str(guild_id) - }, **member}) - - return '', 204 - - -@bp.route('//members/@me/nick', methods=['PATCH']) -async def update_nickname(guild_id): - """Update a member's nickname in a guild.""" - user_id = await token_check() - await guild_check(user_id, guild_id) - - j = await request.get_json() - - await app.db.execute(""" - UPDATE members - SET nickname = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['nick'], user_id, guild_id) - - member = await app.storage.get_member_data_one(guild_id, user_id) - member.pop('joined_at') - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ - 'guild_id': str(guild_id) - }, **member}) - - return j['nick'] - - async def remove_member(guild_id: int, member_id: int): """Do common tasks related to deleting a member from the guild, such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" diff --git a/run.py b/run.py index 98f674d..78ebb09 100644 --- a/run.py +++ b/run.py @@ -10,8 +10,18 @@ from logbook import StreamHandler, Logger from logbook.compat import redirect_logging import config -from litecord.blueprints import gateway, auth, users, guilds, channels, \ - webhooks, science, voice, invites, relationships, dms +from litecord.blueprints import ( + gateway, auth, users, guilds, channels, webhooks, science, + voice, invites, relationships, dms +) + +# those blueprints are separated from the "main" ones +# for code readability if people want to dig through +# the codebase. +from litecord.blueprints.guild import ( + guild_roles, guild_members, guild_channels +) + from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -50,7 +60,12 @@ bps = { auth: '/auth', users: '/users', relationships: '/users', + guilds: '/guilds', + guild_roles: '/guilds', + guild_members: '/guilds', + guild_channels: '/guilds', + channels: '/channels', webhooks: None, science: None, From 0e3db251b867f6ef3ef56c645f355c1cdbf76ccc Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 02:17:29 -0300 Subject: [PATCH 16/69] guild.channels: subscribe users to the newly created channel --- litecord/blueprints/guild/channels.py | 11 +++++++++++ litecord/pubsub/guild.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index f8c5132..f453b75 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -91,6 +91,17 @@ async def create_channel(guild_id): await create_guild_channel( guild_id, new_channel_id, channel_type, **j) + # TODO: do a better method + # subscribe the currently subscribed users to the new channel + # by getting all user ids and subscribing each one by one. + + # since GuildDispatcher calls Storage.get_channel_ids, + # it will subscribe all users to the newly created channel. + guild_pubsub = app.dispatcher.backends['guild'] + user_ids = guild_pubsub.state[guild_id] + for uid in user_ids: + await app.dispatcher.sub('guild', guild_id, uid) + chan = await app.storage.get_channel(new_channel_id) await app.dispatcher.dispatch_guild( guild_id, 'CHANNEL_CREATE', chan) diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 6613fcb..fcb0f05 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -46,6 +46,8 @@ class GuildDispatcher(DispatcherWithState): # when subbing a user to the guild, we should sub them # to every channel they have access to, in the guild. + # TODO: check for permissions + await self._chan_action('sub', guild_id, user_id) async def unsub(self, guild_id: int, user_id: int): From 660c8d43d9b3c64300265db00b03d9f460c9a9df Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 19:36:00 -0300 Subject: [PATCH 17/69] blueprints.guild.roles: move main pairs algorithm to gen_pairs --- litecord/blueprints/guild/roles.py | 134 ++++++++++++++++++++--------- 1 file changed, 94 insertions(+), 40 deletions(-) diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index d83b33f..c6cb7d2 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -1,3 +1,5 @@ +from typing import List, Dict, Any, Union + from quart import Blueprint, request, current_app as app, jsonify from litecord.auth import token_check @@ -124,6 +126,91 @@ async def _role_pairs_update(guild_id: int, pairs: list): await _role_update_dispatch(role_2, guild_id) +def gen_pairs(list_of_changes: List[Dict[str, int]], + current_state: Dict[int, int], + blacklist: List[int] = None) -> List[tuple]: + """Generate a list of pairs that, when applied to the database, + will generate the desired state given in list_of_changes. + + We must check if the given list_of_changes isn't overwriting an + element's (such as a role or a channel) position to an existing one, + without there having an already existing change for the other one. + + Here's a pratical explanation with roles: + + R1 (in position RP1) wants to be in the same position + as R2 (currently in position RP2). + + So, if we did the simpler approach, list_of_changes + would just contain the preferred change: (R1, RP2). + + With gen_pairs, there MUST be a (R2, RP1) in list_of_changes, + if there is, the given result in gen_pairs will be a pair + ((R1, RP2), (R2, RP1)) which is then used to actually + update the roles' positions in a transaction. + + Parameters + ---------- + list_of_changes: + A list of dictionaries with ``id`` and ``position`` + fields, describing the preferred changes. + current_state: + Dictionary containing the current state of the list + of elements (roles or channels). Points position + to element ID. + blacklist: + List of IDs that shouldn't be moved. + + Returns + ------- + list + List of swaps to do to achieve the preferred + state given by ``list_of_changes``. + """ + pairs = [] + blacklist = blacklist or [] + + preferred_state = {element['id']: element['position'] + for element in list_of_changes} + + for blacklisted_id in blacklist: + preferred_state.pop(blacklisted_id) + + # for each change, we must find a matching change + # in the same list, so we can make a swap pair + for change in list_of_changes: + element_1, new_pos_1 = change['id'], change['position'] + + # check current pairs + # so we don't repeat an element + flag = False + + for pair in pairs: + if (element_1, new_pos_1) in pair: + flag = True + + # skip if found + if flag: + continue + + # search if there is a role/channel in the + # position we want to change to + element_2 = current_state.get(new_pos_1) + + # if there is, is that existing channel being + # swapped to another position? + new_pos_2 = preferred_state.get(element_2) + + # if its being swapped to leave space, add it + # to the pairs list + if new_pos_2: + pairs.append( + ((element_1, new_pos_1), (element_2, new_pos_2)) + ) + + return pairs + + @bp.route('//roles', methods=['PATCH']) async def update_guild_role_positions(guild_id): """Update the positions for a bunch of roles.""" @@ -140,57 +227,24 @@ async def update_guild_role_positions(guild_id): # extract the list out j = j['roles'] - print(j) all_roles = await app.storage.get_role_data(guild_id) # we'll have to calculate pairs of changing roles, # then do the changes, etc. roles_pos = {role['position']: int(role['id']) for role in all_roles} - new_positions = {role['id']: role['position'] for role in j} - # always ignore people trying to change the @everyone role # TODO: check if the user can even change the roles in the first place, # preferrably when we have a proper perms system. - try: - new_positions.pop(guild_id) - except KeyError: - pass - pairs = [] + pairs = gen_pairs( + j, + roles_pos, - # we want to find pairs of (role_1, new_position_1) - # where new_position_1 is actually pointing to position_2 (for a role 2) - # AND we have (role_2, new_position_2) in the list of new_positions. - - # I hope the explanation went through. - - for change in j: - role_1, new_pos_1 = change['id'], change['position'] - - # check current pairs - # so we don't repeat a role - flag = False - - for pair in pairs: - if (role_1, new_pos_1) in pair: - flag = True - - # skip if found - if flag: - continue - - # find a role that is in that new position - role_2 = roles_pos.get(new_pos_1) - - # search role_2 in the new_positions list - new_pos_2 = new_positions.get(role_2) - - # if we found it, add it to the pairs array. - if new_pos_2: - pairs.append( - ((role_1, new_pos_1), (role_2, new_pos_2)) - ) + # always ignore people trying to change + # the @everyone's role position + [guild_id] + ) await _role_pairs_update(guild_id, pairs) From ba794de47ad9fea7b3461f76bf2e7302d303e1b5 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 20:36:04 -0300 Subject: [PATCH 18/69] channels: add implementation for change channel position --- litecord/blueprints/guild/channels.py | 70 ++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index f453b75..412ec86 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -5,9 +5,10 @@ from litecord.blueprints.checks import guild_check, guild_owner_check from litecord.snowflake import get_snowflake from litecord.errors import BadRequest from litecord.enums import ChannelType -# from litecord.schemas import ( -# validate, CHAN_UPDATE_POSITION -# ) +from litecord.schemas import ( + validate, ROLE_UPDATE_POSITION +) +from litecord.blueprints.guild.roles import gen_pairs bp = Blueprint('guild_channels', __name__) @@ -108,15 +109,70 @@ async def create_channel(guild_id): return jsonify(chan) +async def _chan_update_dispatch(guild_id: int, channel_id: int): + """Fetch new information about the channel and dispatch + a single CHANNEL_UPDATE event to the guild.""" + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan) + + +async def _do_single_swap(guild_id: int, pair: tuple): + """Do a single channel swap, dispatching + the CHANNEL_UPDATE events for after the swap""" + pair1, pair2 = pair + channel_1, new_pos_1 = pair1 + channel_2, new_pos_2 = pair2 + + # do the swap in a transaction. + conn = await app.db.acquire() + + async with conn.transaction(): + await conn.executemany(""" + UPDATE guild_channels + SET position = $1 + WHERE id = $2 AND guild_id = $3 + """, [ + (new_pos_1, channel_1, guild_id), + (new_pos_2, channel_2, guild_id)]) + + await _chan_update_dispatch(guild_id, channel_1) + await _chan_update_dispatch(guild_id, channel_2) + + +async def _do_channel_swaps(guild_id: int, swap_pairs: list): + """Swap channel pairs' positions, given the list + of pairs to do. + + Dispatches CHANNEL_UPDATEs to the guild. + """ + for pair in swap_pairs: + await _do_single_swap(guild_id, pair) + + @bp.route('//channels', methods=['PATCH']) async def modify_channel_pos(guild_id): + """Change positions of channels in a guild.""" user_id = await token_check() # TODO: check MANAGE_CHANNELS await guild_owner_check(user_id, guild_id) - # TODO: this route - # raw_j = await request.get_json() - # j = validate({'channels': raw_j}, CHAN_UPDATE_POSITION) + # same thing as guild.roles, so we use + # the same schema and all. + raw_j = await request.get_json() + j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) + j = j['roles'] - raise NotImplementedError + channels = await app.storage.get_channel_data(guild_id) + + channel_positions = {chan['position']: int(chan['id']) + for chan in channels} + + swap_pairs = gen_pairs( + j, + channel_positions + ) + + await _do_channel_swaps(guild_id, swap_pairs) + + return '', 204 From 72465abdcb3b0b203f32214a0eb5f0751ac0c4a8 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 20:44:59 -0300 Subject: [PATCH 19/69] channels: make sure we release the connection --- litecord/blueprints/guild/channels.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index 412ec86..21d89ec 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -135,6 +135,8 @@ async def _do_single_swap(guild_id: int, pair: tuple): (new_pos_1, channel_1, guild_id), (new_pos_2, channel_2, guild_id)]) + await app.db.release(conn) + await _chan_update_dispatch(guild_id, channel_1) await _chan_update_dispatch(guild_id, channel_2) From ce5b75921a7c83ba96b8b33b5e6eb11ca39b9145 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 21:50:15 -0300 Subject: [PATCH 20/69] members: add role change impl - lazy_guild: add to online/offline groups when role isnt hoisted - schemas: fix MEMBER_UPDATE.nick --- litecord/blueprints/guild/members.py | 63 +++++++++++++++++++++++++++- litecord/pubsub/lazy_guild.py | 5 ++- litecord/schemas.py | 2 +- litecord/storage.py | 4 +- 4 files changed, 68 insertions(+), 6 deletions(-) diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py index 6fd2ad2..8d90435 100644 --- a/litecord/blueprints/guild/members.py +++ b/litecord/blueprints/guild/members.py @@ -3,6 +3,10 @@ from quart import Blueprint, request, current_app as app, jsonify from litecord.blueprints.auth import token_check from litecord.blueprints.checks import guild_check from litecord.errors import BadRequest +from litecord.schemas import ( + validate, MEMBER_UPDATE +) +from litecord.blueprints.checks import guild_owner_check bp = Blueprint('guild_members', __name__) @@ -42,10 +46,61 @@ async def get_members(guild_id): return jsonify(members) +async def _update_member_roles(guild_id: int, member_id: int, + wanted_roles: list): + """Update the roles a member has.""" + + # first, fetch all current roles + roles = await app.db.fetch(""" + SELECT role_id from member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + roles = [r['role_id'] for r in roles] + + roles = set(roles) + wanted_roles = set(wanted_roles) + + # first, we need to find all added roles: + # roles that are on wanted_roles but + # not on roles + added_roles = wanted_roles - roles + + # and then the removed roles + # which are roles in roles, but not + # in wanted_roles + removed_roles = roles - wanted_roles + + conn = await app.db.acquire() + + async with conn.transaction(): + # add roles + await app.db.executemany(""" + INSERT INTO member_roles (user_id, guild_id, role_id) + VALUES ($1, $2, $3) + """, [(member_id, guild_id, role_id) + for role_id in added_roles]) + + # remove roles + await app.db.executemany(""" + DELETE FROM member_roles + WHERE + user_id = $1 + AND guild_id = $2 + AND role_id = $3 + """, [(member_id, guild_id, role_id) + for role_id in removed_roles]) + + await app.db.release(conn) + + @bp.route('//members/', methods=['PATCH']) async def modify_guild_member(guild_id, member_id): """Modify a members' information in a guild.""" - j = await request.get_json() + user_id = await token_check() + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), MEMBER_UPDATE) if 'nick' in j: # TODO: check MANAGE_NICKNAMES @@ -75,10 +130,14 @@ async def modify_guild_member(guild_id, member_id): """, j['deaf'], member_id, guild_id) if 'channel_id' in j: - # TODO: check MOVE_MEMBERS + # TODO: check MOVE_MEMBERS and CONNECT to the channel # TODO: change the member's voice channel pass + if 'roles' in j: + # TODO: check permissions + await _update_member_roles(guild_id, member_id, j['roles']) + member = await app.storage.get_member_data_one(guild_id, member_id) member.pop('joined_at') diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 0bcf1f5..d010f78 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -151,7 +151,10 @@ class GuildMemberList: # this user has a best_role that isn't the # @everyone role, so we'll put them in the respective group - group_data[best_role_id].append(presence) + try: + group_data[best_role_id].append(presence) + except KeyError: + group_data[simple_group].append(presence) # go through each group and sort the resulting members by display name diff --git a/litecord/schemas.py b/litecord/schemas.py index 77147ab..822e7b6 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -284,7 +284,7 @@ ROLE_UPDATE_POSITION = { MEMBER_UPDATE = { 'nick': { - 'type': 'nickname', + 'type': 'username', 'minlength': 1, 'maxlength': 100, 'required': False, }, diff --git a/litecord/storage.py b/litecord/storage.py index 7337e54..0c141b3 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -164,7 +164,7 @@ class Storage: """, guild_id, member_id) async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: - members_roles = await self.db.fetch(""" + roles = await self.db.fetch(""" SELECT role_id::text FROM member_roles WHERE guild_id = $1 AND user_id = $2 @@ -173,7 +173,7 @@ class Storage: return { 'user': await self.get_user(member_id), 'nick': row['nickname'], - 'roles': [row[0] for row in members_roles], + 'roles': [r['role_id'] for r in roles], 'joined_at': row['joined_at'].isoformat(), 'deaf': row['deafened'], 'mute': row['muted'], From 8678c84355deab4432eb9d6854c16441bb0f7a2f Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 22:16:03 -0300 Subject: [PATCH 21/69] roles: add GET /api/v6/guilds/:id/roles - guild: add DELETE /api/v6/guilds/:id/bans/:uid --- litecord/blueprints/guild/roles.py | 22 +++++++++++++++++----- litecord/blueprints/guilds.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index c6cb7d2..aaece26 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -4,19 +4,31 @@ from quart import Blueprint, request, current_app as app, jsonify from litecord.auth import token_check -# from litecord.blueprints.checks import guild_check -from litecord.blueprints.checks import guild_owner_check -from litecord.snowflake import get_snowflake -from litecord.utils import dict_get - +from litecord.blueprints.checks import ( + guild_check, guild_owner_check +) from litecord.schemas import ( validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION ) +from litecord.snowflake import get_snowflake +from litecord.utils import dict_get + DEFAULT_EVERYONE_PERMS = 104324161 bp = Blueprint('guild_roles', __name__) +@bp.route('//roles', methods=['GET']) +async def get_guild_roles(guild_id): + """Get all roles in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + return jsonify( + await app.storage.get_role_data(guild_id) + ) + + async def create_role(guild_id, name: str, **kwargs): """Create a role in a guild.""" new_role_id = get_snowflake() diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 23ee91c..84eb129 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -112,6 +112,10 @@ async def create_guild(): # create the default @everyone role (everyone has it by default, # so we don't insert that in the table) + + # we also don't use create_role because the id of the role + # is the same as the id of the guild, and create_role + # generates a new snowflake. await app.db.execute(""" INSERT INTO roles (id, guild_id, name, position, permissions) VALUES ($1, $2, $3, $4, $5) @@ -314,6 +318,31 @@ async def create_ban(guild_id, member_id): return '', 204 +@bp.route('//bans/', methods=['DELETE']) +async def remove_ban(guild_id, banned_id): + user_id = await token_check() + + # TODO: check BAN_MEMBERS permission + await guild_owner_check(guild_id, user_id) + + res = await app.db.execute(""" + DELETE FROM bans + WHERE guild_id = $1 AND user_id = $@ + """, guild_id, banned_id) + + # we don't really need to dispatch GUILD_BAN_REMOVE + # when no bans were actually removed. + if res == 'DELETE 0': + return '', 204 + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(banned_id) + }) + + return '', 204 + + @bp.route('//messages/search') async def search_messages(guild_id): """Search messages in a guild. From bbc39a953aa811a7f125bc2dfd093b091f125f9c Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 22:20:36 -0300 Subject: [PATCH 22/69] roles: add DELETE /guilds/:gid/roles/:rid --- litecord/blueprints/guild/roles.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index aaece26..003b4e4 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -286,3 +286,30 @@ async def update_guild_role(guild_id, role_id): role = await _role_update_dispatch(role_id, guild_id) return jsonify(role) + + +@bp.route('//roles/', methods=['DELETE']) +async def delete_guild_role(guild_id, role_id): + """Delete a role. + + Dispatches GUILD_ROLE_DELETE. + """ + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + res = await app.db.execute(""" + DELETE FROM roles + WHERE guild_id = $1 AND id = $2 + """, guild_id, role_id) + + if res == 'DELETE 0': + return '', 204 + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', { + 'guild_id': str(guild_id), + 'role_id': str(role_id), + }) + + return '', 204 From 65b47e74bf55c91e1f91fe6d302f7e4bf266176c Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 22:27:05 -0300 Subject: [PATCH 23/69] guild: add mod blueprint to keep moderation-related routes in a single blueprint --- litecord/blueprints/guild/__init__.py | 1 + litecord/blueprints/guild/mod.py | 112 ++++++++++++++++++++++++++ litecord/blueprints/guilds.py | 107 ------------------------ run.py | 4 +- 4 files changed, 116 insertions(+), 108 deletions(-) create mode 100644 litecord/blueprints/guild/mod.py diff --git a/litecord/blueprints/guild/__init__.py b/litecord/blueprints/guild/__init__.py index f4f8356..36c36b6 100644 --- a/litecord/blueprints/guild/__init__.py +++ b/litecord/blueprints/guild/__init__.py @@ -1,3 +1,4 @@ from .roles import bp as guild_roles from .members import bp as guild_members from .channels import bp as guild_channels +from .mod import bp as guild_mod diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py new file mode 100644 index 0000000..92a427f --- /dev/null +++ b/litecord/blueprints/guild/mod.py @@ -0,0 +1,112 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import guild_owner_check + +bp = Blueprint('guild_moderation', __name__) + + +async def remove_member(guild_id: int, member_id: int): + """Do common tasks related to deleting a member from the guild, + such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" + + await app.db.execute(""" + DELETE FROM members + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', { + 'guild_id': guild_id, + 'unavailable': False, + }) + + await app.dispatcher.unsub('guild', guild_id, member_id) + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(member_id), + }) + + +@bp.route('//members/', methods=['DELETE']) +async def kick_member(guild_id, member_id): + """Remove a member from a guild.""" + user_id = await token_check() + + # TODO: check KICK_MEMBERS permission + await guild_owner_check(user_id, guild_id) + await remove_member(guild_id, member_id) + return '', 204 + + +@bp.route('//bans', methods=['GET']) +async def get_bans(guild_id): + user_id = await token_check() + + # TODO: check BAN_MEMBERS permission + await guild_owner_check(user_id, guild_id) + + bans = await app.db.fetch(""" + SELECT user_id, reason + FROM bans + WHERE bans.guild_id = $1 + """, guild_id) + + res = [] + + for ban in bans: + res.append({ + 'reason': ban['reason'], + 'user': await app.storage.get_user(ban['user_id']) + }) + + return jsonify(res) + + +@bp.route('//bans/', methods=['PUT']) +async def create_ban(guild_id, member_id): + user_id = await token_check() + + # TODO: check BAN_MEMBERS permission + await guild_owner_check(user_id, guild_id) + + j = await request.get_json() + + await app.db.execute(""" + INSERT INTO bans (guild_id, user_id, reason) + VALUES ($1, $2, $3) + """, guild_id, member_id, j.get('reason', '')) + + await remove_member(guild_id, member_id) + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(member_id) + }) + + return '', 204 + + +@bp.route('//bans/', methods=['DELETE']) +async def remove_ban(guild_id, banned_id): + user_id = await token_check() + + # TODO: check BAN_MEMBERS permission + await guild_owner_check(guild_id, user_id) + + res = await app.db.execute(""" + DELETE FROM bans + WHERE guild_id = $1 AND user_id = $@ + """, guild_id, banned_id) + + # we don't really need to dispatch GUILD_BAN_REMOVE + # when no bans were actually removed. + if res == 'DELETE 0': + return '', 204 + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(banned_id) + }) + + return '', 204 diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 84eb129..caaf319 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -11,7 +11,6 @@ from ..enums import ChannelType from ..schemas import ( validate, GUILD_CREATE, GUILD_UPDATE ) -from ..utils import dict_get from .channels import channel_ack from .checks import guild_check, guild_owner_check @@ -237,112 +236,6 @@ async def delete_guild(guild_id): return '', 204 -async def remove_member(guild_id: int, member_id: int): - """Do common tasks related to deleting a member from the guild, - such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" - - await app.db.execute(""" - DELETE FROM members - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) - - await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', { - 'guild_id': guild_id, - 'unavailable': False, - }) - - await app.dispatcher.unsub('guild', guild_id, member_id) - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { - 'guild_id': str(guild_id), - 'user': await app.storage.get_user(member_id), - }) - - -@bp.route('//members/', methods=['DELETE']) -async def kick_member(guild_id, member_id): - """Remove a member from a guild.""" - user_id = await token_check() - - # TODO: check KICK_MEMBERS permission - await guild_owner_check(user_id, guild_id) - await remove_member(guild_id, member_id) - return '', 204 - - -@bp.route('//bans', methods=['GET']) -async def get_bans(guild_id): - user_id = await token_check() - - # TODO: check BAN_MEMBERS permission - await guild_owner_check(user_id, guild_id) - - bans = await app.db.fetch(""" - SELECT user_id, reason - FROM bans - WHERE bans.guild_id = $1 - """, guild_id) - - res = [] - - for ban in bans: - res.append({ - 'reason': ban['reason'], - 'user': await app.storage.get_user(ban['user_id']) - }) - - return jsonify(res) - - -@bp.route('//bans/', methods=['PUT']) -async def create_ban(guild_id, member_id): - user_id = await token_check() - - # TODO: check BAN_MEMBERS permission - await guild_owner_check(user_id, guild_id) - - j = await request.get_json() - - await app.db.execute(""" - INSERT INTO bans (guild_id, user_id, reason) - VALUES ($1, $2, $3) - """, guild_id, member_id, j.get('reason', '')) - - await remove_member(guild_id, member_id) - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', { - 'guild_id': str(guild_id), - 'user': await app.storage.get_user(member_id) - }) - - return '', 204 - - -@bp.route('//bans/', methods=['DELETE']) -async def remove_ban(guild_id, banned_id): - user_id = await token_check() - - # TODO: check BAN_MEMBERS permission - await guild_owner_check(guild_id, user_id) - - res = await app.db.execute(""" - DELETE FROM bans - WHERE guild_id = $1 AND user_id = $@ - """, guild_id, banned_id) - - # we don't really need to dispatch GUILD_BAN_REMOVE - # when no bans were actually removed. - if res == 'DELETE 0': - return '', 204 - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', { - 'guild_id': str(guild_id), - 'user': await app.storage.get_user(banned_id) - }) - - return '', 204 - - @bp.route('//messages/search') async def search_messages(guild_id): """Search messages in a guild. diff --git a/run.py b/run.py index 78ebb09..e52002b 100644 --- a/run.py +++ b/run.py @@ -19,7 +19,7 @@ from litecord.blueprints import ( # for code readability if people want to dig through # the codebase. from litecord.blueprints.guild import ( - guild_roles, guild_members, guild_channels + guild_roles, guild_members, guild_channels, mod ) from litecord.gateway import websocket_handler @@ -65,6 +65,8 @@ bps = { guild_roles: '/guilds', guild_members: '/guilds', guild_channels: '/guilds', + # mod for moderation + mod: '/guilds', channels: '/channels', webhooks: None, From 9aa9b3839be54df8e7788db127311c6157d4d034 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 23:32:50 -0300 Subject: [PATCH 24/69] guild.mod: add GET /guilds/:gid/prune - users: fix get_me_guilds - run: fix importing mod blueprint --- litecord/blueprints/guild/mod.py | 49 ++++++++++++++++++++++++++++++++ litecord/blueprints/users.py | 4 ++- litecord/schemas.py | 4 +++ run.py | 5 ++-- 4 files changed, 58 insertions(+), 4 deletions(-) diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index 92a427f..0e8c817 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -3,6 +3,8 @@ from quart import Blueprint, request, current_app as app, jsonify from litecord.blueprints.auth import token_check from litecord.blueprints.checks import guild_owner_check +from litecord.schemas import validate, GUILD_PRUNE + bp = Blueprint('guild_moderation', __name__) @@ -110,3 +112,50 @@ async def remove_ban(guild_id, banned_id): }) return '', 204 + + +async def get_prune(guild_id: int, days: int) -> list: + """Get all members in a guild that: + + - did not login in ``days`` days. + - don't have any roles. + """ + # a good solution would be in pure sql. + member_ids = await app.storage.fetch(f""" + SELECT id + FROM users + JOIN members + ON member.guild_id = $1 AND member.user_id = users.id + WHERE users.last_session < (now() - (interval '{days} days')) + """, guild_id) + + member_ids = [r['id'] for r in member_ids] + members = [] + + for member_id in member_ids: + role_count = await app.db.fetchval(""" + SELECT COUNT(*) + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + if role_count == 0: + members.append(member_id) + + return members + + +@bp.route('//prune', methods=['GET']) +async def get_guild_prune_count(guild_id): + user_id = await token_check() + + # TODO: check KICK_MEMBERS + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), GUILD_PRUNE) + days = j['days'] + member_ids = await get_prune(guild_id, days) + + return jsonify({ + 'pruned': len(member_ids), + }) diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 8ab4706..49d8cb6 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -219,9 +219,11 @@ async def get_me_guilds(): partial = await app.db.fetchrow(""" SELECT id::text, name, icon, owner_id FROM guilds - WHERE guild_id = $1 + WHERE guilds.id = $1 """, guild_id) + partial = dict(partial) + # TODO: partial['permissions'] partial['owner'] = partial['owner_id'] == user_id partial.pop('owner_id') diff --git a/litecord/schemas.py b/litecord/schemas.py index 822e7b6..182420c 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -508,3 +508,7 @@ GUILD_SETTINGS = { 'required': False, } } + +GUILD_PRUNE = { + 'days': {'type': 'number', 'coerce': int, 'min': 1} +} diff --git a/run.py b/run.py index e52002b..2144937 100644 --- a/run.py +++ b/run.py @@ -19,7 +19,7 @@ from litecord.blueprints import ( # for code readability if people want to dig through # the codebase. from litecord.blueprints.guild import ( - guild_roles, guild_members, guild_channels, mod + guild_roles, guild_members, guild_channels, guild_mod ) from litecord.gateway import websocket_handler @@ -65,8 +65,7 @@ bps = { guild_roles: '/guilds', guild_members: '/guilds', guild_channels: '/guilds', - # mod for moderation - mod: '/guilds', + guild_mod: '/guilds', channels: '/channels', webhooks: None, From 80c29265f389c2dd3556bba27e196738739fa770 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 23:40:35 -0300 Subject: [PATCH 25/69] schema.sql: add users.last_session - gateway.websocket: update users.last_session SQL for instances: ```sql ALTER TABLE users ADD COLUMN last_session timestamp without time zone default (now() at time zone 'utc'); ``` --- litecord/gateway/websocket.py | 7 +++++++ schema.sql | 3 +++ 2 files changed, 10 insertions(+) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 515eac6..e32482e 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -406,6 +406,13 @@ class GatewayWebsocket: # link the state to the user self.ext.state_manager.insert(self.state) + # update last_session + await self.ext.db.execute(""" + UPDATE users + SET last_session = (now() at time zone 'utc') + WHERE id = $1 + """, user_id) + await self.update_status(presence) await self.subscribe_all() await self.dispatch_ready() diff --git a/schema.sql b/schema.sql index e43ba45..b520d18 100644 --- a/schema.sql +++ b/schema.sql @@ -75,6 +75,9 @@ CREATE TABLE IF NOT EXISTS users ( phone varchar(60) DEFAULT '', password_hash text NOT NULL, + -- store the last time the user logged in via the gateway + last_session timestamp without time zone default (now() at time zone 'utc'), + PRIMARY KEY (id, username, discriminator) ); From 4d4b075de901cd80ee68da92a9f13391afb55dba Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 27 Oct 2018 23:48:28 -0300 Subject: [PATCH 26/69] auth: move last_session update from gateway to auth this should help in cases where the client has long-lived sessions (more than a day) --- litecord/auth.py | 11 +++++++++++ litecord/gateway/websocket.py | 7 ------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/litecord/auth.py b/litecord/auth.py index fa8404b..9b5e9c7 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -35,6 +35,17 @@ async def raw_token_check(token, db=None): try: signer.unsign(token) log.debug('login for uid {} successful', user_id) + + # update the user's last_session field + # so that we can keep an exact track of activity, + # even on long-lived single sessions (that can happen + # with people leaving their clients open forever) + await db.execute(""" + UPDATE users + SET last_session = (now() at time zone 'utc') + WHERE id = $1 + """, user_id) + return user_id except BadSignature: log.warning('token failed for uid {}', user_id) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index e32482e..515eac6 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -406,13 +406,6 @@ class GatewayWebsocket: # link the state to the user self.ext.state_manager.insert(self.state) - # update last_session - await self.ext.db.execute(""" - UPDATE users - SET last_session = (now() at time zone 'utc') - WHERE id = $1 - """, user_id) - await self.update_status(presence) await self.subscribe_all() await self.dispatch_ready() From 5133aab849e627f24e59f660aafc1262e2140510 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 28 Oct 2018 00:01:59 -0300 Subject: [PATCH 27/69] guild.mod: fix get_prune --- litecord/blueprints/guild/mod.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index 0e8c817..3d8cb3b 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -121,11 +121,11 @@ async def get_prune(guild_id: int, days: int) -> list: - don't have any roles. """ # a good solution would be in pure sql. - member_ids = await app.storage.fetch(f""" + member_ids = await app.db.fetch(f""" SELECT id FROM users JOIN members - ON member.guild_id = $1 AND member.user_id = users.id + ON members.guild_id = $1 AND members.user_id = users.id WHERE users.last_session < (now() - (interval '{days} days')) """, guild_id) From 8352a3cab472fb5e40822a479163a2306acc6df6 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 28 Oct 2018 00:15:56 -0300 Subject: [PATCH 28/69] guild.mod: add POST /guilds/:id/prune --- litecord/blueprints/guild/mod.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index 3d8cb3b..1c865af 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -30,8 +30,14 @@ async def remove_member(guild_id: int, member_id: int): }) +async def remove_member_multi(guild_id: int, members: list): + """Remove multiple members.""" + for member_id in members: + await remove_member(guild_id, member_id) + + @bp.route('//members/', methods=['DELETE']) -async def kick_member(guild_id, member_id): +async def kick_guild_member(guild_id, member_id): """Remove a member from a guild.""" user_id = await token_check() @@ -159,3 +165,22 @@ async def get_guild_prune_count(guild_id): return jsonify({ 'pruned': len(member_ids), }) + + +@bp.route('//prune', methods=['POST']) +async def begin_guild_prune(guild_id): + user_id = await token_check() + + # TODO: check KICK_MEMBERS + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), GUILD_PRUNE) + days = j['days'] + member_ids = await get_prune(guild_id, days) + + app.loop.create_task(remove_member_multi(guild_id, member_ids)) + + return jsonify({ + 'pruned': len(member_ids) + }) + From f2d591367202adbeac2876995f08fe500748bcb3 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 28 Oct 2018 17:29:43 -0300 Subject: [PATCH 29/69] blueprints: split channels to channel.messages bp --- litecord/blueprints/channel/__init__.py | 1 + litecord/blueprints/channel/messages.py | 212 ++++++++++++++++++++++++ litecord/blueprints/channels.py | 206 +---------------------- litecord/blueprints/guild/mod.py | 1 - run.py | 6 + 5 files changed, 223 insertions(+), 203 deletions(-) create mode 100644 litecord/blueprints/channel/__init__.py create mode 100644 litecord/blueprints/channel/messages.py diff --git a/litecord/blueprints/channel/__init__.py b/litecord/blueprints/channel/__init__.py new file mode 100644 index 0000000..999ed70 --- /dev/null +++ b/litecord/blueprints/channel/__init__.py @@ -0,0 +1 @@ +from .messages import bp as channel_messages diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py new file mode 100644 index 0000000..bdcff92 --- /dev/null +++ b/litecord/blueprints/channel/messages.py @@ -0,0 +1,212 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from logbook import Logger + + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.blueprints.dms import try_dm_state +from litecord.errors import MessageNotFound, Forbidden +from litecord.enums import MessageType, ChannelType, GUILD_CHANS +from litecord.snowflake import get_snowflake +from litecord.schemas import validate, MESSAGE_CREATE + + +log = Logger(__name__) +bp = Blueprint('channel_messages', __name__) + + +@bp.route('//messages', methods=['GET']) +async def get_messages(channel_id): + user_id = await token_check() + await channel_check(user_id, channel_id) + + # TODO: before, after, around keys + + message_ids = await app.db.fetch(f""" + SELECT id + FROM messages + WHERE channel_id = $1 + ORDER BY id DESC + LIMIT 100 + """, channel_id) + + result = [] + + for message_id in message_ids: + msg = await app.storage.get_message(message_id['id']) + + if msg is None: + continue + + result.append(msg) + + log.info('Fetched {} messages', len(result)) + return jsonify(result) + + +@bp.route('//messages/', methods=['GET']) +async def get_single_message(channel_id, message_id): + user_id = await token_check() + await channel_check(user_id, channel_id) + + # TODO: check READ_MESSAGE_HISTORY permissions + message = await app.storage.get_message(message_id) + + if not message: + raise MessageNotFound() + + return jsonify(message) + + +async def _dm_pre_dispatch(channel_id, peer_id): + """Do some checks pre-MESSAGE_CREATE so we + make sure the receiving party will handle everything.""" + + # check the other party's dm_channel_state + + dm_state = await app.db.fetchval(""" + SELECT dm_id + FROM dm_channel_state + WHERE user_id = $1 AND dm_id = $2 + """, peer_id, channel_id) + + if dm_state: + # the peer already has the channel + # opened, so we don't need to do anything + return + + dm_chan = await app.storage.get_channel(channel_id) + + # dispatch CHANNEL_CREATE so the client knows which + # channel the future event is about + await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) + + # subscribe the peer to the channel + await app.dispatcher.sub('channel', channel_id, peer_id) + + # insert it on dm_channel_state so the client + # is subscribed on the future + await try_dm_state(peer_id, channel_id) + + +@bp.route('//messages', methods=['POST']) +async def create_message(channel_id): + user_id = await token_check() + ctype, guild_id = await channel_check(user_id, channel_id) + + j = validate(await request.get_json(), MESSAGE_CREATE) + message_id = get_snowflake() + + # TODO: check SEND_MESSAGES permission + # TODO: check connection to the gateway + + await app.db.execute( + """ + INSERT INTO messages (id, channel_id, author_id, content, tts, + mention_everyone, nonce, message_type) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + """, + message_id, + channel_id, + user_id, + j['content'], + + # TODO: check SEND_TTS_MESSAGES + j.get('tts', False), + + # TODO: check MENTION_EVERYONE permissions + '@everyone' in j['content'], + int(j.get('nonce', 0)), + MessageType.DEFAULT.value + ) + + payload = await app.storage.get_message(message_id) + + if ctype == ChannelType.DM: + # guild id here is the peer's ID. + await _dm_pre_dispatch(channel_id, guild_id) + + await app.dispatcher.dispatch('channel', channel_id, + 'MESSAGE_CREATE', payload) + + # TODO: dispatch the MESSAGE_CREATE to any mentioning user. + + if ctype == ChannelType.GUILD_TEXT: + for str_uid in payload['mentions']: + uid = int(str_uid) + + await app.db.execute(""" + UPDATE user_read_state + SET mention_count += 1 + WHERE user_id = $1 AND channel_id = $2 + """, uid, channel_id) + + return jsonify(payload) + + +@bp.route('//messages/', methods=['PATCH']) +async def edit_message(channel_id, message_id): + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + author_id = await app.db.fetchval(""" + SELECT author_id FROM messages + WHERE messages.id = $1 + """, message_id) + + if not author_id == user_id: + raise Forbidden('You can not edit this message') + + j = await request.get_json() + updated = 'content' in j or 'embed' in j + + if 'content' in j: + await app.db.execute(""" + UPDATE messages + SET content=$1 + WHERE messages.id = $2 + """, j['content'], message_id) + + # TODO: update embed + + message = await app.storage.get_message(message_id) + + # only dispatch MESSAGE_UPDATE if we actually had any update to start with + if updated: + await app.dispatcher.dispatch('channel', channel_id, + 'MESSAGE_UPDATE', message) + + return jsonify(message) + + +@bp.route('//messages/', methods=['DELETE']) +async def delete_message(channel_id, message_id): + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + author_id = await app.db.fetchval(""" + SELECT author_id FROM messages + WHERE messages.id = $1 + """, message_id) + + # TODO: MANAGE_MESSAGES permission check + if author_id != user_id: + raise Forbidden('You can not delete this message') + + await app.db.execute(""" + DELETE FROM messages + WHERE messages.id = $1 + """, message_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, + 'MESSAGE_DELETE', { + 'id': str(message_id), + 'channel_id': str(channel_id), + + # for lazy guilds + 'guild_id': str(guild_id), + }) + + return '', 204 diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 2ab084f..e982007 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -4,13 +4,11 @@ from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger from ..auth import token_check -from ..snowflake import get_snowflake, snowflake_datetime -from ..enums import ChannelType, MessageType, GUILD_CHANS -from ..errors import Forbidden, ChannelNotFound, MessageNotFound -from ..schemas import validate, MESSAGE_CREATE +from ..snowflake import snowflake_datetime +from ..enums import ChannelType, GUILD_CHANS +from ..errors import ChannelNotFound -from .checks import channel_check, guild_check -from .dms import try_dm_state +from .checks import channel_check log = Logger(__name__) bp = Blueprint('channels', __name__) @@ -215,202 +213,6 @@ async def close_channel(channel_id): return '', 404 -@bp.route('//messages', methods=['GET']) -async def get_messages(channel_id): - user_id = await token_check() - await channel_check(user_id, channel_id) - - # TODO: before, after, around keys - - message_ids = await app.db.fetch(f""" - SELECT id - FROM messages - WHERE channel_id = $1 - ORDER BY id DESC - LIMIT 100 - """, channel_id) - - result = [] - - for message_id in message_ids: - msg = await app.storage.get_message(message_id['id']) - - if msg is None: - continue - - result.append(msg) - - log.info('Fetched {} messages', len(result)) - return jsonify(result) - - -@bp.route('//messages/', methods=['GET']) -async def get_single_message(channel_id, message_id): - user_id = await token_check() - await channel_check(user_id, channel_id) - - # TODO: check READ_MESSAGE_HISTORY permissions - message = await app.storage.get_message(message_id) - - if not message: - raise MessageNotFound() - - return jsonify(message) - - -async def _dm_pre_dispatch(channel_id, peer_id): - """Do some checks pre-MESSAGE_CREATE so we - make sure the receiving party will handle everything.""" - - # check the other party's dm_channel_state - - dm_state = await app.db.fetchval(""" - SELECT dm_id - FROM dm_channel_state - WHERE user_id = $1 AND dm_id = $2 - """, peer_id, channel_id) - - if dm_state: - # the peer already has the channel - # opened, so we don't need to do anything - return - - dm_chan = await app.storage.get_channel(channel_id) - - # dispatch CHANNEL_CREATE so the client knows which - # channel the future event is about - await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) - - # subscribe the peer to the channel - await app.dispatcher.sub('channel', channel_id, peer_id) - - # insert it on dm_channel_state so the client - # is subscribed on the future - await try_dm_state(peer_id, channel_id) - - -@bp.route('//messages', methods=['POST']) -async def create_message(channel_id): - user_id = await token_check() - ctype, guild_id = await channel_check(user_id, channel_id) - - j = validate(await request.get_json(), MESSAGE_CREATE) - message_id = get_snowflake() - - # TODO: check SEND_MESSAGES permission - # TODO: check connection to the gateway - - await app.db.execute( - """ - INSERT INTO messages (id, channel_id, author_id, content, tts, - mention_everyone, nonce, message_type) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """, - message_id, - channel_id, - user_id, - j['content'], - - # TODO: check SEND_TTS_MESSAGES - j.get('tts', False), - - # TODO: check MENTION_EVERYONE permissions - '@everyone' in j['content'], - int(j.get('nonce', 0)), - MessageType.DEFAULT.value - ) - - payload = await app.storage.get_message(message_id) - - if ctype == ChannelType.DM: - # guild id here is the peer's ID. - await _dm_pre_dispatch(channel_id, guild_id) - - await app.dispatcher.dispatch('channel', channel_id, - 'MESSAGE_CREATE', payload) - - # TODO: dispatch the MESSAGE_CREATE to any mentioning user. - - if ctype == ChannelType.GUILD_TEXT: - for str_uid in payload['mentions']: - uid = int(str_uid) - - await app.db.execute(""" - UPDATE user_read_state - SET mention_count += 1 - WHERE user_id = $1 AND channel_id = $2 - """, uid, channel_id) - - return jsonify(payload) - - -@bp.route('//messages/', methods=['PATCH']) -async def edit_message(channel_id, message_id): - user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) - - author_id = await app.db.fetchval(""" - SELECT author_id FROM messages - WHERE messages.id = $1 - """, message_id) - - if not author_id == user_id: - raise Forbidden('You can not edit this message') - - j = await request.get_json() - updated = 'content' in j or 'embed' in j - - if 'content' in j: - await app.db.execute(""" - UPDATE messages - SET content=$1 - WHERE messages.id = $2 - """, j['content'], message_id) - - # TODO: update embed - - message = await app.storage.get_message(message_id) - - # only dispatch MESSAGE_UPDATE if we actually had any update to start with - if updated: - await app.dispatcher.dispatch('channel', channel_id, - 'MESSAGE_UPDATE', message) - - return jsonify(message) - - -@bp.route('//messages/', methods=['DELETE']) -async def delete_message(channel_id, message_id): - user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) - - author_id = await app.db.fetchval(""" - SELECT author_id FROM messages - WHERE messages.id = $1 - """, message_id) - - # TODO: MANAGE_MESSAGES permission check - if author_id != user_id: - raise Forbidden('You can not delete this message') - - await app.db.execute(""" - DELETE FROM messages - WHERE messages.id = $1 - """, message_id) - - await app.dispatcher.dispatch( - 'channel', channel_id, - 'MESSAGE_DELETE', { - 'id': str(message_id), - 'channel_id': str(channel_id), - - # for lazy guilds - 'guild_id': str(guild_id), - }) - - return '', 204 - - @bp.route('//pins', methods=['GET']) async def get_pins(channel_id): user_id = await token_check() diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index 1c865af..0461da3 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -183,4 +183,3 @@ async def begin_guild_prune(guild_id): return jsonify({ 'pruned': len(member_ids) }) - diff --git a/run.py b/run.py index 2144937..1e5c936 100644 --- a/run.py +++ b/run.py @@ -22,6 +22,10 @@ from litecord.blueprints.guild import ( guild_roles, guild_members, guild_channels, guild_mod ) +from litecord.blueprints.channel import ( + channel_messages +) + from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -68,6 +72,8 @@ bps = { guild_mod: '/guilds', channels: '/channels', + channel_messages: '/channels', + webhooks: None, science: None, voice: '/voice', From 2f822408b91aa33ad73478425ede1c36ba8c7fd7 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 15:26:31 -0300 Subject: [PATCH 30/69] channel.messages: add implementation for before, after and around for GET /api/v6/channels/:id/messages. --- litecord/blueprints/channel/messages.py | 43 ++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index bdcff92..4663693 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -6,7 +6,7 @@ from logbook import Logger from litecord.blueprints.auth import token_check from litecord.blueprints.checks import channel_check from litecord.blueprints.dms import try_dm_state -from litecord.errors import MessageNotFound, Forbidden +from litecord.errors import MessageNotFound, Forbidden, BadRequest from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.snowflake import get_snowflake from litecord.schemas import validate, MESSAGE_CREATE @@ -16,19 +16,54 @@ log = Logger(__name__) bp = Blueprint('channel_messages', __name__) +def query_tuple_from_args(args: dict, limit: int) -> tuple: + before, after = None, None + + if 'around' in request.args: + average = int(limit / 2) + around = int(request.args['around']) + + after = around - average + before = around + average + + elif 'before' in request.args: + before = int(request.args['before']) + elif 'after' in request.args: + before = int(request.args['after']) + + return before, after + + @bp.route('//messages', methods=['GET']) async def get_messages(channel_id): user_id = await token_check() + + # TODO: check READ_MESSAGE_HISTORY permission await channel_check(user_id, channel_id) - # TODO: before, after, around keys + try: + limit = int(request.args.get('limit', 50)) + + if limit not in range(0, 100): + raise ValueError() + except (TypeError, ValueError): + raise BadRequest('limit not int') + + where_clause = '' + before, after = query_tuple_from_args(request.args, limit) + + if before: + where_clause += f'AND id < {before}' + + if after: + where_clause += f'AND id > {after}' message_ids = await app.db.fetch(f""" SELECT id FROM messages - WHERE channel_id = $1 + WHERE channel_id = $1 {where_clause} ORDER BY id DESC - LIMIT 100 + LIMIT {limit} """, channel_id) result = [] From 2b1f9489b73b023739ba0e5a0337e83c1d0cbb7f Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 18:29:07 -0300 Subject: [PATCH 31/69] channel: add reactions blueprint SQL for instances: ```sql DROP TABLE message_reactions; ``` Then rerun `schema.sql`. --- litecord/blueprints/channel/messages.py | 20 +- litecord/blueprints/channel/reactions.py | 224 +++++++++++++++++++++++ nginx.conf | 10 +- schema.sql | 7 +- 4 files changed, 246 insertions(+), 15 deletions(-) create mode 100644 litecord/blueprints/channel/reactions.py diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 4663693..7ea70fb 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -16,6 +16,18 @@ log = Logger(__name__) bp = Blueprint('channel_messages', __name__) +def extract_limit(request, default: int = 50): + try: + limit = int(request.args.get('limit', 50)) + + if limit not in range(0, 100): + raise ValueError() + except (TypeError, ValueError): + raise BadRequest('limit not int') + + return limit + + def query_tuple_from_args(args: dict, limit: int) -> tuple: before, after = None, None @@ -41,13 +53,7 @@ async def get_messages(channel_id): # TODO: check READ_MESSAGE_HISTORY permission await channel_check(user_id, channel_id) - try: - limit = int(request.args.get('limit', 50)) - - if limit not in range(0, 100): - raise ValueError() - except (TypeError, ValueError): - raise BadRequest('limit not int') + limit = extract_limit(request, 50) where_clause = '' before, after = query_tuple_from_args(request.args, limit) diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py new file mode 100644 index 0000000..2c01977 --- /dev/null +++ b/litecord/blueprints/channel/reactions.py @@ -0,0 +1,224 @@ +from enum import IntEnum + +from quart import Blueprint, request, current_app as app, jsonify +from logbook import Logger + + +from litecord.utils import async_map +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.blueprints.channel.messages import ( + query_tuple_from_args, extract_limit +) + +from litecord.errors import MessageNotFound, Forbidden, BadRequest +from litecord.enums import GUILD_CHANS + + +log = Logger(__name__) +bp = Blueprint('channel_reactions', __name__) + +BASEPATH = '//messages//reactions' + + +class EmojiType(IntEnum): + CUSTOM = 0 + UNICODE = 1 + + +def emoji_info_from_str(emoji: str) -> tuple: + """Extract emoji information from an emoji string + given on the reaction endpoints.""" + # custom emoji have an emoji of name:id + # unicode emoji just have the raw unicode. + + # try checking if the emoji is custom or unicode + emoji_type = 0 if ':' in emoji else 1 + emoji_type = EmojiType(emoji_type) + + # extract the emoji id OR the unicode value of the emoji + # depending if it is custom or not + emoji_id = (int(emoji.split(':')[1]) + if emoji_type == EmojiType.CUSTOM + else emoji) + + emoji_name = emoji.split(':')[0] + + return emoji_type, emoji_id, emoji_name + + +def _partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: + return { + 'id': None if emoji_type.UNICODE else emoji_id, + 'name': emoji_id if emoji_type.UNICODE else emoji_name + } + + +def _make_payload(user_id, channel_id, message_id, partial): + return { + 'user_id': str(user_id), + 'channel_id': str(channel_id), + 'message_id': str(message_id), + 'emoji': partial + } + + +@bp.route(f'{BASEPATH}//@me', methods=['PUT']) +async def add_reaction(channel_id: int, message_id: int, emoji: str): + """Put a reaction.""" + user_id = await token_check() + + # TODO: check READ_MESSAGE_HISTORY permission + # and ADD_REACTIONS. look on route docs. + ctype, guild_id = await channel_check(user_id, channel_id) + + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + + await app.db.execute( + """ + INSERT INTO message_reactions (message_id, user_id, + emoji_type, emoji_id, emoji_text) + VALUES ($1, $2, $3, $4, $5) + """, message_id, user_id, emoji_type, + + # if it is custom, we put the emoji_id on emoji_id + # column, if it isn't, we put it on emoji_text + # column. + emoji_id if emoji_type == EmojiType.CUSTOM else None, + emoji_id if emoji_type == EmojiType.UNICODE else None + ) + + partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + payload = _make_payload(user_id, channel_id, message_id, partial) + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_ADD', payload) + + return '', 204 + + +def _emoji_sql(emoji_type, emoji_id, emoji_name, param=4): + """Extract SQL clauses to search for specific emoji + in the message_reactions table.""" + param = f'${param}' + + # know which column to filter with + where_ext = (f'AND emoji_id = {param}' + if emoji_type == EmojiType.CUSTOM else + f'AND emoji_text = {param}') + + # which emoji to remove (custom or unicode) + main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name + + return where_ext, main_emoji + + +def _emoji_sql_simple(emoji: str, param=4): + """Simpler version of _emoji_sql for functions that + don't need the results from emoji_info_from_str.""" + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + return _emoji_sql(emoji_type, emoji_id, emoji_name, param) + + +async def remove_reaction(channel_id: int, message_id: int, + user_id: int, emoji: str): + ctype, guild_id = await channel_check(user_id, channel_id) + + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + where_ext, main_emoji = _emoji_sql(emoji_type, emoji_id, emoji_name) + + await app.db.execute( + f""" + DELETE FROM message_reactions + WHERE message_id = $1 + AND user_id = $2 + AND emoji_type = $3 + {where_ext} + """, message_id, user_id, emoji_type, main_emoji) + + partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + payload = _make_payload(user_id, channel_id, message_id, partial) + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload) + + +@bp.route(f'{BASEPATH}//@me', methods=['DELETE']) +async def remove_own_reaction(channel_id, message_id, emoji): + """Remove a reaction.""" + user_id = await token_check() + + await remove_reaction(channel_id, message_id, user_id, emoji) + + return '', 204 + + +@bp.route(f'{BASEPATH}//', methods=['DELETE']) +async def remove_user_reaction(channel_id, message_id, emoji, other_id): + """Remove a reaction made by another user.""" + await token_check() + + # TODO: check MANAGE_MESSAGES permission (and use user_id + # from token_check to do it) + await remove_reaction(channel_id, message_id, other_id, emoji) + + return '', 204 + + +@bp.route(f'{BASEPATH}/', methods=['GET']) +async def list_users_reaction(channel_id, message_id, emoji): + """Get the list of all users who reacted with a certain emoji.""" + user_id = await token_check() + + # this is not using either ctype or guild_id + # that are returned by channel_check + await channel_check(user_id, channel_id) + + limit = extract_limit(request, 25) + before, after = query_tuple_from_args(request.args, limit) + + before_clause = 'AND user_id < $2' if before else '' + after_clause = 'AND user_id > $3' if after else '' + + where_ext, main_emoji = _emoji_sql_simple(emoji, 4) + + rows = await app.db.fetch(f""" + SELECT user_id + FROM message_reactions + WHERE message_id = $1 {before_clause} {after_clause} {where_ext} + """, message_id, before, after, main_emoji) + + user_ids = [r['user_id'] for r in rows] + users = await async_map(app.storage.get_user, user_ids) + return jsonify(users) + + +@bp.route(f'{BASEPATH}', methods=['DELETE']) +async def remove_all_reactions(channel_id, message_id): + """Remove all reactions in a message.""" + user_id = await token_check() + + # TODO: check MANAGE_MESSAGES permission + ctype, guild_id = await channel_check(user_id, channel_id) + + await app.db.execute(""" + DELETE FROM message_reactions + WHERE message_id = $1 + """, message_id) + + payload = { + 'channel_id': str(channel_id), + 'message_id': str(message_id), + } + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload) diff --git a/nginx.conf b/nginx.conf index 9590d5c..d42d2df 100644 --- a/nginx.conf +++ b/nginx.conf @@ -5,13 +5,11 @@ server { location / { proxy_pass http://localhost:5000; } -} -# Main litecord websocket proxy. -server { - server_name websocket.somewhere; - - location / { + # if you don't want to keep the gateway + # domain as the main domain, you can + # keep a separate server block + location /ws { proxy_pass http://localhost:5001; # those options are required for websockets diff --git a/schema.sql b/schema.sql index b520d18..e8d101c 100644 --- a/schema.sql +++ b/schema.sql @@ -528,9 +528,12 @@ CREATE TABLE IF NOT EXISTS message_reactions ( message_id bigint REFERENCES messages (id), user_id bigint REFERENCES users (id), - -- since it can be a custom emote, or unicode emoji + -- emoji_type = 0 -> custom emoji + -- emoji_type = 1 -> unicode emoji + emoji_type int DEFAULT 0, emoji_id bigint REFERENCES guild_emoji (id), - emoji_text text NOT NULL, + emoji_text text, + PRIMARY KEY (message_id, user_id, emoji_id, emoji_text) ); From 378809bdd6d7ce6a31dd30ca39203afd5aacd0eb Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 19:14:41 -0300 Subject: [PATCH 32/69] all: misc fixes - gateway.websocket: fix etf dict decode - auth: better token parsing - auth: fix new_discrim calc - channel.messages: call _dm_pre_dispatch on get_messages - channels: fix get_pins - guilds: make sure guild owner has guild everyone role - invites: replace sub_guild to sub --- litecord/auth.py | 6 ++++- litecord/blueprints/auth.py | 2 +- litecord/blueprints/channel/messages.py | 9 ++++++- litecord/blueprints/channels.py | 2 +- litecord/blueprints/guilds.py | 6 +++++ litecord/blueprints/invites.py | 2 +- litecord/gateway/websocket.py | 32 ++++++++++++++++++++++++- litecord/storage.py | 4 ++++ 8 files changed, 57 insertions(+), 6 deletions(-) diff --git a/litecord/auth.py b/litecord/auth.py index 9b5e9c7..af241b5 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -13,7 +13,11 @@ log = Logger(__name__) async def raw_token_check(token, db=None): db = db or app.db - user_id, _hmac = token.split('.') + + # just try by fragments instead of + # unpacking + fragments = token.split('.') + user_id = fragments[0] try: user_id = base64.b64decode(user_id.encode()) diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index 17e77f3..ce5544e 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -65,7 +65,7 @@ async def register(): new_id = get_snowflake() - new_discrim = str(random.randint(1, 9999)) + new_discrim = random.randint(1, 9999) new_discrim = '%04d' % new_discrim pwd_hash = await hash_data(password) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 7ea70fb..231f83d 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -51,7 +51,13 @@ async def get_messages(channel_id): user_id = await token_check() # TODO: check READ_MESSAGE_HISTORY permission - await channel_check(user_id, channel_id) + ctype, peer_id = await channel_check(user_id, channel_id) + + if ctype == ChannelType.DM: + # make sure both parties will be subbed + # to a dm + await _dm_pre_dispatch(channel_id, user_id) + await _dm_pre_dispatch(channel_id, peer_id) limit = extract_limit(request, 50) @@ -166,6 +172,7 @@ async def create_message(channel_id): if ctype == ChannelType.DM: # guild id here is the peer's ID. + await _dm_pre_dispatch(channel_id, user_id) await _dm_pre_dispatch(channel_id, guild_id) await app.dispatcher.dispatch('channel', channel_id, diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index e982007..2259b3e 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -233,7 +233,7 @@ async def get_pins(channel_id): if message is not None: res.append(message) - return jsonify(message) + return jsonify(res) @bp.route('//pins/', methods=['PUT']) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index caaf319..cd40ec2 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -120,6 +120,12 @@ async def create_guild(): VALUES ($1, $2, $3, $4, $5) """, guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS) + # add the @everyone role to the guild creator + await app.db.execute(""" + INSERT INTO member_roles (user_id, guild_id, role_id) + VALUES ($1, $2, $3) + """, user_id, guild_id, guild_id) + # create a single #general channel. general_id = get_snowflake() diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index 246b8d0..2172ea7 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -185,7 +185,7 @@ async def use_invite(invite_code): }) # subscribe new member to guild, so they get events n stuff - app.dispatcher.sub_guild(guild_id, user_id) + await app.dispatcher.sub('guild', guild_id, user_id) # tell the new member that theres the guild it just joined. # we use dispatch_user_guild so that we send the GUILD_CREATE diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 515eac6..90bbc07 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -44,8 +44,38 @@ def encode_etf(payload) -> str: return earl.pack(payload) +def _etf_decode_dict(data): + # NOTE: this is a very slow implementation to + # decode the dictionary. + + if isinstance(data, bytes): + return data.decode() + + if not isinstance(data, dict): + return data + + _copy = dict(data) + result = {} + + for key in _copy.keys(): + # assuming key is bytes rn. + new_k = key.decode() + + # maybe nested dicts, so... + result[new_k] = _etf_decode_dict(data[key]) + + return result + def decode_etf(data: bytes): - return earl.unpack(data) + res = earl.unpack(data) + + if isinstance(res, bytes): + return data.decode() + + if isinstance(res, dict): + return _etf_decode_dict(res) + + return res class GatewayWebsocket: diff --git a/litecord/storage.py b/litecord/storage.py index 0c141b3..61792da 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -173,6 +173,10 @@ class Storage: return { 'user': await self.get_user(member_id), 'nick': row['nickname'], + + # we don't send the @everyone role's id to + # the user since it is known that everyone has + # that role. 'roles': [r['role_id'] for r in roles], 'joined_at': row['joined_at'].isoformat(), 'deaf': row['deafened'], From 1db27a811fd95e66e3e6c1f671c49a5be6ebad3d Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 19:20:39 -0300 Subject: [PATCH 33/69] litecord.storage: proper fix for missing guild everyone role --- litecord/storage.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/litecord/storage.py b/litecord/storage.py index 61792da..2de2b99 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -170,6 +170,19 @@ class Storage: WHERE guild_id = $1 AND user_id = $2 """, guild_id, member_id) + roles = [r['role_id'] for r in roles] + + try: + roles.remove(str(guild_id)) + except ValueError: + # if the @everyone role isn't in, we add it + # to member_roles automatically (it won't + # be shown on the API, though). + await self.db.execute(""" + INSERT INTO member_roles (user_id, guild_id, role_id) + VALUES ($1, $2, $3) + """, member_id, guild_id, guild_id) + return { 'user': await self.get_user(member_id), 'nick': row['nickname'], @@ -177,7 +190,7 @@ class Storage: # we don't send the @everyone role's id to # the user since it is known that everyone has # that role. - 'roles': [r['role_id'] for r in roles], + 'roles': roles, 'joined_at': row['joined_at'].isoformat(), 'deaf': row['deafened'], 'mute': row['muted'], From db7fbdb95477cf78de9ac5c9e2677af9e5563fd9 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 20:11:23 -0300 Subject: [PATCH 34/69] run: load channel_reactions bp --- litecord/blueprints/channel/__init__.py | 1 + run.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/litecord/blueprints/channel/__init__.py b/litecord/blueprints/channel/__init__.py index 999ed70..de5839a 100644 --- a/litecord/blueprints/channel/__init__.py +++ b/litecord/blueprints/channel/__init__.py @@ -1 +1,2 @@ from .messages import bp as channel_messages +from .reactions import bp as channel_reactions diff --git a/run.py b/run.py index 1e5c936..27753dd 100644 --- a/run.py +++ b/run.py @@ -23,7 +23,7 @@ from litecord.blueprints.guild import ( ) from litecord.blueprints.channel import ( - channel_messages + channel_messages, channel_reactions ) from litecord.gateway import websocket_handler @@ -73,6 +73,7 @@ bps = { channels: '/channels', channel_messages: '/channels', + channel_reactions: '/channels', webhooks: None, science: None, From 0f7ffaf717d6feb264621d00512dc16169476355 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 22:07:32 -0300 Subject: [PATCH 35/69] storage: add Storage.get_reactions This finishes basic reaction code (both inserting and putting a reaction). SQL for instances: ```sql DROP TABLE message_reactions; ``` Then rerun `schema.sql` - channel.reactions: fix partial_emoji - schema.sql: add message_reactions.react_ts and unique constraint instead of primary key --- litecord/blueprints/channel/messages.py | 8 +-- litecord/blueprints/channel/reactions.py | 17 +++--- litecord/storage.py | 75 ++++++++++++++++++++++-- schema.sql | 9 ++- 4 files changed, 90 insertions(+), 19 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 231f83d..1e54fdf 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -81,7 +81,7 @@ async def get_messages(channel_id): result = [] for message_id in message_ids: - msg = await app.storage.get_message(message_id['id']) + msg = await app.storage.get_message(message_id['id'], user_id) if msg is None: continue @@ -98,7 +98,7 @@ async def get_single_message(channel_id, message_id): await channel_check(user_id, channel_id) # TODO: check READ_MESSAGE_HISTORY permissions - message = await app.storage.get_message(message_id) + message = await app.storage.get_message(message_id, user_id) if not message: raise MessageNotFound() @@ -168,7 +168,7 @@ async def create_message(channel_id): MessageType.DEFAULT.value ) - payload = await app.storage.get_message(message_id) + payload = await app.storage.get_message(message_id, user_id) if ctype == ChannelType.DM: # guild id here is the peer's ID. @@ -218,7 +218,7 @@ async def edit_message(channel_id, message_id): # TODO: update embed - message = await app.storage.get_message(message_id) + message = await app.storage.get_message(message_id, user_id) # only dispatch MESSAGE_UPDATE if we actually had any update to start with if updated: diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py index 2c01977..6db9e6b 100644 --- a/litecord/blueprints/channel/reactions.py +++ b/litecord/blueprints/channel/reactions.py @@ -47,10 +47,11 @@ def emoji_info_from_str(emoji: str) -> tuple: return emoji_type, emoji_id, emoji_name -def _partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: +def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: + print(emoji_type, emoji_id, emoji_name) return { - 'id': None if emoji_type.UNICODE else emoji_id, - 'name': emoji_id if emoji_type.UNICODE else emoji_name + 'id': None if emoji_type == EmojiType.UNICODE else emoji_id, + 'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id } @@ -88,7 +89,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str): emoji_id if emoji_type == EmojiType.UNICODE else None ) - partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + partial = partial_emoji(emoji_type, emoji_id, emoji_name) payload = _make_payload(user_id, channel_id, message_id, partial) if ctype in GUILD_CHANS: @@ -100,7 +101,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str): return '', 204 -def _emoji_sql(emoji_type, emoji_id, emoji_name, param=4): +def emoji_sql(emoji_type, emoji_id, emoji_name, param=4): """Extract SQL clauses to search for specific emoji in the message_reactions table.""" param = f'${param}' @@ -120,7 +121,7 @@ def _emoji_sql_simple(emoji: str, param=4): """Simpler version of _emoji_sql for functions that don't need the results from emoji_info_from_str.""" emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) - return _emoji_sql(emoji_type, emoji_id, emoji_name, param) + return emoji_sql(emoji_type, emoji_id, emoji_name, param) async def remove_reaction(channel_id: int, message_id: int, @@ -128,7 +129,7 @@ async def remove_reaction(channel_id: int, message_id: int, ctype, guild_id = await channel_check(user_id, channel_id) emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) - where_ext, main_emoji = _emoji_sql(emoji_type, emoji_id, emoji_name) + where_ext, main_emoji = emoji_sql(emoji_type, emoji_id, emoji_name) await app.db.execute( f""" @@ -139,7 +140,7 @@ async def remove_reaction(channel_id: int, message_id: int, {where_ext} """, message_id, user_id, emoji_type, main_emoji) - partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + partial = partial_emoji(emoji_type, emoji_id, emoji_name) payload = _make_payload(user_id, channel_id, message_id, partial) if ctype in GUILD_CHANS: diff --git a/litecord/storage.py b/litecord/storage.py index 2de2b99..313b4f2 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -5,6 +5,9 @@ from logbook import Logger from .enums import ChannelType, RelationshipType from .schemas import USER_MENTION, ROLE_MENTION +from litecord.blueprints.channel.reactions import ( + emoji_info_from_str, EmojiType, emoji_sql, partial_emoji +) log = Logger(__name__) @@ -553,7 +556,72 @@ class Storage: return res - async def get_message(self, message_id: int) -> Dict: + async def get_reactions(self, message_id: int, user_id=None) -> List: + """Get all reactions in a message.""" + reactions = await self.db.fetch(""" + SELECT user_id, emoji_type, emoji_id, emoji_text + FROM message_reactions + ORDER BY react_ts + """) + + # ordered list of emoji + emoji = [] + + # the current state of emoji info + react_stats = {} + + # to generate the list, we pass through all + # all reactions and insert them all. + + # we can't use a set() because that + # doesn't guarantee any order. + for row in reactions: + etype = EmojiType(row['emoji_type']) + eid, etext = row['emoji_id'], row['emoji_text'] + + # get the main key to use, given + # the emoji information + _, main_emoji = emoji_sql(etype, eid, etext) + + if main_emoji in emoji: + continue + + # maintain order (first reacted comes first + # on the reaction list) + emoji.append(main_emoji) + + react_stats[main_emoji] = { + 'count': 0, + 'me': False, + 'emoji': partial_emoji(etype, eid, etext) + } + + # then the 2nd pass, where we insert + # the info for each reaction in the react_stats + # dictionary + for row in reactions: + etype = EmojiType(row['emoji_type']) + eid, etext = row['emoji_id'], row['emoji_text'] + + # same thing as the last loop, + # extracting main key + _, main_emoji = emoji_sql(etype, eid, etext) + + stats = react_stats[main_emoji] + stats['count'] += 1 + + print(row['user_id'], user_id) + if row['user_id'] == user_id: + stats['me'] = True + + # after processing reaction counts, + # we get them in the same order + # they were defined in the first loop. + print(emoji) + print(react_stats) + return list(map(react_stats.get, emoji)) + + async def get_message(self, message_id: int, user_id=None) -> Dict: """Get a single message's payload.""" row = await self.db.fetchrow(""" SELECT id::text, channel_id::text, author_id, content, @@ -614,6 +682,8 @@ class Storage: res['mention_roles'] = await self._msg_regex( ROLE_MENTION, _get_role_mention, content) + res['reactions'] = await self.get_reactions(message_id, user_id) + # TODO: handle webhook authors res['author'] = await self.get_user(res['author_id']) res.pop('author_id') @@ -624,9 +694,6 @@ class Storage: # TODO: res['embeds'] res['embeds'] = [] - # TODO: res['reactions'] - res['reactions'] = [] - # TODO: res['pinned'] res['pinned'] = False diff --git a/schema.sql b/schema.sql index e8d101c..3a654c3 100644 --- a/schema.sql +++ b/schema.sql @@ -528,15 +528,18 @@ CREATE TABLE IF NOT EXISTS message_reactions ( message_id bigint REFERENCES messages (id), user_id bigint REFERENCES users (id), + react_ts timestamp without time zone default (now() at time zone 'utc'), + -- emoji_type = 0 -> custom emoji -- emoji_type = 1 -> unicode emoji emoji_type int DEFAULT 0, emoji_id bigint REFERENCES guild_emoji (id), - emoji_text text, - - PRIMARY KEY (message_id, user_id, emoji_id, emoji_text) + emoji_text text ); +ALTER TABLE message_reactions ADD CONSTRAINT message_reactions_main_uniq + UNIQUE (message_id, user_id, emoji_id, emoji_text); + CREATE TABLE IF NOT EXISTS channel_pins ( channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, message_id bigint REFERENCES messages (id) ON DELETE CASCADE, From afb429ec7726fcffeb9787ed37a0d20797c21ccd Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 3 Nov 2018 20:58:01 -0300 Subject: [PATCH 36/69] gateway: change from 1200 guilds per shard to 1000 to more closely reproduce discord. --- litecord/blueprints/gateway.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/litecord/blueprints/gateway.py b/litecord/blueprints/gateway.py index 9f1c4df..76567aa 100644 --- a/litecord/blueprints/gateway.py +++ b/litecord/blueprints/gateway.py @@ -6,12 +6,14 @@ bp = Blueprint('gateway', __name__) def get_gw(): + """Get the gateway's web""" proto = 'wss://' if app.config['IS_SSL'] else 'ws://' return f'{proto}{app.config["WEBSOCKET_URL"]}/ws' @bp.route('/gateway') def api_gateway(): + """Get the raw URL.""" return jsonify({ 'url': get_gw() }) @@ -27,8 +29,9 @@ async def api_gateway_bot(): WHERE user_id = $1 """, user_id) - shards = max(int(guild_count / 1200), 1) + shards = max(int(guild_count / 1000), 1) + # TODO: session_start_limit (depends on ratelimits) return jsonify({ 'url': get_gw(), 'shards': shards, From 69fbd9c117aa48a9236838cd9b9d3379a5497518 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 3 Nov 2018 21:58:51 -0300 Subject: [PATCH 37/69] gateway.state_manager: send OP 7 Reconnect to clients - gateway.websocket: check StateManager flags on new connections - gateway.websocket: cancel all tasks on GatewayWebsocket.wsp.tasks - run: call StateManager.gen_close_tasks() and StateManager.close() on app shutdown --- litecord/gateway/state_manager.py | 107 +++++++++++++++++++++++++++++- litecord/gateway/websocket.py | 12 ++++ run.py | 11 +++ 3 files changed, 127 insertions(+), 3 deletions(-) diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index e393a79..5185b1d 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -1,18 +1,68 @@ +import asyncio + from typing import List, Dict, Any from collections import defaultdict +from websockets.exceptions import ConnectionClosed from logbook import Logger -from .state import GatewayState +from litecord.gateway.state import GatewayState +from litecord.gateway.opcodes import OP log = Logger(__name__) +class ManagerClose(Exception): + pass + + +class StateDictWrapper: + """Wrap a mapping so that any kind of access to the mapping while the + state manager is closed raises a ManagerClose error""" + def __init__(self, state_manager, mapping): + self.state_manager = state_manager + self._map = mapping + + def _check_closed(self): + if self.state_manager.closed: + raise ManagerClose() + + def __getitem__(self, key): + self._check_closed() + return self._map[key] + + def __delitem__(self, key): + self._check_closed() + del self._map[key] + + def __setitem__(self, key, value): + if not self.state_manager.accept_new: + raise ManagerClose() + + self._check_closed() + self._map[key] = value + + def __iter__(self): + return self._map.__iter__() + + def pop(self, key): + return self._map.pop(key) + + def values(self): + return self._map.values() + + class StateManager: """Manager for gateway state information.""" def __init__(self): + #: closed flag + self.closed = False + + #: accept new states? + self.accept_new = True + # { # user_id: { # session_id: GatewayState, @@ -20,10 +70,10 @@ class StateManager: # }, # user_id_2: {}, ... # } - self.states = defaultdict(dict) + self.states = StateDictWrapper(self, defaultdict(dict)) #: raw mapping from session ids to GatewayState - self.states_raw = {} + self.states_raw = StateDictWrapper(self, {}) def insert(self, state: GatewayState): """Insert a new state object.""" @@ -113,3 +163,54 @@ class StateManager: states.extend(member_states) return states + + async def shutdown_single(self, state: GatewayState): + """Send OP Reconnect to a single connection.""" + websocket = state.ws + + await websocket.send({ + 'op': OP.RECONNECT + }) + + # wait 200ms + # so that the client has time to process + # our payload then close the connection + await asyncio.sleep(0.2) + + try: + # try to close the connection ourselves + await websocket.ws.close( + code=4000, + reason='litecord shutting down' + ) + except ConnectionClosed: + log.info('client {} already closed', state) + + def gen_close_tasks(self): + """Generate the tasks that will order the clients + to reconnect. + + This is required to be ran before :meth:`StateManager.close`, + since this function doesn't wait for the tasks to complete. + """ + + self.accept_new = False + + #: store the shutdown tasks + tasks = [] + + for state in self.states_raw.values(): + if not state.ws: + continue + + tasks.append( + self.shutdown_single(state) + ) + + log.info('made {} shutdown tasks', len(tasks)) + + return tasks + + def close(self): + """Close the state manager.""" + self.closed = True diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 90bbc07..0d54ad3 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -753,6 +753,15 @@ class GatewayWebsocket: async def listen_messages(self): """Listen for messages coming in from the websocket.""" + + # close anyone trying to login while the + # server is shutting down + if self.ext.state_manager.closed: + raise WebsocketClose(4000, 'state manager closed') + + if not self.ext.state_manager.accept_new: + raise WebsocketClose(4000, 'state manager closed for new') + while True: message = await self.ws.recv() if len(message) > 4096: @@ -762,6 +771,9 @@ class GatewayWebsocket: await self.process_message(payload) def _cleanup(self): + for task in self.wsp.tasks.values(): + task.cancel() + if self.state: self.ext.state_manager.remove(self.state) self.state.ws = None diff --git a/run.py b/run.py index 27753dd..6f9f748 100644 --- a/run.py +++ b/run.py @@ -131,6 +131,8 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. + + # TODO: pass just the app object await websocket_handler((app.db, app.state_manager, app.storage, app.loop, app.dispatcher, app.presence), ws, url) @@ -142,6 +144,15 @@ async def app_before_serving(): @app.after_serving async def app_after_serving(): + """Shutdown tasks for the server.""" + + # first close all clients, then close db + tasks = app.state_manager.gen_close_tasks() + if tasks: + await asyncio.wait(tasks, loop=app.loop) + + app.state_manager.close() + log.info('closing db') await app.db.close() From b17cfd46eb53cb9842741efe9c02d6da7c2ba87d Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 3 Nov 2018 22:48:36 -0300 Subject: [PATCH 38/69] run: add dummy ratelimiting handler --- litecord/ratelimits/main.py | 5 +++++ run.py | 7 +++++++ 2 files changed, 12 insertions(+) create mode 100644 litecord/ratelimits/main.py diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py new file mode 100644 index 0000000..4ab71d6 --- /dev/null +++ b/litecord/ratelimits/main.py @@ -0,0 +1,5 @@ +from quart import current_app as app, request + +async def ratelimit_handler(): + # dummy handler for future code + print(request.headers) diff --git a/run.py b/run.py index 6f9f748..195aa0c 100644 --- a/run.py +++ b/run.py @@ -15,6 +15,8 @@ from litecord.blueprints import ( voice, invites, relationships, dms ) +from litecord.ratelimits.main import ratelimit_handler + # those blueprints are separated from the "main" ones # for code readability if people want to dig through # the codebase. @@ -87,6 +89,11 @@ for bp, suffix in bps.items(): app.register_blueprint(bp, url_prefix=f'/api/v6{suffix}') +@app.before_request +async def app_before_request(): + await ratelimit_handler() + + @app.after_request async def app_after_request(resp): origin = request.headers.get('Origin', '*') From 33f893c0ff5c92a5a037d45bd517016c65f127e5 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 02:06:40 -0300 Subject: [PATCH 39/69] all: add ratelimit implementation haven't tested yet, but it should work in theory. - gateway.websocket: add the 3 main ws ratelimits - litecord: add ratelimits package - ratelimits.main: add implementation - run: add app_set_ratelimit_headers --- litecord/errors.py | 4 ++ litecord/gateway/websocket.py | 30 ++++++++- litecord/ratelimits/bucket.py | 113 +++++++++++++++++++++++++++++++++ litecord/ratelimits/handler.py | 67 +++++++++++++++++++ litecord/ratelimits/main.py | 56 ++++++++++++++-- run.py | 26 +++++++- 6 files changed, 288 insertions(+), 8 deletions(-) create mode 100644 litecord/ratelimits/bucket.py create mode 100644 litecord/ratelimits/handler.py diff --git a/litecord/errors.py b/litecord/errors.py index fe4f130..e1fe1ac 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -41,6 +41,10 @@ class MessageNotFound(LitecordError): status_code = 404 +class Ratelimited(LitecordError): + status_code = 429 + + class WebsocketClose(Exception): @property def code(self): diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 0d54ad3..e03e12c 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple( ) WebsocketObjects = collections.namedtuple( - 'WebsocketObjects', 'db state_manager storage loop dispatcher presence' + 'WebsocketObjects', ('db', 'state_manager', 'storage', + 'loop', 'dispatcher', 'presence', 'ratelimiter') ) @@ -138,6 +139,11 @@ class GatewayWebsocket: else: await self.ws.send(encoded.decode()) + def _check_ratelimit(self, key: str, ratelimit_key: str): + ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}') + bucket = ratelimit.get_bucket(ratelimit_key) + return bucket.update_rate_limit() + async def _hb_wait(self, interval: int): """Wait heartbeat""" # if the client heartbeats in time, @@ -342,6 +348,14 @@ class GatewayWebsocket: async def update_status(self, status: dict): """Update the status of the current websocket connection.""" + if not self.state: + return + + if self._check_ratelimit('presence', self.state.session_id): + # Presence Updates beyond the ratelimit + # are just silently dropped. + return + if status is None: status = { 'afk': False, @@ -395,6 +409,11 @@ class GatewayWebsocket: 'op': OP.HEARTBEAT_ACK, }) + async def _connect_ratelimit(self, user_id: int): + if self._check_ratelimit('connect', user_id): + await self.invalidate_session(False) + raise WebsocketClose(4009, 'You are being ratelimited.') + async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" try: @@ -414,6 +433,8 @@ class GatewayWebsocket: except (Unauthorized, Forbidden): raise WebsocketClose(4004, 'Authentication failed') + await self._connect_ratelimit(user_id) + bot = await self.ext.db.fetchval(""" SELECT bot FROM users WHERE id = $1 @@ -751,6 +772,10 @@ class GatewayWebsocket: await handler(payload) + async def _msg_ratelimit(self): + if self._check_ratelimit('messages', self.state.session_id): + raise WebsocketClose(4008, 'You are being ratelimited.') + async def listen_messages(self): """Listen for messages coming in from the websocket.""" @@ -767,6 +792,9 @@ class GatewayWebsocket: if len(message) > 4096: raise DecodeError('Payload length exceeded') + if self.state: + await self._msg_ratelimit() + payload = self.decoder(message) await self.process_message(payload) diff --git a/litecord/ratelimits/bucket.py b/litecord/ratelimits/bucket.py new file mode 100644 index 0000000..dabb0ae --- /dev/null +++ b/litecord/ratelimits/bucket.py @@ -0,0 +1,113 @@ +""" +main litecord ratelimiting code + + This code was copied from elixire's ratelimiting, + which in turn is a work on top of discord.py's ratelimiting. +""" +import time + + +class RatelimitBucket: + """Main ratelimit bucket class.""" + def __init__(self, tokens, second): + self.requests = tokens + self.second = second + + self._window = 0.0 + self._tokens = self.requests + self.retries = 0 + self._last = 0.0 + + def get_tokens(self, current): + """Get the current amount of available tokens.""" + if not current: + current = time.time() + + # by default, use _tokens + tokens = self._tokens + + # if current timestamp is above _window + seconds + # reset tokens to self.requests (default) + if current > self._window + self.second: + tokens = self.requests + + return tokens + + def update_rate_limit(self): + """Update current ratelimit state.""" + current = time.time() + self._last = current + self._tokens = self.get_tokens(current) + + # we are using the ratelimit for the first time + # so set current ratelimit window to right now + if self._tokens == self.requests: + self._window = current + + # Are we currently ratelimited? + if self._tokens == 0: + self.retries += 1 + return self.second - (current - self._window) + + # if not ratelimited, remove a token + self.retries = 0 + self._tokens -= 1 + + # if we got ratelimited after that token removal, + # set window to now + if self._tokens == 0: + self._window = current + + def reset(self): + """Reset current ratelimit to default state.""" + self._tokens = self.requests + self._last = 0.0 + self.retries = 0 + + def copy(self): + """Create a copy of this ratelimit. + + Used to manage multiple ratelimits to users. + """ + return RatelimitBucket(self.requests, + self.second) + + def __repr__(self): + return (f'') + + +class Ratelimit: + """Manages buckets.""" + def __init__(self, tokens, second, keys=None): + self._cache = {} + if keys is None: + keys = tuple() + self.keys = keys + self._cooldown = RatelimitBucket(tokens, second) + + def __repr__(self): + return (f'') + + def _verify_cache(self): + current = time.time() + dead_keys = [k for k, v in self._cache.items() + if current > v._last + v.second] + + for k in dead_keys: + del self._cache[k] + + def get_bucket(self, key) -> RatelimitBucket: + if not self._cooldown: + return None + + self._verify_cache() + + if key not in self._cache: + bucket = self._cooldown.copy() + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket diff --git a/litecord/ratelimits/handler.py b/litecord/ratelimits/handler.py new file mode 100644 index 0000000..db896bf --- /dev/null +++ b/litecord/ratelimits/handler.py @@ -0,0 +1,67 @@ +from quart import current_app as app, request, g + +from litecord.errors import Ratelimited +from litecord.auth import token_check, Unauthorized + + +async def _check_bucket(bucket): + retry_after = bucket.update_rate_limit() + + request.bucket = bucket + + if retry_after: + raise Ratelimited('You are being ratelimited.', { + 'retry_after': retry_after + }) + + +async def _handle_global(ratelimit): + """Global ratelimit is per-user.""" + try: + user_id = await token_check() + except Unauthorized: + user_id = request.remote_addr + + bucket = ratelimit.get_bucket(user_id) + await _check_bucket(bucket) + + +async def _handle_specific(ratelimit): + try: + user_id = await token_check() + except Unauthorized: + user_id = request.remote_addr + + # construct the key based on the ratelimit.keys + keys = ratelimit.keys + + # base key is the user id + key_components = [f'user_id:{user_id}'] + + for key in keys: + val = request.view_args[key] + key_components.append(f'{key}:{val}') + + bucket_key = ':'.join(key_components) + bucket = ratelimit.get_bucket(bucket_key) + await _check_bucket(bucket) + + +async def ratelimit_handler(): + """Main ratelimit handler. + + Decides on which ratelimit to use. + """ + rule = request.url_rule + + # rule.endpoint is composed of '.' + # and so we can use that to make routes with different + # methods have different ratelimits + rule_path = rule.endpoint + + try: + ratelimit = app.ratelimiter.get_ratelimit(rule_path) + await _handle_specific(ratelimit) + except KeyError: + ratelimit = app.ratelimiter.global_bucket + await _handle_global(ratelimit) diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py index 4ab71d6..ffcc54d 100644 --- a/litecord/ratelimits/main.py +++ b/litecord/ratelimits/main.py @@ -1,5 +1,53 @@ -from quart import current_app as app, request +from litecord.ratelimits.bucket import Ratelimit -async def ratelimit_handler(): - # dummy handler for future code - print(request.headers) +""" +REST: + POST Message | 5/5s | per-channel + DELETE Message | 5/1s | per-channel + PUT/DELETE Reaction | 1/0.25s | per-channel + PATCH Member | 10/10s | per-guild + PATCH Member Nick | 1/1s | per-guild + PATCH Username | 2/3600s | per-account + |All Requests| | 50/1s | per-account +WS: + Gateway Connect | 1/5s | per-account + Presence Update | 5/60s | per-session + |All Sent Messages| | 120/60s | per-session +""" + +REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id')) + +RATELIMITS = { + 'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')), + 'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')), + + # all of those share the same bucket. + 'channel_reactions.add_reaction': REACTION_BUCKET, + 'channel_reactions.remove_own_reaction': REACTION_BUCKET, + 'channel_reactions.remove_user_reaction': REACTION_BUCKET, + + 'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')), + 'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')), + + # this only applies to username. + # 'users.patch_me': Ratelimit(2, 3600), + + '_ws.connect': Ratelimit(1, 5), + '_ws.presence': Ratelimit(5, 60), + '_ws.messages': Ratelimit(120, 60), +} + +class RatelimitManager: + """Manager for the bucket managers""" + def __init__(self): + self._ratelimiters = {} + self.global_bucket = Ratelimit(50, 1) + self._fill_rtl() + + def _fill_rtl(self): + for path, rtl in RATELIMITS.items(): + self._ratelimiters[path] = rtl + + def get_ratelimit(self, key: str) -> Ratelimit: + """Get the :class:`Ratelimit` instance for a given path.""" + return self._ratelimiters.get(key, self.global_bucket) diff --git a/run.py b/run.py index 195aa0c..ff3e623 100644 --- a/run.py +++ b/run.py @@ -9,14 +9,14 @@ from quart import Quart, g, jsonify, request from logbook import StreamHandler, Logger from logbook.compat import redirect_logging +# import the config set by instance owner import config + from litecord.blueprints import ( gateway, auth, users, guilds, channels, webhooks, science, voice, invites, relationships, dms ) -from litecord.ratelimits.main import ratelimit_handler - # those blueprints are separated from the "main" ones # for code readability if people want to dig through # the codebase. @@ -28,6 +28,9 @@ from litecord.blueprints.channel import ( channel_messages, channel_reactions ) +from litecord.ratelimits.handler import ratelimit_handler +from litecord.ratelimits.main import RatelimitManager + from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -110,6 +113,21 @@ async def app_after_request(resp): # resp.headers['Access-Control-Allow-Methods'] = '*' resp.headers['Access-Control-Allow-Methods'] = \ resp.headers.get('allow', '*') + + return resp + + +@app.after_request +async def app_set_ratelimit_headers(resp): + """Set the specific ratelimit headers.""" + try: + bucket = request.bucket + resp.headers['X-RateLimit-Limit'] = str(bucket.requests) + resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens) + resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second) + except AttributeError: + pass + return resp @@ -123,6 +141,7 @@ async def app_before_serving(): app.loop = asyncio.get_event_loop() g.loop = asyncio.get_event_loop() + app.ratelimiter = RatelimitManager() app.state_manager = StateManager() app.storage = Storage(app.db) @@ -141,7 +160,8 @@ async def app_before_serving(): # TODO: pass just the app object await websocket_handler((app.db, app.state_manager, app.storage, - app.loop, app.dispatcher, app.presence), + app.loop, app.dispatcher, app.presence, + app.ratelimiter), ws, url) ws_future = websockets.serve(_wrapper, host, port) From a96b9c5e7f2122b59315d46f47d1bd726d511f7a Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 02:23:26 -0300 Subject: [PATCH 40/69] ratelimits.handler: five better retry_after and global flag - run: add X-RateLimit-Global and Retry-After headers --- litecord/ratelimits/handler.py | 14 ++++++++++++-- litecord/ratelimits/main.py | 2 +- run.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/litecord/ratelimits/handler.py b/litecord/ratelimits/handler.py index db896bf..715a283 100644 --- a/litecord/ratelimits/handler.py +++ b/litecord/ratelimits/handler.py @@ -10,8 +10,11 @@ async def _check_bucket(bucket): request.bucket = bucket if retry_after: - raise Ratelimited('You are being ratelimited.', { - 'retry_after': retry_after + request.retry_after = retry_after + + raise Ratelimited('You are being rate limited.', { + 'retry_after': int(retry_after * 1000), + 'global': request.bucket_global, }) @@ -22,6 +25,7 @@ async def _handle_global(ratelimit): except Unauthorized: user_id = request.remote_addr + request.bucket_global = True bucket = ratelimit.get_bucket(user_id) await _check_bucket(bucket) @@ -59,6 +63,12 @@ async def ratelimit_handler(): # methods have different ratelimits rule_path = rule.endpoint + # some request ratelimit context. + # TODO: maybe put those in a namedtuple or contextvar of sorts? + request.bucket = None + request.retry_after = None + request.bucket_global = False + try: ratelimit = app.ratelimiter.get_ratelimit(rule_path) await _handle_specific(ratelimit) diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py index ffcc54d..f3e6bfb 100644 --- a/litecord/ratelimits/main.py +++ b/litecord/ratelimits/main.py @@ -41,7 +41,7 @@ class RatelimitManager: """Manager for the bucket managers""" def __init__(self): self._ratelimiters = {} - self.global_bucket = Ratelimit(50, 1) + self.global_bucket = Ratelimit(1, 1) self._fill_rtl() def _fill_rtl(self): diff --git a/run.py b/run.py index ff3e623..a8059c6 100644 --- a/run.py +++ b/run.py @@ -122,9 +122,20 @@ async def app_set_ratelimit_headers(resp): """Set the specific ratelimit headers.""" try: bucket = request.bucket + + if bucket is None: + raise AttributeError() + resp.headers['X-RateLimit-Limit'] = str(bucket.requests) resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens) resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second) + + resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower() + + # only add Retry-After if we actually hit a ratelimit + retry_after = request.retry_after + if request.retry_after: + resp.headers['Retry-After'] = str(retry_after) except AttributeError: pass From 1f5f736a8e5c46006f2e35f4d114f3e3b9de40b9 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 02:24:13 -0300 Subject: [PATCH 41/69] ratelimits.main: rollback global_bucket to 50/1 --- litecord/ratelimits/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py index f3e6bfb..ffcc54d 100644 --- a/litecord/ratelimits/main.py +++ b/litecord/ratelimits/main.py @@ -41,7 +41,7 @@ class RatelimitManager: """Manager for the bucket managers""" def __init__(self): self._ratelimiters = {} - self.global_bucket = Ratelimit(1, 1) + self.global_bucket = Ratelimit(50, 1) self._fill_rtl() def _fill_rtl(self): From c710ad5aaf6aa56f137dbf510072fd0271283df2 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 16:36:11 -0300 Subject: [PATCH 42/69] gateway.websocket: add _ws.session ratelimit --- litecord/gateway/websocket.py | 4 ++++ litecord/ratelimits/main.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index e03e12c..6226fe6 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -414,6 +414,10 @@ class GatewayWebsocket: await self.invalidate_session(False) raise WebsocketClose(4009, 'You are being ratelimited.') + if self._check_ratelimit('session', user_id): + await self.invalidate_session(False) + raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.') + async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" try: diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py index ffcc54d..10d219b 100644 --- a/litecord/ratelimits/main.py +++ b/litecord/ratelimits/main.py @@ -35,6 +35,9 @@ RATELIMITS = { '_ws.connect': Ratelimit(1, 5), '_ws.presence': Ratelimit(5, 60), '_ws.messages': Ratelimit(120, 60), + + # 1000 / 4h for new session issuing + '_ws.session': Ratelimit(1000, 14400) } class RatelimitManager: From 818571d336a63c6050bbe2c05de0491414dd72aa Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 16:43:38 -0300 Subject: [PATCH 43/69] gateway: add session_start_limit --- litecord/blueprints/gateway.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/litecord/blueprints/gateway.py b/litecord/blueprints/gateway.py index 76567aa..a301cac 100644 --- a/litecord/blueprints/gateway.py +++ b/litecord/blueprints/gateway.py @@ -1,3 +1,5 @@ +import time + from quart import Blueprint, jsonify, current_app as app from ..auth import token_check @@ -31,8 +33,23 @@ async def api_gateway_bot(): shards = max(int(guild_count / 1000), 1) - # TODO: session_start_limit (depends on ratelimits) + # get _ws.session ratelimit + ratelimit = app.ratelimiter.get_ratelimit('_ws.session') + bucket = ratelimit.get_bucket(user_id) + + # timestamp of bucket reset + reset_ts = bucket._window + bucket.second + + # how many seconds until bucket reset + reset_after_ts = reset_ts - time.time() + return jsonify({ 'url': get_gw(), 'shards': shards, + + 'session_start_limit': { + 'total': bucket.requests, + 'remaining': bucket._tokens, + 'reset_after': int(reset_after_ts * 1000), + } }) From 87dd70b4d9a16ae1cd60d0163219878006a2383f Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 17:54:48 -0300 Subject: [PATCH 44/69] permissions: add basic permission api - litecord.auth: insert request.user_id - storage: add get_member_role_ids --- litecord/auth.py | 10 +++- litecord/permissions.py | 129 ++++++++++++++++++++++++++++++++++++++++ litecord/storage.py | 18 ++++-- 3 files changed, 150 insertions(+), 7 deletions(-) diff --git a/litecord/auth.py b/litecord/auth.py index af241b5..498aa59 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -58,6 +58,12 @@ async def raw_token_check(token, db=None): async def token_check(): """Check token information.""" + # first, check if the request info already has a uid + try: + return request.user_id + except AttributeError: + pass + try: token = request.headers['Authorization'] except KeyError: @@ -66,4 +72,6 @@ async def token_check(): if token.startswith('Bot '): token = token.replace('Bot ', '') - return await raw_token_check(token) + user_id = await raw_token_check(token) + request.user_id = user_id + return user_id diff --git a/litecord/permissions.py b/litecord/permissions.py index c5c5966..850759f 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -1,5 +1,7 @@ import ctypes +from quart import current_app as app, request + # so we don't keep repeating the same # type for all the fields _i = ctypes.c_uint8 @@ -55,3 +57,130 @@ class Permissions(ctypes.Union): def numby(self): return self.binary + + +ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) + + +async def base_permissions(member_id, guild_id) -> Permissions: + """Compute the base permissions for a given user. + + Base permissions are + (permissions from @everyone role) + + (permissions from any other role the member has) + + This will give ALL_PERMISSIONS if base permissions + has the Administrator bit set. + """ + owner_id = await app.db.fetchval(""" + SELECT owner_id + FROM guilds + WHERE id = $1 + """, guild_id) + + if owner_id == member_id: + return ALL_PERMISSIONS + + # get permissions for @everyone + everyone_perms = await app.db.fetchval(""" + SELECT permissions + FROM roles + WHERE guild_id = $1 + """, guild_id) + + permissions = everyone_perms + + role_perms = await app.db.fetch(""" + SELECT permissions + FROM roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + for perm_num in role_perms: + permissions.binary |= perm_num + + if permissions.bits.administrator: + return ALL_PERMISSIONS + + return permissions + + +def _mix(perms: Permissions, overwrite: dict) -> Permissions: + # we make a copy of the binary representation + # so we don't modify the old perms in-place + # which could be an unwanted side-effect + result = perms.binary + + # negate the permissions that are denied + result &= ~overwrite['deny'] + + # combine the permissions that are allowed + result |= overwrite['allow'] + + return Permissions(result) + + +def _overwrite_mix(perms: Permissions, overwrites: dict, + target_id: int) -> Permissions: + overwrite = overwrites.get(target_id) + + if overwrite: + # only mix if overwrite found + return _mix(perms, overwrite) + + return perms + + +async def compute_overwrites(base_perms, user_id, channel_id: int, + guild_id: int = None): + """Compute the permissions in the context of a channel.""" + + if base_perms.bits.administrator: + return ALL_PERMISSIONS + + perms = base_perms + + # list of overwrites + overwrites = await app.storage.chan_overwrites(channel_id) + + if not guild_id: + guild_id = await app.storage.guild_from_channel(channel_id) + + # make it a map for better usage + overwrites = {o['id']: o for o in overwrites} + + perms = _overwrite_mix(perms, overwrites, guild_id) + + # apply role specific overwrites + allow, deny = 0, 0 + + # fetch roles from user and convert to int + role_ids = await app.storage.get_member_role_ids(guild_id, user_id) + role_ids = map(int, role_ids) + + # make the allow and deny binaries + for role_id in role_ids: + overwrite = overwrites.get(role_id) + if overwrite: + allow |= overwrite['allow'] + deny |= overwrite['deny'] + + # final step for roles: mix + perms = _mix(perms, { + 'allow': allow, + 'deny': deny + }) + + # apply member specific overwrites + perms = _overwrite_mix(perms, overwrites, user_id) + + return perms + + +async def get_permissions(member_id, channel_id): + """Get all the permissions for a user in a channel.""" + guild_id = await app.storage.guild_from_channel(channel_id) + base_perms = await base_permissions(member_id, guild_id) + + return await compute_overwrites(base_perms, member_id, + channel_id, guild_id) diff --git a/litecord/storage.py b/litecord/storage.py index 313b4f2..8cb5f8a 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -166,7 +166,9 @@ class Storage: WHERE guild_id = $1 and user_id = $2 """, guild_id, member_id) - async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: + async def get_member_role_ids(self, guild_id: int, + member_id: int) -> List[int]: + """Get a list of role IDs that are on a member.""" roles = await self.db.fetch(""" SELECT role_id::text FROM member_roles @@ -186,6 +188,10 @@ class Storage: VALUES ($1, $2, $3) """, member_id, guild_id, guild_id) + return list(map(str, roles)) + + async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: + roles = await self.get_member_role_ids(guild_id, member_id) return { 'user': await self.get_user(member_id), 'nick': row['nickname'], @@ -309,7 +315,7 @@ class Storage: WHERE channels.id = $1 """, channel_id) - async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: + async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: overwrite_rows = await self.db.fetch(""" SELECT target_type, target_role, target_user, allow, deny FROM channel_overwrites @@ -355,8 +361,8 @@ class Storage: dbase['type'] = chan_type res = await self._channels_extra(dbase) - res['permission_overwrites'] = \ - list(await self._chan_overwrites(channel_id)) + res['permission_overwrites'] = await self.chan_overwrites( + channel_id) res['id'] = str(res['id']) return res @@ -421,8 +427,8 @@ class Storage: res = await self._channels_extra(drow) - res['permission_overwrites'] = \ - list(await self._chan_overwrites(row['id'])) + res['permission_overwrites'] = await self.chan_overwrites( + row['id']) # Making sure. res['id'] = str(res['id']) From da8b0491742ad35e7130e58362bf870bd7f3f0c0 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 19:55:21 -0300 Subject: [PATCH 45/69] channel.messages: add permission checks this commit only adds permission checking to most parts of the message endpoints. - channel.messages: fix extract_limit's default param - channel.messages: check send_messages, mention_everyone, send_tts_messages - channel.messages: check manage_messages - blueprints.checks: add guild_perm_check, channel_perm_check - errors: add error_code property, change some inheritance - permissions: fix base_permissions - storage: fix get_reactions - storage: remove print-debug - run: use error_code property when given --- litecord/blueprints/channel/messages.py | 34 ++++++++++++++++++------- litecord/blueprints/checks.py | 31 ++++++++++++++++++++-- litecord/errors.py | 16 +++++++----- litecord/permissions.py | 27 +++++++++++++++++--- litecord/storage.py | 6 ++--- run.py | 6 ++++- 6 files changed, 94 insertions(+), 26 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 1e54fdf..4e5c7b2 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -4,7 +4,7 @@ from logbook import Logger from litecord.blueprints.auth import token_check -from litecord.blueprints.checks import channel_check +from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.blueprints.dms import try_dm_state from litecord.errors import MessageNotFound, Forbidden, BadRequest from litecord.enums import MessageType, ChannelType, GUILD_CHANS @@ -18,7 +18,7 @@ bp = Blueprint('channel_messages', __name__) def extract_limit(request, default: int = 50): try: - limit = int(request.args.get('limit', 50)) + limit = int(request.args.get('limit', default)) if limit not in range(0, 100): raise ValueError() @@ -142,12 +142,25 @@ async def create_message(channel_id): user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) + if ctype in GUILD_CHANS: + await channel_perm_check(user_id, channel_id, 'send_messages') + j = validate(await request.get_json(), MESSAGE_CREATE) message_id = get_snowflake() - # TODO: check SEND_MESSAGES permission # TODO: check connection to the gateway + mentions_everyone = ('@everyone' in j['content'] and + await channel_perm_check( + user_id, channel_id, 'mention_everyone', False + ) + ) + + is_tts = (j.get('tts', False) and + await channel_perm_check( + user_id, channel_id, 'send_tts_messages', False + )) + await app.db.execute( """ INSERT INTO messages (id, channel_id, author_id, content, tts, @@ -159,11 +172,9 @@ async def create_message(channel_id): user_id, j['content'], - # TODO: check SEND_TTS_MESSAGES - j.get('tts', False), + is_tts, + mentions_everyone, - # TODO: check MENTION_EVERYONE permissions - '@everyone' in j['content'], int(j.get('nonce', 0)), MessageType.DEFAULT.value ) @@ -238,8 +249,13 @@ async def delete_message(channel_id, message_id): WHERE messages.id = $1 """, message_id) - # TODO: MANAGE_MESSAGES permission check - if author_id != user_id: + by_perm = await channel_perm_check( + user_id, channel_id, 'manage_messages', False + ) + + by_ownership = author_id == user_id + + if not by_perm and not by_ownership: raise Forbidden('You can not delete this message') await app.db.execute(""" diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 17bd337..e051c11 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -1,7 +1,10 @@ from quart import current_app as app -from ..enums import ChannelType, GUILD_CHANS -from ..errors import GuildNotFound, ChannelNotFound, Forbidden +from litecord.enums import ChannelType, GUILD_CHANS +from litecord.errors import ( + GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions +) +from litecord.permissions import base_permissions, get_permissions async def guild_check(user_id: int, guild_id: int): @@ -54,3 +57,27 @@ async def channel_check(user_id, channel_id): if ctype == ChannelType.DM: peer_id = await app.storage.get_dm_peer(channel_id, user_id) return ctype, peer_id + + +async def guild_perm_check(user_id, guild_id, permission: str): + """Check guild permissions for a user.""" + base_perms = await base_permissions(user_id, guild_id) + hasperm = getattr(base_perms.bits, permission) + + if not hasperm: + raise MissingPermissions('Missing permissions.') + + +async def channel_perm_check(user_id, channel_id, + permission: str, raise_err=True): + """Check channel permissions for a user.""" + base_perms = await get_permissions(user_id, channel_id) + hasperm = getattr(base_perms.bits, permission) + + print(base_perms) + print(base_perms.binary) + + if not hasperm and raise_err: + raise MissingPermissions('Missing permissions.') + + return hasperm diff --git a/litecord/errors.py b/litecord/errors.py index e1fe1ac..70afcf9 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -29,22 +29,26 @@ class NotFound(LitecordError): status_code = 404 -class GuildNotFound(LitecordError): - status_code = 404 +class GuildNotFound(NotFound): + error_code = 10004 -class ChannelNotFound(LitecordError): - status_code = 404 +class ChannelNotFound(NotFound): + error_code = 10003 -class MessageNotFound(LitecordError): - status_code = 404 +class MessageNotFound(NotFound): + error_code = 10008 class Ratelimited(LitecordError): status_code = 429 +class MissingPermissions(Forbidden): + error_code = 50013 + + class WebsocketClose(Exception): @property def code(self): diff --git a/litecord/permissions.py b/litecord/permissions.py index 850759f..f68e259 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -52,6 +52,9 @@ class Permissions(ctypes.Union): def __init__(self, val: int): self.binary = val + def __repr__(self): + return f'' + def __int__(self): return self.binary @@ -88,14 +91,25 @@ async def base_permissions(member_id, guild_id) -> Permissions: WHERE guild_id = $1 """, guild_id) - permissions = everyone_perms + permissions = Permissions(everyone_perms) - role_perms = await app.db.fetch(""" - SELECT permissions - FROM roles + role_ids = await app.db.fetch(""" + SELECT role_id + FROM member_roles WHERE guild_id = $1 AND user_id = $2 """, guild_id, member_id) + role_perms = [] + + for row in role_ids: + rperm = await app.db.fetchval(""" + SELECT permissions + FROM roles + WHERE id = $1 + """, row['role_id']) + + role_perms.append(rperm) + for perm_num in role_perms: permissions.binary |= perm_num @@ -180,6 +194,11 @@ async def compute_overwrites(base_perms, user_id, channel_id: int, async def get_permissions(member_id, channel_id): """Get all the permissions for a user in a channel.""" guild_id = await app.storage.guild_from_channel(channel_id) + + # for non guild channels + if not guild_id: + return ALL_PERMISSIONS + base_perms = await base_permissions(member_id, guild_id) return await compute_overwrites(base_perms, member_id, diff --git a/litecord/storage.py b/litecord/storage.py index 8cb5f8a..ab1f226 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -567,8 +567,9 @@ class Storage: reactions = await self.db.fetch(""" SELECT user_id, emoji_type, emoji_id, emoji_text FROM message_reactions + WHERE message_id = $1 ORDER BY react_ts - """) + """, message_id) # ordered list of emoji emoji = [] @@ -616,15 +617,12 @@ class Storage: stats = react_stats[main_emoji] stats['count'] += 1 - print(row['user_id'], user_id) if row['user_id'] == user_id: stats['me'] = True # after processing reaction counts, # we get them in the same order # they were defined in the first loop. - print(emoji) - print(react_stats) return list(map(react_stats.get, emoji)) async def get_message(self, message_id: int, user_id=None) -> Dict: diff --git a/run.py b/run.py index a8059c6..bbdf701 100644 --- a/run.py +++ b/run.py @@ -202,9 +202,13 @@ async def handle_litecord_err(err): except IndexError: ejson = {} + try: + ejson['code'] = err.error_code + except AttributeError: + pass + return jsonify({ 'error': True, - # 'code': err.code, 'status': err.status_code, 'message': err.message, **ejson From 7ce59398c45b11178a388647fcd25e30297500e6 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 23:41:52 -0300 Subject: [PATCH 46/69] channels: use UPSERT on channel_ack --- litecord/blueprints/channels.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 2259b3e..c3cb7dc 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -134,6 +134,7 @@ async def guild_cleanup(channel_id): @bp.route('/', methods=['DELETE']) async def close_channel(channel_id): + """Close or delete a channel.""" user_id = await token_check() chan_type = await app.storage.get_chan_type(channel_id) @@ -210,7 +211,7 @@ async def close_channel(channel_id): # TODO: group dm pass - return '', 404 + raise ChannelNotFound() @bp.route('//pins', methods=['GET']) @@ -320,21 +321,16 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): if not message_id: message_id = await app.storage.chan_last_message(channel_id) - res = await app.db.execute(""" - UPDATE user_read_state - - SET last_message_id = $1, - mention_count = 0 - - WHERE user_id = $2 AND channel_id = $3 - """, message_id, user_id, channel_id) - - if res == 'UPDATE 0': - await app.db.execute(""" - INSERT INTO user_read_state - (user_id, channel_id, last_message_id, mention_count) - VALUES ($1, $2, $3, $4) - """, user_id, channel_id, message_id, 0) + await app.db.execute(""" + INSERT INTO user_read_state + (user_id, channel_id, last_message_id, mention_count) + VALUES + ($1, $2, $3, 0) + ON CONFLICT + DO UPDATE user_read_state + SET last_message_id = $3, mention_count = 0 + WHERE user_id = $1 AND channel_id = $2 + """, user_id, channel_id, message_id) if guild_id: await app.dispatcher.dispatch_user_guild( @@ -353,6 +349,7 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): @bp.route('//messages//ack', methods=['POST']) async def ack_channel(channel_id, message_id): + """Acknowledge a channel.""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) @@ -371,6 +368,7 @@ async def ack_channel(channel_id, message_id): @bp.route('//messages/ack', methods=['DELETE']) async def delete_read_state(channel_id): + """Delete the read state of a channel.""" user_id = await token_check() await channel_check(user_id, channel_id) From 03e42d9a4359b5490fce81145badfd4b57345b06 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 23:46:18 -0300 Subject: [PATCH 47/69] users: use UPSERT when setting chan overrides - channels: fix UPSERT --- litecord/blueprints/channels.py | 3 +-- litecord/blueprints/users.py | 25 +++++++++++-------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index c3cb7dc..dea2411 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -326,8 +326,7 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): (user_id, channel_id, last_message_id, mention_count) VALUES ($1, $2, $3, 0) - ON CONFLICT - DO UPDATE user_read_state + ON CONFLICT DO UPDATE SET last_message_id = $3, mention_count = 0 WHERE user_id = $1 AND channel_id = $2 """, user_id, channel_id, message_id) diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 49d8cb6..32bd8f0 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -446,20 +446,17 @@ async def patch_guild_settings(guild_id: int): continue for field in chan_overrides: - res = await app.db.execute(f""" - UPDATE guild_settings_channel_overrides - SET {field} = $1 - WHERE user_id = $2 - AND guild_id = $3 - AND channel_id = $4 - """, chan_overrides[field], user_id, guild_id, chan_id) - - if res == 'UPDATE 0': - await app.db.execute(f""" - INSERT INTO guild_settings_channel_overrides - (user_id, guild_id, channel_id, {field}) - VALUES ($1, $2, $3, $4) - """, user_id, guild_id, chan_id, chan_overrides[field]) + await app.db.execute(f""" + INSERT INTO guild_settings_channel_overrides + (user_id, guild_id, channel_id, {field}) + VALUES + ($1, $2, $3, $4) + ON CONFLICT DO UPDATE + SET {field} = $4 + WHERE user_id = $1 + AND guild_id = $2 + AND channel_id = $3 + """, user_id, guild_id, chan_id, chan_overrides[field]) settings = await app.storage.get_guild_settings_one(user_id, guild_id) From 22fe0f07c62005cb49e38845aa2b88d296e4771c Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Mon, 5 Nov 2018 00:15:59 -0300 Subject: [PATCH 48/69] channel: create channel.pins blueprint - schemas: add CHAN_UPDATE and CHAN_OVERWRITE --- litecord/blueprints/channel/__init__.py | 1 + litecord/blueprints/channel/pins.py | 93 +++++++++++++++++++++++++ litecord/blueprints/channels.py | 85 ---------------------- litecord/schemas.py | 51 ++++++++++++++ run.py | 3 +- 5 files changed, 147 insertions(+), 86 deletions(-) create mode 100644 litecord/blueprints/channel/pins.py diff --git a/litecord/blueprints/channel/__init__.py b/litecord/blueprints/channel/__init__.py index de5839a..4337684 100644 --- a/litecord/blueprints/channel/__init__.py +++ b/litecord/blueprints/channel/__init__.py @@ -1,2 +1,3 @@ from .messages import bp as channel_messages from .reactions import bp as channel_reactions +from .pins import bp as channel_pins diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py new file mode 100644 index 0000000..7d5b42b --- /dev/null +++ b/litecord/blueprints/channel/pins.py @@ -0,0 +1,93 @@ +from quart import Blueprint, current_app as app, request, jsonify + +from litecord.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.snowflake import snowflake_datetime + +bp = Blueprint('channel_pins', __name__) + + +@bp.route('//pins', methods=['GET']) +async def get_pins(channel_id): + """Get the pins for a channel""" + user_id = await token_check() + await channel_check(user_id, channel_id) + + ids = await app.db.fetch(""" + SELECT message_id + FROM channel_pins + WHERE channel_id = $1 + ORDER BY message_id ASC + """, channel_id) + + ids = [r['message_id'] for r in ids] + res = [] + + for message_id in ids: + message = await app.storage.get_message(message_id) + if message is not None: + res.append(message) + + return jsonify(res) + + +@bp.route('//pins/', methods=['PUT']) +async def add_pin(channel_id, message_id): + """Add a pin to a channel""" + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + # TODO: check MANAGE_MESSAGES permission + + await app.db.execute(""" + INSERT INTO channel_pins (channel_id, message_id) + VALUES ($1, $2) + """, channel_id, message_id) + + row = await app.db.fetchrow(""" + SELECT message_id + FROM channel_pins + WHERE channel_id = $1 + ORDER BY message_id ASC + LIMIT 1 + """, channel_id) + + timestamp = snowflake_datetime(row['message_id']) + + await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', { + 'channel_id': str(channel_id), + 'last_pin_timestamp': timestamp.isoformat() + }) + + return '', 204 + + +@bp.route('//pins/', methods=['DELETE']) +async def delete_pin(channel_id, message_id): + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + # TODO: check MANAGE_MESSAGES permission + + await app.db.execute(""" + DELETE FROM channel_pins + WHERE channel_id = $1 AND message_id = $2 + """, channel_id, message_id) + + row = await app.db.fetchrow(""" + SELECT message_id + FROM channel_pins + WHERE channel_id = $1 + ORDER BY message_id ASC + LIMIT 1 + """, channel_id) + + timestamp = snowflake_datetime(row['message_id']) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'CHANNEL_PINS_UPDATE', { + 'channel_id': str(channel_id), + 'last_pin_timestamp': timestamp.isoformat() + }) + + return '', 204 diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index dea2411..85282ec 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -4,7 +4,6 @@ from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger from ..auth import token_check -from ..snowflake import snowflake_datetime from ..enums import ChannelType, GUILD_CHANS from ..errors import ChannelNotFound @@ -214,90 +213,6 @@ async def close_channel(channel_id): raise ChannelNotFound() -@bp.route('//pins', methods=['GET']) -async def get_pins(channel_id): - user_id = await token_check() - await channel_check(user_id, channel_id) - - ids = await app.db.fetch(""" - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - """, channel_id) - - ids = [r['message_id'] for r in ids] - res = [] - - for message_id in ids: - message = await app.storage.get_message(message_id) - if message is not None: - res.append(message) - - return jsonify(res) - - -@bp.route('//pins/', methods=['PUT']) -async def add_pin(channel_id, message_id): - user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) - - # TODO: check MANAGE_MESSAGES permission - - await app.db.execute(""" - INSERT INTO channel_pins (channel_id, message_id) - VALUES ($1, $2) - """, channel_id, message_id) - - row = await app.db.fetchrow(""" - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - LIMIT 1 - """, channel_id) - - timestamp = snowflake_datetime(row['message_id']) - - await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', { - 'channel_id': str(channel_id), - 'last_pin_timestamp': timestamp.isoformat() - }) - - return '', 204 - - -@bp.route('//pins/', methods=['DELETE']) -async def delete_pin(channel_id, message_id): - user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) - - # TODO: check MANAGE_MESSAGES permission - - await app.db.execute(""" - DELETE FROM channel_pins - WHERE channel_id = $1 AND message_id = $2 - """, channel_id, message_id) - - row = await app.db.fetchrow(""" - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - LIMIT 1 - """, channel_id) - - timestamp = snowflake_datetime(row['message_id']) - - await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_PINS_UPDATE', { - 'channel_id': str(channel_id), - 'last_pin_timestamp': timestamp.isoformat() - }) - - return '', 204 - - @bp.route('//typing', methods=['POST']) async def trigger_typing(channel_id): user_id = await token_check() diff --git a/litecord/schemas.py b/litecord/schemas.py index 182420c..901961f 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -251,6 +251,57 @@ GUILD_UPDATE = { } +CHAN_OVERWRITE = { + 'type': 'dict', + 'schema': { + 'id': {'coerce': int}, + 'type': {'type': 'string', 'allowed': ['role', 'member']}, + 'allow': {'coerce': Permissions}, + 'deny': {'coerce': Permissions} + } +} + + +CHAN_UPDATE = { + 'name': { + 'type': 'string', 'minlength': 2, + 'maxlength': 100, 'required': False}, + + 'position': {'coerce': int, 'required': False}, + + 'topic': { + 'type': 'string', 'minlength': 0, + 'maxlength': 1024, 'required': False}, + + 'nsfw': {'type': 'boolean', 'required': False}, + 'rate_limit_per_user': { + 'coerce': int, 'min': 0, + 'max': 120, 'required': False}, + + 'bitrate': { + 'coerce': int, 'min': 8000, + + # NOTE: 'max' is 96000 for non-vip guilds + 'max': 128000, 'required': False}, + + 'user_limit': { + # user_limit being 0 means infinite. + 'coerce': int, 'min': 0, + 'max': 99, 'required': False + }, + + 'permission_overwrites': { + 'type': 'list', + 'schema': CHAN_OVERWRITE, + 'required': False + }, + + 'parent_id': {'coerce': int, 'required': False, 'nullable': True} + + +} + + ROLE_CREATE = { 'name': {'type': 'string', 'default': 'new role'}, 'permissions': {'coerce': Permissions, 'nullable': True}, diff --git a/run.py b/run.py index bbdf701..48a630c 100644 --- a/run.py +++ b/run.py @@ -25,7 +25,7 @@ from litecord.blueprints.guild import ( ) from litecord.blueprints.channel import ( - channel_messages, channel_reactions + channel_messages, channel_reactions, channel_pins ) from litecord.ratelimits.handler import ratelimit_handler @@ -79,6 +79,7 @@ bps = { channels: '/channels', channel_messages: '/channels', channel_reactions: '/channels', + channel_pins: '/channels', webhooks: None, science: None, From c200ab870716595a710ad286ca82296b50a7b7d8 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Mon, 5 Nov 2018 03:12:09 -0300 Subject: [PATCH 49/69] channels: add untested basic channel update - channels: fix upserts - users: fix upserts --- litecord/blueprints/channels.py | 119 ++++++++++++++++++++++++++++++-- litecord/blueprints/users.py | 18 +++-- 2 files changed, 124 insertions(+), 13 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 85282ec..13bbe97 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -3,11 +3,14 @@ import time from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger -from ..auth import token_check -from ..enums import ChannelType, GUILD_CHANS -from ..errors import ChannelNotFound +from litecord.auth import token_check +from litecord.enums import ChannelType, GUILD_CHANS +from litecord.errors import ChannelNotFound +from litecord.schemas import ( + validate, CHAN_UPDATE +) -from .checks import channel_check +from litecord.blueprints.checks import channel_check, channel_perm_check log = Logger(__name__) bp = Blueprint('channels', __name__) @@ -213,6 +216,107 @@ async def close_channel(channel_id): raise ChannelNotFound() +async def _update_pos(channel_id, pos: int): + await app.db.execute(""" + UPDATE guild_channels + SET position = $1 + WHERE id = $2 + """, pos, channel_id) + + +async def _mass_chan_update(guild_id, channel_ids: int): + for channel_id in channel_ids: + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch( + 'guild', guild_id, 'CHANNEL_UPDATE', chan) + + +async def _update_channel_common(channel_id, guild_id: int, j: dict): + if 'name' in j: + await app.db.execute(""" + UPDATE guild_channels + SET name = $1 + WHERE id = $2 + """, j['name'], channel_id) + + if 'position' in j: + channel_data = await app.storage.get_channel_data(guild_id) + + chans = [None * len(channel_data)] + for chandata in channel_data: + chans.insert(chandata['position'], int(chandata['id'])) + + # are we changing to the left or to the right? + + # left: [channel1, channel2, ..., channelN-1, channelN] + # becomes + # [channel1, channelN-1, channel2, ..., channelN] + # so we can say that the "main change" is + # channelN-1 going to the position channel2 + # was occupying. + current_pos = chans.index(channel_id) + new_pos = j['position'] + + # if the new position is bigger than the current one, + # we're making a left shift of all the channels that are + # beyond the current one, to make space + left_shift = new_pos > current_pos + + # find all channels that we'll have to shift + shift_block = (chans[current_pos:new_pos] + if left_shift else + chans[new_pos:current_pos] + ) + + shift = -1 if left_shift else 1 + + # do the shift (to the left or to the right) + await app.db.executemany(""" + UPDATE guild_channels + SET position = position + $1 + WHERE id = $2 + """, [(shift, chan_id) for chan_id in shift_block]) + + await _mass_chan_update(guild_id, shift_block) + + # since theres now an empty slot, move current channel to it + await _update_pos(channel_id, new_pos) + + +async def _update_text_channel(channel_id: int, j: dict): + pass + + +async def _update_voice_channel(channel_id: int, j: dict): + pass + + +@bp.route('/', methods=['PUT', 'PATCH']) +async def update_channel(channel_id): + """Update a channel's information""" + user_id = await token_check() + ctype, guild_id = await channel_check(user_id, channel_id) + + if ctype not in GUILD_CHANS: + raise ChannelNotFound('Can not edit non-guild channels.') + + await channel_perm_check(user_id, channel_id, 'manage_channels') + j = validate(await request.get_json(), CHAN_UPDATE) + + # TODO: categories? + update_handler = { + ChannelType.GUILD_TEXT: _update_text_channel, + ChannelType.GUILD_VOICE: _update_voice_channel, + }[ctype] + + await _update_channel_common(channel_id, guild_id, j) + await update_handler(channel_id, j) + + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan) + return jsonify(chan) + + @bp.route('//typing', methods=['POST']) async def trigger_typing(channel_id): user_id = await token_check() @@ -241,9 +345,12 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): (user_id, channel_id, last_message_id, mention_count) VALUES ($1, $2, $3, 0) - ON CONFLICT DO UPDATE + ON CONFLICT ON CONSTRAINT user_read_state_pkey + DO + UPDATE SET last_message_id = $3, mention_count = 0 - WHERE user_id = $1 AND channel_id = $2 + WHERE user_read_state.user_id = $1 + AND user_read_state.channel_id = $2 """, user_id, channel_id, message_id) if guild_id: diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 32bd8f0..0f75891 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -281,10 +281,11 @@ async def put_note(target_id: int): INSERT INTO notes (user_id, target_id, note) VALUES ($1, $2, $3) - ON CONFLICT DO UPDATE SET + ON CONFLICT ON CONSTRAINT notes_pkey + DO UPDATE SET note = $3 - WHERE - user_id = $1 AND target_id = $2 + WHERE notes.user_id = $1 + AND notes.target_id = $2 """, user_id, target_id, note) await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', { @@ -451,11 +452,14 @@ async def patch_guild_settings(guild_id: int): (user_id, guild_id, channel_id, {field}) VALUES ($1, $2, $3, $4) - ON CONFLICT DO UPDATE + ON CONFLICT + ON CONSTRAINT guild_settings_channel_overrides_pkey + DO + UPDATE SET {field} = $4 - WHERE user_id = $1 - AND guild_id = $2 - AND channel_id = $3 + WHERE guild_settings_channel_overrides.user_id = $1 + AND guild_settings_channel_overrides.guild_id = $2 + AND guild_settings_channel_overrides.channel_id = $3 """, user_id, guild_id, chan_id, chan_overrides[field]) settings = await app.storage.get_guild_settings_one(user_id, guild_id) From 8ffa14d934a6ae46b1468cd636817c3e59c60727 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Mon, 5 Nov 2018 03:15:46 -0300 Subject: [PATCH 50/69] ratelimits.handler: add check for when rule is none (404s) --- litecord/ratelimits/handler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/litecord/ratelimits/handler.py b/litecord/ratelimits/handler.py index 715a283..9a8987e 100644 --- a/litecord/ratelimits/handler.py +++ b/litecord/ratelimits/handler.py @@ -58,6 +58,11 @@ async def ratelimit_handler(): """ rule = request.url_rule + if rule is None: + return await _handle_global( + app.ratelimiter.global_bucket + ) + # rule.endpoint is composed of '.' # and so we can use that to make routes with different # methods have different ratelimits @@ -73,5 +78,6 @@ async def ratelimit_handler(): ratelimit = app.ratelimiter.get_ratelimit(rule_path) await _handle_specific(ratelimit) except KeyError: - ratelimit = app.ratelimiter.global_bucket - await _handle_global(ratelimit) + await _handle_global( + app.ratelimiter.global_bucket + ) From 7c274f0f70571f184964a905f532b492fbd1c509 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Mon, 5 Nov 2018 04:12:16 -0300 Subject: [PATCH 51/69] channels: add PUT /api/v6/:chan_id/permissions/:overwrite_id This finishes the basics on channel overwrites. SQL for instances: ```sql DROP TABLE channel_overwrites; ``` Then run `schema.sql`. - channels: finish implementations for update_{text,voice}_channel - storage: fix _overwrite_convert - schema.sql: use unique constraint instead of primary key in channel_overwrites --- litecord/blueprints/channels.py | 97 ++++++++++++++++++++++++++++++++- litecord/schemas.py | 13 ++--- litecord/storage.py | 8 ++- schema.sql | 10 +++- 4 files changed, 111 insertions(+), 17 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 13bbe97..ed5e62e 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -7,7 +7,7 @@ from litecord.auth import token_check from litecord.enums import ChannelType, GUILD_CHANS from litecord.errors import ChannelNotFound from litecord.schemas import ( - validate, CHAN_UPDATE + validate, CHAN_UPDATE, CHAN_OVERWRITE ) from litecord.blueprints.checks import channel_check, channel_perm_check @@ -231,6 +231,63 @@ async def _mass_chan_update(guild_id, channel_ids: int): 'guild', guild_id, 'CHANNEL_UPDATE', chan) +async def _process_overwrites(channel_id: int, overwrites: list): + for overwrite in overwrites: + + # 0 for user overwrite, 1 for role overwrite + target_type = 0 if overwrite['type'] == 'user' else 1 + target_role = None if target_type == 0 else overwrite['id'] + target_user = overwrite['id'] if target_type == 0 else None + + await app.db.execute( + """ + INSERT INTO channel_overwrites + (channel_id, target_type, target_role, + target_user, allow, deny) + VALUES + ($1, $2, $3, $4, $5, $6) + ON CONFLICT ON CONSTRAINT channel_overwrites_uniq + DO + UPDATE + SET allow = $5, deny = $6 + WHERE channel_overwrites.channel_id = $1 + AND channel_overwrites.target_type = $2 + AND channel_overwrites.target_role = $3 + AND channel_overwrites.target_user = $4 + """, + channel_id, target_type, + target_role, target_user, + overwrite['allow'], overwrite['deny']) + + +@bp.route('//permissions/', methods=['PUT']) +async def put_channel_overwrite(channel_id: int, overwrite_id: int): + """Insert or modify a channel overwrite.""" + user_id = await token_check() + ctype, guild_id = await channel_check(user_id, channel_id) + + if ctype not in GUILD_CHANS: + raise ChannelNotFound('Only usable for guild channels.') + + await channel_perm_check(user_id, guild_id, 'manage_roles') + + j = validate( + # inserting a fake id on the payload so validation passes through + {**await request.get_json(), **{'id': -1}}, + CHAN_OVERWRITE + ) + + await _process_overwrites(channel_id, [{ + 'allow': j['allow'], + 'deny': j['deny'], + 'type': j['type'], + 'id': overwrite_id + }]) + + await _mass_chan_update(guild_id, [channel_id]) + return '', 204 + + async def _update_channel_common(channel_id, guild_id: int, j: dict): if 'name' in j: await app.db.execute(""" @@ -282,13 +339,47 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict): # since theres now an empty slot, move current channel to it await _update_pos(channel_id, new_pos) + if 'channel_overwrites' in j: + overwrites = j['channel_overwrites'] + await _process_overwrites(channel_id, overwrites) + + +async def _common_guild_chan(channel_id, j: dict): + # common updates to the guild_channels table + for field in [field for field in j.keys() + if field in ('nsfw', 'parent_id')]: + await app.db.execute(f""" + UPDATE guild_channels + SET {field} = $1 + WHERE id = $2 + """, j[field], channel_id) + async def _update_text_channel(channel_id: int, j: dict): - pass + # first do the specific ones related to guild_text_channels + for field in [field for field in j.keys() + if field in ('topic', 'rate_limit_per_user')]: + await app.db.execute(f""" + UPDATE guild_text_channels + SET {field} = $1 + WHERE id = $2 + """, j[field], channel_id) + + await _common_guild_chan(channel_id, j) async def _update_voice_channel(channel_id: int, j: dict): - pass + # first do the specific ones in guild_voice_channels + for field in [field for field in j.keys() + if field in ('bitrate', 'user_limit')]: + await app.db.execute(f""" + UPDATE guild_voice_channels + SET {field} = $1 + WHERE id = $2 + """, j[field], channel_id) + + # yes, i'm letting voice channels have nsfw, you cant stop me + await _common_guild_chan(channel_id, j) @bp.route('/', methods=['PUT', 'PATCH']) diff --git a/litecord/schemas.py b/litecord/schemas.py index 901961f..710ccd0 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -252,13 +252,10 @@ GUILD_UPDATE = { CHAN_OVERWRITE = { - 'type': 'dict', - 'schema': { - 'id': {'coerce': int}, - 'type': {'type': 'string', 'allowed': ['role', 'member']}, - 'allow': {'coerce': Permissions}, - 'deny': {'coerce': Permissions} - } + 'id': {'coerce': int}, + 'type': {'type': 'string', 'allowed': ['role', 'member']}, + 'allow': {'coerce': Permissions}, + 'deny': {'coerce': Permissions} } @@ -292,7 +289,7 @@ CHAN_UPDATE = { 'permission_overwrites': { 'type': 'list', - 'schema': CHAN_OVERWRITE, + 'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE}, 'required': False }, diff --git a/litecord/storage.py b/litecord/storage.py index ab1f226..b78c2b0 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -324,18 +324,20 @@ class Storage: def _overwrite_convert(row): drow = dict(row) - drow['type'] = drow['target_type'] + + target_type = drow['target_type'] + drow['type'] = 'user' if target_type == 0 else 'role' # if type is 0, the overwrite is for a user # if type is 1, the overwrite is for a role drow['id'] = { 0: drow['target_user'], 1: drow['target_role'], - }[drow['type']] + }[target_type] drow['id'] = str(drow['id']) - drow.pop('overwrite_type') + drow.pop('target_type') drow.pop('target_user') drow.pop('target_role') diff --git a/schema.sql b/schema.sql index 3a654c3..e2823ec 100644 --- a/schema.sql +++ b/schema.sql @@ -347,11 +347,15 @@ CREATE TABLE IF NOT EXISTS channel_overwrites ( -- they're bigints (64bits), discord, -- for now, only needs 53. allow bigint DEFAULT 0, - deny bigint DEFAULT 0, - - PRIMARY KEY (channel_id, target_role, target_user) + deny bigint DEFAULT 0 ); +-- columns in private keys can't have NULL values, +-- so instead we use a custom constraint with UNIQUE + +ALTER TABLE channel_overwrites ADD CONSTRAINT channel_overwrites_uniq + UNIQUE (channel_id, target_role, target_user); + CREATE TABLE IF NOT EXISTS features ( id serial PRIMARY KEY, From 04d89a221401cdf1a0f4620d9f1c0d32c01b6cf3 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Mon, 5 Nov 2018 22:04:48 -0300 Subject: [PATCH 52/69] pubsub.lazy_guilds: major refactor This makes the whole process of generating a member list easier to understand and modify (from my point of view). The actual event dispatching functionality is not on this commit. - permissions: add optional storage kwarg --- litecord/permissions.py | 49 +++-- litecord/pubsub/lazy_guild.py | 337 +++++++++++++++++++++------------- 2 files changed, 237 insertions(+), 149 deletions(-) diff --git a/litecord/permissions.py b/litecord/permissions.py index f68e259..f5fa63c 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -65,7 +65,7 @@ class Permissions(ctypes.Union): ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) -async def base_permissions(member_id, guild_id) -> Permissions: +async def base_permissions(member_id, guild_id, storage=None) -> Permissions: """Compute the base permissions for a given user. Base permissions are @@ -75,7 +75,11 @@ async def base_permissions(member_id, guild_id) -> Permissions: This will give ALL_PERMISSIONS if base permissions has the Administrator bit set. """ - owner_id = await app.db.fetchval(""" + + if not storage: + storage = app.storage + + owner_id = await storage.db.fetchval(""" SELECT owner_id FROM guilds WHERE id = $1 @@ -85,7 +89,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: return ALL_PERMISSIONS # get permissions for @everyone - everyone_perms = await app.db.fetchval(""" + everyone_perms = await storage.db.fetchval(""" SELECT permissions FROM roles WHERE guild_id = $1 @@ -93,7 +97,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: permissions = Permissions(everyone_perms) - role_ids = await app.db.fetch(""" + role_ids = await storage.db.fetch(""" SELECT role_id FROM member_roles WHERE guild_id = $1 AND user_id = $2 @@ -102,7 +106,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: role_perms = [] for row in role_ids: - rperm = await app.db.fetchval(""" + rperm = await storage.db.fetchval(""" SELECT permissions FROM roles WHERE id = $1 @@ -119,7 +123,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: return permissions -def _mix(perms: Permissions, overwrite: dict) -> Permissions: +def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions: # we make a copy of the binary representation # so we don't modify the old perms in-place # which could be an unwanted side-effect @@ -134,20 +138,22 @@ def _mix(perms: Permissions, overwrite: dict) -> Permissions: return Permissions(result) -def _overwrite_mix(perms: Permissions, overwrites: dict, - target_id: int) -> Permissions: +def overwrite_find_mix(perms: Permissions, overwrites: dict, + target_id: int) -> Permissions: overwrite = overwrites.get(target_id) if overwrite: # only mix if overwrite found - return _mix(perms, overwrite) + return overwrite_mix(perms, overwrite) return perms async def compute_overwrites(base_perms, user_id, channel_id: int, - guild_id: int = None): + guild_id: int = None, storage=None): """Compute the permissions in the context of a channel.""" + if not storage: + storage = app.storage if base_perms.bits.administrator: return ALL_PERMISSIONS @@ -155,21 +161,21 @@ async def compute_overwrites(base_perms, user_id, channel_id: int, perms = base_perms # list of overwrites - overwrites = await app.storage.chan_overwrites(channel_id) + overwrites = await storage.chan_overwrites(channel_id) if not guild_id: - guild_id = await app.storage.guild_from_channel(channel_id) + guild_id = await storage.guild_from_channel(channel_id) # make it a map for better usage overwrites = {o['id']: o for o in overwrites} - perms = _overwrite_mix(perms, overwrites, guild_id) + perms = overwrite_find_mix(perms, overwrites, guild_id) # apply role specific overwrites allow, deny = 0, 0 # fetch roles from user and convert to int - role_ids = await app.storage.get_member_role_ids(guild_id, user_id) + role_ids = await storage.get_member_role_ids(guild_id, user_id) role_ids = map(int, role_ids) # make the allow and deny binaries @@ -180,26 +186,29 @@ async def compute_overwrites(base_perms, user_id, channel_id: int, deny |= overwrite['deny'] # final step for roles: mix - perms = _mix(perms, { + perms = overwrite_mix(perms, { 'allow': allow, 'deny': deny }) # apply member specific overwrites - perms = _overwrite_mix(perms, overwrites, user_id) + perms = overwrite_find_mix(perms, overwrites, user_id) return perms -async def get_permissions(member_id, channel_id): +async def get_permissions(member_id, channel_id, *, storage=None): """Get all the permissions for a user in a channel.""" - guild_id = await app.storage.guild_from_channel(channel_id) + if not storage: + storage = app.storage + + guild_id = await storage.guild_from_channel(channel_id) # for non guild channels if not guild_id: return ALL_PERMISSIONS - base_perms = await base_permissions(member_id, guild_id) + base_perms = await base_permissions(member_id, guild_id, storage) return await compute_overwrites(base_perms, member_id, - channel_id, guild_id) + channel_id, guild_id, storage) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index d010f78..2202234 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -2,15 +2,57 @@ Main code for Lazy Guild implementation in litecord. """ import pprint +from dataclasses import dataclass, asdict from collections import defaultdict -from typing import Any, List, Dict +from typing import Any, List, Dict, Union from logbook import Logger -from .dispatcher import Dispatcher +from litecord.pubsub.dispatcher import Dispatcher +from litecord.permissions import ( + Permissions, overwrite_find_mix, get_permissions +) log = Logger(__name__) +GroupID = Union[int, str] +Presence = Dict[str, Any] + + +@dataclass +class GroupInfo: + gid: GroupID + name: str + position: int + permissions: Permissions + + +@dataclass +class MemberList: + groups: List[GroupInfo] = None + group_info: Dict[GroupID, GroupInfo] = None + data: Dict[GroupID, Presence] = None + overwrites: Dict[int, Dict[str, Any]] = None + + def __bool__(self): + """Return if the current member list is fully initialized.""" + list_dict = asdict(self) + return all(v is not None for v in list_dict.values()) + + def __iter__(self): + """Iterate over all groups in the correct order. + + Yields a tuple containing :class:`GroupInfo` and + the List[Presence] for the group. + """ + for group in self.groups: + yield group, self.data[group.gid] + + +def _to_simple_group(presence: dict) -> str: + """Return a simple group (not a role), given a presence.""" + return 'offline' if presence['status'] == 'offline' else 'online' + class GuildMemberList: """This class stores the current member list information @@ -41,7 +83,6 @@ class GuildMemberList: """ def __init__(self, guild_id: int, channel_id: int, main_lg): - self.main_lg = main_lg self.guild_id = guild_id self.channel_id = channel_id @@ -52,167 +93,205 @@ class GuildMemberList: self.presence = main.app.presence self.state_man = main.app.state_manager - self.member_list = None - self.items = None + self.list = MemberList(None, None, None, None) #: holds the state of subscribed shards # to this channels' member list self.state = set() + def _set_empty_list(self): + self.list = MemberList(None, None, None, None) + async def _init_check(self): """Check if the member list is initialized before messing with it.""" - if self.member_list is None: + if not self.list: await self._init_member_list() - async def get_roles(self) -> List[Dict[str, Any]]: + async def _fetch_overwrites(self): + overwrites = await self.storage.chan_overwrites(self.channel_id) + overwrites = {int(ov['id']): ov for ov in overwrites} + self.list.overwrites = overwrites + + def _calc_member_group(self, roles: List[int], status: str): + """Calculate the best fitting group for a member, + given their roles and their current status.""" + try: + # the first group in the list + # that the member is entitled to is + # the selected group for the member. + group_id = next(g.gid for g in self.list.groups + if g.gid in roles) + except StopIteration: + # no group was found, so we fallback + # to simple group" + group_id = _to_simple_group({'status': status}) + + return group_id + + async def get_roles(self) -> List[GroupInfo]: """Get role information, but only: - the ID - the name - the position - - of all HOISTED roles.""" - # TODO: write own query for this - # TODO: calculate channel overrides - roles = await self.storage.get_role_data(self.guild_id) + - the permissions - return [{ - 'id': role['id'], - 'name': role['name'], - 'position': role['position'] - } for role in roles if role['hoist']] + of all HOISTED roles AND roles that + have permissions to read the channel + being referred to this :class:`GuildMemberList` + instance. + + The list is sorted by position. + """ + roledata = await self.storage.db.fetch(""" + SELECT id, name, hoist, position, permissions + FROM roles + WHERE guild_id = $1 + """, self.guild_id) + + hoisted = [ + GroupInfo(row['id'], row['name'], + row['position'], row['permissions']) + for row in roledata if row['hoist'] + ] + + # sort role list by position + hoisted = sorted(hoisted, key=lambda group: group.position) + + # we need to store them for later on + # for members + await self._fetch_overwrites() + + def _can_read_chan(group: GroupInfo): + # get the base role perms + role_perms = group.permissions + + # then the final perms for that role if + # any overwrite exists in the channel + final_perms = overwrite_find_mix( + role_perms, self.list.overwrites, group.gid) + + # update the group's permissions + # with the mixed ones + group.permissions = final_perms + + # if the role can read messages, then its + # part of the group. + return final_perms.bits.read_messages + + return list(filter(_can_read_chan, hoisted)) + + async def set_groups(self): + """Get the groups for the member list.""" + role_groups = await self.get_roles() + role_ids = [g.gid for g in role_groups] + + self.list.groups = role_ids + ['online', 'offline'] + + # inject default groups 'online' and 'offline' + self.list.groups = role_ids + [ + GroupInfo('online', 'online', -1, -1), + GroupInfo('offline', 'offline', -1, -1) + ] + self.list.group_info = {g.gid: g for g in role_groups} + + async def _pass_1(self, guild_presences: List[Presence]): + """First pass on generating the member list. + + This assigns all presences a single group. + """ + for presence in guild_presences: + member_id = int(presence['user']['id']) + + # list of roles for the member + member_roles = list(map(int, presence['roles'])) + + # get the member's permissions relative to the channel + # (accounting for channel overwrites) + member_perms = await get_permissions( + member_id, self.channel_id, storage=self.storage) + + if not member_perms.bits.read_messages: + continue + + # if the member is offline, we + # default give them the offline group. + status = presence['status'] + group_id = ('offline' if status == 'offline' + else self._calc_member_group(member_roles, status)) + + self.list.data[group_id].append(presence) + + async def _sort_groups(self): + members = await self.storage.get_member_data(self.guild_id) + + # make a dictionary of member ids to nicknames + # so we don't need to keep querying the db on + # every loop iteration + member_nicks = {m['user']['id']: m.get('nick') + for m in members} + + for group_members in self.list.data.values(): + def display_name(presence: Presence) -> str: + uid = presence['user']['id'] + + uname = presence['user']['username'] + nick = member_nicks.get(uid) + + return nick or uname + + # this should update the list in-place + group_members.sort(key=display_name) async def _init_member_list(self): - """Fill in :attr:`GuildMemberList.member_list` - with information about the guilds' members.""" + """Generate the main member list with groups.""" member_ids = await self.storage.get_member_ids(self.guild_id) guild_presences = await self.presence.guild_presences( member_ids, self.guild_id) - guild_roles = await self.get_roles() - - # sort by position - guild_roles.sort(key=lambda role: role['position']) - roleids = [r['id'] for r in guild_roles] - - # groups are: - # - roles that are hoisted - # - "online" and "offline", with "online" - # being for people without any roles. - - groups = roleids + ['online', 'offline'] + await self.set_groups() log.debug('{} presences, {} groups', - len(guild_presences), len(groups)) + len(guild_presences), + len(self.list.groups)) - group_data = {group: [] for group in groups} + self.list.data = {group.gid: [] for group in self.list.groups} - print('group data', group_data) + # first pass: set which presences + # go to which groups + await self._pass_1(guild_presences) - def _try_hier(role_id: str, roleids: list): - """Try to fetch a role's position in the hierarchy""" - try: - return roleids.index(role_id) - except ValueError: - # the given role isn't on a group - # so it doesn't count for our purposes. - return 0 + # second pass: sort each group's members + # by the display name + await self._sort_groups() - for presence in guild_presences: - # simple group (online or offline) - # we'll decide on the best group for the presence later on - simple_group = ('offline' - if presence['status'] == 'offline' - else 'online') + @property + def items(self) -> list: + """Main items list.""" - # get the best possible role - roles = sorted( - presence['roles'], - key=lambda role_id: _try_hier(role_id, roleids) - ) + # TODO: maybe make this stored in the list + # so we don't need to keep regenning? - try: - best_role_id = roles[0] - except IndexError: - # no hoisted roles exist in the guild, assign - # the @everyone role - best_role_id = str(self.guild_id) - - print('best role', best_role_id, str(self.guild_id)) - print('simple group assign', simple_group) - - # if the best role is literally the @everyone role, - # this user has no hoisted roles - if best_role_id == str(self.guild_id): - # this user has no roles, put it on online/offline - group_data[simple_group].append(presence) - continue - - # this user has a best_role that isn't the - # @everyone role, so we'll put them in the respective group - try: - group_data[best_role_id].append(presence) - except KeyError: - group_data[simple_group].append(presence) - - # go through each group and sort the resulting members by display name - - members = await self.storage.get_member_data(self.guild_id) - member_nicks = {member['user']['id']: member.get('nick') - for member in members} - - # now we'll sort each group by their display name - # (can be their current nickname OR their username - # if no nickname is set) - print('pre-sorted group data') - pprint.pprint(group_data) - - for _, group_list in group_data.items(): - def display_name(presence: dict) -> str: - uid = presence['user']['id'] - - uname = presence['user']['username'] - nick = member_nicks[uid] - - return nick or uname - - group_list.sort(key=display_name) - - pprint.pprint(group_data) - - self.member_list = { - 'groups': groups, - 'data': group_data - } - - def get_items(self) -> list: - """Generate the main items list,""" - if self.member_list is None: + if not self.list: return [] - if self.items: - return self.items - - groups = self.member_list['groups'] - res = [] - for group in groups: - members = self.member_list['data'][group] + # NOTE: maybe use map()? + for group, presences in self.list: res.append({ 'group': { - 'id': group, - 'count': len(members), + 'id': group.gid, + 'count': len(presences), } }) - for member in members: + for presence in presences: res.append({ - 'member': member + 'member': presence }) - self.items = res return res async def sub(self, session_id: str): @@ -230,7 +309,7 @@ class GuildMemberList: # uninitialized) for a future subscriber. if not self.state: - self.member_list = None + self._set_empty_list() async def shard_query(self, session_id: str, ranges: list): """Send a GUILD_MEMBER_LIST_UPDATE event @@ -259,7 +338,7 @@ class GuildMemberList: # subscribe the state to this list await self.sub(session_id) - # TODO: subscribe shard to 'everyone' + # TODO: subscribe shard to the 'everyone' member list # and forward the query to that list reply = { @@ -273,9 +352,9 @@ class GuildMemberList: 'groups': [ { - 'count': len(self.member_list['data'][group]), - 'id': group - } for group in self.member_list['groups'] + 'count': len(presences), + 'id': group.gid + } for group, presences in self.list ], 'ops': [], @@ -288,12 +367,12 @@ class GuildMemberList: if itemcount < 0: continue - items = self.get_items() + # TODO: subscribe user to the slice reply['ops'].append({ 'op': 'SYNC', 'range': [start, end], - 'items': items[start:end], + 'items': self.items[start:end], }) # the first GUILD_MEMBER_LIST_UPDATE for a shard From a1a914dc874177f922057e4595b1b10e016a64fd Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Tue, 6 Nov 2018 17:42:27 -0300 Subject: [PATCH 53/69] add manage.py --- litecord/pubsub/lazy_guild.py | 2 ++ manage.py | 12 ++++++++++++ manage/__init__.py | 0 manage/main.py | 30 ++++++++++++++++++++++++++++++ run.py | 21 +++++++++++++++------ 5 files changed, 59 insertions(+), 6 deletions(-) create mode 100755 manage.py create mode 100644 manage/__init__.py create mode 100644 manage/main.py diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 2202234..ba19276 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -21,6 +21,7 @@ Presence = Dict[str, Any] @dataclass class GroupInfo: + """Store information about a specific group.""" gid: GroupID name: str position: int @@ -29,6 +30,7 @@ class GroupInfo: @dataclass class MemberList: + """Total information on the guild's member list.""" groups: List[GroupInfo] = None group_info: Dict[GroupID, GroupInfo] = None data: Dict[GroupID, Presence] = None diff --git a/manage.py b/manage.py new file mode 100755 index 0000000..a65b9d6 --- /dev/null +++ b/manage.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +import logging +import sys + +from manage.main import main + +import config + +logging.basicConfig(level=logging.DEBUG) + +if __name__ == '__main__': + sys.exit(main(config)) diff --git a/manage/__init__.py b/manage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/manage/main.py b/manage/main.py new file mode 100644 index 0000000..b05d6db --- /dev/null +++ b/manage/main.py @@ -0,0 +1,30 @@ +import asyncio +from dataclasses import dataclass + + +from run import init_app_managers, init_app_db + + +@dataclass +class FakeApp: + """Fake app instance.""" + config: dict + db = None + loop: asyncio.BaseEventLoop = None + ratelimiter = None + state_manager = None + storage = None + dispatcher = None + presence = None + + +def main(config): + """Start the script""" + loop = asyncio.get_event_loop() + cfg = getattr(config, config.MODE) + app = FakeApp(cfg.__dict__) + + loop.run_until_complete(init_app_db(app)) + init_app_managers(app) + + print(app) diff --git a/run.py b/run.py index 48a630c..c1af61f 100644 --- a/run.py +++ b/run.py @@ -143,16 +143,14 @@ async def app_set_ratelimit_headers(resp): return resp -@app.before_serving -async def app_before_serving(): - log.info('opening db') +async def init_app_db(app): + """Connect to databases""" app.db = await asyncpg.create_pool(**app.config['POSTGRES']) - g.app = app +def init_app_managers(app): + """Initialize singleton classes.""" app.loop = asyncio.get_event_loop() - g.loop = asyncio.get_event_loop() - app.ratelimiter = RatelimitManager() app.state_manager = StateManager() app.storage = Storage(app.db) @@ -162,6 +160,17 @@ async def app_before_serving(): app.state_manager, app.dispatcher) app.storage.presence = app.presence + +@app.before_serving +async def app_before_serving(): + log.info('opening db') + await init_app_db(app) + + g.app = app + g.loop = asyncio.get_event_loop() + + init_app_managers(app) + # start the websocket, etc host, port = app.config['WS_HOST'], app.config['WS_PORT'] log.info(f'starting websocket at {host} {port}') From c3210bf5b013f77f8be8f0c2c9b6f832d418aa1b Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Tue, 6 Nov 2018 18:43:36 -0300 Subject: [PATCH 54/69] manage: add dummy migration command --- manage/cmd/migration/__init__.py | 1 + manage/cmd/migration/command.py | 17 +++++++++++++++++ manage/main.py | 30 +++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 manage/cmd/migration/__init__.py create mode 100644 manage/cmd/migration/command.py diff --git a/manage/cmd/migration/__init__.py b/manage/cmd/migration/__init__.py new file mode 100644 index 0000000..3a9fa59 --- /dev/null +++ b/manage/cmd/migration/__init__.py @@ -0,0 +1 @@ +from .command import setup as migration diff --git a/manage/cmd/migration/command.py b/manage/cmd/migration/command.py new file mode 100644 index 0000000..0296df6 --- /dev/null +++ b/manage/cmd/migration/command.py @@ -0,0 +1,17 @@ +async def migrate_cmd(app, args): + """Main migration command. + + This makes sure the database + is updated. + """ + print('not implemented yet') + + +def setup(subparser): + migrate_parser = subparser.add_parser( + 'migrate', + help='Run migration tasks', + description=migrate_cmd.__doc__ + ) + + migrate_parser.set_defaults(func=migrate_cmd) diff --git a/manage/main.py b/manage/main.py index b05d6db..7cb0f99 100644 --- a/manage/main.py +++ b/manage/main.py @@ -1,8 +1,14 @@ import asyncio +import argparse +from sys import argv from dataclasses import dataclass +from logbook import Logger from run import init_app_managers, init_app_db +from manage.cmd.migration import migration + +log = Logger(__name__) @dataclass @@ -18,6 +24,15 @@ class FakeApp: presence = None +def init_parser(): + parser = argparse.ArgumentParser() + subparser = parser.add_subparsers(help='operations') + + migration(subparser) + + return parser + + def main(config): """Start the script""" loop = asyncio.get_event_loop() @@ -27,4 +42,17 @@ def main(config): loop.run_until_complete(init_app_db(app)) init_app_managers(app) - print(app) + # initialize argparser + parser = init_parser() + + try: + if len(argv) < 2: + parser.print_help() + return + + args = parser.parse_args() + loop.run_until_complete(args.func(app, args)) + except Exception: + log.exception('error while running command') + finally: + loop.run_until_complete(app.db.close()) From db9fd783f56990410a69d83a26b9c3da11ad1fa1 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Tue, 6 Nov 2018 20:46:17 -0300 Subject: [PATCH 55/69] manage.cmd.migration: add rudimentary implementation Also add table changes for future message embeds. --- manage/cmd/migration/command.py | 128 +++++++++++++++++- .../scripts/1_message_embed_type.sql | 6 + schema.sql | 13 +- 3 files changed, 134 insertions(+), 13 deletions(-) create mode 100644 manage/cmd/migration/scripts/1_message_embed_type.sql diff --git a/manage/cmd/migration/command.py b/manage/cmd/migration/command.py index 0296df6..fd237fb 100644 --- a/manage/cmd/migration/command.py +++ b/manage/cmd/migration/command.py @@ -1,10 +1,134 @@ -async def migrate_cmd(app, args): +import inspect +from pathlib import Path +from dataclasses import dataclass +from collections import namedtuple +from typing import Dict + +import asyncpg +from logbook import Logger + +log = Logger(__name__) + + +Migration = namedtuple('Migration', 'id name path') + + +@dataclass +class MigrationContext: + """Hold information about migration.""" + migration_folder: Path + scripts: Dict[int, Migration] + + @property + def latest(self): + """Return the latest migration ID.""" + return max(self.scripts.keys()) + + +def make_migration_ctx() -> MigrationContext: + """Create the MigrationContext instance.""" + # taken from https://stackoverflow.com/a/6628348 + script_path = inspect.stack()[0][1] + script_folder = '/'.join(script_path.split('/')[:-1]) + script_folder = Path(script_folder) + + migration_folder = script_folder / 'scripts' + + mctx = MigrationContext(migration_folder, {}) + + for mig_path in migration_folder.glob('*.sql'): + mig_path_str = str(mig_path) + + # extract migration script id and name + mig_filename = mig_path_str.split('/')[-1].split('.')[0] + name_fragments = mig_filename.split('_') + + mig_id = int(name_fragments[0]) + mig_name = '_'.join(name_fragments[1:]) + + mctx.scripts[mig_id] = Migration( + mig_id, mig_name, mig_path) + + return mctx + + +async def _ensure_changelog(app, ctx): + # make sure we have the migration table up + + try: + await app.db.execute(""" + CREATE TABLE migration_log ( + change_num bigint NOT NULL, + + apply_ts timestamp without time zone default + (now() at time zone 'utc'), + + description text, + + PRIMARY KEY (change_num) + ); + """) + + # if we were able to create the + # migration_log table, insert that we are + # on the latest version. + await app.db.execute(""" + INSERT INTO migration_log (change_num, description) + VALUES ($1, $2) + """, ctx.latest, 'migration setup') + except asyncpg.DuplicateTableError: + log.debug('existing migration table') + + +async def apply_migration(app, migration: Migration): + """Apply a single migration.""" + migration_sql = migration.path.read_text(encoding='utf-8') + + try: + await app.db.execute(""" + INSERT INTO migration_log (change_num, description) + VALUES ($1, $2) + """, migration.id, f'migration: {migration.name}') + except asyncpg.UniqueViolationError: + log.warning('already applied {}', migration.id) + return + + await app.db.execute(migration_sql) + log.info('applied {}', migration.id) + + +async def migrate_cmd(app, _args): """Main migration command. This makes sure the database is updated. """ - print('not implemented yet') + + ctx = make_migration_ctx() + + await _ensure_changelog(app, ctx) + + # local point in the changelog + local_change = await app.db.fetchval(""" + SELECT max(change_num) + FROM migration_log + """) + + local_change = local_change or 0 + latest_change = ctx.latest + + if local_change == latest_change: + print('no changes to do, exiting') + return + + # we do local_change + 1 so we start from the + # next migration to do, end in latest_change + 1 + # because of how range() works. + for idx in range(local_change + 1, latest_change + 1): + migration = ctx.scripts.get(idx) + + print('applying', migration.id, migration.name) + await apply_migration(app, migration) def setup(subparser): diff --git a/manage/cmd/migration/scripts/1_message_embed_type.sql b/manage/cmd/migration/scripts/1_message_embed_type.sql new file mode 100644 index 0000000..8650558 --- /dev/null +++ b/manage/cmd/migration/scripts/1_message_embed_type.sql @@ -0,0 +1,6 @@ +-- unused tables +DROP TABLE message_embeds; +DROP TABLE embeds; + +ALTER TABLE messages + ADD COLUMN embeds jsonb DEFAULT '[]' diff --git a/schema.sql b/schema.sql index e2823ec..9f9bda9 100644 --- a/schema.sql +++ b/schema.sql @@ -486,11 +486,6 @@ CREATE TABLE IF NOT EXISTS bans ( ); -CREATE TABLE IF NOT EXISTS embeds ( - -- TODO: this table - id bigint PRIMARY KEY -); - CREATE TABLE IF NOT EXISTS messages ( id bigint PRIMARY KEY, channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, @@ -511,6 +506,8 @@ CREATE TABLE IF NOT EXISTS messages ( tts bool default false, mention_everyone bool default false, + embeds jsonb DEFAULT '[]', + nonce bigint default 0, message_type int NOT NULL @@ -522,12 +519,6 @@ CREATE TABLE IF NOT EXISTS message_attachments ( PRIMARY KEY (message_id, attachment) ); -CREATE TABLE IF NOT EXISTS message_embeds ( - message_id bigint REFERENCES messages (id) UNIQUE, - embed_id bigint REFERENCES embeds (id), - PRIMARY KEY (message_id, embed_id) -); - CREATE TABLE IF NOT EXISTS message_reactions ( message_id bigint REFERENCES messages (id), user_id bigint REFERENCES users (id), From 86923cc6e39bf1eb6b48a50341bb709539504fbc Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Tue, 6 Nov 2018 20:49:43 -0300 Subject: [PATCH 56/69] README: add migrate instructions --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 5f5c9c4..3b3756f 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,10 @@ $ psql -f schema.sql litecord # edit config.py as you wish $ cp config.example.py config.py +# run database migrations (this is a +# required step in setup) +$ pipenv run ./manage.py migrate + # Install all packages: $ pipenv install --dev ``` @@ -50,3 +54,10 @@ Use `--access-log -` to output access logs to stdout. ```sh $ pipenv run hypercorn run:app ``` + +## Updating + +```sh +$ git pull +$ pipenv run ./manage.py migrate +``` From 55f89196899eac08f8877545b0beedc0e5d2b1cb Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 04:11:17 -0300 Subject: [PATCH 57/69] pubsub.lazy_guilds: add support for 'everyone'-type GMLs if the everyone role can read the channel, then the member list will be equivalent to any other list where the everyone role can read the channel. with this in mind we can generate a "global" member list directed only for that usecase. - permissions: add role_permissions --- litecord/permissions.py | 42 +++++++++++++++--- litecord/pubsub/lazy_guild.py | 81 ++++++++++++++++++++--------------- schema.sql | 2 +- 3 files changed, 82 insertions(+), 43 deletions(-) diff --git a/litecord/permissions.py b/litecord/permissions.py index f5fa63c..b63fa98 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -65,6 +65,20 @@ class Permissions(ctypes.Union): ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) +async def get_role_perms(guild_id, role_id, storage=None) -> Permissions: + """Get the raw :class:`Permissions` object for a role.""" + if not storage: + storage = app.storage + + perms = await storage.db.fetchval(""" + SELECT permissions + FROM roles + WHERE guild_id = $1 AND id = $2 + """, guild_id, role_id) + + return Permissions(perms) + + async def base_permissions(member_id, guild_id, storage=None) -> Permissions: """Compute the base permissions for a given user. @@ -89,13 +103,7 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions: return ALL_PERMISSIONS # get permissions for @everyone - everyone_perms = await storage.db.fetchval(""" - SELECT permissions - FROM roles - WHERE guild_id = $1 - """, guild_id) - - permissions = Permissions(everyone_perms) + permissions = await get_role_perms(guild_id, guild_id, storage) role_ids = await storage.db.fetch(""" SELECT role_id @@ -149,6 +157,26 @@ def overwrite_find_mix(perms: Permissions, overwrites: dict, return perms +async def role_permissions(guild_id: int, role_id: int, + channel_id: int, storage=None) -> Permissions: + """Get the permissions for a role, in relation to a channel""" + if not storage: + storage = app.storage + + perms = await get_role_perms(guild_id, role_id, storage) + + overwrite = await storage.db.fetchrow(""" + SELECT allow, deny + FROM channel_overwrites + WHERE channel_id = $1 AND target_type = $2 AND target_role = $3 + """, channel_id, 1, role_id) + + if overwrite: + perms = overwrite_mix(perms, overwrite) + + return perms + + async def compute_overwrites(base_perms, user_id, channel_id: int, guild_id: int = None, storage=None): """Compute the permissions in the context of a channel.""" diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index ba19276..5d865e4 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -10,7 +10,7 @@ from logbook import Logger from litecord.pubsub.dispatcher import Dispatcher from litecord.permissions import ( - Permissions, overwrite_find_mix, get_permissions + Permissions, overwrite_find_mix, get_permissions, role_permissions ) log = Logger(__name__) @@ -90,16 +90,15 @@ class GuildMemberList: # a really long chain of classes to get # to the storage instance... - main = main_lg.main_dispatcher - self.storage = main.app.storage - self.presence = main.app.presence - self.state_man = main.app.state_manager + self.main = main_lg + self.storage = self.main.app.storage + self.presence = self.main.app.presence + self.state_man = self.main.app.state_manager self.list = MemberList(None, None, None, None) - #: holds the state of subscribed shards - # to this channels' member list - self.state = set() + #: {session_id: set[list]} + self.state = defaultdict(set) def _set_empty_list(self): self.list = MemberList(None, None, None, None) @@ -296,14 +295,16 @@ class GuildMemberList: return res - async def sub(self, session_id: str): + async def sub(self, _session_id: str): """Subscribe a shard to the member list.""" await self._init_check() - self.state.add(session_id) async def unsub(self, session_id: str): """Unsubscribe a shard from the member list""" - self.state.discard(session_id) + try: + self.state.pop(session_id) + except KeyError: + pass # once we reach 0 subscribers, # we drop the current member list we have (for memory) @@ -327,6 +328,29 @@ class GuildMemberList: ranges of the list that we want. """ + # a guild list with a channel id of the guild + # represents the 'everyone' global list. + list_id = ('everyone' + if self.channel_id == self.guild_id + else str(self.channel_id)) + + # if everyone can read the channel, + # we direct the request to the 'everyone' gml instance + # instead of the current one. + everyone_perms = await role_permissions( + self.guild_id, + self.guild_id, + self.channel_id, + storage=self.storage + ) + + if everyone_perms.bits.read_messages and list_id != 'everyone': + everyone_gml = await self.main.get_gml(self.guild_id) + + return await everyone_gml.shard_query( + session_id, ranges + ) + await self._init_check() # make sure this is a sane state @@ -335,22 +359,9 @@ class GuildMemberList: await self.unsub(session_id) return - # since this is a sane state AND - # trying to query, we automatically - # subscribe the state to this list - await self.sub(session_id) - - # TODO: subscribe shard to the 'everyone' member list - # and forward the query to that list - reply = { 'guild_id': str(self.guild_id), - - # TODO: everyone for channels without overrides - # channel_id for channels WITH overrides. - - 'id': 'everyone', - # 'id': str(self.channel_id), + 'id': list_id, 'groups': [ { @@ -386,22 +397,17 @@ class GuildMemberList: return list(self.state) async def dispatch(self, event: str, data: Any): - """The dispatch() method here, instead of being - about dispatching a single event to the subscribed - users and forgetting about it, is about storing - the actual member list information so that we - can generate the respective events to the users. + """Modify the member list and dispatch the respective + events to subscribed shards. GuildMemberList stores the current guilds' list - in its :attr:`GuildMemberList.member_list` attribute, + in its :attr:`GuildMemberList.list` attribute, with that attribute being modified via different calls to :meth:`GuildMemberList.dispatch` """ - if self.member_list is None: - # if the list is currently uninitialized, - # no subscribers actually happened, so - # we can safely drop the incoming event. + # if no subscribers, drop event + if not self.list: return @@ -436,6 +442,11 @@ class LazyGuildDispatcher(Dispatcher): channel_id ) + # if we don't find a guild, we just + # set it the same as the channel. + if not guild_id: + guild_id = channel_id + gml = GuildMemberList(guild_id, channel_id, self) self.state[channel_id] = gml self.guild_map[guild_id].append(channel_id) diff --git a/schema.sql b/schema.sql index 9f9bda9..11b2caf 100644 --- a/schema.sql +++ b/schema.sql @@ -331,7 +331,7 @@ CREATE TABLE IF NOT EXISTS channel_overwrites ( channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, -- target_type = 0 -> use target_user - -- target_type = 1 -> user target_role + -- target_type = 1 -> use target_role -- discord already has overwrite.type = 'role' | 'member' -- so this allows us to be more compliant with the API target_type integer default null, From 3b2f6062fc20437b8671ce13871c1f123155ca5c Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 16:41:22 -0300 Subject: [PATCH 58/69] pubsub.lazy_guild: subscribe user to the given range --- litecord/pubsub/lazy_guild.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 5d865e4..c4cf6ef 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -380,7 +380,7 @@ class GuildMemberList: if itemcount < 0: continue - # TODO: subscribe user to the slice + self.state[session_id].add((start, end)) reply['ops'].append({ 'op': 'SYNC', From 7bcd08ef7aabd8d7408db957003b4390568911b9 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 17:13:30 -0300 Subject: [PATCH 59/69] pubsub.lazy_guild: remove LazyGuildDispatcher.unsub states subscribe via shard_query only. --- litecord/pubsub/lazy_guild.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index c4cf6ef..e8f7b30 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -459,10 +459,6 @@ class LazyGuildDispatcher(Dispatcher): self.guild_map[guild_id] )) - async def sub(self, chan_id, session_id): - gml = await self.get_gml(chan_id) - await gml.sub(session_id) - async def unsub(self, chan_id, session_id): gml = await self.get_gml(chan_id) await gml.unsub(session_id) From c212cbd39281f56cfd925862dc38c42442b796de Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 17:53:31 -0300 Subject: [PATCH 60/69] pubsub.lazy_guild: change some instance vars to properties - utils: add index_by_func --- litecord/pubsub/lazy_guild.py | 26 +++++++++++++++++++------- litecord/utils.py | 10 ++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index e8f7b30..8063383 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -12,6 +12,7 @@ from litecord.pubsub.dispatcher import Dispatcher from litecord.permissions import ( Permissions, overwrite_find_mix, get_permissions, role_permissions ) +from litecord.utils import index_by_func log = Logger(__name__) @@ -88,19 +89,30 @@ class GuildMemberList: self.guild_id = guild_id self.channel_id = channel_id - # a really long chain of classes to get - # to the storage instance... self.main = main_lg - self.storage = self.main.app.storage - self.presence = self.main.app.presence - self.state_man = self.main.app.state_manager - self.list = MemberList(None, None, None, None) - #: {session_id: set[list]} + #: store the states that are subscribed to the list + # type is{session_id: set[list]} self.state = defaultdict(set) + @property + def storage(self): + """Get the global :class:`Storage` instance.""" + return self.main.app.storage + + @property + def presence(self): + """Get the global :class:`PresenceManager` instance.""" + return self.main.app.presence + + @property + def state_man(self): + """Get the global :class:`StateManager` instance.""" + return self.main.app.state_manager + def _set_empty_list(self): + """Set the member list as being empty.""" self.list = MemberList(None, None, None, None) async def _init_check(self): diff --git a/litecord/utils.py b/litecord/utils.py index 3949194..1a2b676 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -27,3 +27,13 @@ async def task_wrapper(name: str, coro): def dict_get(mapping, key, default): """Return `default` even when mapping[key] is None.""" return mapping.get(key) or default + + +def index_by_func(function, indexable: iter) -> int: + """Search in an idexable and return the index number + for an iterm that has func(item) = True.""" + for index, item in indexable: + if function(item): + return index + + return None From bd9c4cb26cfd8ca009488f5371403c58250f6ed3 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 18:52:50 -0300 Subject: [PATCH 61/69] pubsub.lazy_guild: add implementation for pres_update - utils: fix index_by_func --- litecord/gateway/websocket.py | 6 +- litecord/presence.py | 8 +- litecord/pubsub/lazy_guild.py | 165 +++++++++++++++++++++++++++------- litecord/utils.py | 2 +- 4 files changed, 140 insertions(+), 41 deletions(-) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 6226fe6..8001a28 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -710,7 +710,7 @@ class GatewayWebsocket: list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE' list_data = { - 'id': "everyone" // ?? + 'id': channel_id | 'everyone', 'guild_id': guild_id, 'ops': [ @@ -723,10 +723,10 @@ class GatewayWebsocket: // exists if op = 'SYNC' 'items': sync_item[], - // exists if op = 'INSERT' or 'DELETE' + // exists if op == 'INSERT' | 'DELETE' | 'UPDATE' 'index': num, - // exists if op = 'INSERT' + // exists if op == 'INSERT' | 'UPDATE' 'item': sync_item, } ], diff --git a/litecord/presence.py b/litecord/presence.py index 48ad80f..67f8edd 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -113,9 +113,11 @@ class PresenceManager: for member_list in lists: session_ids = await member_list.pres_update( int(member['user']['id']), - member['roles'], - state['status'], - game + { + 'roles': member['roles'], + 'status': state['status'], + 'game': game + } ) log.debug('Lazy Dispatch to {}', diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 8063383..6467d1c 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -48,10 +48,40 @@ class MemberList: Yields a tuple containing :class:`GroupInfo` and the List[Presence] for the group. """ + if not self.groups: + return + for group in self.groups: yield group, self.data[group.gid] +@dataclass +class Operation: + """Represents a member list operation.""" + list_op: str + params: Dict[str, Any] + + @property + def to_dict(self) -> dict: + res = { + 'op': self.list_op + } + + if self.list_op == 'SYNC': + res['items'] = self.params['items'] + + if self.list_op in ('SYNC', 'INVALIDATE'): + res['range'] = self.params['range'] + + if self.list_op in ('INSERT', 'DELETE', 'UPDATE'): + res['index'] = self.params['index'] + + if self.list_op in ('INSERT', 'UPDATE'): + res['item'] = self.params['item'] + + return res + + def _to_simple_group(presence: dict) -> str: """Return a simple group (not a role), given a presence.""" return 'offline' if presence['status'] == 'offline' else 'online' @@ -111,6 +141,13 @@ class GuildMemberList: """Get the global :class:`StateManager` instance.""" return self.main.app.state_manager + @property + def list_id(self): + """get the id of the member list.""" + return ('everyone' + if self.channel_id == self.guild_id + else str(self.channel_id)) + def _set_empty_list(self): """Set the member list as being empty.""" self.list = MemberList(None, None, None, None) @@ -311,7 +348,7 @@ class GuildMemberList: """Subscribe a shard to the member list.""" await self._init_check() - async def unsub(self, session_id: str): + def unsub(self, session_id: str): """Unsubscribe a shard from the member list""" try: self.state.pop(session_id) @@ -326,6 +363,41 @@ class GuildMemberList: if not self.state: self._set_empty_list() + def get_state(self, session_id: str): + state = self.state_man.fetch_raw(session_id) + + if not state: + self.unsub(session_id) + return + + return state + + async def _dispatch_sess(self, session_ids: List[str], + operations: List[Operation]): + + # construct the payload to dispatch + payload = { + 'id': self.list_id, + 'guild_id': str(self.guild_id), + + 'groups': [ + { + 'count': len(presences), + 'id': group.gid + } for group, presences in self.list + ], + + 'ops': [ + operation.to_dict + for operation in operations + ] + } + + states = map(self.get_state, session_ids) + for state in (s for s in states if s is not None): + await state.ws.dispatch( + 'GUILD_MEMBER_LIST_UPDATE', payload) + async def shard_query(self, session_id: str, ranges: list): """Send a GUILD_MEMBER_LIST_UPDATE event for a shard that is querying about the member list. @@ -342,9 +414,7 @@ class GuildMemberList: # a guild list with a channel id of the guild # represents the 'everyone' global list. - list_id = ('everyone' - if self.channel_id == self.guild_id - else str(self.channel_id)) + list_id = self.list_id # if everyone can read the channel, # we direct the request to the 'everyone' gml instance @@ -365,25 +435,7 @@ class GuildMemberList: await self._init_check() - # make sure this is a sane state - state = self.state_man.fetch_raw(session_id) - if not state: - await self.unsub(session_id) - return - - reply = { - 'guild_id': str(self.guild_id), - 'id': list_id, - - 'groups': [ - { - 'count': len(presences), - 'id': group.gid - } for group, presences in self.list - ], - - 'ops': [], - } + ops = [] for start, end in ranges: itemcount = end - start @@ -394,19 +446,64 @@ class GuildMemberList: self.state[session_id].add((start, end)) - reply['ops'].append({ - 'op': 'SYNC', + ops.append(Operation('SYNC', { 'range': [start, end], - 'items': self.items[start:end], - }) + 'items': self.items[start:end] + })) - # the first GUILD_MEMBER_LIST_UPDATE for a shard - # is dispatched here. - await state.ws.dispatch('GUILD_MEMBER_LIST_UPDATE', reply) + await self._dispatch_sess([session_id], ops) - async def pres_update(self, user_id: int, roles: List[str], - status: str, game: dict) -> List[str]: - return list(self.state) + async def pres_update(self, user_id: int, + partial_presence: Dict[str, Any]): + """Update a presence inside the member listlist.""" + await self._init_check() + + for _group, presences in self.list: + p_idx = index_by_func( + lambda p: p['user']['id'] == str(user_id), + presences) + + if not p_idx: + continue + + presences[p_idx].update(partial_presence) + + item_index = index_by_func( + lambda p: p.get('user', {}).get('id') == str(user_id), + self.items + ) + + pprint.pprint(self.items) + + if not item_index: + log.warning('lazy guild got invalid pres update uid={}', + user_id) + return [] + + item = self.items[item_index] + + def _is_in(sess_id): + ranges = self.state[sess_id] + + for range_start, range_end in ranges: + if range_start <= item_index <= range_end: + return True + + return False + + session_ids = filter(_is_in, self.state.keys()) + + await self._dispatch_sess( + session_ids, + [ + Operation('UPDATE', { + 'index': item_index, + 'item': item, + }) + ] + ) + + return list(session_ids) async def dispatch(self, event: str, data: Any): """Modify the member list and dispatch the respective @@ -473,4 +570,4 @@ class LazyGuildDispatcher(Dispatcher): async def unsub(self, chan_id, session_id): gml = await self.get_gml(chan_id) - await gml.unsub(session_id) + gml.unsub(session_id) diff --git a/litecord/utils.py b/litecord/utils.py index 1a2b676..2fda9d5 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -32,7 +32,7 @@ def dict_get(mapping, key, default): def index_by_func(function, indexable: iter) -> int: """Search in an idexable and return the index number for an iterm that has func(item) = True.""" - for index, item in indexable: + for index, item in enumerate(indexable): if function(item): return index From 773ab8fd18472ba17d95732486c082c267d92352 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 19:08:31 -0300 Subject: [PATCH 62/69] pubsub.lazy_guild: fix fetching user id from item - pubsub.lazy_guild: fix get_state on unknown session id --- litecord/pubsub/lazy_guild.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 6467d1c..253115c 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -364,14 +364,13 @@ class GuildMemberList: self._set_empty_list() def get_state(self, session_id: str): - state = self.state_man.fetch_raw(session_id) - - if not state: + try: + state = self.state_man.fetch_raw(session_id) + return state + except KeyError: self.unsub(session_id) return - return state - async def _dispatch_sess(self, session_ids: List[str], operations: List[Operation]): @@ -394,10 +393,16 @@ class GuildMemberList: } states = map(self.get_state, session_ids) + dispatched = [] + for state in (s for s in states if s is not None): await state.ws.dispatch( 'GUILD_MEMBER_LIST_UPDATE', payload) + dispatched.append(state.session_id) + + return dispatched + async def shard_query(self, session_id: str, ranges: list): """Send a GUILD_MEMBER_LIST_UPDATE event for a shard that is querying about the member list. @@ -468,13 +473,14 @@ class GuildMemberList: presences[p_idx].update(partial_presence) + def _get_id(p): + return p.get('member', {}).get('user', {}).get('id') + item_index = index_by_func( - lambda p: p.get('user', {}).get('id') == str(user_id), + lambda p: _get_id(p) == str(user_id), self.items ) - pprint.pprint(self.items) - if not item_index: log.warning('lazy guild got invalid pres update uid={}', user_id) @@ -493,7 +499,7 @@ class GuildMemberList: session_ids = filter(_is_in, self.state.keys()) - await self._dispatch_sess( + return await self._dispatch_sess( session_ids, [ Operation('UPDATE', { @@ -503,8 +509,6 @@ class GuildMemberList: ] ) - return list(session_ids) - async def dispatch(self, event: str, data: Any): """Modify the member list and dispatch the respective events to subscribed shards. From e1a946eb876fe011a8c1c0dfe5e844fd65481cbc Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 19:10:27 -0300 Subject: [PATCH 63/69] pubsub.lazy_guild: use filter instead of genexpr --- litecord/pubsub/lazy_guild.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 253115c..a7f3248 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -393,9 +393,11 @@ class GuildMemberList: } states = map(self.get_state, session_ids) + states = filter(lambda state: state is not None, states) + dispatched = [] - for state in (s for s in states if s is not None): + for state in states: await state.ws.dispatch( 'GUILD_MEMBER_LIST_UPDATE', payload) From a394c02477db7ac6b7b02c4e63a3fa262d0b7759 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 19:15:07 -0300 Subject: [PATCH 64/69] schemas: add USER_SETTINGS.status --- litecord/schemas.py | 2 ++ schema.sql | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/litecord/schemas.py b/litecord/schemas.py index 710ccd0..e74dda0 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -496,6 +496,8 @@ USER_SETTINGS = { 'show_current_game': {'type': 'boolean', 'required': False}, 'timezone_offset': {'type': 'number', 'required': False}, + + 'status': {'type': 'status_external', 'required': False} } RELATIONSHIP = { diff --git a/schema.sql b/schema.sql index 11b2caf..2d3d5c6 100644 --- a/schema.sql +++ b/schema.sql @@ -134,6 +134,10 @@ CREATE TABLE IF NOT EXISTS user_settings ( -- appearance message_display_compact bool DEFAULT false, + + -- for now we store status but don't + -- actively use it, since the official client + -- sends its own presence on IDENTIFY status text DEFAULT 'online' NOT NULL, theme text DEFAULT 'dark' NOT NULL, developer_mode bool DEFAULT true, From 7be9d30f5db917407f66045d000be17a6a7fd086 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 19:16:44 -0300 Subject: [PATCH 65/69] cmd.migration: add debug log --- manage/cmd/migration/command.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/manage/cmd/migration/command.py b/manage/cmd/migration/command.py index fd237fb..a1e87c6 100644 --- a/manage/cmd/migration/command.py +++ b/manage/cmd/migration/command.py @@ -117,6 +117,8 @@ async def migrate_cmd(app, _args): local_change = local_change or 0 latest_change = ctx.latest + log.debug('local: {}, latest: {}', local_change, latest_change) + if local_change == latest_change: print('no changes to do, exiting') return From 8b093d3d1610ad431ed67ad8208228b071a932c4 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 21:07:31 -0300 Subject: [PATCH 66/69] pubsub.lazy_guild: better algorithm for presence updates this gives the separation between "complex" and "simple" presence updates we can generalize on. --- litecord/pubsub/lazy_guild.py | 130 +++++++++++++++++++++++++--------- 1 file changed, 97 insertions(+), 33 deletions(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index a7f3248..3370a5d 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -245,6 +245,27 @@ class GuildMemberList: ] self.list.group_info = {g.gid: g for g in role_groups} + async def get_group(self, member_id: int, + roles: List[Union[str, int]], + status: str) -> int: + """Return a fitting group ID for the user.""" + member_roles = list(map(int, roles)) + + # get the member's permissions relative to the channel + # (accounting for channel overwrites) + member_perms = await get_permissions( + member_id, self.channel_id, storage=self.storage) + + if not member_perms.bits.read_messages: + return None + + # if the member is offline, we + # default give them the offline group. + group_id = ('offline' if status == 'offline' + else self._calc_member_group(member_roles, status)) + + return group_id + async def _pass_1(self, guild_presences: List[Presence]): """First pass on generating the member list. @@ -253,22 +274,9 @@ class GuildMemberList: for presence in guild_presences: member_id = int(presence['user']['id']) - # list of roles for the member - member_roles = list(map(int, presence['roles'])) - - # get the member's permissions relative to the channel - # (accounting for channel overwrites) - member_perms = await get_permissions( - member_id, self.channel_id, storage=self.storage) - - if not member_perms.bits.read_messages: - continue - - # if the member is offline, we - # default give them the offline group. - status = presence['status'] - group_id = ('offline' if status == 'offline' - else self._calc_member_group(member_roles, status)) + group_id = await self.get_group( + member_id, presence['roles'], presence['status'] + ) self.list.data[group_id].append(presence) @@ -460,24 +468,12 @@ class GuildMemberList: await self._dispatch_sess([session_id], ops) - async def pres_update(self, user_id: int, - partial_presence: Dict[str, Any]): - """Update a presence inside the member listlist.""" - await self._init_check() - - for _group, presences in self.list: - p_idx = index_by_func( - lambda p: p['user']['id'] == str(user_id), - presences) - - if not p_idx: - continue - - presences[p_idx].update(partial_presence) - - def _get_id(p): - return p.get('member', {}).get('user', {}).get('id') + async def _pres_update_simple(self, user_id: int): + def _get_id(item): + # item can be a group item or a member item + return item.get('member', {}).get('user', {}).get('id') + # get the updated item's index item_index = index_by_func( lambda p: _get_id(p) == str(user_id), self.items @@ -490,6 +486,8 @@ class GuildMemberList: item = self.items[item_index] + # only dispatch to sessions + # that are subscribed to the given item's index def _is_in(sess_id): ranges = self.state[sess_id] @@ -501,6 +499,8 @@ class GuildMemberList: session_ids = filter(_is_in, self.state.keys()) + # simple update means we just give an UPDATE + # operation return await self._dispatch_sess( session_ids, [ @@ -511,6 +511,70 @@ class GuildMemberList: ] ) + async def _pres_update_complex(self, user_id: int, + old_group: str, new_group: str): + raise NotImplementedError + + async def pres_update(self, user_id: int, + partial_presence: Dict[str, Any]): + """Update a presence inside the member list. + + There are 4 types of updates that can happen for a user in a group: + - from 'offline' to any + - from any to 'offline' + - from any to any + - from G to G (with G being any group) + + any: 'online' | role_id + + All first, second, and third updates are 'complex' updates, + which means we'll have to change the group the user is on + to account for them. + + The fourth is a 'simple' change, since we're not changing + the group a user is on, and so there's less overhead + involved. + """ + await self._init_check() + + old_group, old_presence = None, None + + for group, presences in self.list: + p_idx = index_by_func( + lambda p: p['user']['id'] == str(user_id), + presences) + + if not p_idx: + continue + + # make a copy since we're modifying in-place + old_group = group.gid + old_presence = dict(presences[p_idx]) + + # be ready if it is a simple update + presences[p_idx].update(partial_presence) + break + + if not old_group: + log.warning('pres update with unknown old group uid={}', + user_id) + return [] + + roles = partial_presence.get('roles', old_presence['roles']) + new_status = partial_presence.get('status', old_presence['status']) + + new_group = await self.get_group(user_id, roles, new_status) + + log.debug('pres update: gid={} cid={} old_g={} new_g={}', + self.guild_id, self.channel_id, old_group, new_group) + + # if we're going to the same group, + # treat this as a simple update + if old_group == new_group: + return await self._pres_update_simple(user_id) + + return await self._pres_update_complex(user_id, old_group, new_group) + async def dispatch(self, event: str, data: Any): """Modify the member list and dispatch the respective events to subscribed shards. From 134cc0eec857c7eb6c21a118482ea37f71e56692 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 22:51:16 -0300 Subject: [PATCH 67/69] pubsub.lazy_guild: add draft impl for complex updates --- litecord/pubsub/lazy_guild.py | 142 +++++++++++++++++++++++++++------- 1 file changed, 114 insertions(+), 28 deletions(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 3370a5d..b2e4b59 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -87,6 +87,19 @@ def _to_simple_group(presence: dict) -> str: return 'offline' if presence['status'] == 'offline' else 'online' +def display_name(member_nicks: Dict[str, str], presence: Presence) -> str: + """Return the display name of a presence. + + Used to sort groups. + """ + uid = presence['user']['id'] + + uname = presence['user']['username'] + nick = member_nicks.get(uid) + + return nick or uname + + class GuildMemberList: """This class stores the current member list information for a guild (by channel). @@ -280,7 +293,8 @@ class GuildMemberList: self.list.data[group_id].append(presence) - async def _sort_groups(self): + async def get_member_nicks_dict(self) -> dict: + """Get a dictionary with nickname information.""" members = await self.storage.get_member_data(self.guild_id) # make a dictionary of member ids to nicknames @@ -289,17 +303,16 @@ class GuildMemberList: member_nicks = {m['user']['id']: m.get('nick') for m in members} + return member_nicks + + async def _sort_groups(self): + member_nicks = await self.get_member_nicks_dict() + for group_members in self.list.data.values(): - def display_name(presence: Presence) -> str: - uid = presence['user']['id'] - - uname = presence['user']['username'] - nick = member_nicks.get(uid) - - return nick or uname # this should update the list in-place - group_members.sort(key=display_name) + group_members.sort( + key=lambda p: display_name(member_nicks, p)) async def _init_member_list(self): """Generate the main member list with groups.""" @@ -468,36 +481,46 @@ class GuildMemberList: await self._dispatch_sess([session_id], ops) - async def _pres_update_simple(self, user_id: int): + def get_item_index(self, user_id: Union[str, int]): def _get_id(item): # item can be a group item or a member item return item.get('member', {}).get('user', {}).get('id') # get the updated item's index - item_index = index_by_func( + return index_by_func( lambda p: _get_id(p) == str(user_id), self.items ) + def state_is_subbed(self, item_index, session_id: str) -> bool: + """Return if a state's ranges include the given + item index.""" + + ranges = self.state[sess_id] + + for range_start, range_end in ranges: + if range_start <= item_index <= range_end: + return True + + return False + + def get_subs(self, item_index: int) -> filter: + """Get the list of subscribed states to a given item.""" + return filter( + lambda sess_id: self.state_is_subbed(item_index, sess_id), + self.state.keys() + ) + + async def _pres_update_simple(self, user_id: int): + item_index = self.get_item_index(user_id) + if not item_index: log.warning('lazy guild got invalid pres update uid={}', user_id) return [] item = self.items[item_index] - - # only dispatch to sessions - # that are subscribed to the given item's index - def _is_in(sess_id): - ranges = self.state[sess_id] - - for range_start, range_end in ranges: - if range_start <= item_index <= range_end: - return True - - return False - - session_ids = filter(_is_in, self.state.keys()) + session_ids = self.get_subs(item_index) # simple update means we just give an UPDATE # operation @@ -512,8 +535,69 @@ class GuildMemberList: ) async def _pres_update_complex(self, user_id: int, - old_group: str, new_group: str): - raise NotImplementedError + old_group: str, old_index: int, + new_group: str): + """Move a member between groups.""" + log.debug('complex update: uid={} old={} old_idx={} new={}', + user_id, old_group, old_index, new_group) + old_group_presences = self.list.data[old_group] + old_item_index = self.get_item_index(user_id) + + # make a copy of current presence to insert in the new group + current_presence = dict(old_group_presences[old_index]) + + # step 1: remove the old presence (old_index is relative + # to the group, and not the items list) + del old_group_presences[old_index] + + # we need to insert current_presence to the new group + # but we also need to calculate its index to insert on. + presences = self.list.data[new_group] + + best_index = 0 + member_nicks = await self.get_member_nicks_dict() + current_name = display_name(member_nicks, current_presence) + + # go through each one until we find the best placement + for presence in presences: + name = display_name(member_nicks, presence) + + print(name, current_name, name < current_name) + + # TODO: check if this works + if name < current_name: + break + + best_index += 1 + + # insert the presence at the index + presences.insert(best_index + 1, current_presence) + + new_item_index = self.get_item_index(user_id) + + log.debug('assigned new item index {} to uid {}', + new_item_index, user_id) + + session_ids_old = self.get_subs(old_item_index) + session_ids_new = self.get_subs(new_item_index) + + # dispatch events to both the old states and + # new states. + return await self._dispatch_sess( + session_ids_old + session_ids_new, + [ + Operation('DELETE', { + 'index': old_item_index, + }), + + Operation('INSERT', { + 'index': new_item_index, + 'item': { + 'member': current_presence + } + }) + ] + ) async def pres_update(self, user_id: int, partial_presence: Dict[str, Any]): @@ -537,7 +621,7 @@ class GuildMemberList: """ await self._init_check() - old_group, old_presence = None, None + old_group, old_index, old_presence = None, None, None for group, presences in self.list: p_idx = index_by_func( @@ -549,6 +633,7 @@ class GuildMemberList: # make a copy since we're modifying in-place old_group = group.gid + old_index = p_idx old_presence = dict(presences[p_idx]) # be ready if it is a simple update @@ -573,7 +658,8 @@ class GuildMemberList: if old_group == new_group: return await self._pres_update_simple(user_id) - return await self._pres_update_complex(user_id, old_group, new_group) + return await self._pres_update_complex( + user_id, old_group, old_index, new_group) async def dispatch(self, event: str, data: Any): """Modify the member list and dispatch the respective From 3b532fa8b0471fe7f1b6d381ecf7e7fbf87ff656 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 22:52:26 -0300 Subject: [PATCH 68/69] blueprints.users: fix settings being updated for everyone --- litecord/blueprints/users.py | 3 ++- litecord/pubsub/lazy_guild.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 0f75891..cd47013 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -318,7 +318,8 @@ async def patch_current_settings(): await app.db.execute(f""" UPDATE user_settings SET {key}=$1 - """, j[key]) + WHERE id = $2 + """, j[key], user_id) settings = await app.storage.get_user_settings(user_id) await app.dispatcher.dispatch_user( diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index b2e4b59..7658ca5 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -496,7 +496,7 @@ class GuildMemberList: """Return if a state's ranges include the given item index.""" - ranges = self.state[sess_id] + ranges = self.state[session_id] for range_start, range_end in ranges: if range_start <= item_index <= range_end: From 748eacf11244c5f9a8e21155d0bc0d769fe2c732 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 7 Nov 2018 23:03:48 -0300 Subject: [PATCH 69/69] pubsub.lazy_guild: fix bugs around p_idx calculation --- litecord/pubsub/lazy_guild.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 7658ca5..d94a0bf 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -482,6 +482,7 @@ class GuildMemberList: await self._dispatch_sess([session_id], ops) def get_item_index(self, user_id: Union[str, int]): + """Get the item index a user is on.""" def _get_id(item): # item can be a group item or a member item return item.get('member', {}).get('user', {}).get('id') @@ -514,7 +515,7 @@ class GuildMemberList: async def _pres_update_simple(self, user_id: int): item_index = self.get_item_index(user_id) - if not item_index: + if item_index is None: log.warning('lazy guild got invalid pres update uid={}', user_id) return [] @@ -584,7 +585,9 @@ class GuildMemberList: # dispatch events to both the old states and # new states. return await self._dispatch_sess( - session_ids_old + session_ids_new, + # inefficient, but necessary since we + # want to merge both session ids. + list(session_ids_old) + list(session_ids_new), [ Operation('DELETE', { 'index': old_item_index, @@ -628,7 +631,11 @@ class GuildMemberList: lambda p: p['user']['id'] == str(user_id), presences) - if not p_idx: + log.debug('p_idx for group {!r} = {}', + group.gid, p_idx) + + if p_idx is None: + log.debug('skipping group {}', group) continue # make a copy since we're modifying in-place