Rough presence management

Lots of changes to get it working.

One day changes will be able to be small enough to be split across
commits.

 - enums: use enum.Enum, make EasyEnum subclass
 - enums: add ActivityType, ActivityFlags, StatusType
 - gateway.state: use 128 random bits instead of 256
 - gateway.state: add MAX_STORE_SIZE in PayloadStore and check it when
    adding a new payload
 - gateway.websocket: add GatewayWebsocket.update_status
 - presence: add PresenceManager.dispatch_guild_pres and
    PresenceManager.dispatch_pres
 - schema: add snowflake, activity_type, status_external types
 - schema: add GW_ACTIVITY, GW_STATUS_UPDATE
 - storage: fix _channels_extra and fixes for ChannelType as enum instead of
    class
This commit is contained in:
Luna Mendes 2018-09-10 01:09:09 -03:00
parent d39783e666
commit cd5dbc4886
7 changed files with 284 additions and 29 deletions

View File

@ -1,5 +1,15 @@
import ctypes
class ChannelType:
from enum import Enum
class EasyEnum(Enum):
@classmethod
def values(cls):
return [v.value for v in cls.__members__.values()]
class ChannelType(EasyEnum):
GUILD_TEXT = 0
DM = 1
GUILD_VOICE = 2
@ -7,7 +17,13 @@ class ChannelType:
GUILD_CATEGORY = 4
class MessageType:
class ActivityType(EasyEnum):
PLAYING = 0
STREAMING = 1
LISTENING = 2
class MessageType(EasyEnum):
DEFAULT = 0
RECIPIENT_ADD = 1
RECIPIENT_REMOVE = 2
@ -18,8 +34,40 @@ class MessageType:
GUILD_MEMBER_JOIN = 7
class MessageActivityType:
class MessageActivityType(EasyEnum):
JOIN = 1
SPECTATE = 2
LISTEN = 3
JOIN_REQUEST = 5
uint8 = ctypes.c_uint8
# use ctypes to interpret the bits in activity flags
class ActivityFlagsBits(ctypes.LittleEndianStructure):
_fields_ = [
('instance', uint8, 1),
('join', uint8, 1),
('spectate', uint8, 1),
('join_request', uint8, 1),
('sync', uint8, 1),
('play', uint8, 1),
]
class ActivityFlags(ctypes.Union):
_anonymous_ = ('bit',)
_fields_ = [
('bit', ActivityFlagsBits),
('as_byte', uint8),
]
class StatusType(EasyEnum):
ONLINE = 'online'
DND = 'dnd'
IDLE = 'idle'
INVISIBLE = 'invisible'
OFFLINE = 'offline'

View File

@ -4,11 +4,13 @@ import os
def gen_session_id() -> str:
"""Generate a random session ID."""
return hashlib.sha1(os.urandom(256)).hexdigest()
return hashlib.sha1(os.urandom(128)).hexdigest()
class PayloadStore:
"""Store manager for payloads."""
MAX_STORE_SIZE = 250
def __init__(self):
self.store = {}
@ -16,9 +18,26 @@ class PayloadStore:
return self.store[opcode]
def __setitem__(self, opcode: int, payload: dict):
if len(self.store) > 250:
# if more than 250, remove old keys until we get 250
opcodes = sorted(list(self.store.keys()))
to_remove = len(opcodes) - self.MAX_STORE_SIZE
for idx in range(to_remove):
opcode = opcodes[idx]
self.store.pop(opcode)
self.store[opcode] = payload
class Presence:
def __init__(self, raw: dict):
self.afk = raw.get('afk', False)
self.status = raw.get('status', 'online')
self.game = raw.get('game', None)
self.since = raw.get('since', 0)
class GatewayState:
"""Main websocket state.
@ -32,6 +51,7 @@ class GatewayState:
self.shard = kwargs.get('shard', [0, 1])
self.user_id = kwargs.get('user_id')
self.bot = kwargs.get('bot', False)
self.presence = {}
self.store = PayloadStore()
for key in kwargs:

View File

@ -15,6 +15,8 @@ from .errors import DecodeError, UnknownOPCode, \
from .opcodes import OP
from .state import GatewayState
from ..schemas import validate, GW_STATUS_UPDATE
log = Logger(__name__)
WebsocketProperties = collections.namedtuple(
@ -76,7 +78,7 @@ class GatewayWebsocket:
This function accounts for the zlib-stream
transport method used by Discord.
"""
log.debug('Sending {}', pprint.pformat(payload))
log.debug('sending {}', pprint.pformat(payload))
encoded = self.encoder(payload)
if not isinstance(encoded, bytes):
@ -162,15 +164,29 @@ class GatewayWebsocket:
"""
return {
# TODO
'relationships': [],
# TODO
'user_guild_settings': [],
# TODO
'notes': {},
'friend_suggestion_count': 0,
# TODO
'presences': [],
# TODO
'read_state': [],
'experiments': [],
'guild_experiments': [],
# TODO
'connected_accounts': [],
# TODO: make those changeable
'user_settings': {
'afk_timeout': 300,
'animate_emoji': True,
@ -198,6 +214,7 @@ class GatewayWebsocket:
'theme': 'dark',
'timezone_offset': 420,
},
'analytics_token': 'transbian',
}
@ -246,19 +263,49 @@ class GatewayWebsocket:
if current_shard > shard_count:
raise InvalidShard('Shard count > Total shards')
async def subscribe_guilds(self):
"""Subscribe to all available guilds"""
async def _guild_ids(self):
# TODO: account for sharding
guild_ids = await self.ext.db.fetch("""
SELECT guild_id
FROM members
WHERE user_id = $1
""", self.state.user_id)
guild_ids = [r['guild_id'] for r in guild_ids]
return [r['guild_id'] for r in guild_ids]
async def subscribe_guilds(self):
"""Subscribe to all available guilds"""
guild_ids = await self._guild_ids()
self.ext.dispatcher.sub_many(self.state.user_id, guild_ids)
async def update_status(self, status: dict):
if status is None:
status = {
'afk': False,
# TODO: fetch status from settings
'status': 'online',
'game': None,
# TODO: this
'since': 0,
}
self.state.presence = status
status = validate(status, GW_STATUS_UPDATE)
if not status:
# invalid status, must ignore
return
self.state.presence = status
await self.ext.presence.dispatch_pres(self.state.user_id,
self.state.presence)
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
# TODO: handling heartbeats
pass
async def handle_2(self, payload: Dict[str, Any]):
@ -294,7 +341,6 @@ class GatewayWebsocket:
shard=shard,
current_shard=shard[0],
shard_count=shard[1],
presence=presence,
ws=self
)
@ -304,6 +350,9 @@ class GatewayWebsocket:
await self.dispatch_ready()
await self.subscribe_guilds()
# dispatch presence only after subscribing
await self.update_status(presence)
async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update."""
pass
@ -426,8 +475,8 @@ class GatewayWebsocket:
payload = self.decoder(message)
pretty_printed = pprint.pformat(payload)
log.debug('received message: {}', pretty_printed)
log.debug('received message: {}',
pprint.pformat(payload))
await self.process_message(payload)
@ -438,9 +487,9 @@ class GatewayWebsocket:
await self.send_hello()
await self.listen_messages()
except websockets.exceptions.ConnectionClosed as err:
log.warning('Client closed, state={}, err={}', self.state, err)
log.warning('conn close, state={}, err={}', self.state, err)
except WebsocketClose as err:
log.warning('closed a client, state={} err={}', self.state, err)
log.warning('ws close, state={} err={}', self.state, err)
await self.ws.close(code=err.code, reason=err.reason)
except Exception as err:

View File

@ -3,12 +3,14 @@ from typing import List, Dict, Any
class PresenceManager:
"""Presence related functions."""
def __init__(self, storage, state_manager):
def __init__(self, storage, state_manager, dispatcher):
self.storage = storage
self.state_manager = state_manager
self.dispatcher = dispatcher
async def guild_presences(self, member_ids: List[int],
guild_id: int) -> List[Dict[Any, str]]:
"""Fetch all presences in a guild."""
states = self.state_manager.guild_states(member_ids, guild_id)
presences = []
@ -20,9 +22,38 @@ class PresenceManager:
presences.append({
'user': member['user'],
'roles': member['roles'],
'game': state.presence['game'],
'guild_id': guild_id,
'status': state.presence['status'],
'game': state.presence.get('game', None),
'status': state.presence.get('status', None),
})
return presences
async def dispatch_guild_pres(self, guild_id: int,
user_id: int, new_state: dict):
"""Dispatch a Presence update to an entire guild."""
state = dict(new_state)
if state['status'] == 'invisible':
state['status'] = 'offline'
member = await self.storage.get_member_data_one(guild_id, user_id)
await self.dispatcher.dispatch_guild(
guild_id, 'PRESENCE_UPDATE', {
'user': member['user'],
'roles': member['roles'],
'guild_id': guild_id,
'game': state['game'],
'status': state['status'],
}
)
async def dispatch_pres(self, user_id: int, state):
"""Dispatch a new presence to all guilds the user is in."""
# TODO: account for sharding
guild_ids = await self.storage.get_user_guilds(user_id)
for guild_id in guild_ids:
await self.dispatch_guild_pres(guild_id, user_id, state)

View File

@ -3,6 +3,7 @@ import re
from cerberus import Validator
from .errors import BadRequest
from .enums import ActivityType, StatusType
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
@ -22,17 +23,38 @@ class LitecordValidator(Validator):
"""Validate against the username regex."""
return bool(USERNAME_REGEX.match(value))
def _validate_type_snowflake(self, value: str) -> bool:
try:
int(value)
return True
except ValueError:
return False
def _validate_type_voice_region(self, value: str) -> bool:
# TODO: complete this list
return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
def _validate_type_activity_type(self, value: int) -> bool:
return value in ActivityType.values()
def validate(reqjson, schema):
def _validate_type_status_external(self, value: str) -> bool:
statuses = StatusType.values()
# clients should send INVISIBLE instead of OFFLINE
statuses.remove(StatusType.OFFLINE.value)
return value in statuses
def validate(reqjson, schema, raise_err: bool = False):
validator = LitecordValidator(schema)
if not validator.validate(reqjson):
errs = validator.errors
raise BadRequest('bad payload', errs)
if raise_err:
raise BadRequest('bad payload', errs)
return None
return reqjson
@ -75,6 +97,7 @@ MEMBER_UPDATE = {
'channel_id': {'type': 'snowflake', 'required': False},
}
MESSAGE_CREATE = {
'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000},
'nonce': {'type': 'string', 'required': False},
@ -82,3 +105,71 @@ MESSAGE_CREATE = {
# TODO: file, embed, payload_json
}
GW_ACTIVITY = {
'name': {'type': 'string', 'required': True},
'type': {'type': 'activity_type', 'required': True},
'url': {'type': 'string', 'required': False, 'nullable': True},
'timestamps': {
'type': 'dict',
'required': False,
'schema': {
'start': {'type': 'number', 'required': True},
'end': {'type': 'number', 'required': True},
},
},
'application_id': {'type': 'snowflake', 'required': False,
'nullable': False},
'details': {'type': 'string', 'required': False, 'nullable': True},
'state': {'type': 'string', 'required': False, 'nullable': True},
'party': {
'type': 'dict',
'required': False,
'schema': {
'id': {'type': 'snowflake', 'required': False},
'size': {'type': 'list', 'required': False},
}
},
'assets': {
'type': 'dict',
'required': False,
'schema': {
'large_image': {'type': 'snowflake', 'required': False},
'large_text': {'type': 'string', 'required': False},
'small_image': {'type': 'snowflake', 'required': False},
'small_text': {'type': 'string', 'required': False},
}
},
'secrets': {
'type': 'dict',
'required': False,
'schema': {
'join': {'type': 'string', 'required': False},
'spectate': {'type': 'string', 'required': False},
'match': {'type': 'string', 'required': False},
}
},
'instance': {'type': 'boolean', 'required': False},
'flags': {'type': 'number', 'required': False},
}
GW_STATUS_UPDATE = {
'status': {'type': 'status_external', 'required': False},
'afk': {'type': 'boolean', 'required': False},
'since': {'type': 'number', 'required': True, 'nullable': True},
'game': {
'type': 'dict',
'required': True,
'nullable': True,
'schema': GW_ACTIVITY,
},
}

View File

@ -1,9 +1,14 @@
from typing import List, Dict, Any
from logbook import Logger
from .enums import ChannelType
from .schemas import USER_MENTION, ROLE_MENTION
log = Logger(__name__)
async def _dummy(any_id):
return str(any_id)
@ -157,12 +162,13 @@ class Storage:
return members
async def _channels_extra(self, row, channel_type: int) -> Dict:
async def _channels_extra(self, row) -> Dict:
"""Fill in more information about a channel."""
# TODO: This could probably be better with a dictionary.
channel_type = row['type']
# TODO: dm and group dm?
if channel_type == ChannelType.GUILD_TEXT:
chan_type = ChannelType(channel_type)
if chan_type == ChannelType.GUILD_TEXT:
topic = await self.db.fetchval("""
SELECT topic FROM guild_text_channels
WHERE id = $1
@ -171,7 +177,7 @@ class Storage:
return {**row, **{
'topic': topic,
}}
elif channel_type == ChannelType.GUILD_VOICE:
elif chan_type == ChannelType.GUILD_VOICE:
vrow = await self.db.fetchval("""
SELECT bitrate, user_limit FROM guild_voice_channels
WHERE id = $1
@ -179,6 +185,8 @@ class Storage:
return {**row, **dict(vrow)}
log.warning('unknown channel type: {}', chan_type)
async def get_chan_type(self, channel_id) -> int:
return await self.db.fetchval("""
SELECT channel_type
@ -205,16 +213,19 @@ class Storage:
"""Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id)
if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY):
if ChannelType(chan_type) in (ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY):
base = await self.db.fetchrow("""
SELECT id, guild_id::text, parent_id, name, position, nsfw
FROM guild_channels
WHERE guild_channels.id = $1
""", channel_id)
res = await self._channels_extra(dict(base), chan_type)
res['type'] = chan_type
dbase = dict(base)
dbase['type'] = chan_type
res = await self._channels_extra(dbase)
res['permission_overwrites'] = \
list(await self._chan_overwrites(channel_id))
@ -240,8 +251,12 @@ class Storage:
WHERE id = $1
""", row['id'])
res = await self._channels_extra(dict(row), ctype)
res['type'] = ctype
drow = dict(row)
drow['type'] = ctype
res = await self._channels_extra(drow)
print(res)
res['permission_overwrites'] = \
list(await self._chan_overwrites(row['id']))

3
run.py
View File

@ -88,7 +88,8 @@ async def app_before_serving():
app.state_manager = StateManager()
app.dispatcher = EventDispatcher(app.state_manager)
app.storage = Storage(app.db)
app.presence = PresenceManager(app.storage, app.state_manager)
app.presence = PresenceManager(app.storage,
app.state_manager, app.dispatcher)
# start the websocket, etc
host, port = app.config['WS_HOST'], app.config['WS_PORT']