remove code from dispatcher

leftovers are TBD.

 - constrict Dispatcher.dispatch() to arity 3
 - add helper methods to Dispatcher
 - add EventType to Dispatcher

While fixing things, it was discovered that many of the things inside
LazyGuildDispatcher were just interfaces to GuildMemberList, in a very
weird way, just so it could be fitted inside the main Dispatcher. it was
decided to remove those unecessary interfaces, clients shall use the
manager directly.
This commit is contained in:
Luna 2020-02-09 21:20:08 +00:00
parent 39e8a1ad7e
commit b0eb3247fd
46 changed files with 788 additions and 871 deletions

View File

@ -18,7 +18,7 @@ zstandard = "*"
winter = {editable = true,git = "https://gitlab.com/elixire/winter.git"} winter = {editable = true,git = "https://gitlab.com/elixire/winter.git"}
[dev-packages] [dev-packages]
pytest = "==5.1.2" pytest = "==5.3.2"
pytest-asyncio = "==0.10.0" pytest-asyncio = "==0.10.0"
mypy = "*" mypy = "*"
black = "*" black = "*"

84
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3" "sha256": "dedc41184df539a608717e68c108ccfb4d529acb1d1702d83de223840c7cc754"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -123,41 +123,36 @@
}, },
"cffi": { "cffi": {
"hashes": [ "hashes": [
"sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42", "sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff",
"sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04", "sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b",
"sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5", "sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac",
"sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54", "sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0",
"sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba", "sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384",
"sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57", "sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26",
"sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396", "sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6",
"sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12", "sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b",
"sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97", "sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e",
"sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43", "sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd",
"sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db", "sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2",
"sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3", "sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66",
"sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b", "sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc",
"sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579", "sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8",
"sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346", "sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55",
"sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159", "sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4",
"sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652", "sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5",
"sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e", "sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d",
"sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a", "sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78",
"sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506", "sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa",
"sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f", "sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793",
"sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d", "sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f",
"sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c", "sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a",
"sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20", "sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f",
"sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858", "sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30",
"sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc", "sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f",
"sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a", "sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3",
"sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3", "sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c"
"sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e",
"sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410",
"sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25",
"sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b",
"sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d"
], ],
"version": "==1.13.2" "version": "==1.14.0"
}, },
"chardet": { "chardet": {
"hashes": [ "hashes": [
@ -195,10 +190,10 @@
}, },
"h2": { "h2": {
"hashes": [ "hashes": [
"sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e", "sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5",
"sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4" "sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14"
], ],
"version": "==3.1.1" "version": "==3.2.0"
}, },
"hpack": { "hpack": {
"hashes": [ "hashes": [
@ -510,13 +505,6 @@
], ],
"version": "==1.4.3" "version": "==1.4.3"
}, },
"atomicwrites": {
"hashes": [
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
"sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6"
],
"version": "==1.3.0"
},
"attrs": { "attrs": {
"hashes": [ "hashes": [
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
@ -647,11 +635,11 @@
}, },
"pytest": { "pytest": {
"hashes": [ "hashes": [
"sha256:95d13143cc14174ca1a01ec68e84d76ba5d9d493ac02716fd9706c949a505210", "sha256:6b571215b5a790f9b41f19f3531c53a45cf6bb8ef2988bc1ff9afb38270b25fa",
"sha256:b78fe2881323bd44fd9bd76e5317173d4316577e7b1cddebae9136a4495ec865" "sha256:e41d489ff43948babd0fad7ad5e49b8735d5d55e26628a58673c39ff61d95de4"
], ],
"index": "pypi", "index": "pypi",
"version": "==5.1.2" "version": "==5.3.2"
}, },
"pytest-asyncio": { "pytest-asyncio": {
"hashes": [ "hashes": [

View File

@ -68,7 +68,7 @@ async def _update_features(guild_id: int, features: list):
) )
guild = await app.storage.get_guild_full(guild_id) guild = await app.storage.get_guild_full(guild_id)
await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild) await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
@bp.route("/<int:guild_id>/features", methods=["PATCH"]) @bp.route("/<int:guild_id>/features", methods=["PATCH"])

View File

@ -63,10 +63,10 @@ async def update_guild(guild_id: int):
if old_unavailable and not new_unavailable: if old_unavailable and not new_unavailable:
# guild became available # guild became available
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild) await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild))
else: else:
# guild became unavailable # guild became unavailable
await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild) await app.dispatcher.guild.dispatch(guild_id, ("GUILD_DELETE", guild))
return jsonify(guild) return jsonify(guild)

View File

@ -24,12 +24,13 @@ import itsdangerous
import bcrypt import bcrypt
from quart import Blueprint, jsonify, request, current_app as app from quart import Blueprint, jsonify, request, current_app as app
from logbook import Logger from logbook import Logger
from winter import get_snowflake
from litecord.auth import token_check from litecord.auth import token_check
from litecord.common.users import create_user from litecord.common.users import create_user
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
from litecord.errors import BadRequest from litecord.errors import BadRequest
from winter import get_snowflake from litecord.pubsub.user import dispatch_user
from .invites import use_invite from .invites import use_invite
log = Logger(__name__) log = Logger(__name__)
@ -172,7 +173,7 @@ async def verify_user():
) )
new_user = await app.storage.get_user(user_id, True) new_user = await app.storage.get_user(user_id, True)
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user) await dispatch_user(user_id, ("USER_UPDATE", new_user))
return "", 204 return "", 204

View File

@ -42,6 +42,7 @@ from litecord.common.messages import (
msg_add_attachment, msg_add_attachment,
msg_guild_text_mentions, msg_guild_text_mentions,
) )
from litecord.pubsub.user import dispatch_user
log = Logger(__name__) log = Logger(__name__)
@ -136,10 +137,10 @@ async def _dm_pre_dispatch(channel_id, peer_id):
# dispatch CHANNEL_CREATE so the client knows which # dispatch CHANNEL_CREATE so the client knows which
# channel the future event is about # channel the future event is about
await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan) await dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
# subscribe the peer to the channel # subscribe the peer to the channel
await app.dispatcher.sub("channel", channel_id, peer_id) await app.dispatcher.channel.sub(channel_id, peer_id)
# insert it on dm_channel_state so the client # insert it on dm_channel_state so the client
# is subscribed on the future # is subscribed on the future
@ -242,7 +243,7 @@ async def _create_message(channel_id):
await _dm_pre_dispatch(channel_id, user_id) await _dm_pre_dispatch(channel_id, user_id)
await _dm_pre_dispatch(channel_id, guild_id) await _dm_pre_dispatch(channel_id, guild_id)
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload) await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
# spawn url processor for embedding of images # spawn url processor for embedding of images
perms = await get_permissions(user_id, channel_id) perms = await get_permissions(user_id, channel_id)
@ -350,7 +351,7 @@ async def edit_message(channel_id, message_id):
# only dispatch MESSAGE_UPDATE if any update # only dispatch MESSAGE_UPDATE if any update
# actually happened # actually happened
if updated: if updated:
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message) await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message))
return jsonify(message) return jsonify(message)
@ -427,9 +428,9 @@ async def delete_message(channel_id, message_id):
message_id, message_id,
) )
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel",
channel_id, channel_id,
(
"MESSAGE_DELETE", "MESSAGE_DELETE",
{ {
"id": str(message_id), "id": str(message_id),
@ -437,6 +438,7 @@ async def delete_message(channel_id, message_id):
# for lazy guilds # for lazy guilds
"guild_id": str(guild_id), "guild_id": str(guild_id),
}, },
),
) )
return "", 204 return "", 204

View File

@ -106,11 +106,15 @@ async def add_pin(channel_id, message_id):
timestamp = snowflake_datetime(row["message_id"]) timestamp = snowflake_datetime(row["message_id"])
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel",
channel_id, channel_id,
(
"CHANNEL_PINS_UPDATE", "CHANNEL_PINS_UPDATE",
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)}, {
"channel_id": str(channel_id),
"last_pin_timestamp": timestamp_(timestamp),
},
),
) )
await send_sys_message( await send_sys_message(
@ -149,11 +153,15 @@ async def delete_pin(channel_id, message_id):
timestamp = snowflake_datetime(row["message_id"]) timestamp = snowflake_datetime(row["message_id"])
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel",
channel_id, channel_id,
(
"CHANNEL_PINS_UPDATE", "CHANNEL_PINS_UPDATE",
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()}, {
"channel_id": str(channel_id),
"last_pin_timestamp": timestamp.isoformat(),
},
),
) )
return "", 204 return "", 204

View File

@ -141,10 +141,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
payload["guild_id"] = str(guild_id) payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_REACTION_ADD", payload))
"channel", channel_id, "MESSAGE_REACTION_ADD", payload
)
return "", 204 return "", 204
@ -206,8 +203,8 @@ async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
payload["guild_id"] = str(guild_id) payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel", channel_id, "MESSAGE_REACTION_REMOVE", payload channel_id, ("MESSAGE_REACTION_REMOVE", payload)
) )
@ -290,6 +287,6 @@ async def remove_all_reactions(channel_id, message_id):
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
payload["guild_id"] = str(guild_id) payload["guild_id"] = str(guild_id)
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload channel_id, ("MESSAGE_REACTION_REMOVE_ALL", payload)
) )

View File

@ -20,9 +20,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import time import time
import datetime import datetime
from typing import List, Optional from typing import List, Optional
from dataclasses import dataclass
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
from winter import snowflake_datetime
from litecord.auth import token_check from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
@ -41,8 +43,9 @@ from litecord.system_messages import send_sys_message
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
from litecord.utils import search_result_from_list from litecord.utils import search_result_from_list
from litecord.embed.messages import process_url_embed, msg_update_embeds from litecord.embed.messages import process_url_embed, msg_update_embeds
from winter import snowflake_datetime
from litecord.common.channels import channel_ack from litecord.common.channels import channel_ack
from litecord.pubsub.user import dispatch_user
from litecord.permissions import get_permissions, Permissions
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint("channels", __name__) bp = Blueprint("channels", __name__)
@ -80,9 +83,7 @@ async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
async def _update_guild_chan_text(guild_id: int, channel_id: int): async def _update_guild_chan_text(guild_id: int, channel_id: int):
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id") res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id") res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id") res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
# if none of them were actually updated, # if none of them were actually updated,
@ -93,7 +94,7 @@ async def _update_guild_chan_text(guild_id: int, channel_id: int):
# at least one of the fields were updated, # at least one of the fields were updated,
# dispatch GUILD_UPDATE # dispatch GUILD_UPDATE
guild = await app.storage.get_guild(guild_id) guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
async def _update_guild_chan_voice(guild_id: int, channel_id: int): async def _update_guild_chan_voice(guild_id: int, channel_id: int):
@ -104,7 +105,7 @@ async def _update_guild_chan_voice(guild_id: int, channel_id: int):
return return
guild = await app.storage.get_guild(guild_id) guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) await app.dispatcher.dispatch(guild_id, ("GUILD_UPDATE", guild))
async def _update_guild_chan_cat(guild_id: int, channel_id: int): async def _update_guild_chan_cat(guild_id: int, channel_id: int):
@ -134,7 +135,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int):
# tell all people in the guild of the category removal # tell all people in the guild of the category removal
for child_id in childs: for child_id in childs:
child = await app.storage.get_channel(child_id) child = await app.storage.get_channel(child_id)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child) await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", child))
async def _delete_messages(channel_id): async def _delete_messages(channel_id):
@ -249,12 +250,10 @@ async def close_channel(channel_id):
) )
# clean its member list representation # clean its member list representation
lazy_guilds = app.dispatcher.backends["lazy_guild"] app.lazy_guild.remove_channel(channel_id)
lazy_guilds.remove_channel(channel_id)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan) await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_DELETE", chan))
await app.dispatcher.channel.drop(channel_id)
await app.dispatcher.remove("channel", channel_id)
return jsonify(chan) return jsonify(chan)
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
@ -273,11 +272,9 @@ async def close_channel(channel_id):
channel_id, channel_id,
) )
# unsubscribe
await app.dispatcher.unsub("channel", channel_id, user_id)
# nothing happens to the other party of the dm channel # nothing happens to the other party of the dm channel
await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan) await app.dispatcher.channel.unsub(channel_id, user_id)
await dispatch_user(user_id, ("CHANNEL_DELETE", chan))
return jsonify(chan) return jsonify(chan)
@ -318,10 +315,54 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
continue continue
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan) await app.dispatcher.guild.dispatch(guild_id, "CHANNEL_UPDATE", chan)
async def _process_overwrites(channel_id: int, overwrites: list): @dataclass
class Target:
type: int
user_id: Optional[int]
role_id: Optional[int]
@property
def is_user(self):
return self.type == 0
@property
def is_role(self):
return self.type == 1
async def _dispatch_action(guild_id: int, channel_id: int, user_id: int, perms) -> None:
"""Apply an action of sub/unsub to all states of a user."""
states = app.state_manager.fetch_states(user_id, guild_id)
for state in states:
if perms.read_messages:
await app.dispatcher.channel.sub(channel_id, state.session_id)
else:
await app.dispatcher.channel.unsub(channel_id, state.session_id)
async def _process_overwrites(guild_id: int, channel_id: int, overwrites: list) -> None:
# user_ids serves as a "prospect" user id list.
# for each overwrite we apply, we fill this list with user ids we
# want to check later to subscribe/unsubscribe from the channel.
# (users without read_messages are automatically unsubbed since we
# don't want to leak messages to them when they dont have perms anymore)
# the expensiveness of large overwrite/role chains shines here.
# since each user id we fill in implies an entire get_permissions call
# (because we don't have the answer if a user is to be subbed/unsubbed
# with only overwrites, an overwrite for a user allowing them might be
# overwritten by a role overwrite denying them if they have the role),
# we get a lot of tension on that, causing channel updates to lag a bit.
# there may be some good optimizations to do here, such as doing some
# precalculations like fetching get_permissions for everyone first, then
# applying the new overwrites one by one, then subbing/unsubbing at the
# end, but it would be very memory intensive.
user_ids: List[int] = []
for overwrite in overwrites: for overwrite in overwrites:
# 0 for member overwrite, 1 for role overwrite # 0 for member overwrite, 1 for role overwrite
@ -329,7 +370,9 @@ async def _process_overwrites(channel_id: int, overwrites: list):
target_role = None if target_type == 0 else overwrite["id"] target_role = None if target_type == 0 else overwrite["id"]
target_user = overwrite["id"] if target_type == 0 else None target_user = overwrite["id"] if target_type == 0 else None
col_name = "target_user" if target_type == 0 else "target_role" target = Target(target_type, target_role, target_user)
col_name = "target_user" if target.is_user else "target_role"
constraint_name = f"channel_overwrites_{col_name}_uniq" constraint_name = f"channel_overwrites_{col_name}_uniq"
await app.db.execute( await app.db.execute(
@ -352,6 +395,17 @@ async def _process_overwrites(channel_id: int, overwrites: list):
overwrite["deny"], overwrite["deny"],
) )
if target.is_user:
perms = Permissions(overwrite["allow"] & ~overwrite["deny"])
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
elif target.is_role:
user_ids.extend(await app.storage.get_role_members(target.role_id))
for user_id in user_ids:
perms = await get_permissions(user_id, channel_id)
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"]) @bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
async def put_channel_overwrite(channel_id: int, overwrite_id: int): async def put_channel_overwrite(channel_id: int, overwrite_id: int):
@ -371,6 +425,7 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int):
) )
await _process_overwrites( await _process_overwrites(
guild_id,
channel_id, channel_id,
[ [
{ {
@ -447,7 +502,7 @@ async def _update_channel_common(channel_id: int, guild_id: int, j: dict):
if "channel_overwrites" in j: if "channel_overwrites" in j:
overwrites = j["channel_overwrites"] overwrites = j["channel_overwrites"]
await _process_overwrites(channel_id, overwrites) await _process_overwrites(guild_id, channel_id, overwrites)
async def _common_guild_chan(channel_id, j: dict): async def _common_guild_chan(channel_id, j: dict):
@ -566,9 +621,9 @@ async def update_channel(channel_id: int):
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
if is_guild: if is_guild:
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan) await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
else: else:
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan) await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
return jsonify(chan) return jsonify(chan)
@ -578,9 +633,9 @@ async def trigger_typing(channel_id):
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel",
channel_id, channel_id,
(
"TYPING_START", "TYPING_START",
{ {
"channel_id": str(channel_id), "channel_id": str(channel_id),
@ -589,6 +644,7 @@ async def trigger_typing(channel_id):
# guild_id for lazy guilds # guild_id for lazy guilds
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None, "guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
}, },
),
) )
return "", 204 return "", 204
@ -816,5 +872,5 @@ async def bulk_delete(channel_id: int):
if res == "DELETE 0": if res == "DELETE 0":
raise BadRequest("No messages were removed") raise BadRequest("No messages were removed")
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload) await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_DELETE_BULK", payload))
return "", 204 return "", 204

View File

@ -27,6 +27,7 @@ from litecord.errors import BadRequest, Forbidden
from winter import get_snowflake from winter import get_snowflake
from litecord.system_messages import send_sys_message from litecord.system_messages import send_sys_message
from litecord.pubsub.channel import gdm_recipient_view from litecord.pubsub.channel import gdm_recipient_view
from litecord.pubsub.user import dispatch_user
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint("dm_channels", __name__) bp = Blueprint("dm_channels", __name__)
@ -82,11 +83,11 @@ async def gdm_create(user_id, peer_id) -> int:
await _raw_gdm_add(channel_id, user_id) await _raw_gdm_add(channel_id, user_id)
await _raw_gdm_add(channel_id, peer_id) await _raw_gdm_add(channel_id, peer_id)
await app.dispatcher.sub("channel", channel_id, user_id) await app.dispatcher.channel.sub(channel_id, user_id)
await app.dispatcher.sub("channel", channel_id, peer_id) await app.dispatcher.channel.sub(channel_id, peer_id)
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan) await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_CREATE", chan))
return channel_id return channel_id
@ -104,13 +105,10 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
# the reasoning behind gdm_recipient_view is in its docstring. # the reasoning behind gdm_recipient_view is in its docstring.
await app.dispatcher.dispatch( await dispatch_user(peer_id, ("CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)))
"user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
)
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan) await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
await app.dispatcher.channel.sub(peer_id)
await app.dispatcher.sub("channel", peer_id)
if user_id: if user_id:
await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id) await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
@ -128,17 +126,19 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
await _raw_gdm_remove(channel_id, peer_id) await _raw_gdm_remove(channel_id, peer_id)
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch( await dispatch_user(peer_id, ("CHANNEL_DELETE", gdm_recipient_view(chan, user_id)))
"user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
)
await app.dispatcher.unsub("channel", peer_id) await app.dispatcher.channel.unsub(peer_id)
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel",
channel_id, channel_id,
(
"CHANNEL_RECIPIENT_REMOVE", "CHANNEL_RECIPIENT_REMOVE",
{"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)}, {
"channel_id": str(channel_id),
"user": await app.storage.get_user(peer_id),
},
),
) )
author_id = peer_id if user_id is None else user_id author_id = peer_id if user_id is None else user_id
@ -174,9 +174,8 @@ async def gdm_destroy(channel_id):
channel_id, channel_id,
) )
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan) await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_DELETE", chan))
await app.dispatcher.channel.drop(channel_id)
await app.dispatcher.remove("channel", channel_id)
async def gdm_is_member(channel_id: int, user_id: int) -> bool: async def gdm_is_member(channel_id: int, user_id: int) -> bool:

View File

@ -59,19 +59,9 @@ async def create_channel(guild_id):
new_channel_id = get_snowflake() new_channel_id = get_snowflake()
await create_guild_channel(guild_id, new_channel_id, channel_type, **j) await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
# TODO: do a better method
# subscribe the currently subscribed users to the new channel
# by getting all user ids and subscribing each one by one.
# since GuildDispatcher calls Storage.get_channel_ids,
# it will subscribe all users to the newly created channel.
guild_pubsub = app.dispatcher.backends["guild"]
user_ids = guild_pubsub.state[guild_id]
for uid in user_ids:
await app.dispatcher.sub("guild", guild_id, uid)
chan = await app.storage.get_channel(new_channel_id) chan = await app.storage.get_channel(new_channel_id)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan) await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_CREATE", chan))
return jsonify(chan) return jsonify(chan)
@ -79,7 +69,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
"""Fetch new information about the channel and dispatch """Fetch new information about the channel and dispatch
a single CHANNEL_UPDATE event to the guild.""" a single CHANNEL_UPDATE event to the guild."""
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan) await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
async def _do_single_swap(guild_id: int, pair: tuple): async def _do_single_swap(guild_id: int, pair: tuple):

View File

@ -32,14 +32,15 @@ bp = Blueprint("guild.emoji", __name__)
async def _dispatch_emojis(guild_id): async def _dispatch_emojis(guild_id):
"""Dispatch a Guild Emojis Update payload to a guild.""" """Dispatch a Guild Emojis Update payload to a guild."""
await app.dispatcher.dispatch( await app.dispatcher.guild.dispatch(
"guild",
guild_id, guild_id,
(
"GUILD_EMOJIS_UPDATE", "GUILD_EMOJIS_UPDATE",
{ {
"guild_id": str(guild_id), "guild_id": str(guild_id),
"emojis": await app.storage.get_guild_emojis(guild_id), "emojis": await app.storage.get_guild_emojis(guild_id),
}, },
),
) )

View File

@ -192,12 +192,9 @@ async def modify_guild_member(guild_id, member_id):
if nick_flag: if nick_flag:
partial["nick"] = j["nick"] partial["nick"] = j["nick"]
await app.dispatcher.dispatch( await app.lazy_guild.pres_update(guild_id, user_id, partial)
"lazy_guild", guild_id, "pres_update", user_id, partial await app.dispatcher.guild.dispatch(
) guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
) )
return "", 204 return "", 204
@ -228,12 +225,9 @@ async def update_nickname(guild_id):
member.pop("joined_at") member.pop("joined_at")
# call pres_update for nick changes, etc. # call pres_update for nick changes, etc.
await app.dispatcher.dispatch( await app.lazy_guild.pres_update(guild_id, user_id, {"nick": j["nick"]})
"lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]} await app.dispatcher.guild.dispatch(
) guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
await app.dispatcher.dispatch_guild(
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
) )
return j["nick"] return j["nick"]

View File

@ -86,10 +86,12 @@ async def create_ban(guild_id, member_id):
await remove_member(guild_id, member_id) await remove_member(guild_id, member_id)
await app.dispatcher.dispatch_guild( await app.dispatcher.guild.dispatch(
guild_id, guild_id,
(
"GUILD_BAN_ADD", "GUILD_BAN_ADD",
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
),
) )
return "", 204 return "", 204
@ -115,10 +117,12 @@ async def remove_ban(guild_id, banned_id):
if res == "DELETE 0": if res == "DELETE 0":
return "", 204 return "", 204
await app.dispatcher.dispatch_guild( await app.dispatcher.guild.dispatch(
guild_id, guild_id,
(
"GUILD_BAN_REMOVE", "GUILD_BAN_REMOVE",
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)}, {"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
),
) )
return "", 204 return "", 204

View File

@ -67,8 +67,8 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role) await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
await app.dispatcher.dispatch_guild( await app.dispatcher.guild.dispatch(
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role} guild_id, ("GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role})
) )
return role return role
@ -304,10 +304,9 @@ async def delete_guild_role(guild_id, role_id):
await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True) await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
await app.dispatcher.dispatch_guild( await app.dispatcher.guild.dispatch(
guild_id, guild_id,
"GUILD_ROLE_DELETE", ("GUILD_ROLE_DELETE", {"guild_id": str(guild_id), "role_id": str(role_id)},),
{"guild_id": str(guild_id), "role_id": str(role_id)},
) )
return "", 204 return "", 204

View File

@ -191,8 +191,9 @@ async def create_guild():
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250) guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
await app.dispatcher.sub("guild", guild_id, user_id) await app.dispatcher.guild.sub_user(guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild_total))
return jsonify(guild_total) return jsonify(guild_total)
@ -350,7 +351,7 @@ async def _update_guild(guild_id):
) )
guild = await app.storage.get_guild_full(guild_id, user_id) guild = await app.storage.get_guild_full(guild_id, user_id)
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
return jsonify(guild) return jsonify(guild)

View File

@ -116,7 +116,7 @@ async def _inv_check_age(inv: dict):
await delete_invite(inv["code"]) await delete_invite(inv["code"])
raise InvalidInvite("Invite is expired") raise InvalidInvite("Invite is expired")
if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]: if inv["max_uses"] != -1 and inv["uses"] > inv["max_uses"]:
await delete_invite(inv["code"]) await delete_invite(inv["code"])
raise InvalidInvite("Too many uses") raise InvalidInvite("Too many uses")

View File

@ -24,6 +24,7 @@ from ..auth import token_check
from ..schemas import validate, RELATIONSHIP, SPECIFIC_FRIEND from ..schemas import validate, RELATIONSHIP, SPECIFIC_FRIEND
from ..enums import RelationshipType from ..enums import RelationshipType
from litecord.errors import BadRequest from litecord.errors import BadRequest
from litecord.pubsub.user import dispatch_user
bp = Blueprint("relationship", __name__) bp = Blueprint("relationship", __name__)
@ -36,17 +37,17 @@ async def get_me_relationships():
async def _dispatch_single_pres(user_id, presence: dict): async def _dispatch_single_pres(user_id, presence: dict):
await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence) await dispatch_user(user_id, ("PRESENCE_UPDATE", presence))
async def _unsub_friend(user_id, peer_id): async def _unsub_friend(user_id, peer_id):
await app.dispatcher.unsub("friend", user_id, peer_id) await app.dispatcher.friend.unsub(user_id, peer_id)
await app.dispatcher.unsub("friend", peer_id, user_id) await app.dispatcher.friend.unsub(peer_id, user_id)
async def _sub_friend(user_id, peer_id): async def _sub_friend(user_id, peer_id):
await app.dispatcher.sub("friend", user_id, peer_id) await app.dispatcher.friend.sub(user_id, peer_id)
await app.dispatcher.sub("friend", peer_id, user_id) await app.dispatcher.friend.sub(peer_id, user_id)
# dispatch presence update to the user and peer about # dispatch presence update to the user and peer about
# eachother's presence. # eachother's presence.
@ -107,8 +108,8 @@ async def make_friend(
_friend, _friend,
) )
await app.dispatcher.dispatch_user( await dispatch_user(
peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)} peer_id, ("RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)})
) )
await _unsub_friend(user_id, peer_id) await _unsub_friend(user_id, peer_id)
@ -130,35 +131,41 @@ async def make_friend(
_friend, _friend,
) )
_dispatch = app.dispatcher.dispatch_user _dispatch = dispatch_user
if existing: if existing:
# accepted a friend request, dispatch respective # accepted a friend request, dispatch respective
# relationship events # relationship events
await _dispatch( await _dispatch(
user_id, user_id,
(
"RELATIONSHIP_REMOVE", "RELATIONSHIP_REMOVE",
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)}, {"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
),
) )
await _dispatch( await _dispatch(
user_id, user_id,
(
"RELATIONSHIP_ADD", "RELATIONSHIP_ADD",
{ {
"type": _friend, "type": _friend,
"id": str(peer_id), "id": str(peer_id),
"user": await app.storage.get_user(peer_id), "user": await app.storage.get_user(peer_id),
}, },
),
) )
await _dispatch( await _dispatch(
peer_id, peer_id,
"RELATIONSHIP_ADD", "RELATIONSHIP_ADD",
(
{ {
"type": _friend, "type": _friend,
"id": str(user_id), "id": str(user_id),
"user": await app.storage.get_user(user_id), "user": await app.storage.get_user(user_id),
}, },
),
) )
await _sub_friend(user_id, peer_id) await _sub_friend(user_id, peer_id)
@ -169,22 +176,26 @@ async def make_friend(
if rel_type == _friend: if rel_type == _friend:
await _dispatch( await _dispatch(
user_id, user_id,
(
"RELATIONSHIP_ADD", "RELATIONSHIP_ADD",
{ {
"id": str(peer_id), "id": str(peer_id),
"type": RelationshipType.OUTGOING.value, "type": RelationshipType.OUTGOING.value,
"user": await app.storage.get_user(peer_id), "user": await app.storage.get_user(peer_id),
}, },
),
) )
await _dispatch( await _dispatch(
peer_id, peer_id,
(
"RELATIONSHIP_ADD", "RELATIONSHIP_ADD",
{ {
"id": str(user_id), "id": str(user_id),
"type": RelationshipType.INCOMING.value, "type": RelationshipType.INCOMING.value,
"user": await app.storage.get_user(user_id), "user": await app.storage.get_user(user_id),
}, },
),
) )
# we don't make the pubsub link # we don't make the pubsub link
@ -240,7 +251,7 @@ async def add_relationship(peer_id: int):
# make_friend did not succeed, so we # make_friend did not succeed, so we
# assume it is a block and dispatch # assume it is a block and dispatch
# the respective RELATIONSHIP_ADD. # the respective RELATIONSHIP_ADD.
await app.dispatcher.dispatch_user( await dispatch_user(
user_id, user_id,
"RELATIONSHIP_ADD", "RELATIONSHIP_ADD",
{ {
@ -261,7 +272,7 @@ async def remove_relationship(peer_id: int):
user_id = await token_check() user_id = await token_check()
_friend = RelationshipType.FRIEND.value _friend = RelationshipType.FRIEND.value
_block = RelationshipType.BLOCK.value _block = RelationshipType.BLOCK.value
_dispatch = app.dispatcher.dispatch_user _dispatch = dispatch_user
rel_type = await app.db.fetchval( rel_type = await app.db.fetchval(
""" """
@ -307,7 +318,8 @@ async def remove_relationship(peer_id: int):
) )
await _dispatch( await _dispatch(
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type} user_id,
("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}),
) )
peer_del_type = ( peer_del_type = (
@ -315,7 +327,8 @@ async def remove_relationship(peer_id: int):
) )
await _dispatch( await _dispatch(
peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type} peer_id,
("RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}),
) )
await _unsub_friend(user_id, peer_id) await _unsub_friend(user_id, peer_id)
@ -334,7 +347,7 @@ async def remove_relationship(peer_id: int):
) )
await _dispatch( await _dispatch(
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block} user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block})
) )
await _unsub_friend(user_id, peer_id) await _unsub_friend(user_id, peer_id)

View File

@ -22,6 +22,7 @@ from quart import Blueprint, jsonify, request, current_app as app
from litecord.auth import token_check from litecord.auth import token_check
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
from litecord.blueprints.checks import guild_check from litecord.blueprints.checks import guild_check
from litecord.pubsub.user import dispatch_user
bp = Blueprint("users_settings", __name__) bp = Blueprint("users_settings", __name__)
@ -58,7 +59,7 @@ async def patch_current_settings():
) )
settings = await app.user_storage.get_user_settings(user_id) settings = await app.user_storage.get_user_settings(user_id)
await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings) await dispatch_user(user_id, ("USER_SETTINGS_UPDATE", settings))
return jsonify(settings) return jsonify(settings)
@ -123,7 +124,7 @@ async def patch_guild_settings(guild_id: int):
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id) settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings) await dispatch_user(user_id, ("USER_GUILD_SETTINGS_UPDATE", settings))
return jsonify(settings) return jsonify(settings)
@ -157,8 +158,8 @@ async def put_note(target_id: int):
note, note,
) )
await app.dispatcher.dispatch_user( await dispatch_user(
user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note} user_id, ("USER_NOTE_UPDATE", {"id": str(target_id), "note": note})
) )
return "", 204 return "", 204

View File

@ -156,11 +156,12 @@ async def webhook_token_check(webhook_id: int, webhook_token: str):
async def _dispatch_webhook_update(guild_id: int, channel_id): async def _dispatch_webhook_update(guild_id: int, channel_id):
await app.dispatcher.dispatch( await app.dispatcher.guild.dispatch(
"guild",
guild_id, guild_id,
(
"WEBHOOKS_UPDATE", "WEBHOOKS_UPDATE",
{"guild_id": str(guild_id), "channel_id": str(channel_id)}, {"guild_id": str(guild_id), "channel_id": str(channel_id)},
),
) )
@ -503,7 +504,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
payload = await app.storage.get_message(message_id) payload = await app.storage.get_message(message_id)
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload) await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
# spawn embedder in the background, even when we're on a webhook. # spawn embedder in the background, even when we're on a webhook.
app.sched.spawn(process_url_embed(payload)) app.sched.spawn(process_url_embed(payload))

View File

@ -22,6 +22,8 @@ from quart import current_app as app
from litecord.errors import ForbiddenDM from litecord.errors import ForbiddenDM
from litecord.enums import RelationshipType from litecord.enums import RelationshipType
from litecord.pubsub.member import dispatch_member
from litecord.pubsub.user import dispatch_user
async def channel_ack( async def channel_ack(
@ -54,20 +56,24 @@ async def channel_ack(
) )
if guild_id: if guild_id:
await app.dispatcher.dispatch_user_guild( await dispatch_member(
user_id,
guild_id, guild_id,
user_id,
(
"MESSAGE_ACK", "MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)}, {"message_id": str(message_id), "channel_id": str(channel_id)},
),
) )
else: else:
# we don't use ChannelDispatcher here because since # we don't use ChannelDispatcher here because since
# guild_id is None, all user devices are already subscribed # guild_id is None, all user devices are already subscribed
# to the given channel (a dm or a group dm) # to the given channel (a dm or a group dm)
await app.dispatcher.dispatch_user( await dispatch_user(
user_id, user_id,
(
"MESSAGE_ACK", "MESSAGE_ACK",
{"message_id": str(message_id), "channel_id": str(channel_id)}, {"message_id": str(message_id), "channel_id": str(channel_id)},
),
) )

View File

@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List
from logbook import Logger from logbook import Logger
from quart import current_app as app from quart import current_app as app
from ..snowflake import get_snowflake from ..snowflake import get_snowflake
from ..permissions import get_role_perms from ..permissions import get_role_perms, get_permissions
from ..utils import dict_get, maybe_lazy_guild_dispatch from ..utils import dict_get, maybe_lazy_guild_dispatch
from ..enums import ChannelType from ..enums import ChannelType
from litecord.pubsub.member import dispatch_member
log = Logger(__name__) log = Logger(__name__)
@ -41,21 +43,20 @@ async def remove_member(guild_id: int, member_id: int):
member_id, member_id,
) )
await app.dispatcher.dispatch_user_guild( await dispatch_member(
member_id,
guild_id, guild_id,
"GUILD_DELETE", member_id,
{"guild_id": str(guild_id), "unavailable": False}, ("GUILD_DELETE", {"guild_id": str(guild_id), "unavailable": False}),
) )
await app.dispatcher.unsub("guild", guild_id, member_id) await app.dispatcher.guild.unsub(guild_id, member_id)
await app.lazy_guild.remove_member(member_id)
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id) await app.dispatcher.guild.dispatch(
await app.dispatcher.dispatch_guild(
guild_id, guild_id,
(
"GUILD_MEMBER_REMOVE", "GUILD_MEMBER_REMOVE",
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)}, {"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
),
) )
@ -108,8 +109,8 @@ async def create_role(guild_id, name: str, **kwargs):
# we need to update the lazy guild handlers for the newly created group # we need to update the lazy guild handlers for the newly created group
await maybe_lazy_guild_dispatch(guild_id, "new_role", role) await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
await app.dispatcher.dispatch_guild( await app.dispatcher.guild.dispatch(
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role} guild_id, ("GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role})
) )
return role return role
@ -137,6 +138,39 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
) )
async def _subscribe_users_new_channel(guild_id: int, channel_id: int) -> None:
# for each state currently subscribed to guild, we check on the database
# which states can also subscribe to the new channel at its creation.
# the list of users that can subscribe are then used again for a pass
# over the states and states that have user ids in that list become
# subscribers of the new channel.
users_to_sub: List[str] = []
for session_id in app.dispatcher.guild.state[guild_id]:
try:
state = app.state_manager.fetch_raw(session_id)
except KeyError:
continue
if state.user_id in users_to_sub:
continue
perms = await get_permissions(state.user_id, channel_id)
if perms.read_messages:
users_to_sub.append(state.user_id)
for session_id in app.dispatcher.guild.state[guild_id]:
try:
state = app.state_manager.fetch_raw(session_id)
except KeyError:
continue
if state.user_id in users_to_sub:
await app.dispatcher.channel.sub(channel_id, session_id)
async def create_guild_channel( async def create_guild_channel(
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
): ):
@ -180,6 +214,8 @@ async def create_guild_channel(
# so we use this function. # so we use this function.
await _specific_chan_create(channel_id, ctype, **kwargs) await _specific_chan_create(channel_id, ctype, **kwargs)
await _subscribe_users_new_channel(guild_id, channel_id)
async def _del_from_table(table: str, user_id: int): async def _del_from_table(table: str, user_id: int):
"""Delete a row from a table.""" """Delete a row from a table."""
@ -206,21 +242,22 @@ async def delete_guild(guild_id: int):
) )
# Discord's client expects IDs being string # Discord's client expects IDs being string
await app.dispatcher.dispatch( await app.dispatcher.guild.dispatch(
"guild",
guild_id, guild_id,
(
"GUILD_DELETE", "GUILD_DELETE",
{ {
"guild_id": str(guild_id), "guild_id": str(guild_id),
"id": str(guild_id), "id": str(guild_id),
# 'unavailable': False, # 'unavailable': False,
}, },
),
) )
# remove from the dispatcher so nobody # remove from the dispatcher so nobody
# becomes the little memer that tries to fuck up with # becomes the little memer that tries to fuck up with
# everybody's gateway # everybody's gateway
await app.dispatcher.remove("guild", guild_id) await app.dispatcher.guild.drop(guild_id)
async def create_guild_settings(guild_id: int, user_id: int): async def create_guild_settings(guild_id: int, user_id: int):
@ -285,18 +322,17 @@ async def add_member(guild_id: int, user_id: int, *, basic=False):
# tell current members a new member came up # tell current members a new member came up
member = await app.storage.get_member_data_one(guild_id, user_id) member = await app.storage.get_member_data_one(guild_id, user_id)
await app.dispatcher.dispatch_guild( await app.dispatcher.guild.dispatch(
guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}} guild_id, ("GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}})
) )
# update member lists for the new member # pubsub changes for new member
await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id) await app.lazy_guild.new_member(guild_id, user_id)
states = await app.dispatcher.guild.sub_user(guild_id, user_id)
# subscribe new member to guild, so they get events n stuff
await app.dispatcher.sub("guild", guild_id, user_id)
# tell the new member that theres the guild it just joined.
# we use dispatch_user_guild so that we send the GUILD_CREATE
# just to the shards that are actually tied to it.
guild = await app.storage.get_guild_full(guild_id, user_id, 250) guild = await app.storage.get_guild_full(guild_id, user_id, 250)
await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild) for state in states:
try:
await state.ws.dispatch("GUILD_CREATE", guild)
except Exception:
log.exception("failed to dispatch to session_id={!r}", state.session_id)

View File

@ -29,41 +29,48 @@ from ..snowflake import get_snowflake
from ..errors import BadRequest from ..errors import BadRequest
from ..auth import hash_data from ..auth import hash_data
from ..utils import rand_hex from ..utils import rand_hex
from ..pubsub.user import dispatch_user
log = Logger(__name__) log = Logger(__name__)
async def mass_user_update(user_id: int): async def mass_user_update(user_id: int) -> Tuple[dict, dict]:
"""Dispatch USER_UPDATE in a mass way.""" """Dispatch a USER_UPDATE to everyone that is subscribed to the user.
# by using dispatch_with_filter
# we're guaranteeing all shards will get
# a USER_UPDATE once and not any others.
This function guarantees all states will get one USER_UPDATE for simple
cases. Lazy guild users might get updates N times depending of how many
lists are they subscribed to.
"""
session_ids: List[str] = [] session_ids: List[str] = []
public_user = await app.storage.get_user(user_id) public_user = await app.storage.get_user(user_id)
private_user = await app.storage.get_user(user_id, secure=True) private_user = await app.storage.get_user(user_id, secure=True)
session_ids.extend(await dispatch_user(user_id, ("USER_UPDATE", private_user)))
guild_ids: List[int] = await app.user_storage.get_user_guilds(user_id)
friend_ids: List[int] = await app.user_storage.get_friend_ids(user_id)
for guild_id in guild_ids:
session_ids.extend( session_ids.extend(
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user) await app.dispatcher.guild.dispatch_filter(
guild_id,
lambda sess_id: sess_id not in session_ids,
("USER_UPDATE", public_user),
)
) )
guild_ids = await app.user_storage.get_user_guilds(user_id) for friend_id in friend_ids:
friend_ids = await app.user_storage.get_friend_ids(user_id)
session_ids.extend( session_ids.extend(
await app.dispatcher.dispatch_many_filter_list( await app.dispatcher.friend.dispatch_filter(
"guild", guild_ids, session_ids, "USER_UPDATE", public_user friend_id,
lambda sess_id: sess_id not in session_ids,
("USER_UPDATE", public_user),
) )
) )
session_ids.extend( for guild_id in guild_ids:
await app.dispatcher.dispatch_many_filter_list( await app.lazy_guild.update_user(guild_id, user_id)
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
)
)
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
return public_user, private_user return public_user, private_user

View File

@ -17,17 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List, Any, Dict
from logbook import Logger from logbook import Logger
from .pubsub import ( from .pubsub import (
GuildDispatcher, GuildDispatcher,
MemberDispatcher,
UserDispatcher,
ChannelDispatcher, ChannelDispatcher,
FriendDispatcher, FriendDispatcher,
LazyGuildDispatcher,
) )
log = Logger(__name__) log = Logger(__name__)
@ -50,169 +45,7 @@ class EventDispatcher:
its subscriber ids. its subscriber ids.
""" """
def __init__(self, app): def __init__(self):
self.state_manager = app.state_manager self.guild: GuildDispatcher = GuildDispatcher()
self.app = app self.channel = ChannelDispatcher()
self.friend = FriendDispatcher()
self.backends = {
"guild": GuildDispatcher(self),
"member": MemberDispatcher(self),
"channel": ChannelDispatcher(self),
"user": UserDispatcher(self),
"friend": FriendDispatcher(self),
"lazy_guild": LazyGuildDispatcher(self),
}
async def action(self, backend_str: str, action: str, key, identifier, *args):
"""Send an action regarding a key/identifier pair to a backend.
Action is usually "sub" or "unsub".
"""
backend = self.backends[backend_str]
method = getattr(backend, action)
# convert keys to the types the backend wants
key = backend.KEY_TYPE(key)
identifier = backend.VAL_TYPE(identifier)
return await method(key, identifier, *args)
async def subscribe(
self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
):
"""Subscribe a single element to the given backend."""
flags = flags or {}
log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
# this is a hacky solution for backwards compatibility between backends
# that implement flags and backends that don't.
# passing flags to backends that don't implement flags will
# cause errors as expected.
if flags:
return await self.action(backend, "sub", key, identifier, flags)
return await self.action(backend, "sub", key, identifier)
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
"""Unsubscribe an element from the given backend."""
log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
return await self.action(backend, "unsub", key, identifier)
async def sub(self, backend, key, identifier):
"""Alias to subscribe()."""
return await self.subscribe(backend, key, identifier)
async def unsub(self, backend, key, identifier):
"""Alias to unsubscribe()."""
return await self.unsubscribe(backend, key, identifier)
async def sub_many(
self,
backend_str: str,
identifier: Any,
keys: list,
flags: Dict[str, Any] = None,
):
"""Subscribe to multiple channels (all in a single backend)
at a time.
Usually used when connecting to the gateway and the client
needs to subscribe to all their guids.
"""
flags = flags or {}
for key in keys:
await self.subscribe(backend_str, key, identifier, flags)
async def mass_sub(self, identifier: Any, backends: List[tuple]):
"""Mass subscribe to many backends at once."""
for bcall in backends:
backend_str, keys = bcall[0], bcall[1]
if len(bcall) == 2:
flags = {}
elif len(bcall) == 3:
# we have flags
flags = bcall[2]
log.debug(
"subscribing {} to {} keys in backend {}, flags: {}",
identifier,
len(keys),
backend_str,
flags,
)
await self.sub_many(backend_str, identifier, keys, flags)
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
"""Dispatch an event to the backend.
The backend is responsible for everything regarding the
actual dispatch.
"""
backend = self.backends[backend_str]
# convert types
key = backend.KEY_TYPE(key)
return await backend.dispatch(key, *args, **kwargs)
async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
"""Dispatch to multiple keys in a single backend."""
log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
for key in keys:
await self.dispatch(backend_str, key, *args, **kwargs)
async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
"""Dispatch to a backend that only accepts
(event, data) arguments with an optional filter
function."""
backend = self.backends[backend_str]
key = backend.KEY_TYPE(key)
return await backend.dispatch_filter(key, func, *args)
async def dispatch_many_filter_list(
self, backend_str: str, keys: List[Any], sess_list: List[str], *args
):
"""Make a "unique" dispatch given a list of session ids.
This only works for backends that have a dispatch_filter
handler and return session id lists in their dispatch
results.
"""
for key in keys:
sess_list.extend(
await self.dispatch_filter(
backend_str, key, lambda sess_id: sess_id not in sess_list, *args
)
)
return sess_list
async def reset(self, backend_str: str, key: Any):
"""Reset the bucket in the given backend."""
backend = self.backends[backend_str]
key = backend.KEY_TYPE(key)
return await backend.reset(key)
async def remove(self, backend_str: str, key: Any):
"""Remove a key from the backend. This
might be a different operation than resetting."""
backend = self.backends[backend_str]
key = backend.KEY_TYPE(key)
return await backend.remove(key)
async def dispatch_guild(self, guild_id, event, data):
"""Backwards compatibility with old EventDispatcher."""
return await self.dispatch("guild", guild_id, event, data)
async def dispatch_user_guild(self, user_id, guild_id, event, data):
"""Backwards compatibility with old EventDispatcher."""
return await self.dispatch("member", (guild_id, user_id), event, data)
async def dispatch_user(self, user_id, event, data):
"""Backwards compatibility with old EventDispatcher."""
return await self.dispatch("user", user_id, event, data)

View File

@ -87,8 +87,8 @@ async def msg_update_embeds(payload, new_embeds):
if "flags" in payload: if "flags" in payload:
update_payload["flags"] = payload["flags"] update_payload["flags"] = payload["flags"]
await app.dispatcher.dispatch( await app.dispatcher.channel.dispatch(
"channel", channel_id, "MESSAGE_UPDATE", update_payload channel_id, ("MESSAGE_UPDATE", update_payload)
) )

View File

@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import urllib.parse import urllib.parse
from typing import Optional
from litecord.gateway.websocket import GatewayWebsocket from litecord.gateway.websocket import GatewayWebsocket
@ -44,15 +45,16 @@ async def websocket_handler(app, ws, url):
return await ws.close(1000, "Invalid gateway encoding") return await ws.close(1000, "Invalid gateway encoding")
try: try:
gw_compress = args["compress"][0] gw_compress: Optional[str] = args["compress"][0]
except (KeyError, IndexError): except (KeyError, IndexError):
gw_compress = None gw_compress = None
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"): if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
return await ws.close(1000, "Invalid gateway compress") return await ws.close(1000, "Invalid gateway compress")
async with app.app_context():
gws = GatewayWebsocket( gws = GatewayWebsocket(
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress ws, v=gw_version, encoding=gw_encoding, compress=gw_compress
) )
# this can be run with a single await since this whole coroutine # this can be run with a single await since this whole coroutine

View File

@ -22,6 +22,7 @@ import asyncio
from typing import List from typing import List
from collections import defaultdict from collections import defaultdict
from quart import current_app as app
from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosed
from logbook import Logger from logbook import Logger
@ -225,3 +226,16 @@ class StateManager:
def close(self): def close(self):
"""Close the state manager.""" """Close the state manager."""
self.closed = True self.closed = True
async def fetch_user_states_for_channel(
self, channel_id: int, user_id: int
) -> List[GatewayState]:
"""Get a list of gateway states for a user that can receive events on a certain channel."""
# TODO optimize this with an in-memory store
guild_id = await app.storage.guild_from_channel(channel_id)
if guild_id:
return self.fetch_states(user_id, guild_id)
# DMs and GDMs use all user states
return self.user_states(user_id)

View File

@ -27,6 +27,7 @@ from random import randint
import websockets import websockets
import zstandard as zstd import zstandard as zstd
from logbook import Logger from logbook import Logger
from quart import current_app as app
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType from litecord.enums import RelationshipType, ChannelType
@ -43,7 +44,6 @@ from litecord.presence import BasePresence
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
from litecord.gateway.state import GatewayState from litecord.gateway.state import GatewayState
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
from litecord.gateway.errors import ( from litecord.gateway.errors import (
DecodeError, DecodeError,
@ -52,8 +52,9 @@ from litecord.gateway.errors import (
ShardingRequired, ShardingRequired,
) )
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
from litecord.gateway.utils import WebsocketFileHandler from litecord.gateway.utils import WebsocketFileHandler
from litecord.pubsub.guild import GuildFlags
from litecord.pubsub.channel import ChannelFlags
from litecord.storage import int_ from litecord.storage import int_
@ -67,7 +68,7 @@ WebsocketProperties = collections.namedtuple(
class GatewayWebsocket: class GatewayWebsocket:
"""Main gateway websocket logic.""" """Main gateway websocket logic."""
def __init__(self, ws, app, **kwargs): def __init__(self, ws, **kwargs):
self.app = app self.app = app
self.storage = app.storage self.storage = app.storage
self.user_storage = app.user_storage self.user_storage = app.user_storage
@ -230,7 +231,7 @@ class GatewayWebsocket:
if task: if task:
task.cancel() task.cancel()
self.wsp.tasks["heartbeat"] = self.app.loop.create_task( self.wsp.tasks["heartbeat"] = app.sched.spawn(
task_wrapper("hb wait", self._hb_wait(interval)) task_wrapper("hb wait", self._hb_wait(interval))
) )
@ -247,6 +248,7 @@ class GatewayWebsocket:
async def dispatch(self, event: str, data: Any): async def dispatch(self, event: str, data: Any):
"""Dispatch an event to the websocket.""" """Dispatch an event to the websocket."""
assert self.state is not None
self.state.seq += 1 self.state.seq += 1
payload = { payload = {
@ -282,6 +284,7 @@ class GatewayWebsocket:
async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]): async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
"""Dispatch GUILD_CREATE information.""" """Dispatch GUILD_CREATE information."""
assert self.state is not None
# Users don't get asynchronous guild dispatching. # Users don't get asynchronous guild dispatching.
if not self.state.bot: if not self.state.bot:
@ -360,9 +363,7 @@ class GatewayWebsocket:
} }
await self.dispatch("READY", {**base_ready, **user_ready}) await self.dispatch("READY", {**base_ready, **user_ready})
app.sched.spawn(self._guild_dispatch(guilds))
# async dispatch of guilds
self.app.loop.create_task(self._guild_dispatch(guilds))
async def _check_shards(self, shard, user_id): async def _check_shards(self, shard, user_id):
"""Check if the given `shard` value in IDENTIFY has good enough values. """Check if the given `shard` value in IDENTIFY has good enough values.
@ -412,6 +413,7 @@ class GatewayWebsocket:
Note: subscribing to channels is already handled Note: subscribing to channels is already handled
by GuildDispatcher.sub by GuildDispatcher.sub
""" """
assert self.state is not None
user_id = self.state.user_id user_id = self.state.user_id
guild_ids = await self._guild_ids() guild_ids = await self._guild_ids()
@ -434,26 +436,50 @@ class GatewayWebsocket:
# (presence and typing events) # (presence and typing events)
# we enable processing of guild_subscriptions by adding flags # we enable processing of guild_subscriptions by adding flags
# when subscribing to the given backend. those are optional. # when subscribing to the given backend.
channels_to_sub = [ session_id = self.state.session_id
( channel_ids: List[int] = []
"guild",
guild_ids,
{"presence": guild_subscriptions, "typing": guild_subscriptions},
),
("channel", dm_ids),
("channel", gdm_ids),
]
await self.app.dispatcher.mass_sub(user_id, channels_to_sub) for guild_id in guild_ids:
await app.dispatcher.guild.sub_with_flags(
guild_id,
session_id,
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
)
# instead of calculating which channels to subscribe to
# inside guild dispatcher, we calculate them in here, so that
# we remove complexity of the dispatcher.
guild_chan_ids = await app.storage.get_channel_ids(guild_id)
for channel_id in guild_chan_ids:
perms = await get_permissions(
self.state.user_id, channel_id, storage=self.storage
)
if perms.bits.read_messages:
channel_ids.append(channel_id)
log.info("subscribing to {} guild channels", len(channel_ids))
for channel_id in channel_ids:
await app.dispatcher.channel.sub_with_flags(
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
)
for dm_id in dm_ids:
await app.dispatcher.channel.sub(dm_id, session_id)
for gdm_id in gdm_ids:
await app.dispatcher.channel.sub(gdm_id, session_id)
if not self.state.bot:
# subscribe to all friends # subscribe to all friends
# (their friends will also subscribe back # (their friends will also subscribe back
# when they come online) # when they come online)
if not self.state.bot:
friend_ids = await self.user_storage.get_friend_ids(user_id) friend_ids = await self.user_storage.get_friend_ids(user_id)
log.info("subscribing to {} friends", len(friend_ids)) log.info("subscribing to {} friends", len(friend_ids))
await self.app.dispatcher.sub_many("friend", user_id, friend_ids) for friend_id in friend_ids:
await app.dispatcher.friend.sub(user_id, friend_id)
async def update_status(self, incoming_status: dict): async def update_status(self, incoming_status: dict):
"""Update the status of the current websocket connection.""" """Update the status of the current websocket connection."""
@ -921,6 +947,7 @@ class GatewayWebsocket:
] ]
} }
""" """
assert self.state is not None
data = payload["d"] data = payload["d"]
gids = await self.user_storage.get_user_guilds(self.state.user_id) gids = await self.user_storage.get_user_guilds(self.state.user_id)
@ -933,11 +960,9 @@ class GatewayWebsocket:
log.debug("lazy request: members: {}", data.get("members", [])) log.debug("lazy request: members: {}", data.get("members", []))
# make shard query # make shard query
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
for chan_id, ranges in data.get("channels", {}).items(): for chan_id, ranges in data.get("channels", {}).items():
chan_id = int(chan_id) chan_id = int(chan_id)
member_list = await lazy_guilds.get_gml(chan_id) member_list = await app.lazy_guild.get_gml(chan_id)
perms = await get_permissions( perms = await get_permissions(
self.state.user_id, chan_id, storage=self.storage self.state.user_id, chan_id, storage=self.storage

View File

@ -93,7 +93,6 @@ class PresenceManager:
self.storage = app.storage self.storage = app.storage
self.user_storage = app.user_storage self.user_storage = app.user_storage
self.state_manager = app.state_manager self.state_manager = app.state_manager
self.dispatcher = app.dispatcher
async def guild_presences( async def guild_presences(
self, member_ids: List[int], guild_id: int self, member_ids: List[int], guild_id: int
@ -127,8 +126,7 @@ class PresenceManager:
member = await self.storage.get_member_data_one(guild_id, user_id) member = await self.storage.get_member_data_one(guild_id, user_id)
lazy_guild_store = self.dispatcher.backends["lazy_guild"] lists = app.lazy_guild.get_gml_guild(guild_id)
lists = lazy_guild_store.get_gml_guild(guild_id)
# shards that are in lazy guilds with 'everyone' # shards that are in lazy guilds with 'everyone'
# enabled # enabled
@ -163,20 +161,21 @@ class PresenceManager:
# given a session id, return if the session id actually connects to # given a session id, return if the session id actually connects to
# a given user, and if the state has not been dispatched via lazy guild. # a given user, and if the state has not been dispatched via lazy guild.
def _session_check(session_id): def _session_check(session_id):
try:
state = self.state_manager.fetch_raw(session_id) state = self.state_manager.fetch_raw(session_id)
uid = int(member["user"]["id"]) except KeyError:
if not state:
return False return False
uid = int(member["user"]["id"])
# we don't want to send a presence update # we don't want to send a presence update
# to the same user # to the same user
return state.user_id != uid and session_id not in in_lazy return state.user_id != uid and session_id not in in_lazy
# everyone not in lazy guild mode # everyone not in lazy guild mode
# gets a PRESENCE_UPDATE # gets a PRESENCE_UPDATE
await self.dispatcher.dispatch_filter( await app.dispatcher.guild.dispatch_filter(
"guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload guild_id, _session_check, ("PRESENCE_UPDATE", event_payload)
) )
return in_lazy return in_lazy
@ -193,11 +192,8 @@ class PresenceManager:
# dispatch to all friends that are subscribed to them # dispatch to all friends that are subscribed to them
user = await self.storage.get_user(user_id) user = await self.storage.get_user(user_id)
await self.dispatcher.dispatch( await app.dispatcher.friend.dispatch(
"friend", user_id, ("PRESENCE_UPDATE", {**presence.partial_dict, **{"user": user}}),
user_id,
"PRESENCE_UPDATE",
{**presence.partial_dict, **{"user": user}},
) )
def fetch_friend_presence(self, friend_id: int) -> BasePresence: def fetch_friend_presence(self, friend_id: int) -> BasePresence:

View File

@ -18,17 +18,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from .guild import GuildDispatcher from .guild import GuildDispatcher
from .member import MemberDispatcher from .member import dispatch_member
from .user import UserDispatcher from .user import dispatch_user
from .channel import ChannelDispatcher from .channel import ChannelDispatcher
from .friend import FriendDispatcher from .friend import FriendDispatcher
from .lazy_guild import LazyGuildDispatcher
__all__ = [ __all__ = [
"GuildDispatcher", "GuildDispatcher",
"MemberDispatcher", "dispatch_member",
"UserDispatcher", "dispatch_user",
"ChannelDispatcher", "ChannelDispatcher",
"FriendDispatcher", "FriendDispatcher",
"LazyGuildDispatcher",
] ]

View File

@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Any, List from typing import List
from dataclasses import dataclass
from quart import current_app as app
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithFlags
from litecord.enums import ChannelType from litecord.enums import ChannelType
from litecord.utils import index_by_func from litecord.utils import index_by_func
from .dispatcher import DispatcherWithFlags, GatewayEvent
log = Logger(__name__) log = Logger(__name__)
@ -37,75 +39,66 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
""" """
# make a copy or the original channel object # make a copy or the original channel object
data = dict(orig) data = dict(orig)
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"]) idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
data["recipients"].pop(idx) data["recipients"].pop(idx)
return data return data
class ChannelDispatcher(DispatcherWithFlags): @dataclass
"""Main channel Pub/Sub logic.""" class ChannelFlags:
typing: bool
KEY_TYPE = int
VAL_TYPE = int
async def dispatch(self, channel_id, event: str, data: Any) -> List[str]: class ChannelDispatcher(
DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
):
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
"""Dispatch an event to a channel.""" """Dispatch an event to a channel."""
# get everyone who is subscribed session_ids = set(self.state[channel_id])
# and store the number of states we dispatched the event to
user_ids = self.state[channel_id]
dispatched = 0
sessions: List[str] = [] sessions: List[str] = []
# making a copy of user_ids since event_type, event_data = event
# we'll modify it later on. assert isinstance(event_data, dict)
for user_id in set(user_ids):
guild_id = await self.app.storage.guild_from_channel(channel_id)
# if we are dispatching to a guild channel, for session_id in session_ids:
# we should only dispatch to the states / shards try:
# that are connected to the guild (via their shard id). state = app.state_manager.fetch_raw(session_id)
except KeyError:
await self.unsub(channel_id, session_id)
continue
# if we aren't, we just get all states tied to the user. try:
# TODO: make a fetch_states that fetches shards flags = self.get_flags(channel_id, session_id)
# - with id 0 (count any) OR except KeyError:
# - single shards (id=0, count=1) log.warning("no flags for {!r}, ignoring", session_id)
states = ( flags = ChannelFlags(typing=True)
self.sm.fetch_states(user_id, guild_id)
if guild_id if event_type.lower().startswith("typing_") and not flags.typing:
else self.sm.user_states(user_id) continue
correct_event = event
# for cases where we are talking about group dms, we create an edited
# event data so that it doesn't show the user we're dispatching
# to in data.recipients (clients already assume they are recipients)
if (
event_type in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
and event_data.get("type") == ChannelType.GROUP_DM.value
):
new_data = gdm_recipient_view(event_data, state.user_id)
correct_event = (event_type, new_data)
try:
await state.ws.dispatch(*correct_event)
except Exception:
log.exception("error while dispatching to {}", state.session_id)
continue
sessions.append(session_id)
log.info(
"Dispatched chan={} {!r} to {} states", channel_id, event[0], len(sessions)
) )
# unsub people who don't have any states tied to the channel.
if not states:
await self.unsub(channel_id, user_id)
continue
# skip typing events for users that don't want it
if event.startswith("TYPING_") and not self.flags_get(
channel_id, user_id, "typing", True
):
continue
cur_sess: List[str] = []
if (
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
and data.get("type") == ChannelType.GROUP_DM.value
):
# we edit the channel payload so it doesn't show
# the user as a recipient
new_data = gdm_recipient_view(data, user_id)
cur_sess = await self._dispatch_states(states, event, new_data)
else:
cur_sess = await self._dispatch_states(states, event, data)
sessions.extend(cur_sess)
dispatched += len(cur_sess)
log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
return sessions return sessions

View File

@ -17,7 +17,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List from typing import (
List,
Generic,
TypeVar,
Any,
Callable,
Dict,
Set,
Mapping,
Iterable,
Tuple,
)
from collections import defaultdict from collections import defaultdict
from logbook import Logger from logbook import Logger
@ -25,79 +36,63 @@ from logbook import Logger
log = Logger(__name__) log = Logger(__name__)
def _identity(_self, x): K = TypeVar("K")
return x V = TypeVar("V")
F = TypeVar("F")
EventType = TypeVar("EventType")
DispatchType = TypeVar("DispatchType")
F_Map = Mapping[V, F]
GatewayEvent = Tuple[str, Any]
__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"]
class Dispatcher: class Dispatcher(Generic[K, V, EventType, DispatchType]):
"""Pub/Sub backend dispatcher. """Pub/Sub backend dispatcher.
This just declares functions all Dispatcher subclasses Classes must implement this protocol.
can implement. This does not mean all Dispatcher
subclasses have them implemented.
""" """
KEY_TYPE = _identity async def sub(self, key: K, identifier: V) -> None:
VAL_TYPE = _identity """Subscribe a given identifier to a given key."""
...
def __init__(self, main): async def sub_many(self, key: K, identifier_list: Iterable[V]) -> None:
#: main EventDispatcher for identifier in identifier_list:
self.main_dispatcher = main await self.sub(key, identifier)
#: gateway state storage async def unsub(self, key: K, identifier: V) -> None:
self.sm = main.state_manager """Unsubscribe a given identifier to a given key."""
...
self.app = main.app async def dispatch(self, key: K, event: EventType) -> DispatchType:
...
async def sub(self, _key, _id): async def dispatch_many(self, keys: List[K], *args: Any, **kwargs: Any) -> None:
"""Subscribe an element to the channel/key.""" log.info("MULTI DISPATCH in {!r}, {} keys", self, len(keys))
raise NotImplementedError for key in keys:
await self.dispatch(key, *args, **kwargs)
async def unsub(self, _key, _id): async def drop(self, key: K) -> None:
"""Unsubscribe an elemtnt from the channel/key.""" """Drop a key."""
raise NotImplementedError ...
async def dispatch_filter(self, _key, _func, *_args): async def clear(self, key: K) -> None:
"""Selectively dispatch to the list of subscribed users. """Clear a key from the backend."""
...
The selection logic is completly arbitraty and up to the async def dispatch_filter(
Pub/Sub backend. self, key: K, filter_function: Callable[[K], bool], event: EventType
) -> List[str]:
"""Selectively dispatch to the list of subscribers.
Function must return a list of separate identifiers for composability.
""" """
raise NotImplementedError ...
async def dispatch(self, _key, *_args):
"""Dispatch an event to the given channel/key."""
raise NotImplementedError
async def reset(self, _key):
"""Reset a key from the backend."""
raise NotImplementedError
async def remove(self, _key):
"""Remove a key from the backend.
The meaning from reset() and remove()
is different, reset() is to clear all
subscribers from the given key,
remove() is to remove the key as well.
"""
raise NotImplementedError
async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
"""Dispatch an event to a list of states."""
res = []
for state in states:
try:
await state.ws.dispatch(event, data)
res.append(state.session_id)
except Exception:
log.exception("error while dispatching")
return res
class DispatcherWithState(Dispatcher): class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
"""Pub/Sub backend with a state dictionary. """Pub/Sub backend with a state dictionary.
This class was made to decrease the amount This class was made to decrease the amount
@ -105,58 +100,58 @@ class DispatcherWithState(Dispatcher):
that have that dictionary. that have that dictionary.
""" """
def __init__(self, main): def __init__(self):
super().__init__(main) super().__init__()
#: the default dict is to a set #: the default dict is to a set
# so we make sure someone calling sub() # so we make sure someone calling sub()
# twice won't get 2x the events for the # twice won't get 2x the events for the
# same channel. # same channel.
self.state = defaultdict(set) self.state: Dict[K, Set[V]] = defaultdict(set)
async def sub(self, key, identifier): async def sub(self, key: K, identifier: V):
self.state[key].add(identifier) self.state[key].add(identifier)
async def unsub(self, key, identifier): async def unsub(self, key: K, identifier: V):
self.state[key].discard(identifier) self.state[key].discard(identifier)
async def reset(self, key): async def reset(self, key: K):
self.state[key] = set() self.state[key] = set()
async def remove(self, key): async def drop(self, key: K):
try: try:
self.state.pop(key) self.state.pop(key)
except KeyError: except KeyError:
pass pass
async def dispatch(self, key, *args):
raise NotImplementedError
class DispatcherWithFlags(
class DispatcherWithFlags(DispatcherWithState): DispatcherWithState, Generic[K, V, EventType, DispatchType, F],
):
"""Pub/Sub backend with both a state and a flags store.""" """Pub/Sub backend with both a state and a flags store."""
def __init__(self, main): def __init__(self):
super().__init__(main) super().__init__()
self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
#: keep flags for subscribers, so for example def set_flags(self, key: K, identifier: V, flags: F):
# a subscriber could drop all presence events at the """Set flags for the given identifier."""
# pubsub level. see gateway's guild_subscriptions field for more self.flags[key][identifier] = flags
self.flags = defaultdict(dict)
async def sub(self, key, identifier, flags=None): def remove_flags(self, key: K, identifier: V):
"""Subscribe a user to the guild.""" """Set flags for the given identifier."""
await super().sub(key, identifier)
self.flags[key][identifier] = flags or {}
async def unsub(self, key, identifier):
"""Unsubscribe a user from the guild."""
await super().unsub(key, identifier)
self.flags[key].pop(identifier) self.flags[key].pop(identifier)
def flags_get(self, key, identifier, field: str, default): def get_flags(self, key: K, identifier: V):
"""Get a single field from the flags store.""" """Get a single field from the flags store."""
# yes, i know its simply an indirection from the main flags store, return self.flags[key][identifier]
# but i'd rather have this than change every call if i ever change
# the structure of the flags store. async def sub_with_flags(self, key: K, identifier: V, flags: F):
return self.flags[key][identifier].get(field, default) """Subscribe a user to the guild."""
await super().sub(key, identifier)
self.set_flags(key, identifier, flags)
async def unsub(self, key: K, identifier: V):
"""Unsubscribe a user from the guild."""
await super().unsub(key, identifier)
self.remove_flags(key, identifier)

View File

@ -17,15 +17,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List from typing import List, Set
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithState from .dispatcher import DispatcherWithState, GatewayEvent
from .user import dispatch_user_filter
log = Logger(__name__) log = Logger(__name__)
class FriendDispatcher(DispatcherWithState): class FriendDispatcher(DispatcherWithState[int, int, GatewayEvent, List[str]]):
"""Friend Pub/Sub logic. """Friend Pub/Sub logic.
When connecting, a client will subscribe to all their friends When connecting, a client will subscribe to all their friends
@ -33,26 +34,18 @@ class FriendDispatcher(DispatcherWithState):
broadcasted through that channel to basically all their friends. broadcasted through that channel to basically all their friends.
""" """
KEY_TYPE = int async def dispatch_filter(self, user_id: int, filter_function, event: GatewayEvent):
VAL_TYPE = int
async def dispatch_filter(self, user_id: int, func, event, data):
"""Dispatch an event to all of a users' friends.""" """Dispatch an event to all of a users' friends."""
peer_ids = self.state[user_id] peer_ids: Set[int] = self.state[user_id]
sessions: List[str] = [] sessions: List[str] = []
for peer_id in peer_ids: for peer_id in peer_ids:
# dispatch to the user instead of the "shards tied to a guild" # dispatch to the user instead of the "shards tied to a guild"
# since relationships broadcast to all shards. # since relationships broadcast to all shards.
sessions.extend( sessions.extend(await dispatch_user_filter(peer_id, filter_function, event))
await self.main_dispatcher.dispatch_filter(
"user", peer_id, func, event, data
)
)
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions)) log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
return sessions return sessions
async def dispatch(self, user_id, event, data): async def dispatch(self, user_id: int, event: GatewayEvent):
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data) return await self.dispatch_filter(user_id, lambda sess_id: True, event)

View File

@ -17,123 +17,73 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Any from typing import List
from dataclasses import dataclass
from quart import current_app as app
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithFlags from .dispatcher import DispatcherWithFlags, GatewayEvent
from litecord.permissions import get_permissions from .channel import ChannelFlags
from litecord.gateway.state import GatewayState
log = Logger(__name__) log = Logger(__name__)
class GuildDispatcher(DispatcherWithFlags): @dataclass
"""Guild backend for Pub/Sub""" class GuildFlags(ChannelFlags):
presence: bool
KEY_TYPE = int
VAL_TYPE = int
async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None): class GuildDispatcher(
"""Send an action to all channels of the guild.""" DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
flags = flags or {} ):
chan_ids = await self.app.storage.get_channel_ids(guild_id) """Guild backend for Pub/Sub."""
for chan_id in chan_ids: async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
states = app.state_manager.fetch_states(user_id, guild_id)
for state in states:
await self.sub(guild_id, state.session_id)
# only do an action for users that can return states
# actually read the channel to start with.
chan_perms = await get_permissions(
user_id, chan_id, storage=self.main_dispatcher.app.storage
)
if not chan_perms.bits.read_messages: async def dispatch_filter(
log.debug("skipping cid={}, no read messages", chan_id) self, guild_id: int, filter_function, event: GatewayEvent
continue
log.debug("sending raw action {!r} to chan={}", action, chan_id)
# for now, only sub() has support for flags.
# it is an idea to have flags support for other actions
args = []
if action == "sub":
chanflags = dict(flags)
# channels don't need presence flags
try:
chanflags.pop("presence")
except KeyError:
pass
args.append(chanflags)
await self.main_dispatcher.action(
"channel", action, chan_id, user_id, *args
)
async def _chan_call(self, meth: str, guild_id: int, *args):
"""Call a method on the ChannelDispatcher, for all channels
in the guild."""
chan_ids = await self.app.storage.get_channel_ids(guild_id)
chan_dispatcher = self.main_dispatcher.backends["channel"]
method = getattr(chan_dispatcher, meth)
for chan_id in chan_ids:
log.debug("calling {} to chan={}", meth, chan_id)
await method(chan_id, *args)
async def sub(self, guild_id: int, user_id: int, flags=None):
"""Subscribe a user to the guild."""
await super().sub(guild_id, user_id, flags)
await self._chan_action("sub", guild_id, user_id, flags)
async def unsub(self, guild_id: int, user_id: int):
"""Unsubscribe a user from the guild."""
await super().unsub(guild_id, user_id)
await self._chan_action("unsub", guild_id, user_id)
async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
"""Selectively dispatch to session ids that have
func(session_id) true."""
user_ids = self.state[guild_id]
dispatched = 0
sessions = []
# acquire a copy since we may be modifying
# the original user_ids
for user_id in set(user_ids):
# fetch all states / shards that are tied to the guild.
states = self.sm.fetch_states(user_id, guild_id)
if not states:
# user is actually disconnected,
# so we should just unsub them
await self.unsub(guild_id, user_id)
continue
# skip the given subscriber if event starts with PRESENCE_
# and the flags say they don't want it.
# note that this does not equate to any unsubscription
# of the channel.
if event.startswith("PRESENCE_") and not self.flags_get(
guild_id, user_id, "presence", True
): ):
session_ids = self.state[guild_id]
sessions: List[str] = []
event_type, _ = event
for session_id in set(session_ids):
if not filter_function(session_id):
continue continue
# filter the ones that matter try:
states = list(filter(lambda state: func(state.session_id), states)) state = app.state_manager.fetch_raw(session_id)
except KeyError:
await self.unsub(guild_id, session_id)
continue
cur_sess = await self._dispatch_states(states, event, data) try:
flags = self.get_flags(guild_id, session_id)
except KeyError:
log.warning("no flags for {!r}, ignoring", session_id)
flags = GuildFlags(presence=True, typing=True)
sessions.extend(cur_sess) if event_type.lower().startswith("presence_") and not flags.presence:
dispatched += len(cur_sess) continue
log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched) try:
await state.ws.dispatch(*event)
except Exception:
log.exception("error while dispatching to {}", state.session_id)
continue
sessions.append(session_id)
log.info("Dispatched {} {!r} to {} states", guild_id, event[0], len(sessions))
return sessions return sessions
async def dispatch(self, guild_id: int, event: str, data: Any): async def dispatch(self, guild_id: int, event):
"""Dispatch an event to all subscribers of the guild.""" """Dispatch an event to all subscribers of the guild."""
return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data) return await self.dispatch_filter(guild_id, lambda sess_id: True, event)

View File

@ -30,10 +30,10 @@ import asyncio
from collections import defaultdict from collections import defaultdict
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from quart import current_app as app
from logbook import Logger from logbook import Logger
from litecord.pubsub.dispatcher import Dispatcher
from litecord.permissions import ( from litecord.permissions import (
Permissions, Permissions,
overwrite_find_mix, overwrite_find_mix,
@ -239,9 +239,6 @@ class GuildMemberList:
Attributes Attributes
---------- ----------
main_lg: LazyGuildDispatcher
Main instance of :class:`LazyGuildDispatcher`,
so that we're able to use things such as :class:`Storage`.
guild_id: int guild_id: int
The Guild ID this instance is referring to. The Guild ID this instance is referring to.
channel_id: int channel_id: int
@ -257,11 +254,10 @@ class GuildMemberList:
for example, can still rely on PRESENCE_UPDATEs. for example, can still rely on PRESENCE_UPDATEs.
""" """
def __init__(self, guild_id: int, channel_id: int, main_lg): def __init__(self, guild_id: int, channel_id: int):
self.guild_id = guild_id self.guild_id = guild_id
self.channel_id = channel_id self.channel_id = channel_id
self.main = main_lg
self.list = MemberList() self.list = MemberList()
#: store the states that are subscribed to the list. #: store the states that are subscribed to the list.
@ -273,22 +269,22 @@ class GuildMemberList:
@property @property
def loop(self): def loop(self):
"""Get the main asyncio loop instance.""" """Get the main asyncio loop instance."""
return self.main.app.loop return app.loop
@property @property
def storage(self): def storage(self):
"""Get the global :class:`Storage` instance.""" """Get the global :class:`Storage` instance."""
return self.main.app.storage return app.storage
@property @property
def presence(self): def presence(self):
"""Get the global :class:`PresenceManager` instance.""" """Get the global :class:`PresenceManager` instance."""
return self.main.app.presence return app.presence
@property @property
def state_man(self): def state_man(self):
"""Get the global :class:`StateManager` instance.""" """Get the global :class:`StateManager` instance."""
return self.main.app.state_manager return app.state_manager
@property @property
def list_id(self): def list_id(self):
@ -572,8 +568,7 @@ class GuildMemberList:
Wrapper for :meth:`StateManager.fetch_raw` Wrapper for :meth:`StateManager.fetch_raw`
""" """
try: try:
state = self.state_man.fetch_raw(session_id) return self.state_man.fetch_raw(session_id)
return state
except KeyError: except KeyError:
return None return None
@ -643,7 +638,7 @@ class GuildMemberList:
# do resync-ing in the background # do resync-ing in the background
result.append(session_id) result.append(session_id)
self.loop.create_task(self.shard_query(session_id, [role_range])) app.sched.spawn(self.shard_query(session_id, [role_range]))
return result return result
@ -683,8 +678,7 @@ class GuildMemberList:
) )
if everyone_perms.bits.read_messages and list_id != "everyone": if everyone_perms.bits.read_messages and list_id != "everyone":
everyone_gml = await self.main.get_gml(self.guild_id) everyone_gml = await app.lazy_guild.get_gml(self.guild_id)
return await everyone_gml.shard_query(session_id, ranges) return await everyone_gml.shard_query(session_id, ranges)
await self._init_check() await self._init_check()
@ -1372,47 +1366,36 @@ class GuildMemberList:
self.guild_id = 0 self.guild_id = 0
self.channel_id = 0 self.channel_id = 0
self.main = None
self._set_empty_list() self._set_empty_list()
self.state = {} self.state = {}
class LazyGuildDispatcher(Dispatcher): class LazyGuildManager:
"""Main class holding the member lists for lazy guilds.""" """Main class holding the member lists for lazy guilds."""
# channel ids def __init__(self):
KEY_TYPE = int
# the session ids subscribing to channels
VAL_TYPE = str
def __init__(self, main):
super().__init__(main)
self.storage = main.app.storage
# {chan_id: gml, ...} # {chan_id: gml, ...}
self.state = {} self.state: Dict[int, GuildMemberList] = {}
#: store which guilds have their #: store which guilds have their
# respective GMLs # respective GMLs
# {guild_id: [chan_id, ...], ...} # {guild_id: [chan_id, ...], ...}
self.guild_map: Dict[int, List[int]] = defaultdict(list) self.guild_map: Dict[int, List[int]] = defaultdict(list)
async def get_gml(self, channel_id: int): async def get_gml(self, channel_id: int) -> GuildMemberList:
"""Get a guild list for a channel ID, """Get a guild list for a channel ID,
generating it if it doesn't exist.""" generating it if it doesn't exist."""
try: try:
return self.state[channel_id] return self.state[channel_id]
except KeyError: except KeyError:
guild_id = await self.storage.guild_from_channel(channel_id) guild_id = await app.storage.guild_from_channel(channel_id)
# if we don't find a guild, we just # if we don't find a guild, we just
# set it the same as the channel. # set it the same as the channel.
if not guild_id: if not guild_id:
guild_id = channel_id guild_id = channel_id
gml = GuildMemberList(guild_id, channel_id, self) gml = GuildMemberList(guild_id, channel_id)
self.state[channel_id] = gml self.state[channel_id] = gml
self.guild_map[guild_id].append(channel_id) self.guild_map[guild_id].append(channel_id)
return gml return gml
@ -1437,16 +1420,6 @@ class LazyGuildDispatcher(Dispatcher):
gml = await self.get_gml(chan_id) gml = await self.get_gml(chan_id)
gml.unsub(session_id) gml.unsub(session_id)
async def dispatch(self, guild_id, event: str, *args, **kwargs):
"""Call a function specialized in handling the given event"""
try:
handler = getattr(self, f"_handle_{event.lower()}")
except AttributeError:
log.warning("unknown event: {}", event)
return
await handler(guild_id, *args, **kwargs)
def remove_channel(self, channel_id: int): def remove_channel(self, channel_id: int):
"""Remove a channel from the manager.""" """Remove a channel from the manager."""
try: try:
@ -1474,29 +1447,29 @@ class LazyGuildDispatcher(Dispatcher):
method = getattr(lazy_list, method_str) method = getattr(lazy_list, method_str)
await method(*args) await method(*args)
async def _handle_new_role(self, guild_id: int, new_role: dict): async def new_role(self, guild_id: int, new_role: dict):
"""Handle the addition of a new group by dispatching it to """Handle the addition of a new group by dispatching it to
the member lists.""" the member lists."""
await self._call_all_lists(guild_id, "new_role", new_role) await self._call_all_lists(guild_id, "new_role", new_role)
async def _handle_role_pos_upd(self, guild_id, role: dict): async def role_position_update(self, guild_id, role: dict):
await self._call_all_lists(guild_id, "role_pos_update", role) await self._call_all_lists(guild_id, "role_pos_update", role)
async def _handle_role_update(self, guild_id, role: dict): async def role_update(self, guild_id, role: dict):
# handle name and hoist changes # handle name and hoist changes
await self._call_all_lists(guild_id, "role_update", role) await self._call_all_lists(guild_id, "role_update", role)
async def _handle_role_delete(self, guild_id, role_id: int): async def role_delete(self, guild_id, role_id: int):
await self._call_all_lists(guild_id, "role_delete", role_id) await self._call_all_lists(guild_id, "role_delete", role_id)
async def _handle_pres_update(self, guild_id, user_id: int, partial: dict): async def pres_update(self, guild_id, user_id: int, partial: dict):
await self._call_all_lists(guild_id, "pres_update", user_id, partial) await self._call_all_lists(guild_id, "pres_update", user_id, partial)
async def _handle_new_member(self, guild_id, user_id: int): async def new_member(self, guild_id, user_id: int):
await self._call_all_lists(guild_id, "new_member", user_id) await self._call_all_lists(guild_id, "new_member", user_id)
async def _handle_remove_member(self, guild_id, user_id: int): async def remove_member(self, guild_id, user_id: int):
await self._call_all_lists(guild_id, "remove_member", user_id) await self._call_all_lists(guild_id, "remove_member", user_id)
async def _handle_update_user(self, guild_id, user_id: int): async def update_user(self, guild_id, user_id: int):
await self._call_all_lists(guild_id, "update_user", user_id) await self._call_all_lists(guild_id, "update_user", user_id)

View File

@ -17,30 +17,20 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from .dispatcher import Dispatcher from typing import List
from quart import current_app as app
from .dispatcher import GatewayEvent
from .utils import send_event_to_states
class MemberDispatcher(Dispatcher): async def dispatch_member(
"""Member backend for Pub/Sub.""" guild_id: int, user_id: int, event: GatewayEvent
) -> List[str]:
states = app.state_manager.fetch_states(user_id, guild_id)
KEY_TYPE = tuple # if no states were found, we should unsub the user from the guild
async def dispatch(self, key, event, data):
"""Dispatch a single event to a member.
This is shard-aware.
"""
# we don't keep any state on this dispatcher, so the key
# is just (guild_id, user_id)
guild_id, user_id = key
# fetch shards
states = self.sm.fetch_states(user_id, guild_id)
# if no states were found, we should
# unsub the user from the GUILD channel
if not states: if not states:
await self.main_dispatcher.unsub("guild", guild_id, user_id) await app.dispatcher.guild.unsub(guild_id, user_id)
return return []
return await self._dispatch_states(states, event, data) return await send_event_to_states(states, event)

View File

@ -17,23 +17,27 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from .dispatcher import Dispatcher from typing import Callable, List
from quart import current_app as app
from .dispatcher import GatewayEvent
from .utils import send_event_to_states
class UserDispatcher(Dispatcher): async def dispatch_user_filter(
"""User backend for Pub/Sub.""" user_id: int, filter_func: Callable[[str], bool], event_data: GatewayEvent
) -> List[str]:
KEY_TYPE = int """Dispatch to a given user's states, but only for states
where filter_func returns true."""
async def dispatch_filter(self, user_id: int, func, event, data):
"""Dispatch an event to all shards of a user."""
# filter only states where func() gives true
states = list( states = list(
filter(lambda state: func(state.session_id), self.sm.user_states(user_id)) filter(
lambda state: filter_func(state.session_id),
app.state_manager.user_states(user_id),
)
) )
return await self._dispatch_states(states, event, data) return await send_event_to_states(states, event_data)
async def dispatch(self, user_id: int, event, data):
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data) async def dispatch_user(user_id: int, event_data: GatewayEvent) -> List[str]:
return await dispatch_user_filter(user_id, lambda sess_id: True, event_data)

41
litecord/pubsub/utils.py Normal file
View File

@ -0,0 +1,41 @@
"""
Litecord
Copyright (C) 2018-2019 Luna Mendes
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import logging
from typing import List, Tuple, Any
from ..gateway.state import GatewayState
log = logging.getLogger(__name__)
async def send_event_to_states(
states: List[GatewayState], event_data: Tuple[str, Any]
) -> List[str]:
"""Dispatch an event to a list of states."""
res = []
for state in states:
try:
event, data = event_data
await state.ws.dispatch(event, data)
res.append(state.session_id)
except Exception:
log.exception("error while dispatching")
return res

View File

@ -482,6 +482,9 @@ INVITE = {
"required": False, "required": False,
"nullable": True, "nullable": True,
}, # discord client sends invite code there }, # discord client sends invite code there
# sent by official client, unknown purpose
"target_user_id": {"type": "snowflake", "required": False, "nullable": True},
"target_user_type": {"type": "number", "required": False, "nullable": True},
} }
USER_SETTINGS = { USER_SETTINGS = {

View File

@ -709,7 +709,9 @@ class Storage:
return res return res
async def get_guild_extra(self, guild_id: int, user_id=None, large=None) -> Dict: async def get_guild_extra(
self, guild_id: int, user_id: Optional[int] = None, large: Optional[int] = None
) -> Dict:
"""Get extra information about a guild.""" """Get extra information about a guild."""
res = {} res = {}

View File

@ -181,9 +181,6 @@ async def send_sys_message(
raise ValueError("Invalid system message type") raise ValueError("Invalid system message type")
message_id = await handler(channel_id, *args, **kwargs) message_id = await handler(channel_id, *args, **kwargs)
message = await app.storage.get_message(message_id) message = await app.storage.get_message(message_id)
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", message))
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
return message_id return message_id

View File

@ -252,7 +252,7 @@ async def maybe_lazy_guild_dispatch(
if isinstance(role, dict) and not role["hoist"] and not force: if isinstance(role, dict) and not role["hoist"] and not force:
return return
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role) await (getattr(app.lazy_guild, event))(guild_id, role)
def extract_limit(request_, default: int = 50, max_val: int = 100): def extract_limit(request_, default: int = 50, max_val: int = 100):

View File

@ -17,11 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Optional from typing import Optional, Dict, List
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from logbook import Logger from logbook import Logger
from quart import current_app as app
from litecord.voice.lvsp_conn import LVSPConnection from litecord.voice.lvsp_conn import LVSPConnection
@ -42,15 +43,15 @@ class LVSPManager:
Spawns :class:`LVSPConnection` as needed, etc. Spawns :class:`LVSPConnection` as needed, etc.
""" """
def __init__(self, app, voice): def __init__(self, app_, voice):
self.app = app self.app = app_
self.voice = voice self.voice = voice
# map servers to LVSPConnection # map servers to LVSPConnection
self.conns = {} self.conns: Dict[str, LVSPConnection] = {}
# maps regions to server hostnames # maps regions to server hostnames
self.servers = defaultdict(list) self.servers: Dict[str, List[str]] = defaultdict(list)
# maps Union[GuildID, DMId, GroupDMId] to server hostnames # maps Union[GuildID, DMId, GroupDMId] to server hostnames
self.assign = {} self.assign = {}
@ -84,7 +85,7 @@ class LVSPManager:
continue continue
self.regions[region.id] = region self.regions[region.id] = region
self.app.loop.create_task(self._spawn_region(region)) app.sched.spawn(self._spawn_region(region))
async def _spawn_region(self, region: Region): async def _spawn_region(self, region: Region):
"""Spawn a region. Involves fetching all the hostnames """Spawn a region. Involves fetching all the hostnames

View File

@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from typing import Tuple, Dict, List from typing import Tuple, Dict, List
from collections import defaultdict from collections import defaultdict
from dataclasses import fields from dataclasses import fields
from quart import current_app as app
from logbook import Logger from logbook import Logger
@ -286,6 +287,6 @@ class VoiceManager:
# slow, but it be like that, also copied from other users... # slow, but it be like that, also copied from other users...
for guild_id in guild_ids: for guild_id in guild_ids:
guild = await self.app.storage.get_guild_full(guild_id, None) guild = await self.app.storage.get_guild_full(guild_id, None)
await self.app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild) await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
# TODO propagate the channel deprecation to LVSP connections # TODO propagate the channel deprecation to LVSP connections

View File

@ -104,7 +104,7 @@ def main(config):
# as the managers require it # as the managers require it
# and the migrate command also sets the db up # and the migrate command also sets the db up
if argv[1] != "migrate": if argv[1] != "migrate":
init_app_managers(app, voice=False) init_app_managers(app, init_voice=False)
args = parser.parse_args() args = parser.parse_args()
loop.run_until_complete(_ctx_wrapper(app, args)) loop.run_until_complete(_ctx_wrapper(app, args))

8
run.py
View File

@ -95,6 +95,7 @@ from litecord.images import IconManager
from litecord.jobs import JobManager from litecord.jobs import JobManager
from litecord.voice.manager import VoiceManager from litecord.voice.manager import VoiceManager
from litecord.guild_memory_store import GuildMemoryStore from litecord.guild_memory_store import GuildMemoryStore
from litecord.pubsub.lazy_guild import LazyGuildManager
from litecord.gateway.gateway import websocket_handler from litecord.gateway.gateway import websocket_handler
@ -254,7 +255,7 @@ async def init_app_db(app_):
app_.sched = JobManager() app_.sched = JobManager()
def init_app_managers(app_, *, voice=True): def init_app_managers(app_: Quart, *, init_voice=True):
"""Initialize singleton classes.""" """Initialize singleton classes."""
app_.loop = asyncio.get_event_loop() app_.loop = asyncio.get_event_loop()
app_.ratelimiter = RatelimitManager(app_.config.get("_testing")) app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
@ -265,7 +266,7 @@ def init_app_managers(app_, *, voice=True):
app_.icons = IconManager(app_) app_.icons = IconManager(app_)
app_.dispatcher = EventDispatcher(app_) app_.dispatcher = EventDispatcher()
app_.presence = PresenceManager(app_) app_.presence = PresenceManager(app_)
app_.storage.presence = app_.presence app_.storage.presence = app_.presence
@ -274,10 +275,11 @@ def init_app_managers(app_, *, voice=True):
# we do this because of a bug on ./manage.py where it # we do this because of a bug on ./manage.py where it
# cancels the LVSPManager's spawn regions task. we don't # cancels the LVSPManager's spawn regions task. we don't
# need to start it on manage time. # need to start it on manage time.
if voice: if init_voice:
app_.voice = VoiceManager(app_) app_.voice = VoiceManager(app_)
app_.guild_store = GuildMemoryStore() app_.guild_store = GuildMemoryStore()
app_.lazy_guild = LazyGuildManager()
async def api_index(app_): async def api_index(app_):