mirror of https://gitlab.com/litecord/litecord.git
remove code from dispatcher
leftovers are TBD. - constrict Dispatcher.dispatch() to arity 3 - add helper methods to Dispatcher - add EventType to Dispatcher While fixing things, it was discovered that many of the things inside LazyGuildDispatcher were just interfaces to GuildMemberList, in a very weird way, just so it could be fitted inside the main Dispatcher. it was decided to remove those unecessary interfaces, clients shall use the manager directly.
This commit is contained in:
parent
39e8a1ad7e
commit
b0eb3247fd
2
Pipfile
2
Pipfile
|
|
@ -18,7 +18,7 @@ zstandard = "*"
|
|||
winter = {editable = true,git = "https://gitlab.com/elixire/winter.git"}
|
||||
|
||||
[dev-packages]
|
||||
pytest = "==5.1.2"
|
||||
pytest = "==5.3.2"
|
||||
pytest-asyncio = "==0.10.0"
|
||||
mypy = "*"
|
||||
black = "*"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3"
|
||||
"sha256": "dedc41184df539a608717e68c108ccfb4d529acb1d1702d83de223840c7cc754"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
|
|
@ -123,41 +123,36 @@
|
|||
},
|
||||
"cffi": {
|
||||
"hashes": [
|
||||
"sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42",
|
||||
"sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04",
|
||||
"sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5",
|
||||
"sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54",
|
||||
"sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba",
|
||||
"sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57",
|
||||
"sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396",
|
||||
"sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12",
|
||||
"sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97",
|
||||
"sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43",
|
||||
"sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db",
|
||||
"sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3",
|
||||
"sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b",
|
||||
"sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579",
|
||||
"sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346",
|
||||
"sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159",
|
||||
"sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652",
|
||||
"sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e",
|
||||
"sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a",
|
||||
"sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506",
|
||||
"sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f",
|
||||
"sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d",
|
||||
"sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c",
|
||||
"sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20",
|
||||
"sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858",
|
||||
"sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc",
|
||||
"sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a",
|
||||
"sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3",
|
||||
"sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e",
|
||||
"sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410",
|
||||
"sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25",
|
||||
"sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b",
|
||||
"sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d"
|
||||
"sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff",
|
||||
"sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b",
|
||||
"sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac",
|
||||
"sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0",
|
||||
"sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384",
|
||||
"sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26",
|
||||
"sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6",
|
||||
"sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b",
|
||||
"sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e",
|
||||
"sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd",
|
||||
"sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2",
|
||||
"sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66",
|
||||
"sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc",
|
||||
"sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8",
|
||||
"sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55",
|
||||
"sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4",
|
||||
"sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5",
|
||||
"sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d",
|
||||
"sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78",
|
||||
"sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa",
|
||||
"sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793",
|
||||
"sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f",
|
||||
"sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a",
|
||||
"sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f",
|
||||
"sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30",
|
||||
"sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f",
|
||||
"sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3",
|
||||
"sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c"
|
||||
],
|
||||
"version": "==1.13.2"
|
||||
"version": "==1.14.0"
|
||||
},
|
||||
"chardet": {
|
||||
"hashes": [
|
||||
|
|
@ -195,10 +190,10 @@
|
|||
},
|
||||
"h2": {
|
||||
"hashes": [
|
||||
"sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e",
|
||||
"sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4"
|
||||
"sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5",
|
||||
"sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14"
|
||||
],
|
||||
"version": "==3.1.1"
|
||||
"version": "==3.2.0"
|
||||
},
|
||||
"hpack": {
|
||||
"hashes": [
|
||||
|
|
@ -510,13 +505,6 @@
|
|||
],
|
||||
"version": "==1.4.3"
|
||||
},
|
||||
"atomicwrites": {
|
||||
"hashes": [
|
||||
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
|
||||
"sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6"
|
||||
],
|
||||
"version": "==1.3.0"
|
||||
},
|
||||
"attrs": {
|
||||
"hashes": [
|
||||
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
|
||||
|
|
@ -647,11 +635,11 @@
|
|||
},
|
||||
"pytest": {
|
||||
"hashes": [
|
||||
"sha256:95d13143cc14174ca1a01ec68e84d76ba5d9d493ac02716fd9706c949a505210",
|
||||
"sha256:b78fe2881323bd44fd9bd76e5317173d4316577e7b1cddebae9136a4495ec865"
|
||||
"sha256:6b571215b5a790f9b41f19f3531c53a45cf6bb8ef2988bc1ff9afb38270b25fa",
|
||||
"sha256:e41d489ff43948babd0fad7ad5e49b8735d5d55e26628a58673c39ff61d95de4"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==5.1.2"
|
||||
"version": "==5.3.2"
|
||||
},
|
||||
"pytest-asyncio": {
|
||||
"hashes": [
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ async def _update_features(guild_id: int, features: list):
|
|||
)
|
||||
|
||||
guild = await app.storage.get_guild_full(guild_id)
|
||||
await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
|
||||
@bp.route("/<int:guild_id>/features", methods=["PATCH"])
|
||||
|
|
|
|||
|
|
@ -63,10 +63,10 @@ async def update_guild(guild_id: int):
|
|||
|
||||
if old_unavailable and not new_unavailable:
|
||||
# guild became available
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild))
|
||||
else:
|
||||
# guild became unavailable
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_DELETE", guild))
|
||||
|
||||
return jsonify(guild)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ import itsdangerous
|
|||
import bcrypt
|
||||
from quart import Blueprint, jsonify, request, current_app as app
|
||||
from logbook import Logger
|
||||
from winter import get_snowflake
|
||||
|
||||
from litecord.auth import token_check
|
||||
from litecord.common.users import create_user
|
||||
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
|
||||
from litecord.errors import BadRequest
|
||||
from winter import get_snowflake
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
from .invites import use_invite
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -172,7 +173,7 @@ async def verify_user():
|
|||
)
|
||||
|
||||
new_user = await app.storage.get_user(user_id, True)
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
|
||||
await dispatch_user(user_id, ("USER_UPDATE", new_user))
|
||||
|
||||
return "", 204
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from litecord.common.messages import (
|
|||
msg_add_attachment,
|
||||
msg_guild_text_mentions,
|
||||
)
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -136,10 +137,10 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
|||
|
||||
# dispatch CHANNEL_CREATE so the client knows which
|
||||
# channel the future event is about
|
||||
await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
||||
await dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
||||
|
||||
# subscribe the peer to the channel
|
||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||
await app.dispatcher.channel.sub(channel_id, peer_id)
|
||||
|
||||
# insert it on dm_channel_state so the client
|
||||
# is subscribed on the future
|
||||
|
|
@ -242,7 +243,7 @@ async def _create_message(channel_id):
|
|||
await _dm_pre_dispatch(channel_id, user_id)
|
||||
await _dm_pre_dispatch(channel_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
|
||||
|
||||
# spawn url processor for embedding of images
|
||||
perms = await get_permissions(user_id, channel_id)
|
||||
|
|
@ -350,7 +351,7 @@ async def edit_message(channel_id, message_id):
|
|||
# only dispatch MESSAGE_UPDATE if any update
|
||||
# actually happened
|
||||
if updated:
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message))
|
||||
|
||||
return jsonify(message)
|
||||
|
||||
|
|
@ -427,9 +428,9 @@ async def delete_message(channel_id, message_id):
|
|||
message_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
(
|
||||
"MESSAGE_DELETE",
|
||||
{
|
||||
"id": str(message_id),
|
||||
|
|
@ -437,6 +438,7 @@ async def delete_message(channel_id, message_id):
|
|||
# for lazy guilds
|
||||
"guild_id": str(guild_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -106,11 +106,15 @@ async def add_pin(channel_id, message_id):
|
|||
|
||||
timestamp = snowflake_datetime(row["message_id"])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
(
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"last_pin_timestamp": timestamp_(timestamp),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await send_sys_message(
|
||||
|
|
@ -149,11 +153,15 @@ async def delete_pin(channel_id, message_id):
|
|||
|
||||
timestamp = snowflake_datetime(row["message_id"])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
(
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"last_pin_timestamp": timestamp.isoformat(),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -141,10 +141,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
|||
if ctype in GUILD_CHANS:
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_REACTION_ADD", payload
|
||||
)
|
||||
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_REACTION_ADD", payload))
|
||||
return "", 204
|
||||
|
||||
|
||||
|
|
@ -206,8 +203,8 @@ async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji
|
|||
if ctype in GUILD_CHANS:
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id, ("MESSAGE_REACTION_REMOVE", payload)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -290,6 +287,6 @@ async def remove_all_reactions(channel_id, message_id):
|
|||
if ctype in GUILD_CHANS:
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id, ("MESSAGE_REACTION_REMOVE_ALL", payload)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
import time
|
||||
import datetime
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
from winter import snowflake_datetime
|
||||
|
||||
from litecord.auth import token_check
|
||||
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
|
||||
|
|
@ -41,8 +43,9 @@ from litecord.system_messages import send_sys_message
|
|||
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
|
||||
from litecord.utils import search_result_from_list
|
||||
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
||||
from winter import snowflake_datetime
|
||||
from litecord.common.channels import channel_ack
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
from litecord.permissions import get_permissions, Permissions
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint("channels", __name__)
|
||||
|
|
@ -80,9 +83,7 @@ async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
|
|||
|
||||
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
||||
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
|
||||
|
||||
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
|
||||
|
||||
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
|
||||
|
||||
# if none of them were actually updated,
|
||||
|
|
@ -93,7 +94,7 @@ async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
|||
# at least one of the fields were updated,
|
||||
# dispatch GUILD_UPDATE
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
|
||||
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
||||
|
|
@ -104,7 +105,7 @@ async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
|||
return
|
||||
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
|
||||
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||
|
|
@ -134,7 +135,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
|||
# tell all people in the guild of the category removal
|
||||
for child_id in childs:
|
||||
child = await app.storage.get_channel(child_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", child))
|
||||
|
||||
|
||||
async def _delete_messages(channel_id):
|
||||
|
|
@ -249,12 +250,10 @@ async def close_channel(channel_id):
|
|||
)
|
||||
|
||||
# clean its member list representation
|
||||
lazy_guilds = app.dispatcher.backends["lazy_guild"]
|
||||
lazy_guilds.remove_channel(channel_id)
|
||||
app.lazy_guild.remove_channel(channel_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
|
||||
|
||||
await app.dispatcher.remove("channel", channel_id)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_DELETE", chan))
|
||||
await app.dispatcher.channel.drop(channel_id)
|
||||
return jsonify(chan)
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
|
|
@ -273,11 +272,9 @@ async def close_channel(channel_id):
|
|||
channel_id,
|
||||
)
|
||||
|
||||
# unsubscribe
|
||||
await app.dispatcher.unsub("channel", channel_id, user_id)
|
||||
|
||||
# nothing happens to the other party of the dm channel
|
||||
await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
|
||||
await app.dispatcher.channel.unsub(channel_id, user_id)
|
||||
await dispatch_user(user_id, ("CHANNEL_DELETE", chan))
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
|
@ -318,10 +315,54 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
|||
continue
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
|
||||
async def _process_overwrites(channel_id: int, overwrites: list):
|
||||
@dataclass
|
||||
class Target:
|
||||
type: int
|
||||
user_id: Optional[int]
|
||||
role_id: Optional[int]
|
||||
|
||||
@property
|
||||
def is_user(self):
|
||||
return self.type == 0
|
||||
|
||||
@property
|
||||
def is_role(self):
|
||||
return self.type == 1
|
||||
|
||||
|
||||
async def _dispatch_action(guild_id: int, channel_id: int, user_id: int, perms) -> None:
|
||||
"""Apply an action of sub/unsub to all states of a user."""
|
||||
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||
for state in states:
|
||||
if perms.read_messages:
|
||||
await app.dispatcher.channel.sub(channel_id, state.session_id)
|
||||
else:
|
||||
await app.dispatcher.channel.unsub(channel_id, state.session_id)
|
||||
|
||||
|
||||
async def _process_overwrites(guild_id: int, channel_id: int, overwrites: list) -> None:
|
||||
# user_ids serves as a "prospect" user id list.
|
||||
# for each overwrite we apply, we fill this list with user ids we
|
||||
# want to check later to subscribe/unsubscribe from the channel.
|
||||
# (users without read_messages are automatically unsubbed since we
|
||||
# don't want to leak messages to them when they dont have perms anymore)
|
||||
|
||||
# the expensiveness of large overwrite/role chains shines here.
|
||||
# since each user id we fill in implies an entire get_permissions call
|
||||
# (because we don't have the answer if a user is to be subbed/unsubbed
|
||||
# with only overwrites, an overwrite for a user allowing them might be
|
||||
# overwritten by a role overwrite denying them if they have the role),
|
||||
# we get a lot of tension on that, causing channel updates to lag a bit.
|
||||
|
||||
# there may be some good optimizations to do here, such as doing some
|
||||
# precalculations like fetching get_permissions for everyone first, then
|
||||
# applying the new overwrites one by one, then subbing/unsubbing at the
|
||||
# end, but it would be very memory intensive.
|
||||
user_ids: List[int] = []
|
||||
|
||||
for overwrite in overwrites:
|
||||
|
||||
# 0 for member overwrite, 1 for role overwrite
|
||||
|
|
@ -329,7 +370,9 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
|||
target_role = None if target_type == 0 else overwrite["id"]
|
||||
target_user = overwrite["id"] if target_type == 0 else None
|
||||
|
||||
col_name = "target_user" if target_type == 0 else "target_role"
|
||||
target = Target(target_type, target_role, target_user)
|
||||
|
||||
col_name = "target_user" if target.is_user else "target_role"
|
||||
constraint_name = f"channel_overwrites_{col_name}_uniq"
|
||||
|
||||
await app.db.execute(
|
||||
|
|
@ -352,6 +395,17 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
|||
overwrite["deny"],
|
||||
)
|
||||
|
||||
if target.is_user:
|
||||
perms = Permissions(overwrite["allow"] & ~overwrite["deny"])
|
||||
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
|
||||
|
||||
elif target.is_role:
|
||||
user_ids.extend(await app.storage.get_role_members(target.role_id))
|
||||
|
||||
for user_id in user_ids:
|
||||
perms = await get_permissions(user_id, channel_id)
|
||||
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
|
||||
|
||||
|
||||
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
|
||||
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||
|
|
@ -371,6 +425,7 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
|||
)
|
||||
|
||||
await _process_overwrites(
|
||||
guild_id,
|
||||
channel_id,
|
||||
[
|
||||
{
|
||||
|
|
@ -447,7 +502,7 @@ async def _update_channel_common(channel_id: int, guild_id: int, j: dict):
|
|||
|
||||
if "channel_overwrites" in j:
|
||||
overwrites = j["channel_overwrites"]
|
||||
await _process_overwrites(channel_id, overwrites)
|
||||
await _process_overwrites(guild_id, channel_id, overwrites)
|
||||
|
||||
|
||||
async def _common_guild_chan(channel_id, j: dict):
|
||||
|
|
@ -566,9 +621,9 @@ async def update_channel(channel_id: int):
|
|||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
if is_guild:
|
||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
|
||||
else:
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
|
@ -578,9 +633,9 @@ async def trigger_typing(channel_id):
|
|||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
(
|
||||
"TYPING_START",
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
|
|
@ -589,6 +644,7 @@ async def trigger_typing(channel_id):
|
|||
# guild_id for lazy guilds
|
||||
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
@ -816,5 +872,5 @@ async def bulk_delete(channel_id: int):
|
|||
if res == "DELETE 0":
|
||||
raise BadRequest("No messages were removed")
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_DELETE_BULK", payload))
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from litecord.errors import BadRequest, Forbidden
|
|||
from winter import get_snowflake
|
||||
from litecord.system_messages import send_sys_message
|
||||
from litecord.pubsub.channel import gdm_recipient_view
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint("dm_channels", __name__)
|
||||
|
|
@ -82,11 +83,11 @@ async def gdm_create(user_id, peer_id) -> int:
|
|||
await _raw_gdm_add(channel_id, user_id)
|
||||
await _raw_gdm_add(channel_id, peer_id)
|
||||
|
||||
await app.dispatcher.sub("channel", channel_id, user_id)
|
||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||
await app.dispatcher.channel.sub(channel_id, user_id)
|
||||
await app.dispatcher.channel.sub(channel_id, peer_id)
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_CREATE", chan))
|
||||
|
||||
return channel_id
|
||||
|
||||
|
|
@ -104,13 +105,10 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
# the reasoning behind gdm_recipient_view is in its docstring.
|
||||
await app.dispatcher.dispatch(
|
||||
"user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
|
||||
)
|
||||
await dispatch_user(peer_id, ("CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)))
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
await app.dispatcher.sub("channel", peer_id)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
|
||||
await app.dispatcher.channel.sub(peer_id)
|
||||
|
||||
if user_id:
|
||||
await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
|
||||
|
|
@ -128,17 +126,19 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
await _raw_gdm_remove(channel_id, peer_id)
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch(
|
||||
"user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
|
||||
)
|
||||
await dispatch_user(peer_id, ("CHANNEL_DELETE", gdm_recipient_view(chan, user_id)))
|
||||
|
||||
await app.dispatcher.unsub("channel", peer_id)
|
||||
await app.dispatcher.channel.unsub(peer_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
(
|
||||
"CHANNEL_RECIPIENT_REMOVE",
|
||||
{"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
author_id = peer_id if user_id is None else user_id
|
||||
|
|
@ -174,9 +174,8 @@ async def gdm_destroy(channel_id):
|
|||
channel_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
|
||||
|
||||
await app.dispatcher.remove("channel", channel_id)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_DELETE", chan))
|
||||
await app.dispatcher.channel.drop(channel_id)
|
||||
|
||||
|
||||
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
|
||||
|
|
|
|||
|
|
@ -59,19 +59,9 @@ async def create_channel(guild_id):
|
|||
new_channel_id = get_snowflake()
|
||||
await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
|
||||
|
||||
# TODO: do a better method
|
||||
# subscribe the currently subscribed users to the new channel
|
||||
# by getting all user ids and subscribing each one by one.
|
||||
|
||||
# since GuildDispatcher calls Storage.get_channel_ids,
|
||||
# it will subscribe all users to the newly created channel.
|
||||
guild_pubsub = app.dispatcher.backends["guild"]
|
||||
user_ids = guild_pubsub.state[guild_id]
|
||||
for uid in user_ids:
|
||||
await app.dispatcher.sub("guild", guild_id, uid)
|
||||
|
||||
chan = await app.storage.get_channel(new_channel_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_CREATE", chan))
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
|
|
@ -79,7 +69,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
|
|||
"""Fetch new information about the channel and dispatch
|
||||
a single CHANNEL_UPDATE event to the guild."""
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
|
||||
|
||||
|
||||
async def _do_single_swap(guild_id: int, pair: tuple):
|
||||
|
|
|
|||
|
|
@ -32,14 +32,15 @@ bp = Blueprint("guild.emoji", __name__)
|
|||
|
||||
async def _dispatch_emojis(guild_id):
|
||||
"""Dispatch a Guild Emojis Update payload to a guild."""
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
(
|
||||
"GUILD_EMOJIS_UPDATE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"emojis": await app.storage.get_guild_emojis(guild_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -192,12 +192,9 @@ async def modify_guild_member(guild_id, member_id):
|
|||
if nick_flag:
|
||||
partial["nick"] = j["nick"]
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"lazy_guild", guild_id, "pres_update", user_id, partial
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||
await app.lazy_guild.pres_update(guild_id, user_id, partial)
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
@ -228,12 +225,9 @@ async def update_nickname(guild_id):
|
|||
member.pop("joined_at")
|
||||
|
||||
# call pres_update for nick changes, etc.
|
||||
await app.dispatcher.dispatch(
|
||||
"lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||
await app.lazy_guild.pres_update(guild_id, user_id, {"nick": j["nick"]})
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
|
||||
)
|
||||
|
||||
return j["nick"]
|
||||
|
|
|
|||
|
|
@ -86,10 +86,12 @@ async def create_ban(guild_id, member_id):
|
|||
|
||||
await remove_member(guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
(
|
||||
"GUILD_BAN_ADD",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
@ -115,10 +117,12 @@ async def remove_ban(guild_id, banned_id):
|
|||
if res == "DELETE 0":
|
||||
return "", 204
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
(
|
||||
"GUILD_BAN_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -67,8 +67,8 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
|
|||
|
||||
await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role})
|
||||
)
|
||||
|
||||
return role
|
||||
|
|
@ -304,10 +304,9 @@ async def delete_guild_role(guild_id, role_id):
|
|||
|
||||
await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"GUILD_ROLE_DELETE",
|
||||
{"guild_id": str(guild_id), "role_id": str(role_id)},
|
||||
("GUILD_ROLE_DELETE", {"guild_id": str(guild_id), "role_id": str(role_id)},),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -191,8 +191,9 @@ async def create_guild():
|
|||
|
||||
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
|
||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
|
||||
await app.dispatcher.guild.sub_user(guild_id, user_id)
|
||||
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild_total))
|
||||
return jsonify(guild_total)
|
||||
|
||||
|
||||
|
|
@ -350,7 +351,7 @@ async def _update_guild(guild_id):
|
|||
)
|
||||
|
||||
guild = await app.storage.get_guild_full(guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
return jsonify(guild)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ async def _inv_check_age(inv: dict):
|
|||
await delete_invite(inv["code"])
|
||||
raise InvalidInvite("Invite is expired")
|
||||
|
||||
if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
|
||||
if inv["max_uses"] != -1 and inv["uses"] > inv["max_uses"]:
|
||||
await delete_invite(inv["code"])
|
||||
raise InvalidInvite("Too many uses")
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from ..auth import token_check
|
|||
from ..schemas import validate, RELATIONSHIP, SPECIFIC_FRIEND
|
||||
from ..enums import RelationshipType
|
||||
from litecord.errors import BadRequest
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
|
||||
bp = Blueprint("relationship", __name__)
|
||||
|
|
@ -36,17 +37,17 @@ async def get_me_relationships():
|
|||
|
||||
|
||||
async def _dispatch_single_pres(user_id, presence: dict):
|
||||
await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
|
||||
await dispatch_user(user_id, ("PRESENCE_UPDATE", presence))
|
||||
|
||||
|
||||
async def _unsub_friend(user_id, peer_id):
|
||||
await app.dispatcher.unsub("friend", user_id, peer_id)
|
||||
await app.dispatcher.unsub("friend", peer_id, user_id)
|
||||
await app.dispatcher.friend.unsub(user_id, peer_id)
|
||||
await app.dispatcher.friend.unsub(peer_id, user_id)
|
||||
|
||||
|
||||
async def _sub_friend(user_id, peer_id):
|
||||
await app.dispatcher.sub("friend", user_id, peer_id)
|
||||
await app.dispatcher.sub("friend", peer_id, user_id)
|
||||
await app.dispatcher.friend.sub(user_id, peer_id)
|
||||
await app.dispatcher.friend.sub(peer_id, user_id)
|
||||
|
||||
# dispatch presence update to the user and peer about
|
||||
# eachother's presence.
|
||||
|
|
@ -107,8 +108,8 @@ async def make_friend(
|
|||
_friend,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user(
|
||||
peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
|
||||
await dispatch_user(
|
||||
peer_id, ("RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)})
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
|
@ -130,35 +131,41 @@ async def make_friend(
|
|||
_friend,
|
||||
)
|
||||
|
||||
_dispatch = app.dispatcher.dispatch_user
|
||||
_dispatch = dispatch_user
|
||||
|
||||
if existing:
|
||||
# accepted a friend request, dispatch respective
|
||||
# relationship events
|
||||
await _dispatch(
|
||||
user_id,
|
||||
(
|
||||
"RELATIONSHIP_REMOVE",
|
||||
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
|
||||
),
|
||||
)
|
||||
|
||||
await _dispatch(
|
||||
user_id,
|
||||
(
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(peer_id),
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await _dispatch(
|
||||
peer_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
(
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(user_id),
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await _sub_friend(user_id, peer_id)
|
||||
|
|
@ -169,22 +176,26 @@ async def make_friend(
|
|||
if rel_type == _friend:
|
||||
await _dispatch(
|
||||
user_id,
|
||||
(
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(peer_id),
|
||||
"type": RelationshipType.OUTGOING.value,
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await _dispatch(
|
||||
peer_id,
|
||||
(
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(user_id),
|
||||
"type": RelationshipType.INCOMING.value,
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# we don't make the pubsub link
|
||||
|
|
@ -240,7 +251,7 @@ async def add_relationship(peer_id: int):
|
|||
# make_friend did not succeed, so we
|
||||
# assume it is a block and dispatch
|
||||
# the respective RELATIONSHIP_ADD.
|
||||
await app.dispatcher.dispatch_user(
|
||||
await dispatch_user(
|
||||
user_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
|
|
@ -261,7 +272,7 @@ async def remove_relationship(peer_id: int):
|
|||
user_id = await token_check()
|
||||
_friend = RelationshipType.FRIEND.value
|
||||
_block = RelationshipType.BLOCK.value
|
||||
_dispatch = app.dispatcher.dispatch_user
|
||||
_dispatch = dispatch_user
|
||||
|
||||
rel_type = await app.db.fetchval(
|
||||
"""
|
||||
|
|
@ -307,7 +318,8 @@ async def remove_relationship(peer_id: int):
|
|||
)
|
||||
|
||||
await _dispatch(
|
||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
|
||||
user_id,
|
||||
("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}),
|
||||
)
|
||||
|
||||
peer_del_type = (
|
||||
|
|
@ -315,7 +327,8 @@ async def remove_relationship(peer_id: int):
|
|||
)
|
||||
|
||||
await _dispatch(
|
||||
peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
|
||||
peer_id,
|
||||
("RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}),
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
|
@ -334,7 +347,7 @@ async def remove_relationship(peer_id: int):
|
|||
)
|
||||
|
||||
await _dispatch(
|
||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
|
||||
user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block})
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from quart import Blueprint, jsonify, request, current_app as app
|
|||
from litecord.auth import token_check
|
||||
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
|
||||
from litecord.blueprints.checks import guild_check
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
bp = Blueprint("users_settings", __name__)
|
||||
|
||||
|
|
@ -58,7 +59,7 @@ async def patch_current_settings():
|
|||
)
|
||||
|
||||
settings = await app.user_storage.get_user_settings(user_id)
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
|
||||
await dispatch_user(user_id, ("USER_SETTINGS_UPDATE", settings))
|
||||
return jsonify(settings)
|
||||
|
||||
|
||||
|
|
@ -123,7 +124,7 @@ async def patch_guild_settings(guild_id: int):
|
|||
|
||||
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
|
||||
await dispatch_user(user_id, ("USER_GUILD_SETTINGS_UPDATE", settings))
|
||||
|
||||
return jsonify(settings)
|
||||
|
||||
|
|
@ -157,8 +158,8 @@ async def put_note(target_id: int):
|
|||
note,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
|
||||
await dispatch_user(
|
||||
user_id, ("USER_NOTE_UPDATE", {"id": str(target_id), "note": note})
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -156,11 +156,12 @@ async def webhook_token_check(webhook_id: int, webhook_token: str):
|
|||
|
||||
|
||||
async def _dispatch_webhook_update(guild_id: int, channel_id):
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
(
|
||||
"WEBHOOKS_UPDATE",
|
||||
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -503,7 +504,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
|||
|
||||
payload = await app.storage.get_message(message_id)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
|
||||
|
||||
# spawn embedder in the background, even when we're on a webhook.
|
||||
app.sched.spawn(process_url_embed(payload))
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from quart import current_app as app
|
|||
|
||||
from litecord.errors import ForbiddenDM
|
||||
from litecord.enums import RelationshipType
|
||||
from litecord.pubsub.member import dispatch_member
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
|
||||
async def channel_ack(
|
||||
|
|
@ -54,20 +56,24 @@ async def channel_ack(
|
|||
)
|
||||
|
||||
if guild_id:
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
user_id,
|
||||
await dispatch_member(
|
||||
guild_id,
|
||||
user_id,
|
||||
(
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
),
|
||||
)
|
||||
else:
|
||||
# we don't use ChannelDispatcher here because since
|
||||
# guild_id is None, all user devices are already subscribed
|
||||
# to the given channel (a dm or a group dm)
|
||||
await app.dispatcher.dispatch_user(
|
||||
await dispatch_user(
|
||||
user_id,
|
||||
(
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
from ..snowflake import get_snowflake
|
||||
from ..permissions import get_role_perms
|
||||
from ..permissions import get_role_perms, get_permissions
|
||||
from ..utils import dict_get, maybe_lazy_guild_dispatch
|
||||
from ..enums import ChannelType
|
||||
from litecord.pubsub.member import dispatch_member
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -41,21 +43,20 @@ async def remove_member(guild_id: int, member_id: int):
|
|||
member_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
member_id,
|
||||
await dispatch_member(
|
||||
guild_id,
|
||||
"GUILD_DELETE",
|
||||
{"guild_id": str(guild_id), "unavailable": False},
|
||||
member_id,
|
||||
("GUILD_DELETE", {"guild_id": str(guild_id), "unavailable": False}),
|
||||
)
|
||||
|
||||
await app.dispatcher.unsub("guild", guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.unsub(guild_id, member_id)
|
||||
await app.lazy_guild.remove_member(member_id)
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
(
|
||||
"GUILD_MEMBER_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -108,8 +109,8 @@ async def create_role(guild_id, name: str, **kwargs):
|
|||
# we need to update the lazy guild handlers for the newly created group
|
||||
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role})
|
||||
)
|
||||
|
||||
return role
|
||||
|
|
@ -137,6 +138,39 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
async def _subscribe_users_new_channel(guild_id: int, channel_id: int) -> None:
|
||||
|
||||
# for each state currently subscribed to guild, we check on the database
|
||||
# which states can also subscribe to the new channel at its creation.
|
||||
|
||||
# the list of users that can subscribe are then used again for a pass
|
||||
# over the states and states that have user ids in that list become
|
||||
# subscribers of the new channel.
|
||||
users_to_sub: List[str] = []
|
||||
|
||||
for session_id in app.dispatcher.guild.state[guild_id]:
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if state.user_id in users_to_sub:
|
||||
continue
|
||||
|
||||
perms = await get_permissions(state.user_id, channel_id)
|
||||
if perms.read_messages:
|
||||
users_to_sub.append(state.user_id)
|
||||
|
||||
for session_id in app.dispatcher.guild.state[guild_id]:
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if state.user_id in users_to_sub:
|
||||
await app.dispatcher.channel.sub(channel_id, session_id)
|
||||
|
||||
|
||||
async def create_guild_channel(
|
||||
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
||||
):
|
||||
|
|
@ -180,6 +214,8 @@ async def create_guild_channel(
|
|||
# so we use this function.
|
||||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||
|
||||
await _subscribe_users_new_channel(guild_id, channel_id)
|
||||
|
||||
|
||||
async def _del_from_table(table: str, user_id: int):
|
||||
"""Delete a row from a table."""
|
||||
|
|
@ -206,21 +242,22 @@ async def delete_guild(guild_id: int):
|
|||
)
|
||||
|
||||
# Discord's client expects IDs being string
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
(
|
||||
"GUILD_DELETE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"id": str(guild_id),
|
||||
# 'unavailable': False,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# remove from the dispatcher so nobody
|
||||
# becomes the little memer that tries to fuck up with
|
||||
# everybody's gateway
|
||||
await app.dispatcher.remove("guild", guild_id)
|
||||
await app.dispatcher.guild.drop(guild_id)
|
||||
|
||||
|
||||
async def create_guild_settings(guild_id: int, user_id: int):
|
||||
|
|
@ -285,18 +322,17 @@ async def add_member(guild_id: int, user_id: int, *, basic=False):
|
|||
|
||||
# tell current members a new member came up
|
||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}})
|
||||
)
|
||||
|
||||
# update member lists for the new member
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
|
||||
# pubsub changes for new member
|
||||
await app.lazy_guild.new_member(guild_id, user_id)
|
||||
states = await app.dispatcher.guild.sub_user(guild_id, user_id)
|
||||
|
||||
# subscribe new member to guild, so they get events n stuff
|
||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||
|
||||
# tell the new member that theres the guild it just joined.
|
||||
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
||||
# just to the shards that are actually tied to it.
|
||||
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
|
||||
for state in states:
|
||||
try:
|
||||
await state.ws.dispatch("GUILD_CREATE", guild)
|
||||
except Exception:
|
||||
log.exception("failed to dispatch to session_id={!r}", state.session_id)
|
||||
|
|
|
|||
|
|
@ -29,41 +29,48 @@ from ..snowflake import get_snowflake
|
|||
from ..errors import BadRequest
|
||||
from ..auth import hash_data
|
||||
from ..utils import rand_hex
|
||||
from ..pubsub.user import dispatch_user
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
async def mass_user_update(user_id: int):
|
||||
"""Dispatch USER_UPDATE in a mass way."""
|
||||
# by using dispatch_with_filter
|
||||
# we're guaranteeing all shards will get
|
||||
# a USER_UPDATE once and not any others.
|
||||
async def mass_user_update(user_id: int) -> Tuple[dict, dict]:
|
||||
"""Dispatch a USER_UPDATE to everyone that is subscribed to the user.
|
||||
|
||||
This function guarantees all states will get one USER_UPDATE for simple
|
||||
cases. Lazy guild users might get updates N times depending of how many
|
||||
lists are they subscribed to.
|
||||
"""
|
||||
session_ids: List[str] = []
|
||||
|
||||
public_user = await app.storage.get_user(user_id)
|
||||
private_user = await app.storage.get_user(user_id, secure=True)
|
||||
|
||||
session_ids.extend(await dispatch_user(user_id, ("USER_UPDATE", private_user)))
|
||||
|
||||
guild_ids: List[int] = await app.user_storage.get_user_guilds(user_id)
|
||||
friend_ids: List[int] = await app.user_storage.get_friend_ids(user_id)
|
||||
|
||||
for guild_id in guild_ids:
|
||||
session_ids.extend(
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
|
||||
await app.dispatcher.guild.dispatch_filter(
|
||||
guild_id,
|
||||
lambda sess_id: sess_id not in session_ids,
|
||||
("USER_UPDATE", public_user),
|
||||
)
|
||||
)
|
||||
|
||||
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
||||
friend_ids = await app.user_storage.get_friend_ids(user_id)
|
||||
|
||||
for friend_id in friend_ids:
|
||||
session_ids.extend(
|
||||
await app.dispatcher.dispatch_many_filter_list(
|
||||
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
|
||||
await app.dispatcher.friend.dispatch_filter(
|
||||
friend_id,
|
||||
lambda sess_id: sess_id not in session_ids,
|
||||
("USER_UPDATE", public_user),
|
||||
)
|
||||
)
|
||||
|
||||
session_ids.extend(
|
||||
await app.dispatcher.dispatch_many_filter_list(
|
||||
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
|
||||
)
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
|
||||
for guild_id in guild_ids:
|
||||
await app.lazy_guild.update_user(guild_id, user_id)
|
||||
|
||||
return public_user, private_user
|
||||
|
||||
|
|
|
|||
|
|
@ -17,17 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from .pubsub import (
|
||||
GuildDispatcher,
|
||||
MemberDispatcher,
|
||||
UserDispatcher,
|
||||
ChannelDispatcher,
|
||||
FriendDispatcher,
|
||||
LazyGuildDispatcher,
|
||||
)
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -50,169 +45,7 @@ class EventDispatcher:
|
|||
its subscriber ids.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.state_manager = app.state_manager
|
||||
self.app = app
|
||||
|
||||
self.backends = {
|
||||
"guild": GuildDispatcher(self),
|
||||
"member": MemberDispatcher(self),
|
||||
"channel": ChannelDispatcher(self),
|
||||
"user": UserDispatcher(self),
|
||||
"friend": FriendDispatcher(self),
|
||||
"lazy_guild": LazyGuildDispatcher(self),
|
||||
}
|
||||
|
||||
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
||||
"""Send an action regarding a key/identifier pair to a backend.
|
||||
|
||||
Action is usually "sub" or "unsub".
|
||||
"""
|
||||
backend = self.backends[backend_str]
|
||||
method = getattr(backend, action)
|
||||
|
||||
# convert keys to the types the backend wants
|
||||
key = backend.KEY_TYPE(key)
|
||||
identifier = backend.VAL_TYPE(identifier)
|
||||
|
||||
return await method(key, identifier, *args)
|
||||
|
||||
async def subscribe(
|
||||
self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
|
||||
):
|
||||
"""Subscribe a single element to the given backend."""
|
||||
flags = flags or {}
|
||||
|
||||
log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
|
||||
|
||||
# this is a hacky solution for backwards compatibility between backends
|
||||
# that implement flags and backends that don't.
|
||||
|
||||
# passing flags to backends that don't implement flags will
|
||||
# cause errors as expected.
|
||||
if flags:
|
||||
return await self.action(backend, "sub", key, identifier, flags)
|
||||
|
||||
return await self.action(backend, "sub", key, identifier)
|
||||
|
||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
||||
"""Unsubscribe an element from the given backend."""
|
||||
log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
|
||||
|
||||
return await self.action(backend, "unsub", key, identifier)
|
||||
|
||||
async def sub(self, backend, key, identifier):
|
||||
"""Alias to subscribe()."""
|
||||
return await self.subscribe(backend, key, identifier)
|
||||
|
||||
async def unsub(self, backend, key, identifier):
|
||||
"""Alias to unsubscribe()."""
|
||||
return await self.unsubscribe(backend, key, identifier)
|
||||
|
||||
async def sub_many(
|
||||
self,
|
||||
backend_str: str,
|
||||
identifier: Any,
|
||||
keys: list,
|
||||
flags: Dict[str, Any] = None,
|
||||
):
|
||||
"""Subscribe to multiple channels (all in a single backend)
|
||||
at a time.
|
||||
|
||||
Usually used when connecting to the gateway and the client
|
||||
needs to subscribe to all their guids.
|
||||
"""
|
||||
flags = flags or {}
|
||||
for key in keys:
|
||||
await self.subscribe(backend_str, key, identifier, flags)
|
||||
|
||||
async def mass_sub(self, identifier: Any, backends: List[tuple]):
|
||||
"""Mass subscribe to many backends at once."""
|
||||
for bcall in backends:
|
||||
backend_str, keys = bcall[0], bcall[1]
|
||||
|
||||
if len(bcall) == 2:
|
||||
flags = {}
|
||||
elif len(bcall) == 3:
|
||||
# we have flags
|
||||
flags = bcall[2]
|
||||
|
||||
log.debug(
|
||||
"subscribing {} to {} keys in backend {}, flags: {}",
|
||||
identifier,
|
||||
len(keys),
|
||||
backend_str,
|
||||
flags,
|
||||
)
|
||||
|
||||
await self.sub_many(backend_str, identifier, keys, flags)
|
||||
|
||||
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
||||
"""Dispatch an event to the backend.
|
||||
|
||||
The backend is responsible for everything regarding the
|
||||
actual dispatch.
|
||||
"""
|
||||
backend = self.backends[backend_str]
|
||||
|
||||
# convert types
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.dispatch(key, *args, **kwargs)
|
||||
|
||||
async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
|
||||
"""Dispatch to multiple keys in a single backend."""
|
||||
log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
|
||||
|
||||
for key in keys:
|
||||
await self.dispatch(backend_str, key, *args, **kwargs)
|
||||
|
||||
async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
|
||||
"""Dispatch to a backend that only accepts
|
||||
(event, data) arguments with an optional filter
|
||||
function."""
|
||||
backend = self.backends[backend_str]
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.dispatch_filter(key, func, *args)
|
||||
|
||||
async def dispatch_many_filter_list(
|
||||
self, backend_str: str, keys: List[Any], sess_list: List[str], *args
|
||||
):
|
||||
"""Make a "unique" dispatch given a list of session ids.
|
||||
|
||||
This only works for backends that have a dispatch_filter
|
||||
handler and return session id lists in their dispatch
|
||||
results.
|
||||
"""
|
||||
for key in keys:
|
||||
sess_list.extend(
|
||||
await self.dispatch_filter(
|
||||
backend_str, key, lambda sess_id: sess_id not in sess_list, *args
|
||||
)
|
||||
)
|
||||
|
||||
return sess_list
|
||||
|
||||
async def reset(self, backend_str: str, key: Any):
|
||||
"""Reset the bucket in the given backend."""
|
||||
backend = self.backends[backend_str]
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.reset(key)
|
||||
|
||||
async def remove(self, backend_str: str, key: Any):
|
||||
"""Remove a key from the backend. This
|
||||
might be a different operation than resetting."""
|
||||
backend = self.backends[backend_str]
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.remove(key)
|
||||
|
||||
async def dispatch_guild(self, guild_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch("guild", guild_id, event, data)
|
||||
|
||||
async def dispatch_user_guild(self, user_id, guild_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch("member", (guild_id, user_id), event, data)
|
||||
|
||||
async def dispatch_user(self, user_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch("user", user_id, event, data)
|
||||
def __init__(self):
|
||||
self.guild: GuildDispatcher = GuildDispatcher()
|
||||
self.channel = ChannelDispatcher()
|
||||
self.friend = FriendDispatcher()
|
||||
|
|
|
|||
|
|
@ -87,8 +87,8 @@ async def msg_update_embeds(payload, new_embeds):
|
|||
if "flags" in payload:
|
||||
update_payload["flags"] = payload["flags"]
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_UPDATE", update_payload
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id, ("MESSAGE_UPDATE", update_payload)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
from litecord.gateway.websocket import GatewayWebsocket
|
||||
|
||||
|
||||
|
|
@ -44,15 +45,16 @@ async def websocket_handler(app, ws, url):
|
|||
return await ws.close(1000, "Invalid gateway encoding")
|
||||
|
||||
try:
|
||||
gw_compress = args["compress"][0]
|
||||
gw_compress: Optional[str] = args["compress"][0]
|
||||
except (KeyError, IndexError):
|
||||
gw_compress = None
|
||||
|
||||
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
|
||||
return await ws.close(1000, "Invalid gateway compress")
|
||||
|
||||
async with app.app_context():
|
||||
gws = GatewayWebsocket(
|
||||
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
||||
ws, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
||||
)
|
||||
|
||||
# this can be run with a single await since this whole coroutine
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import asyncio
|
|||
from typing import List
|
||||
from collections import defaultdict
|
||||
|
||||
from quart import current_app as app
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -225,3 +226,16 @@ class StateManager:
|
|||
def close(self):
|
||||
"""Close the state manager."""
|
||||
self.closed = True
|
||||
|
||||
async def fetch_user_states_for_channel(
|
||||
self, channel_id: int, user_id: int
|
||||
) -> List[GatewayState]:
|
||||
"""Get a list of gateway states for a user that can receive events on a certain channel."""
|
||||
# TODO optimize this with an in-memory store
|
||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||
|
||||
if guild_id:
|
||||
return self.fetch_states(user_id, guild_id)
|
||||
|
||||
# DMs and GDMs use all user states
|
||||
return self.user_states(user_id)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from random import randint
|
|||
import websockets
|
||||
import zstandard as zstd
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType, ChannelType
|
||||
|
|
@ -43,7 +44,6 @@ from litecord.presence import BasePresence
|
|||
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.gateway.state import GatewayState
|
||||
|
||||
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
|
||||
from litecord.gateway.errors import (
|
||||
DecodeError,
|
||||
|
|
@ -52,8 +52,9 @@ from litecord.gateway.errors import (
|
|||
ShardingRequired,
|
||||
)
|
||||
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
|
||||
|
||||
from litecord.gateway.utils import WebsocketFileHandler
|
||||
from litecord.pubsub.guild import GuildFlags
|
||||
from litecord.pubsub.channel import ChannelFlags
|
||||
|
||||
from litecord.storage import int_
|
||||
|
||||
|
|
@ -67,7 +68,7 @@ WebsocketProperties = collections.namedtuple(
|
|||
class GatewayWebsocket:
|
||||
"""Main gateway websocket logic."""
|
||||
|
||||
def __init__(self, ws, app, **kwargs):
|
||||
def __init__(self, ws, **kwargs):
|
||||
self.app = app
|
||||
self.storage = app.storage
|
||||
self.user_storage = app.user_storage
|
||||
|
|
@ -230,7 +231,7 @@ class GatewayWebsocket:
|
|||
if task:
|
||||
task.cancel()
|
||||
|
||||
self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
|
||||
self.wsp.tasks["heartbeat"] = app.sched.spawn(
|
||||
task_wrapper("hb wait", self._hb_wait(interval))
|
||||
)
|
||||
|
||||
|
|
@ -247,6 +248,7 @@ class GatewayWebsocket:
|
|||
|
||||
async def dispatch(self, event: str, data: Any):
|
||||
"""Dispatch an event to the websocket."""
|
||||
assert self.state is not None
|
||||
self.state.seq += 1
|
||||
|
||||
payload = {
|
||||
|
|
@ -282,6 +284,7 @@ class GatewayWebsocket:
|
|||
|
||||
async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
|
||||
"""Dispatch GUILD_CREATE information."""
|
||||
assert self.state is not None
|
||||
|
||||
# Users don't get asynchronous guild dispatching.
|
||||
if not self.state.bot:
|
||||
|
|
@ -360,9 +363,7 @@ class GatewayWebsocket:
|
|||
}
|
||||
|
||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||
|
||||
# async dispatch of guilds
|
||||
self.app.loop.create_task(self._guild_dispatch(guilds))
|
||||
app.sched.spawn(self._guild_dispatch(guilds))
|
||||
|
||||
async def _check_shards(self, shard, user_id):
|
||||
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
||||
|
|
@ -412,6 +413,7 @@ class GatewayWebsocket:
|
|||
Note: subscribing to channels is already handled
|
||||
by GuildDispatcher.sub
|
||||
"""
|
||||
assert self.state is not None
|
||||
user_id = self.state.user_id
|
||||
guild_ids = await self._guild_ids()
|
||||
|
||||
|
|
@ -434,26 +436,50 @@ class GatewayWebsocket:
|
|||
# (presence and typing events)
|
||||
|
||||
# we enable processing of guild_subscriptions by adding flags
|
||||
# when subscribing to the given backend. those are optional.
|
||||
channels_to_sub = [
|
||||
(
|
||||
"guild",
|
||||
guild_ids,
|
||||
{"presence": guild_subscriptions, "typing": guild_subscriptions},
|
||||
),
|
||||
("channel", dm_ids),
|
||||
("channel", gdm_ids),
|
||||
]
|
||||
# when subscribing to the given backend.
|
||||
session_id = self.state.session_id
|
||||
channel_ids: List[int] = []
|
||||
|
||||
await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||
for guild_id in guild_ids:
|
||||
await app.dispatcher.guild.sub_with_flags(
|
||||
guild_id,
|
||||
session_id,
|
||||
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
|
||||
)
|
||||
|
||||
# instead of calculating which channels to subscribe to
|
||||
# inside guild dispatcher, we calculate them in here, so that
|
||||
# we remove complexity of the dispatcher.
|
||||
|
||||
guild_chan_ids = await app.storage.get_channel_ids(guild_id)
|
||||
for channel_id in guild_chan_ids:
|
||||
perms = await get_permissions(
|
||||
self.state.user_id, channel_id, storage=self.storage
|
||||
)
|
||||
|
||||
if perms.bits.read_messages:
|
||||
channel_ids.append(channel_id)
|
||||
|
||||
log.info("subscribing to {} guild channels", len(channel_ids))
|
||||
for channel_id in channel_ids:
|
||||
await app.dispatcher.channel.sub_with_flags(
|
||||
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
|
||||
)
|
||||
|
||||
for dm_id in dm_ids:
|
||||
await app.dispatcher.channel.sub(dm_id, session_id)
|
||||
|
||||
for gdm_id in gdm_ids:
|
||||
await app.dispatcher.channel.sub(gdm_id, session_id)
|
||||
|
||||
if not self.state.bot:
|
||||
# subscribe to all friends
|
||||
# (their friends will also subscribe back
|
||||
# when they come online)
|
||||
if not self.state.bot:
|
||||
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
||||
log.info("subscribing to {} friends", len(friend_ids))
|
||||
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||
for friend_id in friend_ids:
|
||||
await app.dispatcher.friend.sub(user_id, friend_id)
|
||||
|
||||
async def update_status(self, incoming_status: dict):
|
||||
"""Update the status of the current websocket connection."""
|
||||
|
|
@ -921,6 +947,7 @@ class GatewayWebsocket:
|
|||
]
|
||||
}
|
||||
"""
|
||||
assert self.state is not None
|
||||
data = payload["d"]
|
||||
|
||||
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||
|
|
@ -933,11 +960,9 @@ class GatewayWebsocket:
|
|||
log.debug("lazy request: members: {}", data.get("members", []))
|
||||
|
||||
# make shard query
|
||||
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
|
||||
|
||||
for chan_id, ranges in data.get("channels", {}).items():
|
||||
chan_id = int(chan_id)
|
||||
member_list = await lazy_guilds.get_gml(chan_id)
|
||||
member_list = await app.lazy_guild.get_gml(chan_id)
|
||||
|
||||
perms = await get_permissions(
|
||||
self.state.user_id, chan_id, storage=self.storage
|
||||
|
|
|
|||
|
|
@ -93,7 +93,6 @@ class PresenceManager:
|
|||
self.storage = app.storage
|
||||
self.user_storage = app.user_storage
|
||||
self.state_manager = app.state_manager
|
||||
self.dispatcher = app.dispatcher
|
||||
|
||||
async def guild_presences(
|
||||
self, member_ids: List[int], guild_id: int
|
||||
|
|
@ -127,8 +126,7 @@ class PresenceManager:
|
|||
|
||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||
|
||||
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||
lists = app.lazy_guild.get_gml_guild(guild_id)
|
||||
|
||||
# shards that are in lazy guilds with 'everyone'
|
||||
# enabled
|
||||
|
|
@ -163,20 +161,21 @@ class PresenceManager:
|
|||
# given a session id, return if the session id actually connects to
|
||||
# a given user, and if the state has not been dispatched via lazy guild.
|
||||
def _session_check(session_id):
|
||||
try:
|
||||
state = self.state_manager.fetch_raw(session_id)
|
||||
uid = int(member["user"]["id"])
|
||||
|
||||
if not state:
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
uid = int(member["user"]["id"])
|
||||
|
||||
# we don't want to send a presence update
|
||||
# to the same user
|
||||
return state.user_id != uid and session_id not in in_lazy
|
||||
|
||||
# everyone not in lazy guild mode
|
||||
# gets a PRESENCE_UPDATE
|
||||
await self.dispatcher.dispatch_filter(
|
||||
"guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload
|
||||
await app.dispatcher.guild.dispatch_filter(
|
||||
guild_id, _session_check, ("PRESENCE_UPDATE", event_payload)
|
||||
)
|
||||
|
||||
return in_lazy
|
||||
|
|
@ -193,11 +192,8 @@ class PresenceManager:
|
|||
|
||||
# dispatch to all friends that are subscribed to them
|
||||
user = await self.storage.get_user(user_id)
|
||||
await self.dispatcher.dispatch(
|
||||
"friend",
|
||||
user_id,
|
||||
"PRESENCE_UPDATE",
|
||||
{**presence.partial_dict, **{"user": user}},
|
||||
await app.dispatcher.friend.dispatch(
|
||||
user_id, ("PRESENCE_UPDATE", {**presence.partial_dict, **{"user": user}}),
|
||||
)
|
||||
|
||||
def fetch_friend_presence(self, friend_id: int) -> BasePresence:
|
||||
|
|
|
|||
|
|
@ -18,17 +18,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from .guild import GuildDispatcher
|
||||
from .member import MemberDispatcher
|
||||
from .user import UserDispatcher
|
||||
from .member import dispatch_member
|
||||
from .user import dispatch_user
|
||||
from .channel import ChannelDispatcher
|
||||
from .friend import FriendDispatcher
|
||||
from .lazy_guild import LazyGuildDispatcher
|
||||
|
||||
__all__ = [
|
||||
"GuildDispatcher",
|
||||
"MemberDispatcher",
|
||||
"UserDispatcher",
|
||||
"dispatch_member",
|
||||
"dispatch_user",
|
||||
"ChannelDispatcher",
|
||||
"FriendDispatcher",
|
||||
"LazyGuildDispatcher",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import current_app as app
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithFlags
|
||||
from litecord.enums import ChannelType
|
||||
from litecord.utils import index_by_func
|
||||
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -37,75 +39,66 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
|||
"""
|
||||
# make a copy or the original channel object
|
||||
data = dict(orig)
|
||||
|
||||
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
|
||||
|
||||
data["recipients"].pop(idx)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ChannelDispatcher(DispatcherWithFlags):
|
||||
"""Main channel Pub/Sub logic."""
|
||||
@dataclass
|
||||
class ChannelFlags:
|
||||
typing: bool
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def dispatch(self, channel_id, event: str, data: Any) -> List[str]:
|
||||
class ChannelDispatcher(
|
||||
DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
|
||||
):
|
||||
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
|
||||
|
||||
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
|
||||
"""Dispatch an event to a channel."""
|
||||
# get everyone who is subscribed
|
||||
# and store the number of states we dispatched the event to
|
||||
user_ids = self.state[channel_id]
|
||||
dispatched = 0
|
||||
session_ids = set(self.state[channel_id])
|
||||
sessions: List[str] = []
|
||||
|
||||
# making a copy of user_ids since
|
||||
# we'll modify it later on.
|
||||
for user_id in set(user_ids):
|
||||
guild_id = await self.app.storage.guild_from_channel(channel_id)
|
||||
event_type, event_data = event
|
||||
assert isinstance(event_data, dict)
|
||||
|
||||
# if we are dispatching to a guild channel,
|
||||
# we should only dispatch to the states / shards
|
||||
# that are connected to the guild (via their shard id).
|
||||
for session_id in session_ids:
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
await self.unsub(channel_id, session_id)
|
||||
continue
|
||||
|
||||
# if we aren't, we just get all states tied to the user.
|
||||
# TODO: make a fetch_states that fetches shards
|
||||
# - with id 0 (count any) OR
|
||||
# - single shards (id=0, count=1)
|
||||
states = (
|
||||
self.sm.fetch_states(user_id, guild_id)
|
||||
if guild_id
|
||||
else self.sm.user_states(user_id)
|
||||
try:
|
||||
flags = self.get_flags(channel_id, session_id)
|
||||
except KeyError:
|
||||
log.warning("no flags for {!r}, ignoring", session_id)
|
||||
flags = ChannelFlags(typing=True)
|
||||
|
||||
if event_type.lower().startswith("typing_") and not flags.typing:
|
||||
continue
|
||||
|
||||
correct_event = event
|
||||
# for cases where we are talking about group dms, we create an edited
|
||||
# event data so that it doesn't show the user we're dispatching
|
||||
# to in data.recipients (clients already assume they are recipients)
|
||||
if (
|
||||
event_type in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||
and event_data.get("type") == ChannelType.GROUP_DM.value
|
||||
):
|
||||
new_data = gdm_recipient_view(event_data, state.user_id)
|
||||
correct_event = (event_type, new_data)
|
||||
|
||||
try:
|
||||
await state.ws.dispatch(*correct_event)
|
||||
except Exception:
|
||||
log.exception("error while dispatching to {}", state.session_id)
|
||||
continue
|
||||
|
||||
sessions.append(session_id)
|
||||
|
||||
log.info(
|
||||
"Dispatched chan={} {!r} to {} states", channel_id, event[0], len(sessions)
|
||||
)
|
||||
|
||||
# unsub people who don't have any states tied to the channel.
|
||||
if not states:
|
||||
await self.unsub(channel_id, user_id)
|
||||
continue
|
||||
|
||||
# skip typing events for users that don't want it
|
||||
if event.startswith("TYPING_") and not self.flags_get(
|
||||
channel_id, user_id, "typing", True
|
||||
):
|
||||
continue
|
||||
|
||||
cur_sess: List[str] = []
|
||||
|
||||
if (
|
||||
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||
and data.get("type") == ChannelType.GROUP_DM.value
|
||||
):
|
||||
# we edit the channel payload so it doesn't show
|
||||
# the user as a recipient
|
||||
|
||||
new_data = gdm_recipient_view(data, user_id)
|
||||
cur_sess = await self._dispatch_states(states, event, new_data)
|
||||
else:
|
||||
cur_sess = await self._dispatch_states(states, event, data)
|
||||
|
||||
sessions.extend(cur_sess)
|
||||
dispatched += len(cur_sess)
|
||||
|
||||
log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
|
||||
|
||||
return sessions
|
||||
|
|
|
|||
|
|
@ -17,7 +17,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import (
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Set,
|
||||
Mapping,
|
||||
Iterable,
|
||||
Tuple,
|
||||
)
|
||||
from collections import defaultdict
|
||||
|
||||
from logbook import Logger
|
||||
|
|
@ -25,79 +36,63 @@ from logbook import Logger
|
|||
log = Logger(__name__)
|
||||
|
||||
|
||||
def _identity(_self, x):
|
||||
return x
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
F = TypeVar("F")
|
||||
EventType = TypeVar("EventType")
|
||||
DispatchType = TypeVar("DispatchType")
|
||||
F_Map = Mapping[V, F]
|
||||
|
||||
GatewayEvent = Tuple[str, Any]
|
||||
|
||||
__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"]
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
class Dispatcher(Generic[K, V, EventType, DispatchType]):
|
||||
"""Pub/Sub backend dispatcher.
|
||||
|
||||
This just declares functions all Dispatcher subclasses
|
||||
can implement. This does not mean all Dispatcher
|
||||
subclasses have them implemented.
|
||||
Classes must implement this protocol.
|
||||
"""
|
||||
|
||||
KEY_TYPE = _identity
|
||||
VAL_TYPE = _identity
|
||||
async def sub(self, key: K, identifier: V) -> None:
|
||||
"""Subscribe a given identifier to a given key."""
|
||||
...
|
||||
|
||||
def __init__(self, main):
|
||||
#: main EventDispatcher
|
||||
self.main_dispatcher = main
|
||||
async def sub_many(self, key: K, identifier_list: Iterable[V]) -> None:
|
||||
for identifier in identifier_list:
|
||||
await self.sub(key, identifier)
|
||||
|
||||
#: gateway state storage
|
||||
self.sm = main.state_manager
|
||||
async def unsub(self, key: K, identifier: V) -> None:
|
||||
"""Unsubscribe a given identifier to a given key."""
|
||||
...
|
||||
|
||||
self.app = main.app
|
||||
async def dispatch(self, key: K, event: EventType) -> DispatchType:
|
||||
...
|
||||
|
||||
async def sub(self, _key, _id):
|
||||
"""Subscribe an element to the channel/key."""
|
||||
raise NotImplementedError
|
||||
async def dispatch_many(self, keys: List[K], *args: Any, **kwargs: Any) -> None:
|
||||
log.info("MULTI DISPATCH in {!r}, {} keys", self, len(keys))
|
||||
for key in keys:
|
||||
await self.dispatch(key, *args, **kwargs)
|
||||
|
||||
async def unsub(self, _key, _id):
|
||||
"""Unsubscribe an elemtnt from the channel/key."""
|
||||
raise NotImplementedError
|
||||
async def drop(self, key: K) -> None:
|
||||
"""Drop a key."""
|
||||
...
|
||||
|
||||
async def dispatch_filter(self, _key, _func, *_args):
|
||||
"""Selectively dispatch to the list of subscribed users.
|
||||
async def clear(self, key: K) -> None:
|
||||
"""Clear a key from the backend."""
|
||||
...
|
||||
|
||||
The selection logic is completly arbitraty and up to the
|
||||
Pub/Sub backend.
|
||||
async def dispatch_filter(
|
||||
self, key: K, filter_function: Callable[[K], bool], event: EventType
|
||||
) -> List[str]:
|
||||
"""Selectively dispatch to the list of subscribers.
|
||||
|
||||
Function must return a list of separate identifiers for composability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def dispatch(self, _key, *_args):
|
||||
"""Dispatch an event to the given channel/key."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset(self, _key):
|
||||
"""Reset a key from the backend."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def remove(self, _key):
|
||||
"""Remove a key from the backend.
|
||||
|
||||
The meaning from reset() and remove()
|
||||
is different, reset() is to clear all
|
||||
subscribers from the given key,
|
||||
remove() is to remove the key as well.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
|
||||
"""Dispatch an event to a list of states."""
|
||||
res = []
|
||||
|
||||
for state in states:
|
||||
try:
|
||||
await state.ws.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
except Exception:
|
||||
log.exception("error while dispatching")
|
||||
|
||||
return res
|
||||
...
|
||||
|
||||
|
||||
class DispatcherWithState(Dispatcher):
|
||||
class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
|
||||
"""Pub/Sub backend with a state dictionary.
|
||||
|
||||
This class was made to decrease the amount
|
||||
|
|
@ -105,58 +100,58 @@ class DispatcherWithState(Dispatcher):
|
|||
that have that dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
#: the default dict is to a set
|
||||
# so we make sure someone calling sub()
|
||||
# twice won't get 2x the events for the
|
||||
# same channel.
|
||||
self.state = defaultdict(set)
|
||||
self.state: Dict[K, Set[V]] = defaultdict(set)
|
||||
|
||||
async def sub(self, key, identifier):
|
||||
async def sub(self, key: K, identifier: V):
|
||||
self.state[key].add(identifier)
|
||||
|
||||
async def unsub(self, key, identifier):
|
||||
async def unsub(self, key: K, identifier: V):
|
||||
self.state[key].discard(identifier)
|
||||
|
||||
async def reset(self, key):
|
||||
async def reset(self, key: K):
|
||||
self.state[key] = set()
|
||||
|
||||
async def remove(self, key):
|
||||
async def drop(self, key: K):
|
||||
try:
|
||||
self.state.pop(key)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
async def dispatch(self, key, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DispatcherWithFlags(DispatcherWithState):
|
||||
class DispatcherWithFlags(
|
||||
DispatcherWithState, Generic[K, V, EventType, DispatchType, F],
|
||||
):
|
||||
"""Pub/Sub backend with both a state and a flags store."""
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
|
||||
|
||||
#: keep flags for subscribers, so for example
|
||||
# a subscriber could drop all presence events at the
|
||||
# pubsub level. see gateway's guild_subscriptions field for more
|
||||
self.flags = defaultdict(dict)
|
||||
def set_flags(self, key: K, identifier: V, flags: F):
|
||||
"""Set flags for the given identifier."""
|
||||
self.flags[key][identifier] = flags
|
||||
|
||||
async def sub(self, key, identifier, flags=None):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(key, identifier)
|
||||
self.flags[key][identifier] = flags or {}
|
||||
|
||||
async def unsub(self, key, identifier):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(key, identifier)
|
||||
def remove_flags(self, key: K, identifier: V):
|
||||
"""Set flags for the given identifier."""
|
||||
self.flags[key].pop(identifier)
|
||||
|
||||
def flags_get(self, key, identifier, field: str, default):
|
||||
def get_flags(self, key: K, identifier: V):
|
||||
"""Get a single field from the flags store."""
|
||||
# yes, i know its simply an indirection from the main flags store,
|
||||
# but i'd rather have this than change every call if i ever change
|
||||
# the structure of the flags store.
|
||||
return self.flags[key][identifier].get(field, default)
|
||||
return self.flags[key][identifier]
|
||||
|
||||
async def sub_with_flags(self, key: K, identifier: V, flags: F):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(key, identifier)
|
||||
self.set_flags(key, identifier, flags)
|
||||
|
||||
async def unsub(self, key: K, identifier: V):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(key, identifier)
|
||||
self.remove_flags(key, identifier)
|
||||
|
|
|
|||
|
|
@ -17,15 +17,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Set
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithState
|
||||
from .dispatcher import DispatcherWithState, GatewayEvent
|
||||
from .user import dispatch_user_filter
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
class FriendDispatcher(DispatcherWithState):
|
||||
class FriendDispatcher(DispatcherWithState[int, int, GatewayEvent, List[str]]):
|
||||
"""Friend Pub/Sub logic.
|
||||
|
||||
When connecting, a client will subscribe to all their friends
|
||||
|
|
@ -33,26 +34,18 @@ class FriendDispatcher(DispatcherWithState):
|
|||
broadcasted through that channel to basically all their friends.
|
||||
"""
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
||||
async def dispatch_filter(self, user_id: int, filter_function, event: GatewayEvent):
|
||||
"""Dispatch an event to all of a users' friends."""
|
||||
peer_ids = self.state[user_id]
|
||||
peer_ids: Set[int] = self.state[user_id]
|
||||
sessions: List[str] = []
|
||||
|
||||
for peer_id in peer_ids:
|
||||
# dispatch to the user instead of the "shards tied to a guild"
|
||||
# since relationships broadcast to all shards.
|
||||
sessions.extend(
|
||||
await self.main_dispatcher.dispatch_filter(
|
||||
"user", peer_id, func, event, data
|
||||
)
|
||||
)
|
||||
sessions.extend(await dispatch_user_filter(peer_id, filter_function, event))
|
||||
|
||||
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
|
||||
|
||||
return sessions
|
||||
|
||||
async def dispatch(self, user_id, event, data):
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||
async def dispatch(self, user_id: int, event: GatewayEvent):
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event)
|
||||
|
|
|
|||
|
|
@ -17,123 +17,73 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import current_app as app
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithFlags
|
||||
from litecord.permissions import get_permissions
|
||||
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||
from .channel import ChannelFlags
|
||||
from litecord.gateway.state import GatewayState
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
class GuildDispatcher(DispatcherWithFlags):
|
||||
"""Guild backend for Pub/Sub"""
|
||||
@dataclass
|
||||
class GuildFlags(ChannelFlags):
|
||||
presence: bool
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None):
|
||||
"""Send an action to all channels of the guild."""
|
||||
flags = flags or {}
|
||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||
class GuildDispatcher(
|
||||
DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
|
||||
):
|
||||
"""Guild backend for Pub/Sub."""
|
||||
|
||||
for chan_id in chan_ids:
|
||||
async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
|
||||
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||
for state in states:
|
||||
await self.sub(guild_id, state.session_id)
|
||||
|
||||
# only do an action for users that can
|
||||
# actually read the channel to start with.
|
||||
chan_perms = await get_permissions(
|
||||
user_id, chan_id, storage=self.main_dispatcher.app.storage
|
||||
)
|
||||
return states
|
||||
|
||||
if not chan_perms.bits.read_messages:
|
||||
log.debug("skipping cid={}, no read messages", chan_id)
|
||||
continue
|
||||
|
||||
log.debug("sending raw action {!r} to chan={}", action, chan_id)
|
||||
|
||||
# for now, only sub() has support for flags.
|
||||
# it is an idea to have flags support for other actions
|
||||
args = []
|
||||
if action == "sub":
|
||||
chanflags = dict(flags)
|
||||
|
||||
# channels don't need presence flags
|
||||
try:
|
||||
chanflags.pop("presence")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
args.append(chanflags)
|
||||
|
||||
await self.main_dispatcher.action(
|
||||
"channel", action, chan_id, user_id, *args
|
||||
)
|
||||
|
||||
async def _chan_call(self, meth: str, guild_id: int, *args):
|
||||
"""Call a method on the ChannelDispatcher, for all channels
|
||||
in the guild."""
|
||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||
|
||||
chan_dispatcher = self.main_dispatcher.backends["channel"]
|
||||
method = getattr(chan_dispatcher, meth)
|
||||
|
||||
for chan_id in chan_ids:
|
||||
log.debug("calling {} to chan={}", meth, chan_id)
|
||||
await method(chan_id, *args)
|
||||
|
||||
async def sub(self, guild_id: int, user_id: int, flags=None):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(guild_id, user_id, flags)
|
||||
await self._chan_action("sub", guild_id, user_id, flags)
|
||||
|
||||
async def unsub(self, guild_id: int, user_id: int):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(guild_id, user_id)
|
||||
await self._chan_action("unsub", guild_id, user_id)
|
||||
|
||||
async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
|
||||
"""Selectively dispatch to session ids that have
|
||||
func(session_id) true."""
|
||||
user_ids = self.state[guild_id]
|
||||
dispatched = 0
|
||||
sessions = []
|
||||
|
||||
# acquire a copy since we may be modifying
|
||||
# the original user_ids
|
||||
for user_id in set(user_ids):
|
||||
|
||||
# fetch all states / shards that are tied to the guild.
|
||||
states = self.sm.fetch_states(user_id, guild_id)
|
||||
|
||||
if not states:
|
||||
# user is actually disconnected,
|
||||
# so we should just unsub them
|
||||
await self.unsub(guild_id, user_id)
|
||||
continue
|
||||
|
||||
# skip the given subscriber if event starts with PRESENCE_
|
||||
# and the flags say they don't want it.
|
||||
|
||||
# note that this does not equate to any unsubscription
|
||||
# of the channel.
|
||||
if event.startswith("PRESENCE_") and not self.flags_get(
|
||||
guild_id, user_id, "presence", True
|
||||
async def dispatch_filter(
|
||||
self, guild_id: int, filter_function, event: GatewayEvent
|
||||
):
|
||||
session_ids = self.state[guild_id]
|
||||
sessions: List[str] = []
|
||||
event_type, _ = event
|
||||
|
||||
for session_id in set(session_ids):
|
||||
if not filter_function(session_id):
|
||||
continue
|
||||
|
||||
# filter the ones that matter
|
||||
states = list(filter(lambda state: func(state.session_id), states))
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
await self.unsub(guild_id, session_id)
|
||||
continue
|
||||
|
||||
cur_sess = await self._dispatch_states(states, event, data)
|
||||
try:
|
||||
flags = self.get_flags(guild_id, session_id)
|
||||
except KeyError:
|
||||
log.warning("no flags for {!r}, ignoring", session_id)
|
||||
flags = GuildFlags(presence=True, typing=True)
|
||||
|
||||
sessions.extend(cur_sess)
|
||||
dispatched += len(cur_sess)
|
||||
if event_type.lower().startswith("presence_") and not flags.presence:
|
||||
continue
|
||||
|
||||
log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched)
|
||||
try:
|
||||
await state.ws.dispatch(*event)
|
||||
except Exception:
|
||||
log.exception("error while dispatching to {}", state.session_id)
|
||||
continue
|
||||
|
||||
sessions.append(session_id)
|
||||
|
||||
log.info("Dispatched {} {!r} to {} states", guild_id, event[0], len(sessions))
|
||||
return sessions
|
||||
|
||||
async def dispatch(self, guild_id: int, event: str, data: Any):
|
||||
async def dispatch(self, guild_id: int, event):
|
||||
"""Dispatch an event to all subscribers of the guild."""
|
||||
return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data)
|
||||
return await self.dispatch_filter(guild_id, lambda sess_id: True, event)
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ import asyncio
|
|||
from collections import defaultdict
|
||||
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from quart import current_app as app
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.pubsub.dispatcher import Dispatcher
|
||||
from litecord.permissions import (
|
||||
Permissions,
|
||||
overwrite_find_mix,
|
||||
|
|
@ -239,9 +239,6 @@ class GuildMemberList:
|
|||
|
||||
Attributes
|
||||
----------
|
||||
main_lg: LazyGuildDispatcher
|
||||
Main instance of :class:`LazyGuildDispatcher`,
|
||||
so that we're able to use things such as :class:`Storage`.
|
||||
guild_id: int
|
||||
The Guild ID this instance is referring to.
|
||||
channel_id: int
|
||||
|
|
@ -257,11 +254,10 @@ class GuildMemberList:
|
|||
for example, can still rely on PRESENCE_UPDATEs.
|
||||
"""
|
||||
|
||||
def __init__(self, guild_id: int, channel_id: int, main_lg):
|
||||
def __init__(self, guild_id: int, channel_id: int):
|
||||
self.guild_id = guild_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
self.main = main_lg
|
||||
self.list = MemberList()
|
||||
|
||||
#: store the states that are subscribed to the list.
|
||||
|
|
@ -273,22 +269,22 @@ class GuildMemberList:
|
|||
@property
|
||||
def loop(self):
|
||||
"""Get the main asyncio loop instance."""
|
||||
return self.main.app.loop
|
||||
return app.loop
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
"""Get the global :class:`Storage` instance."""
|
||||
return self.main.app.storage
|
||||
return app.storage
|
||||
|
||||
@property
|
||||
def presence(self):
|
||||
"""Get the global :class:`PresenceManager` instance."""
|
||||
return self.main.app.presence
|
||||
return app.presence
|
||||
|
||||
@property
|
||||
def state_man(self):
|
||||
"""Get the global :class:`StateManager` instance."""
|
||||
return self.main.app.state_manager
|
||||
return app.state_manager
|
||||
|
||||
@property
|
||||
def list_id(self):
|
||||
|
|
@ -572,8 +568,7 @@ class GuildMemberList:
|
|||
Wrapper for :meth:`StateManager.fetch_raw`
|
||||
"""
|
||||
try:
|
||||
state = self.state_man.fetch_raw(session_id)
|
||||
return state
|
||||
return self.state_man.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
|
@ -643,7 +638,7 @@ class GuildMemberList:
|
|||
|
||||
# do resync-ing in the background
|
||||
result.append(session_id)
|
||||
self.loop.create_task(self.shard_query(session_id, [role_range]))
|
||||
app.sched.spawn(self.shard_query(session_id, [role_range]))
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -683,8 +678,7 @@ class GuildMemberList:
|
|||
)
|
||||
|
||||
if everyone_perms.bits.read_messages and list_id != "everyone":
|
||||
everyone_gml = await self.main.get_gml(self.guild_id)
|
||||
|
||||
everyone_gml = await app.lazy_guild.get_gml(self.guild_id)
|
||||
return await everyone_gml.shard_query(session_id, ranges)
|
||||
|
||||
await self._init_check()
|
||||
|
|
@ -1372,47 +1366,36 @@ class GuildMemberList:
|
|||
|
||||
self.guild_id = 0
|
||||
self.channel_id = 0
|
||||
self.main = None
|
||||
self._set_empty_list()
|
||||
self.state = {}
|
||||
|
||||
|
||||
class LazyGuildDispatcher(Dispatcher):
|
||||
class LazyGuildManager:
|
||||
"""Main class holding the member lists for lazy guilds."""
|
||||
|
||||
# channel ids
|
||||
KEY_TYPE = int
|
||||
|
||||
# the session ids subscribing to channels
|
||||
VAL_TYPE = str
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
|
||||
self.storage = main.app.storage
|
||||
|
||||
def __init__(self):
|
||||
# {chan_id: gml, ...}
|
||||
self.state = {}
|
||||
self.state: Dict[int, GuildMemberList] = {}
|
||||
|
||||
#: store which guilds have their
|
||||
# respective GMLs
|
||||
# {guild_id: [chan_id, ...], ...}
|
||||
self.guild_map: Dict[int, List[int]] = defaultdict(list)
|
||||
|
||||
async def get_gml(self, channel_id: int):
|
||||
async def get_gml(self, channel_id: int) -> GuildMemberList:
|
||||
"""Get a guild list for a channel ID,
|
||||
generating it if it doesn't exist."""
|
||||
try:
|
||||
return self.state[channel_id]
|
||||
except KeyError:
|
||||
guild_id = await self.storage.guild_from_channel(channel_id)
|
||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||
|
||||
# if we don't find a guild, we just
|
||||
# set it the same as the channel.
|
||||
if not guild_id:
|
||||
guild_id = channel_id
|
||||
|
||||
gml = GuildMemberList(guild_id, channel_id, self)
|
||||
gml = GuildMemberList(guild_id, channel_id)
|
||||
self.state[channel_id] = gml
|
||||
self.guild_map[guild_id].append(channel_id)
|
||||
return gml
|
||||
|
|
@ -1437,16 +1420,6 @@ class LazyGuildDispatcher(Dispatcher):
|
|||
gml = await self.get_gml(chan_id)
|
||||
gml.unsub(session_id)
|
||||
|
||||
async def dispatch(self, guild_id, event: str, *args, **kwargs):
|
||||
"""Call a function specialized in handling the given event"""
|
||||
try:
|
||||
handler = getattr(self, f"_handle_{event.lower()}")
|
||||
except AttributeError:
|
||||
log.warning("unknown event: {}", event)
|
||||
return
|
||||
|
||||
await handler(guild_id, *args, **kwargs)
|
||||
|
||||
def remove_channel(self, channel_id: int):
|
||||
"""Remove a channel from the manager."""
|
||||
try:
|
||||
|
|
@ -1474,29 +1447,29 @@ class LazyGuildDispatcher(Dispatcher):
|
|||
method = getattr(lazy_list, method_str)
|
||||
await method(*args)
|
||||
|
||||
async def _handle_new_role(self, guild_id: int, new_role: dict):
|
||||
async def new_role(self, guild_id: int, new_role: dict):
|
||||
"""Handle the addition of a new group by dispatching it to
|
||||
the member lists."""
|
||||
await self._call_all_lists(guild_id, "new_role", new_role)
|
||||
|
||||
async def _handle_role_pos_upd(self, guild_id, role: dict):
|
||||
async def role_position_update(self, guild_id, role: dict):
|
||||
await self._call_all_lists(guild_id, "role_pos_update", role)
|
||||
|
||||
async def _handle_role_update(self, guild_id, role: dict):
|
||||
async def role_update(self, guild_id, role: dict):
|
||||
# handle name and hoist changes
|
||||
await self._call_all_lists(guild_id, "role_update", role)
|
||||
|
||||
async def _handle_role_delete(self, guild_id, role_id: int):
|
||||
async def role_delete(self, guild_id, role_id: int):
|
||||
await self._call_all_lists(guild_id, "role_delete", role_id)
|
||||
|
||||
async def _handle_pres_update(self, guild_id, user_id: int, partial: dict):
|
||||
async def pres_update(self, guild_id, user_id: int, partial: dict):
|
||||
await self._call_all_lists(guild_id, "pres_update", user_id, partial)
|
||||
|
||||
async def _handle_new_member(self, guild_id, user_id: int):
|
||||
async def new_member(self, guild_id, user_id: int):
|
||||
await self._call_all_lists(guild_id, "new_member", user_id)
|
||||
|
||||
async def _handle_remove_member(self, guild_id, user_id: int):
|
||||
async def remove_member(self, guild_id, user_id: int):
|
||||
await self._call_all_lists(guild_id, "remove_member", user_id)
|
||||
|
||||
async def _handle_update_user(self, guild_id, user_id: int):
|
||||
async def update_user(self, guild_id, user_id: int):
|
||||
await self._call_all_lists(guild_id, "update_user", user_id)
|
||||
|
|
|
|||
|
|
@ -17,30 +17,20 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from typing import List
|
||||
from quart import current_app as app
|
||||
from .dispatcher import GatewayEvent
|
||||
from .utils import send_event_to_states
|
||||
|
||||
|
||||
class MemberDispatcher(Dispatcher):
|
||||
"""Member backend for Pub/Sub."""
|
||||
async def dispatch_member(
|
||||
guild_id: int, user_id: int, event: GatewayEvent
|
||||
) -> List[str]:
|
||||
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||
|
||||
KEY_TYPE = tuple
|
||||
|
||||
async def dispatch(self, key, event, data):
|
||||
"""Dispatch a single event to a member.
|
||||
|
||||
This is shard-aware.
|
||||
"""
|
||||
# we don't keep any state on this dispatcher, so the key
|
||||
# is just (guild_id, user_id)
|
||||
guild_id, user_id = key
|
||||
|
||||
# fetch shards
|
||||
states = self.sm.fetch_states(user_id, guild_id)
|
||||
|
||||
# if no states were found, we should
|
||||
# unsub the user from the GUILD channel
|
||||
# if no states were found, we should unsub the user from the guild
|
||||
if not states:
|
||||
await self.main_dispatcher.unsub("guild", guild_id, user_id)
|
||||
return
|
||||
await app.dispatcher.guild.unsub(guild_id, user_id)
|
||||
return []
|
||||
|
||||
return await self._dispatch_states(states, event, data)
|
||||
return await send_event_to_states(states, event)
|
||||
|
|
|
|||
|
|
@ -17,23 +17,27 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from typing import Callable, List
|
||||
|
||||
from quart import current_app as app
|
||||
from .dispatcher import GatewayEvent
|
||||
from .utils import send_event_to_states
|
||||
|
||||
|
||||
class UserDispatcher(Dispatcher):
|
||||
"""User backend for Pub/Sub."""
|
||||
|
||||
KEY_TYPE = int
|
||||
|
||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
||||
"""Dispatch an event to all shards of a user."""
|
||||
|
||||
# filter only states where func() gives true
|
||||
async def dispatch_user_filter(
|
||||
user_id: int, filter_func: Callable[[str], bool], event_data: GatewayEvent
|
||||
) -> List[str]:
|
||||
"""Dispatch to a given user's states, but only for states
|
||||
where filter_func returns true."""
|
||||
states = list(
|
||||
filter(lambda state: func(state.session_id), self.sm.user_states(user_id))
|
||||
filter(
|
||||
lambda state: filter_func(state.session_id),
|
||||
app.state_manager.user_states(user_id),
|
||||
)
|
||||
)
|
||||
|
||||
return await self._dispatch_states(states, event, data)
|
||||
return await send_event_to_states(states, event_data)
|
||||
|
||||
async def dispatch(self, user_id: int, event, data):
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||
|
||||
async def dispatch_user(user_id: int, event_data: GatewayEvent) -> List[str]:
|
||||
return await dispatch_user_filter(user_id, lambda sess_id: True, event_data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
|
||||
Litecord
|
||||
Copyright (C) 2018-2019 Luna Mendes
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3 of the License.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple, Any
|
||||
from ..gateway.state import GatewayState
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_event_to_states(
|
||||
states: List[GatewayState], event_data: Tuple[str, Any]
|
||||
) -> List[str]:
|
||||
"""Dispatch an event to a list of states."""
|
||||
res = []
|
||||
|
||||
for state in states:
|
||||
try:
|
||||
event, data = event_data
|
||||
await state.ws.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
except Exception:
|
||||
log.exception("error while dispatching")
|
||||
|
||||
return res
|
||||
|
|
@ -482,6 +482,9 @@ INVITE = {
|
|||
"required": False,
|
||||
"nullable": True,
|
||||
}, # discord client sends invite code there
|
||||
# sent by official client, unknown purpose
|
||||
"target_user_id": {"type": "snowflake", "required": False, "nullable": True},
|
||||
"target_user_type": {"type": "number", "required": False, "nullable": True},
|
||||
}
|
||||
|
||||
USER_SETTINGS = {
|
||||
|
|
|
|||
|
|
@ -709,7 +709,9 @@ class Storage:
|
|||
|
||||
return res
|
||||
|
||||
async def get_guild_extra(self, guild_id: int, user_id=None, large=None) -> Dict:
|
||||
async def get_guild_extra(
|
||||
self, guild_id: int, user_id: Optional[int] = None, large: Optional[int] = None
|
||||
) -> Dict:
|
||||
"""Get extra information about a guild."""
|
||||
res = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -181,9 +181,6 @@ async def send_sys_message(
|
|||
raise ValueError("Invalid system message type")
|
||||
|
||||
message_id = await handler(channel_id, *args, **kwargs)
|
||||
|
||||
message = await app.storage.get_message(message_id)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
|
||||
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", message))
|
||||
return message_id
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ async def maybe_lazy_guild_dispatch(
|
|||
if isinstance(role, dict) and not role["hoist"] and not force:
|
||||
return
|
||||
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
|
||||
await (getattr(app.lazy_guild, event))(guild_id, role)
|
||||
|
||||
|
||||
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
||||
|
|
|
|||
|
|
@ -17,11 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, List
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
from litecord.voice.lvsp_conn import LVSPConnection
|
||||
|
||||
|
|
@ -42,15 +43,15 @@ class LVSPManager:
|
|||
Spawns :class:`LVSPConnection` as needed, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, app, voice):
|
||||
self.app = app
|
||||
def __init__(self, app_, voice):
|
||||
self.app = app_
|
||||
self.voice = voice
|
||||
|
||||
# map servers to LVSPConnection
|
||||
self.conns = {}
|
||||
self.conns: Dict[str, LVSPConnection] = {}
|
||||
|
||||
# maps regions to server hostnames
|
||||
self.servers = defaultdict(list)
|
||||
self.servers: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
# maps Union[GuildID, DMId, GroupDMId] to server hostnames
|
||||
self.assign = {}
|
||||
|
|
@ -84,7 +85,7 @@ class LVSPManager:
|
|||
continue
|
||||
|
||||
self.regions[region.id] = region
|
||||
self.app.loop.create_task(self._spawn_region(region))
|
||||
app.sched.spawn(self._spawn_region(region))
|
||||
|
||||
async def _spawn_region(self, region: Region):
|
||||
"""Spawn a region. Involves fetching all the hostnames
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
from typing import Tuple, Dict, List
|
||||
from collections import defaultdict
|
||||
from dataclasses import fields
|
||||
from quart import current_app as app
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -286,6 +287,6 @@ class VoiceManager:
|
|||
# slow, but it be like that, also copied from other users...
|
||||
for guild_id in guild_ids:
|
||||
guild = await self.app.storage.get_guild_full(guild_id, None)
|
||||
await self.app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
# TODO propagate the channel deprecation to LVSP connections
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def main(config):
|
|||
# as the managers require it
|
||||
# and the migrate command also sets the db up
|
||||
if argv[1] != "migrate":
|
||||
init_app_managers(app, voice=False)
|
||||
init_app_managers(app, init_voice=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
loop.run_until_complete(_ctx_wrapper(app, args))
|
||||
|
|
|
|||
8
run.py
8
run.py
|
|
@ -95,6 +95,7 @@ from litecord.images import IconManager
|
|||
from litecord.jobs import JobManager
|
||||
from litecord.voice.manager import VoiceManager
|
||||
from litecord.guild_memory_store import GuildMemoryStore
|
||||
from litecord.pubsub.lazy_guild import LazyGuildManager
|
||||
|
||||
from litecord.gateway.gateway import websocket_handler
|
||||
|
||||
|
|
@ -254,7 +255,7 @@ async def init_app_db(app_):
|
|||
app_.sched = JobManager()
|
||||
|
||||
|
||||
def init_app_managers(app_, *, voice=True):
|
||||
def init_app_managers(app_: Quart, *, init_voice=True):
|
||||
"""Initialize singleton classes."""
|
||||
app_.loop = asyncio.get_event_loop()
|
||||
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
|
||||
|
|
@ -265,7 +266,7 @@ def init_app_managers(app_, *, voice=True):
|
|||
|
||||
app_.icons = IconManager(app_)
|
||||
|
||||
app_.dispatcher = EventDispatcher(app_)
|
||||
app_.dispatcher = EventDispatcher()
|
||||
app_.presence = PresenceManager(app_)
|
||||
|
||||
app_.storage.presence = app_.presence
|
||||
|
|
@ -274,10 +275,11 @@ def init_app_managers(app_, *, voice=True):
|
|||
# we do this because of a bug on ./manage.py where it
|
||||
# cancels the LVSPManager's spawn regions task. we don't
|
||||
# need to start it on manage time.
|
||||
if voice:
|
||||
if init_voice:
|
||||
app_.voice = VoiceManager(app_)
|
||||
|
||||
app_.guild_store = GuildMemoryStore()
|
||||
app_.lazy_guild = LazyGuildManager()
|
||||
|
||||
|
||||
async def api_index(app_):
|
||||
|
|
|
|||
Loading…
Reference in New Issue