mirror of https://gitlab.com/litecord/litecord.git
Add some lazy loading handling
Mostly adding guild_id to some events. It isn't complete support for
them, but its some of the way there.
- storage: give guild_id on get_message
- gateway.websocket: decrease logging for some stuff
- a debug log for the whole packet is still there for development
reasons, maybe i'll put it on a config option.
- gateway.websocket: dispatch an offline presence when the user has no
connections tied to them anymore
This commit is contained in:
parent
b06c07c097
commit
02f2ee6943
|
|
@ -153,7 +153,7 @@ async def edit_message(channel_id, message_id):
|
|||
|
||||
message = await app.storage.get_message(message_id)
|
||||
|
||||
# only dispatch MESSAGE_CREATE if we actually had any update to start with
|
||||
# only dispatch MESSAGE_UPDATE if we actually had any update to start with
|
||||
if updated:
|
||||
await app.dispatcher.dispatch_guild(guild_id,
|
||||
'MESSAGE_UPDATE', message)
|
||||
|
|
@ -182,7 +182,10 @@ async def delete_message(channel_id, message_id):
|
|||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_DELETE', {
|
||||
'id': str(message_id),
|
||||
'channel_id': str(channel_id)
|
||||
'channel_id': str(channel_id),
|
||||
|
||||
# for lazy guilds
|
||||
'guild_id': str(guild_id),
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
|
@ -280,6 +283,9 @@ async def trigger_typing(channel_id):
|
|||
'channel_id': str(channel_id),
|
||||
'user_id': str(user_id),
|
||||
'timestamp': int(time.time()),
|
||||
|
||||
# guild_id for lazy guilds
|
||||
'guild_id': str(guild_id),
|
||||
})
|
||||
|
||||
return '', 204
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class GatewayState:
|
|||
self.user_id = kwargs.get('user_id')
|
||||
self.bot = kwargs.get('bot', False)
|
||||
self.presence = {}
|
||||
self.ws = None
|
||||
self.store = PayloadStore()
|
||||
|
||||
for key in kwargs:
|
||||
|
|
@ -59,5 +60,5 @@ class GatewayState:
|
|||
self.__dict__[key] = value
|
||||
|
||||
def __repr__(self):
|
||||
return (f'GatewayState<session={self.session_id} seq={self.seq} '
|
||||
return (f'GatewayState<seq={self.seq} '
|
||||
f'shard={self.shard} uid={self.user_id}>')
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class StateManager:
|
|||
"""Insert a new state object."""
|
||||
user_states = self.states[state.user_id]
|
||||
|
||||
log.info('inserting state: {!r}', state)
|
||||
log.debug('inserting state: {!r}', state)
|
||||
user_states[state.session_id] = state
|
||||
|
||||
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
||||
|
|
@ -32,7 +32,7 @@ class StateManager:
|
|||
return
|
||||
|
||||
try:
|
||||
log.info('removing state: {!r}', state)
|
||||
log.debug('removing state: {!r}', state)
|
||||
self.states[state.user_id].pop(state.session_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -83,9 +83,17 @@ class GatewayWebsocket:
|
|||
This function accounts for the zlib-stream
|
||||
transport method used by Discord.
|
||||
"""
|
||||
log.debug('sending {}', pprint.pformat(payload))
|
||||
encoded = self.encoder(payload)
|
||||
|
||||
if len(encoded) < 1024:
|
||||
log.debug('sending {}', pprint.pformat(payload))
|
||||
else:
|
||||
log.debug('sending {}', pprint.pformat(payload))
|
||||
log.debug('sending op={} s={} t={} (too big)',
|
||||
payload.get('op'),
|
||||
payload.get('s'),
|
||||
payload.get('t'))
|
||||
|
||||
if not isinstance(encoded, bytes):
|
||||
encoded = encoded.encode()
|
||||
|
||||
|
|
@ -100,9 +108,13 @@ class GatewayWebsocket:
|
|||
|
||||
async def _hb_wait(self, interval: int):
|
||||
"""Wait heartbeat"""
|
||||
# if the client heartbeats in time,
|
||||
# this task will be cancelled.
|
||||
await asyncio.sleep(interval / 1000)
|
||||
await self.ws.close(4000, 'Heartbeat expired')
|
||||
|
||||
self._cleanup()
|
||||
|
||||
def _hb_start(self, interval: int):
|
||||
# always refresh the heartbeat task
|
||||
# when possible
|
||||
|
|
@ -362,9 +374,9 @@ class GatewayWebsocket:
|
|||
|
||||
async def handle_2(self, payload: Dict[str, Any]):
|
||||
"""Handle the OP 2 Identify packet."""
|
||||
data = payload['d']
|
||||
try:
|
||||
token, properties = data['token'], data['properties']
|
||||
data = payload['d']
|
||||
token = data['token']
|
||||
except KeyError:
|
||||
raise DecodeError('Invalid identify parameters')
|
||||
|
||||
|
|
@ -387,7 +399,6 @@ class GatewayWebsocket:
|
|||
self.state = GatewayState(
|
||||
user_id=user_id,
|
||||
bot=bot,
|
||||
properties=properties,
|
||||
compress=compress,
|
||||
large=large,
|
||||
shard=shard,
|
||||
|
|
@ -413,6 +424,15 @@ class GatewayWebsocket:
|
|||
|
||||
async def handle_4(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 4 Voice Status Update."""
|
||||
data = payload['d']
|
||||
log.debug('got VSU cid={} gid={} deaf={} mute={} video={}',
|
||||
data.get('channel_id'),
|
||||
data.get('guild_id'),
|
||||
data.get('self_deaf'),
|
||||
data.get('self_mute'),
|
||||
data.get('self_video'))
|
||||
|
||||
# for now, do nothing
|
||||
pass
|
||||
|
||||
async def _handle_5(self, payload: Dict[str, Any]):
|
||||
|
|
@ -629,6 +649,7 @@ class GatewayWebsocket:
|
|||
return
|
||||
|
||||
member_ids = await self.storage.get_member_ids(guild_id)
|
||||
log.debug('lazy: loading {} members', len(member_ids))
|
||||
|
||||
# the current implementation is rudimentary and only
|
||||
# generates two groups: online and offline, using
|
||||
|
|
@ -639,8 +660,6 @@ class GatewayWebsocket:
|
|||
guild_presences = await self.presence.guild_presences(member_ids,
|
||||
guild_id)
|
||||
|
||||
log.info('loading {} presences for guild', len(guild_presences))
|
||||
|
||||
online = [{'member': p}
|
||||
for p in guild_presences
|
||||
if p['status'] == 'online']
|
||||
|
|
@ -648,6 +667,11 @@ class GatewayWebsocket:
|
|||
for p in guild_presences
|
||||
if p['status'] == 'offline']
|
||||
|
||||
log.debug('lazy: {} presences, online={}, offline={}',
|
||||
len(guild_presences),
|
||||
len(online),
|
||||
len(offline))
|
||||
|
||||
# construct items in the WORST WAY POSSIBLE.
|
||||
items = [{
|
||||
'group': {
|
||||
|
|
@ -706,12 +730,42 @@ class GatewayWebsocket:
|
|||
raise DecodeError('Payload length exceeded')
|
||||
|
||||
payload = self.decoder(message)
|
||||
|
||||
log.debug('received message: {}',
|
||||
pprint.pformat(payload))
|
||||
|
||||
await self.process_message(payload)
|
||||
|
||||
def _cleanup(self):
|
||||
if self.state:
|
||||
self.ext.state_manager.remove(self.state)
|
||||
self.state.ws = None
|
||||
self.state = None
|
||||
|
||||
async def _check_conns(self, user_id):
|
||||
"""Check if there are any existing connections.
|
||||
|
||||
If there aren't, dispatch a presence for offline.
|
||||
"""
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
# TODO: account for sharding
|
||||
# this only updates status to offline once
|
||||
# ALL shards have come offline
|
||||
states = self.ext.state_manager.user_states(user_id)
|
||||
with_ws = [s for s in states if s.ws]
|
||||
|
||||
# there arent any other states with websocket
|
||||
if not with_ws:
|
||||
offline = {
|
||||
'afk': False,
|
||||
'status': 'offline',
|
||||
'game': None,
|
||||
'since': 0,
|
||||
}
|
||||
|
||||
await self.ext.presence.dispatch_pres(
|
||||
user_id,
|
||||
offline
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""Wrap listen_messages inside
|
||||
a try/except block for WebsocketClose handling."""
|
||||
|
|
@ -728,10 +782,6 @@ class GatewayWebsocket:
|
|||
log.exception('An exception has occoured. state={}', self.state)
|
||||
await self.ws.close(code=4000, reason=repr(err))
|
||||
finally:
|
||||
# TODO: move this to a heartbeat checker
|
||||
# instead of websocket cleanup
|
||||
self.ext.state_manager.remove(self.state)
|
||||
|
||||
# disconnect the state from the websocket
|
||||
if self.state:
|
||||
self.state.ws = None
|
||||
user_id = self.state.user_id if self.state else None
|
||||
self._cleanup()
|
||||
await self._check_conns(user_id)
|
||||
|
|
|
|||
|
|
@ -400,6 +400,18 @@ class Storage:
|
|||
# TODO: res['pinned']
|
||||
res['pinned'] = False
|
||||
|
||||
# this is specifically for lazy guilds.
|
||||
guild_id = await self.db.fetchval("""
|
||||
SELECT guild_id
|
||||
FROM guild_channels
|
||||
WHERE guild_channels.id = $1
|
||||
""", res['channel_id'])
|
||||
|
||||
# only insert when the channel
|
||||
# is actually from a guild.
|
||||
if guild_id:
|
||||
res['guild_id'] = guild_id
|
||||
|
||||
return res
|
||||
|
||||
async def fetch_notes(self, user_id: int) -> dict:
|
||||
|
|
|
|||
Loading…
Reference in New Issue