mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'feature/robust-state-invalidation' into 'master'
Robust state invalidation See merge request litecord/litecord!62
This commit is contained in:
commit
fd82046f2c
|
|
@ -268,7 +268,7 @@ async def user_disconnect(user_id: int):
|
|||
|
||||
for state in user_states:
|
||||
# make it unable to resume
|
||||
app.state_manager.remove(state)
|
||||
app.state_manager.remove(state.session_id, user_id=user_id)
|
||||
|
||||
if not state.ws:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from collections import defaultdict
|
||||
|
||||
from quart import current_app as app
|
||||
|
|
@ -96,6 +96,8 @@ class StateManager:
|
|||
#: raw mapping from session ids to GatewayState
|
||||
self.states_raw = StateDictWrapper(self, {})
|
||||
|
||||
self.tasks = {}
|
||||
|
||||
def insert(self, state: GatewayState):
|
||||
"""Insert a new state object."""
|
||||
user_states = self.states[state.user_id]
|
||||
|
|
@ -119,21 +121,20 @@ class StateManager:
|
|||
"""Fetch a single state given the Session ID."""
|
||||
return self.states_raw[session_id]
|
||||
|
||||
def remove(self, state):
|
||||
def remove(self, session_id: str, *, user_id: Optional[int] = None):
|
||||
"""Remove a state from the registry"""
|
||||
if not state:
|
||||
return
|
||||
|
||||
try:
|
||||
self.states_raw.pop(state.session_id)
|
||||
state = self.states_raw.pop(session_id)
|
||||
user_id = state.user_id
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
log.debug("removing state: {!r}", state)
|
||||
self.states[state.user_id].pop(state.session_id)
|
||||
except KeyError:
|
||||
pass
|
||||
if user_id is not None:
|
||||
try:
|
||||
log.debug("removing state: {!r}", state)
|
||||
self.states[state.user_id].pop(session_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def fetch_states(self, user_id: int, guild_id: int) -> List[GatewayState]:
|
||||
"""Fetch all states that are tied to a guild."""
|
||||
|
|
@ -188,14 +189,14 @@ class StateManager:
|
|||
"""Send OP Reconnect to a single connection."""
|
||||
websocket = state.ws
|
||||
|
||||
await websocket.send({"op": OP.RECONNECT})
|
||||
|
||||
# wait 200ms
|
||||
# so that the client has time to process
|
||||
# our payload then close the connection
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
try:
|
||||
await websocket.send({"op": OP.RECONNECT})
|
||||
|
||||
# wait 200ms
|
||||
# so that the client has time to process
|
||||
# our payload then close the connection
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# try to close the connection ourselves
|
||||
await websocket.ws.close(code=4000, reason="litecord shutting down")
|
||||
except ConnectionClosed:
|
||||
|
|
@ -239,3 +240,21 @@ class StateManager:
|
|||
|
||||
# DMs and GDMs use all user states
|
||||
return self.user_states(user_id)
|
||||
|
||||
async def _future_cleanup(self, state: GatewayState):
|
||||
await asyncio.sleep(30)
|
||||
self.remove(state)
|
||||
state.ws.state = None
|
||||
state.ws = None
|
||||
|
||||
async def schedule_deletion(self, state: GatewayState):
|
||||
task = app.loop.create_task(self._future_cleanup(state))
|
||||
self.tasks[state.session_id] = task
|
||||
|
||||
async def unschedule_deletion(self, state: GatewayState):
|
||||
try:
|
||||
task = self.tasks.pop(state.session_id)
|
||||
except KeyError:
|
||||
return
|
||||
|
||||
task.cancel()
|
||||
|
|
|
|||
|
|
@ -267,9 +267,15 @@ class GatewayWebsocket:
|
|||
|
||||
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
|
||||
|
||||
await self.send(payload)
|
||||
try:
|
||||
await self.send(payload)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
log.warning(
|
||||
"Failed to dispatch {!r} to {}", event.upper, self.state.session_id
|
||||
)
|
||||
|
||||
async def _make_guild_list(self) -> List[Dict[str, Any]]:
|
||||
assert self.state is not None
|
||||
user_id = self.state.user_id
|
||||
|
||||
guild_ids = await self._guild_ids()
|
||||
|
|
@ -764,10 +770,11 @@ class GatewayWebsocket:
|
|||
# since the state will be removed from
|
||||
# the manager, it will become unreachable
|
||||
# when trying to resume.
|
||||
self.app.state_manager.remove(self.state)
|
||||
self.app.state_manager.remove(self.state.user_id)
|
||||
|
||||
async def _resume(self, replay_seqs: Iterable):
|
||||
presences = []
|
||||
assert self.state is not None
|
||||
presences: List[dict] = []
|
||||
|
||||
try:
|
||||
for seq in replay_seqs:
|
||||
|
|
@ -824,6 +831,7 @@ class GatewayWebsocket:
|
|||
return await self.invalidate_session(False)
|
||||
|
||||
# relink this connection
|
||||
await self.app.state_manager.unschedule_deletion(state)
|
||||
self.state = state
|
||||
state.ws = self
|
||||
|
||||
|
|
@ -1085,8 +1093,8 @@ class GatewayWebsocket:
|
|||
task.cancel()
|
||||
|
||||
if self.state:
|
||||
self.app.state_manager.remove(self.state)
|
||||
self.state.ws = None
|
||||
self.app.state_manager.schedule_deletion(self.state)
|
||||
self.state = None
|
||||
|
||||
async def _check_conns(self, user_id):
|
||||
|
|
|
|||
|
|
@ -176,3 +176,79 @@ async def test_etf(test_cli):
|
|||
assert hello["op"] == OP.HELLO
|
||||
finally:
|
||||
await _close(conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume(test_cli_user):
|
||||
conn = await gw_start(test_cli_user.cli)
|
||||
|
||||
# get the hello frame but ignore it
|
||||
await _json(conn)
|
||||
|
||||
await _json_send(
|
||||
conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}}
|
||||
)
|
||||
|
||||
try:
|
||||
ready = await _json(conn)
|
||||
assert isinstance(ready, dict)
|
||||
assert ready["op"] == OP.DISPATCH
|
||||
assert ready["t"] == "READY"
|
||||
|
||||
data = ready["d"]
|
||||
assert isinstance(data, dict)
|
||||
|
||||
assert isinstance(data["session_id"], str)
|
||||
sess_id: str = data["session_id"]
|
||||
finally:
|
||||
await _close(conn)
|
||||
|
||||
# try to resume
|
||||
conn = await gw_start(test_cli_user.cli)
|
||||
_ = await _json(conn)
|
||||
|
||||
await _json_send(
|
||||
conn,
|
||||
{
|
||||
"op": OP.RESUME,
|
||||
"d": {
|
||||
"token": test_cli_user.user["token"],
|
||||
"session_id": sess_id,
|
||||
"seq": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
msg = await _json(conn)
|
||||
assert isinstance(msg, dict)
|
||||
assert isinstance(msg["op"], int)
|
||||
assert msg["op"] == OP.DISPATCH
|
||||
assert isinstance(msg["t"], str)
|
||||
assert msg["t"] in ("RESUMED", "PRESENCE_REPLACE")
|
||||
|
||||
# close again, and retry again, but this time by removing the state
|
||||
# and asserting the session won't be resumed.
|
||||
await _close(conn)
|
||||
|
||||
conn = await gw_start(test_cli_user.cli)
|
||||
_ = await _json(conn)
|
||||
|
||||
async with test_cli_user.app.app_context():
|
||||
test_cli_user.app.state_manager.remove(sess_id)
|
||||
|
||||
await _json_send(
|
||||
conn,
|
||||
{
|
||||
"op": OP.RESUME,
|
||||
"d": {
|
||||
"token": test_cli_user.user["token"],
|
||||
"session_id": sess_id,
|
||||
"seq": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
msg = await _json(conn)
|
||||
assert isinstance(msg, dict)
|
||||
assert isinstance(msg["op"], int)
|
||||
assert msg["op"] == OP.INVALID_SESSION
|
||||
|
|
|
|||
Loading…
Reference in New Issue