mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'feature/rewrite-dispatcher' into 'master'
dispatcher refactor Closes #84 See merge request litecord/litecord!60
This commit is contained in:
commit
f0f5570dfa
2
Pipfile
2
Pipfile
|
|
@ -18,7 +18,7 @@ zstandard = "*"
|
||||||
winter = {editable = true,git = "https://gitlab.com/elixire/winter.git"}
|
winter = {editable = true,git = "https://gitlab.com/elixire/winter.git"}
|
||||||
|
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
pytest = "==5.1.2"
|
pytest = "==5.3.2"
|
||||||
pytest-asyncio = "==0.10.0"
|
pytest-asyncio = "==0.10.0"
|
||||||
mypy = "*"
|
mypy = "*"
|
||||||
black = "*"
|
black = "*"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "ee24bd04c2d9b93bce1e8595379c652a31540b9da54f6ba7ef01182164be68e3"
|
"sha256": "dedc41184df539a608717e68c108ccfb4d529acb1d1702d83de223840c7cc754"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
|
|
@ -123,41 +123,36 @@
|
||||||
},
|
},
|
||||||
"cffi": {
|
"cffi": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:0b49274afc941c626b605fb59b59c3485c17dc776dc3cc7cc14aca74cc19cc42",
|
"sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff",
|
||||||
"sha256:0e3ea92942cb1168e38c05c1d56b0527ce31f1a370f6117f1d490b8dcd6b3a04",
|
"sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b",
|
||||||
"sha256:135f69aecbf4517d5b3d6429207b2dff49c876be724ac0c8bf8e1ea99df3d7e5",
|
"sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac",
|
||||||
"sha256:19db0cdd6e516f13329cba4903368bff9bb5a9331d3410b1b448daaadc495e54",
|
"sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0",
|
||||||
"sha256:2781e9ad0e9d47173c0093321bb5435a9dfae0ed6a762aabafa13108f5f7b2ba",
|
"sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384",
|
||||||
"sha256:291f7c42e21d72144bb1c1b2e825ec60f46d0a7468f5346841860454c7aa8f57",
|
"sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26",
|
||||||
"sha256:2c5e309ec482556397cb21ede0350c5e82f0eb2621de04b2633588d118da4396",
|
"sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6",
|
||||||
"sha256:2e9c80a8c3344a92cb04661115898a9129c074f7ab82011ef4b612f645939f12",
|
"sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b",
|
||||||
"sha256:32a262e2b90ffcfdd97c7a5e24a6012a43c61f1f5a57789ad80af1d26c6acd97",
|
"sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e",
|
||||||
"sha256:3c9fff570f13480b201e9ab69453108f6d98244a7f495e91b6c654a47486ba43",
|
"sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd",
|
||||||
"sha256:415bdc7ca8c1c634a6d7163d43fb0ea885a07e9618a64bda407e04b04333b7db",
|
"sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2",
|
||||||
"sha256:42194f54c11abc8583417a7cf4eaff544ce0de8187abaf5d29029c91b1725ad3",
|
"sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66",
|
||||||
"sha256:4424e42199e86b21fc4db83bd76909a6fc2a2aefb352cb5414833c030f6ed71b",
|
"sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc",
|
||||||
"sha256:4a43c91840bda5f55249413037b7a9b79c90b1184ed504883b72c4df70778579",
|
"sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8",
|
||||||
"sha256:599a1e8ff057ac530c9ad1778293c665cb81a791421f46922d80a86473c13346",
|
"sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55",
|
||||||
"sha256:5c4fae4e9cdd18c82ba3a134be256e98dc0596af1e7285a3d2602c97dcfa5159",
|
"sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4",
|
||||||
"sha256:5ecfa867dea6fabe2a58f03ac9186ea64da1386af2159196da51c4904e11d652",
|
"sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5",
|
||||||
"sha256:62f2578358d3a92e4ab2d830cd1c2049c9c0d0e6d3c58322993cc341bdeac22e",
|
"sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d",
|
||||||
"sha256:6471a82d5abea994e38d2c2abc77164b4f7fbaaf80261cb98394d5793f11b12a",
|
"sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78",
|
||||||
"sha256:6d4f18483d040e18546108eb13b1dfa1000a089bcf8529e30346116ea6240506",
|
"sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa",
|
||||||
"sha256:71a608532ab3bd26223c8d841dde43f3516aa5d2bf37b50ac410bb5e99053e8f",
|
"sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793",
|
||||||
"sha256:74a1d8c85fb6ff0b30fbfa8ad0ac23cd601a138f7509dc617ebc65ef305bb98d",
|
"sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f",
|
||||||
"sha256:7b93a885bb13073afb0aa73ad82059a4c41f4b7d8eb8368980448b52d4c7dc2c",
|
"sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a",
|
||||||
"sha256:7d4751da932caaec419d514eaa4215eaf14b612cff66398dd51129ac22680b20",
|
"sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f",
|
||||||
"sha256:7f627141a26b551bdebbc4855c1157feeef18241b4b8366ed22a5c7d672ef858",
|
"sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30",
|
||||||
"sha256:8169cf44dd8f9071b2b9248c35fc35e8677451c52f795daa2bb4643f32a540bc",
|
"sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f",
|
||||||
"sha256:aa00d66c0fab27373ae44ae26a66a9e43ff2a678bf63a9c7c1a9a4d61172827a",
|
"sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3",
|
||||||
"sha256:ccb032fda0873254380aa2bfad2582aedc2959186cce61e3a17abc1a55ff89c3",
|
"sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c"
|
||||||
"sha256:d754f39e0d1603b5b24a7f8484b22d2904fa551fe865fd0d4c3332f078d20d4e",
|
|
||||||
"sha256:d75c461e20e29afc0aee7172a0950157c704ff0dd51613506bd7d82b718e7410",
|
|
||||||
"sha256:dcd65317dd15bc0451f3e01c80da2216a31916bdcffd6221ca1202d96584aa25",
|
|
||||||
"sha256:e570d3ab32e2c2861c4ebe6ffcad6a8abf9347432a37608fe1fbd157b3f0036b",
|
|
||||||
"sha256:fd43a88e045cf992ed09fa724b5315b790525f2676883a6ea64e3263bae6549d"
|
|
||||||
],
|
],
|
||||||
"version": "==1.13.2"
|
"version": "==1.14.0"
|
||||||
},
|
},
|
||||||
"chardet": {
|
"chardet": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
|
|
@ -195,10 +190,10 @@
|
||||||
},
|
},
|
||||||
"h2": {
|
"h2": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e",
|
"sha256:61e0f6601fa709f35cdb730863b4e5ec7ad449792add80d1410d4174ed139af5",
|
||||||
"sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4"
|
"sha256:875f41ebd6f2c44781259005b157faed1a5031df3ae5aa7bcb4628a6c0782f14"
|
||||||
],
|
],
|
||||||
"version": "==3.1.1"
|
"version": "==3.2.0"
|
||||||
},
|
},
|
||||||
"hpack": {
|
"hpack": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
|
|
@ -510,13 +505,6 @@
|
||||||
],
|
],
|
||||||
"version": "==1.4.3"
|
"version": "==1.4.3"
|
||||||
},
|
},
|
||||||
"atomicwrites": {
|
|
||||||
"hashes": [
|
|
||||||
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
|
|
||||||
"sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6"
|
|
||||||
],
|
|
||||||
"version": "==1.3.0"
|
|
||||||
},
|
|
||||||
"attrs": {
|
"attrs": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
|
"sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c",
|
||||||
|
|
@ -647,11 +635,11 @@
|
||||||
},
|
},
|
||||||
"pytest": {
|
"pytest": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:95d13143cc14174ca1a01ec68e84d76ba5d9d493ac02716fd9706c949a505210",
|
"sha256:6b571215b5a790f9b41f19f3531c53a45cf6bb8ef2988bc1ff9afb38270b25fa",
|
||||||
"sha256:b78fe2881323bd44fd9bd76e5317173d4316577e7b1cddebae9136a4495ec865"
|
"sha256:e41d489ff43948babd0fad7ad5e49b8735d5d55e26628a58673c39ff61d95de4"
|
||||||
],
|
],
|
||||||
"index": "pypi",
|
"index": "pypi",
|
||||||
"version": "==5.1.2"
|
"version": "==5.3.2"
|
||||||
},
|
},
|
||||||
"pytest-asyncio": {
|
"pytest-asyncio": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ async def _update_features(guild_id: int, features: list):
|
||||||
)
|
)
|
||||||
|
|
||||||
guild = await app.storage.get_guild_full(guild_id)
|
guild = await app.storage.get_guild_full(guild_id)
|
||||||
await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
|
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:guild_id>/features", methods=["PATCH"])
|
@bp.route("/<int:guild_id>/features", methods=["PATCH"])
|
||||||
|
|
|
||||||
|
|
@ -63,10 +63,10 @@ async def update_guild(guild_id: int):
|
||||||
|
|
||||||
if old_unavailable and not new_unavailable:
|
if old_unavailable and not new_unavailable:
|
||||||
# guild became available
|
# guild became available
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
|
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild))
|
||||||
else:
|
else:
|
||||||
# guild became unavailable
|
# guild became unavailable
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
|
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_DELETE", guild))
|
||||||
|
|
||||||
return jsonify(guild)
|
return jsonify(guild)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,13 @@ import itsdangerous
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from quart import Blueprint, jsonify, request, current_app as app
|
from quart import Blueprint, jsonify, request, current_app as app
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
from winter import get_snowflake
|
||||||
|
|
||||||
from litecord.auth import token_check
|
from litecord.auth import token_check
|
||||||
from litecord.common.users import create_user
|
from litecord.common.users import create_user
|
||||||
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
|
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
from winter import get_snowflake
|
from litecord.pubsub.user import dispatch_user
|
||||||
from .invites import use_invite
|
from .invites import use_invite
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -172,7 +173,7 @@ async def verify_user():
|
||||||
)
|
)
|
||||||
|
|
||||||
new_user = await app.storage.get_user(user_id, True)
|
new_user = await app.storage.get_user(user_id, True)
|
||||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
|
await dispatch_user(user_id, ("USER_UPDATE", new_user))
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ from litecord.common.messages import (
|
||||||
msg_add_attachment,
|
msg_add_attachment,
|
||||||
msg_guild_text_mentions,
|
msg_guild_text_mentions,
|
||||||
)
|
)
|
||||||
|
from litecord.pubsub.user import dispatch_user
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -136,10 +137,10 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
||||||
|
|
||||||
# dispatch CHANNEL_CREATE so the client knows which
|
# dispatch CHANNEL_CREATE so the client knows which
|
||||||
# channel the future event is about
|
# channel the future event is about
|
||||||
await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
await dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
||||||
|
|
||||||
# subscribe the peer to the channel
|
# subscribe the peer to the channel
|
||||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
await app.dispatcher.channel.sub(channel_id, peer_id)
|
||||||
|
|
||||||
# insert it on dm_channel_state so the client
|
# insert it on dm_channel_state so the client
|
||||||
# is subscribed on the future
|
# is subscribed on the future
|
||||||
|
|
@ -242,7 +243,7 @@ async def _create_message(channel_id):
|
||||||
await _dm_pre_dispatch(channel_id, user_id)
|
await _dm_pre_dispatch(channel_id, user_id)
|
||||||
await _dm_pre_dispatch(channel_id, guild_id)
|
await _dm_pre_dispatch(channel_id, guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
|
||||||
|
|
||||||
# spawn url processor for embedding of images
|
# spawn url processor for embedding of images
|
||||||
perms = await get_permissions(user_id, channel_id)
|
perms = await get_permissions(user_id, channel_id)
|
||||||
|
|
@ -350,7 +351,7 @@ async def edit_message(channel_id, message_id):
|
||||||
# only dispatch MESSAGE_UPDATE if any update
|
# only dispatch MESSAGE_UPDATE if any update
|
||||||
# actually happened
|
# actually happened
|
||||||
if updated:
|
if updated:
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
|
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_UPDATE", message))
|
||||||
|
|
||||||
return jsonify(message)
|
return jsonify(message)
|
||||||
|
|
||||||
|
|
@ -427,9 +428,9 @@ async def delete_message(channel_id, message_id):
|
||||||
message_id,
|
message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel",
|
|
||||||
channel_id,
|
channel_id,
|
||||||
|
(
|
||||||
"MESSAGE_DELETE",
|
"MESSAGE_DELETE",
|
||||||
{
|
{
|
||||||
"id": str(message_id),
|
"id": str(message_id),
|
||||||
|
|
@ -437,6 +438,7 @@ async def delete_message(channel_id, message_id):
|
||||||
# for lazy guilds
|
# for lazy guilds
|
||||||
"guild_id": str(guild_id),
|
"guild_id": str(guild_id),
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -106,11 +106,15 @@ async def add_pin(channel_id, message_id):
|
||||||
|
|
||||||
timestamp = snowflake_datetime(row["message_id"])
|
timestamp = snowflake_datetime(row["message_id"])
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel",
|
|
||||||
channel_id,
|
channel_id,
|
||||||
|
(
|
||||||
"CHANNEL_PINS_UPDATE",
|
"CHANNEL_PINS_UPDATE",
|
||||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
|
{
|
||||||
|
"channel_id": str(channel_id),
|
||||||
|
"last_pin_timestamp": timestamp_(timestamp),
|
||||||
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await send_sys_message(
|
await send_sys_message(
|
||||||
|
|
@ -149,11 +153,15 @@ async def delete_pin(channel_id, message_id):
|
||||||
|
|
||||||
timestamp = snowflake_datetime(row["message_id"])
|
timestamp = snowflake_datetime(row["message_id"])
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel",
|
|
||||||
channel_id,
|
channel_id,
|
||||||
|
(
|
||||||
"CHANNEL_PINS_UPDATE",
|
"CHANNEL_PINS_UPDATE",
|
||||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
|
{
|
||||||
|
"channel_id": str(channel_id),
|
||||||
|
"last_pin_timestamp": timestamp.isoformat(),
|
||||||
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -141,10 +141,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
payload["guild_id"] = str(guild_id)
|
payload["guild_id"] = str(guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_REACTION_ADD", payload))
|
||||||
"channel", channel_id, "MESSAGE_REACTION_ADD", payload
|
|
||||||
)
|
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -206,8 +203,8 @@ async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
payload["guild_id"] = str(guild_id)
|
payload["guild_id"] = str(guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
|
channel_id, ("MESSAGE_REACTION_REMOVE", payload)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -290,6 +287,6 @@ async def remove_all_reactions(channel_id, message_id):
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
payload["guild_id"] = str(guild_id)
|
payload["guild_id"] = str(guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
|
channel_id, ("MESSAGE_REACTION_REMOVE_ALL", payload)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from quart import Blueprint, request, current_app as app, jsonify
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
from winter import snowflake_datetime
|
||||||
|
|
||||||
from litecord.auth import token_check
|
from litecord.auth import token_check
|
||||||
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
|
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
|
||||||
|
|
@ -41,8 +43,9 @@ from litecord.system_messages import send_sys_message
|
||||||
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
|
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
|
||||||
from litecord.utils import search_result_from_list
|
from litecord.utils import search_result_from_list
|
||||||
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
||||||
from winter import snowflake_datetime
|
|
||||||
from litecord.common.channels import channel_ack
|
from litecord.common.channels import channel_ack
|
||||||
|
from litecord.pubsub.user import dispatch_user
|
||||||
|
from litecord.permissions import get_permissions, Permissions
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("channels", __name__)
|
bp = Blueprint("channels", __name__)
|
||||||
|
|
@ -80,9 +83,7 @@ async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
|
||||||
|
|
||||||
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
||||||
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
|
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
|
||||||
|
|
||||||
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
|
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
|
||||||
|
|
||||||
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
|
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
|
||||||
|
|
||||||
# if none of them were actually updated,
|
# if none of them were actually updated,
|
||||||
|
|
@ -93,7 +94,7 @@ async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
||||||
# at least one of the fields were updated,
|
# at least one of the fields were updated,
|
||||||
# dispatch GUILD_UPDATE
|
# dispatch GUILD_UPDATE
|
||||||
guild = await app.storage.get_guild(guild_id)
|
guild = await app.storage.get_guild(guild_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||||
|
|
||||||
|
|
||||||
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
||||||
|
|
@ -104,7 +105,7 @@ async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
||||||
return
|
return
|
||||||
|
|
||||||
guild = await app.storage.get_guild(guild_id)
|
guild = await app.storage.get_guild(guild_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
await app.dispatcher.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||||
|
|
||||||
|
|
||||||
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||||
|
|
@ -134,7 +135,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||||
# tell all people in the guild of the category removal
|
# tell all people in the guild of the category removal
|
||||||
for child_id in childs:
|
for child_id in childs:
|
||||||
child = await app.storage.get_channel(child_id)
|
child = await app.storage.get_channel(child_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
|
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", child))
|
||||||
|
|
||||||
|
|
||||||
async def _delete_messages(channel_id):
|
async def _delete_messages(channel_id):
|
||||||
|
|
@ -249,12 +250,10 @@ async def close_channel(channel_id):
|
||||||
)
|
)
|
||||||
|
|
||||||
# clean its member list representation
|
# clean its member list representation
|
||||||
lazy_guilds = app.dispatcher.backends["lazy_guild"]
|
app.lazy_guild.remove_channel(channel_id)
|
||||||
lazy_guilds.remove_channel(channel_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
|
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_DELETE", chan))
|
||||||
|
await app.dispatcher.channel.drop(channel_id)
|
||||||
await app.dispatcher.remove("channel", channel_id)
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
|
|
@ -273,11 +272,9 @@ async def close_channel(channel_id):
|
||||||
channel_id,
|
channel_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# unsubscribe
|
|
||||||
await app.dispatcher.unsub("channel", channel_id, user_id)
|
|
||||||
|
|
||||||
# nothing happens to the other party of the dm channel
|
# nothing happens to the other party of the dm channel
|
||||||
await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
|
await app.dispatcher.channel.unsub(channel_id, user_id)
|
||||||
|
await dispatch_user(user_id, ("CHANNEL_DELETE", chan))
|
||||||
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
@ -318,10 +315,54 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
await app.dispatcher.guild.dispatch(guild_id, "CHANNEL_UPDATE", chan)
|
||||||
|
|
||||||
|
|
||||||
async def _process_overwrites(channel_id: int, overwrites: list):
|
@dataclass
|
||||||
|
class Target:
|
||||||
|
type: int
|
||||||
|
user_id: Optional[int]
|
||||||
|
role_id: Optional[int]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_user(self):
|
||||||
|
return self.type == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_role(self):
|
||||||
|
return self.type == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch_action(guild_id: int, channel_id: int, user_id: int, perms) -> None:
|
||||||
|
"""Apply an action of sub/unsub to all states of a user."""
|
||||||
|
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||||
|
for state in states:
|
||||||
|
if perms.read_messages:
|
||||||
|
await app.dispatcher.channel.sub(channel_id, state.session_id)
|
||||||
|
else:
|
||||||
|
await app.dispatcher.channel.unsub(channel_id, state.session_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_overwrites(guild_id: int, channel_id: int, overwrites: list) -> None:
|
||||||
|
# user_ids serves as a "prospect" user id list.
|
||||||
|
# for each overwrite we apply, we fill this list with user ids we
|
||||||
|
# want to check later to subscribe/unsubscribe from the channel.
|
||||||
|
# (users without read_messages are automatically unsubbed since we
|
||||||
|
# don't want to leak messages to them when they dont have perms anymore)
|
||||||
|
|
||||||
|
# the expensiveness of large overwrite/role chains shines here.
|
||||||
|
# since each user id we fill in implies an entire get_permissions call
|
||||||
|
# (because we don't have the answer if a user is to be subbed/unsubbed
|
||||||
|
# with only overwrites, an overwrite for a user allowing them might be
|
||||||
|
# overwritten by a role overwrite denying them if they have the role),
|
||||||
|
# we get a lot of tension on that, causing channel updates to lag a bit.
|
||||||
|
|
||||||
|
# there may be some good optimizations to do here, such as doing some
|
||||||
|
# precalculations like fetching get_permissions for everyone first, then
|
||||||
|
# applying the new overwrites one by one, then subbing/unsubbing at the
|
||||||
|
# end, but it would be very memory intensive.
|
||||||
|
user_ids: List[int] = []
|
||||||
|
|
||||||
for overwrite in overwrites:
|
for overwrite in overwrites:
|
||||||
|
|
||||||
# 0 for member overwrite, 1 for role overwrite
|
# 0 for member overwrite, 1 for role overwrite
|
||||||
|
|
@ -329,7 +370,9 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
||||||
target_role = None if target_type == 0 else overwrite["id"]
|
target_role = None if target_type == 0 else overwrite["id"]
|
||||||
target_user = overwrite["id"] if target_type == 0 else None
|
target_user = overwrite["id"] if target_type == 0 else None
|
||||||
|
|
||||||
col_name = "target_user" if target_type == 0 else "target_role"
|
target = Target(target_type, target_role, target_user)
|
||||||
|
|
||||||
|
col_name = "target_user" if target.is_user else "target_role"
|
||||||
constraint_name = f"channel_overwrites_{col_name}_uniq"
|
constraint_name = f"channel_overwrites_{col_name}_uniq"
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
|
|
@ -352,6 +395,17 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
||||||
overwrite["deny"],
|
overwrite["deny"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if target.is_user:
|
||||||
|
perms = Permissions(overwrite["allow"] & ~overwrite["deny"])
|
||||||
|
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
|
||||||
|
|
||||||
|
elif target.is_role:
|
||||||
|
user_ids.extend(await app.storage.get_role_members(target.role_id))
|
||||||
|
|
||||||
|
for user_id in user_ids:
|
||||||
|
perms = await get_permissions(user_id, channel_id)
|
||||||
|
await _dispatch_action(guild_id, channel_id, target.user_id, perms)
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
|
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
|
||||||
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||||
|
|
@ -371,6 +425,7 @@ async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
await _process_overwrites(
|
await _process_overwrites(
|
||||||
|
guild_id,
|
||||||
channel_id,
|
channel_id,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
|
|
@ -447,7 +502,7 @@ async def _update_channel_common(channel_id: int, guild_id: int, j: dict):
|
||||||
|
|
||||||
if "channel_overwrites" in j:
|
if "channel_overwrites" in j:
|
||||||
overwrites = j["channel_overwrites"]
|
overwrites = j["channel_overwrites"]
|
||||||
await _process_overwrites(channel_id, overwrites)
|
await _process_overwrites(guild_id, channel_id, overwrites)
|
||||||
|
|
||||||
|
|
||||||
async def _common_guild_chan(channel_id, j: dict):
|
async def _common_guild_chan(channel_id, j: dict):
|
||||||
|
|
@ -566,9 +621,9 @@ async def update_channel(channel_id: int):
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
if is_guild:
|
if is_guild:
|
||||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
|
||||||
else:
|
else:
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
|
||||||
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
@ -578,9 +633,9 @@ async def trigger_typing(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel",
|
|
||||||
channel_id,
|
channel_id,
|
||||||
|
(
|
||||||
"TYPING_START",
|
"TYPING_START",
|
||||||
{
|
{
|
||||||
"channel_id": str(channel_id),
|
"channel_id": str(channel_id),
|
||||||
|
|
@ -589,6 +644,7 @@ async def trigger_typing(channel_id):
|
||||||
# guild_id for lazy guilds
|
# guild_id for lazy guilds
|
||||||
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
@ -816,5 +872,5 @@ async def bulk_delete(channel_id: int):
|
||||||
if res == "DELETE 0":
|
if res == "DELETE 0":
|
||||||
raise BadRequest("No messages were removed")
|
raise BadRequest("No messages were removed")
|
||||||
|
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
|
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_DELETE_BULK", payload))
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from litecord.errors import BadRequest, Forbidden
|
||||||
from winter import get_snowflake
|
from winter import get_snowflake
|
||||||
from litecord.system_messages import send_sys_message
|
from litecord.system_messages import send_sys_message
|
||||||
from litecord.pubsub.channel import gdm_recipient_view
|
from litecord.pubsub.channel import gdm_recipient_view
|
||||||
|
from litecord.pubsub.user import dispatch_user
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("dm_channels", __name__)
|
bp = Blueprint("dm_channels", __name__)
|
||||||
|
|
@ -82,11 +83,11 @@ async def gdm_create(user_id, peer_id) -> int:
|
||||||
await _raw_gdm_add(channel_id, user_id)
|
await _raw_gdm_add(channel_id, user_id)
|
||||||
await _raw_gdm_add(channel_id, peer_id)
|
await _raw_gdm_add(channel_id, peer_id)
|
||||||
|
|
||||||
await app.dispatcher.sub("channel", channel_id, user_id)
|
await app.dispatcher.channel.sub(channel_id, user_id)
|
||||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
await app.dispatcher.channel.sub(channel_id, peer_id)
|
||||||
|
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
|
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_CREATE", chan))
|
||||||
|
|
||||||
return channel_id
|
return channel_id
|
||||||
|
|
||||||
|
|
@ -104,13 +105,10 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
# the reasoning behind gdm_recipient_view is in its docstring.
|
# the reasoning behind gdm_recipient_view is in its docstring.
|
||||||
await app.dispatcher.dispatch(
|
await dispatch_user(peer_id, ("CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)))
|
||||||
"user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_UPDATE", chan))
|
||||||
|
await app.dispatcher.channel.sub(peer_id)
|
||||||
await app.dispatcher.sub("channel", peer_id)
|
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
|
await send_sys_message(channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id)
|
||||||
|
|
@ -128,17 +126,19 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
||||||
await _raw_gdm_remove(channel_id, peer_id)
|
await _raw_gdm_remove(channel_id, peer_id)
|
||||||
|
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch(
|
await dispatch_user(peer_id, ("CHANNEL_DELETE", gdm_recipient_view(chan, user_id)))
|
||||||
"user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.dispatcher.unsub("channel", peer_id)
|
await app.dispatcher.channel.unsub(peer_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel",
|
|
||||||
channel_id,
|
channel_id,
|
||||||
|
(
|
||||||
"CHANNEL_RECIPIENT_REMOVE",
|
"CHANNEL_RECIPIENT_REMOVE",
|
||||||
{"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
|
{
|
||||||
|
"channel_id": str(channel_id),
|
||||||
|
"user": await app.storage.get_user(peer_id),
|
||||||
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
author_id = peer_id if user_id is None else user_id
|
author_id = peer_id if user_id is None else user_id
|
||||||
|
|
@ -174,9 +174,8 @@ async def gdm_destroy(channel_id):
|
||||||
channel_id,
|
channel_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
|
await app.dispatcher.channel.dispatch(channel_id, ("CHANNEL_DELETE", chan))
|
||||||
|
await app.dispatcher.channel.drop(channel_id)
|
||||||
await app.dispatcher.remove("channel", channel_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
|
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -59,19 +59,9 @@ async def create_channel(guild_id):
|
||||||
new_channel_id = get_snowflake()
|
new_channel_id = get_snowflake()
|
||||||
await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
|
await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
|
||||||
|
|
||||||
# TODO: do a better method
|
|
||||||
# subscribe the currently subscribed users to the new channel
|
|
||||||
# by getting all user ids and subscribing each one by one.
|
|
||||||
|
|
||||||
# since GuildDispatcher calls Storage.get_channel_ids,
|
|
||||||
# it will subscribe all users to the newly created channel.
|
|
||||||
guild_pubsub = app.dispatcher.backends["guild"]
|
|
||||||
user_ids = guild_pubsub.state[guild_id]
|
|
||||||
for uid in user_ids:
|
|
||||||
await app.dispatcher.sub("guild", guild_id, uid)
|
|
||||||
|
|
||||||
chan = await app.storage.get_channel(new_channel_id)
|
chan = await app.storage.get_channel(new_channel_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
|
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_CREATE", chan))
|
||||||
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -79,7 +69,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
|
||||||
"""Fetch new information about the channel and dispatch
|
"""Fetch new information about the channel and dispatch
|
||||||
a single CHANNEL_UPDATE event to the guild."""
|
a single CHANNEL_UPDATE event to the guild."""
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
|
await app.dispatcher.guild.dispatch(guild_id, ("CHANNEL_UPDATE", chan))
|
||||||
|
|
||||||
|
|
||||||
async def _do_single_swap(guild_id: int, pair: tuple):
|
async def _do_single_swap(guild_id: int, pair: tuple):
|
||||||
|
|
|
||||||
|
|
@ -32,14 +32,15 @@ bp = Blueprint("guild.emoji", __name__)
|
||||||
|
|
||||||
async def _dispatch_emojis(guild_id):
|
async def _dispatch_emojis(guild_id):
|
||||||
"""Dispatch a Guild Emojis Update payload to a guild."""
|
"""Dispatch a Guild Emojis Update payload to a guild."""
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.guild.dispatch(
|
||||||
"guild",
|
|
||||||
guild_id,
|
guild_id,
|
||||||
|
(
|
||||||
"GUILD_EMOJIS_UPDATE",
|
"GUILD_EMOJIS_UPDATE",
|
||||||
{
|
{
|
||||||
"guild_id": str(guild_id),
|
"guild_id": str(guild_id),
|
||||||
"emojis": await app.storage.get_guild_emojis(guild_id),
|
"emojis": await app.storage.get_guild_emojis(guild_id),
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -192,12 +192,9 @@ async def modify_guild_member(guild_id, member_id):
|
||||||
if nick_flag:
|
if nick_flag:
|
||||||
partial["nick"] = j["nick"]
|
partial["nick"] = j["nick"]
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.lazy_guild.pres_update(guild_id, user_id, partial)
|
||||||
"lazy_guild", guild_id, "pres_update", user_id, partial
|
await app.dispatcher.guild.dispatch(
|
||||||
)
|
guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
|
||||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
@ -228,12 +225,9 @@ async def update_nickname(guild_id):
|
||||||
member.pop("joined_at")
|
member.pop("joined_at")
|
||||||
|
|
||||||
# call pres_update for nick changes, etc.
|
# call pres_update for nick changes, etc.
|
||||||
await app.dispatcher.dispatch(
|
await app.lazy_guild.pres_update(guild_id, user_id, {"nick": j["nick"]})
|
||||||
"lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
|
await app.dispatcher.guild.dispatch(
|
||||||
)
|
guild_id, ("GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member})
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
|
||||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return j["nick"]
|
return j["nick"]
|
||||||
|
|
|
||||||
|
|
@ -86,10 +86,12 @@ async def create_ban(guild_id, member_id):
|
||||||
|
|
||||||
await remove_member(guild_id, member_id)
|
await remove_member(guild_id, member_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.guild.dispatch(
|
||||||
guild_id,
|
guild_id,
|
||||||
|
(
|
||||||
"GUILD_BAN_ADD",
|
"GUILD_BAN_ADD",
|
||||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
@ -115,10 +117,12 @@ async def remove_ban(guild_id, banned_id):
|
||||||
if res == "DELETE 0":
|
if res == "DELETE 0":
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.guild.dispatch(
|
||||||
guild_id,
|
guild_id,
|
||||||
|
(
|
||||||
"GUILD_BAN_REMOVE",
|
"GUILD_BAN_REMOVE",
|
||||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
|
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,8 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
|
||||||
|
|
||||||
await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
|
await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.guild.dispatch(
|
||||||
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
|
guild_id, ("GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role})
|
||||||
)
|
)
|
||||||
|
|
||||||
return role
|
return role
|
||||||
|
|
@ -304,10 +304,9 @@ async def delete_guild_role(guild_id, role_id):
|
||||||
|
|
||||||
await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
|
await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.guild.dispatch(
|
||||||
guild_id,
|
guild_id,
|
||||||
"GUILD_ROLE_DELETE",
|
("GUILD_ROLE_DELETE", {"guild_id": str(guild_id), "role_id": str(role_id)},),
|
||||||
{"guild_id": str(guild_id), "role_id": str(role_id)},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -191,8 +191,9 @@ async def create_guild():
|
||||||
|
|
||||||
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||||
|
|
||||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
await app.dispatcher.guild.sub_user(guild_id, user_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
|
|
||||||
|
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_CREATE", guild_total))
|
||||||
return jsonify(guild_total)
|
return jsonify(guild_total)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -350,7 +351,7 @@ async def _update_guild(guild_id):
|
||||||
)
|
)
|
||||||
|
|
||||||
guild = await app.storage.get_guild_full(guild_id, user_id)
|
guild = await app.storage.get_guild_full(guild_id, user_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||||
return jsonify(guild)
|
return jsonify(guild)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -116,7 +116,7 @@ async def _inv_check_age(inv: dict):
|
||||||
await delete_invite(inv["code"])
|
await delete_invite(inv["code"])
|
||||||
raise InvalidInvite("Invite is expired")
|
raise InvalidInvite("Invite is expired")
|
||||||
|
|
||||||
if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
|
if inv["max_uses"] != -1 and inv["uses"] > inv["max_uses"]:
|
||||||
await delete_invite(inv["code"])
|
await delete_invite(inv["code"])
|
||||||
raise InvalidInvite("Too many uses")
|
raise InvalidInvite("Too many uses")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from ..auth import token_check
|
||||||
from ..schemas import validate, RELATIONSHIP, SPECIFIC_FRIEND
|
from ..schemas import validate, RELATIONSHIP, SPECIFIC_FRIEND
|
||||||
from ..enums import RelationshipType
|
from ..enums import RelationshipType
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
|
from litecord.pubsub.user import dispatch_user
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint("relationship", __name__)
|
bp = Blueprint("relationship", __name__)
|
||||||
|
|
@ -36,17 +37,17 @@ async def get_me_relationships():
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch_single_pres(user_id, presence: dict):
|
async def _dispatch_single_pres(user_id, presence: dict):
|
||||||
await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
|
await dispatch_user(user_id, ("PRESENCE_UPDATE", presence))
|
||||||
|
|
||||||
|
|
||||||
async def _unsub_friend(user_id, peer_id):
|
async def _unsub_friend(user_id, peer_id):
|
||||||
await app.dispatcher.unsub("friend", user_id, peer_id)
|
await app.dispatcher.friend.unsub(user_id, peer_id)
|
||||||
await app.dispatcher.unsub("friend", peer_id, user_id)
|
await app.dispatcher.friend.unsub(peer_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _sub_friend(user_id, peer_id):
|
async def _sub_friend(user_id, peer_id):
|
||||||
await app.dispatcher.sub("friend", user_id, peer_id)
|
await app.dispatcher.friend.sub(user_id, peer_id)
|
||||||
await app.dispatcher.sub("friend", peer_id, user_id)
|
await app.dispatcher.friend.sub(peer_id, user_id)
|
||||||
|
|
||||||
# dispatch presence update to the user and peer about
|
# dispatch presence update to the user and peer about
|
||||||
# eachother's presence.
|
# eachother's presence.
|
||||||
|
|
@ -107,8 +108,8 @@ async def make_friend(
|
||||||
_friend,
|
_friend,
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(
|
await dispatch_user(
|
||||||
peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
|
peer_id, ("RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)})
|
||||||
)
|
)
|
||||||
|
|
||||||
await _unsub_friend(user_id, peer_id)
|
await _unsub_friend(user_id, peer_id)
|
||||||
|
|
@ -130,35 +131,41 @@ async def make_friend(
|
||||||
_friend,
|
_friend,
|
||||||
)
|
)
|
||||||
|
|
||||||
_dispatch = app.dispatcher.dispatch_user
|
_dispatch = dispatch_user
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# accepted a friend request, dispatch respective
|
# accepted a friend request, dispatch respective
|
||||||
# relationship events
|
# relationship events
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
user_id,
|
user_id,
|
||||||
|
(
|
||||||
"RELATIONSHIP_REMOVE",
|
"RELATIONSHIP_REMOVE",
|
||||||
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
|
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
user_id,
|
user_id,
|
||||||
|
(
|
||||||
"RELATIONSHIP_ADD",
|
"RELATIONSHIP_ADD",
|
||||||
{
|
{
|
||||||
"type": _friend,
|
"type": _friend,
|
||||||
"id": str(peer_id),
|
"id": str(peer_id),
|
||||||
"user": await app.storage.get_user(peer_id),
|
"user": await app.storage.get_user(peer_id),
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
peer_id,
|
peer_id,
|
||||||
"RELATIONSHIP_ADD",
|
"RELATIONSHIP_ADD",
|
||||||
|
(
|
||||||
{
|
{
|
||||||
"type": _friend,
|
"type": _friend,
|
||||||
"id": str(user_id),
|
"id": str(user_id),
|
||||||
"user": await app.storage.get_user(user_id),
|
"user": await app.storage.get_user(user_id),
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await _sub_friend(user_id, peer_id)
|
await _sub_friend(user_id, peer_id)
|
||||||
|
|
@ -169,22 +176,26 @@ async def make_friend(
|
||||||
if rel_type == _friend:
|
if rel_type == _friend:
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
user_id,
|
user_id,
|
||||||
|
(
|
||||||
"RELATIONSHIP_ADD",
|
"RELATIONSHIP_ADD",
|
||||||
{
|
{
|
||||||
"id": str(peer_id),
|
"id": str(peer_id),
|
||||||
"type": RelationshipType.OUTGOING.value,
|
"type": RelationshipType.OUTGOING.value,
|
||||||
"user": await app.storage.get_user(peer_id),
|
"user": await app.storage.get_user(peer_id),
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
peer_id,
|
peer_id,
|
||||||
|
(
|
||||||
"RELATIONSHIP_ADD",
|
"RELATIONSHIP_ADD",
|
||||||
{
|
{
|
||||||
"id": str(user_id),
|
"id": str(user_id),
|
||||||
"type": RelationshipType.INCOMING.value,
|
"type": RelationshipType.INCOMING.value,
|
||||||
"user": await app.storage.get_user(user_id),
|
"user": await app.storage.get_user(user_id),
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# we don't make the pubsub link
|
# we don't make the pubsub link
|
||||||
|
|
@ -240,7 +251,7 @@ async def add_relationship(peer_id: int):
|
||||||
# make_friend did not succeed, so we
|
# make_friend did not succeed, so we
|
||||||
# assume it is a block and dispatch
|
# assume it is a block and dispatch
|
||||||
# the respective RELATIONSHIP_ADD.
|
# the respective RELATIONSHIP_ADD.
|
||||||
await app.dispatcher.dispatch_user(
|
await dispatch_user(
|
||||||
user_id,
|
user_id,
|
||||||
"RELATIONSHIP_ADD",
|
"RELATIONSHIP_ADD",
|
||||||
{
|
{
|
||||||
|
|
@ -261,7 +272,7 @@ async def remove_relationship(peer_id: int):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_friend = RelationshipType.FRIEND.value
|
_friend = RelationshipType.FRIEND.value
|
||||||
_block = RelationshipType.BLOCK.value
|
_block = RelationshipType.BLOCK.value
|
||||||
_dispatch = app.dispatcher.dispatch_user
|
_dispatch = dispatch_user
|
||||||
|
|
||||||
rel_type = await app.db.fetchval(
|
rel_type = await app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
|
|
@ -307,7 +318,8 @@ async def remove_relationship(peer_id: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
|
user_id,
|
||||||
|
("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}),
|
||||||
)
|
)
|
||||||
|
|
||||||
peer_del_type = (
|
peer_del_type = (
|
||||||
|
|
@ -315,7 +327,8 @@ async def remove_relationship(peer_id: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
|
peer_id,
|
||||||
|
("RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}),
|
||||||
)
|
)
|
||||||
|
|
||||||
await _unsub_friend(user_id, peer_id)
|
await _unsub_friend(user_id, peer_id)
|
||||||
|
|
@ -334,7 +347,7 @@ async def remove_relationship(peer_id: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch(
|
await _dispatch(
|
||||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
|
user_id, ("RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block})
|
||||||
)
|
)
|
||||||
|
|
||||||
await _unsub_friend(user_id, peer_id)
|
await _unsub_friend(user_id, peer_id)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from quart import Blueprint, jsonify, request, current_app as app
|
||||||
from litecord.auth import token_check
|
from litecord.auth import token_check
|
||||||
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
|
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
|
||||||
from litecord.blueprints.checks import guild_check
|
from litecord.blueprints.checks import guild_check
|
||||||
|
from litecord.pubsub.user import dispatch_user
|
||||||
|
|
||||||
bp = Blueprint("users_settings", __name__)
|
bp = Blueprint("users_settings", __name__)
|
||||||
|
|
||||||
|
|
@ -58,7 +59,7 @@ async def patch_current_settings():
|
||||||
)
|
)
|
||||||
|
|
||||||
settings = await app.user_storage.get_user_settings(user_id)
|
settings = await app.user_storage.get_user_settings(user_id)
|
||||||
await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
|
await dispatch_user(user_id, ("USER_SETTINGS_UPDATE", settings))
|
||||||
return jsonify(settings)
|
return jsonify(settings)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -123,7 +124,7 @@ async def patch_guild_settings(guild_id: int):
|
||||||
|
|
||||||
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
|
await dispatch_user(user_id, ("USER_GUILD_SETTINGS_UPDATE", settings))
|
||||||
|
|
||||||
return jsonify(settings)
|
return jsonify(settings)
|
||||||
|
|
||||||
|
|
@ -157,8 +158,8 @@ async def put_note(target_id: int):
|
||||||
note,
|
note,
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(
|
await dispatch_user(
|
||||||
user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
|
user_id, ("USER_NOTE_UPDATE", {"id": str(target_id), "note": note})
|
||||||
)
|
)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -156,11 +156,12 @@ async def webhook_token_check(webhook_id: int, webhook_token: str):
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch_webhook_update(guild_id: int, channel_id):
|
async def _dispatch_webhook_update(guild_id: int, channel_id):
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.guild.dispatch(
|
||||||
"guild",
|
|
||||||
guild_id,
|
guild_id,
|
||||||
|
(
|
||||||
"WEBHOOKS_UPDATE",
|
"WEBHOOKS_UPDATE",
|
||||||
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -503,7 +504,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
||||||
|
|
||||||
payload = await app.storage.get_message(message_id)
|
payload = await app.storage.get_message(message_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", payload))
|
||||||
|
|
||||||
# spawn embedder in the background, even when we're on a webhook.
|
# spawn embedder in the background, even when we're on a webhook.
|
||||||
app.sched.spawn(process_url_embed(payload))
|
app.sched.spawn(process_url_embed(payload))
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ from quart import current_app as app
|
||||||
|
|
||||||
from litecord.errors import ForbiddenDM
|
from litecord.errors import ForbiddenDM
|
||||||
from litecord.enums import RelationshipType
|
from litecord.enums import RelationshipType
|
||||||
|
from litecord.pubsub.member import dispatch_member
|
||||||
|
from litecord.pubsub.user import dispatch_user
|
||||||
|
|
||||||
|
|
||||||
async def channel_ack(
|
async def channel_ack(
|
||||||
|
|
@ -54,20 +56,24 @@ async def channel_ack(
|
||||||
)
|
)
|
||||||
|
|
||||||
if guild_id:
|
if guild_id:
|
||||||
await app.dispatcher.dispatch_user_guild(
|
await dispatch_member(
|
||||||
user_id,
|
|
||||||
guild_id,
|
guild_id,
|
||||||
|
user_id,
|
||||||
|
(
|
||||||
"MESSAGE_ACK",
|
"MESSAGE_ACK",
|
||||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# we don't use ChannelDispatcher here because since
|
# we don't use ChannelDispatcher here because since
|
||||||
# guild_id is None, all user devices are already subscribed
|
# guild_id is None, all user devices are already subscribed
|
||||||
# to the given channel (a dm or a group dm)
|
# to the given channel (a dm or a group dm)
|
||||||
await app.dispatcher.dispatch_user(
|
await dispatch_user(
|
||||||
user_id,
|
user_id,
|
||||||
|
(
|
||||||
"MESSAGE_ACK",
|
"MESSAGE_ACK",
|
||||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
from ..snowflake import get_snowflake
|
from ..snowflake import get_snowflake
|
||||||
from ..permissions import get_role_perms
|
from ..permissions import get_role_perms, get_permissions
|
||||||
from ..utils import dict_get, maybe_lazy_guild_dispatch
|
from ..utils import dict_get, maybe_lazy_guild_dispatch
|
||||||
from ..enums import ChannelType
|
from ..enums import ChannelType
|
||||||
|
from litecord.pubsub.member import dispatch_member
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -41,21 +43,20 @@ async def remove_member(guild_id: int, member_id: int):
|
||||||
member_id,
|
member_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user_guild(
|
await dispatch_member(
|
||||||
member_id,
|
|
||||||
guild_id,
|
guild_id,
|
||||||
"GUILD_DELETE",
|
member_id,
|
||||||
{"guild_id": str(guild_id), "unavailable": False},
|
("GUILD_DELETE", {"guild_id": str(guild_id), "unavailable": False}),
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.unsub("guild", guild_id, member_id)
|
await app.dispatcher.guild.unsub(guild_id, member_id)
|
||||||
|
await app.lazy_guild.remove_member(member_id)
|
||||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
await app.dispatcher.guild.dispatch(
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
|
||||||
guild_id,
|
guild_id,
|
||||||
|
(
|
||||||
"GUILD_MEMBER_REMOVE",
|
"GUILD_MEMBER_REMOVE",
|
||||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -108,8 +109,8 @@ async def create_role(guild_id, name: str, **kwargs):
|
||||||
# we need to update the lazy guild handlers for the newly created group
|
# we need to update the lazy guild handlers for the newly created group
|
||||||
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
|
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.guild.dispatch(
|
||||||
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
guild_id, ("GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role})
|
||||||
)
|
)
|
||||||
|
|
||||||
return role
|
return role
|
||||||
|
|
@ -137,6 +138,39 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _subscribe_users_new_channel(guild_id: int, channel_id: int) -> None:
|
||||||
|
|
||||||
|
# for each state currently subscribed to guild, we check on the database
|
||||||
|
# which states can also subscribe to the new channel at its creation.
|
||||||
|
|
||||||
|
# the list of users that can subscribe are then used again for a pass
|
||||||
|
# over the states and states that have user ids in that list become
|
||||||
|
# subscribers of the new channel.
|
||||||
|
users_to_sub: List[str] = []
|
||||||
|
|
||||||
|
for session_id in app.dispatcher.guild.state[guild_id]:
|
||||||
|
try:
|
||||||
|
state = app.state_manager.fetch_raw(session_id)
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if state.user_id in users_to_sub:
|
||||||
|
continue
|
||||||
|
|
||||||
|
perms = await get_permissions(state.user_id, channel_id)
|
||||||
|
if perms.read_messages:
|
||||||
|
users_to_sub.append(state.user_id)
|
||||||
|
|
||||||
|
for session_id in app.dispatcher.guild.state[guild_id]:
|
||||||
|
try:
|
||||||
|
state = app.state_manager.fetch_raw(session_id)
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if state.user_id in users_to_sub:
|
||||||
|
await app.dispatcher.channel.sub(channel_id, session_id)
|
||||||
|
|
||||||
|
|
||||||
async def create_guild_channel(
|
async def create_guild_channel(
|
||||||
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
||||||
):
|
):
|
||||||
|
|
@ -180,6 +214,8 @@ async def create_guild_channel(
|
||||||
# so we use this function.
|
# so we use this function.
|
||||||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||||
|
|
||||||
|
await _subscribe_users_new_channel(guild_id, channel_id)
|
||||||
|
|
||||||
|
|
||||||
async def _del_from_table(table: str, user_id: int):
|
async def _del_from_table(table: str, user_id: int):
|
||||||
"""Delete a row from a table."""
|
"""Delete a row from a table."""
|
||||||
|
|
@ -206,21 +242,22 @@ async def delete_guild(guild_id: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Discord's client expects IDs being string
|
# Discord's client expects IDs being string
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.guild.dispatch(
|
||||||
"guild",
|
|
||||||
guild_id,
|
guild_id,
|
||||||
|
(
|
||||||
"GUILD_DELETE",
|
"GUILD_DELETE",
|
||||||
{
|
{
|
||||||
"guild_id": str(guild_id),
|
"guild_id": str(guild_id),
|
||||||
"id": str(guild_id),
|
"id": str(guild_id),
|
||||||
# 'unavailable': False,
|
# 'unavailable': False,
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove from the dispatcher so nobody
|
# remove from the dispatcher so nobody
|
||||||
# becomes the little memer that tries to fuck up with
|
# becomes the little memer that tries to fuck up with
|
||||||
# everybody's gateway
|
# everybody's gateway
|
||||||
await app.dispatcher.remove("guild", guild_id)
|
await app.dispatcher.guild.drop(guild_id)
|
||||||
|
|
||||||
|
|
||||||
async def create_guild_settings(guild_id: int, user_id: int):
|
async def create_guild_settings(guild_id: int, user_id: int):
|
||||||
|
|
@ -285,18 +322,17 @@ async def add_member(guild_id: int, user_id: int, *, basic=False):
|
||||||
|
|
||||||
# tell current members a new member came up
|
# tell current members a new member came up
|
||||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.guild.dispatch(
|
||||||
guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
|
guild_id, ("GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}})
|
||||||
)
|
)
|
||||||
|
|
||||||
# update member lists for the new member
|
# pubsub changes for new member
|
||||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
|
await app.lazy_guild.new_member(guild_id, user_id)
|
||||||
|
states = await app.dispatcher.guild.sub_user(guild_id, user_id)
|
||||||
|
|
||||||
# subscribe new member to guild, so they get events n stuff
|
|
||||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
|
||||||
|
|
||||||
# tell the new member that theres the guild it just joined.
|
|
||||||
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
|
||||||
# just to the shards that are actually tied to it.
|
|
||||||
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||||
await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
|
for state in states:
|
||||||
|
try:
|
||||||
|
await state.ws.dispatch("GUILD_CREATE", guild)
|
||||||
|
except Exception:
|
||||||
|
log.exception("failed to dispatch to session_id={!r}", state.session_id)
|
||||||
|
|
|
||||||
|
|
@ -29,41 +29,48 @@ from ..snowflake import get_snowflake
|
||||||
from ..errors import BadRequest
|
from ..errors import BadRequest
|
||||||
from ..auth import hash_data
|
from ..auth import hash_data
|
||||||
from ..utils import rand_hex
|
from ..utils import rand_hex
|
||||||
|
from ..pubsub.user import dispatch_user
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def mass_user_update(user_id: int):
|
async def mass_user_update(user_id: int) -> Tuple[dict, dict]:
|
||||||
"""Dispatch USER_UPDATE in a mass way."""
|
"""Dispatch a USER_UPDATE to everyone that is subscribed to the user.
|
||||||
# by using dispatch_with_filter
|
|
||||||
# we're guaranteeing all shards will get
|
|
||||||
# a USER_UPDATE once and not any others.
|
|
||||||
|
|
||||||
|
This function guarantees all states will get one USER_UPDATE for simple
|
||||||
|
cases. Lazy guild users might get updates N times depending of how many
|
||||||
|
lists are they subscribed to.
|
||||||
|
"""
|
||||||
session_ids: List[str] = []
|
session_ids: List[str] = []
|
||||||
|
|
||||||
public_user = await app.storage.get_user(user_id)
|
public_user = await app.storage.get_user(user_id)
|
||||||
private_user = await app.storage.get_user(user_id, secure=True)
|
private_user = await app.storage.get_user(user_id, secure=True)
|
||||||
|
|
||||||
|
session_ids.extend(await dispatch_user(user_id, ("USER_UPDATE", private_user)))
|
||||||
|
|
||||||
|
guild_ids: List[int] = await app.user_storage.get_user_guilds(user_id)
|
||||||
|
friend_ids: List[int] = await app.user_storage.get_friend_ids(user_id)
|
||||||
|
|
||||||
|
for guild_id in guild_ids:
|
||||||
session_ids.extend(
|
session_ids.extend(
|
||||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
|
await app.dispatcher.guild.dispatch_filter(
|
||||||
|
guild_id,
|
||||||
|
lambda sess_id: sess_id not in session_ids,
|
||||||
|
("USER_UPDATE", public_user),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
for friend_id in friend_ids:
|
||||||
friend_ids = await app.user_storage.get_friend_ids(user_id)
|
|
||||||
|
|
||||||
session_ids.extend(
|
session_ids.extend(
|
||||||
await app.dispatcher.dispatch_many_filter_list(
|
await app.dispatcher.friend.dispatch_filter(
|
||||||
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
|
friend_id,
|
||||||
|
lambda sess_id: sess_id not in session_ids,
|
||||||
|
("USER_UPDATE", public_user),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
session_ids.extend(
|
for guild_id in guild_ids:
|
||||||
await app.dispatcher.dispatch_many_filter_list(
|
await app.lazy_guild.update_user(guild_id, user_id)
|
||||||
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
|
|
||||||
|
|
||||||
return public_user, private_user
|
return public_user, private_user
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,17 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Any, Dict
|
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .pubsub import (
|
from .pubsub import (
|
||||||
GuildDispatcher,
|
GuildDispatcher,
|
||||||
MemberDispatcher,
|
|
||||||
UserDispatcher,
|
|
||||||
ChannelDispatcher,
|
ChannelDispatcher,
|
||||||
FriendDispatcher,
|
FriendDispatcher,
|
||||||
LazyGuildDispatcher,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -50,169 +45,7 @@ class EventDispatcher:
|
||||||
its subscriber ids.
|
its subscriber ids.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self):
|
||||||
self.state_manager = app.state_manager
|
self.guild: GuildDispatcher = GuildDispatcher()
|
||||||
self.app = app
|
self.channel = ChannelDispatcher()
|
||||||
|
self.friend = FriendDispatcher()
|
||||||
self.backends = {
|
|
||||||
"guild": GuildDispatcher(self),
|
|
||||||
"member": MemberDispatcher(self),
|
|
||||||
"channel": ChannelDispatcher(self),
|
|
||||||
"user": UserDispatcher(self),
|
|
||||||
"friend": FriendDispatcher(self),
|
|
||||||
"lazy_guild": LazyGuildDispatcher(self),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
|
||||||
"""Send an action regarding a key/identifier pair to a backend.
|
|
||||||
|
|
||||||
Action is usually "sub" or "unsub".
|
|
||||||
"""
|
|
||||||
backend = self.backends[backend_str]
|
|
||||||
method = getattr(backend, action)
|
|
||||||
|
|
||||||
# convert keys to the types the backend wants
|
|
||||||
key = backend.KEY_TYPE(key)
|
|
||||||
identifier = backend.VAL_TYPE(identifier)
|
|
||||||
|
|
||||||
return await method(key, identifier, *args)
|
|
||||||
|
|
||||||
async def subscribe(
|
|
||||||
self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
|
|
||||||
):
|
|
||||||
"""Subscribe a single element to the given backend."""
|
|
||||||
flags = flags or {}
|
|
||||||
|
|
||||||
log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
|
|
||||||
|
|
||||||
# this is a hacky solution for backwards compatibility between backends
|
|
||||||
# that implement flags and backends that don't.
|
|
||||||
|
|
||||||
# passing flags to backends that don't implement flags will
|
|
||||||
# cause errors as expected.
|
|
||||||
if flags:
|
|
||||||
return await self.action(backend, "sub", key, identifier, flags)
|
|
||||||
|
|
||||||
return await self.action(backend, "sub", key, identifier)
|
|
||||||
|
|
||||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
|
||||||
"""Unsubscribe an element from the given backend."""
|
|
||||||
log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
|
|
||||||
|
|
||||||
return await self.action(backend, "unsub", key, identifier)
|
|
||||||
|
|
||||||
async def sub(self, backend, key, identifier):
|
|
||||||
"""Alias to subscribe()."""
|
|
||||||
return await self.subscribe(backend, key, identifier)
|
|
||||||
|
|
||||||
async def unsub(self, backend, key, identifier):
|
|
||||||
"""Alias to unsubscribe()."""
|
|
||||||
return await self.unsubscribe(backend, key, identifier)
|
|
||||||
|
|
||||||
async def sub_many(
|
|
||||||
self,
|
|
||||||
backend_str: str,
|
|
||||||
identifier: Any,
|
|
||||||
keys: list,
|
|
||||||
flags: Dict[str, Any] = None,
|
|
||||||
):
|
|
||||||
"""Subscribe to multiple channels (all in a single backend)
|
|
||||||
at a time.
|
|
||||||
|
|
||||||
Usually used when connecting to the gateway and the client
|
|
||||||
needs to subscribe to all their guids.
|
|
||||||
"""
|
|
||||||
flags = flags or {}
|
|
||||||
for key in keys:
|
|
||||||
await self.subscribe(backend_str, key, identifier, flags)
|
|
||||||
|
|
||||||
async def mass_sub(self, identifier: Any, backends: List[tuple]):
|
|
||||||
"""Mass subscribe to many backends at once."""
|
|
||||||
for bcall in backends:
|
|
||||||
backend_str, keys = bcall[0], bcall[1]
|
|
||||||
|
|
||||||
if len(bcall) == 2:
|
|
||||||
flags = {}
|
|
||||||
elif len(bcall) == 3:
|
|
||||||
# we have flags
|
|
||||||
flags = bcall[2]
|
|
||||||
|
|
||||||
log.debug(
|
|
||||||
"subscribing {} to {} keys in backend {}, flags: {}",
|
|
||||||
identifier,
|
|
||||||
len(keys),
|
|
||||||
backend_str,
|
|
||||||
flags,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.sub_many(backend_str, identifier, keys, flags)
|
|
||||||
|
|
||||||
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
|
||||||
"""Dispatch an event to the backend.
|
|
||||||
|
|
||||||
The backend is responsible for everything regarding the
|
|
||||||
actual dispatch.
|
|
||||||
"""
|
|
||||||
backend = self.backends[backend_str]
|
|
||||||
|
|
||||||
# convert types
|
|
||||||
key = backend.KEY_TYPE(key)
|
|
||||||
return await backend.dispatch(key, *args, **kwargs)
|
|
||||||
|
|
||||||
async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
|
|
||||||
"""Dispatch to multiple keys in a single backend."""
|
|
||||||
log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
await self.dispatch(backend_str, key, *args, **kwargs)
|
|
||||||
|
|
||||||
async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
|
|
||||||
"""Dispatch to a backend that only accepts
|
|
||||||
(event, data) arguments with an optional filter
|
|
||||||
function."""
|
|
||||||
backend = self.backends[backend_str]
|
|
||||||
key = backend.KEY_TYPE(key)
|
|
||||||
return await backend.dispatch_filter(key, func, *args)
|
|
||||||
|
|
||||||
async def dispatch_many_filter_list(
|
|
||||||
self, backend_str: str, keys: List[Any], sess_list: List[str], *args
|
|
||||||
):
|
|
||||||
"""Make a "unique" dispatch given a list of session ids.
|
|
||||||
|
|
||||||
This only works for backends that have a dispatch_filter
|
|
||||||
handler and return session id lists in their dispatch
|
|
||||||
results.
|
|
||||||
"""
|
|
||||||
for key in keys:
|
|
||||||
sess_list.extend(
|
|
||||||
await self.dispatch_filter(
|
|
||||||
backend_str, key, lambda sess_id: sess_id not in sess_list, *args
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return sess_list
|
|
||||||
|
|
||||||
async def reset(self, backend_str: str, key: Any):
|
|
||||||
"""Reset the bucket in the given backend."""
|
|
||||||
backend = self.backends[backend_str]
|
|
||||||
key = backend.KEY_TYPE(key)
|
|
||||||
return await backend.reset(key)
|
|
||||||
|
|
||||||
async def remove(self, backend_str: str, key: Any):
|
|
||||||
"""Remove a key from the backend. This
|
|
||||||
might be a different operation than resetting."""
|
|
||||||
backend = self.backends[backend_str]
|
|
||||||
key = backend.KEY_TYPE(key)
|
|
||||||
return await backend.remove(key)
|
|
||||||
|
|
||||||
async def dispatch_guild(self, guild_id, event, data):
|
|
||||||
"""Backwards compatibility with old EventDispatcher."""
|
|
||||||
return await self.dispatch("guild", guild_id, event, data)
|
|
||||||
|
|
||||||
async def dispatch_user_guild(self, user_id, guild_id, event, data):
|
|
||||||
"""Backwards compatibility with old EventDispatcher."""
|
|
||||||
return await self.dispatch("member", (guild_id, user_id), event, data)
|
|
||||||
|
|
||||||
async def dispatch_user(self, user_id, event, data):
|
|
||||||
"""Backwards compatibility with old EventDispatcher."""
|
|
||||||
return await self.dispatch("user", user_id, event, data)
|
|
||||||
|
|
|
||||||
|
|
@ -87,8 +87,8 @@ async def msg_update_embeds(payload, new_embeds):
|
||||||
if "flags" in payload:
|
if "flags" in payload:
|
||||||
update_payload["flags"] = payload["flags"]
|
update_payload["flags"] = payload["flags"]
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.channel.dispatch(
|
||||||
"channel", channel_id, "MESSAGE_UPDATE", update_payload
|
channel_id, ("MESSAGE_UPDATE", update_payload)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from typing import Optional
|
||||||
from litecord.gateway.websocket import GatewayWebsocket
|
from litecord.gateway.websocket import GatewayWebsocket
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -44,15 +45,16 @@ async def websocket_handler(app, ws, url):
|
||||||
return await ws.close(1000, "Invalid gateway encoding")
|
return await ws.close(1000, "Invalid gateway encoding")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
gw_compress = args["compress"][0]
|
gw_compress: Optional[str] = args["compress"][0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
gw_compress = None
|
gw_compress = None
|
||||||
|
|
||||||
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
|
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
|
||||||
return await ws.close(1000, "Invalid gateway compress")
|
return await ws.close(1000, "Invalid gateway compress")
|
||||||
|
|
||||||
|
async with app.app_context():
|
||||||
gws = GatewayWebsocket(
|
gws = GatewayWebsocket(
|
||||||
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
ws, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
||||||
)
|
)
|
||||||
|
|
||||||
# this can be run with a single await since this whole coroutine
|
# this can be run with a single await since this whole coroutine
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import asyncio
|
||||||
from typing import List
|
from typing import List
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
from websockets.exceptions import ConnectionClosed
|
from websockets.exceptions import ConnectionClosed
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
@ -225,3 +226,16 @@ class StateManager:
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the state manager."""
|
"""Close the state manager."""
|
||||||
self.closed = True
|
self.closed = True
|
||||||
|
|
||||||
|
async def fetch_user_states_for_channel(
|
||||||
|
self, channel_id: int, user_id: int
|
||||||
|
) -> List[GatewayState]:
|
||||||
|
"""Get a list of gateway states for a user that can receive events on a certain channel."""
|
||||||
|
# TODO optimize this with an in-memory store
|
||||||
|
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||||
|
|
||||||
|
if guild_id:
|
||||||
|
return self.fetch_states(user_id, guild_id)
|
||||||
|
|
||||||
|
# DMs and GDMs use all user states
|
||||||
|
return self.user_states(user_id)
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from random import randint
|
||||||
import websockets
|
import websockets
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
from quart import current_app as app
|
||||||
|
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from litecord.enums import RelationshipType, ChannelType
|
from litecord.enums import RelationshipType, ChannelType
|
||||||
|
|
@ -43,7 +44,6 @@ from litecord.presence import BasePresence
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
from litecord.gateway.state import GatewayState
|
from litecord.gateway.state import GatewayState
|
||||||
|
|
||||||
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
|
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
|
||||||
from litecord.gateway.errors import (
|
from litecord.gateway.errors import (
|
||||||
DecodeError,
|
DecodeError,
|
||||||
|
|
@ -52,8 +52,9 @@ from litecord.gateway.errors import (
|
||||||
ShardingRequired,
|
ShardingRequired,
|
||||||
)
|
)
|
||||||
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
|
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
|
||||||
|
|
||||||
from litecord.gateway.utils import WebsocketFileHandler
|
from litecord.gateway.utils import WebsocketFileHandler
|
||||||
|
from litecord.pubsub.guild import GuildFlags
|
||||||
|
from litecord.pubsub.channel import ChannelFlags
|
||||||
|
|
||||||
from litecord.storage import int_
|
from litecord.storage import int_
|
||||||
|
|
||||||
|
|
@ -67,7 +68,7 @@ WebsocketProperties = collections.namedtuple(
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
"""Main gateway websocket logic."""
|
"""Main gateway websocket logic."""
|
||||||
|
|
||||||
def __init__(self, ws, app, **kwargs):
|
def __init__(self, ws, **kwargs):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.storage = app.storage
|
self.storage = app.storage
|
||||||
self.user_storage = app.user_storage
|
self.user_storage = app.user_storage
|
||||||
|
|
@ -230,7 +231,7 @@ class GatewayWebsocket:
|
||||||
if task:
|
if task:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
|
self.wsp.tasks["heartbeat"] = app.sched.spawn(
|
||||||
task_wrapper("hb wait", self._hb_wait(interval))
|
task_wrapper("hb wait", self._hb_wait(interval))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -247,6 +248,7 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def dispatch(self, event: str, data: Any):
|
async def dispatch(self, event: str, data: Any):
|
||||||
"""Dispatch an event to the websocket."""
|
"""Dispatch an event to the websocket."""
|
||||||
|
assert self.state is not None
|
||||||
self.state.seq += 1
|
self.state.seq += 1
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
|
|
@ -282,6 +284,7 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
|
async def _guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]):
|
||||||
"""Dispatch GUILD_CREATE information."""
|
"""Dispatch GUILD_CREATE information."""
|
||||||
|
assert self.state is not None
|
||||||
|
|
||||||
# Users don't get asynchronous guild dispatching.
|
# Users don't get asynchronous guild dispatching.
|
||||||
if not self.state.bot:
|
if not self.state.bot:
|
||||||
|
|
@ -360,9 +363,7 @@ class GatewayWebsocket:
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||||
|
app.sched.spawn(self._guild_dispatch(guilds))
|
||||||
# async dispatch of guilds
|
|
||||||
self.app.loop.create_task(self._guild_dispatch(guilds))
|
|
||||||
|
|
||||||
async def _check_shards(self, shard, user_id):
|
async def _check_shards(self, shard, user_id):
|
||||||
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
||||||
|
|
@ -412,6 +413,7 @@ class GatewayWebsocket:
|
||||||
Note: subscribing to channels is already handled
|
Note: subscribing to channels is already handled
|
||||||
by GuildDispatcher.sub
|
by GuildDispatcher.sub
|
||||||
"""
|
"""
|
||||||
|
assert self.state is not None
|
||||||
user_id = self.state.user_id
|
user_id = self.state.user_id
|
||||||
guild_ids = await self._guild_ids()
|
guild_ids = await self._guild_ids()
|
||||||
|
|
||||||
|
|
@ -434,26 +436,50 @@ class GatewayWebsocket:
|
||||||
# (presence and typing events)
|
# (presence and typing events)
|
||||||
|
|
||||||
# we enable processing of guild_subscriptions by adding flags
|
# we enable processing of guild_subscriptions by adding flags
|
||||||
# when subscribing to the given backend. those are optional.
|
# when subscribing to the given backend.
|
||||||
channels_to_sub = [
|
session_id = self.state.session_id
|
||||||
(
|
channel_ids: List[int] = []
|
||||||
"guild",
|
|
||||||
guild_ids,
|
|
||||||
{"presence": guild_subscriptions, "typing": guild_subscriptions},
|
|
||||||
),
|
|
||||||
("channel", dm_ids),
|
|
||||||
("channel", gdm_ids),
|
|
||||||
]
|
|
||||||
|
|
||||||
await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
|
for guild_id in guild_ids:
|
||||||
|
await app.dispatcher.guild.sub_with_flags(
|
||||||
|
guild_id,
|
||||||
|
session_id,
|
||||||
|
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
|
||||||
|
)
|
||||||
|
|
||||||
|
# instead of calculating which channels to subscribe to
|
||||||
|
# inside guild dispatcher, we calculate them in here, so that
|
||||||
|
# we remove complexity of the dispatcher.
|
||||||
|
|
||||||
|
guild_chan_ids = await app.storage.get_channel_ids(guild_id)
|
||||||
|
for channel_id in guild_chan_ids:
|
||||||
|
perms = await get_permissions(
|
||||||
|
self.state.user_id, channel_id, storage=self.storage
|
||||||
|
)
|
||||||
|
|
||||||
|
if perms.bits.read_messages:
|
||||||
|
channel_ids.append(channel_id)
|
||||||
|
|
||||||
|
log.info("subscribing to {} guild channels", len(channel_ids))
|
||||||
|
for channel_id in channel_ids:
|
||||||
|
await app.dispatcher.channel.sub_with_flags(
|
||||||
|
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
|
||||||
|
)
|
||||||
|
|
||||||
|
for dm_id in dm_ids:
|
||||||
|
await app.dispatcher.channel.sub(dm_id, session_id)
|
||||||
|
|
||||||
|
for gdm_id in gdm_ids:
|
||||||
|
await app.dispatcher.channel.sub(gdm_id, session_id)
|
||||||
|
|
||||||
if not self.state.bot:
|
|
||||||
# subscribe to all friends
|
# subscribe to all friends
|
||||||
# (their friends will also subscribe back
|
# (their friends will also subscribe back
|
||||||
# when they come online)
|
# when they come online)
|
||||||
|
if not self.state.bot:
|
||||||
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
||||||
log.info("subscribing to {} friends", len(friend_ids))
|
log.info("subscribing to {} friends", len(friend_ids))
|
||||||
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
for friend_id in friend_ids:
|
||||||
|
await app.dispatcher.friend.sub(user_id, friend_id)
|
||||||
|
|
||||||
async def update_status(self, incoming_status: dict):
|
async def update_status(self, incoming_status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
|
|
@ -921,6 +947,7 @@ class GatewayWebsocket:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
assert self.state is not None
|
||||||
data = payload["d"]
|
data = payload["d"]
|
||||||
|
|
||||||
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||||
|
|
@ -933,11 +960,9 @@ class GatewayWebsocket:
|
||||||
log.debug("lazy request: members: {}", data.get("members", []))
|
log.debug("lazy request: members: {}", data.get("members", []))
|
||||||
|
|
||||||
# make shard query
|
# make shard query
|
||||||
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
|
|
||||||
|
|
||||||
for chan_id, ranges in data.get("channels", {}).items():
|
for chan_id, ranges in data.get("channels", {}).items():
|
||||||
chan_id = int(chan_id)
|
chan_id = int(chan_id)
|
||||||
member_list = await lazy_guilds.get_gml(chan_id)
|
member_list = await app.lazy_guild.get_gml(chan_id)
|
||||||
|
|
||||||
perms = await get_permissions(
|
perms = await get_permissions(
|
||||||
self.state.user_id, chan_id, storage=self.storage
|
self.state.user_id, chan_id, storage=self.storage
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,6 @@ class PresenceManager:
|
||||||
self.storage = app.storage
|
self.storage = app.storage
|
||||||
self.user_storage = app.user_storage
|
self.user_storage = app.user_storage
|
||||||
self.state_manager = app.state_manager
|
self.state_manager = app.state_manager
|
||||||
self.dispatcher = app.dispatcher
|
|
||||||
|
|
||||||
async def guild_presences(
|
async def guild_presences(
|
||||||
self, member_ids: List[int], guild_id: int
|
self, member_ids: List[int], guild_id: int
|
||||||
|
|
@ -127,8 +126,7 @@ class PresenceManager:
|
||||||
|
|
||||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||||
|
|
||||||
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
lists = app.lazy_guild.get_gml_guild(guild_id)
|
||||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
|
||||||
|
|
||||||
# shards that are in lazy guilds with 'everyone'
|
# shards that are in lazy guilds with 'everyone'
|
||||||
# enabled
|
# enabled
|
||||||
|
|
@ -163,20 +161,21 @@ class PresenceManager:
|
||||||
# given a session id, return if the session id actually connects to
|
# given a session id, return if the session id actually connects to
|
||||||
# a given user, and if the state has not been dispatched via lazy guild.
|
# a given user, and if the state has not been dispatched via lazy guild.
|
||||||
def _session_check(session_id):
|
def _session_check(session_id):
|
||||||
|
try:
|
||||||
state = self.state_manager.fetch_raw(session_id)
|
state = self.state_manager.fetch_raw(session_id)
|
||||||
uid = int(member["user"]["id"])
|
except KeyError:
|
||||||
|
|
||||||
if not state:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
uid = int(member["user"]["id"])
|
||||||
|
|
||||||
# we don't want to send a presence update
|
# we don't want to send a presence update
|
||||||
# to the same user
|
# to the same user
|
||||||
return state.user_id != uid and session_id not in in_lazy
|
return state.user_id != uid and session_id not in in_lazy
|
||||||
|
|
||||||
# everyone not in lazy guild mode
|
# everyone not in lazy guild mode
|
||||||
# gets a PRESENCE_UPDATE
|
# gets a PRESENCE_UPDATE
|
||||||
await self.dispatcher.dispatch_filter(
|
await app.dispatcher.guild.dispatch_filter(
|
||||||
"guild", guild_id, _session_check, "PRESENCE_UPDATE", event_payload
|
guild_id, _session_check, ("PRESENCE_UPDATE", event_payload)
|
||||||
)
|
)
|
||||||
|
|
||||||
return in_lazy
|
return in_lazy
|
||||||
|
|
@ -193,11 +192,8 @@ class PresenceManager:
|
||||||
|
|
||||||
# dispatch to all friends that are subscribed to them
|
# dispatch to all friends that are subscribed to them
|
||||||
user = await self.storage.get_user(user_id)
|
user = await self.storage.get_user(user_id)
|
||||||
await self.dispatcher.dispatch(
|
await app.dispatcher.friend.dispatch(
|
||||||
"friend",
|
user_id, ("PRESENCE_UPDATE", {**presence.partial_dict, **{"user": user}}),
|
||||||
user_id,
|
|
||||||
"PRESENCE_UPDATE",
|
|
||||||
{**presence.partial_dict, **{"user": user}},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch_friend_presence(self, friend_id: int) -> BasePresence:
|
def fetch_friend_presence(self, friend_id: int) -> BasePresence:
|
||||||
|
|
|
||||||
|
|
@ -18,17 +18,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .guild import GuildDispatcher
|
from .guild import GuildDispatcher
|
||||||
from .member import MemberDispatcher
|
from .member import dispatch_member
|
||||||
from .user import UserDispatcher
|
from .user import dispatch_user
|
||||||
from .channel import ChannelDispatcher
|
from .channel import ChannelDispatcher
|
||||||
from .friend import FriendDispatcher
|
from .friend import FriendDispatcher
|
||||||
from .lazy_guild import LazyGuildDispatcher
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GuildDispatcher",
|
"GuildDispatcher",
|
||||||
"MemberDispatcher",
|
"dispatch_member",
|
||||||
"UserDispatcher",
|
"dispatch_user",
|
||||||
"ChannelDispatcher",
|
"ChannelDispatcher",
|
||||||
"FriendDispatcher",
|
"FriendDispatcher",
|
||||||
"LazyGuildDispatcher",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, List
|
from typing import List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .dispatcher import DispatcherWithFlags
|
|
||||||
from litecord.enums import ChannelType
|
from litecord.enums import ChannelType
|
||||||
from litecord.utils import index_by_func
|
from litecord.utils import index_by_func
|
||||||
|
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -37,75 +39,66 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
||||||
"""
|
"""
|
||||||
# make a copy or the original channel object
|
# make a copy or the original channel object
|
||||||
data = dict(orig)
|
data = dict(orig)
|
||||||
|
|
||||||
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
|
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
|
||||||
|
|
||||||
data["recipients"].pop(idx)
|
data["recipients"].pop(idx)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ChannelDispatcher(DispatcherWithFlags):
|
@dataclass
|
||||||
"""Main channel Pub/Sub logic."""
|
class ChannelFlags:
|
||||||
|
typing: bool
|
||||||
|
|
||||||
KEY_TYPE = int
|
|
||||||
VAL_TYPE = int
|
|
||||||
|
|
||||||
async def dispatch(self, channel_id, event: str, data: Any) -> List[str]:
|
class ChannelDispatcher(
|
||||||
|
DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
|
||||||
|
):
|
||||||
|
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
|
||||||
|
|
||||||
|
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
|
||||||
"""Dispatch an event to a channel."""
|
"""Dispatch an event to a channel."""
|
||||||
# get everyone who is subscribed
|
session_ids = set(self.state[channel_id])
|
||||||
# and store the number of states we dispatched the event to
|
|
||||||
user_ids = self.state[channel_id]
|
|
||||||
dispatched = 0
|
|
||||||
sessions: List[str] = []
|
sessions: List[str] = []
|
||||||
|
|
||||||
# making a copy of user_ids since
|
event_type, event_data = event
|
||||||
# we'll modify it later on.
|
assert isinstance(event_data, dict)
|
||||||
for user_id in set(user_ids):
|
|
||||||
guild_id = await self.app.storage.guild_from_channel(channel_id)
|
|
||||||
|
|
||||||
# if we are dispatching to a guild channel,
|
for session_id in session_ids:
|
||||||
# we should only dispatch to the states / shards
|
try:
|
||||||
# that are connected to the guild (via their shard id).
|
state = app.state_manager.fetch_raw(session_id)
|
||||||
|
except KeyError:
|
||||||
|
await self.unsub(channel_id, session_id)
|
||||||
|
continue
|
||||||
|
|
||||||
# if we aren't, we just get all states tied to the user.
|
try:
|
||||||
# TODO: make a fetch_states that fetches shards
|
flags = self.get_flags(channel_id, session_id)
|
||||||
# - with id 0 (count any) OR
|
except KeyError:
|
||||||
# - single shards (id=0, count=1)
|
log.warning("no flags for {!r}, ignoring", session_id)
|
||||||
states = (
|
flags = ChannelFlags(typing=True)
|
||||||
self.sm.fetch_states(user_id, guild_id)
|
|
||||||
if guild_id
|
if event_type.lower().startswith("typing_") and not flags.typing:
|
||||||
else self.sm.user_states(user_id)
|
continue
|
||||||
|
|
||||||
|
correct_event = event
|
||||||
|
# for cases where we are talking about group dms, we create an edited
|
||||||
|
# event data so that it doesn't show the user we're dispatching
|
||||||
|
# to in data.recipients (clients already assume they are recipients)
|
||||||
|
if (
|
||||||
|
event_type in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||||
|
and event_data.get("type") == ChannelType.GROUP_DM.value
|
||||||
|
):
|
||||||
|
new_data = gdm_recipient_view(event_data, state.user_id)
|
||||||
|
correct_event = (event_type, new_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await state.ws.dispatch(*correct_event)
|
||||||
|
except Exception:
|
||||||
|
log.exception("error while dispatching to {}", state.session_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
sessions.append(session_id)
|
||||||
|
|
||||||
|
log.info(
|
||||||
|
"Dispatched chan={} {!r} to {} states", channel_id, event[0], len(sessions)
|
||||||
)
|
)
|
||||||
|
|
||||||
# unsub people who don't have any states tied to the channel.
|
|
||||||
if not states:
|
|
||||||
await self.unsub(channel_id, user_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# skip typing events for users that don't want it
|
|
||||||
if event.startswith("TYPING_") and not self.flags_get(
|
|
||||||
channel_id, user_id, "typing", True
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
cur_sess: List[str] = []
|
|
||||||
|
|
||||||
if (
|
|
||||||
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
|
||||||
and data.get("type") == ChannelType.GROUP_DM.value
|
|
||||||
):
|
|
||||||
# we edit the channel payload so it doesn't show
|
|
||||||
# the user as a recipient
|
|
||||||
|
|
||||||
new_data = gdm_recipient_view(data, user_id)
|
|
||||||
cur_sess = await self._dispatch_states(states, event, new_data)
|
|
||||||
else:
|
|
||||||
cur_sess = await self._dispatch_states(states, event, data)
|
|
||||||
|
|
||||||
sessions.extend(cur_sess)
|
|
||||||
dispatched += len(cur_sess)
|
|
||||||
|
|
||||||
log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
|
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import (
|
||||||
|
List,
|
||||||
|
Generic,
|
||||||
|
TypeVar,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Set,
|
||||||
|
Mapping,
|
||||||
|
Iterable,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
@ -25,79 +36,63 @@ from logbook import Logger
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _identity(_self, x):
|
K = TypeVar("K")
|
||||||
return x
|
V = TypeVar("V")
|
||||||
|
F = TypeVar("F")
|
||||||
|
EventType = TypeVar("EventType")
|
||||||
|
DispatchType = TypeVar("DispatchType")
|
||||||
|
F_Map = Mapping[V, F]
|
||||||
|
|
||||||
|
GatewayEvent = Tuple[str, Any]
|
||||||
|
|
||||||
|
__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"]
|
||||||
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher(Generic[K, V, EventType, DispatchType]):
|
||||||
"""Pub/Sub backend dispatcher.
|
"""Pub/Sub backend dispatcher.
|
||||||
|
|
||||||
This just declares functions all Dispatcher subclasses
|
Classes must implement this protocol.
|
||||||
can implement. This does not mean all Dispatcher
|
|
||||||
subclasses have them implemented.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
KEY_TYPE = _identity
|
async def sub(self, key: K, identifier: V) -> None:
|
||||||
VAL_TYPE = _identity
|
"""Subscribe a given identifier to a given key."""
|
||||||
|
...
|
||||||
|
|
||||||
def __init__(self, main):
|
async def sub_many(self, key: K, identifier_list: Iterable[V]) -> None:
|
||||||
#: main EventDispatcher
|
for identifier in identifier_list:
|
||||||
self.main_dispatcher = main
|
await self.sub(key, identifier)
|
||||||
|
|
||||||
#: gateway state storage
|
async def unsub(self, key: K, identifier: V) -> None:
|
||||||
self.sm = main.state_manager
|
"""Unsubscribe a given identifier to a given key."""
|
||||||
|
...
|
||||||
|
|
||||||
self.app = main.app
|
async def dispatch(self, key: K, event: EventType) -> DispatchType:
|
||||||
|
...
|
||||||
|
|
||||||
async def sub(self, _key, _id):
|
async def dispatch_many(self, keys: List[K], *args: Any, **kwargs: Any) -> None:
|
||||||
"""Subscribe an element to the channel/key."""
|
log.info("MULTI DISPATCH in {!r}, {} keys", self, len(keys))
|
||||||
raise NotImplementedError
|
for key in keys:
|
||||||
|
await self.dispatch(key, *args, **kwargs)
|
||||||
|
|
||||||
async def unsub(self, _key, _id):
|
async def drop(self, key: K) -> None:
|
||||||
"""Unsubscribe an elemtnt from the channel/key."""
|
"""Drop a key."""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
async def dispatch_filter(self, _key, _func, *_args):
|
async def clear(self, key: K) -> None:
|
||||||
"""Selectively dispatch to the list of subscribed users.
|
"""Clear a key from the backend."""
|
||||||
|
...
|
||||||
|
|
||||||
The selection logic is completly arbitraty and up to the
|
async def dispatch_filter(
|
||||||
Pub/Sub backend.
|
self, key: K, filter_function: Callable[[K], bool], event: EventType
|
||||||
|
) -> List[str]:
|
||||||
|
"""Selectively dispatch to the list of subscribers.
|
||||||
|
|
||||||
|
Function must return a list of separate identifiers for composability.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
...
|
||||||
|
|
||||||
async def dispatch(self, _key, *_args):
|
|
||||||
"""Dispatch an event to the given channel/key."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def reset(self, _key):
|
|
||||||
"""Reset a key from the backend."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def remove(self, _key):
|
|
||||||
"""Remove a key from the backend.
|
|
||||||
|
|
||||||
The meaning from reset() and remove()
|
|
||||||
is different, reset() is to clear all
|
|
||||||
subscribers from the given key,
|
|
||||||
remove() is to remove the key as well.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
|
|
||||||
"""Dispatch an event to a list of states."""
|
|
||||||
res = []
|
|
||||||
|
|
||||||
for state in states:
|
|
||||||
try:
|
|
||||||
await state.ws.dispatch(event, data)
|
|
||||||
res.append(state.session_id)
|
|
||||||
except Exception:
|
|
||||||
log.exception("error while dispatching")
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class DispatcherWithState(Dispatcher):
|
class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
|
||||||
"""Pub/Sub backend with a state dictionary.
|
"""Pub/Sub backend with a state dictionary.
|
||||||
|
|
||||||
This class was made to decrease the amount
|
This class was made to decrease the amount
|
||||||
|
|
@ -105,58 +100,58 @@ class DispatcherWithState(Dispatcher):
|
||||||
that have that dictionary.
|
that have that dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, main):
|
def __init__(self):
|
||||||
super().__init__(main)
|
super().__init__()
|
||||||
|
|
||||||
#: the default dict is to a set
|
#: the default dict is to a set
|
||||||
# so we make sure someone calling sub()
|
# so we make sure someone calling sub()
|
||||||
# twice won't get 2x the events for the
|
# twice won't get 2x the events for the
|
||||||
# same channel.
|
# same channel.
|
||||||
self.state = defaultdict(set)
|
self.state: Dict[K, Set[V]] = defaultdict(set)
|
||||||
|
|
||||||
async def sub(self, key, identifier):
|
async def sub(self, key: K, identifier: V):
|
||||||
self.state[key].add(identifier)
|
self.state[key].add(identifier)
|
||||||
|
|
||||||
async def unsub(self, key, identifier):
|
async def unsub(self, key: K, identifier: V):
|
||||||
self.state[key].discard(identifier)
|
self.state[key].discard(identifier)
|
||||||
|
|
||||||
async def reset(self, key):
|
async def reset(self, key: K):
|
||||||
self.state[key] = set()
|
self.state[key] = set()
|
||||||
|
|
||||||
async def remove(self, key):
|
async def drop(self, key: K):
|
||||||
try:
|
try:
|
||||||
self.state.pop(key)
|
self.state.pop(key)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def dispatch(self, key, *args):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
class DispatcherWithFlags(
|
||||||
class DispatcherWithFlags(DispatcherWithState):
|
DispatcherWithState, Generic[K, V, EventType, DispatchType, F],
|
||||||
|
):
|
||||||
"""Pub/Sub backend with both a state and a flags store."""
|
"""Pub/Sub backend with both a state and a flags store."""
|
||||||
|
|
||||||
def __init__(self, main):
|
def __init__(self):
|
||||||
super().__init__(main)
|
super().__init__()
|
||||||
|
self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
|
||||||
|
|
||||||
#: keep flags for subscribers, so for example
|
def set_flags(self, key: K, identifier: V, flags: F):
|
||||||
# a subscriber could drop all presence events at the
|
"""Set flags for the given identifier."""
|
||||||
# pubsub level. see gateway's guild_subscriptions field for more
|
self.flags[key][identifier] = flags
|
||||||
self.flags = defaultdict(dict)
|
|
||||||
|
|
||||||
async def sub(self, key, identifier, flags=None):
|
def remove_flags(self, key: K, identifier: V):
|
||||||
"""Subscribe a user to the guild."""
|
"""Set flags for the given identifier."""
|
||||||
await super().sub(key, identifier)
|
|
||||||
self.flags[key][identifier] = flags or {}
|
|
||||||
|
|
||||||
async def unsub(self, key, identifier):
|
|
||||||
"""Unsubscribe a user from the guild."""
|
|
||||||
await super().unsub(key, identifier)
|
|
||||||
self.flags[key].pop(identifier)
|
self.flags[key].pop(identifier)
|
||||||
|
|
||||||
def flags_get(self, key, identifier, field: str, default):
|
def get_flags(self, key: K, identifier: V):
|
||||||
"""Get a single field from the flags store."""
|
"""Get a single field from the flags store."""
|
||||||
# yes, i know its simply an indirection from the main flags store,
|
return self.flags[key][identifier]
|
||||||
# but i'd rather have this than change every call if i ever change
|
|
||||||
# the structure of the flags store.
|
async def sub_with_flags(self, key: K, identifier: V, flags: F):
|
||||||
return self.flags[key][identifier].get(field, default)
|
"""Subscribe a user to the guild."""
|
||||||
|
await super().sub(key, identifier)
|
||||||
|
self.set_flags(key, identifier, flags)
|
||||||
|
|
||||||
|
async def unsub(self, key: K, identifier: V):
|
||||||
|
"""Unsubscribe a user from the guild."""
|
||||||
|
await super().unsub(key, identifier)
|
||||||
|
self.remove_flags(key, identifier)
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Set
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .dispatcher import DispatcherWithState
|
from .dispatcher import DispatcherWithState, GatewayEvent
|
||||||
|
from .user import dispatch_user_filter
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FriendDispatcher(DispatcherWithState):
|
class FriendDispatcher(DispatcherWithState[int, int, GatewayEvent, List[str]]):
|
||||||
"""Friend Pub/Sub logic.
|
"""Friend Pub/Sub logic.
|
||||||
|
|
||||||
When connecting, a client will subscribe to all their friends
|
When connecting, a client will subscribe to all their friends
|
||||||
|
|
@ -33,26 +34,18 @@ class FriendDispatcher(DispatcherWithState):
|
||||||
broadcasted through that channel to basically all their friends.
|
broadcasted through that channel to basically all their friends.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
KEY_TYPE = int
|
async def dispatch_filter(self, user_id: int, filter_function, event: GatewayEvent):
|
||||||
VAL_TYPE = int
|
|
||||||
|
|
||||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
|
||||||
"""Dispatch an event to all of a users' friends."""
|
"""Dispatch an event to all of a users' friends."""
|
||||||
peer_ids = self.state[user_id]
|
peer_ids: Set[int] = self.state[user_id]
|
||||||
sessions: List[str] = []
|
sessions: List[str] = []
|
||||||
|
|
||||||
for peer_id in peer_ids:
|
for peer_id in peer_ids:
|
||||||
# dispatch to the user instead of the "shards tied to a guild"
|
# dispatch to the user instead of the "shards tied to a guild"
|
||||||
# since relationships broadcast to all shards.
|
# since relationships broadcast to all shards.
|
||||||
sessions.extend(
|
sessions.extend(await dispatch_user_filter(peer_id, filter_function, event))
|
||||||
await self.main_dispatcher.dispatch_filter(
|
|
||||||
"user", peer_id, func, event, data
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
|
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
async def dispatch(self, user_id, event, data):
|
async def dispatch(self, user_id: int, event: GatewayEvent):
|
||||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
return await self.dispatch_filter(user_id, lambda sess_id: True, event)
|
||||||
|
|
|
||||||
|
|
@ -17,123 +17,73 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .dispatcher import DispatcherWithFlags
|
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||||
from litecord.permissions import get_permissions
|
from .channel import ChannelFlags
|
||||||
|
from litecord.gateway.state import GatewayState
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GuildDispatcher(DispatcherWithFlags):
|
@dataclass
|
||||||
"""Guild backend for Pub/Sub"""
|
class GuildFlags(ChannelFlags):
|
||||||
|
presence: bool
|
||||||
|
|
||||||
KEY_TYPE = int
|
|
||||||
VAL_TYPE = int
|
|
||||||
|
|
||||||
async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None):
|
class GuildDispatcher(
|
||||||
"""Send an action to all channels of the guild."""
|
DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
|
||||||
flags = flags or {}
|
):
|
||||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
"""Guild backend for Pub/Sub."""
|
||||||
|
|
||||||
for chan_id in chan_ids:
|
async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
|
||||||
|
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||||
|
for state in states:
|
||||||
|
await self.sub(guild_id, state.session_id)
|
||||||
|
|
||||||
# only do an action for users that can
|
return states
|
||||||
# actually read the channel to start with.
|
|
||||||
chan_perms = await get_permissions(
|
|
||||||
user_id, chan_id, storage=self.main_dispatcher.app.storage
|
|
||||||
)
|
|
||||||
|
|
||||||
if not chan_perms.bits.read_messages:
|
async def dispatch_filter(
|
||||||
log.debug("skipping cid={}, no read messages", chan_id)
|
self, guild_id: int, filter_function, event: GatewayEvent
|
||||||
continue
|
|
||||||
|
|
||||||
log.debug("sending raw action {!r} to chan={}", action, chan_id)
|
|
||||||
|
|
||||||
# for now, only sub() has support for flags.
|
|
||||||
# it is an idea to have flags support for other actions
|
|
||||||
args = []
|
|
||||||
if action == "sub":
|
|
||||||
chanflags = dict(flags)
|
|
||||||
|
|
||||||
# channels don't need presence flags
|
|
||||||
try:
|
|
||||||
chanflags.pop("presence")
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
args.append(chanflags)
|
|
||||||
|
|
||||||
await self.main_dispatcher.action(
|
|
||||||
"channel", action, chan_id, user_id, *args
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _chan_call(self, meth: str, guild_id: int, *args):
|
|
||||||
"""Call a method on the ChannelDispatcher, for all channels
|
|
||||||
in the guild."""
|
|
||||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
|
||||||
|
|
||||||
chan_dispatcher = self.main_dispatcher.backends["channel"]
|
|
||||||
method = getattr(chan_dispatcher, meth)
|
|
||||||
|
|
||||||
for chan_id in chan_ids:
|
|
||||||
log.debug("calling {} to chan={}", meth, chan_id)
|
|
||||||
await method(chan_id, *args)
|
|
||||||
|
|
||||||
async def sub(self, guild_id: int, user_id: int, flags=None):
|
|
||||||
"""Subscribe a user to the guild."""
|
|
||||||
await super().sub(guild_id, user_id, flags)
|
|
||||||
await self._chan_action("sub", guild_id, user_id, flags)
|
|
||||||
|
|
||||||
async def unsub(self, guild_id: int, user_id: int):
|
|
||||||
"""Unsubscribe a user from the guild."""
|
|
||||||
await super().unsub(guild_id, user_id)
|
|
||||||
await self._chan_action("unsub", guild_id, user_id)
|
|
||||||
|
|
||||||
async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
|
|
||||||
"""Selectively dispatch to session ids that have
|
|
||||||
func(session_id) true."""
|
|
||||||
user_ids = self.state[guild_id]
|
|
||||||
dispatched = 0
|
|
||||||
sessions = []
|
|
||||||
|
|
||||||
# acquire a copy since we may be modifying
|
|
||||||
# the original user_ids
|
|
||||||
for user_id in set(user_ids):
|
|
||||||
|
|
||||||
# fetch all states / shards that are tied to the guild.
|
|
||||||
states = self.sm.fetch_states(user_id, guild_id)
|
|
||||||
|
|
||||||
if not states:
|
|
||||||
# user is actually disconnected,
|
|
||||||
# so we should just unsub them
|
|
||||||
await self.unsub(guild_id, user_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# skip the given subscriber if event starts with PRESENCE_
|
|
||||||
# and the flags say they don't want it.
|
|
||||||
|
|
||||||
# note that this does not equate to any unsubscription
|
|
||||||
# of the channel.
|
|
||||||
if event.startswith("PRESENCE_") and not self.flags_get(
|
|
||||||
guild_id, user_id, "presence", True
|
|
||||||
):
|
):
|
||||||
|
session_ids = self.state[guild_id]
|
||||||
|
sessions: List[str] = []
|
||||||
|
event_type, _ = event
|
||||||
|
|
||||||
|
for session_id in set(session_ids):
|
||||||
|
if not filter_function(session_id):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# filter the ones that matter
|
try:
|
||||||
states = list(filter(lambda state: func(state.session_id), states))
|
state = app.state_manager.fetch_raw(session_id)
|
||||||
|
except KeyError:
|
||||||
|
await self.unsub(guild_id, session_id)
|
||||||
|
continue
|
||||||
|
|
||||||
cur_sess = await self._dispatch_states(states, event, data)
|
try:
|
||||||
|
flags = self.get_flags(guild_id, session_id)
|
||||||
|
except KeyError:
|
||||||
|
log.warning("no flags for {!r}, ignoring", session_id)
|
||||||
|
flags = GuildFlags(presence=True, typing=True)
|
||||||
|
|
||||||
sessions.extend(cur_sess)
|
if event_type.lower().startswith("presence_") and not flags.presence:
|
||||||
dispatched += len(cur_sess)
|
continue
|
||||||
|
|
||||||
log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched)
|
try:
|
||||||
|
await state.ws.dispatch(*event)
|
||||||
|
except Exception:
|
||||||
|
log.exception("error while dispatching to {}", state.session_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
sessions.append(session_id)
|
||||||
|
|
||||||
|
log.info("Dispatched {} {!r} to {} states", guild_id, event[0], len(sessions))
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
async def dispatch(self, guild_id: int, event: str, data: Any):
|
async def dispatch(self, guild_id: int, event):
|
||||||
"""Dispatch an event to all subscribers of the guild."""
|
"""Dispatch an event to all subscribers of the guild."""
|
||||||
return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data)
|
return await self.dispatch_filter(guild_id, lambda sess_id: True, event)
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,10 @@ import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
|
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
|
from quart import current_app as app
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.pubsub.dispatcher import Dispatcher
|
|
||||||
from litecord.permissions import (
|
from litecord.permissions import (
|
||||||
Permissions,
|
Permissions,
|
||||||
overwrite_find_mix,
|
overwrite_find_mix,
|
||||||
|
|
@ -239,9 +239,6 @@ class GuildMemberList:
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
main_lg: LazyGuildDispatcher
|
|
||||||
Main instance of :class:`LazyGuildDispatcher`,
|
|
||||||
so that we're able to use things such as :class:`Storage`.
|
|
||||||
guild_id: int
|
guild_id: int
|
||||||
The Guild ID this instance is referring to.
|
The Guild ID this instance is referring to.
|
||||||
channel_id: int
|
channel_id: int
|
||||||
|
|
@ -257,11 +254,10 @@ class GuildMemberList:
|
||||||
for example, can still rely on PRESENCE_UPDATEs.
|
for example, can still rely on PRESENCE_UPDATEs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, guild_id: int, channel_id: int, main_lg):
|
def __init__(self, guild_id: int, channel_id: int):
|
||||||
self.guild_id = guild_id
|
self.guild_id = guild_id
|
||||||
self.channel_id = channel_id
|
self.channel_id = channel_id
|
||||||
|
|
||||||
self.main = main_lg
|
|
||||||
self.list = MemberList()
|
self.list = MemberList()
|
||||||
|
|
||||||
#: store the states that are subscribed to the list.
|
#: store the states that are subscribed to the list.
|
||||||
|
|
@ -273,22 +269,22 @@ class GuildMemberList:
|
||||||
@property
|
@property
|
||||||
def loop(self):
|
def loop(self):
|
||||||
"""Get the main asyncio loop instance."""
|
"""Get the main asyncio loop instance."""
|
||||||
return self.main.app.loop
|
return app.loop
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def storage(self):
|
def storage(self):
|
||||||
"""Get the global :class:`Storage` instance."""
|
"""Get the global :class:`Storage` instance."""
|
||||||
return self.main.app.storage
|
return app.storage
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def presence(self):
|
def presence(self):
|
||||||
"""Get the global :class:`PresenceManager` instance."""
|
"""Get the global :class:`PresenceManager` instance."""
|
||||||
return self.main.app.presence
|
return app.presence
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_man(self):
|
def state_man(self):
|
||||||
"""Get the global :class:`StateManager` instance."""
|
"""Get the global :class:`StateManager` instance."""
|
||||||
return self.main.app.state_manager
|
return app.state_manager
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def list_id(self):
|
def list_id(self):
|
||||||
|
|
@ -572,8 +568,7 @@ class GuildMemberList:
|
||||||
Wrapper for :meth:`StateManager.fetch_raw`
|
Wrapper for :meth:`StateManager.fetch_raw`
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
state = self.state_man.fetch_raw(session_id)
|
return self.state_man.fetch_raw(session_id)
|
||||||
return state
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -643,7 +638,7 @@ class GuildMemberList:
|
||||||
|
|
||||||
# do resync-ing in the background
|
# do resync-ing in the background
|
||||||
result.append(session_id)
|
result.append(session_id)
|
||||||
self.loop.create_task(self.shard_query(session_id, [role_range]))
|
app.sched.spawn(self.shard_query(session_id, [role_range]))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -683,8 +678,7 @@ class GuildMemberList:
|
||||||
)
|
)
|
||||||
|
|
||||||
if everyone_perms.bits.read_messages and list_id != "everyone":
|
if everyone_perms.bits.read_messages and list_id != "everyone":
|
||||||
everyone_gml = await self.main.get_gml(self.guild_id)
|
everyone_gml = await app.lazy_guild.get_gml(self.guild_id)
|
||||||
|
|
||||||
return await everyone_gml.shard_query(session_id, ranges)
|
return await everyone_gml.shard_query(session_id, ranges)
|
||||||
|
|
||||||
await self._init_check()
|
await self._init_check()
|
||||||
|
|
@ -1372,47 +1366,36 @@ class GuildMemberList:
|
||||||
|
|
||||||
self.guild_id = 0
|
self.guild_id = 0
|
||||||
self.channel_id = 0
|
self.channel_id = 0
|
||||||
self.main = None
|
|
||||||
self._set_empty_list()
|
self._set_empty_list()
|
||||||
self.state = {}
|
self.state = {}
|
||||||
|
|
||||||
|
|
||||||
class LazyGuildDispatcher(Dispatcher):
|
class LazyGuildManager:
|
||||||
"""Main class holding the member lists for lazy guilds."""
|
"""Main class holding the member lists for lazy guilds."""
|
||||||
|
|
||||||
# channel ids
|
def __init__(self):
|
||||||
KEY_TYPE = int
|
|
||||||
|
|
||||||
# the session ids subscribing to channels
|
|
||||||
VAL_TYPE = str
|
|
||||||
|
|
||||||
def __init__(self, main):
|
|
||||||
super().__init__(main)
|
|
||||||
|
|
||||||
self.storage = main.app.storage
|
|
||||||
|
|
||||||
# {chan_id: gml, ...}
|
# {chan_id: gml, ...}
|
||||||
self.state = {}
|
self.state: Dict[int, GuildMemberList] = {}
|
||||||
|
|
||||||
#: store which guilds have their
|
#: store which guilds have their
|
||||||
# respective GMLs
|
# respective GMLs
|
||||||
# {guild_id: [chan_id, ...], ...}
|
# {guild_id: [chan_id, ...], ...}
|
||||||
self.guild_map: Dict[int, List[int]] = defaultdict(list)
|
self.guild_map: Dict[int, List[int]] = defaultdict(list)
|
||||||
|
|
||||||
async def get_gml(self, channel_id: int):
|
async def get_gml(self, channel_id: int) -> GuildMemberList:
|
||||||
"""Get a guild list for a channel ID,
|
"""Get a guild list for a channel ID,
|
||||||
generating it if it doesn't exist."""
|
generating it if it doesn't exist."""
|
||||||
try:
|
try:
|
||||||
return self.state[channel_id]
|
return self.state[channel_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
guild_id = await self.storage.guild_from_channel(channel_id)
|
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||||
|
|
||||||
# if we don't find a guild, we just
|
# if we don't find a guild, we just
|
||||||
# set it the same as the channel.
|
# set it the same as the channel.
|
||||||
if not guild_id:
|
if not guild_id:
|
||||||
guild_id = channel_id
|
guild_id = channel_id
|
||||||
|
|
||||||
gml = GuildMemberList(guild_id, channel_id, self)
|
gml = GuildMemberList(guild_id, channel_id)
|
||||||
self.state[channel_id] = gml
|
self.state[channel_id] = gml
|
||||||
self.guild_map[guild_id].append(channel_id)
|
self.guild_map[guild_id].append(channel_id)
|
||||||
return gml
|
return gml
|
||||||
|
|
@ -1437,16 +1420,6 @@ class LazyGuildDispatcher(Dispatcher):
|
||||||
gml = await self.get_gml(chan_id)
|
gml = await self.get_gml(chan_id)
|
||||||
gml.unsub(session_id)
|
gml.unsub(session_id)
|
||||||
|
|
||||||
async def dispatch(self, guild_id, event: str, *args, **kwargs):
|
|
||||||
"""Call a function specialized in handling the given event"""
|
|
||||||
try:
|
|
||||||
handler = getattr(self, f"_handle_{event.lower()}")
|
|
||||||
except AttributeError:
|
|
||||||
log.warning("unknown event: {}", event)
|
|
||||||
return
|
|
||||||
|
|
||||||
await handler(guild_id, *args, **kwargs)
|
|
||||||
|
|
||||||
def remove_channel(self, channel_id: int):
|
def remove_channel(self, channel_id: int):
|
||||||
"""Remove a channel from the manager."""
|
"""Remove a channel from the manager."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -1474,29 +1447,29 @@ class LazyGuildDispatcher(Dispatcher):
|
||||||
method = getattr(lazy_list, method_str)
|
method = getattr(lazy_list, method_str)
|
||||||
await method(*args)
|
await method(*args)
|
||||||
|
|
||||||
async def _handle_new_role(self, guild_id: int, new_role: dict):
|
async def new_role(self, guild_id: int, new_role: dict):
|
||||||
"""Handle the addition of a new group by dispatching it to
|
"""Handle the addition of a new group by dispatching it to
|
||||||
the member lists."""
|
the member lists."""
|
||||||
await self._call_all_lists(guild_id, "new_role", new_role)
|
await self._call_all_lists(guild_id, "new_role", new_role)
|
||||||
|
|
||||||
async def _handle_role_pos_upd(self, guild_id, role: dict):
|
async def role_position_update(self, guild_id, role: dict):
|
||||||
await self._call_all_lists(guild_id, "role_pos_update", role)
|
await self._call_all_lists(guild_id, "role_pos_update", role)
|
||||||
|
|
||||||
async def _handle_role_update(self, guild_id, role: dict):
|
async def role_update(self, guild_id, role: dict):
|
||||||
# handle name and hoist changes
|
# handle name and hoist changes
|
||||||
await self._call_all_lists(guild_id, "role_update", role)
|
await self._call_all_lists(guild_id, "role_update", role)
|
||||||
|
|
||||||
async def _handle_role_delete(self, guild_id, role_id: int):
|
async def role_delete(self, guild_id, role_id: int):
|
||||||
await self._call_all_lists(guild_id, "role_delete", role_id)
|
await self._call_all_lists(guild_id, "role_delete", role_id)
|
||||||
|
|
||||||
async def _handle_pres_update(self, guild_id, user_id: int, partial: dict):
|
async def pres_update(self, guild_id, user_id: int, partial: dict):
|
||||||
await self._call_all_lists(guild_id, "pres_update", user_id, partial)
|
await self._call_all_lists(guild_id, "pres_update", user_id, partial)
|
||||||
|
|
||||||
async def _handle_new_member(self, guild_id, user_id: int):
|
async def new_member(self, guild_id, user_id: int):
|
||||||
await self._call_all_lists(guild_id, "new_member", user_id)
|
await self._call_all_lists(guild_id, "new_member", user_id)
|
||||||
|
|
||||||
async def _handle_remove_member(self, guild_id, user_id: int):
|
async def remove_member(self, guild_id, user_id: int):
|
||||||
await self._call_all_lists(guild_id, "remove_member", user_id)
|
await self._call_all_lists(guild_id, "remove_member", user_id)
|
||||||
|
|
||||||
async def _handle_update_user(self, guild_id, user_id: int):
|
async def update_user(self, guild_id, user_id: int):
|
||||||
await self._call_all_lists(guild_id, "update_user", user_id)
|
await self._call_all_lists(guild_id, "update_user", user_id)
|
||||||
|
|
|
||||||
|
|
@ -17,30 +17,20 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .dispatcher import Dispatcher
|
from typing import List
|
||||||
|
from quart import current_app as app
|
||||||
|
from .dispatcher import GatewayEvent
|
||||||
|
from .utils import send_event_to_states
|
||||||
|
|
||||||
|
|
||||||
class MemberDispatcher(Dispatcher):
|
async def dispatch_member(
|
||||||
"""Member backend for Pub/Sub."""
|
guild_id: int, user_id: int, event: GatewayEvent
|
||||||
|
) -> List[str]:
|
||||||
|
states = app.state_manager.fetch_states(user_id, guild_id)
|
||||||
|
|
||||||
KEY_TYPE = tuple
|
# if no states were found, we should unsub the user from the guild
|
||||||
|
|
||||||
async def dispatch(self, key, event, data):
|
|
||||||
"""Dispatch a single event to a member.
|
|
||||||
|
|
||||||
This is shard-aware.
|
|
||||||
"""
|
|
||||||
# we don't keep any state on this dispatcher, so the key
|
|
||||||
# is just (guild_id, user_id)
|
|
||||||
guild_id, user_id = key
|
|
||||||
|
|
||||||
# fetch shards
|
|
||||||
states = self.sm.fetch_states(user_id, guild_id)
|
|
||||||
|
|
||||||
# if no states were found, we should
|
|
||||||
# unsub the user from the GUILD channel
|
|
||||||
if not states:
|
if not states:
|
||||||
await self.main_dispatcher.unsub("guild", guild_id, user_id)
|
await app.dispatcher.guild.unsub(guild_id, user_id)
|
||||||
return
|
return []
|
||||||
|
|
||||||
return await self._dispatch_states(states, event, data)
|
return await send_event_to_states(states, event)
|
||||||
|
|
|
||||||
|
|
@ -17,23 +17,27 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .dispatcher import Dispatcher
|
from typing import Callable, List
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
|
from .dispatcher import GatewayEvent
|
||||||
|
from .utils import send_event_to_states
|
||||||
|
|
||||||
|
|
||||||
class UserDispatcher(Dispatcher):
|
async def dispatch_user_filter(
|
||||||
"""User backend for Pub/Sub."""
|
user_id: int, filter_func: Callable[[str], bool], event_data: GatewayEvent
|
||||||
|
) -> List[str]:
|
||||||
KEY_TYPE = int
|
"""Dispatch to a given user's states, but only for states
|
||||||
|
where filter_func returns true."""
|
||||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
|
||||||
"""Dispatch an event to all shards of a user."""
|
|
||||||
|
|
||||||
# filter only states where func() gives true
|
|
||||||
states = list(
|
states = list(
|
||||||
filter(lambda state: func(state.session_id), self.sm.user_states(user_id))
|
filter(
|
||||||
|
lambda state: filter_func(state.session_id),
|
||||||
|
app.state_manager.user_states(user_id),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._dispatch_states(states, event, data)
|
return await send_event_to_states(states, event_data)
|
||||||
|
|
||||||
async def dispatch(self, user_id: int, event, data):
|
|
||||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
async def dispatch_user(user_id: int, event_data: GatewayEvent) -> List[str]:
|
||||||
|
return await dispatch_user_filter(user_id, lambda sess_id: True, event_data)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
Litecord
|
||||||
|
Copyright (C) 2018-2019 Luna Mendes
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, version 3 of the License.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List, Tuple, Any
|
||||||
|
from ..gateway.state import GatewayState
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def send_event_to_states(
|
||||||
|
states: List[GatewayState], event_data: Tuple[str, Any]
|
||||||
|
) -> List[str]:
|
||||||
|
"""Dispatch an event to a list of states."""
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for state in states:
|
||||||
|
try:
|
||||||
|
event, data = event_data
|
||||||
|
await state.ws.dispatch(event, data)
|
||||||
|
res.append(state.session_id)
|
||||||
|
except Exception:
|
||||||
|
log.exception("error while dispatching")
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
@ -482,6 +482,9 @@ INVITE = {
|
||||||
"required": False,
|
"required": False,
|
||||||
"nullable": True,
|
"nullable": True,
|
||||||
}, # discord client sends invite code there
|
}, # discord client sends invite code there
|
||||||
|
# sent by official client, unknown purpose
|
||||||
|
"target_user_id": {"type": "snowflake", "required": False, "nullable": True},
|
||||||
|
"target_user_type": {"type": "number", "required": False, "nullable": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
USER_SETTINGS = {
|
USER_SETTINGS = {
|
||||||
|
|
|
||||||
|
|
@ -709,7 +709,9 @@ class Storage:
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_guild_extra(self, guild_id: int, user_id=None, large=None) -> Dict:
|
async def get_guild_extra(
|
||||||
|
self, guild_id: int, user_id: Optional[int] = None, large: Optional[int] = None
|
||||||
|
) -> Dict:
|
||||||
"""Get extra information about a guild."""
|
"""Get extra information about a guild."""
|
||||||
res = {}
|
res = {}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -181,9 +181,6 @@ async def send_sys_message(
|
||||||
raise ValueError("Invalid system message type")
|
raise ValueError("Invalid system message type")
|
||||||
|
|
||||||
message_id = await handler(channel_id, *args, **kwargs)
|
message_id = await handler(channel_id, *args, **kwargs)
|
||||||
|
|
||||||
message = await app.storage.get_message(message_id)
|
message = await app.storage.get_message(message_id)
|
||||||
|
await app.dispatcher.channel.dispatch(channel_id, ("MESSAGE_CREATE", message))
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
|
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
|
||||||
|
|
@ -252,7 +252,7 @@ async def maybe_lazy_guild_dispatch(
|
||||||
if isinstance(role, dict) and not role["hoist"] and not force:
|
if isinstance(role, dict) and not role["hoist"] and not force:
|
||||||
return
|
return
|
||||||
|
|
||||||
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
|
await (getattr(app.lazy_guild, event))(guild_id, role)
|
||||||
|
|
||||||
|
|
||||||
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Dict, List
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
from quart import current_app as app
|
||||||
|
|
||||||
from litecord.voice.lvsp_conn import LVSPConnection
|
from litecord.voice.lvsp_conn import LVSPConnection
|
||||||
|
|
||||||
|
|
@ -42,15 +43,15 @@ class LVSPManager:
|
||||||
Spawns :class:`LVSPConnection` as needed, etc.
|
Spawns :class:`LVSPConnection` as needed, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app, voice):
|
def __init__(self, app_, voice):
|
||||||
self.app = app
|
self.app = app_
|
||||||
self.voice = voice
|
self.voice = voice
|
||||||
|
|
||||||
# map servers to LVSPConnection
|
# map servers to LVSPConnection
|
||||||
self.conns = {}
|
self.conns: Dict[str, LVSPConnection] = {}
|
||||||
|
|
||||||
# maps regions to server hostnames
|
# maps regions to server hostnames
|
||||||
self.servers = defaultdict(list)
|
self.servers: Dict[str, List[str]] = defaultdict(list)
|
||||||
|
|
||||||
# maps Union[GuildID, DMId, GroupDMId] to server hostnames
|
# maps Union[GuildID, DMId, GroupDMId] to server hostnames
|
||||||
self.assign = {}
|
self.assign = {}
|
||||||
|
|
@ -84,7 +85,7 @@ class LVSPManager:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.regions[region.id] = region
|
self.regions[region.id] = region
|
||||||
self.app.loop.create_task(self._spawn_region(region))
|
app.sched.spawn(self._spawn_region(region))
|
||||||
|
|
||||||
async def _spawn_region(self, region: Region):
|
async def _spawn_region(self, region: Region):
|
||||||
"""Spawn a region. Involves fetching all the hostnames
|
"""Spawn a region. Involves fetching all the hostnames
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
from typing import Tuple, Dict, List
|
from typing import Tuple, Dict, List
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
|
from quart import current_app as app
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
@ -286,6 +287,6 @@ class VoiceManager:
|
||||||
# slow, but it be like that, also copied from other users...
|
# slow, but it be like that, also copied from other users...
|
||||||
for guild_id in guild_ids:
|
for guild_id in guild_ids:
|
||||||
guild = await self.app.storage.get_guild_full(guild_id, None)
|
guild = await self.app.storage.get_guild_full(guild_id, None)
|
||||||
await self.app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
await app.dispatcher.guild.dispatch(guild_id, ("GUILD_UPDATE", guild))
|
||||||
|
|
||||||
# TODO propagate the channel deprecation to LVSP connections
|
# TODO propagate the channel deprecation to LVSP connections
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ def main(config):
|
||||||
# as the managers require it
|
# as the managers require it
|
||||||
# and the migrate command also sets the db up
|
# and the migrate command also sets the db up
|
||||||
if argv[1] != "migrate":
|
if argv[1] != "migrate":
|
||||||
init_app_managers(app, voice=False)
|
init_app_managers(app, init_voice=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
loop.run_until_complete(_ctx_wrapper(app, args))
|
loop.run_until_complete(_ctx_wrapper(app, args))
|
||||||
|
|
|
||||||
8
run.py
8
run.py
|
|
@ -95,6 +95,7 @@ from litecord.images import IconManager
|
||||||
from litecord.jobs import JobManager
|
from litecord.jobs import JobManager
|
||||||
from litecord.voice.manager import VoiceManager
|
from litecord.voice.manager import VoiceManager
|
||||||
from litecord.guild_memory_store import GuildMemoryStore
|
from litecord.guild_memory_store import GuildMemoryStore
|
||||||
|
from litecord.pubsub.lazy_guild import LazyGuildManager
|
||||||
|
|
||||||
from litecord.gateway.gateway import websocket_handler
|
from litecord.gateway.gateway import websocket_handler
|
||||||
|
|
||||||
|
|
@ -254,7 +255,7 @@ async def init_app_db(app_):
|
||||||
app_.sched = JobManager()
|
app_.sched = JobManager()
|
||||||
|
|
||||||
|
|
||||||
def init_app_managers(app_, *, voice=True):
|
def init_app_managers(app_: Quart, *, init_voice=True):
|
||||||
"""Initialize singleton classes."""
|
"""Initialize singleton classes."""
|
||||||
app_.loop = asyncio.get_event_loop()
|
app_.loop = asyncio.get_event_loop()
|
||||||
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
|
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
|
||||||
|
|
@ -265,7 +266,7 @@ def init_app_managers(app_, *, voice=True):
|
||||||
|
|
||||||
app_.icons = IconManager(app_)
|
app_.icons = IconManager(app_)
|
||||||
|
|
||||||
app_.dispatcher = EventDispatcher(app_)
|
app_.dispatcher = EventDispatcher()
|
||||||
app_.presence = PresenceManager(app_)
|
app_.presence = PresenceManager(app_)
|
||||||
|
|
||||||
app_.storage.presence = app_.presence
|
app_.storage.presence = app_.presence
|
||||||
|
|
@ -274,10 +275,11 @@ def init_app_managers(app_, *, voice=True):
|
||||||
# we do this because of a bug on ./manage.py where it
|
# we do this because of a bug on ./manage.py where it
|
||||||
# cancels the LVSPManager's spawn regions task. we don't
|
# cancels the LVSPManager's spawn regions task. we don't
|
||||||
# need to start it on manage time.
|
# need to start it on manage time.
|
||||||
if voice:
|
if init_voice:
|
||||||
app_.voice = VoiceManager(app_)
|
app_.voice = VoiceManager(app_)
|
||||||
|
|
||||||
app_.guild_store = GuildMemoryStore()
|
app_.guild_store = GuildMemoryStore()
|
||||||
|
app_.lazy_guild = LazyGuildManager()
|
||||||
|
|
||||||
|
|
||||||
async def api_index(app_):
|
async def api_index(app_):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue