From 7c878515e950ef1babed4be374ebccbb8aaa23b0 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 09:38:48 -0300 Subject: [PATCH 01/20] make JobManager.spawn copy current app context --- litecord/jobs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/litecord/jobs.py b/litecord/jobs.py index 4ad3852..4f08271 100644 --- a/litecord/jobs.py +++ b/litecord/jobs.py @@ -18,7 +18,9 @@ along with this program. If not, see . """ import asyncio +from typing import Any +from quart.ctx import copy_current_app_context from logbook import Logger log = Logger(__name__) @@ -47,9 +49,14 @@ class JobManager: def spawn(self, coro): """Spawn a given future or coroutine in the background.""" - task = self.loop.create_task(self._wrapper(coro)) + @copy_current_app_context + async def _ctx_wrapper_bg() -> Any: + return await coro + + task = self.loop.create_task(self._wrapper(_ctx_wrapper_bg())) self.jobs.append(task) + return task def close(self): """Close the job manager, cancelling all existing jobs. From 7278c15d9ce500c29d40f9d53398d5adb0117d0f Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 09:41:52 -0300 Subject: [PATCH 02/20] gateway.websocket: remove WebsocketObjects we can just use the app object directly. --- litecord/gateway/websocket.py | 90 +++++++++++++---------------------- 1 file changed, 32 insertions(+), 58 deletions(-) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 6579982..f4c23c7 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -21,7 +21,7 @@ import collections import asyncio import pprint import zlib -from typing import List, Dict, Any +from typing import List, Dict, Any, Iterable from random import randint import websockets @@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple( "WebsocketProperties", "v encoding compress zctx zsctx tasks" ) -WebsocketObjects = collections.namedtuple( - "WebsocketObjects", - ( - "db", - "state_manager", - "storage", - "loop", - "dispatcher", - "presence", - "ratelimiter", - "user_storage", - "voice", - ), -) - class GatewayWebsocket: """Main gateway websocket logic.""" def __init__(self, ws, app, **kwargs): - self.ext = WebsocketObjects( - app.db, - app.state_manager, - app.storage, - app.loop, - app.dispatcher, - app.presence, - app.ratelimiter, - app.user_storage, - app.voice, - ) - - self.storage = self.ext.storage - self.user_storage = self.ext.user_storage - self.presence = self.ext.presence + self.app = app + self.storage = app.storage + self.user_storage = app.user_storage + self.presence = app.presence self.ws = ws self.wsp = WebsocketProperties( @@ -225,7 +199,7 @@ class GatewayWebsocket: await self.send({"op": op_code, "d": data, "t": None, "s": None}) def _check_ratelimit(self, key: str, ratelimit_key): - ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}") + ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}") bucket = ratelimit.get_bucket(ratelimit_key) return bucket.update_rate_limit() @@ -245,7 +219,7 @@ class GatewayWebsocket: if task: task.cancel() - self.wsp.tasks["heartbeat"] = self.ext.loop.create_task( + self.wsp.tasks["heartbeat"] = self.app.loop.create_task( task_wrapper("hb wait", self._hb_wait(interval)) ) @@ -330,7 +304,7 @@ class GatewayWebsocket: if r["type"] == RelationshipType.FRIEND.value ] - friend_presences = await self.ext.presence.friend_presences(friend_ids) + friend_presences = await self.app.presence.friend_presences(friend_ids) settings = await self.user_storage.get_user_settings(user_id) return { @@ -377,14 +351,14 @@ class GatewayWebsocket: await self.dispatch("READY", {**base_ready, **user_ready}) # async dispatch of guilds - self.ext.loop.create_task(self._guild_dispatch(guilds)) + self.app.loop.create_task(self._guild_dispatch(guilds)) async def _check_shards(self, shard, user_id): """Check if the given `shard` value in IDENTIFY has good enough values. """ current_shard, shard_count = shard - guilds = await self.ext.db.fetchval( + guilds = await self.app.db.fetchval( """ SELECT COUNT(*) FROM members @@ -460,7 +434,7 @@ class GatewayWebsocket: ("channel", gdm_ids), ] - await self.ext.dispatcher.mass_sub(user_id, channels_to_sub) + await self.app.dispatcher.mass_sub(user_id, channels_to_sub) if not self.state.bot: # subscribe to all friends @@ -468,7 +442,7 @@ class GatewayWebsocket: # when they come online) friend_ids = await self.user_storage.get_friend_ids(user_id) log.info("subscribing to {} friends", len(friend_ids)) - await self.ext.dispatcher.sub_many("friend", user_id, friend_ids) + await self.app.dispatcher.sub_many("friend", user_id, friend_ids) async def update_status(self, status: dict): """Update the status of the current websocket connection.""" @@ -520,7 +494,7 @@ class GatewayWebsocket: f'Updating presence status={status["status"]} for ' f"uid={self.state.user_id}" ) - await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence) + await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence) async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" @@ -558,13 +532,13 @@ class GatewayWebsocket: presence = data.get("presence") try: - user_id = await raw_token_check(token, self.ext.db) + user_id = await raw_token_check(token, self.app.db) except (Unauthorized, Forbidden): raise WebsocketClose(4004, "Authentication failed") await self._connect_ratelimit(user_id) - bot = await self.ext.db.fetchval( + bot = await self.app.db.fetchval( """ SELECT bot FROM users WHERE id = $1 @@ -587,7 +561,7 @@ class GatewayWebsocket: ) # link the state to the user - self.ext.state_manager.insert(self.state) + self.app.state_manager.insert(self.state) await self.update_status(presence) await self.subscribe_all(data.get("guild_subscriptions", True)) @@ -631,12 +605,12 @@ class GatewayWebsocket: # if its null and null, disconnect the user from any voice # TODO: maybe just leave from DMs? idk... if channel_id is None and guild_id is None: - return await self.ext.voice.leave_all(self.state.user_id) + return await self.app.voice.leave_all(self.state.user_id) # if guild is not none but channel is, we are leaving # a guild's channel if channel_id is None: - return await self.ext.voice.leave(guild_id, self.state.user_id) + return await self.app.voice.leave(guild_id, self.state.user_id) # fetch an existing state given user and guild OR user and channel chan_type = ChannelType(await self.storage.get_chan_type(channel_id)) @@ -659,10 +633,10 @@ class GatewayWebsocket: # this state id format takes care of that. voice_key = (self.state.user_id, state_id2) - voice_state = await self.ext.voice.get_state(voice_key) + voice_state = await self.app.voice.get_state(voice_key) if voice_state is None: - return await self.ext.voice.create_state(voice_key, data) + return await self.app.voice.create_state(voice_key, data) same_guild = guild_id == voice_state.guild_id same_channel = channel_id == voice_state.channel_id @@ -670,10 +644,10 @@ class GatewayWebsocket: prop = await self._vsu_get_prop(voice_state, data) if same_guild and same_channel: - return await self.ext.voice.update_state(voice_state, prop) + return await self.app.voice.update_state(voice_state, prop) if same_guild and not same_channel: - return await self.ext.voice.move_state(voice_state, channel_id) + return await self.app.voice.move_state(voice_state, channel_id) async def _handle_5(self, payload: Dict[str, Any]): """Handle OP 5 Voice Server Ping. @@ -698,9 +672,9 @@ class GatewayWebsocket: # since the state will be removed from # the manager, it will become unreachable # when trying to resume. - self.ext.state_manager.remove(self.state) + self.app.state_manager.remove(self.state) - async def _resume(self, replay_seqs: iter): + async def _resume(self, replay_seqs: Iterable): presences = [] try: @@ -740,12 +714,12 @@ class GatewayWebsocket: raise DecodeError("Invalid resume payload") try: - user_id = await raw_token_check(token, self.ext.db) + user_id = await raw_token_check(token, self.app.db) except (Unauthorized, Forbidden): raise WebsocketClose(4004, "Invalid token") try: - state = self.ext.state_manager.fetch(user_id, sess_id) + state = self.app.state_manager.fetch(user_id, sess_id) except KeyError: return await self.invalidate_session(False) @@ -948,7 +922,7 @@ class GatewayWebsocket: log.debug("lazy request: members: {}", data.get("members", [])) # make shard query - lazy_guilds = self.ext.dispatcher.backends["lazy_guild"] + lazy_guilds = self.app.dispatcher.backends["lazy_guild"] for chan_id, ranges in data.get("channels", {}).items(): chan_id = int(chan_id) @@ -992,10 +966,10 @@ class GatewayWebsocket: # close anyone trying to login while the # server is shutting down - if self.ext.state_manager.closed: + if self.app.state_manager.closed: raise WebsocketClose(4000, "state manager closed") - if not self.ext.state_manager.accept_new: + if not self.app.state_manager.accept_new: raise WebsocketClose(4000, "state manager closed for new") while True: @@ -1016,7 +990,7 @@ class GatewayWebsocket: task.cancel() if self.state: - self.ext.state_manager.remove(self.state) + self.app.state_manager.remove(self.state) self.state.ws = None self.state = None @@ -1031,14 +1005,14 @@ class GatewayWebsocket: # TODO: account for sharding # this only updates status to offline once # ALL shards have come offline - states = self.ext.state_manager.user_states(user_id) + states = self.app.state_manager.user_states(user_id) with_ws = [s for s in states if s.ws] # there arent any other states with websocket if not with_ws: offline = {"afk": False, "status": "offline", "game": None, "since": 0} - await self.ext.presence.dispatch_pres(user_id, offline) + await self.app.presence.dispatch_pres(user_id, offline) async def run(self): """Wrap :meth:`listen_messages` inside From 420646e76f2590c781ae8818bd41dac5f8633e90 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 10:00:35 -0300 Subject: [PATCH 03/20] revamp how Flags works instead of pulling a hack and injecting from_int() in `__init_subclass__`, we just make an actual function, and cache the wanted attributes in the subclass at its creation this fixes statis analyzers claiming from_int() doesn't exist on subclasses, for good reason, as they'd be turing-complete to do so, lol - auth: fix some mypy issues about reusing same variable --- litecord/auth.py | 30 +++++++++++++++--------------- litecord/enums.py | 30 ++++++++++++------------------ 2 files changed, 27 insertions(+), 33 deletions(-) diff --git a/litecord/auth.py b/litecord/auth.py index 078a02f..b96e475 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -56,20 +56,20 @@ async def raw_token_check(token: str, db=None) -> int: # just try by fragments instead of # unpacking fragments = token.split(".") - user_id = fragments[0] + user_id_str = fragments[0] try: - user_id = base64.b64decode(user_id.encode()) - user_id = int(user_id) + user_id_decoded = base64.b64decode(user_id_str.encode()) + user_id = int(user_id_decoded) except (ValueError, binascii.Error): raise Unauthorized("Invalid user ID type") pwd_hash = await db.fetchval( """ - SELECT password_hash - FROM users - WHERE id = $1 - """, + SELECT password_hash + FROM users + WHERE id = $1 + """, user_id, ) @@ -88,10 +88,10 @@ async def raw_token_check(token: str, db=None) -> int: # with people leaving their clients open forever) await db.execute( """ - UPDATE users - SET last_session = (now() at time zone 'utc') - WHERE id = $1 - """, + UPDATE users + SET last_session = (now() at time zone 'utc') + WHERE id = $1 + """, user_id, ) @@ -128,10 +128,10 @@ async def admin_check() -> int: flags = await app.db.fetchval( """ - SELECT flags - FROM users - WHERE id = $1 - """, + SELECT flags + FROM users + WHERE id = $1 + """, user_id, ) diff --git a/litecord/enums.py b/litecord/enums.py index ce68e18..7c5b143 100644 --- a/litecord/enums.py +++ b/litecord/enums.py @@ -54,27 +54,21 @@ class Flags: """ def __init_subclass__(cls, **_kwargs): - attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x)) + # get only the members that represent a field + cls._attrs = inspect.getmembers(cls, lambda x: isinstance(x, int)) - def _make_int(value): - res = Flags() + @classmethod + def from_int(cls, value: int): + """Create a Flags from a given int value.""" + res = Flags() + setattr(res, "value", value) - setattr(res, "value", value) + for attr, val in cls._attrs: + has_attr = (value & val) == val + # set attributes dynamically + setattr(res, f"is_{attr}", has_attr) - for attr, val in attrs: - # get only the ones that represent a field in the - # number's bits - if not isinstance(val, int): - continue - - has_attr = (value & val) == val - - # set each attribute - setattr(res, f"is_{attr}", has_attr) - - return res - - cls.from_int = _make_int + return res class ChannelType(EasyEnum): From 2780ca41750504de926c4b7c55d57d65e714b274 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 10:24:20 -0300 Subject: [PATCH 04/20] guilds: remove app_ param from delete_guild() --- litecord/blueprints/guilds.py | 16 +++++++--------- tests/test_admin_api/test_guilds.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 7db4224..ce2b78b 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -393,20 +393,18 @@ async def _update_guild(guild_id): return jsonify(guild) -async def delete_guild(guild_id: int, *, app_=None): +async def delete_guild(guild_id: int): """Delete a single guild.""" - app_ = app_ or app - - await app_.db.execute( + await app.db.execute( """ - DELETE FROM guilds - WHERE guilds.id = $1 - """, + DELETE FROM guilds + WHERE guilds.id = $1 + """, guild_id, ) # Discord's client expects IDs being string - await app_.dispatcher.dispatch( + await app.dispatcher.dispatch( "guild", guild_id, "GUILD_DELETE", @@ -420,7 +418,7 @@ async def delete_guild(guild_id: int, *, app_=None): # remove from the dispatcher so nobody # becomes the little memer that tries to fuck up with # everybody's gateway - await app_.dispatcher.remove("guild", guild_id) + await app.dispatcher.remove("guild", guild_id) @bp.route("/", methods=["DELETE"]) diff --git a/tests/test_admin_api/test_guilds.py b/tests/test_admin_api/test_guilds.py index b6619e7..6ca61cb 100644 --- a/tests/test_admin_api/test_guilds.py +++ b/tests/test_admin_api/test_guilds.py @@ -54,6 +54,11 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False): return rjson +async def _delete_guild(test_cli, guild_id: int): + async with test_cli.app.app_context(): + await delete_guild(int(guild_id)) + + @pytest.mark.asyncio async def test_guild_fetch(test_cli_staff): """Test the creation and fetching of a guild via the Admin API.""" @@ -63,7 +68,7 @@ async def test_guild_fetch(test_cli_staff): try: await _fetch_guild(test_cli_staff, guild_id) finally: - await delete_guild(int(guild_id), app_=test_cli_staff.app) + await _delete_guild(test_cli_staff, int(guild_id)) @pytest.mark.asyncio @@ -91,7 +96,7 @@ async def test_guild_update(test_cli_staff): rjson = await _fetch_guild(test_cli_staff, guild_id) assert rjson["unavailable"] finally: - await delete_guild(int(guild_id), app_=test_cli_staff.app) + await _delete_guild(test_cli_staff, int(guild_id)) @pytest.mark.asyncio @@ -113,4 +118,4 @@ async def test_guild_delete(test_cli_staff): assert rjson["error"] assert rjson["code"] == GuildNotFound.error_code finally: - await delete_guild(int(guild_id), app_=test_cli_staff.app) + await _delete_guild(test_cli_staff, int(guild_id)) From 2024c4bdf8811154fb4e09c1ccfee5ee76970780 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 10:40:26 -0300 Subject: [PATCH 05/20] remove app parameters from embed functions --- litecord/blueprints/channel/messages.py | 11 ++--- litecord/embed/messages.py | 37 +++++++-------- litecord/embed/sanitizer.py | 60 +++++++------------------ 3 files changed, 39 insertions(+), 69 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 551508b..1a3d6b5 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -383,12 +383,8 @@ async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) return attachment_id -async def _spawn_embed(app_, payload, **kwargs): - app_.sched.spawn( - process_url_embed( - app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs - ) - ) +async def _spawn_embed(payload, **kwargs): + app.sched.spawn(process_url_embed(payload, **kwargs)) @bp.route("//messages", methods=["POST"]) @@ -458,7 +454,7 @@ async def _create_message(channel_id): # spawn url processor for embedding of images perms = await get_permissions(user_id, channel_id) if perms.bits.embed_links: - await _spawn_embed(app, payload) + await _spawn_embed(payload) # update read state for the author await app.db.execute( @@ -536,7 +532,6 @@ async def edit_message(channel_id, message_id): perms = await get_permissions(user_id, channel_id) if perms.bits.embed_links: await _spawn_embed( - app, { "id": message_id, "channel_id": channel_id, diff --git a/litecord/embed/messages.py b/litecord/embed/messages.py index ce23ea4..bc240ea 100644 --- a/litecord/embed/messages.py +++ b/litecord/embed/messages.py @@ -22,6 +22,7 @@ import asyncio import urllib.parse from pathlib import Path +from quart import current_app as app from logbook import Logger from litecord.embed.sanitizer import proxify, fetch_metadata, fetch_embed @@ -33,10 +34,10 @@ log = Logger(__name__) MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm") -async def insert_media_meta(url, config, session): +async def insert_media_meta(url): """Insert media metadata as an embed.""" - img_proxy_url = proxify(url, config=config) - meta = await fetch_metadata(url, config=config, session=session) + img_proxy_url = proxify(url) + meta = await fetch_metadata(url) if meta is None: return @@ -56,19 +57,19 @@ async def insert_media_meta(url, config, session): } -async def msg_update_embeds(payload, new_embeds, storage, dispatcher): +async def msg_update_embeds(payload, new_embeds): """Update the message with the given embeds and dispatch a MESSAGE_UPDATE to users.""" message_id = int(payload["id"]) channel_id = int(payload["channel_id"]) - await storage.execute_with_json( + await app.storage.execute_with_json( """ - UPDATE messages - SET embeds = $1 - WHERE messages.id = $2 - """, + UPDATE messages + SET embeds = $1 + WHERE messages.id = $2 + """, new_embeds, message_id, ) @@ -85,7 +86,9 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher): if "flags" in payload: update_payload["flags"] = payload["flags"] - await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload) + await app.dispatcher.dispatch( + "channel", channel_id, "MESSAGE_UPDATE", update_payload + ) def is_media_url(url) -> bool: @@ -102,15 +105,13 @@ def is_media_url(url) -> bool: return extension in MEDIA_EXTENSIONS -async def insert_mp_embed(parsed, config, session): +async def insert_mp_embed(parsed): """Insert mediaproxy embed.""" - embed = await fetch_embed(parsed, config=config, session=session) + embed = await fetch_embed(parsed) return embed -async def process_url_embed( - config, storage, dispatcher, session, payload: dict, *, delay=0 -): +async def process_url_embed(payload: dict, *, delay=0): """Process URLs in a message and generate embeds based on that.""" await asyncio.sleep(delay) @@ -145,9 +146,9 @@ async def process_url_embed( url = EmbedURL(url) if is_media_url(url): - embed = await insert_media_meta(url, config, session) + embed = await insert_media_meta(url) else: - embed = await insert_mp_embed(url, config, session) + embed = await insert_mp_embed(url) if not embed: continue @@ -160,4 +161,4 @@ async def process_url_embed( log.debug("made {} embeds for mid {}", len(new_embeds), message_id) - await msg_update_embeds(payload, new_embeds, storage, dispatcher) + await msg_update_embeds(payload, new_embeds) diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py index b14e436..14e8977 100644 --- a/litecord/embed/sanitizer.py +++ b/litecord/embed/sanitizer.py @@ -75,35 +75,24 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]): return False -def _mk_cfg_sess(config, session) -> tuple: - """Return a tuple of (config, session).""" - if config is None: - config = app.config - - if session is None: - session = app.session - - return config, session - - -def _md_base(config) -> Optional[tuple]: +def _md_base() -> Optional[tuple]: """Return the protocol and base url for the mediaproxy.""" - md_base_url = config["MEDIA_PROXY"] + md_base_url = app.config["MEDIA_PROXY"] if md_base_url is None: return None - proto = "https" if config["IS_SSL"] else "http" + proto = "https" if app.config["IS_SSL"] else "http" return proto, md_base_url -def make_md_req_url(config, scope: str, url): - """Make a mediaproxy request URL given the config, scope, and the url +def make_md_req_url(scope: str, url): + """Make a mediaproxy request URL given the scope and the url to be proxied. When MEDIA_PROXY is None, however, returns the original URL. """ - base = _md_base(config) + base = _md_base() if base is None: return url.url if isinstance(url, EmbedURL) else url @@ -111,38 +100,25 @@ def make_md_req_url(config, scope: str, url): return f"{proto}://{base_url}/{scope}/{url.to_md_path}" -def proxify(url, *, config=None) -> str: +def proxify(url) -> str: """Return a mediaproxy url for the given EmbedURL. Returns an /img/ scope.""" - config, _sess = _mk_cfg_sess(config, False) - if isinstance(url, str): url = EmbedURL(url) - return make_md_req_url(config, "img", url) + return make_md_req_url("img", url) async def _md_client_req( - config, session, scope: str, url, *, ret_resp=False + scope: str, url, *, ret_resp=False ) -> Optional[Union[Tuple, Dict]]: """Makes a request to the mediaproxy. This has common code between all the main mediaproxy request functions to decrease code repetition. - Note that config and session exist because there are cases where the app - isn't retrievable (as those functions usually run in background tasks, - not in the app itself). - Parameters ---------- - config: dict-like - the app configuration, if None, this will get the global one from the - app instance. - session: aiohttp client session - the aiohttp ClientSession instance to use, if None, this will get - the global one from the app. - scope: str the scope of your request. one of 'meta', 'img', or 'embed' are available for the mediaproxy's API. @@ -155,14 +131,12 @@ async def _md_client_req( the raw bytes of the response, but by the time this function is returned, the response object is invalid and the socket is closed """ - config, session = _mk_cfg_sess(config, session) - if not isinstance(url, EmbedURL): url = EmbedURL(url) - request_url = make_md_req_url(config, scope, url) + request_url = make_md_req_url(scope, url) - async with session.get(request_url) as resp: + async with app.session.get(request_url) as resp: if resp.status == 200: if ret_resp: return resp, await resp.read() @@ -174,18 +148,18 @@ async def _md_client_req( return None -async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]: +async def fetch_metadata(url) -> Optional[Dict]: """Fetch metadata for a url (image width, mime, etc).""" - return await _md_client_req(config, session, "meta", url) + return await _md_client_req("meta", url) -async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]: +async def fetch_raw_img(url) -> Optional[tuple]: """Fetch raw data for a url (the bytes given off, used to proxy images). Returns a tuple containing the response object and the raw bytes given by the website. """ - tup = await _md_client_req(config, session, "img", url, ret_resp=True) + tup = await _md_client_req("img", url, ret_resp=True) if not tup: return None @@ -193,13 +167,13 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]: return tup -async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]: +async def fetch_embed(url) -> Dict[str, Any]: """Fetch an embed for a given webpage (an automatically generated embed by the mediaproxy, look over the project on how it generates embeds). Returns a discord embed object. """ - return await _md_client_req(config, session, "embed", url) + return await _md_client_req("embed", url) async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]: From e0a849fa6a37123881968072da525f9de9fa4505 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 10:49:35 -0300 Subject: [PATCH 06/20] fix callers of embed functions due to param changes --- litecord/blueprints/channels.py | 2 +- litecord/blueprints/icons.py | 2 +- litecord/blueprints/webhooks.py | 4 +--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index a4b6c45..bcaeeb2 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -799,7 +799,7 @@ async def suppress_embeds(channel_id: int, message_id: int): message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds - await msg_update_embeds(message, [], app.storage, app.dispatcher) + await msg_update_embeds(message, []) elif not suppress and not url_embeds: # spawn process_url_embed to restore the embeds, if any await _msg_unset_flags(message_id, MessageFlags.suppress_embeds) diff --git a/litecord/blueprints/icons.py b/litecord/blueprints/icons.py index 9b0f378..a01509e 100644 --- a/litecord/blueprints/icons.py +++ b/litecord/blueprints/icons.py @@ -64,7 +64,7 @@ async def _get_default_user_avatar(default_id: int): async def _handle_webhook_avatar(md_url_redir: str): - md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir)) + md_url = make_md_req_url("img", EmbedURL(md_url_redir)) return redirect(md_url) diff --git a/litecord/blueprints/webhooks.py b/litecord/blueprints/webhooks.py index 9a0be5f..47a1336 100644 --- a/litecord/blueprints/webhooks.py +++ b/litecord/blueprints/webhooks.py @@ -499,9 +499,7 @@ async def execute_webhook(webhook_id: int, webhook_token): await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload) # spawn embedder in the background, even when we're on a webhook. - app.sched.spawn( - process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload) - ) + app.sched.spawn(process_url_embed(payload)) # we can assume its a guild text channel, so just call it await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here) From ce04ac5c5f35e8a6e0ba84233310eda231f06ee3 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 10:52:46 -0300 Subject: [PATCH 07/20] remove app param from guild_region_check() --- litecord/blueprints/admin_api/voice.py | 24 ++++++++++++------------ run.py | 2 +- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/litecord/blueprints/admin_api/voice.py b/litecord/blueprints/admin_api/voice.py index 334bcfe..e2b87d5 100644 --- a/litecord/blueprints/admin_api/voice.py +++ b/litecord/blueprints/admin_api/voice.py @@ -118,7 +118,7 @@ async def deprecate_region(region): return "", 204 -async def guild_region_check(app_): +async def guild_region_check(): """Check all guilds for voice region inconsistencies. Since the voice migration caused all guilds.region columns @@ -126,23 +126,23 @@ async def guild_region_check(app_): than one region setup. """ - regions = await app_.storage.all_voice_regions() + regions = await app.storage.all_voice_regions() if not regions: log.info("region check: no regions to move guilds to") return - res = await app_.db.execute( + res = await app.db.execute( """ - UPDATE guilds - SET region = ( - SELECT id - FROM voice_regions - OFFSET floor(random()*$1) - LIMIT 1 - ) - WHERE region = NULL - """, + UPDATE guilds + SET region = ( + SELECT id + FROM voice_regions + OFFSET floor(random()*$1) + LIMIT 1 + ) + WHERE region = NULL + """, len(regions), ) diff --git a/run.py b/run.py index d15dbc1..eaaa9b2 100644 --- a/run.py +++ b/run.py @@ -339,7 +339,7 @@ async def post_app_start(app_): # we'll need to start a billing job app_.sched.spawn(payment_job(app_)) app_.sched.spawn(api_index(app_)) - app_.sched.spawn(guild_region_check(app_)) + app_.sched.spawn(guild_region_check()) def start_websocket(host, port, ws_handler) -> asyncio.Future: From f6f50a1cff83298806aeb6dce9d29bac4a373c5c Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 11:23:40 -0300 Subject: [PATCH 08/20] remove app param from billing functions --- litecord/blueprints/user/billing.py | 134 ++++++++++-------------- litecord/blueprints/user/billing_job.py | 36 ++++--- run.py | 2 +- 3 files changed, 78 insertions(+), 94 deletions(-) diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py index a1a4148..2ac1972 100644 --- a/litecord/blueprints/user/billing.py +++ b/litecord/blueprints/user/billing.py @@ -122,16 +122,13 @@ async def get_payment_source_ids(user_id: int) -> list: return [r["id"] for r in rows] -async def get_payment_ids(user_id: int, db=None) -> list: - if not db: - db = app.db - - rows = await db.fetch( +async def get_payment_ids(user_id: int) -> list: + rows = await app.db.fetch( """ - SELECT id - FROM user_payments - WHERE user_id = $1 - """, + SELECT id + FROM user_payments + WHERE user_id = $1 + """, user_id, ) @@ -151,18 +148,14 @@ async def get_subscription_ids(user_id: int) -> list: return [r["id"] for r in rows] -async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: +async def get_payment_source(user_id: int, source_id: int) -> dict: """Get a payment source's information.""" - - if not db: - db = app.db - - source_type = await db.fetchval( + source_type = await app.db.fetchval( """ - SELECT source_type - FROM user_payment_sources - WHERE id = $1 AND user_id = $2 - """, + SELECT source_type + FROM user_payment_sources + WHERE id = $1 AND user_id = $2 + """, source_id, user_id, ) @@ -176,7 +169,7 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: fields = ",".join(specific_fields) - extras_row = await db.fetchrow( + extras_row = await app.db.fetchrow( f""" SELECT {fields}, billing_address, default_, id::text FROM user_payment_sources @@ -199,22 +192,19 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: return {**source, **derow} -async def get_subscription(subscription_id: int, db=None): +async def get_subscription(subscription_id: int): """Get a subscription's information.""" - if not db: - db = app.db - - row = await db.fetchrow( + row = await app.db.fetchrow( """ - SELECT id::text, source_id::text AS payment_source_id, - user_id, - payment_gateway, payment_gateway_plan_id, - period_start AS current_period_start, - period_end AS current_period_end, - canceled_at, s_type, status - FROM user_subscriptions - WHERE id = $1 - """, + SELECT id::text, source_id::text AS payment_source_id, + user_id, + payment_gateway, payment_gateway_plan_id, + period_start AS current_period_start, + period_end AS current_period_end, + canceled_at, s_type, status + FROM user_subscriptions + WHERE id = $1 + """, subscription_id, ) @@ -231,19 +221,16 @@ async def get_subscription(subscription_id: int, db=None): return drow -async def get_payment(payment_id: int, db=None): +async def get_payment(payment_id: int): """Get a single payment's information.""" - if not db: - db = app.db - - row = await db.fetchrow( + row = await app.db.fetchrow( """ - SELECT id::text, source_id, subscription_id, user_id, - amount, amount_refunded, currency, - description, status, tax, tax_inclusive - FROM user_payments - WHERE id = $1 - """, + SELECT id::text, source_id, subscription_id, user_id, + amount, amount_refunded, currency, + description, status, tax, tax_inclusive + FROM user_payments + WHERE id = $1 + """, payment_id, ) @@ -255,27 +242,22 @@ async def get_payment(payment_id: int, db=None): drow["created_at"] = snowflake_datetime(int(drow["id"])) - drow["payment_source"] = await get_payment_source( - row["user_id"], row["source_id"], db - ) + drow["payment_source"] = await get_payment_source(row["user_id"], row["source_id"]) - drow["subscription"] = await get_subscription(row["subscription_id"], db) + drow["subscription"] = await get_subscription(row["subscription_id"]) return drow -async def create_payment(subscription_id, db=None): +async def create_payment(subscription_id): """Create a payment.""" - if not db: - db = app.db - - sub = await get_subscription(subscription_id, db) + sub = await get_subscription(subscription_id) new_id = get_snowflake() amount = AMOUNTS[sub["payment_gateway_plan_id"]] - await db.execute( + await app.db.execute( """ INSERT INTO user_payments ( id, source_id, subscription_id, user_id, @@ -298,9 +280,9 @@ async def create_payment(subscription_id, db=None): return new_id -async def process_subscription(app, subscription_id: int): +async def process_subscription(subscription_id: int): """Process a single subscription.""" - sub = await get_subscription(subscription_id, app.db) + sub = await get_subscription(subscription_id) user_id = int(sub["user_id"]) @@ -313,10 +295,10 @@ async def process_subscription(app, subscription_id: int): # payments), then we should update premium status first_payment_id = await app.db.fetchval( """ - SELECT MIN(id) - FROM user_payments - WHERE subscription_id = $1 - """, + SELECT MIN(id) + FROM user_payments + WHERE subscription_id = $1 + """, subscription_id, ) @@ -324,10 +306,10 @@ async def process_subscription(app, subscription_id: int): premium_since = await app.db.fetchval( """ - SELECT premium_since - FROM users - WHERE id = $1 - """, + SELECT premium_since + FROM users + WHERE id = $1 + """, user_id, ) @@ -343,10 +325,10 @@ async def process_subscription(app, subscription_id: int): old_flags = await app.db.fetchval( """ - SELECT flags - FROM users - WHERE id = $1 - """, + SELECT flags + FROM users + WHERE id = $1 + """, user_id, ) @@ -355,17 +337,17 @@ async def process_subscription(app, subscription_id: int): await app.db.execute( """ - UPDATE users - SET premium_since = $1, flags = $2 - WHERE id = $3 - """, + UPDATE users + SET premium_since = $1, flags = $2 + WHERE id = $3 + """, first_payment_ts, new_flags, user_id, ) # dispatch updated user to all possible clients - await mass_user_update(user_id, app) + await mass_user_update(user_id) @bp.route("/@me/billing/payment-sources", methods=["GET"]) @@ -474,11 +456,11 @@ async def _create_subscription(): 1, ) - await create_payment(new_id, app.db) + await create_payment(new_id) # make sure we update the user's premium status # and dispatch respective user updates to other people. - await process_subscription(app, new_id) + await process_subscription(new_id) return jsonify(await get_subscription(new_id)) diff --git a/litecord/blueprints/user/billing_job.py b/litecord/blueprints/user/billing_job.py index 4148415..ee50c33 100644 --- a/litecord/blueprints/user/billing_job.py +++ b/litecord/blueprints/user/billing_job.py @@ -21,6 +21,8 @@ along with this program. If not, see . this file only serves the periodic payment job code. """ import datetime + +from quart import current_app as app from asyncio import sleep, CancelledError from logbook import Logger @@ -47,14 +49,14 @@ THRESHOLDS = { } -async def _resched(app): +async def _resched(): log.debug("waiting 30 minutes for job.") await sleep(30 * MINUTES) - app.sched.spawn(payment_job(app)) + app.sched.spawn(payment_job()) -async def _process_user_payments(app, user_id: int): - payments = await get_payment_ids(user_id, app.db) +async def _process_user_payments(user_id: int): + payments = await get_payment_ids(user_id) if not payments: log.debug("no payments for uid {}, skipping", user_id) @@ -64,7 +66,7 @@ async def _process_user_payments(app, user_id: int): latest_payment = max(payments) - payment_data = await get_payment(latest_payment, app.db) + payment_data = await get_payment(latest_payment) # calculate the difference between this payment # and now. @@ -74,7 +76,7 @@ async def _process_user_payments(app, user_id: int): delta = now - payment_tstamp sub_id = int(payment_data["subscription"]["id"]) - subscription = await get_subscription(sub_id, app.db) + subscription = await get_subscription(sub_id) # if the max payment is X days old, we create another. # X is 30 for monthly subscriptions of nitro, @@ -89,12 +91,12 @@ async def _process_user_payments(app, user_id: int): # create_payment does not call any Stripe # or BrainTree APIs at all, since we'll just # give it as free. - await create_payment(sub_id, app.db) + await create_payment(sub_id) else: log.debug("sid={}, missing {} days", sub_id, threshold - delta.days) -async def payment_job(app): +async def payment_job(): """Main payment job function. This function will check through users' payments @@ -104,9 +106,9 @@ async def payment_job(app): user_ids = await app.db.fetch( """ - SELECT DISTINCT user_id - FROM user_payments - """ + SELECT DISTINCT user_id + FROM user_payments + """ ) log.debug("working {} users", len(user_ids)) @@ -115,24 +117,24 @@ async def payment_job(app): for row in user_ids: user_id = row["user_id"] try: - await _process_user_payments(app, user_id) + await _process_user_payments(user_id) except Exception: log.exception("error while processing user payments") subscribers = await app.db.fetch( """ - SELECT id - FROM user_subscriptions - """ + SELECT id + FROM user_subscriptions + """ ) for row in subscribers: try: - await process_subscription(app, row["id"]) + await process_subscription(row["id"]) except Exception: log.exception("error while processing subscription") log.debug("rescheduling..") try: - await _resched(app) + await _resched() except CancelledError: log.info("cancelled while waiting for resched") diff --git a/run.py b/run.py index eaaa9b2..55aeb37 100644 --- a/run.py +++ b/run.py @@ -337,7 +337,7 @@ async def api_index(app_): async def post_app_start(app_): # we'll need to start a billing job - app_.sched.spawn(payment_job(app_)) + app_.sched.spawn(payment_job()) app_.sched.spawn(api_index(app_)) app_.sched.spawn(guild_region_check()) From 2bc7bb39248f8b464235014ab94c4661a209d6f6 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 11:31:12 -0300 Subject: [PATCH 09/20] fix some calls not needing app --- litecord/blueprints/channels.py | 6 +----- litecord/blueprints/users.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index bcaeeb2..98ec291 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -809,11 +809,7 @@ async def suppress_embeds(channel_id: int, message_id: int): except KeyError: pass - app.sched.spawn( - process_url_embed( - app.config, app.storage, app.dispatcher, app.session, message - ) - ) + app.sched.spawn(process_url_embed(message)) return "", 204 diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index e45a1a3..763f18e 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -276,7 +276,7 @@ async def patch_me(): user.pop("password_hash") - _, private_user = await mass_user_update(user_id, app) + _, private_user = await mass_user_update(user_id) return jsonify(private_user) From ab89b70ddcfa28a5ae795d21f6dd98301c28fbc5 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 11:33:19 -0300 Subject: [PATCH 10/20] fix create_user call --- litecord/blueprints/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index 81a25d1..ba0a2a5 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -120,7 +120,7 @@ async def _register_with_invite(): ) user_id, pwd_hash = await create_user( - data["username"], data["email"], data["password"], app.db + data["username"], data["email"], data["password"] ) return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)}) From 1efc65511cf8994e31041f91b198b07d44597606 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 13:31:47 -0300 Subject: [PATCH 11/20] create litecord.common --- litecord/common/__init__.py | 0 .../dm_checks.py => common/channels.py} | 78 ++++++- litecord/common/guilds.py | 206 ++++++++++++++++++ litecord/common/messages.py | 189 ++++++++++++++++ 4 files changed, 463 insertions(+), 10 deletions(-) create mode 100644 litecord/common/__init__.py rename litecord/{blueprints/channel/dm_checks.py => common/channels.py} (53%) create mode 100644 litecord/common/guilds.py create mode 100644 litecord/common/messages.py diff --git a/litecord/common/__init__.py b/litecord/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/litecord/blueprints/channel/dm_checks.py b/litecord/common/channels.py similarity index 53% rename from litecord/blueprints/channel/dm_checks.py rename to litecord/common/channels.py index e2cb195..dc85ef4 100644 --- a/litecord/blueprints/channel/dm_checks.py +++ b/litecord/common/channels.py @@ -16,15 +16,55 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - from quart import current_app as app -from litecord.errors import Forbidden + +from litecord.errors import ForbiddenDM from litecord.enums import RelationshipType -class ForbiddenDM(Forbidden): - error_code = 50007 +async def channel_ack( + user_id: int, guild_id: int, channel_id: int, message_id: int = None +): + """ACK a channel.""" + + if not message_id: + message_id = await app.storage.chan_last_message(channel_id) + + await app.db.execute( + """ + INSERT INTO user_read_state + (user_id, channel_id, last_message_id, mention_count) + VALUES + ($1, $2, $3, 0) + ON CONFLICT ON CONSTRAINT user_read_state_pkey + DO + UPDATE + SET last_message_id = $3, mention_count = 0 + WHERE user_read_state.user_id = $1 + AND user_read_state.channel_id = $2 + """, + user_id, + channel_id, + message_id, + ) + + if guild_id: + await app.dispatcher.dispatch_user_guild( + user_id, + guild_id, + "MESSAGE_ACK", + {"message_id": str(message_id), "channel_id": str(channel_id)}, + ) + else: + # we don't use ChannelDispatcher here because since + # guild_id is None, all user devices are already subscribed + # to the given channel (a dm or a group dm) + await app.dispatcher.dispatch_user( + user_id, + "MESSAGE_ACK", + {"message_id": str(message_id), "channel_id": str(channel_id)}, + ) async def dm_pre_check(user_id: int, channel_id: int, peer_id: int): @@ -32,12 +72,12 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int): # first step is checking if there is a block in any direction blockrow = await app.db.fetchrow( """ - SELECT rel_type - FROM relationships - WHERE rel_type = $3 - AND user_id IN ($1, $2) - AND peer_id IN ($1, $2) - """, + SELECT rel_type + FROM relationships + WHERE rel_type = $3 + AND user_id IN ($1, $2) + AND peer_id IN ($1, $2) + """, user_id, peer_id, RelationshipType.BLOCK.value, @@ -75,3 +115,21 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int): # if after this filtering we don't have any more guilds, error if not mutual_guilds: raise ForbiddenDM() + + +async def try_dm_state(user_id: int, dm_id: int): + """Try inserting the user into the dm state + for the given DM. + + Does not do anything if the user is already + in the dm state. + """ + await app.db.execute( + """ + INSERT INTO dm_channel_state (user_id, dm_id) + VALUES ($1, $2) + ON CONFLICT DO NOTHING + """, + user_id, + dm_id, + ) diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py new file mode 100644 index 0000000..f1915e0 --- /dev/null +++ b/litecord/common/guilds.py @@ -0,0 +1,206 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +from quart import current_app as app + +from ..snowflake import get_snowflake +from ..permissions import get_role_perms +from ..utils import dict_get, maybe_lazy_guild_dispatch +from ..enums import ChannelType + + +async def remove_member(guild_id: int, member_id: int): + """Do common tasks related to deleting a member from the guild, + such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" + + await app.db.execute( + """ + DELETE FROM members + WHERE guild_id = $1 AND user_id = $2 + """, + guild_id, + member_id, + ) + + await app.dispatcher.dispatch_user_guild( + member_id, + guild_id, + "GUILD_DELETE", + {"guild_id": str(guild_id), "unavailable": False}, + ) + + await app.dispatcher.unsub("guild", guild_id, member_id) + + await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id) + + await app.dispatcher.dispatch_guild( + guild_id, + "GUILD_MEMBER_REMOVE", + {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, + ) + + +async def remove_member_multi(guild_id: int, members: list): + """Remove multiple members.""" + for member_id in members: + await remove_member(guild_id, member_id) + + +async def create_role(guild_id, name: str, **kwargs): + """Create a role in a guild.""" + new_role_id = get_snowflake() + + everyone_perms = await get_role_perms(guild_id, guild_id) + default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary) + + # update all roles so that we have space for pos 1, but without + # sending GUILD_ROLE_UPDATE for everyone + await app.db.execute( + """ + UPDATE roles + SET + position = position + 1 + WHERE guild_id = $1 + AND NOT (position = 0) + """, + guild_id, + ) + + await app.db.execute( + """ + INSERT INTO roles (id, guild_id, name, color, + hoist, position, permissions, managed, mentionable) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + """, + new_role_id, + guild_id, + name, + dict_get(kwargs, "color", 0), + dict_get(kwargs, "hoist", False), + # always set ourselves on position 1 + 1, + int(dict_get(kwargs, "permissions", default_perms)), + False, + dict_get(kwargs, "mentionable", False), + ) + + role = await app.storage.get_role(new_role_id, guild_id) + + # we need to update the lazy guild handlers for the newly created group + await maybe_lazy_guild_dispatch(guild_id, "new_role", role) + + await app.dispatcher.dispatch_guild( + guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role} + ) + + return role + + +async def _specific_chan_create(channel_id, ctype, **kwargs): + if ctype == ChannelType.GUILD_TEXT: + await app.db.execute( + """ + INSERT INTO guild_text_channels (id, topic) + VALUES ($1, $2) + """, + channel_id, + kwargs.get("topic", ""), + ) + elif ctype == ChannelType.GUILD_VOICE: + await app.db.execute( + """ + INSERT INTO guild_voice_channels (id, bitrate, user_limit) + VALUES ($1, $2, $3) + """, + channel_id, + kwargs.get("bitrate", 64), + kwargs.get("user_limit", 0), + ) + + +async def create_guild_channel( + guild_id: int, channel_id: int, ctype: ChannelType, **kwargs +): + """Create a channel in a guild.""" + await app.db.execute( + """ + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, + channel_id, + ctype.value, + ) + + # calc new pos + max_pos = await app.db.fetchval( + """ + SELECT MAX(position) + FROM guild_channels + WHERE guild_id = $1 + """, + guild_id, + ) + + # account for the first channel in a guild too + max_pos = max_pos or 0 + + # all channels go to guild_channels + await app.db.execute( + """ + INSERT INTO guild_channels (id, guild_id, name, position) + VALUES ($1, $2, $3, $4) + """, + channel_id, + guild_id, + kwargs["name"], + max_pos + 1, + ) + + # the rest of sql magic is dependant on the channel + # we're creating (a text or voice or category), + # so we use this function. + await _specific_chan_create(channel_id, ctype, **kwargs) + + +async def delete_guild(guild_id: int): + """Delete a single guild.""" + await app.db.execute( + """ + DELETE FROM guilds + WHERE guilds.id = $1 + """, + guild_id, + ) + + # Discord's client expects IDs being string + await app.dispatcher.dispatch( + "guild", + guild_id, + "GUILD_DELETE", + { + "guild_id": str(guild_id), + "id": str(guild_id), + # 'unavailable': False, + }, + ) + + # remove from the dispatcher so nobody + # becomes the little memer that tries to fuck up with + # everybody's gateway + await app.dispatcher.remove("guild", guild_id) diff --git a/litecord/common/messages.py b/litecord/common/messages.py new file mode 100644 index 0000000..ab43767 --- /dev/null +++ b/litecord/common/messages.py @@ -0,0 +1,189 @@ +import json +import logging + +from PIL import Image +from quart import request, current_app as app + +from litecord.errors import BadRequest +from ..snowflake import get_snowflake + +log = logging.getLogger(__name__) + + +async def msg_create_request() -> tuple: + """Extract the json input and any file information + the client gave to us in the request. + + This only applies to create message route. + """ + form = await request.form + request_json = await request.get_json() or {} + + # NOTE: embed isn't set on form data + json_from_form = { + "content": form.get("content", ""), + "nonce": form.get("nonce", "0"), + "tts": json.loads(form.get("tts", "false")), + } + + payload_json = json.loads(form.get("payload_json", "{}")) + + json_from_form.update(request_json) + json_from_form.update(payload_json) + + files = await request.files + + # we don't really care about the given fields on the files dict, so + # we only extract the values + return json_from_form, [v for k, v in files.items()] + + +def msg_create_check_content(payload: dict, files: list, *, use_embeds=False): + """Check if there is actually any content being sent to us.""" + has_content = bool(payload.get("content", "")) + has_files = len(files) > 0 + + embed_field = "embeds" if use_embeds else "embed" + has_embed = embed_field in payload and payload.get(embed_field) is not None + + has_total_content = has_content or has_embed or has_files + + if not has_total_content: + raise BadRequest("No content has been provided.") + + +async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int: + """Add an attachment to a message. + + Parameters + ---------- + message_id: int + The ID of the message getting the attachment. + channel_id: int + The ID of the channel the message belongs to. + + Exists because the attachment URL scheme contains + a channel id. The purpose is unknown, but we are + implementing Discord's behavior. + attachment_file: quart.FileStorage + quart FileStorage instance of the file. + """ + + attachment_id = get_snowflake() + filename = attachment_file.filename + + # understand file info + mime = attachment_file.mimetype + is_image = mime.startswith("image/") + + img_width, img_height = None, None + + # extract file size + # TODO: this is probably inneficient + file_size = attachment_file.stream.getbuffer().nbytes + + if is_image: + # open with pillow, extract image size + image = Image.open(attachment_file.stream) + img_width, img_height = image.size + + # NOTE: DO NOT close the image, as closing the image will + # also close the stream. + + # reset it to 0 for later usage + attachment_file.stream.seek(0) + + await app.db.execute( + """ + INSERT INTO attachments + (id, channel_id, message_id, + filename, filesize, + image, width, height) + VALUES + ($1, $2, $3, $4, $5, $6, $7, $8) + """, + attachment_id, + channel_id, + message_id, + filename, + file_size, + is_image, + img_width, + img_height, + ) + + ext = filename.split(".")[-1] + + with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file: + attach_file.write(attachment_file.stream.read()) + + log.debug("written {} bytes for attachment id {}", file_size, attachment_id) + + return attachment_id + + +async def msg_guild_text_mentions( + payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool +): + """Calculates mention data side-effects.""" + channel_id = int(payload["channel_id"]) + + # calculate the user ids we'll bump the mention count for + uids = set() + + # first is extracting user mentions + for mention in payload["mentions"]: + uids.add(int(mention["id"])) + + # then role mentions + for role_mention in payload["mention_roles"]: + role_id = int(role_mention) + member_ids = await app.storage.get_role_members(role_id) + + for member_id in member_ids: + uids.add(member_id) + + # at-here only updates the state + # for the users that have a state + # in the channel. + if mentions_here: + uids = set() + + await app.db.execute( + """ + UPDATE user_read_state + SET mention_count = mention_count + 1 + WHERE channel_id = $1 + """, + channel_id, + ) + + # at-here updates the read state + # for all users, including the ones + # that might not have read permissions + # to the channel. + if mentions_everyone: + uids = set() + + member_ids = await app.storage.get_member_ids(guild_id) + + await app.db.executemany( + """ + UPDATE user_read_state + SET mention_count = mention_count + 1 + WHERE channel_id = $1 AND user_id = $2 + """, + [(channel_id, uid) for uid in member_ids], + ) + + for user_id in uids: + await app.db.execute( + """ + UPDATE user_read_state + SET mention_count = mention_count + 1 + WHERE user_id = $1 + AND channel_id = $2 + """, + user_id, + channel_id, + ) From a67b6580ba10717b862d46d024227b7e1ec2d4cd Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 13:33:52 -0300 Subject: [PATCH 12/20] make other blueprints use common, etc --- litecord/blueprints/admin_api/guilds.py | 2 +- litecord/blueprints/channel/reactions.py | 11 ++-- litecord/blueprints/channels.py | 51 ++-------------- litecord/blueprints/dms.py | 19 +----- litecord/blueprints/guild/channels.py | 68 +-------------------- litecord/blueprints/guild/mod.py | 41 +------------ litecord/blueprints/guild/roles.py | 77 ++---------------------- litecord/blueprints/users.py | 3 +- litecord/blueprints/webhooks.py | 2 +- litecord/errors.py | 4 ++ litecord/utils.py | 53 +++++++++++++++- 11 files changed, 78 insertions(+), 253 deletions(-) diff --git a/litecord/blueprints/admin_api/guilds.py b/litecord/blueprints/admin_api/guilds.py index 1cf792b..15f2647 100644 --- a/litecord/blueprints/admin_api/guilds.py +++ b/litecord/blueprints/admin_api/guilds.py @@ -22,7 +22,7 @@ from quart import Blueprint, jsonify, current_app as app, request from litecord.auth import admin_check from litecord.schemas import validate from litecord.admin_schemas import GUILD_UPDATE -from litecord.blueprints.guilds import delete_guild +from litecord.common.guilds import delete_guild from litecord.errors import GuildNotFound bp = Blueprint("guilds_admin", __name__) diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py index 4c8dabd..ac1440f 100644 --- a/litecord/blueprints/channel/reactions.py +++ b/litecord/blueprints/channel/reactions.py @@ -23,10 +23,9 @@ from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger -from litecord.utils import async_map +from litecord.utils import async_map, query_tuple_from_args, extract_limit from litecord.blueprints.auth import token_check from litecord.blueprints.checks import channel_check, channel_perm_check -from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit from litecord.enums import GUILD_CHANS @@ -165,7 +164,8 @@ def _emoji_sql_simple(emoji: str, param=4): return emoji_sql(emoji_type, emoji_id, emoji_name, param) -async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str): +async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str): + """Remove given reaction from a message.""" ctype, guild_id = await channel_check(user_id, channel_id) emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) @@ -201,8 +201,7 @@ async def remove_own_reaction(channel_id, message_id, emoji): """Remove a reaction.""" user_id = await token_check() - await remove_reaction(channel_id, message_id, user_id, emoji) - + await _remove_reaction(channel_id, message_id, user_id, emoji) return "", 204 @@ -212,7 +211,7 @@ async def remove_user_reaction(channel_id, message_id, emoji, other_id): user_id = await token_check() await channel_perm_check(user_id, channel_id, "manage_messages") - await remove_reaction(channel_id, message_id, other_id, emoji) + await _remove_reaction(channel_id, message_id, other_id, emoji) return "", 204 diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 98ec291..701c141 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -42,6 +42,7 @@ from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy from litecord.utils import search_result_from_list from litecord.embed.messages import process_url_embed, msg_update_embeds from litecord.snowflake import snowflake_datetime +from litecord.common.channels import channel_ack log = Logger(__name__) bp = Blueprint("channels", __name__) @@ -136,7 +137,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int): await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child) -async def delete_messages(channel_id): +async def _delete_messages(channel_id): await app.db.execute( """ DELETE FROM channel_pins @@ -162,7 +163,7 @@ async def delete_messages(channel_id): ) -async def guild_cleanup(channel_id): +async def _guild_cleanup(channel_id): await app.db.execute( """ DELETE FROM channel_overwrites @@ -220,8 +221,8 @@ async def close_channel(channel_id): # didn't work on my setup, so I delete # everything before moving to the main # channel table deletes - await delete_messages(channel_id) - await guild_cleanup(channel_id) + await _delete_messages(channel_id) + await _guild_cleanup(channel_id) await app.db.execute( f""" @@ -595,48 +596,6 @@ async def trigger_typing(channel_id): return "", 204 -async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): - """ACK a channel.""" - - if not message_id: - message_id = await app.storage.chan_last_message(channel_id) - - await app.db.execute( - """ - INSERT INTO user_read_state - (user_id, channel_id, last_message_id, mention_count) - VALUES - ($1, $2, $3, 0) - ON CONFLICT ON CONSTRAINT user_read_state_pkey - DO - UPDATE - SET last_message_id = $3, mention_count = 0 - WHERE user_read_state.user_id = $1 - AND user_read_state.channel_id = $2 - """, - user_id, - channel_id, - message_id, - ) - - if guild_id: - await app.dispatcher.dispatch_user_guild( - user_id, - guild_id, - "MESSAGE_ACK", - {"message_id": str(message_id), "channel_id": str(channel_id)}, - ) - else: - # we don't use ChannelDispatcher here because since - # guild_id is None, all user devices are already subscribed - # to the given channel (a dm or a group dm) - await app.dispatcher.dispatch_user( - user_id, - "MESSAGE_ACK", - {"message_id": str(message_id), "channel_id": str(channel_id)}, - ) - - @bp.route("//messages//ack", methods=["POST"]) async def ack_channel(channel_id, message_id): """Acknowledge a channel.""" diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py index 975d7ea..2a50837 100644 --- a/litecord/blueprints/dms.py +++ b/litecord/blueprints/dms.py @@ -31,6 +31,7 @@ from ..snowflake import get_snowflake from .auth import token_check from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient +from litecord.common.channels import try_dm_state log = Logger(__name__) bp = Blueprint("dms", __name__) @@ -44,24 +45,6 @@ async def get_dms(): return jsonify(dms) -async def try_dm_state(user_id: int, dm_id: int): - """Try inserting the user into the dm state - for the given DM. - - Does not do anything if the user is already - in the dm state. - """ - await app.db.execute( - """ - INSERT INTO dm_channel_state (user_id, dm_id) - VALUES ($1, $2) - ON CONFLICT DO NOTHING - """, - user_id, - dm_id, - ) - - async def jsonify_dm(dm_id: int, user_id: int): dm_chan = await app.storage.get_dm(dm_id, user_id) return jsonify(dm_chan) diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index c8a2d2e..de20b95 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -27,77 +27,11 @@ from litecord.blueprints.guild.roles import gen_pairs from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check - +from litecord.common.guilds import create_guild_channel bp = Blueprint("guild_channels", __name__) -async def _specific_chan_create(channel_id, ctype, **kwargs): - if ctype == ChannelType.GUILD_TEXT: - await app.db.execute( - """ - INSERT INTO guild_text_channels (id, topic) - VALUES ($1, $2) - """, - channel_id, - kwargs.get("topic", ""), - ) - elif ctype == ChannelType.GUILD_VOICE: - await app.db.execute( - """ - INSERT INTO guild_voice_channels (id, bitrate, user_limit) - VALUES ($1, $2, $3) - """, - channel_id, - kwargs.get("bitrate", 64), - kwargs.get("user_limit", 0), - ) - - -async def create_guild_channel( - guild_id: int, channel_id: int, ctype: ChannelType, **kwargs -): - """Create a channel in a guild.""" - await app.db.execute( - """ - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, - channel_id, - ctype.value, - ) - - # calc new pos - max_pos = await app.db.fetchval( - """ - SELECT MAX(position) - FROM guild_channels - WHERE guild_id = $1 - """, - guild_id, - ) - - # account for the first channel in a guild too - max_pos = max_pos or 0 - - # all channels go to guild_channels - await app.db.execute( - """ - INSERT INTO guild_channels (id, guild_id, name, position) - VALUES ($1, $2, $3, $4) - """, - channel_id, - guild_id, - kwargs["name"], - max_pos + 1, - ) - - # the rest of sql magic is dependant on the channel - # we're creating (a text or voice or category), - # so we use this function. - await _specific_chan_create(channel_id, ctype, **kwargs) - - @bp.route("//channels", methods=["GET"]) async def get_guild_channels(guild_id): """Get the list of channels in a guild.""" diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index 5949fd0..4032bb0 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -23,47 +23,11 @@ from litecord.blueprints.auth import token_check from litecord.blueprints.checks import guild_perm_check from litecord.schemas import validate, GUILD_PRUNE +from litecord.common.guilds import remove_member, remove_member_multi bp = Blueprint("guild_moderation", __name__) -async def remove_member(guild_id: int, member_id: int): - """Do common tasks related to deleting a member from the guild, - such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" - - await app.db.execute( - """ - DELETE FROM members - WHERE guild_id = $1 AND user_id = $2 - """, - guild_id, - member_id, - ) - - await app.dispatcher.dispatch_user_guild( - member_id, - guild_id, - "GUILD_DELETE", - {"guild_id": str(guild_id), "unavailable": False}, - ) - - await app.dispatcher.unsub("guild", guild_id, member_id) - - await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id) - - await app.dispatcher.dispatch_guild( - guild_id, - "GUILD_MEMBER_REMOVE", - {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, - ) - - -async def remove_member_multi(guild_id: int, members: list): - """Remove multiple members.""" - for member_id in members: - await remove_member(guild_id, member_id) - - @bp.route("//members/", methods=["DELETE"]) async def kick_guild_member(guild_id, member_id): """Remove a member from a guild.""" @@ -221,6 +185,5 @@ async def begin_guild_prune(guild_id): days = j["days"] member_ids = await get_prune(guild_id, days) - app.loop.create_task(remove_member_multi(guild_id, member_ids)) - + app.sched.spawn(remove_member_multi(guild_id, member_ids)) return jsonify({"pruned": len(member_ids)}) diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index 9516aa4..98440fa 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -27,11 +27,9 @@ from litecord.auth import token_check from litecord.blueprints.checks import guild_check, guild_perm_check from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION -from litecord.snowflake import get_snowflake -from litecord.utils import dict_get -from litecord.permissions import get_role_perms +from litecord.utils import maybe_lazy_guild_dispatch +from litecord.common.guilds import create_role -DEFAULT_EVERYONE_PERMS = 104324161 log = Logger(__name__) bp = Blueprint("guild_roles", __name__) @@ -45,71 +43,6 @@ async def get_guild_roles(guild_id): return jsonify(await app.storage.get_role_data(guild_id)) -async def _maybe_lg(guild_id: int, event: str, role, force: bool = False): - # sometimes we want to dispatch an event - # even if the role isn't hoisted - - # an example of such a case is when a role loses - # its hoist status. - - # check if is a dict first because role_delete - # only receives the role id. - if isinstance(role, dict) and not role["hoist"] and not force: - return - - await app.dispatcher.dispatch("lazy_guild", guild_id, event, role) - - -async def create_role(guild_id, name: str, **kwargs): - """Create a role in a guild.""" - new_role_id = get_snowflake() - - everyone_perms = await get_role_perms(guild_id, guild_id) - default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary) - - # update all roles so that we have space for pos 1, but without - # sending GUILD_ROLE_UPDATE for everyone - await app.db.execute( - """ - UPDATE roles - SET - position = position + 1 - WHERE guild_id = $1 - AND NOT (position = 0) - """, - guild_id, - ) - - await app.db.execute( - """ - INSERT INTO roles (id, guild_id, name, color, - hoist, position, permissions, managed, mentionable) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, - new_role_id, - guild_id, - name, - dict_get(kwargs, "color", 0), - dict_get(kwargs, "hoist", False), - # always set ourselves on position 1 - 1, - int(dict_get(kwargs, "permissions", default_perms)), - False, - dict_get(kwargs, "mentionable", False), - ) - - role = await app.storage.get_role(new_role_id, guild_id) - - # we need to update the lazy guild handlers for the newly created group - await _maybe_lg(guild_id, "new_role", role) - - await app.dispatcher.dispatch_guild( - guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role} - ) - - return role - - @bp.route("//roles", methods=["POST"]) async def create_guild_role(guild_id: int): """Add a role to a guild""" @@ -132,7 +65,7 @@ async def _role_update_dispatch(role_id: int, guild_id: int): """Dispatch a GUILD_ROLE_UPDATE with updated information on a role.""" role = await app.storage.get_role(role_id, guild_id) - await _maybe_lg(guild_id, "role_pos_upd", role) + await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role) await app.dispatcher.dispatch_guild( guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role} @@ -343,7 +276,7 @@ async def update_guild_role(guild_id, role_id): ) role = await _role_update_dispatch(role_id, guild_id) - await _maybe_lg(guild_id, "role_update", role, True) + await maybe_lazy_guild_dispatch(guild_id, "role_update", role, True) return jsonify(role) @@ -369,7 +302,7 @@ async def delete_guild_role(guild_id, role_id): if res == "DELETE 0": return "", 204 - await _maybe_lg(guild_id, "role_delete", role_id, True) + await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True) await app.dispatcher.dispatch_guild( guild_id, diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 763f18e..667e143 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -28,7 +28,7 @@ from ..schemas import validate, USER_UPDATE, GET_MENTIONS from .guilds import guild_check from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim -from litecord.blueprints.guild.mod import remove_member +from litecord.common.guilds import remove_member from litecord.enums import PremiumType from litecord.images import parse_data_uri @@ -319,7 +319,6 @@ async def leave_guild(guild_id: int): await guild_check(user_id, guild_id) await remove_member(guild_id, user_id) - return "", 204 diff --git a/litecord/blueprints/webhooks.py b/litecord/blueprints/webhooks.py index 47a1336..bb3668b 100644 --- a/litecord/blueprints/webhooks.py +++ b/litecord/blueprints/webhooks.py @@ -43,7 +43,7 @@ from litecord.snowflake import get_snowflake from litecord.utils import async_map from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest -from litecord.blueprints.channel.messages import ( +from litecord.common.messages import ( msg_create_request, msg_create_check_content, msg_add_attachment, diff --git a/litecord/errors.py b/litecord/errors.py index d2a72c2..1789a5e 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -116,6 +116,10 @@ class Forbidden(LitecordError): status_code = 403 +class ForbiddenDM(Forbidden): + error_code = 50007 + + class NotFound(LitecordError): status_code = 404 diff --git a/litecord/utils.py b/litecord/utils.py index a0f5587..15b89d2 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -23,7 +23,9 @@ from typing import Any, Iterable, Optional, Sequence, List, Dict, Union from logbook import Logger from quart.json import JSONEncoder -from quart import current_app as app +from quart import current_app as app, request + +from .errors import BadRequest log = Logger(__name__) @@ -233,3 +235,52 @@ def maybe_int(val: Any) -> Union[int, Any]: return int(val) except (ValueError, TypeError): return val + + +async def maybe_lazy_guild_dispatch( + guild_id: int, event: str, role, force: bool = False +): + # sometimes we want to dispatch an event + # even if the role isn't hoisted + + # an example of such a case is when a role loses + # its hoist status. + + # check if is a dict first because role_delete + # only receives the role id. + if isinstance(role, dict) and not role["hoist"] and not force: + return + + await app.dispatcher.dispatch("lazy_guild", guild_id, event, role) + + +def extract_limit(request_, default: int = 50, max_val: int = 100): + """Extract a limit kwarg.""" + try: + limit = int(request_.args.get("limit", default)) + + if limit not in range(0, max_val + 1): + raise ValueError() + except (TypeError, ValueError): + raise BadRequest("limit not int") + + return limit + + +def query_tuple_from_args(args: dict, limit: int) -> tuple: + """Extract a 2-tuple out of request arguments.""" + before, after = None, None + + if "around" in request.args: + average = int(limit / 2) + around = int(args["around"]) + + after = around - average + before = around + average + + elif "before" in args: + before = int(args["before"]) + elif "after" in args: + before = int(args["after"]) + + return before, after From 2ebb94f4768c6241e1e0318aad60de6aecb4bfc7 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 13:36:02 -0300 Subject: [PATCH 13/20] messages: use common/utils functions --- litecord/blueprints/channel/messages.py | 225 +----------------------- 1 file changed, 9 insertions(+), 216 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 1a3d6b5..8e375b1 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -17,65 +17,37 @@ along with this program. If not, see . """ -import json from pathlib import Path -from PIL import Image from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger from litecord.blueprints.auth import token_check from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.blueprints.dms import try_dm_state -from litecord.errors import MessageNotFound, Forbidden, BadRequest +from litecord.errors import MessageNotFound, Forbidden from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.snowflake import get_snowflake from litecord.schemas import validate, MESSAGE_CREATE -from litecord.utils import pg_set_json +from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit from litecord.permissions import get_permissions from litecord.embed.sanitizer import fill_embed from litecord.embed.messages import process_url_embed -from litecord.blueprints.channel.dm_checks import dm_pre_check +from litecord.common.channels import dm_pre_check, try_dm_state from litecord.images import try_unlink +from litecord.common.messages import ( + msg_create_request, + msg_create_check_content, + msg_add_attachment, + msg_guild_text_mentions, +) log = Logger(__name__) bp = Blueprint("channel_messages", __name__) -def extract_limit(request_, default: int = 50, max_val: int = 100): - """Extract a limit kwarg.""" - try: - limit = int(request_.args.get("limit", default)) - - if limit not in range(0, max_val + 1): - raise ValueError() - except (TypeError, ValueError): - raise BadRequest("limit not int") - - return limit - - -def query_tuple_from_args(args: dict, limit: int) -> tuple: - """Extract a 2-tuple out of request arguments.""" - before, after = None, None - - if "around" in request.args: - average = int(limit / 2) - around = int(args["around"]) - - after = around - average - before = around + average - - elif "before" in args: - before = int(args["before"]) - elif "after" in args: - before = int(args["after"]) - - return before, after - - @bp.route("//messages", methods=["GET"]) async def get_messages(channel_id): user_id = await token_check() @@ -204,185 +176,6 @@ async def create_message( return message_id -async def msg_guild_text_mentions( - payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool -): - """Calculates mention data side-effects.""" - channel_id = int(payload["channel_id"]) - - # calculate the user ids we'll bump the mention count for - uids = set() - - # first is extracting user mentions - for mention in payload["mentions"]: - uids.add(int(mention["id"])) - - # then role mentions - for role_mention in payload["mention_roles"]: - role_id = int(role_mention) - member_ids = await app.storage.get_role_members(role_id) - - for member_id in member_ids: - uids.add(member_id) - - # at-here only updates the state - # for the users that have a state - # in the channel. - if mentions_here: - uids = set() - - await app.db.execute( - """ - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE channel_id = $1 - """, - channel_id, - ) - - # at-here updates the read state - # for all users, including the ones - # that might not have read permissions - # to the channel. - if mentions_everyone: - uids = set() - - member_ids = await app.storage.get_member_ids(guild_id) - - await app.db.executemany( - """ - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE channel_id = $1 AND user_id = $2 - """, - [(channel_id, uid) for uid in member_ids], - ) - - for user_id in uids: - await app.db.execute( - """ - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE user_id = $1 - AND channel_id = $2 - """, - user_id, - channel_id, - ) - - -async def msg_create_request() -> tuple: - """Extract the json input and any file information - the client gave to us in the request. - - This only applies to create message route. - """ - form = await request.form - request_json = await request.get_json() or {} - - # NOTE: embed isn't set on form data - json_from_form = { - "content": form.get("content", ""), - "nonce": form.get("nonce", "0"), - "tts": json.loads(form.get("tts", "false")), - } - - payload_json = json.loads(form.get("payload_json", "{}")) - - json_from_form.update(request_json) - json_from_form.update(payload_json) - - files = await request.files - - # we don't really care about the given fields on the files dict, so - # we only extract the values - return json_from_form, [v for k, v in files.items()] - - -def msg_create_check_content(payload: dict, files: list, *, use_embeds=False): - """Check if there is actually any content being sent to us.""" - has_content = bool(payload.get("content", "")) - has_files = len(files) > 0 - - embed_field = "embeds" if use_embeds else "embed" - has_embed = embed_field in payload and payload.get(embed_field) is not None - - has_total_content = has_content or has_embed or has_files - - if not has_total_content: - raise BadRequest("No content has been provided.") - - -async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int: - """Add an attachment to a message. - - Parameters - ---------- - message_id: int - The ID of the message getting the attachment. - channel_id: int - The ID of the channel the message belongs to. - - Exists because the attachment URL scheme contains - a channel id. The purpose is unknown, but we are - implementing Discord's behavior. - attachment_file: quart.FileStorage - quart FileStorage instance of the file. - """ - - attachment_id = get_snowflake() - filename = attachment_file.filename - - # understand file info - mime = attachment_file.mimetype - is_image = mime.startswith("image/") - - img_width, img_height = None, None - - # extract file size - # TODO: this is probably inneficient - file_size = attachment_file.stream.getbuffer().nbytes - - if is_image: - # open with pillow, extract image size - image = Image.open(attachment_file.stream) - img_width, img_height = image.size - - # NOTE: DO NOT close the image, as closing the image will - # also close the stream. - - # reset it to 0 for later usage - attachment_file.stream.seek(0) - - await app.db.execute( - """ - INSERT INTO attachments - (id, channel_id, message_id, - filename, filesize, - image, width, height) - VALUES - ($1, $2, $3, $4, $5, $6, $7, $8) - """, - attachment_id, - channel_id, - message_id, - filename, - file_size, - is_image, - img_width, - img_height, - ) - - ext = filename.split(".")[-1] - - with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file: - attach_file.write(attachment_file.stream.read()) - - log.debug("written {} bytes for attachment id {}", file_size, attachment_id) - - return attachment_id - - async def _spawn_embed(payload, **kwargs): app.sched.spawn(process_url_embed(payload, **kwargs)) From b41d3b6f36f6c88872e8601f3b4892702f843693 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 13:39:22 -0300 Subject: [PATCH 14/20] guilds: use common, remove delete_guild() from it --- litecord/blueprints/guilds.py | 34 +++------------------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index ce2b78b..d5b9375 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -21,8 +21,7 @@ from typing import Optional, List from quart import Blueprint, request, current_app as app, jsonify -from litecord.blueprints.guild.channels import create_guild_channel -from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS +from litecord.common.guilds import create_role, create_guild_channel, delete_guild from ..auth import token_check from ..snowflake import get_snowflake @@ -34,12 +33,13 @@ from ..schemas import ( SEARCH_CHANNEL, VANITY_URL_PATCH, ) -from .channels import channel_ack from .checks import guild_check, guild_owner_check, guild_perm_check +from ..common.channels import channel_ack from litecord.utils import to_update, search_result_from_list from litecord.errors import BadRequest from litecord.permissions import get_permissions +DEFAULT_EVERYONE_PERMS = 104324161 bp = Blueprint("guilds", __name__) @@ -393,34 +393,6 @@ async def _update_guild(guild_id): return jsonify(guild) -async def delete_guild(guild_id: int): - """Delete a single guild.""" - await app.db.execute( - """ - DELETE FROM guilds - WHERE guilds.id = $1 - """, - guild_id, - ) - - # Discord's client expects IDs being string - await app.dispatcher.dispatch( - "guild", - guild_id, - "GUILD_DELETE", - { - "guild_id": str(guild_id), - "id": str(guild_id), - # 'unavailable': False, - }, - ) - - # remove from the dispatcher so nobody - # becomes the little memer that tries to fuck up with - # everybody's gateway - await app.dispatcher.remove("guild", guild_id) - - @bp.route("/", methods=["DELETE"]) # this endpoint is not documented, but used by the official client. @bp.route("//delete", methods=["POST"]) From ffa244173c89ba2b9bb8ce1d10dda26ab2d5254d Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 13:41:11 -0300 Subject: [PATCH 15/20] messages: remove double import --- litecord/blueprints/channel/messages.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 8e375b1..f46bae8 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -24,7 +24,6 @@ from logbook import Logger from litecord.blueprints.auth import token_check from litecord.blueprints.checks import channel_check, channel_perm_check -from litecord.blueprints.dms import try_dm_state from litecord.errors import MessageNotFound, Forbidden from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.snowflake import get_snowflake From f1e6baffd29cd2137b5b13d00a67dbf210631b1c Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 13:48:22 -0300 Subject: [PATCH 16/20] move user functions to common.users --- litecord/auth.py | 107 +-------------- litecord/blueprints/users.py | 143 +------------------- litecord/common/users.py | 250 +++++++++++++++++++++++++++++++++++ manage/cmd/users.py | 3 +- 4 files changed, 254 insertions(+), 249 deletions(-) create mode 100644 litecord/common/users.py diff --git a/litecord/auth.py b/litecord/auth.py index b96e475..0341fb9 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -20,15 +20,14 @@ along with this program. If not, see . import base64 import binascii from random import randint -from typing import Tuple, Optional +from typing import Tuple import bcrypt -from asyncpg import UniqueViolationError from itsdangerous import TimestampSigner, BadSignature from logbook import Logger from quart import request, current_app as app -from litecord.errors import Forbidden, Unauthorized, BadRequest +from litecord.errors import Forbidden, Unauthorized from litecord.snowflake import get_snowflake from litecord.enums import UserFlags @@ -150,105 +149,3 @@ async def hash_data(data: str, loop=None) -> str: hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14)) return hashed.decode() - - -async def check_username_usage(username: str): - """Raise an error if too many people are with the same username.""" - same_username = await app.db.fetchval( - """ - SELECT COUNT(*) - FROM users - WHERE username = $1 - """, - username, - ) - - if same_username > 9000: - raise BadRequest( - "Too many people.", - { - "username": "Too many people used the same username. " - "Please choose another" - }, - ) - - -def _raw_discrim() -> str: - discrim_number = randint(1, 9999) - return "%04d" % discrim_number - - -async def roll_discrim(username: str) -> Optional[str]: - """Roll a discriminator for a DiscordTag. - - Tries to generate one 10 times. - - Calls check_username_usage. - """ - - # we shouldn't roll discrims for usernames - # that have been used too much. - await check_username_usage(username) - - # max 10 times for a reroll - for _ in range(10): - # generate random discrim - discrim = _raw_discrim() - - # check if anyone is with it - res = await app.db.fetchval( - """ - SELECT id - FROM users - WHERE username = $1 AND discriminator = $2 - """, - username, - discrim, - ) - - # if no user is found with the (username, discrim) - # pair, then this is unique! return it. - if res is None: - return discrim - - return None - - -async def create_user(username: str, email: str, password: str) -> Tuple[int, str]: - """Create a single user. - - Generates a distriminator and other information. You can fetch the user - data back with :meth:`Storage.get_user`. - """ - db = app.db - loop = app.loop - - new_id = get_snowflake() - new_discrim = await roll_discrim(username) - - if new_discrim is None: - raise BadRequest( - "Unable to register.", - {"username": "Too many people are with this username."}, - ) - - pwd_hash = await hash_data(password, loop) - - try: - await db.execute( - """ - INSERT INTO users - (id, email, username, discriminator, password_hash) - VALUES - ($1, $2, $3, $4, $5) - """, - new_id, - email, - username, - new_discrim, - pwd_hash, - ) - except UniqueViolationError: - raise BadRequest("Email already used.") - - return new_id, pwd_hash diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 667e143..7009bd5 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -36,46 +36,12 @@ from litecord.permissions import base_permissions from litecord.blueprints.auth import check_password from litecord.utils import to_update +from litecord.common.users import mass_user_update, delete_user bp = Blueprint("user", __name__) log = Logger(__name__) -async def mass_user_update(user_id): - """Dispatch USER_UPDATE in a mass way.""" - # by using dispatch_with_filter - # we're guaranteeing all shards will get - # a USER_UPDATE once and not any others. - - session_ids = [] - - public_user = await app.storage.get_user(user_id) - private_user = await app.storage.get_user(user_id, secure=True) - - session_ids.extend( - await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user) - ) - - guild_ids = await app.user_storage.get_user_guilds(user_id) - friend_ids = await app.user_storage.get_friend_ids(user_id) - - session_ids.extend( - await app.dispatcher.dispatch_many_filter_list( - "guild", guild_ids, session_ids, "USER_UPDATE", public_user - ) - ) - - session_ids.extend( - await app.dispatcher.dispatch_many_filter_list( - "friend", friend_ids, session_ids, "USER_UPDATE", public_user - ) - ) - - await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id) - - return public_user, private_user - - @bp.route("/@me", methods=["GET"]) async def get_me(): """Get the current user's information.""" @@ -472,113 +438,6 @@ def rand_hex(length: int = 8) -> str: return urandom(length).hex()[:length] -async def _del_from_table(db, table: str, user_id: int): - """Delete a row from a table.""" - column = { - "channel_overwrites": "target_user", - "user_settings": "id", - "group_dm_members": "member_id", - }.get(table, "user_id") - - res = await db.execute( - f""" - DELETE FROM {table} - WHERE {column} = $1 - """, - user_id, - ) - - log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res) - - -async def delete_user(user_id, *, mass_update: bool = True): - """Delete a user. Does not disconnect the user.""" - db = app.db - - new_username = f"Deleted User {rand_hex()}" - - # by using a random hex in password_hash - # we break attempts at using the default '123' password hash - # to issue valid tokens for deleted users. - - await db.execute( - """ - UPDATE users - SET - username = $1, - email = NULL, - mfa_enabled = false, - verified = false, - avatar = NULL, - flags = 0, - premium_since = NULL, - phone = '', - password_hash = $2 - WHERE - id = $3 - """, - new_username, - rand_hex(32), - user_id, - ) - - # remove the user from various tables - await _del_from_table(db, "user_settings", user_id) - await _del_from_table(db, "user_payment_sources", user_id) - await _del_from_table(db, "user_subscriptions", user_id) - await _del_from_table(db, "user_payments", user_id) - await _del_from_table(db, "user_read_state", user_id) - await _del_from_table(db, "guild_settings", user_id) - await _del_from_table(db, "guild_settings_channel_overrides", user_id) - - await db.execute( - """ - DELETE FROM relationships - WHERE user_id = $1 OR peer_id = $1 - """, - user_id, - ) - - # DMs are still maintained, but not the state. - await _del_from_table(db, "dm_channel_state", user_id) - - # NOTE: we don't delete the group dms the user is an owner of... - # TODO: group dm owner reassign when the owner leaves a gdm - await _del_from_table(db, "group_dm_members", user_id) - - await _del_from_table(db, "members", user_id) - await _del_from_table(db, "member_roles", user_id) - await _del_from_table(db, "channel_overwrites", user_id) - - # after updating the user, we send USER_UPDATE so that all the other - # clients can refresh their caches on the now-deleted user - if mass_update: - await mass_user_update(user_id) - - -async def user_disconnect(user_id: int): - """Disconnects the given user's devices.""" - # after removing the user from all tables, we need to force - # all known user states to reconnect, causing the user to not - # be online anymore. - user_states = app.state_manager.user_states(user_id) - - for state in user_states: - # make it unable to resume - app.state_manager.remove(state) - - if not state.ws: - continue - - # force a close, 4000 should make the client reconnect. - await state.ws.ws.close(4000) - - # force everyone to see the user as offline - await app.presence.dispatch_pres( - user_id, {"afk": False, "status": "offline", "game": None, "since": 0} - ) - - @bp.route("/@me/delete", methods=["POST"]) async def delete_account(): """Delete own account. diff --git a/litecord/common/users.py b/litecord/common/users.py new file mode 100644 index 0000000..bd67dd8 --- /dev/null +++ b/litecord/common/users.py @@ -0,0 +1,250 @@ +from random import randint +from typing import Tuple, Optional + +from quart import current_app as app +from asyncpg import UniqueViolationError + +from ..snowflake import get_snowflake +from ..errors import BadRequest +from ..auth import hash_data + + +async def mass_user_update(user_id): + """Dispatch USER_UPDATE in a mass way.""" + # by using dispatch_with_filter + # we're guaranteeing all shards will get + # a USER_UPDATE once and not any others. + + session_ids = [] + + public_user = await app.storage.get_user(user_id) + private_user = await app.storage.get_user(user_id, secure=True) + + session_ids.extend( + await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user) + ) + + guild_ids = await app.user_storage.get_user_guilds(user_id) + friend_ids = await app.user_storage.get_friend_ids(user_id) + + session_ids.extend( + await app.dispatcher.dispatch_many_filter_list( + "guild", guild_ids, session_ids, "USER_UPDATE", public_user + ) + ) + + session_ids.extend( + await app.dispatcher.dispatch_many_filter_list( + "friend", friend_ids, session_ids, "USER_UPDATE", public_user + ) + ) + + await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id) + + return public_user, private_user + + +async def check_username_usage(username: str): + """Raise an error if too many people are with the same username.""" + same_username = await app.db.fetchval( + """ + SELECT COUNT(*) + FROM users + WHERE username = $1 + """, + username, + ) + + if same_username > 9000: + raise BadRequest( + "Too many people.", + { + "username": "Too many people used the same username. " + "Please choose another" + }, + ) + + +def _raw_discrim() -> str: + discrim_number = randint(1, 9999) + return "%04d" % discrim_number + + +async def roll_discrim(username: str) -> Optional[str]: + """Roll a discriminator for a DiscordTag. + + Tries to generate one 10 times. + + Calls check_username_usage. + """ + + # we shouldn't roll discrims for usernames + # that have been used too much. + await check_username_usage(username) + + # max 10 times for a reroll + for _ in range(10): + # generate random discrim + discrim = _raw_discrim() + + # check if anyone is with it + res = await app.db.fetchval( + """ + SELECT id + FROM users + WHERE username = $1 AND discriminator = $2 + """, + username, + discrim, + ) + + # if no user is found with the (username, discrim) + # pair, then this is unique! return it. + if res is None: + return discrim + + return None + + +async def create_user(username: str, email: str, password: str) -> Tuple[int, str]: + """Create a single user. + + Generates a distriminator and other information. You can fetch the user + data back with :meth:`Storage.get_user`. + """ + new_id = get_snowflake() + new_discrim = await roll_discrim(username) + + if new_discrim is None: + raise BadRequest( + "Unable to register.", + {"username": "Too many people are with this username."}, + ) + + pwd_hash = await hash_data(password) + + try: + await app.db.execute( + """ + INSERT INTO users + (id, email, username, discriminator, password_hash) + VALUES + ($1, $2, $3, $4, $5) + """, + new_id, + email, + username, + new_discrim, + pwd_hash, + ) + except UniqueViolationError: + raise BadRequest("Email already used.") + + return new_id, pwd_hash + + +async def _del_from_table(db, table: str, user_id: int): + """Delete a row from a table.""" + column = { + "channel_overwrites": "target_user", + "user_settings": "id", + "group_dm_members": "member_id", + }.get(table, "user_id") + + res = await db.execute( + f""" + DELETE FROM {table} + WHERE {column} = $1 + """, + user_id, + ) + + log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res) + + +async def delete_user(user_id, *, mass_update: bool = True): + """Delete a user. Does not disconnect the user.""" + db = app.db + + new_username = f"Deleted User {rand_hex()}" + + # by using a random hex in password_hash + # we break attempts at using the default '123' password hash + # to issue valid tokens for deleted users. + + await db.execute( + """ + UPDATE users + SET + username = $1, + email = NULL, + mfa_enabled = false, + verified = false, + avatar = NULL, + flags = 0, + premium_since = NULL, + phone = '', + password_hash = $2 + WHERE + id = $3 + """, + new_username, + rand_hex(32), + user_id, + ) + + # remove the user from various tables + await _del_from_table(db, "user_settings", user_id) + await _del_from_table(db, "user_payment_sources", user_id) + await _del_from_table(db, "user_subscriptions", user_id) + await _del_from_table(db, "user_payments", user_id) + await _del_from_table(db, "user_read_state", user_id) + await _del_from_table(db, "guild_settings", user_id) + await _del_from_table(db, "guild_settings_channel_overrides", user_id) + + await db.execute( + """ + DELETE FROM relationships + WHERE user_id = $1 OR peer_id = $1 + """, + user_id, + ) + + # DMs are still maintained, but not the state. + await _del_from_table(db, "dm_channel_state", user_id) + + # NOTE: we don't delete the group dms the user is an owner of... + # TODO: group dm owner reassign when the owner leaves a gdm + await _del_from_table(db, "group_dm_members", user_id) + + await _del_from_table(db, "members", user_id) + await _del_from_table(db, "member_roles", user_id) + await _del_from_table(db, "channel_overwrites", user_id) + + # after updating the user, we send USER_UPDATE so that all the other + # clients can refresh their caches on the now-deleted user + if mass_update: + await mass_user_update(user_id) + + +async def user_disconnect(user_id: int): + """Disconnects the given user's devices.""" + # after removing the user from all tables, we need to force + # all known user states to reconnect, causing the user to not + # be online anymore. + user_states = app.state_manager.user_states(user_id) + + for state in user_states: + # make it unable to resume + app.state_manager.remove(state) + + if not state.ws: + continue + + # force a close, 4000 should make the client reconnect. + await state.ws.ws.close(4000) + + # force everyone to see the user as offline + await app.presence.dispatch_pres( + user_id, {"afk": False, "status": "offline", "game": None, "since": 0} + ) diff --git a/manage/cmd/users.py b/manage/cmd/users.py index c20e181..50818e8 100644 --- a/manage/cmd/users.py +++ b/manage/cmd/users.py @@ -17,9 +17,8 @@ along with this program. If not, see . """ -from litecord.auth import create_user +from litecord.common.users import create_user, delete_user from litecord.blueprints.auth import make_token -from litecord.blueprints.users import delete_user from litecord.enums import UserFlags From 91d70d2c41d8a2c62a0da706c9c6f961dff5db81 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 13:52:33 -0300 Subject: [PATCH 17/20] fix imports to common.users --- litecord/blueprints/admin_api/users.py | 8 ++++++-- litecord/blueprints/auth.py | 3 ++- litecord/blueprints/users.py | 9 +++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/litecord/blueprints/admin_api/users.py b/litecord/blueprints/admin_api/users.py index b8b0acd..b1bbe98 100644 --- a/litecord/blueprints/admin_api/users.py +++ b/litecord/blueprints/admin_api/users.py @@ -20,13 +20,17 @@ along with this program. If not, see . from quart import Blueprint, jsonify, current_app as app, request from litecord.auth import admin_check -from litecord.blueprints.auth import create_user from litecord.schemas import validate from litecord.admin_schemas import USER_CREATE, USER_UPDATE from litecord.errors import BadRequest, Forbidden from litecord.utils import async_map -from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update from litecord.enums import UserFlags +from litecord.common.users import ( + create_user, + delete_user, + user_disconnect, + mass_user_update, +) bp = Blueprint("users_admin", __name__) diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index ba0a2a5..4b60a69 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -25,7 +25,8 @@ import bcrypt from quart import Blueprint, jsonify, request, current_app as app from logbook import Logger -from litecord.auth import token_check, create_user +from litecord.auth import token_check +from litecord.common.users import create_user from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE from litecord.errors import BadRequest from litecord.snowflake import get_snowflake diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 7009bd5..cf49d9f 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -27,7 +27,7 @@ from ..errors import Forbidden, BadRequest, Unauthorized from ..schemas import validate, USER_UPDATE, GET_MENTIONS from .guilds import guild_check -from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim +from litecord.auth import token_check, hash_data from litecord.common.guilds import remove_member from litecord.enums import PremiumType @@ -36,7 +36,12 @@ from litecord.permissions import base_permissions from litecord.blueprints.auth import check_password from litecord.utils import to_update -from litecord.common.users import mass_user_update, delete_user +from litecord.common.users import ( + mass_user_update, + delete_user, + check_username_usage, + roll_discrim, +) bp = Blueprint("user", __name__) log = Logger(__name__) From c765cd7fe08408581737bc2e56201e85b32b4190 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 16:01:48 -0300 Subject: [PATCH 18/20] fix imports --- litecord/auth.py | 3 --- litecord/blueprints/users.py | 7 +------ litecord/common/users.py | 23 +++++++++++++++++++++++ litecord/utils.py | 6 ++++++ tests/conftest.py | 3 +-- 5 files changed, 31 insertions(+), 11 deletions(-) diff --git a/litecord/auth.py b/litecord/auth.py index 0341fb9..849a45a 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -19,8 +19,6 @@ along with this program. If not, see . import base64 import binascii -from random import randint -from typing import Tuple import bcrypt from itsdangerous import TimestampSigner, BadSignature @@ -28,7 +26,6 @@ from logbook import Logger from quart import request, current_app as app from litecord.errors import Forbidden, Unauthorized -from litecord.snowflake import get_snowflake from litecord.enums import UserFlags diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index cf49d9f..9223c61 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -17,7 +17,6 @@ along with this program. If not, see . """ -from os import urandom from asyncpg import UniqueViolationError from quart import Blueprint, jsonify, request, current_app as app @@ -41,6 +40,7 @@ from litecord.common.users import ( delete_user, check_username_usage, roll_discrim, + user_disconnect, ) bp = Blueprint("user", __name__) @@ -438,11 +438,6 @@ async def _get_mentions(): return jsonify(res) -def rand_hex(length: int = 8) -> str: - """Generate random hex characters.""" - return urandom(length).hex()[:length] - - @bp.route("/@me/delete", methods=["POST"]) async def delete_account(): """Delete own account. diff --git a/litecord/common/users.py b/litecord/common/users.py index bd67dd8..2b36082 100644 --- a/litecord/common/users.py +++ b/litecord/common/users.py @@ -1,3 +1,23 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +import logging from random import randint from typing import Tuple, Optional @@ -7,6 +27,9 @@ from asyncpg import UniqueViolationError from ..snowflake import get_snowflake from ..errors import BadRequest from ..auth import hash_data +from ..utils import rand_hex + +log = logging.getLogger(__name__) async def mass_user_update(user_id): diff --git a/litecord/utils.py b/litecord/utils.py index 15b89d2..91e7d8a 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -19,6 +19,7 @@ along with this program. If not, see . import asyncio import json +import secrets from typing import Any, Iterable, Optional, Sequence, List, Dict, Union from logbook import Logger @@ -284,3 +285,8 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple: before = int(args["after"]) return before, after + + +def rand_hex(length: int = 8) -> str: + """Generate random hex characters.""" + return secrets.token_hex(length)[:length] diff --git a/tests/conftest.py b/tests/conftest.py index 2a48d5c..f0444c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,10 +30,9 @@ from tests.common import email, TestClient from run import app as main_app, set_blueprints -from litecord.auth import create_user +from litecord.common.users import create_user, delete_user from litecord.enums import UserFlags from litecord.blueprints.auth import make_token -from litecord.blueprints.users import delete_user @pytest.fixture(name="app") From 7c3e9ec2c71b0b89ecb0026d47e04605c44df027 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 16:19:18 -0300 Subject: [PATCH 19/20] move create_guild_settings to common.guilds --- litecord/blueprints/guilds.py | 35 ++++++---------------------------- litecord/blueprints/invites.py | 2 +- litecord/common/guilds.py | 28 +++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index d5b9375..9317429 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -21,7 +21,12 @@ from typing import Optional, List from quart import Blueprint, request, current_app as app, jsonify -from litecord.common.guilds import create_role, create_guild_channel, delete_guild +from litecord.common.guilds import ( + create_role, + create_guild_channel, + delete_guild, + create_guild_settings, +) from ..auth import token_check from ..snowflake import get_snowflake @@ -44,34 +49,6 @@ DEFAULT_EVERYONE_PERMS = 104324161 bp = Blueprint("guilds", __name__) -async def create_guild_settings(guild_id: int, user_id: int): - """Create guild settings for the user - joining the guild.""" - - # new guild_settings are based off the currently - # set guild settings (for the guild) - m_notifs = await app.db.fetchval( - """ - SELECT default_message_notifications - FROM guilds - WHERE id = $1 - """, - guild_id, - ) - - await app.db.execute( - """ - INSERT INTO guild_settings - (user_id, guild_id, message_notifications) - VALUES - ($1, $2, $3) - """, - user_id, - guild_id, - m_notifs, - ) - - async def add_member(guild_id: int, user_id: int): """Add a user to a guild.""" await app.db.execute( diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index 02b7610..0a9f7d7 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -28,7 +28,6 @@ from ..auth import token_check from ..schemas import validate, INVITE from ..enums import ChannelType from ..errors import BadRequest, Forbidden -from .guilds import create_guild_settings from ..utils import async_map from litecord.blueprints.checks import ( @@ -39,6 +38,7 @@ from litecord.blueprints.checks import ( ) from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient +from litecord.common.guilds import create_guild_settings log = Logger(__name__) bp = Blueprint("invites", __name__) diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index f1915e0..d7d1e25 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -204,3 +204,31 @@ async def delete_guild(guild_id: int): # becomes the little memer that tries to fuck up with # everybody's gateway await app.dispatcher.remove("guild", guild_id) + + +async def create_guild_settings(guild_id: int, user_id: int): + """Create guild settings for the user + joining the guild.""" + + # new guild_settings are based off the currently + # set guild settings (for the guild) + m_notifs = await app.db.fetchval( + """ + SELECT default_message_notifications + FROM guilds + WHERE id = $1 + """, + guild_id, + ) + + await app.db.execute( + """ + INSERT INTO guild_settings + (user_id, guild_id, message_notifications) + VALUES + ($1, $2, $3) + """, + user_id, + guild_id, + m_notifs, + ) From ca7f1eae6bc898ccb7f929ff553aef124cbd5d9c Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 16:44:39 -0300 Subject: [PATCH 20/20] user.billing: fix import --- litecord/blueprints/user/billing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py index 2ac1972..3bf770d 100644 --- a/litecord/blueprints/user/billing.py +++ b/litecord/blueprints/user/billing.py @@ -30,7 +30,7 @@ from litecord.snowflake import snowflake_datetime, get_snowflake from litecord.errors import BadRequest from litecord.types import timestamp_, HOURS from litecord.enums import UserFlags, PremiumType -from litecord.blueprints.users import mass_user_update +from litecord.common.users import mass_user_update log = Logger(__name__) bp = Blueprint("users_billing", __name__)