From 11d4b54f879f1929d7dd4989ddf8f9d2cc69dce1 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 17 Nov 2018 02:14:10 -0300 Subject: [PATCH] split Storage into UserStorage this should help with the amount of methods being tossed in the Storage class. --- litecord/blueprints/dms.py | 2 +- litecord/blueprints/relationships.py | 7 +- litecord/blueprints/user/settings.py | 5 +- litecord/gateway/websocket.py | 13 +- litecord/storage.py | 272 +------------------------- litecord/user_storage.py | 280 +++++++++++++++++++++++++++ run.py | 4 + 7 files changed, 301 insertions(+), 282 deletions(-) create mode 100644 litecord/user_storage.py diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py index 57a4556..011b652 100644 --- a/litecord/blueprints/dms.py +++ b/litecord/blueprints/dms.py @@ -19,7 +19,7 @@ bp = Blueprint('dms', __name__) async def get_dms(): """Get the open DMs for the user.""" user_id = await token_check() - dms = await app.storage.get_dms(user_id) + dms = await app.user_storage.get_dms(user_id) return jsonify(dms) diff --git a/litecord/blueprints/relationships.py b/litecord/blueprints/relationships.py index 43788d9..f24a238 100644 --- a/litecord/blueprints/relationships.py +++ b/litecord/blueprints/relationships.py @@ -12,7 +12,8 @@ bp = Blueprint('relationship', __name__) @bp.route('/@me/relationships', methods=['GET']) async def get_me_relationships(): user_id = await token_check() - return jsonify(await app.storage.get_relationships(user_id)) + return jsonify( + await app.user_storage.get_relationships(user_id)) async def _unsub_friend(user_id, peer_id): @@ -254,8 +255,8 @@ async def get_mutual_friends(peer_id: int): # NOTE: maybe this could be better with pure SQL calculations # but it would be beyond my current SQL knowledge, so... - user_rels = await app.storage.get_relationships(user_id) - peer_rels = await app.storage.get_relationships(peer_id) + user_rels = await app.user_storage.get_relationships(user_id) + peer_rels = await app.user_storage.get_relationships(peer_id) user_friends = {rel['user']['id'] for rel in user_rels if rel['type'] == _friend} diff --git a/litecord/blueprints/user/settings.py b/litecord/blueprints/user/settings.py index ae003a4..d812827 100644 --- a/litecord/blueprints/user/settings.py +++ b/litecord/blueprints/user/settings.py @@ -51,7 +51,7 @@ async def patch_guild_settings(guild_id: int): # querying the guild settings information before modifying # will make sure they exist in the table. - await app.storage.get_guild_settings_one(user_id, guild_id) + await app.user_storage.get_guild_settings_one(user_id, guild_id) for field in (k for k in j.keys() if k != 'channel_overrides'): await app.db.execute(f""" @@ -86,7 +86,8 @@ async def patch_guild_settings(guild_id: int): AND guild_settings_channel_overrides.channel_id = $3 """, user_id, guild_id, chan_id, chan_overrides[field]) - settings = await app.storage.get_guild_settings_one(user_id, guild_id) + settings = await app.user_storage.get_guild_settings_one( + user_id, guild_id) await app.dispatcher.dispatch_user( user_id, 'USER_GUILD_SETTINGS_UPDATE', settings) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 2cd25cd..32531b1 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -247,20 +247,21 @@ class GatewayWebsocket: user_id = self.state.user_id - relationships = await self.storage.get_relationships(user_id) + relationships = await self.user_storage.get_relationships(user_id) friend_ids = [int(r['user']['id']) for r in relationships if r['type'] == RelationshipType.FRIEND.value] friend_presences = await self.ext.presence.friend_presences(friend_ids) + settings = await self.user_storage.get_user_settings(user_id) return { - 'user_settings': await self.storage.get_user_settings(user_id), - 'notes': await self.storage.fetch_notes(user_id), + 'user_settings': settings, + 'notes': await self.user_storage.fetch_notes(user_id), 'relationships': relationships, 'presences': friend_presences, - 'read_state': await self.storage.get_read_state(user_id), - 'user_guild_settings': await self.storage.get_guild_settings( + 'read_state': await self.user_storage.get_read_state(user_id), + 'user_guild_settings': await self.user_storage.get_guild_settings( user_id), 'friend_suggestion_count': 0, @@ -288,7 +289,7 @@ class GatewayWebsocket: 'v': 6, 'user': user, - 'private_channels': await self.storage.get_dms(user_id), + 'private_channels': await self.user_storage.get_dms(user_id), 'guilds': guilds, 'session_id': self.state.session_id, diff --git a/litecord/storage.py b/litecord/storage.py index 76a11ed..27912ec 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -3,8 +3,8 @@ from typing import List, Dict, Any from logbook import Logger -from .enums import ChannelType, RelationshipType -from .schemas import USER_MENTION, ROLE_MENTION +from litecord.enums import ChannelType +from litecord.schemas import USER_MENTION, ROLE_MENTION from litecord.blueprints.channel.reactions import ( EmojiType, emoji_sql, partial_emoji ) @@ -715,17 +715,6 @@ class Storage: return res - async def fetch_notes(self, user_id: int) -> dict: - """Fetch a users' notes""" - note_rows = await self.db.fetch(""" - SELECT target_id, note - FROM notes - WHERE user_id = $1 - """, user_id) - - return {str(row['target_id']): row['note'] - for row in note_rows} - async def get_invite(self, invite_code: str) -> dict: """Fetch invite information given its code.""" invite = await self.db.fetchrow(""" @@ -802,131 +791,6 @@ class Storage: return dinv - async def get_user_settings(self, user_id: int) -> Dict[str, Any]: - """Get current user settings.""" - row = await self._fetchrow_with_json(""" - SELECT * - FROM user_settings - WHERE id = $1 - """, user_id) - - if not row: - log.info('Generating user settings for {}', user_id) - - await self.db.execute(""" - INSERT INTO user_settings (id) - VALUES ($1) - """, user_id) - - # recalling get_user_settings - # should work after adding - return await self.get_user_settings(user_id) - - drow = dict(row) - drow.pop('id') - return drow - - async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]: - """Get all relationships for a user.""" - # first, fetch all friendships outgoing - # from the user - _friend = RelationshipType.FRIEND.value - _block = RelationshipType.BLOCK.value - _incoming = RelationshipType.INCOMING.value - _outgoing = RelationshipType.OUTGOING.value - - # check all outgoing friends - friends = await self.db.fetch(""" - SELECT user_id, peer_id, rel_type - FROM relationships - WHERE user_id = $1 AND rel_type = $2 - """, user_id, _friend) - friends = list(map(dict, friends)) - - # mutuals is a list of ints - # of people who are actually friends - # and accepted the friend request - mutuals = [] - - # for each outgoing, find if theres an outgoing from them - for row in friends: - is_friend = await self.db.fetchrow( - """ - SELECT user_id, peer_id - FROM relationships - WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 - """, row['peer_id'], row['user_id'], - _friend) - - if is_friend is not None: - mutuals.append(row['peer_id']) - - # fetch friend requests directed at us - incoming_friends = await self.db.fetch(""" - SELECT user_id, peer_id - FROM relationships - WHERE peer_id = $1 AND rel_type = $2 - """, user_id, _friend) - - # only need their ids - incoming_friends = [r['user_id'] for r in incoming_friends - if r['user_id'] not in mutuals] - - # only fetch blocks we did, - # not fetching the ones people did to us - blocks = await self.db.fetch(""" - SELECT user_id, peer_id, rel_type - FROM relationships - WHERE user_id = $1 AND rel_type = $2 - """, user_id, _block) - blocks = list(map(dict, blocks)) - - res = [] - - for drow in friends: - drow['type'] = drow['rel_type'] - drow['id'] = str(drow['peer_id']) - drow.pop('rel_type') - - # check if the receiver is a mutual - # if it isnt, its still on a friend request stage - if drow['peer_id'] not in mutuals: - drow['type'] = _outgoing - - drow['user'] = await self.get_user(drow['peer_id']) - - drow.pop('user_id') - drow.pop('peer_id') - res.append(drow) - - for peer_id in incoming_friends: - res.append({ - 'id': str(peer_id), - 'user': await self.get_user(peer_id), - 'type': _incoming, - }) - - for drow in blocks: - drow['type'] = drow['rel_type'] - drow.pop('rel_type') - - drow['id'] = str(drow['peer_id']) - drow['user'] = await self.get_user(drow['peer_id']) - - drow.pop('user_id') - drow.pop('peer_id') - res.append(drow) - - return res - - async def get_friend_ids(self, user_id: int) -> List[int]: - """Get all friend IDs for a user.""" - rels = await self.get_relationships(user_id) - - return [int(r['user']['id']) - for r in rels - if r['type'] == RelationshipType.FRIEND.value] - async def get_dm(self, dm_id: int, user_id: int = None): dm_chan = await self.get_channel(dm_id) @@ -937,50 +801,6 @@ class Storage: return dm_chan - async def get_dms(self, user_id: int) -> List[Dict[str, Any]]: - """Get all DM channels for a user, including group DMs. - - This will only fetch channels the user has in their state, - which is different than the whole list of DM channels. - """ - dm_ids = await self.db.fetch(""" - SELECT dm_id - FROM dm_channel_state - WHERE user_id = $1 - """, user_id) - - dm_ids = [r['dm_id'] for r in dm_ids] - - res = [] - - for dm_id in dm_ids: - dm_chan = await self.get_dm(dm_id, user_id) - res.append(dm_chan) - - return res - - async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: - """Get the read state for a user.""" - rows = await self.db.fetch(""" - SELECT channel_id, last_message_id, mention_count - FROM user_read_state - WHERE user_id = $1 - """, user_id) - - res = [] - - for row in rows: - drow = dict(row) - - drow['id'] = str(drow['channel_id']) - drow.pop('channel_id') - - drow['last_message_id'] = str(drow['last_message_id']) - - res.append(drow) - - return res - async def guild_from_channel(self, channel_id: int): """Get the guild id coming from a channel id.""" return await self.db.fetchval(""" @@ -1003,91 +823,3 @@ class Storage: parties.remove(user_id) return parties[0] - - async def get_guild_settings_one(self, user_id: int, - guild_id: int) -> dict: - """Get guild settings information for a single guild.""" - row = await self.db.fetchrow(""" - SELECT guild_id::text, suppress_everyone, muted, - message_notifications, mobile_push - FROM guild_settings - WHERE user_id = $1 AND guild_id = $2 - """, user_id, guild_id) - - if not row: - await self.db.execute(""" - INSERT INTO guild_settings (user_id, guild_id) - VALUES ($1, $2) - """, user_id, guild_id) - - return await self.get_guild_settings_one(user_id, guild_id) - - gid = int(row['guild_id']) - drow = dict(row) - - chan_overrides = {} - - overrides = await self.db.fetch(""" - SELECT channel_id::text, muted, message_notifications - FROM guild_settings_channel_overrides - WHERE - user_id = $1 - AND guild_id = $2 - """, user_id, gid) - - for chan_row in overrides: - dcrow = dict(chan_row) - - chan_id = dcrow['channel_id'] - dcrow.pop('channel_id') - - chan_overrides[chan_id] = dcrow - - return {**drow, **{ - 'channel_overrides': chan_overrides - }} - - async def get_guild_settings(self, user_id: int): - """Get the specific User Guild Settings, - for all guilds a user is on.""" - - res = [] - - settings = await self.db.fetch(""" - SELECT guild_id::text, suppress_everyone, muted, - message_notifications, mobile_push - FROM guild_settings - WHERE user_id = $1 - """, user_id) - - for row in settings: - gid = int(row['guild_id']) - drow = dict(row) - - chan_overrides = {} - - overrides = await self.db.fetch(""" - SELECT channel_id::text, muted, message_notifications - FROM guild_settings_channel_overrides - WHERE - user_id = $1 - AND guild_id = $2 - """, user_id, gid) - - for chan_row in overrides: - dcrow = dict(chan_row) - - # channel_id isn't on the value of the dict - # so we query it (for the key) then pop - # from the value - chan_id = dcrow['channel_id'] - dcrow.pop('channel_id') - - chan_overrides[chan_id] = dcrow - - res.append({**drow, **{ - 'channel_overrides': chan_overrides - }}) - - return res - diff --git a/litecord/user_storage.py b/litecord/user_storage.py new file mode 100644 index 0000000..f02c15c --- /dev/null +++ b/litecord/user_storage.py @@ -0,0 +1,280 @@ +from typing import List, Dict, Any + +from logbook import Logger +from litecord.enums import RelationshipType + +log = Logger(__name__) + + +class UserStorage: + """Storage functions related to a single user.""" + def __init__(self, storage): + self.storage = storage + self.db = storage.db + + async def fetch_notes(self, user_id: int) -> dict: + """Fetch a users' notes""" + note_rows = await self.db.fetch(""" + SELECT target_id, note + FROM notes + WHERE user_id = $1 + """, user_id) + + return {str(row['target_id']): row['note'] + for row in note_rows} + + async def get_user_settings(self, user_id: int) -> Dict[str, Any]: + """Get current user settings.""" + row = await self._fetchrow_with_json(""" + SELECT * + FROM user_settings + WHERE id = $1 + """, user_id) + + if not row: + log.info('Generating user settings for {}', user_id) + + await self.db.execute(""" + INSERT INTO user_settings (id) + VALUES ($1) + """, user_id) + + # recalling get_user_settings + # should work after adding + return await self.get_user_settings(user_id) + + drow = dict(row) + drow.pop('id') + return drow + + async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]: + """Get all relationships for a user.""" + # first, fetch all friendships outgoing + # from the user + _friend = RelationshipType.FRIEND.value + _block = RelationshipType.BLOCK.value + _incoming = RelationshipType.INCOMING.value + _outgoing = RelationshipType.OUTGOING.value + + # check all outgoing friends + friends = await self.db.fetch(""" + SELECT user_id, peer_id, rel_type + FROM relationships + WHERE user_id = $1 AND rel_type = $2 + """, user_id, _friend) + friends = list(map(dict, friends)) + + # mutuals is a list of ints + # of people who are actually friends + # and accepted the friend request + mutuals = [] + + # for each outgoing, find if theres an outgoing from them + for row in friends: + is_friend = await self.db.fetchrow( + """ + SELECT user_id, peer_id + FROM relationships + WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 + """, row['peer_id'], row['user_id'], + _friend) + + if is_friend is not None: + mutuals.append(row['peer_id']) + + # fetch friend requests directed at us + incoming_friends = await self.db.fetch(""" + SELECT user_id, peer_id + FROM relationships + WHERE peer_id = $1 AND rel_type = $2 + """, user_id, _friend) + + # only need their ids + incoming_friends = [r['user_id'] for r in incoming_friends + if r['user_id'] not in mutuals] + + # only fetch blocks we did, + # not fetching the ones people did to us + blocks = await self.db.fetch(""" + SELECT user_id, peer_id, rel_type + FROM relationships + WHERE user_id = $1 AND rel_type = $2 + """, user_id, _block) + blocks = list(map(dict, blocks)) + + res = [] + + for drow in friends: + drow['type'] = drow['rel_type'] + drow['id'] = str(drow['peer_id']) + drow.pop('rel_type') + + # check if the receiver is a mutual + # if it isnt, its still on a friend request stage + if drow['peer_id'] not in mutuals: + drow['type'] = _outgoing + + drow['user'] = await self.get_user(drow['peer_id']) + + drow.pop('user_id') + drow.pop('peer_id') + res.append(drow) + + for peer_id in incoming_friends: + res.append({ + 'id': str(peer_id), + 'user': await self.storage.get_user(peer_id), + 'type': _incoming, + }) + + for drow in blocks: + drow['type'] = drow['rel_type'] + drow.pop('rel_type') + + drow['id'] = str(drow['peer_id']) + drow['user'] = await self.storage.get_user(drow['peer_id']) + + drow.pop('user_id') + drow.pop('peer_id') + res.append(drow) + + return res + + async def get_friend_ids(self, user_id: int) -> List[int]: + """Get all friend IDs for a user.""" + rels = await self.get_relationships(user_id) + + return [int(r['user']['id']) + for r in rels + if r['type'] == RelationshipType.FRIEND.value] + + async def get_dms(self, user_id: int) -> List[Dict[str, Any]]: + """Get all DM channels for a user, including group DMs. + + This will only fetch channels the user has in their state, + which is different than the whole list of DM channels. + """ + dm_ids = await self.db.fetch(""" + SELECT dm_id + FROM dm_channel_state + WHERE user_id = $1 + """, user_id) + + dm_ids = [r['dm_id'] for r in dm_ids] + + res = [] + + for dm_id in dm_ids: + dm_chan = await self.storage.get_dm(dm_id, user_id) + res.append(dm_chan) + + return res + + async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: + """Get the read state for a user.""" + rows = await self.db.fetch(""" + SELECT channel_id, last_message_id, mention_count + FROM user_read_state + WHERE user_id = $1 + """, user_id) + + res = [] + + for row in rows: + drow = dict(row) + + drow['id'] = str(drow['channel_id']) + drow.pop('channel_id') + + drow['last_message_id'] = str(drow['last_message_id']) + + res.append(drow) + + return res + + async def get_guild_settings_one(self, user_id: int, + guild_id: int) -> dict: + """Get guild settings information for a single guild.""" + row = await self.db.fetchrow(""" + SELECT guild_id::text, suppress_everyone, muted, + message_notifications, mobile_push + FROM guild_settings + WHERE user_id = $1 AND guild_id = $2 + """, user_id, guild_id) + + if not row: + await self.db.execute(""" + INSERT INTO guild_settings (user_id, guild_id) + VALUES ($1, $2) + """, user_id, guild_id) + + return await self.get_guild_settings_one(user_id, guild_id) + + gid = int(row['guild_id']) + drow = dict(row) + + chan_overrides = {} + + overrides = await self.db.fetch(""" + SELECT channel_id::text, muted, message_notifications + FROM guild_settings_channel_overrides + WHERE + user_id = $1 + AND guild_id = $2 + """, user_id, gid) + + for chan_row in overrides: + dcrow = dict(chan_row) + + chan_id = dcrow['channel_id'] + dcrow.pop('channel_id') + + chan_overrides[chan_id] = dcrow + + return {**drow, **{ + 'channel_overrides': chan_overrides + }} + + async def get_guild_settings(self, user_id: int): + """Get the specific User Guild Settings, + for all guilds a user is on.""" + + res = [] + + settings = await self.db.fetch(""" + SELECT guild_id::text, suppress_everyone, muted, + message_notifications, mobile_push + FROM guild_settings + WHERE user_id = $1 + """, user_id) + + for row in settings: + gid = int(row['guild_id']) + drow = dict(row) + + chan_overrides = {} + + overrides = await self.db.fetch(""" + SELECT channel_id::text, muted, message_notifications + FROM guild_settings_channel_overrides + WHERE + user_id = $1 + AND guild_id = $2 + """, user_id, gid) + + for chan_row in overrides: + dcrow = dict(chan_row) + + # channel_id isn't on the value of the dict + # so we query it (for the key) then pop + # from the value + chan_id = dcrow['channel_id'] + dcrow.pop('channel_id') + + chan_overrides[chan_id] = dcrow + + res.append({**drow, **{ + 'channel_overrides': chan_overrides + }}) + + return res diff --git a/run.py b/run.py index e0a35d3..f897cd3 100644 --- a/run.py +++ b/run.py @@ -43,6 +43,7 @@ from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager from litecord.storage import Storage +from litecord.user_storage import UserStorage from litecord.dispatcher import EventDispatcher from litecord.presence import PresenceManager from litecord.images import IconManager @@ -184,7 +185,10 @@ def init_app_managers(app): app.loop = asyncio.get_event_loop() app.ratelimiter = RatelimitManager() app.state_manager = StateManager() + app.storage = Storage(app.db) + app.user_storage = UserStorage(app.storage) + app.icons = IconManager(app) app.dispatcher = EventDispatcher(app)