gateway.websocket: remove WebsocketObjects

we can just use the app object directly.
This commit is contained in:
Luna 2019-10-25 09:41:52 -03:00
parent 7c878515e9
commit 7278c15d9c
1 changed files with 32 additions and 58 deletions

View File

@ -21,7 +21,7 @@ import collections
import asyncio
import pprint
import zlib
from typing import List, Dict, Any
from typing import List, Dict, Any, Iterable
from random import randint
import websockets
@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple(
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
)
WebsocketObjects = collections.namedtuple(
"WebsocketObjects",
(
"db",
"state_manager",
"storage",
"loop",
"dispatcher",
"presence",
"ratelimiter",
"user_storage",
"voice",
),
)
class GatewayWebsocket:
"""Main gateway websocket logic."""
def __init__(self, ws, app, **kwargs):
self.ext = WebsocketObjects(
app.db,
app.state_manager,
app.storage,
app.loop,
app.dispatcher,
app.presence,
app.ratelimiter,
app.user_storage,
app.voice,
)
self.storage = self.ext.storage
self.user_storage = self.ext.user_storage
self.presence = self.ext.presence
self.app = app
self.storage = app.storage
self.user_storage = app.user_storage
self.presence = app.presence
self.ws = ws
self.wsp = WebsocketProperties(
@ -225,7 +199,7 @@ class GatewayWebsocket:
await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@ -245,7 +219,7 @@ class GatewayWebsocket:
if task:
task.cancel()
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
task_wrapper("hb wait", self._hb_wait(interval))
)
@ -330,7 +304,7 @@ class GatewayWebsocket:
if r["type"] == RelationshipType.FRIEND.value
]
friend_presences = await self.ext.presence.friend_presences(friend_ids)
friend_presences = await self.app.presence.friend_presences(friend_ids)
settings = await self.user_storage.get_user_settings(user_id)
return {
@ -377,14 +351,14 @@ class GatewayWebsocket:
await self.dispatch("READY", {**base_ready, **user_ready})
# async dispatch of guilds
self.ext.loop.create_task(self._guild_dispatch(guilds))
self.app.loop.create_task(self._guild_dispatch(guilds))
async def _check_shards(self, shard, user_id):
"""Check if the given `shard` value in IDENTIFY has good enough values.
"""
current_shard, shard_count = shard
guilds = await self.ext.db.fetchval(
guilds = await self.app.db.fetchval(
"""
SELECT COUNT(*)
FROM members
@ -460,7 +434,7 @@ class GatewayWebsocket:
("channel", gdm_ids),
]
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
if not self.state.bot:
# subscribe to all friends
@ -468,7 +442,7 @@ class GatewayWebsocket:
# when they come online)
friend_ids = await self.user_storage.get_friend_ids(user_id)
log.info("subscribing to {} friends", len(friend_ids))
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
async def update_status(self, status: dict):
"""Update the status of the current websocket connection."""
@ -520,7 +494,7 @@ class GatewayWebsocket:
f'Updating presence status={status["status"]} for '
f"uid={self.state.user_id}"
)
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
@ -558,13 +532,13 @@ class GatewayWebsocket:
presence = data.get("presence")
try:
user_id = await raw_token_check(token, self.ext.db)
user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id)
bot = await self.ext.db.fetchval(
bot = await self.app.db.fetchval(
"""
SELECT bot FROM users
WHERE id = $1
@ -587,7 +561,7 @@ class GatewayWebsocket:
)
# link the state to the user
self.ext.state_manager.insert(self.state)
self.app.state_manager.insert(self.state)
await self.update_status(presence)
await self.subscribe_all(data.get("guild_subscriptions", True))
@ -631,12 +605,12 @@ class GatewayWebsocket:
# if its null and null, disconnect the user from any voice
# TODO: maybe just leave from DMs? idk...
if channel_id is None and guild_id is None:
return await self.ext.voice.leave_all(self.state.user_id)
return await self.app.voice.leave_all(self.state.user_id)
# if guild is not none but channel is, we are leaving
# a guild's channel
if channel_id is None:
return await self.ext.voice.leave(guild_id, self.state.user_id)
return await self.app.voice.leave(guild_id, self.state.user_id)
# fetch an existing state given user and guild OR user and channel
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
@ -659,10 +633,10 @@ class GatewayWebsocket:
# this state id format takes care of that.
voice_key = (self.state.user_id, state_id2)
voice_state = await self.ext.voice.get_state(voice_key)
voice_state = await self.app.voice.get_state(voice_key)
if voice_state is None:
return await self.ext.voice.create_state(voice_key, data)
return await self.app.voice.create_state(voice_key, data)
same_guild = guild_id == voice_state.guild_id
same_channel = channel_id == voice_state.channel_id
@ -670,10 +644,10 @@ class GatewayWebsocket:
prop = await self._vsu_get_prop(voice_state, data)
if same_guild and same_channel:
return await self.ext.voice.update_state(voice_state, prop)
return await self.app.voice.update_state(voice_state, prop)
if same_guild and not same_channel:
return await self.ext.voice.move_state(voice_state, channel_id)
return await self.app.voice.move_state(voice_state, channel_id)
async def _handle_5(self, payload: Dict[str, Any]):
"""Handle OP 5 Voice Server Ping.
@ -698,9 +672,9 @@ class GatewayWebsocket:
# since the state will be removed from
# the manager, it will become unreachable
# when trying to resume.
self.ext.state_manager.remove(self.state)
self.app.state_manager.remove(self.state)
async def _resume(self, replay_seqs: iter):
async def _resume(self, replay_seqs: Iterable):
presences = []
try:
@ -740,12 +714,12 @@ class GatewayWebsocket:
raise DecodeError("Invalid resume payload")
try:
user_id = await raw_token_check(token, self.ext.db)
user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Invalid token")
try:
state = self.ext.state_manager.fetch(user_id, sess_id)
state = self.app.state_manager.fetch(user_id, sess_id)
except KeyError:
return await self.invalidate_session(False)
@ -948,7 +922,7 @@ class GatewayWebsocket:
log.debug("lazy request: members: {}", data.get("members", []))
# make shard query
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
for chan_id, ranges in data.get("channels", {}).items():
chan_id = int(chan_id)
@ -992,10 +966,10 @@ class GatewayWebsocket:
# close anyone trying to login while the
# server is shutting down
if self.ext.state_manager.closed:
if self.app.state_manager.closed:
raise WebsocketClose(4000, "state manager closed")
if not self.ext.state_manager.accept_new:
if not self.app.state_manager.accept_new:
raise WebsocketClose(4000, "state manager closed for new")
while True:
@ -1016,7 +990,7 @@ class GatewayWebsocket:
task.cancel()
if self.state:
self.ext.state_manager.remove(self.state)
self.app.state_manager.remove(self.state)
self.state.ws = None
self.state = None
@ -1031,14 +1005,14 @@ class GatewayWebsocket:
# TODO: account for sharding
# this only updates status to offline once
# ALL shards have come offline
states = self.ext.state_manager.user_states(user_id)
states = self.app.state_manager.user_states(user_id)
with_ws = [s for s in states if s.ws]
# there arent any other states with websocket
if not with_ws:
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
await self.ext.presence.dispatch_pres(user_id, offline)
await self.app.presence.dispatch_pres(user_id, offline)
async def run(self):
"""Wrap :meth:`listen_messages` inside