mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'feature/robust-state-invalidation' into 'master'
Robust state invalidation See merge request litecord/litecord!62
This commit is contained in:
commit
fd82046f2c
|
|
@ -268,7 +268,7 @@ async def user_disconnect(user_id: int):
|
||||||
|
|
||||||
for state in user_states:
|
for state in user_states:
|
||||||
# make it unable to resume
|
# make it unable to resume
|
||||||
app.state_manager.remove(state)
|
app.state_manager.remove(state.session_id, user_id=user_id)
|
||||||
|
|
||||||
if not state.ws:
|
if not state.ws:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
@ -96,6 +96,8 @@ class StateManager:
|
||||||
#: raw mapping from session ids to GatewayState
|
#: raw mapping from session ids to GatewayState
|
||||||
self.states_raw = StateDictWrapper(self, {})
|
self.states_raw = StateDictWrapper(self, {})
|
||||||
|
|
||||||
|
self.tasks = {}
|
||||||
|
|
||||||
def insert(self, state: GatewayState):
|
def insert(self, state: GatewayState):
|
||||||
"""Insert a new state object."""
|
"""Insert a new state object."""
|
||||||
user_states = self.states[state.user_id]
|
user_states = self.states[state.user_id]
|
||||||
|
|
@ -119,21 +121,20 @@ class StateManager:
|
||||||
"""Fetch a single state given the Session ID."""
|
"""Fetch a single state given the Session ID."""
|
||||||
return self.states_raw[session_id]
|
return self.states_raw[session_id]
|
||||||
|
|
||||||
def remove(self, state):
|
def remove(self, session_id: str, *, user_id: Optional[int] = None):
|
||||||
"""Remove a state from the registry"""
|
"""Remove a state from the registry"""
|
||||||
if not state:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.states_raw.pop(state.session_id)
|
state = self.states_raw.pop(session_id)
|
||||||
|
user_id = state.user_id
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
if user_id is not None:
|
||||||
log.debug("removing state: {!r}", state)
|
try:
|
||||||
self.states[state.user_id].pop(state.session_id)
|
log.debug("removing state: {!r}", state)
|
||||||
except KeyError:
|
self.states[state.user_id].pop(session_id)
|
||||||
pass
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
def fetch_states(self, user_id: int, guild_id: int) -> List[GatewayState]:
|
def fetch_states(self, user_id: int, guild_id: int) -> List[GatewayState]:
|
||||||
"""Fetch all states that are tied to a guild."""
|
"""Fetch all states that are tied to a guild."""
|
||||||
|
|
@ -188,14 +189,14 @@ class StateManager:
|
||||||
"""Send OP Reconnect to a single connection."""
|
"""Send OP Reconnect to a single connection."""
|
||||||
websocket = state.ws
|
websocket = state.ws
|
||||||
|
|
||||||
await websocket.send({"op": OP.RECONNECT})
|
|
||||||
|
|
||||||
# wait 200ms
|
|
||||||
# so that the client has time to process
|
|
||||||
# our payload then close the connection
|
|
||||||
await asyncio.sleep(0.2)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
await websocket.send({"op": OP.RECONNECT})
|
||||||
|
|
||||||
|
# wait 200ms
|
||||||
|
# so that the client has time to process
|
||||||
|
# our payload then close the connection
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
# try to close the connection ourselves
|
# try to close the connection ourselves
|
||||||
await websocket.ws.close(code=4000, reason="litecord shutting down")
|
await websocket.ws.close(code=4000, reason="litecord shutting down")
|
||||||
except ConnectionClosed:
|
except ConnectionClosed:
|
||||||
|
|
@ -239,3 +240,21 @@ class StateManager:
|
||||||
|
|
||||||
# DMs and GDMs use all user states
|
# DMs and GDMs use all user states
|
||||||
return self.user_states(user_id)
|
return self.user_states(user_id)
|
||||||
|
|
||||||
|
async def _future_cleanup(self, state: GatewayState):
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
self.remove(state)
|
||||||
|
state.ws.state = None
|
||||||
|
state.ws = None
|
||||||
|
|
||||||
|
async def schedule_deletion(self, state: GatewayState):
|
||||||
|
task = app.loop.create_task(self._future_cleanup(state))
|
||||||
|
self.tasks[state.session_id] = task
|
||||||
|
|
||||||
|
async def unschedule_deletion(self, state: GatewayState):
|
||||||
|
try:
|
||||||
|
task = self.tasks.pop(state.session_id)
|
||||||
|
except KeyError:
|
||||||
|
return
|
||||||
|
|
||||||
|
task.cancel()
|
||||||
|
|
|
||||||
|
|
@ -267,9 +267,15 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
|
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
|
||||||
|
|
||||||
await self.send(payload)
|
try:
|
||||||
|
await self.send(payload)
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
log.warning(
|
||||||
|
"Failed to dispatch {!r} to {}", event.upper, self.state.session_id
|
||||||
|
)
|
||||||
|
|
||||||
async def _make_guild_list(self) -> List[Dict[str, Any]]:
|
async def _make_guild_list(self) -> List[Dict[str, Any]]:
|
||||||
|
assert self.state is not None
|
||||||
user_id = self.state.user_id
|
user_id = self.state.user_id
|
||||||
|
|
||||||
guild_ids = await self._guild_ids()
|
guild_ids = await self._guild_ids()
|
||||||
|
|
@ -764,10 +770,11 @@ class GatewayWebsocket:
|
||||||
# since the state will be removed from
|
# since the state will be removed from
|
||||||
# the manager, it will become unreachable
|
# the manager, it will become unreachable
|
||||||
# when trying to resume.
|
# when trying to resume.
|
||||||
self.app.state_manager.remove(self.state)
|
self.app.state_manager.remove(self.state.user_id)
|
||||||
|
|
||||||
async def _resume(self, replay_seqs: Iterable):
|
async def _resume(self, replay_seqs: Iterable):
|
||||||
presences = []
|
assert self.state is not None
|
||||||
|
presences: List[dict] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for seq in replay_seqs:
|
for seq in replay_seqs:
|
||||||
|
|
@ -824,6 +831,7 @@ class GatewayWebsocket:
|
||||||
return await self.invalidate_session(False)
|
return await self.invalidate_session(False)
|
||||||
|
|
||||||
# relink this connection
|
# relink this connection
|
||||||
|
await self.app.state_manager.unschedule_deletion(state)
|
||||||
self.state = state
|
self.state = state
|
||||||
state.ws = self
|
state.ws = self
|
||||||
|
|
||||||
|
|
@ -1085,8 +1093,8 @@ class GatewayWebsocket:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
if self.state:
|
if self.state:
|
||||||
self.app.state_manager.remove(self.state)
|
|
||||||
self.state.ws = None
|
self.state.ws = None
|
||||||
|
self.app.state_manager.schedule_deletion(self.state)
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
async def _check_conns(self, user_id):
|
async def _check_conns(self, user_id):
|
||||||
|
|
|
||||||
|
|
@ -176,3 +176,79 @@ async def test_etf(test_cli):
|
||||||
assert hello["op"] == OP.HELLO
|
assert hello["op"] == OP.HELLO
|
||||||
finally:
|
finally:
|
||||||
await _close(conn)
|
await _close(conn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume(test_cli_user):
|
||||||
|
conn = await gw_start(test_cli_user.cli)
|
||||||
|
|
||||||
|
# get the hello frame but ignore it
|
||||||
|
await _json(conn)
|
||||||
|
|
||||||
|
await _json_send(
|
||||||
|
conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ready = await _json(conn)
|
||||||
|
assert isinstance(ready, dict)
|
||||||
|
assert ready["op"] == OP.DISPATCH
|
||||||
|
assert ready["t"] == "READY"
|
||||||
|
|
||||||
|
data = ready["d"]
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
|
||||||
|
assert isinstance(data["session_id"], str)
|
||||||
|
sess_id: str = data["session_id"]
|
||||||
|
finally:
|
||||||
|
await _close(conn)
|
||||||
|
|
||||||
|
# try to resume
|
||||||
|
conn = await gw_start(test_cli_user.cli)
|
||||||
|
_ = await _json(conn)
|
||||||
|
|
||||||
|
await _json_send(
|
||||||
|
conn,
|
||||||
|
{
|
||||||
|
"op": OP.RESUME,
|
||||||
|
"d": {
|
||||||
|
"token": test_cli_user.user["token"],
|
||||||
|
"session_id": sess_id,
|
||||||
|
"seq": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await _json(conn)
|
||||||
|
assert isinstance(msg, dict)
|
||||||
|
assert isinstance(msg["op"], int)
|
||||||
|
assert msg["op"] == OP.DISPATCH
|
||||||
|
assert isinstance(msg["t"], str)
|
||||||
|
assert msg["t"] in ("RESUMED", "PRESENCE_REPLACE")
|
||||||
|
|
||||||
|
# close again, and retry again, but this time by removing the state
|
||||||
|
# and asserting the session won't be resumed.
|
||||||
|
await _close(conn)
|
||||||
|
|
||||||
|
conn = await gw_start(test_cli_user.cli)
|
||||||
|
_ = await _json(conn)
|
||||||
|
|
||||||
|
async with test_cli_user.app.app_context():
|
||||||
|
test_cli_user.app.state_manager.remove(sess_id)
|
||||||
|
|
||||||
|
await _json_send(
|
||||||
|
conn,
|
||||||
|
{
|
||||||
|
"op": OP.RESUME,
|
||||||
|
"d": {
|
||||||
|
"token": test_cli_user.user["token"],
|
||||||
|
"session_id": sess_id,
|
||||||
|
"seq": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await _json(conn)
|
||||||
|
assert isinstance(msg, dict)
|
||||||
|
assert isinstance(msg["op"], int)
|
||||||
|
assert msg["op"] == OP.INVALID_SESSION
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue