mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: remove WebsocketObjects
we can just use the app object directly.
This commit is contained in:
parent
7c878515e9
commit
7278c15d9c
|
|
@ -21,7 +21,7 @@ import collections
|
||||||
import asyncio
|
import asyncio
|
||||||
import pprint
|
import pprint
|
||||||
import zlib
|
import zlib
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Iterable
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple(
|
||||||
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
|
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
WebsocketObjects = collections.namedtuple(
|
|
||||||
"WebsocketObjects",
|
|
||||||
(
|
|
||||||
"db",
|
|
||||||
"state_manager",
|
|
||||||
"storage",
|
|
||||||
"loop",
|
|
||||||
"dispatcher",
|
|
||||||
"presence",
|
|
||||||
"ratelimiter",
|
|
||||||
"user_storage",
|
|
||||||
"voice",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
"""Main gateway websocket logic."""
|
"""Main gateway websocket logic."""
|
||||||
|
|
||||||
def __init__(self, ws, app, **kwargs):
|
def __init__(self, ws, app, **kwargs):
|
||||||
self.ext = WebsocketObjects(
|
self.app = app
|
||||||
app.db,
|
self.storage = app.storage
|
||||||
app.state_manager,
|
self.user_storage = app.user_storage
|
||||||
app.storage,
|
self.presence = app.presence
|
||||||
app.loop,
|
|
||||||
app.dispatcher,
|
|
||||||
app.presence,
|
|
||||||
app.ratelimiter,
|
|
||||||
app.user_storage,
|
|
||||||
app.voice,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.storage = self.ext.storage
|
|
||||||
self.user_storage = self.ext.user_storage
|
|
||||||
self.presence = self.ext.presence
|
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
||||||
self.wsp = WebsocketProperties(
|
self.wsp = WebsocketProperties(
|
||||||
|
|
@ -225,7 +199,7 @@ class GatewayWebsocket:
|
||||||
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
||||||
|
|
||||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
def _check_ratelimit(self, key: str, ratelimit_key):
|
||||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
|
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
|
||||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||||
return bucket.update_rate_limit()
|
return bucket.update_rate_limit()
|
||||||
|
|
||||||
|
|
@ -245,7 +219,7 @@ class GatewayWebsocket:
|
||||||
if task:
|
if task:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
|
self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
|
||||||
task_wrapper("hb wait", self._hb_wait(interval))
|
task_wrapper("hb wait", self._hb_wait(interval))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -330,7 +304,7 @@ class GatewayWebsocket:
|
||||||
if r["type"] == RelationshipType.FRIEND.value
|
if r["type"] == RelationshipType.FRIEND.value
|
||||||
]
|
]
|
||||||
|
|
||||||
friend_presences = await self.ext.presence.friend_presences(friend_ids)
|
friend_presences = await self.app.presence.friend_presences(friend_ids)
|
||||||
settings = await self.user_storage.get_user_settings(user_id)
|
settings = await self.user_storage.get_user_settings(user_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -377,14 +351,14 @@ class GatewayWebsocket:
|
||||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||||
|
|
||||||
# async dispatch of guilds
|
# async dispatch of guilds
|
||||||
self.ext.loop.create_task(self._guild_dispatch(guilds))
|
self.app.loop.create_task(self._guild_dispatch(guilds))
|
||||||
|
|
||||||
async def _check_shards(self, shard, user_id):
|
async def _check_shards(self, shard, user_id):
|
||||||
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
||||||
"""
|
"""
|
||||||
current_shard, shard_count = shard
|
current_shard, shard_count = shard
|
||||||
|
|
||||||
guilds = await self.ext.db.fetchval(
|
guilds = await self.app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM members
|
FROM members
|
||||||
|
|
@ -460,7 +434,7 @@ class GatewayWebsocket:
|
||||||
("channel", gdm_ids),
|
("channel", gdm_ids),
|
||||||
]
|
]
|
||||||
|
|
||||||
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
|
await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||||
|
|
||||||
if not self.state.bot:
|
if not self.state.bot:
|
||||||
# subscribe to all friends
|
# subscribe to all friends
|
||||||
|
|
@ -468,7 +442,7 @@ class GatewayWebsocket:
|
||||||
# when they come online)
|
# when they come online)
|
||||||
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
||||||
log.info("subscribing to {} friends", len(friend_ids))
|
log.info("subscribing to {} friends", len(friend_ids))
|
||||||
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
|
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||||
|
|
||||||
async def update_status(self, status: dict):
|
async def update_status(self, status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
|
|
@ -520,7 +494,7 @@ class GatewayWebsocket:
|
||||||
f'Updating presence status={status["status"]} for '
|
f'Updating presence status={status["status"]} for '
|
||||||
f"uid={self.state.user_id}"
|
f"uid={self.state.user_id}"
|
||||||
)
|
)
|
||||||
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
|
await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
|
||||||
|
|
||||||
async def handle_1(self, payload: Dict[str, Any]):
|
async def handle_1(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 1 Heartbeat packets."""
|
"""Handle OP 1 Heartbeat packets."""
|
||||||
|
|
@ -558,13 +532,13 @@ class GatewayWebsocket:
|
||||||
presence = data.get("presence")
|
presence = data.get("presence")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token, self.ext.db)
|
user_id = await raw_token_check(token, self.app.db)
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, "Authentication failed")
|
raise WebsocketClose(4004, "Authentication failed")
|
||||||
|
|
||||||
await self._connect_ratelimit(user_id)
|
await self._connect_ratelimit(user_id)
|
||||||
|
|
||||||
bot = await self.ext.db.fetchval(
|
bot = await self.app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT bot FROM users
|
SELECT bot FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
@ -587,7 +561,7 @@ class GatewayWebsocket:
|
||||||
)
|
)
|
||||||
|
|
||||||
# link the state to the user
|
# link the state to the user
|
||||||
self.ext.state_manager.insert(self.state)
|
self.app.state_manager.insert(self.state)
|
||||||
|
|
||||||
await self.update_status(presence)
|
await self.update_status(presence)
|
||||||
await self.subscribe_all(data.get("guild_subscriptions", True))
|
await self.subscribe_all(data.get("guild_subscriptions", True))
|
||||||
|
|
@ -631,12 +605,12 @@ class GatewayWebsocket:
|
||||||
# if its null and null, disconnect the user from any voice
|
# if its null and null, disconnect the user from any voice
|
||||||
# TODO: maybe just leave from DMs? idk...
|
# TODO: maybe just leave from DMs? idk...
|
||||||
if channel_id is None and guild_id is None:
|
if channel_id is None and guild_id is None:
|
||||||
return await self.ext.voice.leave_all(self.state.user_id)
|
return await self.app.voice.leave_all(self.state.user_id)
|
||||||
|
|
||||||
# if guild is not none but channel is, we are leaving
|
# if guild is not none but channel is, we are leaving
|
||||||
# a guild's channel
|
# a guild's channel
|
||||||
if channel_id is None:
|
if channel_id is None:
|
||||||
return await self.ext.voice.leave(guild_id, self.state.user_id)
|
return await self.app.voice.leave(guild_id, self.state.user_id)
|
||||||
|
|
||||||
# fetch an existing state given user and guild OR user and channel
|
# fetch an existing state given user and guild OR user and channel
|
||||||
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
|
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
|
||||||
|
|
@ -659,10 +633,10 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
# this state id format takes care of that.
|
# this state id format takes care of that.
|
||||||
voice_key = (self.state.user_id, state_id2)
|
voice_key = (self.state.user_id, state_id2)
|
||||||
voice_state = await self.ext.voice.get_state(voice_key)
|
voice_state = await self.app.voice.get_state(voice_key)
|
||||||
|
|
||||||
if voice_state is None:
|
if voice_state is None:
|
||||||
return await self.ext.voice.create_state(voice_key, data)
|
return await self.app.voice.create_state(voice_key, data)
|
||||||
|
|
||||||
same_guild = guild_id == voice_state.guild_id
|
same_guild = guild_id == voice_state.guild_id
|
||||||
same_channel = channel_id == voice_state.channel_id
|
same_channel = channel_id == voice_state.channel_id
|
||||||
|
|
@ -670,10 +644,10 @@ class GatewayWebsocket:
|
||||||
prop = await self._vsu_get_prop(voice_state, data)
|
prop = await self._vsu_get_prop(voice_state, data)
|
||||||
|
|
||||||
if same_guild and same_channel:
|
if same_guild and same_channel:
|
||||||
return await self.ext.voice.update_state(voice_state, prop)
|
return await self.app.voice.update_state(voice_state, prop)
|
||||||
|
|
||||||
if same_guild and not same_channel:
|
if same_guild and not same_channel:
|
||||||
return await self.ext.voice.move_state(voice_state, channel_id)
|
return await self.app.voice.move_state(voice_state, channel_id)
|
||||||
|
|
||||||
async def _handle_5(self, payload: Dict[str, Any]):
|
async def _handle_5(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 5 Voice Server Ping.
|
"""Handle OP 5 Voice Server Ping.
|
||||||
|
|
@ -698,9 +672,9 @@ class GatewayWebsocket:
|
||||||
# since the state will be removed from
|
# since the state will be removed from
|
||||||
# the manager, it will become unreachable
|
# the manager, it will become unreachable
|
||||||
# when trying to resume.
|
# when trying to resume.
|
||||||
self.ext.state_manager.remove(self.state)
|
self.app.state_manager.remove(self.state)
|
||||||
|
|
||||||
async def _resume(self, replay_seqs: iter):
|
async def _resume(self, replay_seqs: Iterable):
|
||||||
presences = []
|
presences = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -740,12 +714,12 @@ class GatewayWebsocket:
|
||||||
raise DecodeError("Invalid resume payload")
|
raise DecodeError("Invalid resume payload")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token, self.ext.db)
|
user_id = await raw_token_check(token, self.app.db)
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, "Invalid token")
|
raise WebsocketClose(4004, "Invalid token")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = self.ext.state_manager.fetch(user_id, sess_id)
|
state = self.app.state_manager.fetch(user_id, sess_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return await self.invalidate_session(False)
|
return await self.invalidate_session(False)
|
||||||
|
|
||||||
|
|
@ -948,7 +922,7 @@ class GatewayWebsocket:
|
||||||
log.debug("lazy request: members: {}", data.get("members", []))
|
log.debug("lazy request: members: {}", data.get("members", []))
|
||||||
|
|
||||||
# make shard query
|
# make shard query
|
||||||
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
|
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
|
||||||
|
|
||||||
for chan_id, ranges in data.get("channels", {}).items():
|
for chan_id, ranges in data.get("channels", {}).items():
|
||||||
chan_id = int(chan_id)
|
chan_id = int(chan_id)
|
||||||
|
|
@ -992,10 +966,10 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
# close anyone trying to login while the
|
# close anyone trying to login while the
|
||||||
# server is shutting down
|
# server is shutting down
|
||||||
if self.ext.state_manager.closed:
|
if self.app.state_manager.closed:
|
||||||
raise WebsocketClose(4000, "state manager closed")
|
raise WebsocketClose(4000, "state manager closed")
|
||||||
|
|
||||||
if not self.ext.state_manager.accept_new:
|
if not self.app.state_manager.accept_new:
|
||||||
raise WebsocketClose(4000, "state manager closed for new")
|
raise WebsocketClose(4000, "state manager closed for new")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -1016,7 +990,7 @@ class GatewayWebsocket:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
if self.state:
|
if self.state:
|
||||||
self.ext.state_manager.remove(self.state)
|
self.app.state_manager.remove(self.state)
|
||||||
self.state.ws = None
|
self.state.ws = None
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
|
|
@ -1031,14 +1005,14 @@ class GatewayWebsocket:
|
||||||
# TODO: account for sharding
|
# TODO: account for sharding
|
||||||
# this only updates status to offline once
|
# this only updates status to offline once
|
||||||
# ALL shards have come offline
|
# ALL shards have come offline
|
||||||
states = self.ext.state_manager.user_states(user_id)
|
states = self.app.state_manager.user_states(user_id)
|
||||||
with_ws = [s for s in states if s.ws]
|
with_ws = [s for s in states if s.ws]
|
||||||
|
|
||||||
# there arent any other states with websocket
|
# there arent any other states with websocket
|
||||||
if not with_ws:
|
if not with_ws:
|
||||||
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||||
|
|
||||||
await self.ext.presence.dispatch_pres(user_id, offline)
|
await self.app.presence.dispatch_pres(user_id, offline)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Wrap :meth:`listen_messages` inside
|
"""Wrap :meth:`listen_messages` inside
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue