mirror of https://gitlab.com/litecord/litecord.git
Rough presence management
Lots of changes to get it working.
One day changes will be able to be small enough to be split across
commits.
- enums: use enum.Enum, make EasyEnum subclass
- enums: add ActivityType, ActivityFlags, StatusType
- gateway.state: use 128 random bits instead of 256
- gateway.state: add MAX_STORE_SIZE in PayloadStore and check it when
adding a new payload
- gateway.websocket: add GatewayWebsocket.update_status
- presence: add PresenceManager.dispatch_guild_pres and
PresenceManager.dispatch_pres
- schema: add snowflake, activity_type, status_external types
- schema: add GW_ACTIVITY, GW_STATUS_UPDATE
- storage: fix _channels_extra and fixes for ChannelType as enum instead of
class
This commit is contained in:
parent
d39783e666
commit
cd5dbc4886
|
|
@ -1,5 +1,15 @@
|
|||
import ctypes
|
||||
|
||||
class ChannelType:
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EasyEnum(Enum):
|
||||
@classmethod
|
||||
def values(cls):
|
||||
return [v.value for v in cls.__members__.values()]
|
||||
|
||||
|
||||
class ChannelType(EasyEnum):
|
||||
GUILD_TEXT = 0
|
||||
DM = 1
|
||||
GUILD_VOICE = 2
|
||||
|
|
@ -7,7 +17,13 @@ class ChannelType:
|
|||
GUILD_CATEGORY = 4
|
||||
|
||||
|
||||
class MessageType:
|
||||
class ActivityType(EasyEnum):
|
||||
PLAYING = 0
|
||||
STREAMING = 1
|
||||
LISTENING = 2
|
||||
|
||||
|
||||
class MessageType(EasyEnum):
|
||||
DEFAULT = 0
|
||||
RECIPIENT_ADD = 1
|
||||
RECIPIENT_REMOVE = 2
|
||||
|
|
@ -18,8 +34,40 @@ class MessageType:
|
|||
GUILD_MEMBER_JOIN = 7
|
||||
|
||||
|
||||
class MessageActivityType:
|
||||
class MessageActivityType(EasyEnum):
|
||||
JOIN = 1
|
||||
SPECTATE = 2
|
||||
LISTEN = 3
|
||||
JOIN_REQUEST = 5
|
||||
|
||||
|
||||
uint8 = ctypes.c_uint8
|
||||
|
||||
|
||||
# use ctypes to interpret the bits in activity flags
|
||||
class ActivityFlagsBits(ctypes.LittleEndianStructure):
|
||||
_fields_ = [
|
||||
('instance', uint8, 1),
|
||||
('join', uint8, 1),
|
||||
('spectate', uint8, 1),
|
||||
('join_request', uint8, 1),
|
||||
('sync', uint8, 1),
|
||||
('play', uint8, 1),
|
||||
]
|
||||
|
||||
|
||||
class ActivityFlags(ctypes.Union):
|
||||
_anonymous_ = ('bit',)
|
||||
|
||||
_fields_ = [
|
||||
('bit', ActivityFlagsBits),
|
||||
('as_byte', uint8),
|
||||
]
|
||||
|
||||
|
||||
class StatusType(EasyEnum):
|
||||
ONLINE = 'online'
|
||||
DND = 'dnd'
|
||||
IDLE = 'idle'
|
||||
INVISIBLE = 'invisible'
|
||||
OFFLINE = 'offline'
|
||||
|
|
|
|||
|
|
@ -4,11 +4,13 @@ import os
|
|||
|
||||
def gen_session_id() -> str:
|
||||
"""Generate a random session ID."""
|
||||
return hashlib.sha1(os.urandom(256)).hexdigest()
|
||||
return hashlib.sha1(os.urandom(128)).hexdigest()
|
||||
|
||||
|
||||
class PayloadStore:
|
||||
"""Store manager for payloads."""
|
||||
MAX_STORE_SIZE = 250
|
||||
|
||||
def __init__(self):
|
||||
self.store = {}
|
||||
|
||||
|
|
@ -16,9 +18,26 @@ class PayloadStore:
|
|||
return self.store[opcode]
|
||||
|
||||
def __setitem__(self, opcode: int, payload: dict):
|
||||
if len(self.store) > 250:
|
||||
# if more than 250, remove old keys until we get 250
|
||||
opcodes = sorted(list(self.store.keys()))
|
||||
to_remove = len(opcodes) - self.MAX_STORE_SIZE
|
||||
|
||||
for idx in range(to_remove):
|
||||
opcode = opcodes[idx]
|
||||
self.store.pop(opcode)
|
||||
|
||||
self.store[opcode] = payload
|
||||
|
||||
|
||||
class Presence:
|
||||
def __init__(self, raw: dict):
|
||||
self.afk = raw.get('afk', False)
|
||||
self.status = raw.get('status', 'online')
|
||||
self.game = raw.get('game', None)
|
||||
self.since = raw.get('since', 0)
|
||||
|
||||
|
||||
class GatewayState:
|
||||
"""Main websocket state.
|
||||
|
||||
|
|
@ -32,6 +51,7 @@ class GatewayState:
|
|||
self.shard = kwargs.get('shard', [0, 1])
|
||||
self.user_id = kwargs.get('user_id')
|
||||
self.bot = kwargs.get('bot', False)
|
||||
self.presence = {}
|
||||
self.store = PayloadStore()
|
||||
|
||||
for key in kwargs:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from .errors import DecodeError, UnknownOPCode, \
|
|||
from .opcodes import OP
|
||||
from .state import GatewayState
|
||||
|
||||
from ..schemas import validate, GW_STATUS_UPDATE
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
WebsocketProperties = collections.namedtuple(
|
||||
|
|
@ -76,7 +78,7 @@ class GatewayWebsocket:
|
|||
This function accounts for the zlib-stream
|
||||
transport method used by Discord.
|
||||
"""
|
||||
log.debug('Sending {}', pprint.pformat(payload))
|
||||
log.debug('sending {}', pprint.pformat(payload))
|
||||
encoded = self.encoder(payload)
|
||||
|
||||
if not isinstance(encoded, bytes):
|
||||
|
|
@ -162,15 +164,29 @@ class GatewayWebsocket:
|
|||
"""
|
||||
|
||||
return {
|
||||
# TODO
|
||||
'relationships': [],
|
||||
|
||||
# TODO
|
||||
'user_guild_settings': [],
|
||||
|
||||
# TODO
|
||||
'notes': {},
|
||||
'friend_suggestion_count': 0,
|
||||
|
||||
# TODO
|
||||
'presences': [],
|
||||
|
||||
# TODO
|
||||
'read_state': [],
|
||||
|
||||
'experiments': [],
|
||||
'guild_experiments': [],
|
||||
|
||||
# TODO
|
||||
'connected_accounts': [],
|
||||
|
||||
# TODO: make those changeable
|
||||
'user_settings': {
|
||||
'afk_timeout': 300,
|
||||
'animate_emoji': True,
|
||||
|
|
@ -198,6 +214,7 @@ class GatewayWebsocket:
|
|||
'theme': 'dark',
|
||||
'timezone_offset': 420,
|
||||
},
|
||||
|
||||
'analytics_token': 'transbian',
|
||||
}
|
||||
|
||||
|
|
@ -246,19 +263,49 @@ class GatewayWebsocket:
|
|||
if current_shard > shard_count:
|
||||
raise InvalidShard('Shard count > Total shards')
|
||||
|
||||
async def subscribe_guilds(self):
|
||||
"""Subscribe to all available guilds"""
|
||||
async def _guild_ids(self):
|
||||
# TODO: account for sharding
|
||||
guild_ids = await self.ext.db.fetch("""
|
||||
SELECT guild_id
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", self.state.user_id)
|
||||
|
||||
guild_ids = [r['guild_id'] for r in guild_ids]
|
||||
return [r['guild_id'] for r in guild_ids]
|
||||
|
||||
async def subscribe_guilds(self):
|
||||
"""Subscribe to all available guilds"""
|
||||
guild_ids = await self._guild_ids()
|
||||
self.ext.dispatcher.sub_many(self.state.user_id, guild_ids)
|
||||
|
||||
async def update_status(self, status: dict):
|
||||
if status is None:
|
||||
status = {
|
||||
'afk': False,
|
||||
|
||||
# TODO: fetch status from settings
|
||||
'status': 'online',
|
||||
'game': None,
|
||||
|
||||
# TODO: this
|
||||
'since': 0,
|
||||
}
|
||||
|
||||
self.state.presence = status
|
||||
|
||||
status = validate(status, GW_STATUS_UPDATE)
|
||||
|
||||
if not status:
|
||||
# invalid status, must ignore
|
||||
return
|
||||
|
||||
self.state.presence = status
|
||||
await self.ext.presence.dispatch_pres(self.state.user_id,
|
||||
self.state.presence)
|
||||
|
||||
async def handle_1(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 1 Heartbeat packets."""
|
||||
# TODO: handling heartbeats
|
||||
pass
|
||||
|
||||
async def handle_2(self, payload: Dict[str, Any]):
|
||||
|
|
@ -294,7 +341,6 @@ class GatewayWebsocket:
|
|||
shard=shard,
|
||||
current_shard=shard[0],
|
||||
shard_count=shard[1],
|
||||
presence=presence,
|
||||
ws=self
|
||||
)
|
||||
|
||||
|
|
@ -304,6 +350,9 @@ class GatewayWebsocket:
|
|||
await self.dispatch_ready()
|
||||
await self.subscribe_guilds()
|
||||
|
||||
# dispatch presence only after subscribing
|
||||
await self.update_status(presence)
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 3 Status Update."""
|
||||
pass
|
||||
|
|
@ -426,8 +475,8 @@ class GatewayWebsocket:
|
|||
|
||||
payload = self.decoder(message)
|
||||
|
||||
pretty_printed = pprint.pformat(payload)
|
||||
log.debug('received message: {}', pretty_printed)
|
||||
log.debug('received message: {}',
|
||||
pprint.pformat(payload))
|
||||
|
||||
await self.process_message(payload)
|
||||
|
||||
|
|
@ -438,9 +487,9 @@ class GatewayWebsocket:
|
|||
await self.send_hello()
|
||||
await self.listen_messages()
|
||||
except websockets.exceptions.ConnectionClosed as err:
|
||||
log.warning('Client closed, state={}, err={}', self.state, err)
|
||||
log.warning('conn close, state={}, err={}', self.state, err)
|
||||
except WebsocketClose as err:
|
||||
log.warning('closed a client, state={} err={}', self.state, err)
|
||||
log.warning('ws close, state={} err={}', self.state, err)
|
||||
|
||||
await self.ws.close(code=err.code, reason=err.reason)
|
||||
except Exception as err:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@ from typing import List, Dict, Any
|
|||
|
||||
class PresenceManager:
|
||||
"""Presence related functions."""
|
||||
def __init__(self, storage, state_manager):
|
||||
def __init__(self, storage, state_manager, dispatcher):
|
||||
self.storage = storage
|
||||
self.state_manager = state_manager
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
async def guild_presences(self, member_ids: List[int],
|
||||
guild_id: int) -> List[Dict[Any, str]]:
|
||||
"""Fetch all presences in a guild."""
|
||||
states = self.state_manager.guild_states(member_ids, guild_id)
|
||||
|
||||
presences = []
|
||||
|
|
@ -20,9 +22,38 @@ class PresenceManager:
|
|||
presences.append({
|
||||
'user': member['user'],
|
||||
'roles': member['roles'],
|
||||
'game': state.presence['game'],
|
||||
'guild_id': guild_id,
|
||||
'status': state.presence['status'],
|
||||
'game': state.presence.get('game', None),
|
||||
'status': state.presence.get('status', None),
|
||||
})
|
||||
|
||||
return presences
|
||||
|
||||
async def dispatch_guild_pres(self, guild_id: int,
|
||||
user_id: int, new_state: dict):
|
||||
"""Dispatch a Presence update to an entire guild."""
|
||||
state = dict(new_state)
|
||||
|
||||
if state['status'] == 'invisible':
|
||||
state['status'] = 'offline'
|
||||
|
||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||
|
||||
await self.dispatcher.dispatch_guild(
|
||||
guild_id, 'PRESENCE_UPDATE', {
|
||||
'user': member['user'],
|
||||
'roles': member['roles'],
|
||||
'guild_id': guild_id,
|
||||
|
||||
'game': state['game'],
|
||||
'status': state['status'],
|
||||
}
|
||||
)
|
||||
|
||||
async def dispatch_pres(self, user_id: int, state):
|
||||
"""Dispatch a new presence to all guilds the user is in."""
|
||||
# TODO: account for sharding
|
||||
guild_ids = await self.storage.get_user_guilds(user_id)
|
||||
|
||||
for guild_id in guild_ids:
|
||||
await self.dispatch_guild_pres(guild_id, user_id, state)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import re
|
|||
from cerberus import Validator
|
||||
|
||||
from .errors import BadRequest
|
||||
from .enums import ActivityType, StatusType
|
||||
|
||||
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
|
||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
|
||||
|
|
@ -22,18 +23,39 @@ class LitecordValidator(Validator):
|
|||
"""Validate against the username regex."""
|
||||
return bool(USERNAME_REGEX.match(value))
|
||||
|
||||
def _validate_type_snowflake(self, value: str) -> bool:
|
||||
try:
|
||||
int(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _validate_type_voice_region(self, value: str) -> bool:
|
||||
# TODO: complete this list
|
||||
return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
|
||||
|
||||
def _validate_type_activity_type(self, value: int) -> bool:
|
||||
return value in ActivityType.values()
|
||||
|
||||
def validate(reqjson, schema):
|
||||
def _validate_type_status_external(self, value: str) -> bool:
|
||||
statuses = StatusType.values()
|
||||
|
||||
# clients should send INVISIBLE instead of OFFLINE
|
||||
statuses.remove(StatusType.OFFLINE.value)
|
||||
|
||||
return value in statuses
|
||||
|
||||
|
||||
def validate(reqjson, schema, raise_err: bool = False):
|
||||
validator = LitecordValidator(schema)
|
||||
if not validator.validate(reqjson):
|
||||
errs = validator.errors
|
||||
|
||||
if raise_err:
|
||||
raise BadRequest('bad payload', errs)
|
||||
|
||||
return None
|
||||
|
||||
return reqjson
|
||||
|
||||
|
||||
|
|
@ -75,6 +97,7 @@ MEMBER_UPDATE = {
|
|||
'channel_id': {'type': 'snowflake', 'required': False},
|
||||
}
|
||||
|
||||
|
||||
MESSAGE_CREATE = {
|
||||
'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000},
|
||||
'nonce': {'type': 'string', 'required': False},
|
||||
|
|
@ -82,3 +105,71 @@ MESSAGE_CREATE = {
|
|||
|
||||
# TODO: file, embed, payload_json
|
||||
}
|
||||
|
||||
|
||||
GW_ACTIVITY = {
|
||||
'name': {'type': 'string', 'required': True},
|
||||
'type': {'type': 'activity_type', 'required': True},
|
||||
|
||||
'url': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
'timestamps': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'start': {'type': 'number', 'required': True},
|
||||
'end': {'type': 'number', 'required': True},
|
||||
},
|
||||
},
|
||||
|
||||
'application_id': {'type': 'snowflake', 'required': False,
|
||||
'nullable': False},
|
||||
'details': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'state': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
'party': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'id': {'type': 'snowflake', 'required': False},
|
||||
'size': {'type': 'list', 'required': False},
|
||||
}
|
||||
},
|
||||
|
||||
'assets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'large_image': {'type': 'snowflake', 'required': False},
|
||||
'large_text': {'type': 'string', 'required': False},
|
||||
'small_image': {'type': 'snowflake', 'required': False},
|
||||
'small_text': {'type': 'string', 'required': False},
|
||||
}
|
||||
},
|
||||
|
||||
'secrets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'join': {'type': 'string', 'required': False},
|
||||
'spectate': {'type': 'string', 'required': False},
|
||||
'match': {'type': 'string', 'required': False},
|
||||
}
|
||||
},
|
||||
|
||||
'instance': {'type': 'boolean', 'required': False},
|
||||
'flags': {'type': 'number', 'required': False},
|
||||
}
|
||||
|
||||
GW_STATUS_UPDATE = {
|
||||
'status': {'type': 'status_external', 'required': False},
|
||||
'afk': {'type': 'boolean', 'required': False},
|
||||
|
||||
'since': {'type': 'number', 'required': True, 'nullable': True},
|
||||
'game': {
|
||||
'type': 'dict',
|
||||
'required': True,
|
||||
'nullable': True,
|
||||
'schema': GW_ACTIVITY,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
from typing import List, Dict, Any
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from .enums import ChannelType
|
||||
from .schemas import USER_MENTION, ROLE_MENTION
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
async def _dummy(any_id):
|
||||
return str(any_id)
|
||||
|
||||
|
|
@ -157,12 +162,13 @@ class Storage:
|
|||
|
||||
return members
|
||||
|
||||
async def _channels_extra(self, row, channel_type: int) -> Dict:
|
||||
async def _channels_extra(self, row) -> Dict:
|
||||
"""Fill in more information about a channel."""
|
||||
# TODO: This could probably be better with a dictionary.
|
||||
channel_type = row['type']
|
||||
|
||||
# TODO: dm and group dm?
|
||||
if channel_type == ChannelType.GUILD_TEXT:
|
||||
chan_type = ChannelType(channel_type)
|
||||
if chan_type == ChannelType.GUILD_TEXT:
|
||||
topic = await self.db.fetchval("""
|
||||
SELECT topic FROM guild_text_channels
|
||||
WHERE id = $1
|
||||
|
|
@ -171,7 +177,7 @@ class Storage:
|
|||
return {**row, **{
|
||||
'topic': topic,
|
||||
}}
|
||||
elif channel_type == ChannelType.GUILD_VOICE:
|
||||
elif chan_type == ChannelType.GUILD_VOICE:
|
||||
vrow = await self.db.fetchval("""
|
||||
SELECT bitrate, user_limit FROM guild_voice_channels
|
||||
WHERE id = $1
|
||||
|
|
@ -179,6 +185,8 @@ class Storage:
|
|||
|
||||
return {**row, **dict(vrow)}
|
||||
|
||||
log.warning('unknown channel type: {}', chan_type)
|
||||
|
||||
async def get_chan_type(self, channel_id) -> int:
|
||||
return await self.db.fetchval("""
|
||||
SELECT channel_type
|
||||
|
|
@ -205,7 +213,8 @@ class Storage:
|
|||
"""Fetch a single channel's information."""
|
||||
chan_type = await self.get_chan_type(channel_id)
|
||||
|
||||
if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
||||
if ChannelType(chan_type) in (ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE,
|
||||
ChannelType.GUILD_CATEGORY):
|
||||
base = await self.db.fetchrow("""
|
||||
SELECT id, guild_id::text, parent_id, name, position, nsfw
|
||||
|
|
@ -213,8 +222,10 @@ class Storage:
|
|||
WHERE guild_channels.id = $1
|
||||
""", channel_id)
|
||||
|
||||
res = await self._channels_extra(dict(base), chan_type)
|
||||
res['type'] = chan_type
|
||||
dbase = dict(base)
|
||||
dbase['type'] = chan_type
|
||||
|
||||
res = await self._channels_extra(dbase)
|
||||
res['permission_overwrites'] = \
|
||||
list(await self._chan_overwrites(channel_id))
|
||||
|
||||
|
|
@ -240,8 +251,12 @@ class Storage:
|
|||
WHERE id = $1
|
||||
""", row['id'])
|
||||
|
||||
res = await self._channels_extra(dict(row), ctype)
|
||||
res['type'] = ctype
|
||||
drow = dict(row)
|
||||
drow['type'] = ctype
|
||||
|
||||
res = await self._channels_extra(drow)
|
||||
|
||||
print(res)
|
||||
|
||||
res['permission_overwrites'] = \
|
||||
list(await self._chan_overwrites(row['id']))
|
||||
|
|
|
|||
3
run.py
3
run.py
|
|
@ -88,7 +88,8 @@ async def app_before_serving():
|
|||
app.state_manager = StateManager()
|
||||
app.dispatcher = EventDispatcher(app.state_manager)
|
||||
app.storage = Storage(app.db)
|
||||
app.presence = PresenceManager(app.storage, app.state_manager)
|
||||
app.presence = PresenceManager(app.storage,
|
||||
app.state_manager, app.dispatcher)
|
||||
|
||||
# start the websocket, etc
|
||||
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
||||
|
|
|
|||
Loading…
Reference in New Issue