""" Litecord Copyright (C) 2018-2019 Luna Mendes This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3 of the License. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ from typing import List, Dict, Any from logbook import Logger from litecord.enums import RelationshipType log = Logger(__name__) class UserStorage: """Storage functions related to a single user.""" def __init__(self, storage): self.storage = storage self.db = storage.db async def fetch_notes(self, user_id: int) -> dict: """Fetch a users' notes""" note_rows = await self.db.fetch( """ SELECT target_id, note FROM notes WHERE user_id = $1 """, user_id, ) return {str(row["target_id"]): row["note"] for row in note_rows} async def get_user_settings(self, user_id: int) -> Dict[str, Any]: """Get current user settings.""" row = await self.storage.fetchrow_with_json( """ SELECT * FROM user_settings WHERE id = $1 """, user_id, ) if row is None: log.info("Generating user settings for {}", user_id) await self.db.execute( """ INSERT INTO user_settings (id) VALUES ($1) """, user_id, ) # recalling get_user_settings # should work after adding return await self.get_user_settings(user_id) drow = dict(row) drow.pop("id") return drow async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]: """Get all relationships for a user.""" # first, fetch all friendships outgoing # from the user _friend = RelationshipType.FRIEND.value _block = RelationshipType.BLOCK.value _incoming = RelationshipType.INCOMING.value _outgoing = RelationshipType.OUTGOING.value # check all outgoing friends friends = await self.db.fetch( """ SELECT user_id, peer_id, rel_type FROM relationships WHERE user_id = $1 AND rel_type = $2 """, user_id, _friend, ) friends = list(map(dict, friends)) # mutuals is a list of ints # of people who are actually friends # and accepted the friend request mutuals = [] # for each outgoing, find if theres an outgoing from them for row in friends: is_friend = await self.db.fetchrow( """ SELECT user_id, peer_id FROM relationships WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3 """, row["peer_id"], row["user_id"], _friend, ) if is_friend is not None: mutuals.append(row["peer_id"]) # fetch friend requests directed at us incoming_friends = await self.db.fetch( """ SELECT user_id, peer_id FROM relationships WHERE peer_id = $1 AND rel_type = $2 """, user_id, _friend, ) # only need their ids incoming_friends = [ r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals ] # only fetch blocks we did, # not fetching the ones people did to us blocks = await self.db.fetch( """ SELECT user_id, peer_id, rel_type FROM relationships WHERE user_id = $1 AND rel_type = $2 """, user_id, _block, ) blocks = list(map(dict, blocks)) res = [] for drow in friends: drow["type"] = drow["rel_type"] drow["id"] = str(drow["peer_id"]) drow.pop("rel_type") # check if the receiver is a mutual # if it isnt, its still on a friend request stage if drow["peer_id"] not in mutuals: drow["type"] = _outgoing drow["user"] = await self.storage.get_user(drow["peer_id"]) drow.pop("user_id") drow.pop("peer_id") res.append(drow) for peer_id in incoming_friends: res.append( { "id": str(peer_id), "user": await self.storage.get_user(peer_id), "type": _incoming, } ) for drow in blocks: drow["type"] = drow["rel_type"] drow.pop("rel_type") drow["id"] = str(drow["peer_id"]) drow["user"] = await self.storage.get_user(drow["peer_id"]) drow.pop("user_id") drow.pop("peer_id") res.append(drow) return res async def get_friend_ids(self, user_id: int) -> List[int]: """Get all friend IDs for a user.""" rels = await self.get_relationships(user_id) return [ int(r["user"]["id"]) for r in rels if r["type"] == RelationshipType.FRIEND.value ] async def get_dms(self, user_id: int) -> List[Dict[str, Any]]: """Get all DM channels for a user, including group DMs. This will only fetch channels the user has in their state, which is different than the whole list of DM channels. """ dm_ids = await self.db.fetch( """ SELECT dm_id FROM dm_channel_state WHERE user_id = $1 """, user_id, ) dm_ids = [r["dm_id"] for r in dm_ids] res = [] for dm_id in dm_ids: dm_chan = await self.storage.get_dm(dm_id, user_id) res.append(dm_chan) return res async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]: """Get the read state for a user.""" rows = await self.db.fetch( """ SELECT channel_id, last_message_id, mention_count FROM user_read_state WHERE user_id = $1 """, user_id, ) res = [] for row in rows: drow = dict(row) drow["id"] = str(drow["channel_id"]) drow.pop("channel_id") drow["last_message_id"] = str(drow["last_message_id"]) res.append(drow) return res async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List: chan_overrides = [] overrides = await self.db.fetch( """ SELECT channel_id::text, muted, message_notifications FROM guild_settings_channel_overrides WHERE user_id = $1 AND guild_id = $2 """, user_id, guild_id, ) for chan_row in overrides: dcrow = dict(chan_row) chan_overrides.append(dcrow) return chan_overrides async def get_guild_settings_one(self, user_id: int, guild_id: int) -> dict: """Get guild settings information for a single guild.""" row = await self.db.fetchrow( """ SELECT guild_id::text, suppress_everyone, muted, message_notifications, mobile_push FROM guild_settings WHERE user_id = $1 AND guild_id = $2 """, user_id, guild_id, ) if not row: await self.db.execute( """ INSERT INTO guild_settings (user_id, guild_id) VALUES ($1, $2) """, user_id, guild_id, ) return await self.get_guild_settings_one(user_id, guild_id) gid = int(row["guild_id"]) drow = dict(row) chan_overrides = await self._get_chan_overrides(user_id, gid) return {**drow, **{"channel_overrides": chan_overrides}} async def get_guild_settings(self, user_id: int): """Get the specific User Guild Settings, for all guilds a user is on.""" res = [] settings = await self.db.fetch( """ SELECT guild_id::text, suppress_everyone, muted, message_notifications, mobile_push FROM guild_settings WHERE user_id = $1 """, user_id, ) for row in settings: gid = int(row["guild_id"]) drow = dict(row) chan_overrides = await self._get_chan_overrides(user_id, gid) res.append({**drow, **{"channel_overrides": chan_overrides}}) return res async def get_user_guilds(self, user_id: int) -> List[int]: """Get all guild IDs a user is on.""" guild_ids = await self.db.fetch( """ SELECT guild_id FROM members WHERE user_id = $1 """, user_id, ) return [row["guild_id"] for row in guild_ids] async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]: """Get a list of guilds two separate users have in common.""" if user_id == peer_id: # if we are trying to query the mutual guilds with ourselves, we # only need to give the list of guilds we are on. # doing the INTERSECT has some edge-cases that can fuck up testing, # such as a user querying its own profile card while they are # not in any guilds. return await self.get_user_guilds(user_id) or [0] mutual_guilds = await self.db.fetch( """ SELECT guild_id FROM members WHERE user_id = $1 INTERSECT SELECT guild_id FROM members WHERE user_id = $2 """, user_id, peer_id, ) mutual_guilds = [r["guild_id"] for r in mutual_guilds] return mutual_guilds async def are_friends_with(self, user_id: int, peer_id: int) -> bool: """Return if two people are friends. This returns false even if there is a friend request. """ return await self.db.fetchval( """ SELECT ( SELECT EXISTS( SELECT rel_type FROM relationships WHERE user_id = $1 AND peer_id = $2 AND rel_type = 1 ) ) AND ( SELECT EXISTS( SELECT rel_type FROM relationships WHERE user_id = $2 AND peer_id = $1 AND rel_type = 1 ) ) """, user_id, peer_id, ) async def get_gdms_internal(self, user_id) -> List[int]: """Return a list of Group DM IDs the user is a member of.""" rows = await self.db.fetch( """ SELECT id FROM group_dm_members WHERE member_id = $1 """, user_id, ) return [r["id"] for r in rows] async def get_gdms(self, user_id) -> List[Dict[str, Any]]: """Get list of group DMs a user is in.""" gdm_ids = await self.get_gdms_internal(user_id) res = [] for gdm_id in gdm_ids: res.append(await self.storage.get_channel(gdm_id, user_id=user_id)) return res