mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'feature/rewrite-dispatcher' into 'master'
dispatcher refactor Closes #84 See merge request litecord/litecord!60
This commit is contained in:
commit
f0f5570dfa
2
Pipfile
2
Pipfile
|
|
@ -18,7 +18,7 @@ zstandard = "*"
|
|||
winter = {editable = true,git = "https://gitlab.com/elixire/winter.git"}
|
||||
|
||||
[dev-packages]
|
||||
pytest = "==5.1.2"
|
||||
pytest = "==5.3.2"
|
||||
pytest-asyncio = "==0.10.0"
|
||||
mypy = "*"
|
||||
black = "*"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3"
|
||||
"sha256": "dedc41184df539a608717e68c108ccfb4d529acb1d1702d83de223840c7cc754"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
|
|
@ -123,41 +123,36 @@
|
|||
},
|
||||
"cffi": {
|
||||
"hashes": [
|
||||
"sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42",
|
||||
"sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04",
|
||||
"sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5",
|
||||
"sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54",
|
||||
"sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba",
|
||||
"sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57",
|
||||
"sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396",
|
||||
"sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12",
|
||||
"sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97",
|
||||
"sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43",
|
||||
"sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db",
|
||||
"sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3",
|
||||
"sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b",
|
||||
"sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579",
|
||||
"sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346",
|
||||
"sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159",
|
||||
"sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652",
|
||||
"sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e",
|
||||
"sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a",
|
||||
"sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506",
|
||||
"sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f",
|
||||
"sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d",
|
||||
"sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c",
|
||||
"sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20",
|
||||
"sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858",
|
||||
"sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc",
|
||||
"sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a",
|
||||
"sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3",
|
||||
"sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e",
|
||||
"sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410",
|
||||
"sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25",
|
||||
"sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b",
|
||||
"sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d"
|
||||
"sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff",
|
||||
"sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b",
|
||||
"sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac",
|
||||
"sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0",
|
||||
"sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384",
|
||||
"sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26",
|
||||
"sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6",
|
||||
"sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b",
|
||||
"sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e",
|
||||
"sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd",
|
||||
"sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2",
|
||||
"sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66",
|
||||
"sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc",
|
||||
"sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8",
|
||||
"sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55",
|
||||
"sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4",
|
||||
"sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5",
|
||||
"sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d",
|
||||
"sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78",
|
||||
"sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa",
|
||||
"sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793",
|
||||
"sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f",
|
||||
"sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a",
|
||||
"sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f",
|
||||
"sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30",
|
||||
"sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f",
|
||||
"sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3",
|
||||
"sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c"
|
||||
],
|
||||
"version": "==1.13.2"
|
||||
"version": "==1.14.0"
|
||||
},
|
||||
"chardet": {
|
||||
"hashes": [
|
||||
|
|
@ -195,10 +190,10 @@
|
|||
},
|
||||
"h2": {
|
||||
"hashes": [
|
||||
"sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e",
|
||||
"sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4"
|
||||
"sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5",
|
||||
"sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14"
|
||||
],
|
||||
"version": "==3.1.1"
|
||||
"version": "==3.2.0"
|
||||
},
|
||||
"hpack": {
|
||||
"hashes": [
|
||||
|
|
@ -510,13 +505,6 @@
|
|||
],
|
||||
"version": "==1.4.3"
|
||||
},
|
||||
"atomicwrites": {
|
||||
"hashes": [
|
||||
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
|
||||
"sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6"
|
||||
],
|
||||
"version": "==1.3.0"
|
||||
},
|
||||
"attrs": {
|
||||
"hashes": [
|
||||
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
|
||||
|
|
@ -647,11 +635,11 @@
|
|||
},
|
||||
"pytest": {
|
||||
"hashes": [
|
||||
"sha256:95d13143cc14174ca1a01ec68e84d76ba5d9d493ac02716fd9706c949a505210",
|
||||
"sha256:b78fe2881323bd44fd9bd76e5317173d4316577e7b1cddebae9136a4495ec865"
|
||||
"sha256:6b571215b5a790f9b41f19f3531c53a45cf6bb8ef2988bc1ff9afb38270b25fa",
|
||||
"sha256:e41d489ff43948babd0fad7ad5e49b8735d5d55e26628a58673c39ff61d95de4"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==5.1.2"
|
||||
"version": "==5.3.2"
|
||||
},
|
||||
"pytest-asyncio": {
|
||||
"hashes": [
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ async def _update_features(guild_id: int, features: list):
|
|||
)
|
||||
|
||||
guild = await app.storage.get_guild_full(guild_id)
|
||||
await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
|
||||
@bp.route("/<int:guild_id>/features", methods=["PATCH"])
|
||||
|
|
|
|||
|
|
@ -63,10 +63,10 @@ async def update_guild(guild_id: int):
|
|||
|
||||
if old_unavailable and not new_unavailable:
|
||||
# guild became available
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild))
|
||||
else:
|
||||
# guild became unavailable
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_DELETE", guild))
|
||||
|
||||
return jsonify(guild)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,13 @@ import itsdangerous
|
|||
import bcrypt
|
||||
from quart import Blueprint, jsonify, request, current_app as app
|
||||
from logbook import Logger
|
||||
from winter import get_snowflake
|
||||
|
||||
from litecord.auth import token_check
|
||||
from litecord.common.users import create_user
|
||||
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
|
||||
from litecord.errors import BadRequest
|
||||
from winter import get_snowflake
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
from .invites import use_invite
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -172,7 +173,7 @@ async def verify_user():
|
|||
)
|
||||
|
||||
new_user = await app.storage.get_user(user_id, True)
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
|
||||
await dispatch_user(user_id, ("USER_UPDATE", new_user))
|
||||
|
||||
return "", 204
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from litecord.common.messages import (
|
|||
msg_add_attachment,
|
||||
msg_guild_text_mentions,
|
||||
)
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -136,10 +137,10 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
|||
|
||||
# dispatch CHANNEL_CREATE so the client knows which
|
||||
# channel the future event is about
|
||||
await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
||||
await dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
||||
|
||||
# subscribe the peer to the channel
|
||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||
await app.dispatcher.channel.sub(channel_id, peer_id)
|
||||
|
||||
# insert it on dm_channel_state so the client
|
||||
# is subscribed on the future
|
||||
|
|
@ -242,7 +243,7 @@ async def _create_message(channel_id):
|
|||
await _dm_pre_dispatch(channel_id, user_id)
|
||||
await _dm_pre_dispatch(channel_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
|
||||
|
||||
# spawn url processor for embedding of images
|
||||
perms = await get_permissions(user_id, channel_id)
|
||||
|
|
@ -350,7 +351,7 @@ async def edit_message(channel_id, message_id):
|
|||
# only dispatch MESSAGE_UPDATE if any update
|
||||
# actually happened
|
||||
if updated:
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message))
|
||||
|
||||
return jsonify(message)
|
||||
|
||||
|
|
@ -427,16 +428,17 @@ async def delete_message(channel_id, message_id):
|
|||
message_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
"MESSAGE_DELETE",
|
||||
{
|
||||
"id": str(message_id),
|
||||
"channel_id": str(channel_id),
|
||||
# for lazy guilds
|
||||
"guild_id": str(guild_id),
|
||||
},
|
||||
(
|
||||
"MESSAGE_DELETE",
|
||||
{
|
||||
"id": str(message_id),
|
||||
"channel_id": str(channel_id),
|
||||
# for lazy guilds
|
||||
"guild_id": str(guild_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -106,11 +106,15 @@ async def add_pin(channel_id, message_id):
|
|||
|
||||
timestamp = snowflake_datetime(row["message_id"])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
|
||||
(
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"last_pin_timestamp": timestamp_(timestamp),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await send_sys_message(
|
||||
|
|
@ -149,11 +153,15 @@ async def delete_pin(channel_id, message_id):
|
|||
|
||||
timestamp = snowflake_datetime(row["message_id"])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
|
||||
(
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"last_pin_timestamp": timestamp.isoformat(),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -141,10 +141,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
|||
if ctype in GUILD_CHANS:
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_REACTION_ADD", payload
|
||||
)
|
||||
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_REACTION_ADD", payload))
|
||||
return "", 204
|
||||
|
||||
|
||||
|
|
@ -206,8 +203,8 @@ async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji
|
|||
if ctype in GUILD_CHANS:
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id, ("MESSAGE_REACTION_REMOVE", payload)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -290,6 +287,6 @@ async def remove_all_reactions(channel_id, message_id):
|
|||
if ctype in GUILD_CHANS:
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id, ("MESSAGE_REACTION_REMOVE_ALL", payload)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,9 +20,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
import time
|
||||
import datetime
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
from winter import snowflake_datetime
|
||||
|
||||
from litecord.auth import token_check
|
||||
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
|
||||
|
|
@ -41,8 +43,9 @@ from litecord.system_messages import send_sys_message
|
|||
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
|
||||
from litecord.utils import search_result_from_list
|
||||
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
||||
from winter import snowflake_datetime
|
||||
from litecord.common.channels import channel_ack
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
from litecord.permissions import get_permissions, Permissions
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint("channels", __name__)
|
||||
|
|
@ -80,9 +83,7 @@ async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
|
|||
|
||||
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
||||
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
|
||||
|
||||
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
|
||||
|
||||
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
|
||||
|
||||
# if none of them were actually updated,
|
||||
|
|
@ -93,7 +94,7 @@ async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
|||
# at least one of the fields were updated,
|
||||
# dispatch GUILD_UPDATE
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
|
||||
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
||||
|
|
@ -104,7 +105,7 @@ async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
|||
return
|
||||
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
|
||||
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||
|
|
@ -134,7 +135,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
|||
# tell all people in the guild of the category removal
|
||||
for child_id in childs:
|
||||
child = await app.storage.get_channel(child_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", child))
|
||||
|
||||
|
||||
async def _delete_messages(channel_id):
|
||||
|
|
@ -249,12 +250,10 @@ async def close_channel(channel_id):
|
|||
)
|
||||
|
||||
# clean its member list representation
|
||||
lazy_guilds = app.dispatcher.backends["lazy_guild"]
|
||||
lazy_guilds.remove_channel(channel_id)
|
||||
app.lazy_guild.remove_channel(channel_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
|
||||
|
||||
await app.dispatcher.remove("channel", channel_id)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_DELETE", chan))
|
||||
await app.dispatcher.channel.drop(channel_id)
|
||||
return jsonify(chan)
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
|
|
@ -273,11 +272,9 @@ async def close_channel(channel_id):
|
|||
channel_id,
|
||||
)
|
||||
|
||||
# unsubscribe
|
||||
await app.dispatcher.unsub("channel", channel_id, user_id)
|
||||
|
||||
# nothing happens to the other party of the dm channel
|
||||
await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
|
||||
await app.dispatcher.channel.unsub(channel_id, user_id)
|
||||
await dispatch_user(user_id, ("CHANNEL_DELETE", chan))
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
|
@ -318,10 +315,54 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
|||
continue
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
|
||||
async def _process_overwrites(channel_id: int, overwrites: list):
|
||||
@dataclass
|
||||
class Target:
|
||||
type: int
|
||||
user_id: Optional[int]
|
||||
role_id: Optional[int]
|
||||
|
||||
@property
|
||||
def is_user(self):
|
||||
return self.type == 0
|
||||
|
||||
@property
|
||||
def is_role(self):
|
||||
return self.type == 1
|
||||
|
||||
|
||||
async def _dispatch_action(guild_id: int, channel_id: int, user_id: int, perms) -> None:
|
||||
"""Apply an action of sub/unsub to all states of a user."""
|
||||
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||
for state in states:
|
||||
if perms.read_messages:
|
||||
await app.dispatcher.channel.sub(channel_id, state.session_id)
|
||||
else:
|
||||
await app.dispatcher.channel.unsub(channel_id, state.session_id)
|
||||
|
||||
|
||||
async def _process_overwrites(guild_id: int, channel_id: int, overwrites: list) -> None:
|
||||
# user_ids serves as a "prospect" user id list.
|
||||
# for each overwrite we apply, we fill this list with user ids we
|
||||
# want to check later to subscribe/unsubscribe from the channel.
|
||||
# (users without read_messages are automatically unsubbed since we
|
||||
# don't want to leak messages to them when they dont have perms anymore)
|
||||
|
||||
# the expensiveness of large overwrite/role chains shines here.
|
||||
# since each user id we fill in implies an entire get_permissions call
|
||||
# (because we don't have the answer if a user is to be subbed/unsubbed
|
||||
# with only overwrites, an overwrite for a user allowing them might be
|
||||
# overwritten by a role overwrite denying them if they have the role),
|
||||
# we get a lot of tension on that, causing channel updates to lag a bit.
|
||||
|
||||
# there may be some good optimizations to do here, such as doing some
|
||||
# precalculations like fetching get_permissions for everyone first, then
|
||||
# applying the new overwrites one by one, then subbing/unsubbing at the
|
||||
# end, but it would be very memory intensive.
|
||||
user_ids: List[int] = []
|
||||
|
||||
for overwrite in overwrites:
|
||||
|
||||
# 0 for member overwrite, 1 for role overwrite
|
||||
|
|
@ -329,7 +370,9 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
|||
target_role = None if target_type == 0 else overwrite["id"]
|
||||
target_user = overwrite["id"] if target_type == 0 else None
|
||||
|
||||
col_name = "target_user" if target_type == 0 else "target_role"
|
||||
target = Target(target_type, target_role, target_user)
|
||||
|
||||
col_name = "target_user" if target.is_user else "target_role"
|
||||
constraint_name = f"channel_overwrites_{col_name}_uniq"
|
||||
|
||||
await app.db.execute(
|
||||
|
|
@ -352,6 +395,17 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
|||
overwrite["deny"],
|
||||
)
|
||||
|
||||
if target.is_user:
|
||||
perms = Permissions(overwrite["allow"] & ~overwrite["deny"])
|
||||
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
|
||||
|
||||
elif target.is_role:
|
||||
user_ids.extend(await app.storage.get_role_members(target.role_id))
|
||||
|
||||
for user_id in user_ids:
|
||||
perms = await get_permissions(user_id, channel_id)
|
||||
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
|
||||
|
||||
|
||||
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
|
||||
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||
|
|
@ -371,6 +425,7 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
|||
)
|
||||
|
||||
await _process_overwrites(
|
||||
guild_id,
|
||||
channel_id,
|
||||
[
|
||||
{
|
||||
|
|
@ -447,7 +502,7 @@ async def _update_channel_common(channel_id: int, guild_id: int, j: dict):
|
|||
|
||||
if "channel_overwrites" in j:
|
||||
overwrites = j["channel_overwrites"]
|
||||
await _process_overwrites(channel_id, overwrites)
|
||||
await _process_overwrites(guild_id, channel_id, overwrites)
|
||||
|
||||
|
||||
async def _common_guild_chan(channel_id, j: dict):
|
||||
|
|
@ -566,9 +621,9 @@ async def update_channel(channel_id: int):
|
|||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
if is_guild:
|
||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
|
||||
else:
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
|
@ -578,17 +633,18 @@ async def trigger_typing(channel_id):
|
|||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
"TYPING_START",
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"user_id": str(user_id),
|
||||
"timestamp": int(time.time()),
|
||||
# guild_id for lazy guilds
|
||||
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
||||
},
|
||||
(
|
||||
"TYPING_START",
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"user_id": str(user_id),
|
||||
"timestamp": int(time.time()),
|
||||
# guild_id for lazy guilds
|
||||
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
@ -816,5 +872,5 @@ async def bulk_delete(channel_id: int):
|
|||
if res == "DELETE 0":
|
||||
raise BadRequest("No messages were removed")
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_DELETE_BULK", payload))
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from litecord.errors import BadRequest, Forbidden
|
|||
from winter import get_snowflake
|
||||
from litecord.system_messages import send_sys_message
|
||||
from litecord.pubsub.channel import gdm_recipient_view
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint("dm_channels", __name__)
|
||||
|
|
@ -82,11 +83,11 @@ async def gdm_create(user_id, peer_id) -> int:
|
|||
await _raw_gdm_add(channel_id, user_id)
|
||||
await _raw_gdm_add(channel_id, peer_id)
|
||||
|
||||
await app.dispatcher.sub("channel", channel_id, user_id)
|
||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||
await app.dispatcher.channel.sub(channel_id, user_id)
|
||||
await app.dispatcher.channel.sub(channel_id, peer_id)
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_CREATE", chan))
|
||||
|
||||
return channel_id
|
||||
|
||||
|
|
@ -104,13 +105,10 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
# the reasoning behind gdm_recipient_view is in its docstring.
|
||||
await app.dispatcher.dispatch(
|
||||
"user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
|
||||
)
|
||||
await dispatch_user(peer_id, ("CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)))
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
await app.dispatcher.sub("channel", peer_id)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
|
||||
await app.dispatcher.channel.sub(peer_id)
|
||||
|
||||
if user_id:
|
||||
await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
|
||||
|
|
@ -128,17 +126,19 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
await _raw_gdm_remove(channel_id, peer_id)
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch(
|
||||
"user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
|
||||
)
|
||||
await dispatch_user(peer_id, ("CHANNEL_DELETE", gdm_recipient_view(chan, user_id)))
|
||||
|
||||
await app.dispatcher.unsub("channel", peer_id)
|
||||
await app.dispatcher.channel.unsub(peer_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id,
|
||||
"CHANNEL_RECIPIENT_REMOVE",
|
||||
{"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
|
||||
(
|
||||
"CHANNEL_RECIPIENT_REMOVE",
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
author_id = peer_id if user_id is None else user_id
|
||||
|
|
@ -174,9 +174,8 @@ async def gdm_destroy(channel_id):
|
|||
channel_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
|
||||
|
||||
await app.dispatcher.remove("channel", channel_id)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_DELETE", chan))
|
||||
await app.dispatcher.channel.drop(channel_id)
|
||||
|
||||
|
||||
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
|
||||
|
|
|
|||
|
|
@ -59,19 +59,9 @@ async def create_channel(guild_id):
|
|||
new_channel_id = get_snowflake()
|
||||
await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
|
||||
|
||||
# TODO: do a better method
|
||||
# subscribe the currently subscribed users to the new channel
|
||||
# by getting all user ids and subscribing each one by one.
|
||||
|
||||
# since GuildDispatcher calls Storage.get_channel_ids,
|
||||
# it will subscribe all users to the newly created channel.
|
||||
guild_pubsub = app.dispatcher.backends["guild"]
|
||||
user_ids = guild_pubsub.state[guild_id]
|
||||
for uid in user_ids:
|
||||
await app.dispatcher.sub("guild", guild_id, uid)
|
||||
|
||||
chan = await app.storage.get_channel(new_channel_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_CREATE", chan))
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
|
|
@ -79,7 +69,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
|
|||
"""Fetch new information about the channel and dispatch
|
||||
a single CHANNEL_UPDATE event to the guild."""
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
|
||||
|
||||
|
||||
async def _do_single_swap(guild_id: int, pair: tuple):
|
||||
|
|
|
|||
|
|
@ -32,14 +32,15 @@ bp = Blueprint("guild.emoji", __name__)
|
|||
|
||||
async def _dispatch_emojis(guild_id):
|
||||
"""Dispatch a Guild Emojis Update payload to a guild."""
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"GUILD_EMOJIS_UPDATE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"emojis": await app.storage.get_guild_emojis(guild_id),
|
||||
},
|
||||
(
|
||||
"GUILD_EMOJIS_UPDATE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"emojis": await app.storage.get_guild_emojis(guild_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -192,12 +192,9 @@ async def modify_guild_member(guild_id, member_id):
|
|||
if nick_flag:
|
||||
partial["nick"] = j["nick"]
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"lazy_guild", guild_id, "pres_update", user_id, partial
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||
await app.lazy_guild.pres_update(guild_id, user_id, partial)
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
@ -228,12 +225,9 @@ async def update_nickname(guild_id):
|
|||
member.pop("joined_at")
|
||||
|
||||
# call pres_update for nick changes, etc.
|
||||
await app.dispatcher.dispatch(
|
||||
"lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||
await app.lazy_guild.pres_update(guild_id, user_id, {"nick": j["nick"]})
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
|
||||
)
|
||||
|
||||
return j["nick"]
|
||||
|
|
|
|||
|
|
@ -86,10 +86,12 @@ async def create_ban(guild_id, member_id):
|
|||
|
||||
await remove_member(guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"GUILD_BAN_ADD",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
(
|
||||
"GUILD_BAN_ADD",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
@ -115,10 +117,12 @@ async def remove_ban(guild_id, banned_id):
|
|||
if res == "DELETE 0":
|
||||
return "", 204
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"GUILD_BAN_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
|
||||
(
|
||||
"GUILD_BAN_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
|
||||
),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -67,8 +67,8 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
|
|||
|
||||
await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role})
|
||||
)
|
||||
|
||||
return role
|
||||
|
|
@ -304,10 +304,9 @@ async def delete_guild_role(guild_id, role_id):
|
|||
|
||||
await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"GUILD_ROLE_DELETE",
|
||||
{"guild_id": str(guild_id), "role_id": str(role_id)},
|
||||
("GUILD_ROLE_DELETE", {"guild_id": str(guild_id), "role_id": str(role_id)},),
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -191,8 +191,9 @@ async def create_guild():
|
|||
|
||||
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
|
||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
|
||||
await app.dispatcher.guild.sub_user(guild_id, user_id)
|
||||
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild_total))
|
||||
return jsonify(guild_total)
|
||||
|
||||
|
||||
|
|
@ -350,7 +351,7 @@ async def _update_guild(guild_id):
|
|||
)
|
||||
|
||||
guild = await app.storage.get_guild_full(guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
return jsonify(guild)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ async def _inv_check_age(inv: dict):
|
|||
await delete_invite(inv["code"])
|
||||
raise InvalidInvite("Invite is expired")
|
||||
|
||||
if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
|
||||
if inv["max_uses"] != -1 and inv["uses"] > inv["max_uses"]:
|
||||
await delete_invite(inv["code"])
|
||||
raise InvalidInvite("Too many uses")
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ from ..auth import token_check
|
|||
from ..schemas import validate, RELATIONSHIP, SPECIFIC_FRIEND
|
||||
from ..enums import RelationshipType
|
||||
from litecord.errors import BadRequest
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
|
||||
bp = Blueprint("relationship", __name__)
|
||||
|
|
@ -36,17 +37,17 @@ async def get_me_relationships():
|
|||
|
||||
|
||||
async def _dispatch_single_pres(user_id, presence: dict):
|
||||
await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
|
||||
await dispatch_user(user_id, ("PRESENCE_UPDATE", presence))
|
||||
|
||||
|
||||
async def _unsub_friend(user_id, peer_id):
|
||||
await app.dispatcher.unsub("friend", user_id, peer_id)
|
||||
await app.dispatcher.unsub("friend", peer_id, user_id)
|
||||
await app.dispatcher.friend.unsub(user_id, peer_id)
|
||||
await app.dispatcher.friend.unsub(peer_id, user_id)
|
||||
|
||||
|
||||
async def _sub_friend(user_id, peer_id):
|
||||
await app.dispatcher.sub("friend", user_id, peer_id)
|
||||
await app.dispatcher.sub("friend", peer_id, user_id)
|
||||
await app.dispatcher.friend.sub(user_id, peer_id)
|
||||
await app.dispatcher.friend.sub(peer_id, user_id)
|
||||
|
||||
# dispatch presence update to the user and peer about
|
||||
# eachother's presence.
|
||||
|
|
@ -107,8 +108,8 @@ async def make_friend(
|
|||
_friend,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user(
|
||||
peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
|
||||
await dispatch_user(
|
||||
peer_id, ("RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)})
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
|
@ -130,35 +131,41 @@ async def make_friend(
|
|||
_friend,
|
||||
)
|
||||
|
||||
_dispatch = app.dispatcher.dispatch_user
|
||||
_dispatch = dispatch_user
|
||||
|
||||
if existing:
|
||||
# accepted a friend request, dispatch respective
|
||||
# relationship events
|
||||
await _dispatch(
|
||||
user_id,
|
||||
"RELATIONSHIP_REMOVE",
|
||||
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
|
||||
(
|
||||
"RELATIONSHIP_REMOVE",
|
||||
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
|
||||
),
|
||||
)
|
||||
|
||||
await _dispatch(
|
||||
user_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(peer_id),
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
(
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(peer_id),
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await _dispatch(
|
||||
peer_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(user_id),
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
(
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(user_id),
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await _sub_friend(user_id, peer_id)
|
||||
|
|
@ -169,22 +176,26 @@ async def make_friend(
|
|||
if rel_type == _friend:
|
||||
await _dispatch(
|
||||
user_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(peer_id),
|
||||
"type": RelationshipType.OUTGOING.value,
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
(
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(peer_id),
|
||||
"type": RelationshipType.OUTGOING.value,
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
await _dispatch(
|
||||
peer_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(user_id),
|
||||
"type": RelationshipType.INCOMING.value,
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
(
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(user_id),
|
||||
"type": RelationshipType.INCOMING.value,
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# we don't make the pubsub link
|
||||
|
|
@ -240,7 +251,7 @@ async def add_relationship(peer_id: int):
|
|||
# make_friend did not succeed, so we
|
||||
# assume it is a block and dispatch
|
||||
# the respective RELATIONSHIP_ADD.
|
||||
await app.dispatcher.dispatch_user(
|
||||
await dispatch_user(
|
||||
user_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
|
|
@ -261,7 +272,7 @@ async def remove_relationship(peer_id: int):
|
|||
user_id = await token_check()
|
||||
_friend = RelationshipType.FRIEND.value
|
||||
_block = RelationshipType.BLOCK.value
|
||||
_dispatch = app.dispatcher.dispatch_user
|
||||
_dispatch = dispatch_user
|
||||
|
||||
rel_type = await app.db.fetchval(
|
||||
"""
|
||||
|
|
@ -307,7 +318,8 @@ async def remove_relationship(peer_id: int):
|
|||
)
|
||||
|
||||
await _dispatch(
|
||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
|
||||
user_id,
|
||||
("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}),
|
||||
)
|
||||
|
||||
peer_del_type = (
|
||||
|
|
@ -315,7 +327,8 @@ async def remove_relationship(peer_id: int):
|
|||
)
|
||||
|
||||
await _dispatch(
|
||||
peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
|
||||
peer_id,
|
||||
("RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}),
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
|
@ -334,7 +347,7 @@ async def remove_relationship(peer_id: int):
|
|||
)
|
||||
|
||||
await _dispatch(
|
||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
|
||||
user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block})
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from quart import Blueprint, jsonify, request, current_app as app
|
|||
from litecord.auth import token_check
|
||||
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
|
||||
from litecord.blueprints.checks import guild_check
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
bp = Blueprint("users_settings", __name__)
|
||||
|
||||
|
|
@ -58,7 +59,7 @@ async def patch_current_settings():
|
|||
)
|
||||
|
||||
settings = await app.user_storage.get_user_settings(user_id)
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
|
||||
await dispatch_user(user_id, ("USER_SETTINGS_UPDATE", settings))
|
||||
return jsonify(settings)
|
||||
|
||||
|
||||
|
|
@ -123,7 +124,7 @@ async def patch_guild_settings(guild_id: int):
|
|||
|
||||
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
|
||||
await dispatch_user(user_id, ("USER_GUILD_SETTINGS_UPDATE", settings))
|
||||
|
||||
return jsonify(settings)
|
||||
|
||||
|
|
@ -157,8 +158,8 @@ async def put_note(target_id: int):
|
|||
note,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
|
||||
await dispatch_user(
|
||||
user_id, ("USER_NOTE_UPDATE", {"id": str(target_id), "note": note})
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -156,11 +156,12 @@ async def webhook_token_check(webhook_id: int, webhook_token: str):
|
|||
|
||||
|
||||
async def _dispatch_webhook_update(guild_id: int, channel_id):
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"WEBHOOKS_UPDATE",
|
||||
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||
(
|
||||
"WEBHOOKS_UPDATE",
|
||||
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -503,7 +504,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
|||
|
||||
payload = await app.storage.get_message(message_id)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
|
||||
|
||||
# spawn embedder in the background, even when we're on a webhook.
|
||||
app.sched.spawn(process_url_embed(payload))
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ from quart import current_app as app
|
|||
|
||||
from litecord.errors import ForbiddenDM
|
||||
from litecord.enums import RelationshipType
|
||||
from litecord.pubsub.member import dispatch_member
|
||||
from litecord.pubsub.user import dispatch_user
|
||||
|
||||
|
||||
async def channel_ack(
|
||||
|
|
@ -54,20 +56,24 @@ async def channel_ack(
|
|||
)
|
||||
|
||||
if guild_id:
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
user_id,
|
||||
await dispatch_member(
|
||||
guild_id,
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
user_id,
|
||||
(
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
),
|
||||
)
|
||||
else:
|
||||
# we don't use ChannelDispatcher here because since
|
||||
# guild_id is None, all user devices are already subscribed
|
||||
# to the given channel (a dm or a group dm)
|
||||
await app.dispatcher.dispatch_user(
|
||||
await dispatch_user(
|
||||
user_id,
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
(
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
from ..snowflake import get_snowflake
|
||||
from ..permissions import get_role_perms
|
||||
from ..permissions import get_role_perms, get_permissions
|
||||
from ..utils import dict_get, maybe_lazy_guild_dispatch
|
||||
from ..enums import ChannelType
|
||||
from litecord.pubsub.member import dispatch_member
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -41,21 +43,20 @@ async def remove_member(guild_id: int, member_id: int):
|
|||
member_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
member_id,
|
||||
await dispatch_member(
|
||||
guild_id,
|
||||
"GUILD_DELETE",
|
||||
{"guild_id": str(guild_id), "unavailable": False},
|
||||
member_id,
|
||||
("GUILD_DELETE", {"guild_id": str(guild_id), "unavailable": False}),
|
||||
)
|
||||
|
||||
await app.dispatcher.unsub("guild", guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
await app.dispatcher.guild.unsub(guild_id, member_id)
|
||||
await app.lazy_guild.remove_member(member_id)
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"GUILD_MEMBER_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
(
|
||||
"GUILD_MEMBER_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -108,8 +109,8 @@ async def create_role(guild_id, name: str, **kwargs):
|
|||
# we need to update the lazy guild handlers for the newly created group
|
||||
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role})
|
||||
)
|
||||
|
||||
return role
|
||||
|
|
@ -137,6 +138,39 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
async def _subscribe_users_new_channel(guild_id: int, channel_id: int) -> None:
|
||||
|
||||
# for each state currently subscribed to guild, we check on the database
|
||||
# which states can also subscribe to the new channel at its creation.
|
||||
|
||||
# the list of users that can subscribe are then used again for a pass
|
||||
# over the states and states that have user ids in that list become
|
||||
# subscribers of the new channel.
|
||||
users_to_sub: List[str] = []
|
||||
|
||||
for session_id in app.dispatcher.guild.state[guild_id]:
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if state.user_id in users_to_sub:
|
||||
continue
|
||||
|
||||
perms = await get_permissions(state.user_id, channel_id)
|
||||
if perms.read_messages:
|
||||
users_to_sub.append(state.user_id)
|
||||
|
||||
for session_id in app.dispatcher.guild.state[guild_id]:
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if state.user_id in users_to_sub:
|
||||
await app.dispatcher.channel.sub(channel_id, session_id)
|
||||
|
||||
|
||||
async def create_guild_channel(
|
||||
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
||||
):
|
||||
|
|
@ -180,6 +214,8 @@ async def create_guild_channel(
|
|||
# so we use this function.
|
||||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||
|
||||
await _subscribe_users_new_channel(guild_id, channel_id)
|
||||
|
||||
|
||||
async def _del_from_table(table: str, user_id: int):
|
||||
"""Delete a row from a table."""
|
||||
|
|
@ -206,21 +242,22 @@ async def delete_guild(guild_id: int):
|
|||
)
|
||||
|
||||
# Discord's client expects IDs being string
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id,
|
||||
"GUILD_DELETE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"id": str(guild_id),
|
||||
# 'unavailable': False,
|
||||
},
|
||||
(
|
||||
"GUILD_DELETE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"id": str(guild_id),
|
||||
# 'unavailable': False,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# remove from the dispatcher so nobody
|
||||
# becomes the little memer that tries to fuck up with
|
||||
# everybody's gateway
|
||||
await app.dispatcher.remove("guild", guild_id)
|
||||
await app.dispatcher.guild.drop(guild_id)
|
||||
|
||||
|
||||
async def create_guild_settings(guild_id: int, user_id: int):
|
||||
|
|
@ -285,18 +322,17 @@ async def add_member(guild_id: int, user_id: int, *, basic=False):
|
|||
|
||||
# tell current members a new member came up
|
||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
|
||||
await app.dispatcher.guild.dispatch(
|
||||
guild_id, ("GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}})
|
||||
)
|
||||
|
||||
# update member lists for the new member
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
|
||||
# pubsub changes for new member
|
||||
await app.lazy_guild.new_member(guild_id, user_id)
|
||||
states = await app.dispatcher.guild.sub_user(guild_id, user_id)
|
||||
|
||||
# subscribe new member to guild, so they get events n stuff
|
||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||
|
||||
# tell the new member that theres the guild it just joined.
|
||||
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
||||
# just to the shards that are actually tied to it.
|
||||
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
|
||||
for state in states:
|
||||
try:
|
||||
await state.ws.dispatch("GUILD_CREATE", guild)
|
||||
except Exception:
|
||||
log.exception("failed to dispatch to session_id={!r}", state.session_id)
|
||||
|
|
|
|||
|
|
@ -29,41 +29,48 @@ from ..snowflake import get_snowflake
|
|||
from ..errors import BadRequest
|
||||
from ..auth import hash_data
|
||||
from ..utils import rand_hex
|
||||
from ..pubsub.user import dispatch_user
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
async def mass_user_update(user_id: int):
|
||||
"""Dispatch USER_UPDATE in a mass way."""
|
||||
# by using dispatch_with_filter
|
||||
# we're guaranteeing all shards will get
|
||||
# a USER_UPDATE once and not any others.
|
||||
async def mass_user_update(user_id: int) -> Tuple[dict, dict]:
|
||||
"""Dispatch a USER_UPDATE to everyone that is subscribed to the user.
|
||||
|
||||
This function guarantees all states will get one USER_UPDATE for simple
|
||||
cases. Lazy guild users might get updates N times depending of how many
|
||||
lists are they subscribed to.
|
||||
"""
|
||||
session_ids: List[str] = []
|
||||
|
||||
public_user = await app.storage.get_user(user_id)
|
||||
private_user = await app.storage.get_user(user_id, secure=True)
|
||||
|
||||
session_ids.extend(
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
|
||||
)
|
||||
session_ids.extend(await dispatch_user(user_id, ("USER_UPDATE", private_user)))
|
||||
|
||||
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
||||
friend_ids = await app.user_storage.get_friend_ids(user_id)
|
||||
guild_ids: List[int] = await app.user_storage.get_user_guilds(user_id)
|
||||
friend_ids: List[int] = await app.user_storage.get_friend_ids(user_id)
|
||||
|
||||
session_ids.extend(
|
||||
await app.dispatcher.dispatch_many_filter_list(
|
||||
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
|
||||
for guild_id in guild_ids:
|
||||
session_ids.extend(
|
||||
await app.dispatcher.guild.dispatch_filter(
|
||||
guild_id,
|
||||
lambda sess_id: sess_id not in session_ids,
|
||||
("USER_UPDATE", public_user),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
session_ids.extend(
|
||||
await app.dispatcher.dispatch_many_filter_list(
|
||||
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
|
||||
for friend_id in friend_ids:
|
||||
session_ids.extend(
|
||||
await app.dispatcher.friend.dispatch_filter(
|
||||
friend_id,
|
||||
lambda sess_id: sess_id not in session_ids,
|
||||
("USER_UPDATE", public_user),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
|
||||
for guild_id in guild_ids:
|
||||
await app.lazy_guild.update_user(guild_id, user_id)
|
||||
|
||||
return public_user, private_user
|
||||
|
||||
|
|
|
|||
|
|
@ -17,17 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from .pubsub import (
|
||||
GuildDispatcher,
|
||||
MemberDispatcher,
|
||||
UserDispatcher,
|
||||
ChannelDispatcher,
|
||||
FriendDispatcher,
|
||||
LazyGuildDispatcher,
|
||||
)
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -50,169 +45,7 @@ class EventDispatcher:
|
|||
its subscriber ids.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.state_manager = app.state_manager
|
||||
self.app = app
|
||||
|
||||
self.backends = {
|
||||
"guild": GuildDispatcher(self),
|
||||
"member": MemberDispatcher(self),
|
||||
"channel": ChannelDispatcher(self),
|
||||
"user": UserDispatcher(self),
|
||||
"friend": FriendDispatcher(self),
|
||||
"lazy_guild": LazyGuildDispatcher(self),
|
||||
}
|
||||
|
||||
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
||||
"""Send an action regarding a key/identifier pair to a backend.
|
||||
|
||||
Action is usually "sub" or "unsub".
|
||||
"""
|
||||
backend = self.backends[backend_str]
|
||||
method = getattr(backend, action)
|
||||
|
||||
# convert keys to the types the backend wants
|
||||
key = backend.KEY_TYPE(key)
|
||||
identifier = backend.VAL_TYPE(identifier)
|
||||
|
||||
return await method(key, identifier, *args)
|
||||
|
||||
async def subscribe(
|
||||
self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
|
||||
):
|
||||
"""Subscribe a single element to the given backend."""
|
||||
flags = flags or {}
|
||||
|
||||
log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
|
||||
|
||||
# this is a hacky solution for backwards compatibility between backends
|
||||
# that implement flags and backends that don't.
|
||||
|
||||
# passing flags to backends that don't implement flags will
|
||||
# cause errors as expected.
|
||||
if flags:
|
||||
return await self.action(backend, "sub", key, identifier, flags)
|
||||
|
||||
return await self.action(backend, "sub", key, identifier)
|
||||
|
||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
||||
"""Unsubscribe an element from the given backend."""
|
||||
log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
|
||||
|
||||
return await self.action(backend, "unsub", key, identifier)
|
||||
|
||||
async def sub(self, backend, key, identifier):
|
||||
"""Alias to subscribe()."""
|
||||
return await self.subscribe(backend, key, identifier)
|
||||
|
||||
async def unsub(self, backend, key, identifier):
|
||||
"""Alias to unsubscribe()."""
|
||||
return await self.unsubscribe(backend, key, identifier)
|
||||
|
||||
async def sub_many(
|
||||
self,
|
||||
backend_str: str,
|
||||
identifier: Any,
|
||||
keys: list,
|
||||
flags: Dict[str, Any] = None,
|
||||
):
|
||||
"""Subscribe to multiple channels (all in a single backend)
|
||||
at a time.
|
||||
|
||||
Usually used when connecting to the gateway and the client
|
||||
needs to subscribe to all their guids.
|
||||
"""
|
||||
flags = flags or {}
|
||||
for key in keys:
|
||||
await self.subscribe(backend_str, key, identifier, flags)
|
||||
|
||||
async def mass_sub(self, identifier: Any, backends: List[tuple]):
|
||||
"""Mass subscribe to many backends at once."""
|
||||
for bcall in backends:
|
||||
backend_str, keys = bcall[0], bcall[1]
|
||||
|
||||
if len(bcall) == 2:
|
||||
flags = {}
|
||||
elif len(bcall) == 3:
|
||||
# we have flags
|
||||
flags = bcall[2]
|
||||
|
||||
log.debug(
|
||||
"subscribing {} to {} keys in backend {}, flags: {}",
|
||||
identifier,
|
||||
len(keys),
|
||||
backend_str,
|
||||
flags,
|
||||
)
|
||||
|
||||
await self.sub_many(backend_str, identifier, keys, flags)
|
||||
|
||||
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
||||
"""Dispatch an event to the backend.
|
||||
|
||||
The backend is responsible for everything regarding the
|
||||
actual dispatch.
|
||||
"""
|
||||
backend = self.backends[backend_str]
|
||||
|
||||
# convert types
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.dispatch(key, *args, **kwargs)
|
||||
|
||||
async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
|
||||
"""Dispatch to multiple keys in a single backend."""
|
||||
log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
|
||||
|
||||
for key in keys:
|
||||
await self.dispatch(backend_str, key, *args, **kwargs)
|
||||
|
||||
async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
|
||||
"""Dispatch to a backend that only accepts
|
||||
(event, data) arguments with an optional filter
|
||||
function."""
|
||||
backend = self.backends[backend_str]
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.dispatch_filter(key, func, *args)
|
||||
|
||||
async def dispatch_many_filter_list(
|
||||
self, backend_str: str, keys: List[Any], sess_list: List[str], *args
|
||||
):
|
||||
"""Make a "unique" dispatch given a list of session ids.
|
||||
|
||||
This only works for backends that have a dispatch_filter
|
||||
handler and return session id lists in their dispatch
|
||||
results.
|
||||
"""
|
||||
for key in keys:
|
||||
sess_list.extend(
|
||||
await self.dispatch_filter(
|
||||
backend_str, key, lambda sess_id: sess_id not in sess_list, *args
|
||||
)
|
||||
)
|
||||
|
||||
return sess_list
|
||||
|
||||
async def reset(self, backend_str: str, key: Any):
|
||||
"""Reset the bucket in the given backend."""
|
||||
backend = self.backends[backend_str]
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.reset(key)
|
||||
|
||||
async def remove(self, backend_str: str, key: Any):
|
||||
"""Remove a key from the backend. This
|
||||
might be a different operation than resetting."""
|
||||
backend = self.backends[backend_str]
|
||||
key = backend.KEY_TYPE(key)
|
||||
return await backend.remove(key)
|
||||
|
||||
async def dispatch_guild(self, guild_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch("guild", guild_id, event, data)
|
||||
|
||||
async def dispatch_user_guild(self, user_id, guild_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch("member", (guild_id, user_id), event, data)
|
||||
|
||||
async def dispatch_user(self, user_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch("user", user_id, event, data)
|
||||
def __init__(self):
|
||||
self.guild: GuildDispatcher = GuildDispatcher()
|
||||
self.channel = ChannelDispatcher()
|
||||
self.friend = FriendDispatcher()
|
||||
|
|
|
|||
|
|
@ -87,8 +87,8 @@ async def msg_update_embeds(payload, new_embeds):
|
|||
if "flags" in payload:
|
||||
update_payload["flags"] = payload["flags"]
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
"channel", channel_id, "MESSAGE_UPDATE", update_payload
|
||||
await app.dispatcher.channel.dispatch(
|
||||
channel_id, ("MESSAGE_UPDATE", update_payload)
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
from litecord.gateway.websocket import GatewayWebsocket
|
||||
|
||||
|
||||
|
|
@ -44,17 +45,18 @@ async def websocket_handler(app, ws, url):
|
|||
return await ws.close(1000, "Invalid gateway encoding")
|
||||
|
||||
try:
|
||||
gw_compress = args["compress"][0]
|
||||
gw_compress: Optional[str] = args["compress"][0]
|
||||
except (KeyError, IndexError):
|
||||
gw_compress = None
|
||||
|
||||
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
|
||||
return await ws.close(1000, "Invalid gateway compress")
|
||||
|
||||
gws = GatewayWebsocket(
|
||||
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
||||
)
|
||||
async with app.app_context():
|
||||
gws = GatewayWebsocket(
|
||||
ws, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
||||
)
|
||||
|
||||
# this can be run with a single await since this whole coroutine
|
||||
# is already running in the background.
|
||||
await gws.run()
|
||||
# this can be run with a single await since this whole coroutine
|
||||
# is already running in the background.
|
||||
await gws.run()
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import asyncio
|
|||
from typing import List
|
||||
from collections import defaultdict
|
||||
|
||||
from quart import current_app as app
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -225,3 +226,16 @@ class StateManager:
|
|||
def close(self):
|
||||
"""Close the state manager."""
|
||||
self.closed = True
|
||||
|
||||
async def fetch_user_states_for_channel(
|
||||
self, channel_id: int, user_id: int
|
||||
) -> List[GatewayState]:
|
||||
"""Get a list of gateway states for a user that can receive events on a certain channel."""
|
||||
# TODO optimize this with an in-memory store
|
||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||
|
||||
if guild_id:
|
||||
return self.fetch_states(user_id, guild_id)
|
||||
|
||||
# DMs and GDMs use all user states
|
||||
return self.user_states(user_id)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from random import randint
|
|||
import websockets
|
||||
import zstandard as zstd
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType, ChannelType
|
||||
|
|
@ -43,7 +44,6 @@ from litecord.presence import BasePresence
|
|||
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.gateway.state import GatewayState
|
||||
|
||||
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
|
||||
from litecord.gateway.errors import (
|
||||
DecodeError,
|
||||
|
|
@ -52,8 +52,9 @@ from litecord.gateway.errors import (
|
|||
ShardingRequired,
|
||||
)
|
||||
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
|
||||
|
||||
from litecord.gateway.utils import WebsocketFileHandler
|
||||
from litecord.pubsub.guild import GuildFlags
|
||||
from litecord.pubsub.channel import ChannelFlags
|
||||
|
||||
from litecord.storage import int_
|
||||
|
||||
|
|
@ -67,7 +68,7 @@ WebsocketProperties = collections.namedtuple(
|
|||
class GatewayWebsocket:
|
||||
"""Main gateway websocket logic."""
|
||||
|
||||
def __init__(self, ws, app, **kwargs):
|
||||
def __init__(self, ws, **kwargs):
|
||||
self.app = app
|
||||
self.storage = app.storage
|
||||
self.user_storage = app.user_storage
|
||||
|
|
@ -230,7 +231,7 @@ class GatewayWebsocket:
|
|||
if task:
|
||||
task.cancel()
|
||||
|
||||
self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
|
||||
self.wsp.tasks["heartbeat"] = app.sched.spawn(
|
||||
task_wrapper("hb wait", self._hb_wait(interval))
|
||||
)
|
||||
|
||||
|
|
@ -247,6 +248,7 @@ class GatewayWebsocket:
|
|||
|
||||
async def dispatch(self, event: str, data: Any):
|
||||
"""Dispatch an event to the websocket."""
|
||||
assert self.state is not None
|
||||
self.state.seq += 1
|
||||
|
||||
payload = {
|
||||
|
|
@ -282,6 +284,7 @@ class GatewayWebsocket:
|
|||
|
||||
async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
|
||||
"""Dispatch GUILD_CREATE information."""
|
||||
assert self.state is not None
|
||||
|
||||
# Users don't get asynchronous guild dispatching.
|
||||
if not self.state.bot:
|
||||
|
|
@ -360,9 +363,7 @@ class GatewayWebsocket:
|
|||
}
|
||||
|
||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||
|
||||
# async dispatch of guilds
|
||||
self.app.loop.create_task(self._guild_dispatch(guilds))
|
||||
app.sched.spawn(self._guild_dispatch(guilds))
|
||||
|
||||
async def _check_shards(self, shard, user_id):
|
||||
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
||||
|
|
@ -412,6 +413,7 @@ class GatewayWebsocket:
|
|||
Note: subscribing to channels is already handled
|
||||
by GuildDispatcher.sub
|
||||
"""
|
||||
assert self.state is not None
|
||||
user_id = self.state.user_id
|
||||
guild_ids = await self._guild_ids()
|
||||
|
||||
|
|
@ -434,26 +436,50 @@ class GatewayWebsocket:
|
|||
# (presence and typing events)
|
||||
|
||||
# we enable processing of guild_subscriptions by adding flags
|
||||
# when subscribing to the given backend. those are optional.
|
||||
channels_to_sub = [
|
||||
(
|
||||
"guild",
|
||||
guild_ids,
|
||||
{"presence": guild_subscriptions, "typing": guild_subscriptions},
|
||||
),
|
||||
("channel", dm_ids),
|
||||
("channel", gdm_ids),
|
||||
]
|
||||
# when subscribing to the given backend.
|
||||
session_id = self.state.session_id
|
||||
channel_ids: List[int] = []
|
||||
|
||||
await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||
for guild_id in guild_ids:
|
||||
await app.dispatcher.guild.sub_with_flags(
|
||||
guild_id,
|
||||
session_id,
|
||||
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
|
||||
)
|
||||
|
||||
# instead of calculating which channels to subscribe to
|
||||
# inside guild dispatcher, we calculate them in here, so that
|
||||
# we remove complexity of the dispatcher.
|
||||
|
||||
guild_chan_ids = await app.storage.get_channel_ids(guild_id)
|
||||
for channel_id in guild_chan_ids:
|
||||
perms = await get_permissions(
|
||||
self.state.user_id, channel_id, storage=self.storage
|
||||
)
|
||||
|
||||
if perms.bits.read_messages:
|
||||
channel_ids.append(channel_id)
|
||||
|
||||
log.info("subscribing to {} guild channels", len(channel_ids))
|
||||
for channel_id in channel_ids:
|
||||
await app.dispatcher.channel.sub_with_flags(
|
||||
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
|
||||
)
|
||||
|
||||
for dm_id in dm_ids:
|
||||
await app.dispatcher.channel.sub(dm_id, session_id)
|
||||
|
||||
for gdm_id in gdm_ids:
|
||||
await app.dispatcher.channel.sub(gdm_id, session_id)
|
||||
|
||||
# subscribe to all friends
|
||||
# (their friends will also subscribe back
|
||||
# when they come online)
|
||||
if not self.state.bot:
|
||||
# subscribe to all friends
|
||||
# (their friends will also subscribe back
|
||||
# when they come online)
|
||||
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
||||
log.info("subscribing to {} friends", len(friend_ids))
|
||||
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||
for friend_id in friend_ids:
|
||||
await app.dispatcher.friend.sub(user_id, friend_id)
|
||||
|
||||
async def update_status(self, incoming_status: dict):
|
||||
"""Update the status of the current websocket connection."""
|
||||
|
|
@ -921,6 +947,7 @@ class GatewayWebsocket:
|
|||
]
|
||||
}
|
||||
"""
|
||||
assert self.state is not None
|
||||
data = payload["d"]
|
||||
|
||||
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||
|
|
@ -933,11 +960,9 @@ class GatewayWebsocket:
|
|||
log.debug("lazy request: members: {}", data.get("members", []))
|
||||
|
||||
# make shard query
|
||||
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
|
||||
|
||||
for chan_id, ranges in data.get("channels", {}).items():
|
||||
chan_id = int(chan_id)
|
||||
member_list = await lazy_guilds.get_gml(chan_id)
|
||||
member_list = await app.lazy_guild.get_gml(chan_id)
|
||||
|
||||
perms = await get_permissions(
|
||||
self.state.user_id, chan_id, storage=self.storage
|
||||
|
|
|
|||
|
|
@ -93,7 +93,6 @@ class PresenceManager:
|
|||
self.storage = app.storage
|
||||
self.user_storage = app.user_storage
|
||||
self.state_manager = app.state_manager
|
||||
self.dispatcher = app.dispatcher
|
||||
|
||||
async def guild_presences(
|
||||
self, member_ids: List[int], guild_id: int
|
||||
|
|
@ -127,8 +126,7 @@ class PresenceManager:
|
|||
|
||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||
|
||||
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||
lists = app.lazy_guild.get_gml_guild(guild_id)
|
||||
|
||||
# shards that are in lazy guilds with 'everyone'
|
||||
# enabled
|
||||
|
|
@ -163,20 +161,21 @@ class PresenceManager:
|
|||
# given a session id, return if the session id actually connects to
|
||||
# a given user, and if the state has not been dispatched via lazy guild.
|
||||
def _session_check(session_id):
|
||||
state = self.state_manager.fetch_raw(session_id)
|
||||
uid = int(member["user"]["id"])
|
||||
|
||||
if not state:
|
||||
try:
|
||||
state = self.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
uid = int(member["user"]["id"])
|
||||
|
||||
# we don't want to send a presence update
|
||||
# to the same user
|
||||
return state.user_id != uid and session_id not in in_lazy
|
||||
|
||||
# everyone not in lazy guild mode
|
||||
# gets a PRESENCE_UPDATE
|
||||
await self.dispatcher.dispatch_filter(
|
||||
"guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload
|
||||
await app.dispatcher.guild.dispatch_filter(
|
||||
guild_id, _session_check, ("PRESENCE_UPDATE", event_payload)
|
||||
)
|
||||
|
||||
return in_lazy
|
||||
|
|
@ -193,11 +192,8 @@ class PresenceManager:
|
|||
|
||||
# dispatch to all friends that are subscribed to them
|
||||
user = await self.storage.get_user(user_id)
|
||||
await self.dispatcher.dispatch(
|
||||
"friend",
|
||||
user_id,
|
||||
"PRESENCE_UPDATE",
|
||||
{**presence.partial_dict, **{"user": user}},
|
||||
await app.dispatcher.friend.dispatch(
|
||||
user_id, ("PRESENCE_UPDATE", {**presence.partial_dict, **{"user": user}}),
|
||||
)
|
||||
|
||||
def fetch_friend_presence(self, friend_id: int) -> BasePresence:
|
||||
|
|
|
|||
|
|
@ -18,17 +18,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from .guild import GuildDispatcher
|
||||
from .member import MemberDispatcher
|
||||
from .user import UserDispatcher
|
||||
from .member import dispatch_member
|
||||
from .user import dispatch_user
|
||||
from .channel import ChannelDispatcher
|
||||
from .friend import FriendDispatcher
|
||||
from .lazy_guild import LazyGuildDispatcher
|
||||
|
||||
__all__ = [
|
||||
"GuildDispatcher",
|
||||
"MemberDispatcher",
|
||||
"UserDispatcher",
|
||||
"dispatch_member",
|
||||
"dispatch_user",
|
||||
"ChannelDispatcher",
|
||||
"FriendDispatcher",
|
||||
"LazyGuildDispatcher",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import current_app as app
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithFlags
|
||||
from litecord.enums import ChannelType
|
||||
from litecord.utils import index_by_func
|
||||
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -37,75 +39,66 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
|||
"""
|
||||
# make a copy or the original channel object
|
||||
data = dict(orig)
|
||||
|
||||
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
|
||||
|
||||
data["recipients"].pop(idx)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ChannelDispatcher(DispatcherWithFlags):
|
||||
"""Main channel Pub/Sub logic."""
|
||||
@dataclass
|
||||
class ChannelFlags:
|
||||
typing: bool
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def dispatch(self, channel_id, event: str, data: Any) -> List[str]:
|
||||
class ChannelDispatcher(
|
||||
DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
|
||||
):
|
||||
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
|
||||
|
||||
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
|
||||
"""Dispatch an event to a channel."""
|
||||
# get everyone who is subscribed
|
||||
# and store the number of states we dispatched the event to
|
||||
user_ids = self.state[channel_id]
|
||||
dispatched = 0
|
||||
session_ids = set(self.state[channel_id])
|
||||
sessions: List[str] = []
|
||||
|
||||
# making a copy of user_ids since
|
||||
# we'll modify it later on.
|
||||
for user_id in set(user_ids):
|
||||
guild_id = await self.app.storage.guild_from_channel(channel_id)
|
||||
event_type, event_data = event
|
||||
assert isinstance(event_data, dict)
|
||||
|
||||
# if we are dispatching to a guild channel,
|
||||
# we should only dispatch to the states / shards
|
||||
# that are connected to the guild (via their shard id).
|
||||
|
||||
# if we aren't, we just get all states tied to the user.
|
||||
# TODO: make a fetch_states that fetches shards
|
||||
# - with id 0 (count any) OR
|
||||
# - single shards (id=0, count=1)
|
||||
states = (
|
||||
self.sm.fetch_states(user_id, guild_id)
|
||||
if guild_id
|
||||
else self.sm.user_states(user_id)
|
||||
)
|
||||
|
||||
# unsub people who don't have any states tied to the channel.
|
||||
if not states:
|
||||
await self.unsub(channel_id, user_id)
|
||||
for session_id in session_ids:
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
await self.unsub(channel_id, session_id)
|
||||
continue
|
||||
|
||||
# skip typing events for users that don't want it
|
||||
if event.startswith("TYPING_") and not self.flags_get(
|
||||
channel_id, user_id, "typing", True
|
||||
):
|
||||
try:
|
||||
flags = self.get_flags(channel_id, session_id)
|
||||
except KeyError:
|
||||
log.warning("no flags for {!r}, ignoring", session_id)
|
||||
flags = ChannelFlags(typing=True)
|
||||
|
||||
if event_type.lower().startswith("typing_") and not flags.typing:
|
||||
continue
|
||||
|
||||
cur_sess: List[str] = []
|
||||
|
||||
correct_event = event
|
||||
# for cases where we are talking about group dms, we create an edited
|
||||
# event data so that it doesn't show the user we're dispatching
|
||||
# to in data.recipients (clients already assume they are recipients)
|
||||
if (
|
||||
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||
and data.get("type") == ChannelType.GROUP_DM.value
|
||||
event_type in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||
and event_data.get("type") == ChannelType.GROUP_DM.value
|
||||
):
|
||||
# we edit the channel payload so it doesn't show
|
||||
# the user as a recipient
|
||||
new_data = gdm_recipient_view(event_data, state.user_id)
|
||||
correct_event = (event_type, new_data)
|
||||
|
||||
new_data = gdm_recipient_view(data, user_id)
|
||||
cur_sess = await self._dispatch_states(states, event, new_data)
|
||||
else:
|
||||
cur_sess = await self._dispatch_states(states, event, data)
|
||||
try:
|
||||
await state.ws.dispatch(*correct_event)
|
||||
except Exception:
|
||||
log.exception("error while dispatching to {}", state.session_id)
|
||||
continue
|
||||
|
||||
sessions.extend(cur_sess)
|
||||
dispatched += len(cur_sess)
|
||||
sessions.append(session_id)
|
||||
|
||||
log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
|
||||
log.info(
|
||||
"Dispatched chan={} {!r} to {} states", channel_id, event[0], len(sessions)
|
||||
)
|
||||
|
||||
return sessions
|
||||
|
|
|
|||
|
|
@ -17,7 +17,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import (
|
||||
List,
|
||||
Generic,
|
||||
TypeVar,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Set,
|
||||
Mapping,
|
||||
Iterable,
|
||||
Tuple,
|
||||
)
|
||||
from collections import defaultdict
|
||||
|
||||
from logbook import Logger
|
||||
|
|
@ -25,79 +36,63 @@ from logbook import Logger
|
|||
log = Logger(__name__)
|
||||
|
||||
|
||||
def _identity(_self, x):
|
||||
return x
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
F = TypeVar("F")
|
||||
EventType = TypeVar("EventType")
|
||||
DispatchType = TypeVar("DispatchType")
|
||||
F_Map = Mapping[V, F]
|
||||
|
||||
GatewayEvent = Tuple[str, Any]
|
||||
|
||||
__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"]
|
||||
|
||||
|
||||
class Dispatcher:
|
||||
class Dispatcher(Generic[K, V, EventType, DispatchType]):
|
||||
"""Pub/Sub backend dispatcher.
|
||||
|
||||
This just declares functions all Dispatcher subclasses
|
||||
can implement. This does not mean all Dispatcher
|
||||
subclasses have them implemented.
|
||||
Classes must implement this protocol.
|
||||
"""
|
||||
|
||||
KEY_TYPE = _identity
|
||||
VAL_TYPE = _identity
|
||||
async def sub(self, key: K, identifier: V) -> None:
|
||||
"""Subscribe a given identifier to a given key."""
|
||||
...
|
||||
|
||||
def __init__(self, main):
|
||||
#: main EventDispatcher
|
||||
self.main_dispatcher = main
|
||||
async def sub_many(self, key: K, identifier_list: Iterable[V]) -> None:
|
||||
for identifier in identifier_list:
|
||||
await self.sub(key, identifier)
|
||||
|
||||
#: gateway state storage
|
||||
self.sm = main.state_manager
|
||||
async def unsub(self, key: K, identifier: V) -> None:
|
||||
"""Unsubscribe a given identifier to a given key."""
|
||||
...
|
||||
|
||||
self.app = main.app
|
||||
async def dispatch(self, key: K, event: EventType) -> DispatchType:
|
||||
...
|
||||
|
||||
async def sub(self, _key, _id):
|
||||
"""Subscribe an element to the channel/key."""
|
||||
raise NotImplementedError
|
||||
async def dispatch_many(self, keys: List[K], *args: Any, **kwargs: Any) -> None:
|
||||
log.info("MULTI DISPATCH in {!r}, {} keys", self, len(keys))
|
||||
for key in keys:
|
||||
await self.dispatch(key, *args, **kwargs)
|
||||
|
||||
async def unsub(self, _key, _id):
|
||||
"""Unsubscribe an elemtnt from the channel/key."""
|
||||
raise NotImplementedError
|
||||
async def drop(self, key: K) -> None:
|
||||
"""Drop a key."""
|
||||
...
|
||||
|
||||
async def dispatch_filter(self, _key, _func, *_args):
|
||||
"""Selectively dispatch to the list of subscribed users.
|
||||
async def clear(self, key: K) -> None:
|
||||
"""Clear a key from the backend."""
|
||||
...
|
||||
|
||||
The selection logic is completly arbitraty and up to the
|
||||
Pub/Sub backend.
|
||||
async def dispatch_filter(
|
||||
self, key: K, filter_function: Callable[[K], bool], event: EventType
|
||||
) -> List[str]:
|
||||
"""Selectively dispatch to the list of subscribers.
|
||||
|
||||
Function must return a list of separate identifiers for composability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def dispatch(self, _key, *_args):
|
||||
"""Dispatch an event to the given channel/key."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset(self, _key):
|
||||
"""Reset a key from the backend."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def remove(self, _key):
|
||||
"""Remove a key from the backend.
|
||||
|
||||
The meaning from reset() and remove()
|
||||
is different, reset() is to clear all
|
||||
subscribers from the given key,
|
||||
remove() is to remove the key as well.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
|
||||
"""Dispatch an event to a list of states."""
|
||||
res = []
|
||||
|
||||
for state in states:
|
||||
try:
|
||||
await state.ws.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
except Exception:
|
||||
log.exception("error while dispatching")
|
||||
|
||||
return res
|
||||
...
|
||||
|
||||
|
||||
class DispatcherWithState(Dispatcher):
|
||||
class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
|
||||
"""Pub/Sub backend with a state dictionary.
|
||||
|
||||
This class was made to decrease the amount
|
||||
|
|
@ -105,58 +100,58 @@ class DispatcherWithState(Dispatcher):
|
|||
that have that dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
#: the default dict is to a set
|
||||
# so we make sure someone calling sub()
|
||||
# twice won't get 2x the events for the
|
||||
# same channel.
|
||||
self.state = defaultdict(set)
|
||||
self.state: Dict[K, Set[V]] = defaultdict(set)
|
||||
|
||||
async def sub(self, key, identifier):
|
||||
async def sub(self, key: K, identifier: V):
|
||||
self.state[key].add(identifier)
|
||||
|
||||
async def unsub(self, key, identifier):
|
||||
async def unsub(self, key: K, identifier: V):
|
||||
self.state[key].discard(identifier)
|
||||
|
||||
async def reset(self, key):
|
||||
async def reset(self, key: K):
|
||||
self.state[key] = set()
|
||||
|
||||
async def remove(self, key):
|
||||
async def drop(self, key: K):
|
||||
try:
|
||||
self.state.pop(key)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
async def dispatch(self, key, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DispatcherWithFlags(DispatcherWithState):
|
||||
class DispatcherWithFlags(
|
||||
DispatcherWithState, Generic[K, V, EventType, DispatchType, F],
|
||||
):
|
||||
"""Pub/Sub backend with both a state and a flags store."""
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
|
||||
|
||||
#: keep flags for subscribers, so for example
|
||||
# a subscriber could drop all presence events at the
|
||||
# pubsub level. see gateway's guild_subscriptions field for more
|
||||
self.flags = defaultdict(dict)
|
||||
def set_flags(self, key: K, identifier: V, flags: F):
|
||||
"""Set flags for the given identifier."""
|
||||
self.flags[key][identifier] = flags
|
||||
|
||||
async def sub(self, key, identifier, flags=None):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(key, identifier)
|
||||
self.flags[key][identifier] = flags or {}
|
||||
|
||||
async def unsub(self, key, identifier):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(key, identifier)
|
||||
def remove_flags(self, key: K, identifier: V):
|
||||
"""Set flags for the given identifier."""
|
||||
self.flags[key].pop(identifier)
|
||||
|
||||
def flags_get(self, key, identifier, field: str, default):
|
||||
def get_flags(self, key: K, identifier: V):
|
||||
"""Get a single field from the flags store."""
|
||||
# yes, i know its simply an indirection from the main flags store,
|
||||
# but i'd rather have this than change every call if i ever change
|
||||
# the structure of the flags store.
|
||||
return self.flags[key][identifier].get(field, default)
|
||||
return self.flags[key][identifier]
|
||||
|
||||
async def sub_with_flags(self, key: K, identifier: V, flags: F):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(key, identifier)
|
||||
self.set_flags(key, identifier, flags)
|
||||
|
||||
async def unsub(self, key: K, identifier: V):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(key, identifier)
|
||||
self.remove_flags(key, identifier)
|
||||
|
|
|
|||
|
|
@ -17,15 +17,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import List, Set
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithState
|
||||
from .dispatcher import DispatcherWithState, GatewayEvent
|
||||
from .user import dispatch_user_filter
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
class FriendDispatcher(DispatcherWithState):
|
||||
class FriendDispatcher(DispatcherWithState[int, int, GatewayEvent, List[str]]):
|
||||
"""Friend Pub/Sub logic.
|
||||
|
||||
When connecting, a client will subscribe to all their friends
|
||||
|
|
@ -33,26 +34,18 @@ class FriendDispatcher(DispatcherWithState):
|
|||
broadcasted through that channel to basically all their friends.
|
||||
"""
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
||||
async def dispatch_filter(self, user_id: int, filter_function, event: GatewayEvent):
|
||||
"""Dispatch an event to all of a users' friends."""
|
||||
peer_ids = self.state[user_id]
|
||||
peer_ids: Set[int] = self.state[user_id]
|
||||
sessions: List[str] = []
|
||||
|
||||
for peer_id in peer_ids:
|
||||
# dispatch to the user instead of the "shards tied to a guild"
|
||||
# since relationships broadcast to all shards.
|
||||
sessions.extend(
|
||||
await self.main_dispatcher.dispatch_filter(
|
||||
"user", peer_id, func, event, data
|
||||
)
|
||||
)
|
||||
sessions.extend(await dispatch_user_filter(peer_id, filter_function, event))
|
||||
|
||||
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
|
||||
|
||||
return sessions
|
||||
|
||||
async def dispatch(self, user_id, event, data):
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||
async def dispatch(self, user_id: int, event: GatewayEvent):
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event)
|
||||
|
|
|
|||
|
|
@ -17,123 +17,73 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import current_app as app
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithFlags
|
||||
from litecord.permissions import get_permissions
|
||||
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||
from .channel import ChannelFlags
|
||||
from litecord.gateway.state import GatewayState
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
class GuildDispatcher(DispatcherWithFlags):
|
||||
"""Guild backend for Pub/Sub"""
|
||||
@dataclass
|
||||
class GuildFlags(ChannelFlags):
|
||||
presence: bool
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None):
|
||||
"""Send an action to all channels of the guild."""
|
||||
flags = flags or {}
|
||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||
class GuildDispatcher(
|
||||
DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
|
||||
):
|
||||
"""Guild backend for Pub/Sub."""
|
||||
|
||||
for chan_id in chan_ids:
|
||||
async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
|
||||
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||
for state in states:
|
||||
await self.sub(guild_id, state.session_id)
|
||||
|
||||
# only do an action for users that can
|
||||
# actually read the channel to start with.
|
||||
chan_perms = await get_permissions(
|
||||
user_id, chan_id, storage=self.main_dispatcher.app.storage
|
||||
)
|
||||
return states
|
||||
|
||||
if not chan_perms.bits.read_messages:
|
||||
log.debug("skipping cid={}, no read messages", chan_id)
|
||||
async def dispatch_filter(
|
||||
self, guild_id: int, filter_function, event: GatewayEvent
|
||||
):
|
||||
session_ids = self.state[guild_id]
|
||||
sessions: List[str] = []
|
||||
event_type, _ = event
|
||||
|
||||
for session_id in set(session_ids):
|
||||
if not filter_function(session_id):
|
||||
continue
|
||||
|
||||
log.debug("sending raw action {!r} to chan={}", action, chan_id)
|
||||
|
||||
# for now, only sub() has support for flags.
|
||||
# it is an idea to have flags support for other actions
|
||||
args = []
|
||||
if action == "sub":
|
||||
chanflags = dict(flags)
|
||||
|
||||
# channels don't need presence flags
|
||||
try:
|
||||
chanflags.pop("presence")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
args.append(chanflags)
|
||||
|
||||
await self.main_dispatcher.action(
|
||||
"channel", action, chan_id, user_id, *args
|
||||
)
|
||||
|
||||
async def _chan_call(self, meth: str, guild_id: int, *args):
|
||||
"""Call a method on the ChannelDispatcher, for all channels
|
||||
in the guild."""
|
||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||
|
||||
chan_dispatcher = self.main_dispatcher.backends["channel"]
|
||||
method = getattr(chan_dispatcher, meth)
|
||||
|
||||
for chan_id in chan_ids:
|
||||
log.debug("calling {} to chan={}", meth, chan_id)
|
||||
await method(chan_id, *args)
|
||||
|
||||
async def sub(self, guild_id: int, user_id: int, flags=None):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(guild_id, user_id, flags)
|
||||
await self._chan_action("sub", guild_id, user_id, flags)
|
||||
|
||||
async def unsub(self, guild_id: int, user_id: int):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(guild_id, user_id)
|
||||
await self._chan_action("unsub", guild_id, user_id)
|
||||
|
||||
async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
|
||||
"""Selectively dispatch to session ids that have
|
||||
func(session_id) true."""
|
||||
user_ids = self.state[guild_id]
|
||||
dispatched = 0
|
||||
sessions = []
|
||||
|
||||
# acquire a copy since we may be modifying
|
||||
# the original user_ids
|
||||
for user_id in set(user_ids):
|
||||
|
||||
# fetch all states / shards that are tied to the guild.
|
||||
states = self.sm.fetch_states(user_id, guild_id)
|
||||
|
||||
if not states:
|
||||
# user is actually disconnected,
|
||||
# so we should just unsub them
|
||||
await self.unsub(guild_id, user_id)
|
||||
try:
|
||||
state = app.state_manager.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
await self.unsub(guild_id, session_id)
|
||||
continue
|
||||
|
||||
# skip the given subscriber if event starts with PRESENCE_
|
||||
# and the flags say they don't want it.
|
||||
try:
|
||||
flags = self.get_flags(guild_id, session_id)
|
||||
except KeyError:
|
||||
log.warning("no flags for {!r}, ignoring", session_id)
|
||||
flags = GuildFlags(presence=True, typing=True)
|
||||
|
||||
# note that this does not equate to any unsubscription
|
||||
# of the channel.
|
||||
if event.startswith("PRESENCE_") and not self.flags_get(
|
||||
guild_id, user_id, "presence", True
|
||||
):
|
||||
if event_type.lower().startswith("presence_") and not flags.presence:
|
||||
continue
|
||||
|
||||
# filter the ones that matter
|
||||
states = list(filter(lambda state: func(state.session_id), states))
|
||||
try:
|
||||
await state.ws.dispatch(*event)
|
||||
except Exception:
|
||||
log.exception("error while dispatching to {}", state.session_id)
|
||||
continue
|
||||
|
||||
cur_sess = await self._dispatch_states(states, event, data)
|
||||
|
||||
sessions.extend(cur_sess)
|
||||
dispatched += len(cur_sess)
|
||||
|
||||
log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched)
|
||||
sessions.append(session_id)
|
||||
|
||||
log.info("Dispatched {} {!r} to {} states", guild_id, event[0], len(sessions))
|
||||
return sessions
|
||||
|
||||
async def dispatch(self, guild_id: int, event: str, data: Any):
|
||||
async def dispatch(self, guild_id: int, event):
|
||||
"""Dispatch an event to all subscribers of the guild."""
|
||||
return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data)
|
||||
return await self.dispatch_filter(guild_id, lambda sess_id: True, event)
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ import asyncio
|
|||
from collections import defaultdict
|
||||
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from quart import current_app as app
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.pubsub.dispatcher import Dispatcher
|
||||
from litecord.permissions import (
|
||||
Permissions,
|
||||
overwrite_find_mix,
|
||||
|
|
@ -239,9 +239,6 @@ class GuildMemberList:
|
|||
|
||||
Attributes
|
||||
----------
|
||||
main_lg: LazyGuildDispatcher
|
||||
Main instance of :class:`LazyGuildDispatcher`,
|
||||
so that we're able to use things such as :class:`Storage`.
|
||||
guild_id: int
|
||||
The Guild ID this instance is referring to.
|
||||
channel_id: int
|
||||
|
|
@ -257,11 +254,10 @@ class GuildMemberList:
|
|||
for example, can still rely on PRESENCE_UPDATEs.
|
||||
"""
|
||||
|
||||
def __init__(self, guild_id: int, channel_id: int, main_lg):
|
||||
def __init__(self, guild_id: int, channel_id: int):
|
||||
self.guild_id = guild_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
self.main = main_lg
|
||||
self.list = MemberList()
|
||||
|
||||
#: store the states that are subscribed to the list.
|
||||
|
|
@ -273,22 +269,22 @@ class GuildMemberList:
|
|||
@property
|
||||
def loop(self):
|
||||
"""Get the main asyncio loop instance."""
|
||||
return self.main.app.loop
|
||||
return app.loop
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
"""Get the global :class:`Storage` instance."""
|
||||
return self.main.app.storage
|
||||
return app.storage
|
||||
|
||||
@property
|
||||
def presence(self):
|
||||
"""Get the global :class:`PresenceManager` instance."""
|
||||
return self.main.app.presence
|
||||
return app.presence
|
||||
|
||||
@property
|
||||
def state_man(self):
|
||||
"""Get the global :class:`StateManager` instance."""
|
||||
return self.main.app.state_manager
|
||||
return app.state_manager
|
||||
|
||||
@property
|
||||
def list_id(self):
|
||||
|
|
@ -572,8 +568,7 @@ class GuildMemberList:
|
|||
Wrapper for :meth:`StateManager.fetch_raw`
|
||||
"""
|
||||
try:
|
||||
state = self.state_man.fetch_raw(session_id)
|
||||
return state
|
||||
return self.state_man.fetch_raw(session_id)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
|
@ -643,7 +638,7 @@ class GuildMemberList:
|
|||
|
||||
# do resync-ing in the background
|
||||
result.append(session_id)
|
||||
self.loop.create_task(self.shard_query(session_id, [role_range]))
|
||||
app.sched.spawn(self.shard_query(session_id, [role_range]))
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -683,8 +678,7 @@ class GuildMemberList:
|
|||
)
|
||||
|
||||
if everyone_perms.bits.read_messages and list_id != "everyone":
|
||||
everyone_gml = await self.main.get_gml(self.guild_id)
|
||||
|
||||
everyone_gml = await app.lazy_guild.get_gml(self.guild_id)
|
||||
return await everyone_gml.shard_query(session_id, ranges)
|
||||
|
||||
await self._init_check()
|
||||
|
|
@ -1372,47 +1366,36 @@ class GuildMemberList:
|
|||
|
||||
self.guild_id = 0
|
||||
self.channel_id = 0
|
||||
self.main = None
|
||||
self._set_empty_list()
|
||||
self.state = {}
|
||||
|
||||
|
||||
class LazyGuildDispatcher(Dispatcher):
|
||||
class LazyGuildManager:
|
||||
"""Main class holding the member lists for lazy guilds."""
|
||||
|
||||
# channel ids
|
||||
KEY_TYPE = int
|
||||
|
||||
# the session ids subscribing to channels
|
||||
VAL_TYPE = str
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
|
||||
self.storage = main.app.storage
|
||||
|
||||
def __init__(self):
|
||||
# {chan_id: gml, ...}
|
||||
self.state = {}
|
||||
self.state: Dict[int, GuildMemberList] = {}
|
||||
|
||||
#: store which guilds have their
|
||||
# respective GMLs
|
||||
# {guild_id: [chan_id, ...], ...}
|
||||
self.guild_map: Dict[int, List[int]] = defaultdict(list)
|
||||
|
||||
async def get_gml(self, channel_id: int):
|
||||
async def get_gml(self, channel_id: int) -> GuildMemberList:
|
||||
"""Get a guild list for a channel ID,
|
||||
generating it if it doesn't exist."""
|
||||
try:
|
||||
return self.state[channel_id]
|
||||
except KeyError:
|
||||
guild_id = await self.storage.guild_from_channel(channel_id)
|
||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||
|
||||
# if we don't find a guild, we just
|
||||
# set it the same as the channel.
|
||||
if not guild_id:
|
||||
guild_id = channel_id
|
||||
|
||||
gml = GuildMemberList(guild_id, channel_id, self)
|
||||
gml = GuildMemberList(guild_id, channel_id)
|
||||
self.state[channel_id] = gml
|
||||
self.guild_map[guild_id].append(channel_id)
|
||||
return gml
|
||||
|
|
@ -1437,16 +1420,6 @@ class LazyGuildDispatcher(Dispatcher):
|
|||
gml = await self.get_gml(chan_id)
|
||||
gml.unsub(session_id)
|
||||
|
||||
async def dispatch(self, guild_id, event: str, *args, **kwargs):
|
||||
"""Call a function specialized in handling the given event"""
|
||||
try:
|
||||
handler = getattr(self, f"_handle_{event.lower()}")
|
||||
except AttributeError:
|
||||
log.warning("unknown event: {}", event)
|
||||
return
|
||||
|
||||
await handler(guild_id, *args, **kwargs)
|
||||
|
||||
def remove_channel(self, channel_id: int):
|
||||
"""Remove a channel from the manager."""
|
||||
try:
|
||||
|
|
@ -1474,29 +1447,29 @@ class LazyGuildDispatcher(Dispatcher):
|
|||
method = getattr(lazy_list, method_str)
|
||||
await method(*args)
|
||||
|
||||
async def _handle_new_role(self, guild_id: int, new_role: dict):
|
||||
async def new_role(self, guild_id: int, new_role: dict):
|
||||
"""Handle the addition of a new group by dispatching it to
|
||||
the member lists."""
|
||||
await self._call_all_lists(guild_id, "new_role", new_role)
|
||||
|
||||
async def _handle_role_pos_upd(self, guild_id, role: dict):
|
||||
async def role_position_update(self, guild_id, role: dict):
|
||||
await self._call_all_lists(guild_id, "role_pos_update", role)
|
||||
|
||||
async def _handle_role_update(self, guild_id, role: dict):
|
||||
async def role_update(self, guild_id, role: dict):
|
||||
# handle name and hoist changes
|
||||
await self._call_all_lists(guild_id, "role_update", role)
|
||||
|
||||
async def _handle_role_delete(self, guild_id, role_id: int):
|
||||
async def role_delete(self, guild_id, role_id: int):
|
||||
await self._call_all_lists(guild_id, "role_delete", role_id)
|
||||
|
||||
async def _handle_pres_update(self, guild_id, user_id: int, partial: dict):
|
||||
async def pres_update(self, guild_id, user_id: int, partial: dict):
|
||||
await self._call_all_lists(guild_id, "pres_update", user_id, partial)
|
||||
|
||||
async def _handle_new_member(self, guild_id, user_id: int):
|
||||
async def new_member(self, guild_id, user_id: int):
|
||||
await self._call_all_lists(guild_id, "new_member", user_id)
|
||||
|
||||
async def _handle_remove_member(self, guild_id, user_id: int):
|
||||
async def remove_member(self, guild_id, user_id: int):
|
||||
await self._call_all_lists(guild_id, "remove_member", user_id)
|
||||
|
||||
async def _handle_update_user(self, guild_id, user_id: int):
|
||||
async def update_user(self, guild_id, user_id: int):
|
||||
await self._call_all_lists(guild_id, "update_user", user_id)
|
||||
|
|
|
|||
|
|
@ -17,30 +17,20 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from typing import List
|
||||
from quart import current_app as app
|
||||
from .dispatcher import GatewayEvent
|
||||
from .utils import send_event_to_states
|
||||
|
||||
|
||||
class MemberDispatcher(Dispatcher):
|
||||
"""Member backend for Pub/Sub."""
|
||||
async def dispatch_member(
|
||||
guild_id: int, user_id: int, event: GatewayEvent
|
||||
) -> List[str]:
|
||||
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||
|
||||
KEY_TYPE = tuple
|
||||
# if no states were found, we should unsub the user from the guild
|
||||
if not states:
|
||||
await app.dispatcher.guild.unsub(guild_id, user_id)
|
||||
return []
|
||||
|
||||
async def dispatch(self, key, event, data):
|
||||
"""Dispatch a single event to a member.
|
||||
|
||||
This is shard-aware.
|
||||
"""
|
||||
# we don't keep any state on this dispatcher, so the key
|
||||
# is just (guild_id, user_id)
|
||||
guild_id, user_id = key
|
||||
|
||||
# fetch shards
|
||||
states = self.sm.fetch_states(user_id, guild_id)
|
||||
|
||||
# if no states were found, we should
|
||||
# unsub the user from the GUILD channel
|
||||
if not states:
|
||||
await self.main_dispatcher.unsub("guild", guild_id, user_id)
|
||||
return
|
||||
|
||||
return await self._dispatch_states(states, event, data)
|
||||
return await send_event_to_states(states, event)
|
||||
|
|
|
|||
|
|
@ -17,23 +17,27 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from typing import Callable, List
|
||||
|
||||
from quart import current_app as app
|
||||
from .dispatcher import GatewayEvent
|
||||
from .utils import send_event_to_states
|
||||
|
||||
|
||||
class UserDispatcher(Dispatcher):
|
||||
"""User backend for Pub/Sub."""
|
||||
|
||||
KEY_TYPE = int
|
||||
|
||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
||||
"""Dispatch an event to all shards of a user."""
|
||||
|
||||
# filter only states where func() gives true
|
||||
states = list(
|
||||
filter(lambda state: func(state.session_id), self.sm.user_states(user_id))
|
||||
async def dispatch_user_filter(
|
||||
user_id: int, filter_func: Callable[[str], bool], event_data: GatewayEvent
|
||||
) -> List[str]:
|
||||
"""Dispatch to a given user's states, but only for states
|
||||
where filter_func returns true."""
|
||||
states = list(
|
||||
filter(
|
||||
lambda state: filter_func(state.session_id),
|
||||
app.state_manager.user_states(user_id),
|
||||
)
|
||||
)
|
||||
|
||||
return await self._dispatch_states(states, event, data)
|
||||
return await send_event_to_states(states, event_data)
|
||||
|
||||
async def dispatch(self, user_id: int, event, data):
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||
|
||||
async def dispatch_user(user_id: int, event_data: GatewayEvent) -> List[str]:
|
||||
return await dispatch_user_filter(user_id, lambda sess_id: True, event_data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
"""
|
||||
|
||||
Litecord
|
||||
Copyright (C) 2018-2019 Luna Mendes
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU General Public License as published by
|
||||
the Free Software Foundation, version 3 of the License.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple, Any
|
||||
from ..gateway.state import GatewayState
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def send_event_to_states(
|
||||
states: List[GatewayState], event_data: Tuple[str, Any]
|
||||
) -> List[str]:
|
||||
"""Dispatch an event to a list of states."""
|
||||
res = []
|
||||
|
||||
for state in states:
|
||||
try:
|
||||
event, data = event_data
|
||||
await state.ws.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
except Exception:
|
||||
log.exception("error while dispatching")
|
||||
|
||||
return res
|
||||
|
|
@ -482,6 +482,9 @@ INVITE = {
|
|||
"required": False,
|
||||
"nullable": True,
|
||||
}, # discord client sends invite code there
|
||||
# sent by official client, unknown purpose
|
||||
"target_user_id": {"type": "snowflake", "required": False, "nullable": True},
|
||||
"target_user_type": {"type": "number", "required": False, "nullable": True},
|
||||
}
|
||||
|
||||
USER_SETTINGS = {
|
||||
|
|
|
|||
|
|
@ -709,7 +709,9 @@ class Storage:
|
|||
|
||||
return res
|
||||
|
||||
async def get_guild_extra(self, guild_id: int, user_id=None, large=None) -> Dict:
|
||||
async def get_guild_extra(
|
||||
self, guild_id: int, user_id: Optional[int] = None, large: Optional[int] = None
|
||||
) -> Dict:
|
||||
"""Get extra information about a guild."""
|
||||
res = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -181,9 +181,6 @@ async def send_sys_message(
|
|||
raise ValueError("Invalid system message type")
|
||||
|
||||
message_id = await handler(channel_id, *args, **kwargs)
|
||||
|
||||
message = await app.storage.get_message(message_id)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
|
||||
|
||||
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", message))
|
||||
return message_id
|
||||
|
|
|
|||
|
|
@ -252,7 +252,7 @@ async def maybe_lazy_guild_dispatch(
|
|||
if isinstance(role, dict) and not role["hoist"] and not force:
|
||||
return
|
||||
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
|
||||
await (getattr(app.lazy_guild, event))(guild_id, role)
|
||||
|
||||
|
||||
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
||||
|
|
|
|||
|
|
@ -17,11 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, List
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
||||
from litecord.voice.lvsp_conn import LVSPConnection
|
||||
|
||||
|
|
@ -42,15 +43,15 @@ class LVSPManager:
|
|||
Spawns :class:`LVSPConnection` as needed, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, app, voice):
|
||||
self.app = app
|
||||
def __init__(self, app_, voice):
|
||||
self.app = app_
|
||||
self.voice = voice
|
||||
|
||||
# map servers to LVSPConnection
|
||||
self.conns = {}
|
||||
self.conns: Dict[str, LVSPConnection] = {}
|
||||
|
||||
# maps regions to server hostnames
|
||||
self.servers = defaultdict(list)
|
||||
self.servers: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
# maps Union[GuildID, DMId, GroupDMId] to server hostnames
|
||||
self.assign = {}
|
||||
|
|
@ -84,7 +85,7 @@ class LVSPManager:
|
|||
continue
|
||||
|
||||
self.regions[region.id] = region
|
||||
self.app.loop.create_task(self._spawn_region(region))
|
||||
app.sched.spawn(self._spawn_region(region))
|
||||
|
||||
async def _spawn_region(self, region: Region):
|
||||
"""Spawn a region. Involves fetching all the hostnames
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
from typing import Tuple, Dict, List
|
||||
from collections import defaultdict
|
||||
from dataclasses import fields
|
||||
from quart import current_app as app
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -286,6 +287,6 @@ class VoiceManager:
|
|||
# slow, but it be like that, also copied from other users...
|
||||
for guild_id in guild_ids:
|
||||
guild = await self.app.storage.get_guild_full(guild_id, None)
|
||||
await self.app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||
|
||||
# TODO propagate the channel deprecation to LVSP connections
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ def main(config):
|
|||
# as the managers require it
|
||||
# and the migrate command also sets the db up
|
||||
if argv[1] != "migrate":
|
||||
init_app_managers(app, voice=False)
|
||||
init_app_managers(app, init_voice=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
loop.run_until_complete(_ctx_wrapper(app, args))
|
||||
|
|
|
|||
8
run.py
8
run.py
|
|
@ -95,6 +95,7 @@ from litecord.images import IconManager
|
|||
from litecord.jobs import JobManager
|
||||
from litecord.voice.manager import VoiceManager
|
||||
from litecord.guild_memory_store import GuildMemoryStore
|
||||
from litecord.pubsub.lazy_guild import LazyGuildManager
|
||||
|
||||
from litecord.gateway.gateway import websocket_handler
|
||||
|
||||
|
|
@ -254,7 +255,7 @@ async def init_app_db(app_):
|
|||
app_.sched = JobManager()
|
||||
|
||||
|
||||
def init_app_managers(app_, *, voice=True):
|
||||
def init_app_managers(app_: Quart, *, init_voice=True):
|
||||
"""Initialize singleton classes."""
|
||||
app_.loop = asyncio.get_event_loop()
|
||||
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
|
||||
|
|
@ -265,7 +266,7 @@ def init_app_managers(app_, *, voice=True):
|
|||
|
||||
app_.icons = IconManager(app_)
|
||||
|
||||
app_.dispatcher = EventDispatcher(app_)
|
||||
app_.dispatcher = EventDispatcher()
|
||||
app_.presence = PresenceManager(app_)
|
||||
|
||||
app_.storage.presence = app_.presence
|
||||
|
|
@ -274,10 +275,11 @@ def init_app_managers(app_, *, voice=True):
|
|||
# we do this because of a bug on ./manage.py where it
|
||||
# cancels the LVSPManager's spawn regions task. we don't
|
||||
# need to start it on manage time.
|
||||
if voice:
|
||||
if init_voice:
|
||||
app_.voice = VoiceManager(app_)
|
||||
|
||||
app_.guild_store = GuildMemoryStore()
|
||||
app_.lazy_guild = LazyGuildManager()
|
||||
|
||||
|
||||
async def api_index(app_):
|
||||
|
|
|
|||
Loading…
Reference in New Issue