From b0eb3247fdfaa4e795f8a9107e87deac8392fe31 Mon Sep 17 00:00:00 2001 From: Luna <508270-luna@users.noreply.gitlab.com> Date: Sun, 9 Feb 2020 21:20:08 +0000 Subject: [PATCH] remove code from dispatcher leftovers are TBD. - constrict Dispatcher.dispatch() to arity 3 - add helper methods to Dispatcher - add EventType to Dispatcher While fixing things, it was discovered that many of the things inside LazyGuildDispatcher were just interfaces to GuildMemberList, in a very weird way, just so it could be fitted inside the main Dispatcher. it was decided to remove those unecessary interfaces, clients shall use the manager directly. --- Pipfile | 2 +- Pipfile.lock | 84 +++++------ litecord/blueprints/admin_api/features.py | 2 +- litecord/blueprints/admin_api/guilds.py | 4 +- litecord/blueprints/auth.py | 5 +- litecord/blueprints/channel/messages.py | 28 ++-- litecord/blueprints/channel/pins.py | 24 ++- litecord/blueprints/channel/reactions.py | 13 +- litecord/blueprints/channels.py | 120 +++++++++++---- litecord/blueprints/dm_channels.py | 39 +++-- litecord/blueprints/guild/channels.py | 16 +- litecord/blueprints/guild/emoji.py | 15 +- litecord/blueprints/guild/members.py | 18 +-- litecord/blueprints/guild/mod.py | 16 +- litecord/blueprints/guild/roles.py | 9 +- litecord/blueprints/guilds.py | 7 +- litecord/blueprints/invites.py | 2 +- litecord/blueprints/relationships.py | 89 ++++++----- litecord/blueprints/user/settings.py | 9 +- litecord/blueprints/webhooks.py | 11 +- litecord/common/channels.py | 20 ++- litecord/common/guilds.py | 104 ++++++++----- litecord/common/users.py | 45 +++--- litecord/dispatcher.py | 175 +--------------------- litecord/embed/messages.py | 4 +- litecord/gateway/gateway.py | 16 +- litecord/gateway/state_manager.py | 14 ++ litecord/gateway/websocket.py | 75 ++++++---- litecord/presence.py | 24 ++- litecord/pubsub/__init__.py | 10 +- litecord/pubsub/channel.py | 97 ++++++------ litecord/pubsub/dispatcher.py | 171 ++++++++++----------- litecord/pubsub/friend.py | 25 ++-- litecord/pubsub/guild.py | 142 ++++++------------ litecord/pubsub/lazy_guild.py | 73 +++------ litecord/pubsub/member.py | 36 ++--- litecord/pubsub/user.py | 34 +++-- litecord/pubsub/utils.py | 41 +++++ litecord/schemas.py | 3 + litecord/storage.py | 4 +- litecord/system_messages.py | 5 +- litecord/utils.py | 2 +- litecord/voice/lvsp_manager.py | 13 +- litecord/voice/manager.py | 3 +- manage/main.py | 2 +- run.py | 8 +- 46 files changed, 788 insertions(+), 871 deletions(-) create mode 100644 litecord/pubsub/utils.py diff --git a/Pipfile b/Pipfile index d2c026e..db17fc3 100644 --- a/Pipfile +++ b/Pipfile @@ -18,7 +18,7 @@ zstandard = "*" winter = {editable = true,git = "https://gitlab.com/elixire/winter.git"} [dev-packages] -pytest = "==5.1.2" +pytest = "==5.3.2" pytest-asyncio = "==0.10.0" mypy = "*" black = "*" diff --git a/Pipfile.lock b/Pipfile.lock index 6230024..6772234 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3" + "sha256": "dedc41184df539a608717e68c108ccfb4d529acb1d1702d83de223840c7cc754" }, "pipfile-spec": 6, "requires": { @@ -123,41 +123,36 @@ }, "cffi": { "hashes": [ - "sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42", - "sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04", - "sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5", - "sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54", - "sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba", - "sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57", - "sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396", - "sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12", - "sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97", - "sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43", - "sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db", - "sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3", - "sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b", - "sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579", - "sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346", - "sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159", - "sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652", - "sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e", - "sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a", - "sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506", - "sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f", - "sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d", - "sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c", - "sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20", - "sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858", - "sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc", - "sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a", - "sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3", - "sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e", - "sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410", - "sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25", - "sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b", - "sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d" + "sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff", + "sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b", + "sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac", + "sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0", + "sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384", + "sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26", + "sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6", + "sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b", + "sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e", + "sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd", + "sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2", + "sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66", + "sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc", + "sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8", + "sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55", + "sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4", + "sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5", + "sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d", + "sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78", + "sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa", + "sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793", + "sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f", + "sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a", + "sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f", + "sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30", + "sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f", + "sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3", + "sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c" ], - "version": "==1.13.2" + "version": "==1.14.0" }, "chardet": { "hashes": [ @@ -195,10 +190,10 @@ }, "h2": { "hashes": [ - "sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e", - "sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4" + "sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5", + "sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14" ], - "version": "==3.1.1" + "version": "==3.2.0" }, "hpack": { "hashes": [ @@ -510,13 +505,6 @@ ], "version": "==1.4.3" }, - "atomicwrites": { - "hashes": [ - "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4", - "sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6" - ], - "version": "==1.3.0" - }, "attrs": { "hashes": [ "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", @@ -647,11 +635,11 @@ }, "pytest": { "hashes": [ - "sha256:95d13143cc14174ca1a01ec68e84d76ba5d9d493ac02716fd9706c949a505210", - "sha256:b78fe2881323bd44fd9bd76e5317173d4316577e7b1cddebae9136a4495ec865" + "sha256:6b571215b5a790f9b41f19f3531c53a45cf6bb8ef2988bc1ff9afb38270b25fa", + "sha256:e41d489ff43948babd0fad7ad5e49b8735d5d55e26628a58673c39ff61d95de4" ], "index": "pypi", - "version": "==5.1.2" + "version": "==5.3.2" }, "pytest-asyncio": { "hashes": [ diff --git a/litecord/blueprints/admin_api/features.py b/litecord/blueprints/admin_api/features.py index 314bb44..2d9d655 100644 --- a/litecord/blueprints/admin_api/features.py +++ b/litecord/blueprints/admin_api/features.py @@ -68,7 +68,7 @@ async def _update_features(guild_id: int, features: list): ) guild = await app.storage.get_guild_full(guild_id) - await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild)) @bp.route("//features", methods=["PATCH"]) diff --git a/litecord/blueprints/admin_api/guilds.py b/litecord/blueprints/admin_api/guilds.py index 15f2647..153bdff 100644 --- a/litecord/blueprints/admin_api/guilds.py +++ b/litecord/blueprints/admin_api/guilds.py @@ -63,10 +63,10 @@ async def update_guild(guild_id: int): if old_unavailable and not new_unavailable: # guild became available - await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild)) else: # guild became unavailable - await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_DELETE", guild)) return jsonify(guild) diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index 497306d..aae2143 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -24,12 +24,13 @@ import itsdangerous import bcrypt from quart import Blueprint, jsonify, request, current_app as app from logbook import Logger +from winter import get_snowflake from litecord.auth import token_check from litecord.common.users import create_user from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE from litecord.errors import BadRequest -from winter import get_snowflake +from litecord.pubsub.user import dispatch_user from .invites import use_invite log = Logger(__name__) @@ -172,7 +173,7 @@ async def verify_user(): ) new_user = await app.storage.get_user(user_id, True) - await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user) + await dispatch_user(user_id, ("USER_UPDATE", new_user)) return "", 204 diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 4a20706..d54f986 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -42,6 +42,7 @@ from litecord.common.messages import ( msg_add_attachment, msg_guild_text_mentions, ) +from litecord.pubsub.user import dispatch_user log = Logger(__name__) @@ -136,10 +137,10 @@ async def _dm_pre_dispatch(channel_id, peer_id): # dispatch CHANNEL_CREATE so the client knows which # channel the future event is about - await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan) + await dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan) # subscribe the peer to the channel - await app.dispatcher.sub("channel", channel_id, peer_id) + await app.dispatcher.channel.sub(channel_id, peer_id) # insert it on dm_channel_state so the client # is subscribed on the future @@ -242,7 +243,7 @@ async def _create_message(channel_id): await _dm_pre_dispatch(channel_id, user_id) await _dm_pre_dispatch(channel_id, guild_id) - await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload)) # spawn url processor for embedding of images perms = await get_permissions(user_id, channel_id) @@ -350,7 +351,7 @@ async def edit_message(channel_id, message_id): # only dispatch MESSAGE_UPDATE if any update # actually happened if updated: - await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message)) return jsonify(message) @@ -427,16 +428,17 @@ async def delete_message(channel_id, message_id): message_id, ) - await app.dispatcher.dispatch( - "channel", + await app.dispatcher.channel.dispatch( channel_id, - "MESSAGE_DELETE", - { - "id": str(message_id), - "channel_id": str(channel_id), - # for lazy guilds - "guild_id": str(guild_id), - }, + ( + "MESSAGE_DELETE", + { + "id": str(message_id), + "channel_id": str(channel_id), + # for lazy guilds + "guild_id": str(guild_id), + }, + ), ) return "", 204 diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py index 9153ffc..0f11812 100644 --- a/litecord/blueprints/channel/pins.py +++ b/litecord/blueprints/channel/pins.py @@ -106,11 +106,15 @@ async def add_pin(channel_id, message_id): timestamp = snowflake_datetime(row["message_id"]) - await app.dispatcher.dispatch( - "channel", + await app.dispatcher.channel.dispatch( channel_id, - "CHANNEL_PINS_UPDATE", - {"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)}, + ( + "CHANNEL_PINS_UPDATE", + { + "channel_id": str(channel_id), + "last_pin_timestamp": timestamp_(timestamp), + }, + ), ) await send_sys_message( @@ -149,11 +153,15 @@ async def delete_pin(channel_id, message_id): timestamp = snowflake_datetime(row["message_id"]) - await app.dispatcher.dispatch( - "channel", + await app.dispatcher.channel.dispatch( channel_id, - "CHANNEL_PINS_UPDATE", - {"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()}, + ( + "CHANNEL_PINS_UPDATE", + { + "channel_id": str(channel_id), + "last_pin_timestamp": timestamp.isoformat(), + }, + ), ) return "", 204 diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py index 3531b0e..3fa18a6 100644 --- a/litecord/blueprints/channel/reactions.py +++ b/litecord/blueprints/channel/reactions.py @@ -141,10 +141,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str): if ctype in GUILD_CHANS: payload["guild_id"] = str(guild_id) - await app.dispatcher.dispatch( - "channel", channel_id, "MESSAGE_REACTION_ADD", payload - ) - + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_REACTION_ADD", payload)) return "", 204 @@ -206,8 +203,8 @@ async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji if ctype in GUILD_CHANS: payload["guild_id"] = str(guild_id) - await app.dispatcher.dispatch( - "channel", channel_id, "MESSAGE_REACTION_REMOVE", payload + await app.dispatcher.channel.dispatch( + channel_id, ("MESSAGE_REACTION_REMOVE", payload) ) @@ -290,6 +287,6 @@ async def remove_all_reactions(channel_id, message_id): if ctype in GUILD_CHANS: payload["guild_id"] = str(guild_id) - await app.dispatcher.dispatch( - "channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload + await app.dispatcher.channel.dispatch( + channel_id, ("MESSAGE_REACTION_REMOVE_ALL", payload) ) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 386f333..d1f93db 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -20,9 +20,11 @@ along with this program. If not, see . import time import datetime from typing import List, Optional +from dataclasses import dataclass from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger +from winter import snowflake_datetime from litecord.auth import token_check from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags @@ -41,8 +43,9 @@ from litecord.system_messages import send_sys_message from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy from litecord.utils import search_result_from_list from litecord.embed.messages import process_url_embed, msg_update_embeds -from winter import snowflake_datetime from litecord.common.channels import channel_ack +from litecord.pubsub.user import dispatch_user +from litecord.permissions import get_permissions, Permissions log = Logger(__name__) bp = Blueprint("channels", __name__) @@ -80,9 +83,7 @@ async def __guild_chan_sql(guild_id, channel_id, field: str) -> str: async def _update_guild_chan_text(guild_id: int, channel_id: int): res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id") - res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id") - res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id") # if none of them were actually updated, @@ -93,7 +94,7 @@ async def _update_guild_chan_text(guild_id: int, channel_id: int): # at least one of the fields were updated, # dispatch GUILD_UPDATE guild = await app.storage.get_guild(guild_id) - await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild)) async def _update_guild_chan_voice(guild_id: int, channel_id: int): @@ -104,7 +105,7 @@ async def _update_guild_chan_voice(guild_id: int, channel_id: int): return guild = await app.storage.get_guild(guild_id) - await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) + await app.dispatcher.dispatch(guild_id, ("GUILD_UPDATE", guild)) async def _update_guild_chan_cat(guild_id: int, channel_id: int): @@ -134,7 +135,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int): # tell all people in the guild of the category removal for child_id in childs: child = await app.storage.get_channel(child_id) - await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child) + await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", child)) async def _delete_messages(channel_id): @@ -249,12 +250,10 @@ async def close_channel(channel_id): ) # clean its member list representation - lazy_guilds = app.dispatcher.backends["lazy_guild"] - lazy_guilds.remove_channel(channel_id) + app.lazy_guild.remove_channel(channel_id) - await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan) - - await app.dispatcher.remove("channel", channel_id) + await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_DELETE", chan)) + await app.dispatcher.channel.drop(channel_id) return jsonify(chan) if ctype == ChannelType.DM: @@ -273,11 +272,9 @@ async def close_channel(channel_id): channel_id, ) - # unsubscribe - await app.dispatcher.unsub("channel", channel_id, user_id) - # nothing happens to the other party of the dm channel - await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan) + await app.dispatcher.channel.unsub(channel_id, user_id) + await dispatch_user(user_id, ("CHANNEL_DELETE", chan)) return jsonify(chan) @@ -318,10 +315,54 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]): continue chan = await app.storage.get_channel(channel_id) - await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan) + await app.dispatcher.guild.dispatch(guild_id, "CHANNEL_UPDATE", chan) -async def _process_overwrites(channel_id: int, overwrites: list): +@dataclass +class Target: + type: int + user_id: Optional[int] + role_id: Optional[int] + + @property + def is_user(self): + return self.type == 0 + + @property + def is_role(self): + return self.type == 1 + + +async def _dispatch_action(guild_id: int, channel_id: int, user_id: int, perms) -> None: + """Apply an action of sub/unsub to all states of a user.""" + states = app.state_manager.fetch_states(user_id, guild_id) + for state in states: + if perms.read_messages: + await app.dispatcher.channel.sub(channel_id, state.session_id) + else: + await app.dispatcher.channel.unsub(channel_id, state.session_id) + + +async def _process_overwrites(guild_id: int, channel_id: int, overwrites: list) -> None: + # user_ids serves as a "prospect" user id list. + # for each overwrite we apply, we fill this list with user ids we + # want to check later to subscribe/unsubscribe from the channel. + # (users without read_messages are automatically unsubbed since we + # don't want to leak messages to them when they dont have perms anymore) + + # the expensiveness of large overwrite/role chains shines here. + # since each user id we fill in implies an entire get_permissions call + # (because we don't have the answer if a user is to be subbed/unsubbed + # with only overwrites, an overwrite for a user allowing them might be + # overwritten by a role overwrite denying them if they have the role), + # we get a lot of tension on that, causing channel updates to lag a bit. + + # there may be some good optimizations to do here, such as doing some + # precalculations like fetching get_permissions for everyone first, then + # applying the new overwrites one by one, then subbing/unsubbing at the + # end, but it would be very memory intensive. + user_ids: List[int] = [] + for overwrite in overwrites: # 0 for member overwrite, 1 for role overwrite @@ -329,7 +370,9 @@ async def _process_overwrites(channel_id: int, overwrites: list): target_role = None if target_type == 0 else overwrite["id"] target_user = overwrite["id"] if target_type == 0 else None - col_name = "target_user" if target_type == 0 else "target_role" + target = Target(target_type, target_role, target_user) + + col_name = "target_user" if target.is_user else "target_role" constraint_name = f"channel_overwrites_{col_name}_uniq" await app.db.execute( @@ -352,6 +395,17 @@ async def _process_overwrites(channel_id: int, overwrites: list): overwrite["deny"], ) + if target.is_user: + perms = Permissions(overwrite["allow"] & ~overwrite["deny"]) + await _dispatch_action(guild_id, channel_id, target.user_id, perms) + + elif target.is_role: + user_ids.extend(await app.storage.get_role_members(target.role_id)) + + for user_id in user_ids: + perms = await get_permissions(user_id, channel_id) + await _dispatch_action(guild_id, channel_id, target.user_id, perms) + @bp.route("//permissions/", methods=["PUT"]) async def put_channel_overwrite(channel_id: int, overwrite_id: int): @@ -371,6 +425,7 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int): ) await _process_overwrites( + guild_id, channel_id, [ { @@ -447,7 +502,7 @@ async def _update_channel_common(channel_id: int, guild_id: int, j: dict): if "channel_overwrites" in j: overwrites = j["channel_overwrites"] - await _process_overwrites(channel_id, overwrites) + await _process_overwrites(guild_id, channel_id, overwrites) async def _common_guild_chan(channel_id, j: dict): @@ -566,9 +621,9 @@ async def update_channel(channel_id: int): chan = await app.storage.get_channel(channel_id) if is_guild: - await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan) + await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan)) else: - await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan) + await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan)) return jsonify(chan) @@ -578,17 +633,18 @@ async def trigger_typing(channel_id): user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - await app.dispatcher.dispatch( - "channel", + await app.dispatcher.channel.dispatch( channel_id, - "TYPING_START", - { - "channel_id": str(channel_id), - "user_id": str(user_id), - "timestamp": int(time.time()), - # guild_id for lazy guilds - "guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None, - }, + ( + "TYPING_START", + { + "channel_id": str(channel_id), + "user_id": str(user_id), + "timestamp": int(time.time()), + # guild_id for lazy guilds + "guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None, + }, + ), ) return "", 204 @@ -816,5 +872,5 @@ async def bulk_delete(channel_id: int): if res == "DELETE 0": raise BadRequest("No messages were removed") - await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_DELETE_BULK", payload)) return "", 204 diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py index 4530d68..f5e25c6 100644 --- a/litecord/blueprints/dm_channels.py +++ b/litecord/blueprints/dm_channels.py @@ -27,6 +27,7 @@ from litecord.errors import BadRequest, Forbidden from winter import get_snowflake from litecord.system_messages import send_sys_message from litecord.pubsub.channel import gdm_recipient_view +from litecord.pubsub.user import dispatch_user log = Logger(__name__) bp = Blueprint("dm_channels", __name__) @@ -82,11 +83,11 @@ async def gdm_create(user_id, peer_id) -> int: await _raw_gdm_add(channel_id, user_id) await _raw_gdm_add(channel_id, peer_id) - await app.dispatcher.sub("channel", channel_id, user_id) - await app.dispatcher.sub("channel", channel_id, peer_id) + await app.dispatcher.channel.sub(channel_id, user_id) + await app.dispatcher.channel.sub(channel_id, peer_id) chan = await app.storage.get_channel(channel_id) - await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan) + await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_CREATE", chan)) return channel_id @@ -104,13 +105,10 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None): chan = await app.storage.get_channel(channel_id) # the reasoning behind gdm_recipient_view is in its docstring. - await app.dispatcher.dispatch( - "user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id) - ) + await dispatch_user(peer_id, ("CHANNEL_CREATE", gdm_recipient_view(chan, peer_id))) - await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan) - - await app.dispatcher.sub("channel", peer_id) + await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan)) + await app.dispatcher.channel.sub(peer_id) if user_id: await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id) @@ -128,17 +126,19 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None): await _raw_gdm_remove(channel_id, peer_id) chan = await app.storage.get_channel(channel_id) - await app.dispatcher.dispatch( - "user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id) - ) + await dispatch_user(peer_id, ("CHANNEL_DELETE", gdm_recipient_view(chan, user_id))) - await app.dispatcher.unsub("channel", peer_id) + await app.dispatcher.channel.unsub(peer_id) - await app.dispatcher.dispatch( - "channel", + await app.dispatcher.channel.dispatch( channel_id, - "CHANNEL_RECIPIENT_REMOVE", - {"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)}, + ( + "CHANNEL_RECIPIENT_REMOVE", + { + "channel_id": str(channel_id), + "user": await app.storage.get_user(peer_id), + }, + ), ) author_id = peer_id if user_id is None else user_id @@ -174,9 +174,8 @@ async def gdm_destroy(channel_id): channel_id, ) - await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan) - - await app.dispatcher.remove("channel", channel_id) + await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_DELETE", chan)) + await app.dispatcher.channel.drop(channel_id) async def gdm_is_member(channel_id: int, user_id: int) -> bool: diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py index 4f00e10..c05b3ed 100644 --- a/litecord/blueprints/guild/channels.py +++ b/litecord/blueprints/guild/channels.py @@ -59,19 +59,9 @@ async def create_channel(guild_id): new_channel_id = get_snowflake() await create_guild_channel(guild_id, new_channel_id, channel_type, **j) - # TODO: do a better method - # subscribe the currently subscribed users to the new channel - # by getting all user ids and subscribing each one by one. - - # since GuildDispatcher calls Storage.get_channel_ids, - # it will subscribe all users to the newly created channel. - guild_pubsub = app.dispatcher.backends["guild"] - user_ids = guild_pubsub.state[guild_id] - for uid in user_ids: - await app.dispatcher.sub("guild", guild_id, uid) - chan = await app.storage.get_channel(new_channel_id) - await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan) + await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_CREATE", chan)) + return jsonify(chan) @@ -79,7 +69,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int): """Fetch new information about the channel and dispatch a single CHANNEL_UPDATE event to the guild.""" chan = await app.storage.get_channel(channel_id) - await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan) + await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan)) async def _do_single_swap(guild_id: int, pair: tuple): diff --git a/litecord/blueprints/guild/emoji.py b/litecord/blueprints/guild/emoji.py index 31e458c..db53c07 100644 --- a/litecord/blueprints/guild/emoji.py +++ b/litecord/blueprints/guild/emoji.py @@ -32,14 +32,15 @@ bp = Blueprint("guild.emoji", __name__) async def _dispatch_emojis(guild_id): """Dispatch a Guild Emojis Update payload to a guild.""" - await app.dispatcher.dispatch( - "guild", + await app.dispatcher.guild.dispatch( guild_id, - "GUILD_EMOJIS_UPDATE", - { - "guild_id": str(guild_id), - "emojis": await app.storage.get_guild_emojis(guild_id), - }, + ( + "GUILD_EMOJIS_UPDATE", + { + "guild_id": str(guild_id), + "emojis": await app.storage.get_guild_emojis(guild_id), + }, + ), ) diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py index e051f51..dcf8722 100644 --- a/litecord/blueprints/guild/members.py +++ b/litecord/blueprints/guild/members.py @@ -192,12 +192,9 @@ async def modify_guild_member(guild_id, member_id): if nick_flag: partial["nick"] = j["nick"] - await app.dispatcher.dispatch( - "lazy_guild", guild_id, "pres_update", user_id, partial - ) - - await app.dispatcher.dispatch_guild( - guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member} + await app.lazy_guild.pres_update(guild_id, user_id, partial) + await app.dispatcher.guild.dispatch( + guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}) ) return "", 204 @@ -228,12 +225,9 @@ async def update_nickname(guild_id): member.pop("joined_at") # call pres_update for nick changes, etc. - await app.dispatcher.dispatch( - "lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]} - ) - - await app.dispatcher.dispatch_guild( - guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member} + await app.lazy_guild.pres_update(guild_id, user_id, {"nick": j["nick"]}) + await app.dispatcher.guild.dispatch( + guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}) ) return j["nick"] diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index 4032bb0..e294b61 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -86,10 +86,12 @@ async def create_ban(guild_id, member_id): await remove_member(guild_id, member_id) - await app.dispatcher.dispatch_guild( + await app.dispatcher.guild.dispatch( guild_id, - "GUILD_BAN_ADD", - {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, + ( + "GUILD_BAN_ADD", + {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, + ), ) return "", 204 @@ -115,10 +117,12 @@ async def remove_ban(guild_id, banned_id): if res == "DELETE 0": return "", 204 - await app.dispatcher.dispatch_guild( + await app.dispatcher.guild.dispatch( guild_id, - "GUILD_BAN_REMOVE", - {"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)}, + ( + "GUILD_BAN_REMOVE", + {"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)}, + ), ) return "", 204 diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py index c790fcb..1709756 100644 --- a/litecord/blueprints/guild/roles.py +++ b/litecord/blueprints/guild/roles.py @@ -67,8 +67,8 @@ async def _role_update_dispatch(role_id: int, guild_id: int): await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role) - await app.dispatcher.dispatch_guild( - guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role} + await app.dispatcher.guild.dispatch( + guild_id, ("GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}) ) return role @@ -304,10 +304,9 @@ async def delete_guild_role(guild_id, role_id): await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True) - await app.dispatcher.dispatch_guild( + await app.dispatcher.guild.dispatch( guild_id, - "GUILD_ROLE_DELETE", - {"guild_id": str(guild_id), "role_id": str(role_id)}, + ("GUILD_ROLE_DELETE", {"guild_id": str(guild_id), "role_id": str(role_id)},), ) return "", 204 diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index 5d69ffb..546fe01 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -191,8 +191,9 @@ async def create_guild(): guild_total = await app.storage.get_guild_full(guild_id, user_id, 250) - await app.dispatcher.sub("guild", guild_id, user_id) - await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total) + await app.dispatcher.guild.sub_user(guild_id, user_id) + + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild_total)) return jsonify(guild_total) @@ -350,7 +351,7 @@ async def _update_guild(guild_id): ) guild = await app.storage.get_guild_full(guild_id, user_id) - await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild)) return jsonify(guild) diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index 01f4c67..aa7dc99 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -116,7 +116,7 @@ async def _inv_check_age(inv: dict): await delete_invite(inv["code"]) raise InvalidInvite("Invite is expired") - if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]: + if inv["max_uses"] != -1 and inv["uses"] > inv["max_uses"]: await delete_invite(inv["code"]) raise InvalidInvite("Too many uses") diff --git a/litecord/blueprints/relationships.py b/litecord/blueprints/relationships.py index 1fa9d01..85e14b8 100644 --- a/litecord/blueprints/relationships.py +++ b/litecord/blueprints/relationships.py @@ -24,6 +24,7 @@ from ..auth import token_check from ..schemas import validate, RELATIONSHIP, SPECIFIC_FRIEND from ..enums import RelationshipType from litecord.errors import BadRequest +from litecord.pubsub.user import dispatch_user bp = Blueprint("relationship", __name__) @@ -36,17 +37,17 @@ async def get_me_relationships(): async def _dispatch_single_pres(user_id, presence: dict): - await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence) + await dispatch_user(user_id, ("PRESENCE_UPDATE", presence)) async def _unsub_friend(user_id, peer_id): - await app.dispatcher.unsub("friend", user_id, peer_id) - await app.dispatcher.unsub("friend", peer_id, user_id) + await app.dispatcher.friend.unsub(user_id, peer_id) + await app.dispatcher.friend.unsub(peer_id, user_id) async def _sub_friend(user_id, peer_id): - await app.dispatcher.sub("friend", user_id, peer_id) - await app.dispatcher.sub("friend", peer_id, user_id) + await app.dispatcher.friend.sub(user_id, peer_id) + await app.dispatcher.friend.sub(peer_id, user_id) # dispatch presence update to the user and peer about # eachother's presence. @@ -107,8 +108,8 @@ async def make_friend( _friend, ) - await app.dispatcher.dispatch_user( - peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)} + await dispatch_user( + peer_id, ("RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}) ) await _unsub_friend(user_id, peer_id) @@ -130,35 +131,41 @@ async def make_friend( _friend, ) - _dispatch = app.dispatcher.dispatch_user + _dispatch = dispatch_user if existing: # accepted a friend request, dispatch respective # relationship events await _dispatch( user_id, - "RELATIONSHIP_REMOVE", - {"type": RelationshipType.INCOMING.value, "id": str(peer_id)}, + ( + "RELATIONSHIP_REMOVE", + {"type": RelationshipType.INCOMING.value, "id": str(peer_id)}, + ), ) await _dispatch( user_id, - "RELATIONSHIP_ADD", - { - "type": _friend, - "id": str(peer_id), - "user": await app.storage.get_user(peer_id), - }, + ( + "RELATIONSHIP_ADD", + { + "type": _friend, + "id": str(peer_id), + "user": await app.storage.get_user(peer_id), + }, + ), ) await _dispatch( peer_id, "RELATIONSHIP_ADD", - { - "type": _friend, - "id": str(user_id), - "user": await app.storage.get_user(user_id), - }, + ( + { + "type": _friend, + "id": str(user_id), + "user": await app.storage.get_user(user_id), + }, + ), ) await _sub_friend(user_id, peer_id) @@ -169,22 +176,26 @@ async def make_friend( if rel_type == _friend: await _dispatch( user_id, - "RELATIONSHIP_ADD", - { - "id": str(peer_id), - "type": RelationshipType.OUTGOING.value, - "user": await app.storage.get_user(peer_id), - }, + ( + "RELATIONSHIP_ADD", + { + "id": str(peer_id), + "type": RelationshipType.OUTGOING.value, + "user": await app.storage.get_user(peer_id), + }, + ), ) await _dispatch( peer_id, - "RELATIONSHIP_ADD", - { - "id": str(user_id), - "type": RelationshipType.INCOMING.value, - "user": await app.storage.get_user(user_id), - }, + ( + "RELATIONSHIP_ADD", + { + "id": str(user_id), + "type": RelationshipType.INCOMING.value, + "user": await app.storage.get_user(user_id), + }, + ), ) # we don't make the pubsub link @@ -240,7 +251,7 @@ async def add_relationship(peer_id: int): # make_friend did not succeed, so we # assume it is a block and dispatch # the respective RELATIONSHIP_ADD. - await app.dispatcher.dispatch_user( + await dispatch_user( user_id, "RELATIONSHIP_ADD", { @@ -261,7 +272,7 @@ async def remove_relationship(peer_id: int): user_id = await token_check() _friend = RelationshipType.FRIEND.value _block = RelationshipType.BLOCK.value - _dispatch = app.dispatcher.dispatch_user + _dispatch = dispatch_user rel_type = await app.db.fetchval( """ @@ -307,7 +318,8 @@ async def remove_relationship(peer_id: int): ) await _dispatch( - user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type} + user_id, + ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}), ) peer_del_type = ( @@ -315,7 +327,8 @@ async def remove_relationship(peer_id: int): ) await _dispatch( - peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type} + peer_id, + ("RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}), ) await _unsub_friend(user_id, peer_id) @@ -334,7 +347,7 @@ async def remove_relationship(peer_id: int): ) await _dispatch( - user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block} + user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}) ) await _unsub_friend(user_id, peer_id) diff --git a/litecord/blueprints/user/settings.py b/litecord/blueprints/user/settings.py index e64e27e..0d81f9d 100644 --- a/litecord/blueprints/user/settings.py +++ b/litecord/blueprints/user/settings.py @@ -22,6 +22,7 @@ from quart import Blueprint, jsonify, request, current_app as app from litecord.auth import token_check from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS from litecord.blueprints.checks import guild_check +from litecord.pubsub.user import dispatch_user bp = Blueprint("users_settings", __name__) @@ -58,7 +59,7 @@ async def patch_current_settings(): ) settings = await app.user_storage.get_user_settings(user_id) - await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings) + await dispatch_user(user_id, ("USER_SETTINGS_UPDATE", settings)) return jsonify(settings) @@ -123,7 +124,7 @@ async def patch_guild_settings(guild_id: int): settings = await app.user_storage.get_guild_settings_one(user_id, guild_id) - await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings) + await dispatch_user(user_id, ("USER_GUILD_SETTINGS_UPDATE", settings)) return jsonify(settings) @@ -157,8 +158,8 @@ async def put_note(target_id: int): note, ) - await app.dispatcher.dispatch_user( - user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note} + await dispatch_user( + user_id, ("USER_NOTE_UPDATE", {"id": str(target_id), "note": note}) ) return "", 204 diff --git a/litecord/blueprints/webhooks.py b/litecord/blueprints/webhooks.py index a495c39..10e8254 100644 --- a/litecord/blueprints/webhooks.py +++ b/litecord/blueprints/webhooks.py @@ -156,11 +156,12 @@ async def webhook_token_check(webhook_id: int, webhook_token: str): async def _dispatch_webhook_update(guild_id: int, channel_id): - await app.dispatcher.dispatch( - "guild", + await app.dispatcher.guild.dispatch( guild_id, - "WEBHOOKS_UPDATE", - {"guild_id": str(guild_id), "channel_id": str(channel_id)}, + ( + "WEBHOOKS_UPDATE", + {"guild_id": str(guild_id), "channel_id": str(channel_id)}, + ), ) @@ -503,7 +504,7 @@ async def execute_webhook(webhook_id: int, webhook_token): payload = await app.storage.get_message(message_id) - await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload) + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload)) # spawn embedder in the background, even when we're on a webhook. app.sched.spawn(process_url_embed(payload)) diff --git a/litecord/common/channels.py b/litecord/common/channels.py index 05e2aca..71f0c3d 100644 --- a/litecord/common/channels.py +++ b/litecord/common/channels.py @@ -22,6 +22,8 @@ from quart import current_app as app from litecord.errors import ForbiddenDM from litecord.enums import RelationshipType +from litecord.pubsub.member import dispatch_member +from litecord.pubsub.user import dispatch_user async def channel_ack( @@ -54,20 +56,24 @@ async def channel_ack( ) if guild_id: - await app.dispatcher.dispatch_user_guild( - user_id, + await dispatch_member( guild_id, - "MESSAGE_ACK", - {"message_id": str(message_id), "channel_id": str(channel_id)}, + user_id, + ( + "MESSAGE_ACK", + {"message_id": str(message_id), "channel_id": str(channel_id)}, + ), ) else: # we don't use ChannelDispatcher here because since # guild_id is None, all user devices are already subscribed # to the given channel (a dm or a group dm) - await app.dispatcher.dispatch_user( + await dispatch_user( user_id, - "MESSAGE_ACK", - {"message_id": str(message_id), "channel_id": str(channel_id)}, + ( + "MESSAGE_ACK", + {"message_id": str(message_id), "channel_id": str(channel_id)}, + ), ) diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index 2d11e61..25fe555 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -17,13 +17,15 @@ along with this program. If not, see . """ +from typing import List from logbook import Logger from quart import current_app as app from ..snowflake import get_snowflake -from ..permissions import get_role_perms +from ..permissions import get_role_perms, get_permissions from ..utils import dict_get, maybe_lazy_guild_dispatch from ..enums import ChannelType +from litecord.pubsub.member import dispatch_member log = Logger(__name__) @@ -41,21 +43,20 @@ async def remove_member(guild_id: int, member_id: int): member_id, ) - await app.dispatcher.dispatch_user_guild( - member_id, + await dispatch_member( guild_id, - "GUILD_DELETE", - {"guild_id": str(guild_id), "unavailable": False}, + member_id, + ("GUILD_DELETE", {"guild_id": str(guild_id), "unavailable": False}), ) - await app.dispatcher.unsub("guild", guild_id, member_id) - - await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id) - - await app.dispatcher.dispatch_guild( + await app.dispatcher.guild.unsub(guild_id, member_id) + await app.lazy_guild.remove_member(member_id) + await app.dispatcher.guild.dispatch( guild_id, - "GUILD_MEMBER_REMOVE", - {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, + ( + "GUILD_MEMBER_REMOVE", + {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, + ), ) @@ -108,8 +109,8 @@ async def create_role(guild_id, name: str, **kwargs): # we need to update the lazy guild handlers for the newly created group await maybe_lazy_guild_dispatch(guild_id, "new_role", role) - await app.dispatcher.dispatch_guild( - guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role} + await app.dispatcher.guild.dispatch( + guild_id, ("GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}) ) return role @@ -137,6 +138,39 @@ async def _specific_chan_create(channel_id, ctype, **kwargs): ) +async def _subscribe_users_new_channel(guild_id: int, channel_id: int) -> None: + + # for each state currently subscribed to guild, we check on the database + # which states can also subscribe to the new channel at its creation. + + # the list of users that can subscribe are then used again for a pass + # over the states and states that have user ids in that list become + # subscribers of the new channel. + users_to_sub: List[str] = [] + + for session_id in app.dispatcher.guild.state[guild_id]: + try: + state = app.state_manager.fetch_raw(session_id) + except KeyError: + continue + + if state.user_id in users_to_sub: + continue + + perms = await get_permissions(state.user_id, channel_id) + if perms.read_messages: + users_to_sub.append(state.user_id) + + for session_id in app.dispatcher.guild.state[guild_id]: + try: + state = app.state_manager.fetch_raw(session_id) + except KeyError: + continue + + if state.user_id in users_to_sub: + await app.dispatcher.channel.sub(channel_id, session_id) + + async def create_guild_channel( guild_id: int, channel_id: int, ctype: ChannelType, **kwargs ): @@ -180,6 +214,8 @@ async def create_guild_channel( # so we use this function. await _specific_chan_create(channel_id, ctype, **kwargs) + await _subscribe_users_new_channel(guild_id, channel_id) + async def _del_from_table(table: str, user_id: int): """Delete a row from a table.""" @@ -206,21 +242,22 @@ async def delete_guild(guild_id: int): ) # Discord's client expects IDs being string - await app.dispatcher.dispatch( - "guild", + await app.dispatcher.guild.dispatch( guild_id, - "GUILD_DELETE", - { - "guild_id": str(guild_id), - "id": str(guild_id), - # 'unavailable': False, - }, + ( + "GUILD_DELETE", + { + "guild_id": str(guild_id), + "id": str(guild_id), + # 'unavailable': False, + }, + ), ) # remove from the dispatcher so nobody # becomes the little memer that tries to fuck up with # everybody's gateway - await app.dispatcher.remove("guild", guild_id) + await app.dispatcher.guild.drop(guild_id) async def create_guild_settings(guild_id: int, user_id: int): @@ -285,18 +322,17 @@ async def add_member(guild_id: int, user_id: int, *, basic=False): # tell current members a new member came up member = await app.storage.get_member_data_one(guild_id, user_id) - await app.dispatcher.dispatch_guild( - guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}} + await app.dispatcher.guild.dispatch( + guild_id, ("GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}) ) - # update member lists for the new member - await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id) + # pubsub changes for new member + await app.lazy_guild.new_member(guild_id, user_id) + states = await app.dispatcher.guild.sub_user(guild_id, user_id) - # subscribe new member to guild, so they get events n stuff - await app.dispatcher.sub("guild", guild_id, user_id) - - # tell the new member that theres the guild it just joined. - # we use dispatch_user_guild so that we send the GUILD_CREATE - # just to the shards that are actually tied to it. guild = await app.storage.get_guild_full(guild_id, user_id, 250) - await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild) + for state in states: + try: + await state.ws.dispatch("GUILD_CREATE", guild) + except Exception: + log.exception("failed to dispatch to session_id={!r}", state.session_id) diff --git a/litecord/common/users.py b/litecord/common/users.py index 00614d2..fa9b825 100644 --- a/litecord/common/users.py +++ b/litecord/common/users.py @@ -29,41 +29,48 @@ from ..snowflake import get_snowflake from ..errors import BadRequest from ..auth import hash_data from ..utils import rand_hex +from ..pubsub.user import dispatch_user log = Logger(__name__) -async def mass_user_update(user_id: int): - """Dispatch USER_UPDATE in a mass way.""" - # by using dispatch_with_filter - # we're guaranteeing all shards will get - # a USER_UPDATE once and not any others. +async def mass_user_update(user_id: int) -> Tuple[dict, dict]: + """Dispatch a USER_UPDATE to everyone that is subscribed to the user. + This function guarantees all states will get one USER_UPDATE for simple + cases. Lazy guild users might get updates N times depending of how many + lists are they subscribed to. + """ session_ids: List[str] = [] public_user = await app.storage.get_user(user_id) private_user = await app.storage.get_user(user_id, secure=True) - session_ids.extend( - await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user) - ) + session_ids.extend(await dispatch_user(user_id, ("USER_UPDATE", private_user))) - guild_ids = await app.user_storage.get_user_guilds(user_id) - friend_ids = await app.user_storage.get_friend_ids(user_id) + guild_ids: List[int] = await app.user_storage.get_user_guilds(user_id) + friend_ids: List[int] = await app.user_storage.get_friend_ids(user_id) - session_ids.extend( - await app.dispatcher.dispatch_many_filter_list( - "guild", guild_ids, session_ids, "USER_UPDATE", public_user + for guild_id in guild_ids: + session_ids.extend( + await app.dispatcher.guild.dispatch_filter( + guild_id, + lambda sess_id: sess_id not in session_ids, + ("USER_UPDATE", public_user), + ) ) - ) - session_ids.extend( - await app.dispatcher.dispatch_many_filter_list( - "friend", friend_ids, session_ids, "USER_UPDATE", public_user + for friend_id in friend_ids: + session_ids.extend( + await app.dispatcher.friend.dispatch_filter( + friend_id, + lambda sess_id: sess_id not in session_ids, + ("USER_UPDATE", public_user), + ) ) - ) - await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id) + for guild_id in guild_ids: + await app.lazy_guild.update_user(guild_id, user_id) return public_user, private_user diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index 4028757..21e3dfe 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -17,17 +17,12 @@ along with this program. If not, see . """ -from typing import List, Any, Dict - from logbook import Logger from .pubsub import ( GuildDispatcher, - MemberDispatcher, - UserDispatcher, ChannelDispatcher, FriendDispatcher, - LazyGuildDispatcher, ) log = Logger(__name__) @@ -50,169 +45,7 @@ class EventDispatcher: its subscriber ids. """ - def __init__(self, app): - self.state_manager = app.state_manager - self.app = app - - self.backends = { - "guild": GuildDispatcher(self), - "member": MemberDispatcher(self), - "channel": ChannelDispatcher(self), - "user": UserDispatcher(self), - "friend": FriendDispatcher(self), - "lazy_guild": LazyGuildDispatcher(self), - } - - async def action(self, backend_str: str, action: str, key, identifier, *args): - """Send an action regarding a key/identifier pair to a backend. - - Action is usually "sub" or "unsub". - """ - backend = self.backends[backend_str] - method = getattr(backend, action) - - # convert keys to the types the backend wants - key = backend.KEY_TYPE(key) - identifier = backend.VAL_TYPE(identifier) - - return await method(key, identifier, *args) - - async def subscribe( - self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None - ): - """Subscribe a single element to the given backend.""" - flags = flags or {} - - log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend) - - # this is a hacky solution for backwards compatibility between backends - # that implement flags and backends that don't. - - # passing flags to backends that don't implement flags will - # cause errors as expected. - if flags: - return await self.action(backend, "sub", key, identifier, flags) - - return await self.action(backend, "sub", key, identifier) - - async def unsubscribe(self, backend: str, key: Any, identifier: Any): - """Unsubscribe an element from the given backend.""" - log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend) - - return await self.action(backend, "unsub", key, identifier) - - async def sub(self, backend, key, identifier): - """Alias to subscribe().""" - return await self.subscribe(backend, key, identifier) - - async def unsub(self, backend, key, identifier): - """Alias to unsubscribe().""" - return await self.unsubscribe(backend, key, identifier) - - async def sub_many( - self, - backend_str: str, - identifier: Any, - keys: list, - flags: Dict[str, Any] = None, - ): - """Subscribe to multiple channels (all in a single backend) - at a time. - - Usually used when connecting to the gateway and the client - needs to subscribe to all their guids. - """ - flags = flags or {} - for key in keys: - await self.subscribe(backend_str, key, identifier, flags) - - async def mass_sub(self, identifier: Any, backends: List[tuple]): - """Mass subscribe to many backends at once.""" - for bcall in backends: - backend_str, keys = bcall[0], bcall[1] - - if len(bcall) == 2: - flags = {} - elif len(bcall) == 3: - # we have flags - flags = bcall[2] - - log.debug( - "subscribing {} to {} keys in backend {}, flags: {}", - identifier, - len(keys), - backend_str, - flags, - ) - - await self.sub_many(backend_str, identifier, keys, flags) - - async def dispatch(self, backend_str: str, key: Any, *args, **kwargs): - """Dispatch an event to the backend. - - The backend is responsible for everything regarding the - actual dispatch. - """ - backend = self.backends[backend_str] - - # convert types - key = backend.KEY_TYPE(key) - return await backend.dispatch(key, *args, **kwargs) - - async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs): - """Dispatch to multiple keys in a single backend.""" - log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys)) - - for key in keys: - await self.dispatch(backend_str, key, *args, **kwargs) - - async def dispatch_filter(self, backend_str: str, key: Any, func, *args): - """Dispatch to a backend that only accepts - (event, data) arguments with an optional filter - function.""" - backend = self.backends[backend_str] - key = backend.KEY_TYPE(key) - return await backend.dispatch_filter(key, func, *args) - - async def dispatch_many_filter_list( - self, backend_str: str, keys: List[Any], sess_list: List[str], *args - ): - """Make a "unique" dispatch given a list of session ids. - - This only works for backends that have a dispatch_filter - handler and return session id lists in their dispatch - results. - """ - for key in keys: - sess_list.extend( - await self.dispatch_filter( - backend_str, key, lambda sess_id: sess_id not in sess_list, *args - ) - ) - - return sess_list - - async def reset(self, backend_str: str, key: Any): - """Reset the bucket in the given backend.""" - backend = self.backends[backend_str] - key = backend.KEY_TYPE(key) - return await backend.reset(key) - - async def remove(self, backend_str: str, key: Any): - """Remove a key from the backend. This - might be a different operation than resetting.""" - backend = self.backends[backend_str] - key = backend.KEY_TYPE(key) - return await backend.remove(key) - - async def dispatch_guild(self, guild_id, event, data): - """Backwards compatibility with old EventDispatcher.""" - return await self.dispatch("guild", guild_id, event, data) - - async def dispatch_user_guild(self, user_id, guild_id, event, data): - """Backwards compatibility with old EventDispatcher.""" - return await self.dispatch("member", (guild_id, user_id), event, data) - - async def dispatch_user(self, user_id, event, data): - """Backwards compatibility with old EventDispatcher.""" - return await self.dispatch("user", user_id, event, data) + def __init__(self): + self.guild: GuildDispatcher = GuildDispatcher() + self.channel = ChannelDispatcher() + self.friend = FriendDispatcher() diff --git a/litecord/embed/messages.py b/litecord/embed/messages.py index b61a0f1..69b1576 100644 --- a/litecord/embed/messages.py +++ b/litecord/embed/messages.py @@ -87,8 +87,8 @@ async def msg_update_embeds(payload, new_embeds): if "flags" in payload: update_payload["flags"] = payload["flags"] - await app.dispatcher.dispatch( - "channel", channel_id, "MESSAGE_UPDATE", update_payload + await app.dispatcher.channel.dispatch( + channel_id, ("MESSAGE_UPDATE", update_payload) ) diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index 2540269..5d36922 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -18,6 +18,7 @@ along with this program. If not, see . """ import urllib.parse +from typing import Optional from litecord.gateway.websocket import GatewayWebsocket @@ -44,17 +45,18 @@ async def websocket_handler(app, ws, url): return await ws.close(1000, "Invalid gateway encoding") try: - gw_compress = args["compress"][0] + gw_compress: Optional[str] = args["compress"][0] except (KeyError, IndexError): gw_compress = None if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"): return await ws.close(1000, "Invalid gateway compress") - gws = GatewayWebsocket( - ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress - ) + async with app.app_context(): + gws = GatewayWebsocket( + ws, v=gw_version, encoding=gw_encoding, compress=gw_compress + ) - # this can be run with a single await since this whole coroutine - # is already running in the background. - await gws.run() + # this can be run with a single await since this whole coroutine + # is already running in the background. + await gws.run() diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index 756e551..6b92eb9 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -22,6 +22,7 @@ import asyncio from typing import List from collections import defaultdict +from quart import current_app as app from websockets.exceptions import ConnectionClosed from logbook import Logger @@ -225,3 +226,16 @@ class StateManager: def close(self): """Close the state manager.""" self.closed = True + + async def fetch_user_states_for_channel( + self, channel_id: int, user_id: int + ) -> List[GatewayState]: + """Get a list of gateway states for a user that can receive events on a certain channel.""" + # TODO optimize this with an in-memory store + guild_id = await app.storage.guild_from_channel(channel_id) + + if guild_id: + return self.fetch_states(user_id, guild_id) + + # DMs and GDMs use all user states + return self.user_states(user_id) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index a4f10e8..8be9a8f 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -27,6 +27,7 @@ from random import randint import websockets import zstandard as zstd from logbook import Logger +from quart import current_app as app from litecord.auth import raw_token_check from litecord.enums import RelationshipType, ChannelType @@ -43,7 +44,6 @@ from litecord.presence import BasePresence from litecord.gateway.opcodes import OP from litecord.gateway.state import GatewayState - from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest from litecord.gateway.errors import ( DecodeError, @@ -52,8 +52,9 @@ from litecord.gateway.errors import ( ShardingRequired, ) from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf - from litecord.gateway.utils import WebsocketFileHandler +from litecord.pubsub.guild import GuildFlags +from litecord.pubsub.channel import ChannelFlags from litecord.storage import int_ @@ -67,7 +68,7 @@ WebsocketProperties = collections.namedtuple( class GatewayWebsocket: """Main gateway websocket logic.""" - def __init__(self, ws, app, **kwargs): + def __init__(self, ws, **kwargs): self.app = app self.storage = app.storage self.user_storage = app.user_storage @@ -230,7 +231,7 @@ class GatewayWebsocket: if task: task.cancel() - self.wsp.tasks["heartbeat"] = self.app.loop.create_task( + self.wsp.tasks["heartbeat"] = app.sched.spawn( task_wrapper("hb wait", self._hb_wait(interval)) ) @@ -247,6 +248,7 @@ class GatewayWebsocket: async def dispatch(self, event: str, data: Any): """Dispatch an event to the websocket.""" + assert self.state is not None self.state.seq += 1 payload = { @@ -282,6 +284,7 @@ class GatewayWebsocket: async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]): """Dispatch GUILD_CREATE information.""" + assert self.state is not None # Users don't get asynchronous guild dispatching. if not self.state.bot: @@ -360,9 +363,7 @@ class GatewayWebsocket: } await self.dispatch("READY", {**base_ready, **user_ready}) - - # async dispatch of guilds - self.app.loop.create_task(self._guild_dispatch(guilds)) + app.sched.spawn(self._guild_dispatch(guilds)) async def _check_shards(self, shard, user_id): """Check if the given `shard` value in IDENTIFY has good enough values. @@ -412,6 +413,7 @@ class GatewayWebsocket: Note: subscribing to channels is already handled by GuildDispatcher.sub """ + assert self.state is not None user_id = self.state.user_id guild_ids = await self._guild_ids() @@ -434,26 +436,50 @@ class GatewayWebsocket: # (presence and typing events) # we enable processing of guild_subscriptions by adding flags - # when subscribing to the given backend. those are optional. - channels_to_sub = [ - ( - "guild", - guild_ids, - {"presence": guild_subscriptions, "typing": guild_subscriptions}, - ), - ("channel", dm_ids), - ("channel", gdm_ids), - ] + # when subscribing to the given backend. + session_id = self.state.session_id + channel_ids: List[int] = [] - await self.app.dispatcher.mass_sub(user_id, channels_to_sub) + for guild_id in guild_ids: + await app.dispatcher.guild.sub_with_flags( + guild_id, + session_id, + GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions), + ) + # instead of calculating which channels to subscribe to + # inside guild dispatcher, we calculate them in here, so that + # we remove complexity of the dispatcher. + + guild_chan_ids = await app.storage.get_channel_ids(guild_id) + for channel_id in guild_chan_ids: + perms = await get_permissions( + self.state.user_id, channel_id, storage=self.storage + ) + + if perms.bits.read_messages: + channel_ids.append(channel_id) + + log.info("subscribing to {} guild channels", len(channel_ids)) + for channel_id in channel_ids: + await app.dispatcher.channel.sub_with_flags( + channel_id, session_id, ChannelFlags(typing=guild_subscriptions) + ) + + for dm_id in dm_ids: + await app.dispatcher.channel.sub(dm_id, session_id) + + for gdm_id in gdm_ids: + await app.dispatcher.channel.sub(gdm_id, session_id) + + # subscribe to all friends + # (their friends will also subscribe back + # when they come online) if not self.state.bot: - # subscribe to all friends - # (their friends will also subscribe back - # when they come online) friend_ids = await self.user_storage.get_friend_ids(user_id) log.info("subscribing to {} friends", len(friend_ids)) - await self.app.dispatcher.sub_many("friend", user_id, friend_ids) + for friend_id in friend_ids: + await app.dispatcher.friend.sub(user_id, friend_id) async def update_status(self, incoming_status: dict): """Update the status of the current websocket connection.""" @@ -921,6 +947,7 @@ class GatewayWebsocket: ] } """ + assert self.state is not None data = payload["d"] gids = await self.user_storage.get_user_guilds(self.state.user_id) @@ -933,11 +960,9 @@ class GatewayWebsocket: log.debug("lazy request: members: {}", data.get("members", [])) # make shard query - lazy_guilds = self.app.dispatcher.backends["lazy_guild"] - for chan_id, ranges in data.get("channels", {}).items(): chan_id = int(chan_id) - member_list = await lazy_guilds.get_gml(chan_id) + member_list = await app.lazy_guild.get_gml(chan_id) perms = await get_permissions( self.state.user_id, chan_id, storage=self.storage diff --git a/litecord/presence.py b/litecord/presence.py index e6c1189..d3f052c 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -93,7 +93,6 @@ class PresenceManager: self.storage = app.storage self.user_storage = app.user_storage self.state_manager = app.state_manager - self.dispatcher = app.dispatcher async def guild_presences( self, member_ids: List[int], guild_id: int @@ -127,8 +126,7 @@ class PresenceManager: member = await self.storage.get_member_data_one(guild_id, user_id) - lazy_guild_store = self.dispatcher.backends["lazy_guild"] - lists = lazy_guild_store.get_gml_guild(guild_id) + lists = app.lazy_guild.get_gml_guild(guild_id) # shards that are in lazy guilds with 'everyone' # enabled @@ -163,20 +161,21 @@ class PresenceManager: # given a session id, return if the session id actually connects to # a given user, and if the state has not been dispatched via lazy guild. def _session_check(session_id): - state = self.state_manager.fetch_raw(session_id) - uid = int(member["user"]["id"]) - - if not state: + try: + state = self.state_manager.fetch_raw(session_id) + except KeyError: return False + uid = int(member["user"]["id"]) + # we don't want to send a presence update # to the same user return state.user_id != uid and session_id not in in_lazy # everyone not in lazy guild mode # gets a PRESENCE_UPDATE - await self.dispatcher.dispatch_filter( - "guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload + await app.dispatcher.guild.dispatch_filter( + guild_id, _session_check, ("PRESENCE_UPDATE", event_payload) ) return in_lazy @@ -193,11 +192,8 @@ class PresenceManager: # dispatch to all friends that are subscribed to them user = await self.storage.get_user(user_id) - await self.dispatcher.dispatch( - "friend", - user_id, - "PRESENCE_UPDATE", - {**presence.partial_dict, **{"user": user}}, + await app.dispatcher.friend.dispatch( + user_id, ("PRESENCE_UPDATE", {**presence.partial_dict, **{"user": user}}), ) def fetch_friend_presence(self, friend_id: int) -> BasePresence: diff --git a/litecord/pubsub/__init__.py b/litecord/pubsub/__init__.py index 3840695..a038822 100644 --- a/litecord/pubsub/__init__.py +++ b/litecord/pubsub/__init__.py @@ -18,17 +18,15 @@ along with this program. If not, see . """ from .guild import GuildDispatcher -from .member import MemberDispatcher -from .user import UserDispatcher +from .member import dispatch_member +from .user import dispatch_user from .channel import ChannelDispatcher from .friend import FriendDispatcher -from .lazy_guild import LazyGuildDispatcher __all__ = [ "GuildDispatcher", - "MemberDispatcher", - "UserDispatcher", + "dispatch_member", + "dispatch_user", "ChannelDispatcher", "FriendDispatcher", - "LazyGuildDispatcher", ] diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index cc0404f..483b94f 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -17,13 +17,15 @@ along with this program. If not, see . """ -from typing import Any, List +from typing import List +from dataclasses import dataclass +from quart import current_app as app from logbook import Logger -from .dispatcher import DispatcherWithFlags from litecord.enums import ChannelType from litecord.utils import index_by_func +from .dispatcher import DispatcherWithFlags, GatewayEvent log = Logger(__name__) @@ -37,75 +39,66 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict: """ # make a copy or the original channel object data = dict(orig) - idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"]) - data["recipients"].pop(idx) - return data -class ChannelDispatcher(DispatcherWithFlags): - """Main channel Pub/Sub logic.""" +@dataclass +class ChannelFlags: + typing: bool - KEY_TYPE = int - VAL_TYPE = int - async def dispatch(self, channel_id, event: str, data: Any) -> List[str]: +class ChannelDispatcher( + DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags] +): + """Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels.""" + + async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]: """Dispatch an event to a channel.""" - # get everyone who is subscribed - # and store the number of states we dispatched the event to - user_ids = self.state[channel_id] - dispatched = 0 + session_ids = set(self.state[channel_id]) sessions: List[str] = [] - # making a copy of user_ids since - # we'll modify it later on. - for user_id in set(user_ids): - guild_id = await self.app.storage.guild_from_channel(channel_id) + event_type, event_data = event + assert isinstance(event_data, dict) - # if we are dispatching to a guild channel, - # we should only dispatch to the states / shards - # that are connected to the guild (via their shard id). - - # if we aren't, we just get all states tied to the user. - # TODO: make a fetch_states that fetches shards - # - with id 0 (count any) OR - # - single shards (id=0, count=1) - states = ( - self.sm.fetch_states(user_id, guild_id) - if guild_id - else self.sm.user_states(user_id) - ) - - # unsub people who don't have any states tied to the channel. - if not states: - await self.unsub(channel_id, user_id) + for session_id in session_ids: + try: + state = app.state_manager.fetch_raw(session_id) + except KeyError: + await self.unsub(channel_id, session_id) continue - # skip typing events for users that don't want it - if event.startswith("TYPING_") and not self.flags_get( - channel_id, user_id, "typing", True - ): + try: + flags = self.get_flags(channel_id, session_id) + except KeyError: + log.warning("no flags for {!r}, ignoring", session_id) + flags = ChannelFlags(typing=True) + + if event_type.lower().startswith("typing_") and not flags.typing: continue - cur_sess: List[str] = [] - + correct_event = event + # for cases where we are talking about group dms, we create an edited + # event data so that it doesn't show the user we're dispatching + # to in data.recipients (clients already assume they are recipients) if ( - event in ("CHANNEL_CREATE", "CHANNEL_UPDATE") - and data.get("type") == ChannelType.GROUP_DM.value + event_type in ("CHANNEL_CREATE", "CHANNEL_UPDATE") + and event_data.get("type") == ChannelType.GROUP_DM.value ): - # we edit the channel payload so it doesn't show - # the user as a recipient + new_data = gdm_recipient_view(event_data, state.user_id) + correct_event = (event_type, new_data) - new_data = gdm_recipient_view(data, user_id) - cur_sess = await self._dispatch_states(states, event, new_data) - else: - cur_sess = await self._dispatch_states(states, event, data) + try: + await state.ws.dispatch(*correct_event) + except Exception: + log.exception("error while dispatching to {}", state.session_id) + continue - sessions.extend(cur_sess) - dispatched += len(cur_sess) + sessions.append(session_id) - log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched) + log.info( + "Dispatched chan={} {!r} to {} states", channel_id, event[0], len(sessions) + ) return sessions diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index 747ad63..5ab3e6b 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -17,7 +17,18 @@ along with this program. If not, see . """ -from typing import List +from typing import ( + List, + Generic, + TypeVar, + Any, + Callable, + Dict, + Set, + Mapping, + Iterable, + Tuple, +) from collections import defaultdict from logbook import Logger @@ -25,79 +36,63 @@ from logbook import Logger log = Logger(__name__) -def _identity(_self, x): - return x +K = TypeVar("K") +V = TypeVar("V") +F = TypeVar("F") +EventType = TypeVar("EventType") +DispatchType = TypeVar("DispatchType") +F_Map = Mapping[V, F] + +GatewayEvent = Tuple[str, Any] + +__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"] -class Dispatcher: +class Dispatcher(Generic[K, V, EventType, DispatchType]): """Pub/Sub backend dispatcher. - This just declares functions all Dispatcher subclasses - can implement. This does not mean all Dispatcher - subclasses have them implemented. + Classes must implement this protocol. """ - KEY_TYPE = _identity - VAL_TYPE = _identity + async def sub(self, key: K, identifier: V) -> None: + """Subscribe a given identifier to a given key.""" + ... - def __init__(self, main): - #: main EventDispatcher - self.main_dispatcher = main + async def sub_many(self, key: K, identifier_list: Iterable[V]) -> None: + for identifier in identifier_list: + await self.sub(key, identifier) - #: gateway state storage - self.sm = main.state_manager + async def unsub(self, key: K, identifier: V) -> None: + """Unsubscribe a given identifier to a given key.""" + ... - self.app = main.app + async def dispatch(self, key: K, event: EventType) -> DispatchType: + ... - async def sub(self, _key, _id): - """Subscribe an element to the channel/key.""" - raise NotImplementedError + async def dispatch_many(self, keys: List[K], *args: Any, **kwargs: Any) -> None: + log.info("MULTI DISPATCH in {!r}, {} keys", self, len(keys)) + for key in keys: + await self.dispatch(key, *args, **kwargs) - async def unsub(self, _key, _id): - """Unsubscribe an elemtnt from the channel/key.""" - raise NotImplementedError + async def drop(self, key: K) -> None: + """Drop a key.""" + ... - async def dispatch_filter(self, _key, _func, *_args): - """Selectively dispatch to the list of subscribed users. + async def clear(self, key: K) -> None: + """Clear a key from the backend.""" + ... - The selection logic is completly arbitraty and up to the - Pub/Sub backend. + async def dispatch_filter( + self, key: K, filter_function: Callable[[K], bool], event: EventType + ) -> List[str]: + """Selectively dispatch to the list of subscribers. + + Function must return a list of separate identifiers for composability. """ - raise NotImplementedError - - async def dispatch(self, _key, *_args): - """Dispatch an event to the given channel/key.""" - raise NotImplementedError - - async def reset(self, _key): - """Reset a key from the backend.""" - raise NotImplementedError - - async def remove(self, _key): - """Remove a key from the backend. - - The meaning from reset() and remove() - is different, reset() is to clear all - subscribers from the given key, - remove() is to remove the key as well. - """ - raise NotImplementedError - - async def _dispatch_states(self, states: list, event: str, data) -> List[str]: - """Dispatch an event to a list of states.""" - res = [] - - for state in states: - try: - await state.ws.dispatch(event, data) - res.append(state.session_id) - except Exception: - log.exception("error while dispatching") - - return res + ... -class DispatcherWithState(Dispatcher): +class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]): """Pub/Sub backend with a state dictionary. This class was made to decrease the amount @@ -105,58 +100,58 @@ class DispatcherWithState(Dispatcher): that have that dictionary. """ - def __init__(self, main): - super().__init__(main) + def __init__(self): + super().__init__() #: the default dict is to a set # so we make sure someone calling sub() # twice won't get 2x the events for the # same channel. - self.state = defaultdict(set) + self.state: Dict[K, Set[V]] = defaultdict(set) - async def sub(self, key, identifier): + async def sub(self, key: K, identifier: V): self.state[key].add(identifier) - async def unsub(self, key, identifier): + async def unsub(self, key: K, identifier: V): self.state[key].discard(identifier) - async def reset(self, key): + async def reset(self, key: K): self.state[key] = set() - async def remove(self, key): + async def drop(self, key: K): try: self.state.pop(key) except KeyError: pass - async def dispatch(self, key, *args): - raise NotImplementedError - -class DispatcherWithFlags(DispatcherWithState): +class DispatcherWithFlags( + DispatcherWithState, Generic[K, V, EventType, DispatchType, F], +): """Pub/Sub backend with both a state and a flags store.""" - def __init__(self, main): - super().__init__(main) + def __init__(self): + super().__init__() + self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict) - #: keep flags for subscribers, so for example - # a subscriber could drop all presence events at the - # pubsub level. see gateway's guild_subscriptions field for more - self.flags = defaultdict(dict) + def set_flags(self, key: K, identifier: V, flags: F): + """Set flags for the given identifier.""" + self.flags[key][identifier] = flags - async def sub(self, key, identifier, flags=None): - """Subscribe a user to the guild.""" - await super().sub(key, identifier) - self.flags[key][identifier] = flags or {} - - async def unsub(self, key, identifier): - """Unsubscribe a user from the guild.""" - await super().unsub(key, identifier) + def remove_flags(self, key: K, identifier: V): + """Set flags for the given identifier.""" self.flags[key].pop(identifier) - def flags_get(self, key, identifier, field: str, default): + def get_flags(self, key: K, identifier: V): """Get a single field from the flags store.""" - # yes, i know its simply an indirection from the main flags store, - # but i'd rather have this than change every call if i ever change - # the structure of the flags store. - return self.flags[key][identifier].get(field, default) + return self.flags[key][identifier] + + async def sub_with_flags(self, key: K, identifier: V, flags: F): + """Subscribe a user to the guild.""" + await super().sub(key, identifier) + self.set_flags(key, identifier, flags) + + async def unsub(self, key: K, identifier: V): + """Unsubscribe a user from the guild.""" + await super().unsub(key, identifier) + self.remove_flags(key, identifier) diff --git a/litecord/pubsub/friend.py b/litecord/pubsub/friend.py index 4bf727a..c129ed3 100644 --- a/litecord/pubsub/friend.py +++ b/litecord/pubsub/friend.py @@ -17,15 +17,16 @@ along with this program. If not, see . """ -from typing import List +from typing import List, Set from logbook import Logger -from .dispatcher import DispatcherWithState +from .dispatcher import DispatcherWithState, GatewayEvent +from .user import dispatch_user_filter log = Logger(__name__) -class FriendDispatcher(DispatcherWithState): +class FriendDispatcher(DispatcherWithState[int, int, GatewayEvent, List[str]]): """Friend Pub/Sub logic. When connecting, a client will subscribe to all their friends @@ -33,26 +34,18 @@ class FriendDispatcher(DispatcherWithState): broadcasted through that channel to basically all their friends. """ - KEY_TYPE = int - VAL_TYPE = int - - async def dispatch_filter(self, user_id: int, func, event, data): + async def dispatch_filter(self, user_id: int, filter_function, event: GatewayEvent): """Dispatch an event to all of a users' friends.""" - peer_ids = self.state[user_id] + peer_ids: Set[int] = self.state[user_id] sessions: List[str] = [] for peer_id in peer_ids: # dispatch to the user instead of the "shards tied to a guild" # since relationships broadcast to all shards. - sessions.extend( - await self.main_dispatcher.dispatch_filter( - "user", peer_id, func, event, data - ) - ) + sessions.extend(await dispatch_user_filter(peer_id, filter_function, event)) log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions)) - return sessions - async def dispatch(self, user_id, event, data): - return await self.dispatch_filter(user_id, lambda sess_id: True, event, data) + async def dispatch(self, user_id: int, event: GatewayEvent): + return await self.dispatch_filter(user_id, lambda sess_id: True, event) diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 21d5143..4382c9c 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -17,123 +17,73 @@ along with this program. If not, see . """ -from typing import Any +from typing import List +from dataclasses import dataclass +from quart import current_app as app from logbook import Logger -from .dispatcher import DispatcherWithFlags -from litecord.permissions import get_permissions +from .dispatcher import DispatcherWithFlags, GatewayEvent +from .channel import ChannelFlags +from litecord.gateway.state import GatewayState log = Logger(__name__) -class GuildDispatcher(DispatcherWithFlags): - """Guild backend for Pub/Sub""" +@dataclass +class GuildFlags(ChannelFlags): + presence: bool - KEY_TYPE = int - VAL_TYPE = int - async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None): - """Send an action to all channels of the guild.""" - flags = flags or {} - chan_ids = await self.app.storage.get_channel_ids(guild_id) +class GuildDispatcher( + DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags] +): + """Guild backend for Pub/Sub.""" - for chan_id in chan_ids: + async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]: + states = app.state_manager.fetch_states(user_id, guild_id) + for state in states: + await self.sub(guild_id, state.session_id) - # only do an action for users that can - # actually read the channel to start with. - chan_perms = await get_permissions( - user_id, chan_id, storage=self.main_dispatcher.app.storage - ) + return states - if not chan_perms.bits.read_messages: - log.debug("skipping cid={}, no read messages", chan_id) + async def dispatch_filter( + self, guild_id: int, filter_function, event: GatewayEvent + ): + session_ids = self.state[guild_id] + sessions: List[str] = [] + event_type, _ = event + + for session_id in set(session_ids): + if not filter_function(session_id): continue - log.debug("sending raw action {!r} to chan={}", action, chan_id) - - # for now, only sub() has support for flags. - # it is an idea to have flags support for other actions - args = [] - if action == "sub": - chanflags = dict(flags) - - # channels don't need presence flags - try: - chanflags.pop("presence") - except KeyError: - pass - - args.append(chanflags) - - await self.main_dispatcher.action( - "channel", action, chan_id, user_id, *args - ) - - async def _chan_call(self, meth: str, guild_id: int, *args): - """Call a method on the ChannelDispatcher, for all channels - in the guild.""" - chan_ids = await self.app.storage.get_channel_ids(guild_id) - - chan_dispatcher = self.main_dispatcher.backends["channel"] - method = getattr(chan_dispatcher, meth) - - for chan_id in chan_ids: - log.debug("calling {} to chan={}", meth, chan_id) - await method(chan_id, *args) - - async def sub(self, guild_id: int, user_id: int, flags=None): - """Subscribe a user to the guild.""" - await super().sub(guild_id, user_id, flags) - await self._chan_action("sub", guild_id, user_id, flags) - - async def unsub(self, guild_id: int, user_id: int): - """Unsubscribe a user from the guild.""" - await super().unsub(guild_id, user_id) - await self._chan_action("unsub", guild_id, user_id) - - async def dispatch_filter(self, guild_id: int, func, event: str, data: Any): - """Selectively dispatch to session ids that have - func(session_id) true.""" - user_ids = self.state[guild_id] - dispatched = 0 - sessions = [] - - # acquire a copy since we may be modifying - # the original user_ids - for user_id in set(user_ids): - - # fetch all states / shards that are tied to the guild. - states = self.sm.fetch_states(user_id, guild_id) - - if not states: - # user is actually disconnected, - # so we should just unsub them - await self.unsub(guild_id, user_id) + try: + state = app.state_manager.fetch_raw(session_id) + except KeyError: + await self.unsub(guild_id, session_id) continue - # skip the given subscriber if event starts with PRESENCE_ - # and the flags say they don't want it. + try: + flags = self.get_flags(guild_id, session_id) + except KeyError: + log.warning("no flags for {!r}, ignoring", session_id) + flags = GuildFlags(presence=True, typing=True) - # note that this does not equate to any unsubscription - # of the channel. - if event.startswith("PRESENCE_") and not self.flags_get( - guild_id, user_id, "presence", True - ): + if event_type.lower().startswith("presence_") and not flags.presence: continue - # filter the ones that matter - states = list(filter(lambda state: func(state.session_id), states)) + try: + await state.ws.dispatch(*event) + except Exception: + log.exception("error while dispatching to {}", state.session_id) + continue - cur_sess = await self._dispatch_states(states, event, data) - - sessions.extend(cur_sess) - dispatched += len(cur_sess) - - log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched) + sessions.append(session_id) + log.info("Dispatched {} {!r} to {} states", guild_id, event[0], len(sessions)) return sessions - async def dispatch(self, guild_id: int, event: str, data: Any): + async def dispatch(self, guild_id: int, event): """Dispatch an event to all subscribers of the guild.""" - return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data) + return await self.dispatch_filter(guild_id, lambda sess_id: True, event) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 235b69f..4b85dc4 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -30,10 +30,10 @@ import asyncio from collections import defaultdict from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set from dataclasses import dataclass, asdict, field +from quart import current_app as app from logbook import Logger -from litecord.pubsub.dispatcher import Dispatcher from litecord.permissions import ( Permissions, overwrite_find_mix, @@ -239,9 +239,6 @@ class GuildMemberList: Attributes ---------- - main_lg: LazyGuildDispatcher - Main instance of :class:`LazyGuildDispatcher`, - so that we're able to use things such as :class:`Storage`. guild_id: int The Guild ID this instance is referring to. channel_id: int @@ -257,11 +254,10 @@ class GuildMemberList: for example, can still rely on PRESENCE_UPDATEs. """ - def __init__(self, guild_id: int, channel_id: int, main_lg): + def __init__(self, guild_id: int, channel_id: int): self.guild_id = guild_id self.channel_id = channel_id - self.main = main_lg self.list = MemberList() #: store the states that are subscribed to the list. @@ -273,22 +269,22 @@ class GuildMemberList: @property def loop(self): """Get the main asyncio loop instance.""" - return self.main.app.loop + return app.loop @property def storage(self): """Get the global :class:`Storage` instance.""" - return self.main.app.storage + return app.storage @property def presence(self): """Get the global :class:`PresenceManager` instance.""" - return self.main.app.presence + return app.presence @property def state_man(self): """Get the global :class:`StateManager` instance.""" - return self.main.app.state_manager + return app.state_manager @property def list_id(self): @@ -572,8 +568,7 @@ class GuildMemberList: Wrapper for :meth:`StateManager.fetch_raw` """ try: - state = self.state_man.fetch_raw(session_id) - return state + return self.state_man.fetch_raw(session_id) except KeyError: return None @@ -643,7 +638,7 @@ class GuildMemberList: # do resync-ing in the background result.append(session_id) - self.loop.create_task(self.shard_query(session_id, [role_range])) + app.sched.spawn(self.shard_query(session_id, [role_range])) return result @@ -683,8 +678,7 @@ class GuildMemberList: ) if everyone_perms.bits.read_messages and list_id != "everyone": - everyone_gml = await self.main.get_gml(self.guild_id) - + everyone_gml = await app.lazy_guild.get_gml(self.guild_id) return await everyone_gml.shard_query(session_id, ranges) await self._init_check() @@ -1372,47 +1366,36 @@ class GuildMemberList: self.guild_id = 0 self.channel_id = 0 - self.main = None self._set_empty_list() self.state = {} -class LazyGuildDispatcher(Dispatcher): +class LazyGuildManager: """Main class holding the member lists for lazy guilds.""" - # channel ids - KEY_TYPE = int - - # the session ids subscribing to channels - VAL_TYPE = str - - def __init__(self, main): - super().__init__(main) - - self.storage = main.app.storage - + def __init__(self): # {chan_id: gml, ...} - self.state = {} + self.state: Dict[int, GuildMemberList] = {} #: store which guilds have their # respective GMLs # {guild_id: [chan_id, ...], ...} self.guild_map: Dict[int, List[int]] = defaultdict(list) - async def get_gml(self, channel_id: int): + async def get_gml(self, channel_id: int) -> GuildMemberList: """Get a guild list for a channel ID, generating it if it doesn't exist.""" try: return self.state[channel_id] except KeyError: - guild_id = await self.storage.guild_from_channel(channel_id) + guild_id = await app.storage.guild_from_channel(channel_id) # if we don't find a guild, we just # set it the same as the channel. if not guild_id: guild_id = channel_id - gml = GuildMemberList(guild_id, channel_id, self) + gml = GuildMemberList(guild_id, channel_id) self.state[channel_id] = gml self.guild_map[guild_id].append(channel_id) return gml @@ -1437,16 +1420,6 @@ class LazyGuildDispatcher(Dispatcher): gml = await self.get_gml(chan_id) gml.unsub(session_id) - async def dispatch(self, guild_id, event: str, *args, **kwargs): - """Call a function specialized in handling the given event""" - try: - handler = getattr(self, f"_handle_{event.lower()}") - except AttributeError: - log.warning("unknown event: {}", event) - return - - await handler(guild_id, *args, **kwargs) - def remove_channel(self, channel_id: int): """Remove a channel from the manager.""" try: @@ -1474,29 +1447,29 @@ class LazyGuildDispatcher(Dispatcher): method = getattr(lazy_list, method_str) await method(*args) - async def _handle_new_role(self, guild_id: int, new_role: dict): + async def new_role(self, guild_id: int, new_role: dict): """Handle the addition of a new group by dispatching it to the member lists.""" await self._call_all_lists(guild_id, "new_role", new_role) - async def _handle_role_pos_upd(self, guild_id, role: dict): + async def role_position_update(self, guild_id, role: dict): await self._call_all_lists(guild_id, "role_pos_update", role) - async def _handle_role_update(self, guild_id, role: dict): + async def role_update(self, guild_id, role: dict): # handle name and hoist changes await self._call_all_lists(guild_id, "role_update", role) - async def _handle_role_delete(self, guild_id, role_id: int): + async def role_delete(self, guild_id, role_id: int): await self._call_all_lists(guild_id, "role_delete", role_id) - async def _handle_pres_update(self, guild_id, user_id: int, partial: dict): + async def pres_update(self, guild_id, user_id: int, partial: dict): await self._call_all_lists(guild_id, "pres_update", user_id, partial) - async def _handle_new_member(self, guild_id, user_id: int): + async def new_member(self, guild_id, user_id: int): await self._call_all_lists(guild_id, "new_member", user_id) - async def _handle_remove_member(self, guild_id, user_id: int): + async def remove_member(self, guild_id, user_id: int): await self._call_all_lists(guild_id, "remove_member", user_id) - async def _handle_update_user(self, guild_id, user_id: int): + async def update_user(self, guild_id, user_id: int): await self._call_all_lists(guild_id, "update_user", user_id) diff --git a/litecord/pubsub/member.py b/litecord/pubsub/member.py index c5a389e..d026ef9 100644 --- a/litecord/pubsub/member.py +++ b/litecord/pubsub/member.py @@ -17,30 +17,20 @@ along with this program. If not, see . """ -from .dispatcher import Dispatcher +from typing import List +from quart import current_app as app +from .dispatcher import GatewayEvent +from .utils import send_event_to_states -class MemberDispatcher(Dispatcher): - """Member backend for Pub/Sub.""" +async def dispatch_member( + guild_id: int, user_id: int, event: GatewayEvent +) -> List[str]: + states = app.state_manager.fetch_states(user_id, guild_id) - KEY_TYPE = tuple + # if no states were found, we should unsub the user from the guild + if not states: + await app.dispatcher.guild.unsub(guild_id, user_id) + return [] - async def dispatch(self, key, event, data): - """Dispatch a single event to a member. - - This is shard-aware. - """ - # we don't keep any state on this dispatcher, so the key - # is just (guild_id, user_id) - guild_id, user_id = key - - # fetch shards - states = self.sm.fetch_states(user_id, guild_id) - - # if no states were found, we should - # unsub the user from the GUILD channel - if not states: - await self.main_dispatcher.unsub("guild", guild_id, user_id) - return - - return await self._dispatch_states(states, event, data) + return await send_event_to_states(states, event) diff --git a/litecord/pubsub/user.py b/litecord/pubsub/user.py index bc53094..178d6b8 100644 --- a/litecord/pubsub/user.py +++ b/litecord/pubsub/user.py @@ -17,23 +17,27 @@ along with this program. If not, see . """ -from .dispatcher import Dispatcher +from typing import Callable, List + +from quart import current_app as app +from .dispatcher import GatewayEvent +from .utils import send_event_to_states -class UserDispatcher(Dispatcher): - """User backend for Pub/Sub.""" - - KEY_TYPE = int - - async def dispatch_filter(self, user_id: int, func, event, data): - """Dispatch an event to all shards of a user.""" - - # filter only states where func() gives true - states = list( - filter(lambda state: func(state.session_id), self.sm.user_states(user_id)) +async def dispatch_user_filter( + user_id: int, filter_func: Callable[[str], bool], event_data: GatewayEvent +) -> List[str]: + """Dispatch to a given user's states, but only for states + where filter_func returns true.""" + states = list( + filter( + lambda state: filter_func(state.session_id), + app.state_manager.user_states(user_id), ) + ) - return await self._dispatch_states(states, event, data) + return await send_event_to_states(states, event_data) - async def dispatch(self, user_id: int, event, data): - return await self.dispatch_filter(user_id, lambda sess_id: True, event, data) + +async def dispatch_user(user_id: int, event_data: GatewayEvent) -> List[str]: + return await dispatch_user_filter(user_id, lambda sess_id: True, event_data) diff --git a/litecord/pubsub/utils.py b/litecord/pubsub/utils.py new file mode 100644 index 0000000..636ba36 --- /dev/null +++ b/litecord/pubsub/utils.py @@ -0,0 +1,41 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +import logging +from typing import List, Tuple, Any +from ..gateway.state import GatewayState + +log = logging.getLogger(__name__) + + +async def send_event_to_states( + states: List[GatewayState], event_data: Tuple[str, Any] +) -> List[str]: + """Dispatch an event to a list of states.""" + res = [] + + for state in states: + try: + event, data = event_data + await state.ws.dispatch(event, data) + res.append(state.session_id) + except Exception: + log.exception("error while dispatching") + + return res diff --git a/litecord/schemas.py b/litecord/schemas.py index 0410dc1..3962934 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -482,6 +482,9 @@ INVITE = { "required": False, "nullable": True, }, # discord client sends invite code there + # sent by official client, unknown purpose + "target_user_id": {"type": "snowflake", "required": False, "nullable": True}, + "target_user_type": {"type": "number", "required": False, "nullable": True}, } USER_SETTINGS = { diff --git a/litecord/storage.py b/litecord/storage.py index b01bd04..4c22980 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -709,7 +709,9 @@ class Storage: return res - async def get_guild_extra(self, guild_id: int, user_id=None, large=None) -> Dict: + async def get_guild_extra( + self, guild_id: int, user_id: Optional[int] = None, large: Optional[int] = None + ) -> Dict: """Get extra information about a guild.""" res = {} diff --git a/litecord/system_messages.py b/litecord/system_messages.py index 2d548db..184f5e5 100644 --- a/litecord/system_messages.py +++ b/litecord/system_messages.py @@ -181,9 +181,6 @@ async def send_sys_message( raise ValueError("Invalid system message type") message_id = await handler(channel_id, *args, **kwargs) - message = await app.storage.get_message(message_id) - - await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message) - + await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", message)) return message_id diff --git a/litecord/utils.py b/litecord/utils.py index fe3c2d4..7fbc43f 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -252,7 +252,7 @@ async def maybe_lazy_guild_dispatch( if isinstance(role, dict) and not role["hoist"] and not force: return - await app.dispatcher.dispatch("lazy_guild", guild_id, event, role) + await (getattr(app.lazy_guild, event))(guild_id, role) def extract_limit(request_, default: int = 50, max_val: int = 100): diff --git a/litecord/voice/lvsp_manager.py b/litecord/voice/lvsp_manager.py index a75a544..23ab10a 100644 --- a/litecord/voice/lvsp_manager.py +++ b/litecord/voice/lvsp_manager.py @@ -17,11 +17,12 @@ along with this program. If not, see . """ -from typing import Optional +from typing import Optional, Dict, List from collections import defaultdict from dataclasses import dataclass from logbook import Logger +from quart import current_app as app from litecord.voice.lvsp_conn import LVSPConnection @@ -42,15 +43,15 @@ class LVSPManager: Spawns :class:`LVSPConnection` as needed, etc. """ - def __init__(self, app, voice): - self.app = app + def __init__(self, app_, voice): + self.app = app_ self.voice = voice # map servers to LVSPConnection - self.conns = {} + self.conns: Dict[str, LVSPConnection] = {} # maps regions to server hostnames - self.servers = defaultdict(list) + self.servers: Dict[str, List[str]] = defaultdict(list) # maps Union[GuildID, DMId, GroupDMId] to server hostnames self.assign = {} @@ -84,7 +85,7 @@ class LVSPManager: continue self.regions[region.id] = region - self.app.loop.create_task(self._spawn_region(region)) + app.sched.spawn(self._spawn_region(region)) async def _spawn_region(self, region: Region): """Spawn a region. Involves fetching all the hostnames diff --git a/litecord/voice/manager.py b/litecord/voice/manager.py index 008695a..5499363 100644 --- a/litecord/voice/manager.py +++ b/litecord/voice/manager.py @@ -20,6 +20,7 @@ along with this program. If not, see . from typing import Tuple, Dict, List from collections import defaultdict from dataclasses import fields +from quart import current_app as app from logbook import Logger @@ -286,6 +287,6 @@ class VoiceManager: # slow, but it be like that, also copied from other users... for guild_id in guild_ids: guild = await self.app.storage.get_guild_full(guild_id, None) - await self.app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) + await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild)) # TODO propagate the channel deprecation to LVSP connections diff --git a/manage/main.py b/manage/main.py index 22c447c..2008237 100644 --- a/manage/main.py +++ b/manage/main.py @@ -104,7 +104,7 @@ def main(config): # as the managers require it # and the migrate command also sets the db up if argv[1] != "migrate": - init_app_managers(app, voice=False) + init_app_managers(app, init_voice=False) args = parser.parse_args() loop.run_until_complete(_ctx_wrapper(app, args)) diff --git a/run.py b/run.py index 55aeb37..75cc031 100644 --- a/run.py +++ b/run.py @@ -95,6 +95,7 @@ from litecord.images import IconManager from litecord.jobs import JobManager from litecord.voice.manager import VoiceManager from litecord.guild_memory_store import GuildMemoryStore +from litecord.pubsub.lazy_guild import LazyGuildManager from litecord.gateway.gateway import websocket_handler @@ -254,7 +255,7 @@ async def init_app_db(app_): app_.sched = JobManager() -def init_app_managers(app_, *, voice=True): +def init_app_managers(app_: Quart, *, init_voice=True): """Initialize singleton classes.""" app_.loop = asyncio.get_event_loop() app_.ratelimiter = RatelimitManager(app_.config.get("_testing")) @@ -265,7 +266,7 @@ def init_app_managers(app_, *, voice=True): app_.icons = IconManager(app_) - app_.dispatcher = EventDispatcher(app_) + app_.dispatcher = EventDispatcher() app_.presence = PresenceManager(app_) app_.storage.presence = app_.presence @@ -274,10 +275,11 @@ def init_app_managers(app_, *, voice=True): # we do this because of a bug on ./manage.py where it # cancels the LVSPManager's spawn regions task. we don't # need to start it on manage time. - if voice: + if init_voice: app_.voice = VoiceManager(app_) app_.guild_store = GuildMemoryStore() + app_.lazy_guild = LazyGuildManager() async def api_index(app_):