gateway.websocket: remove WebsocketObjects

we can just use the app object directly.
This commit is contained in:
Luna 2019-10-25 09:41:52 -03:00
parent 7c878515e9
commit 7278c15d9c
1 changed files with 32 additions and 58 deletions

View File

@ -21,7 +21,7 @@ import collections
import asyncio import asyncio
import pprint import pprint
import zlib import zlib
from typing import List, Dict, Any from typing import List, Dict, Any, Iterable
from random import randint from random import randint
import websockets import websockets
@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple(
"WebsocketProperties", "v encoding compress zctx zsctx tasks" "WebsocketProperties", "v encoding compress zctx zsctx tasks"
) )
WebsocketObjects = collections.namedtuple(
"WebsocketObjects",
(
"db",
"state_manager",
"storage",
"loop",
"dispatcher",
"presence",
"ratelimiter",
"user_storage",
"voice",
),
)
class GatewayWebsocket: class GatewayWebsocket:
"""Main gateway websocket logic.""" """Main gateway websocket logic."""
def __init__(self, ws, app, **kwargs): def __init__(self, ws, app, **kwargs):
self.ext = WebsocketObjects( self.app = app
app.db, self.storage = app.storage
app.state_manager, self.user_storage = app.user_storage
app.storage, self.presence = app.presence
app.loop,
app.dispatcher,
app.presence,
app.ratelimiter,
app.user_storage,
app.voice,
)
self.storage = self.ext.storage
self.user_storage = self.ext.user_storage
self.presence = self.ext.presence
self.ws = ws self.ws = ws
self.wsp = WebsocketProperties( self.wsp = WebsocketProperties(
@ -225,7 +199,7 @@ class GatewayWebsocket:
await self.send({"op": op_code, "d": data, "t": None, "s": None}) await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key): def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}") ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
bucket = ratelimit.get_bucket(ratelimit_key) bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit() return bucket.update_rate_limit()
@ -245,7 +219,7 @@ class GatewayWebsocket:
if task: if task:
task.cancel() task.cancel()
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task( self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
task_wrapper("hb wait", self._hb_wait(interval)) task_wrapper("hb wait", self._hb_wait(interval))
) )
@ -330,7 +304,7 @@ class GatewayWebsocket:
if r["type"] == RelationshipType.FRIEND.value if r["type"] == RelationshipType.FRIEND.value
] ]
friend_presences = await self.ext.presence.friend_presences(friend_ids) friend_presences = await self.app.presence.friend_presences(friend_ids)
settings = await self.user_storage.get_user_settings(user_id) settings = await self.user_storage.get_user_settings(user_id)
return { return {
@ -377,14 +351,14 @@ class GatewayWebsocket:
await self.dispatch("READY", {**base_ready, **user_ready}) await self.dispatch("READY", {**base_ready, **user_ready})
# async dispatch of guilds # async dispatch of guilds
self.ext.loop.create_task(self._guild_dispatch(guilds)) self.app.loop.create_task(self._guild_dispatch(guilds))
async def _check_shards(self, shard, user_id): async def _check_shards(self, shard, user_id):
"""Check if the given `shard` value in IDENTIFY has good enough values. """Check if the given `shard` value in IDENTIFY has good enough values.
""" """
current_shard, shard_count = shard current_shard, shard_count = shard
guilds = await self.ext.db.fetchval( guilds = await self.app.db.fetchval(
""" """
SELECT COUNT(*) SELECT COUNT(*)
FROM members FROM members
@ -460,7 +434,7 @@ class GatewayWebsocket:
("channel", gdm_ids), ("channel", gdm_ids),
] ]
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub) await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
if not self.state.bot: if not self.state.bot:
# subscribe to all friends # subscribe to all friends
@ -468,7 +442,7 @@ class GatewayWebsocket:
# when they come online) # when they come online)
friend_ids = await self.user_storage.get_friend_ids(user_id) friend_ids = await self.user_storage.get_friend_ids(user_id)
log.info("subscribing to {} friends", len(friend_ids)) log.info("subscribing to {} friends", len(friend_ids))
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids) await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
async def update_status(self, status: dict): async def update_status(self, status: dict):
"""Update the status of the current websocket connection.""" """Update the status of the current websocket connection."""
@ -520,7 +494,7 @@ class GatewayWebsocket:
f'Updating presence status={status["status"]} for ' f'Updating presence status={status["status"]} for '
f"uid={self.state.user_id}" f"uid={self.state.user_id}"
) )
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence) await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
async def handle_1(self, payload: Dict[str, Any]): async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets.""" """Handle OP 1 Heartbeat packets."""
@ -558,13 +532,13 @@ class GatewayWebsocket:
presence = data.get("presence") presence = data.get("presence")
try: try:
user_id = await raw_token_check(token, self.ext.db) user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Authentication failed") raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id) await self._connect_ratelimit(user_id)
bot = await self.ext.db.fetchval( bot = await self.app.db.fetchval(
""" """
SELECT bot FROM users SELECT bot FROM users
WHERE id = $1 WHERE id = $1
@ -587,7 +561,7 @@ class GatewayWebsocket:
) )
# link the state to the user # link the state to the user
self.ext.state_manager.insert(self.state) self.app.state_manager.insert(self.state)
await self.update_status(presence) await self.update_status(presence)
await self.subscribe_all(data.get("guild_subscriptions", True)) await self.subscribe_all(data.get("guild_subscriptions", True))
@ -631,12 +605,12 @@ class GatewayWebsocket:
# if its null and null, disconnect the user from any voice # if its null and null, disconnect the user from any voice
# TODO: maybe just leave from DMs? idk... # TODO: maybe just leave from DMs? idk...
if channel_id is None and guild_id is None: if channel_id is None and guild_id is None:
return await self.ext.voice.leave_all(self.state.user_id) return await self.app.voice.leave_all(self.state.user_id)
# if guild is not none but channel is, we are leaving # if guild is not none but channel is, we are leaving
# a guild's channel # a guild's channel
if channel_id is None: if channel_id is None:
return await self.ext.voice.leave(guild_id, self.state.user_id) return await self.app.voice.leave(guild_id, self.state.user_id)
# fetch an existing state given user and guild OR user and channel # fetch an existing state given user and guild OR user and channel
chan_type = ChannelType(await self.storage.get_chan_type(channel_id)) chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
@ -659,10 +633,10 @@ class GatewayWebsocket:
# this state id format takes care of that. # this state id format takes care of that.
voice_key = (self.state.user_id, state_id2) voice_key = (self.state.user_id, state_id2)
voice_state = await self.ext.voice.get_state(voice_key) voice_state = await self.app.voice.get_state(voice_key)
if voice_state is None: if voice_state is None:
return await self.ext.voice.create_state(voice_key, data) return await self.app.voice.create_state(voice_key, data)
same_guild = guild_id == voice_state.guild_id same_guild = guild_id == voice_state.guild_id
same_channel = channel_id == voice_state.channel_id same_channel = channel_id == voice_state.channel_id
@ -670,10 +644,10 @@ class GatewayWebsocket:
prop = await self._vsu_get_prop(voice_state, data) prop = await self._vsu_get_prop(voice_state, data)
if same_guild and same_channel: if same_guild and same_channel:
return await self.ext.voice.update_state(voice_state, prop) return await self.app.voice.update_state(voice_state, prop)
if same_guild and not same_channel: if same_guild and not same_channel:
return await self.ext.voice.move_state(voice_state, channel_id) return await self.app.voice.move_state(voice_state, channel_id)
async def _handle_5(self, payload: Dict[str, Any]): async def _handle_5(self, payload: Dict[str, Any]):
"""Handle OP 5 Voice Server Ping. """Handle OP 5 Voice Server Ping.
@ -698,9 +672,9 @@ class GatewayWebsocket:
# since the state will be removed from # since the state will be removed from
# the manager, it will become unreachable # the manager, it will become unreachable
# when trying to resume. # when trying to resume.
self.ext.state_manager.remove(self.state) self.app.state_manager.remove(self.state)
async def _resume(self, replay_seqs: iter): async def _resume(self, replay_seqs: Iterable):
presences = [] presences = []
try: try:
@ -740,12 +714,12 @@ class GatewayWebsocket:
raise DecodeError("Invalid resume payload") raise DecodeError("Invalid resume payload")
try: try:
user_id = await raw_token_check(token, self.ext.db) user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Invalid token") raise WebsocketClose(4004, "Invalid token")
try: try:
state = self.ext.state_manager.fetch(user_id, sess_id) state = self.app.state_manager.fetch(user_id, sess_id)
except KeyError: except KeyError:
return await self.invalidate_session(False) return await self.invalidate_session(False)
@ -948,7 +922,7 @@ class GatewayWebsocket:
log.debug("lazy request: members: {}", data.get("members", [])) log.debug("lazy request: members: {}", data.get("members", []))
# make shard query # make shard query
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"] lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
for chan_id, ranges in data.get("channels", {}).items(): for chan_id, ranges in data.get("channels", {}).items():
chan_id = int(chan_id) chan_id = int(chan_id)
@ -992,10 +966,10 @@ class GatewayWebsocket:
# close anyone trying to login while the # close anyone trying to login while the
# server is shutting down # server is shutting down
if self.ext.state_manager.closed: if self.app.state_manager.closed:
raise WebsocketClose(4000, "state manager closed") raise WebsocketClose(4000, "state manager closed")
if not self.ext.state_manager.accept_new: if not self.app.state_manager.accept_new:
raise WebsocketClose(4000, "state manager closed for new") raise WebsocketClose(4000, "state manager closed for new")
while True: while True:
@ -1016,7 +990,7 @@ class GatewayWebsocket:
task.cancel() task.cancel()
if self.state: if self.state:
self.ext.state_manager.remove(self.state) self.app.state_manager.remove(self.state)
self.state.ws = None self.state.ws = None
self.state = None self.state = None
@ -1031,14 +1005,14 @@ class GatewayWebsocket:
# TODO: account for sharding # TODO: account for sharding
# this only updates status to offline once # this only updates status to offline once
# ALL shards have come offline # ALL shards have come offline
states = self.ext.state_manager.user_states(user_id) states = self.app.state_manager.user_states(user_id)
with_ws = [s for s in states if s.ws] with_ws = [s for s in states if s.ws]
# there arent any other states with websocket # there arent any other states with websocket
if not with_ws: if not with_ws:
offline = {"afk": False, "status": "offline", "game": None, "since": 0} offline = {"afk": False, "status": "offline", "game": None, "since": 0}
await self.ext.presence.dispatch_pres(user_id, offline) await self.app.presence.dispatch_pres(user_id, offline)
async def run(self): async def run(self):
"""Wrap :meth:`listen_messages` inside """Wrap :meth:`listen_messages` inside