diff --git a/litecord/enums.py b/litecord/enums.py index 99c376d..2030a2c 100644 --- a/litecord/enums.py +++ b/litecord/enums.py @@ -1,5 +1,15 @@ +import ctypes -class ChannelType: +from enum import Enum + + +class EasyEnum(Enum): + @classmethod + def values(cls): + return [v.value for v in cls.__members__.values()] + + +class ChannelType(EasyEnum): GUILD_TEXT = 0 DM = 1 GUILD_VOICE = 2 @@ -7,7 +17,13 @@ class ChannelType: GUILD_CATEGORY = 4 -class MessageType: +class ActivityType(EasyEnum): + PLAYING = 0 + STREAMING = 1 + LISTENING = 2 + + +class MessageType(EasyEnum): DEFAULT = 0 RECIPIENT_ADD = 1 RECIPIENT_REMOVE = 2 @@ -18,8 +34,40 @@ class MessageType: GUILD_MEMBER_JOIN = 7 -class MessageActivityType: +class MessageActivityType(EasyEnum): JOIN = 1 SPECTATE = 2 LISTEN = 3 JOIN_REQUEST = 5 + + +uint8 = ctypes.c_uint8 + + +# use ctypes to interpret the bits in activity flags +class ActivityFlagsBits(ctypes.LittleEndianStructure): + _fields_ = [ + ('instance', uint8, 1), + ('join', uint8, 1), + ('spectate', uint8, 1), + ('join_request', uint8, 1), + ('sync', uint8, 1), + ('play', uint8, 1), + ] + + +class ActivityFlags(ctypes.Union): + _anonymous_ = ('bit',) + + _fields_ = [ + ('bit', ActivityFlagsBits), + ('as_byte', uint8), + ] + + +class StatusType(EasyEnum): + ONLINE = 'online' + DND = 'dnd' + IDLE = 'idle' + INVISIBLE = 'invisible' + OFFLINE = 'offline' diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index ae58bd1..eccf4a5 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -4,11 +4,13 @@ import os def gen_session_id() -> str: """Generate a random session ID.""" - return hashlib.sha1(os.urandom(256)).hexdigest() + return hashlib.sha1(os.urandom(128)).hexdigest() class PayloadStore: """Store manager for payloads.""" + MAX_STORE_SIZE = 250 + def __init__(self): self.store = {} @@ -16,9 +18,26 @@ class PayloadStore: return self.store[opcode] def __setitem__(self, opcode: int, payload: dict): + if len(self.store) > 250: + # if more than 250, remove old keys until we get 250 + opcodes = sorted(list(self.store.keys())) + to_remove = len(opcodes) - self.MAX_STORE_SIZE + + for idx in range(to_remove): + opcode = opcodes[idx] + self.store.pop(opcode) + self.store[opcode] = payload +class Presence: + def __init__(self, raw: dict): + self.afk = raw.get('afk', False) + self.status = raw.get('status', 'online') + self.game = raw.get('game', None) + self.since = raw.get('since', 0) + + class GatewayState: """Main websocket state. @@ -32,6 +51,7 @@ class GatewayState: self.shard = kwargs.get('shard', [0, 1]) self.user_id = kwargs.get('user_id') self.bot = kwargs.get('bot', False) + self.presence = {} self.store = PayloadStore() for key in kwargs: diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 6571334..bb993ff 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -15,6 +15,8 @@ from .errors import DecodeError, UnknownOPCode, \ from .opcodes import OP from .state import GatewayState +from ..schemas import validate, GW_STATUS_UPDATE + log = Logger(__name__) WebsocketProperties = collections.namedtuple( @@ -76,7 +78,7 @@ class GatewayWebsocket: This function accounts for the zlib-stream transport method used by Discord. """ - log.debug('Sending {}', pprint.pformat(payload)) + log.debug('sending {}', pprint.pformat(payload)) encoded = self.encoder(payload) if not isinstance(encoded, bytes): @@ -162,15 +164,29 @@ class GatewayWebsocket: """ return { + # TODO 'relationships': [], + + # TODO 'user_guild_settings': [], + + # TODO 'notes': {}, 'friend_suggestion_count': 0, + + # TODO 'presences': [], + + # TODO 'read_state': [], + 'experiments': [], 'guild_experiments': [], + + # TODO 'connected_accounts': [], + + # TODO: make those changeable 'user_settings': { 'afk_timeout': 300, 'animate_emoji': True, @@ -198,6 +214,7 @@ class GatewayWebsocket: 'theme': 'dark', 'timezone_offset': 420, }, + 'analytics_token': 'transbian', } @@ -246,19 +263,49 @@ class GatewayWebsocket: if current_shard > shard_count: raise InvalidShard('Shard count > Total shards') - async def subscribe_guilds(self): - """Subscribe to all available guilds""" + async def _guild_ids(self): + # TODO: account for sharding guild_ids = await self.ext.db.fetch(""" SELECT guild_id FROM members WHERE user_id = $1 """, self.state.user_id) - guild_ids = [r['guild_id'] for r in guild_ids] + return [r['guild_id'] for r in guild_ids] + + async def subscribe_guilds(self): + """Subscribe to all available guilds""" + guild_ids = await self._guild_ids() self.ext.dispatcher.sub_many(self.state.user_id, guild_ids) + async def update_status(self, status: dict): + if status is None: + status = { + 'afk': False, + + # TODO: fetch status from settings + 'status': 'online', + 'game': None, + + # TODO: this + 'since': 0, + } + + self.state.presence = status + + status = validate(status, GW_STATUS_UPDATE) + + if not status: + # invalid status, must ignore + return + + self.state.presence = status + await self.ext.presence.dispatch_pres(self.state.user_id, + self.state.presence) + async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" + # TODO: handling heartbeats pass async def handle_2(self, payload: Dict[str, Any]): @@ -294,7 +341,6 @@ class GatewayWebsocket: shard=shard, current_shard=shard[0], shard_count=shard[1], - presence=presence, ws=self ) @@ -304,6 +350,9 @@ class GatewayWebsocket: await self.dispatch_ready() await self.subscribe_guilds() + # dispatch presence only after subscribing + await self.update_status(presence) + async def handle_3(self, payload: Dict[str, Any]): """Handle OP 3 Status Update.""" pass @@ -426,8 +475,8 @@ class GatewayWebsocket: payload = self.decoder(message) - pretty_printed = pprint.pformat(payload) - log.debug('received message: {}', pretty_printed) + log.debug('received message: {}', + pprint.pformat(payload)) await self.process_message(payload) @@ -438,9 +487,9 @@ class GatewayWebsocket: await self.send_hello() await self.listen_messages() except websockets.exceptions.ConnectionClosed as err: - log.warning('Client closed, state={}, err={}', self.state, err) + log.warning('conn close, state={}, err={}', self.state, err) except WebsocketClose as err: - log.warning('closed a client, state={} err={}', self.state, err) + log.warning('ws close, state={} err={}', self.state, err) await self.ws.close(code=err.code, reason=err.reason) except Exception as err: diff --git a/litecord/presence.py b/litecord/presence.py index 2ac1532..68fc84b 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -3,12 +3,14 @@ from typing import List, Dict, Any class PresenceManager: """Presence related functions.""" - def __init__(self, storage, state_manager): + def __init__(self, storage, state_manager, dispatcher): self.storage = storage self.state_manager = state_manager + self.dispatcher = dispatcher async def guild_presences(self, member_ids: List[int], guild_id: int) -> List[Dict[Any, str]]: + """Fetch all presences in a guild.""" states = self.state_manager.guild_states(member_ids, guild_id) presences = [] @@ -20,9 +22,38 @@ class PresenceManager: presences.append({ 'user': member['user'], 'roles': member['roles'], - 'game': state.presence['game'], 'guild_id': guild_id, - 'status': state.presence['status'], + 'game': state.presence.get('game', None), + 'status': state.presence.get('status', None), }) return presences + + async def dispatch_guild_pres(self, guild_id: int, + user_id: int, new_state: dict): + """Dispatch a Presence update to an entire guild.""" + state = dict(new_state) + + if state['status'] == 'invisible': + state['status'] = 'offline' + + member = await self.storage.get_member_data_one(guild_id, user_id) + + await self.dispatcher.dispatch_guild( + guild_id, 'PRESENCE_UPDATE', { + 'user': member['user'], + 'roles': member['roles'], + 'guild_id': guild_id, + + 'game': state['game'], + 'status': state['status'], + } + ) + + async def dispatch_pres(self, user_id: int, state): + """Dispatch a new presence to all guilds the user is in.""" + # TODO: account for sharding + guild_ids = await self.storage.get_user_guilds(user_id) + + for guild_id in guild_ids: + await self.dispatch_guild_pres(guild_id, user_id, state) diff --git a/litecord/schemas.py b/litecord/schemas.py index 9aa8f6c..69a29e5 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -3,6 +3,7 @@ import re from cerberus import Validator from .errors import BadRequest +from .enums import ActivityType, StatusType USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A) EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', @@ -22,17 +23,38 @@ class LitecordValidator(Validator): """Validate against the username regex.""" return bool(USERNAME_REGEX.match(value)) + def _validate_type_snowflake(self, value: str) -> bool: + try: + int(value) + return True + except ValueError: + return False + def _validate_type_voice_region(self, value: str) -> bool: # TODO: complete this list return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia') + def _validate_type_activity_type(self, value: int) -> bool: + return value in ActivityType.values() -def validate(reqjson, schema): + def _validate_type_status_external(self, value: str) -> bool: + statuses = StatusType.values() + + # clients should send INVISIBLE instead of OFFLINE + statuses.remove(StatusType.OFFLINE.value) + + return value in statuses + + +def validate(reqjson, schema, raise_err: bool = False): validator = LitecordValidator(schema) if not validator.validate(reqjson): errs = validator.errors - raise BadRequest('bad payload', errs) + if raise_err: + raise BadRequest('bad payload', errs) + + return None return reqjson @@ -75,6 +97,7 @@ MEMBER_UPDATE = { 'channel_id': {'type': 'snowflake', 'required': False}, } + MESSAGE_CREATE = { 'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000}, 'nonce': {'type': 'string', 'required': False}, @@ -82,3 +105,71 @@ MESSAGE_CREATE = { # TODO: file, embed, payload_json } + + +GW_ACTIVITY = { + 'name': {'type': 'string', 'required': True}, + 'type': {'type': 'activity_type', 'required': True}, + + 'url': {'type': 'string', 'required': False, 'nullable': True}, + + 'timestamps': { + 'type': 'dict', + 'required': False, + 'schema': { + 'start': {'type': 'number', 'required': True}, + 'end': {'type': 'number', 'required': True}, + }, + }, + + 'application_id': {'type': 'snowflake', 'required': False, + 'nullable': False}, + 'details': {'type': 'string', 'required': False, 'nullable': True}, + 'state': {'type': 'string', 'required': False, 'nullable': True}, + + 'party': { + 'type': 'dict', + 'required': False, + 'schema': { + 'id': {'type': 'snowflake', 'required': False}, + 'size': {'type': 'list', 'required': False}, + } + }, + + 'assets': { + 'type': 'dict', + 'required': False, + 'schema': { + 'large_image': {'type': 'snowflake', 'required': False}, + 'large_text': {'type': 'string', 'required': False}, + 'small_image': {'type': 'snowflake', 'required': False}, + 'small_text': {'type': 'string', 'required': False}, + } + }, + + 'secrets': { + 'type': 'dict', + 'required': False, + 'schema': { + 'join': {'type': 'string', 'required': False}, + 'spectate': {'type': 'string', 'required': False}, + 'match': {'type': 'string', 'required': False}, + } + }, + + 'instance': {'type': 'boolean', 'required': False}, + 'flags': {'type': 'number', 'required': False}, +} + +GW_STATUS_UPDATE = { + 'status': {'type': 'status_external', 'required': False}, + 'afk': {'type': 'boolean', 'required': False}, + + 'since': {'type': 'number', 'required': True, 'nullable': True}, + 'game': { + 'type': 'dict', + 'required': True, + 'nullable': True, + 'schema': GW_ACTIVITY, + }, +} diff --git a/litecord/storage.py b/litecord/storage.py index 3e3f6c0..1f94dee 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -1,9 +1,14 @@ from typing import List, Dict, Any +from logbook import Logger + from .enums import ChannelType from .schemas import USER_MENTION, ROLE_MENTION +log = Logger(__name__) + + async def _dummy(any_id): return str(any_id) @@ -157,12 +162,13 @@ class Storage: return members - async def _channels_extra(self, row, channel_type: int) -> Dict: + async def _channels_extra(self, row) -> Dict: """Fill in more information about a channel.""" - # TODO: This could probably be better with a dictionary. + channel_type = row['type'] # TODO: dm and group dm? - if channel_type == ChannelType.GUILD_TEXT: + chan_type = ChannelType(channel_type) + if chan_type == ChannelType.GUILD_TEXT: topic = await self.db.fetchval(""" SELECT topic FROM guild_text_channels WHERE id = $1 @@ -171,7 +177,7 @@ class Storage: return {**row, **{ 'topic': topic, }} - elif channel_type == ChannelType.GUILD_VOICE: + elif chan_type == ChannelType.GUILD_VOICE: vrow = await self.db.fetchval(""" SELECT bitrate, user_limit FROM guild_voice_channels WHERE id = $1 @@ -179,6 +185,8 @@ class Storage: return {**row, **dict(vrow)} + log.warning('unknown channel type: {}', chan_type) + async def get_chan_type(self, channel_id) -> int: return await self.db.fetchval(""" SELECT channel_type @@ -205,16 +213,19 @@ class Storage: """Fetch a single channel's information.""" chan_type = await self.get_chan_type(channel_id) - if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, - ChannelType.GUILD_CATEGORY): + if ChannelType(chan_type) in (ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE, + ChannelType.GUILD_CATEGORY): base = await self.db.fetchrow(""" SELECT id, guild_id::text, parent_id, name, position, nsfw FROM guild_channels WHERE guild_channels.id = $1 """, channel_id) - res = await self._channels_extra(dict(base), chan_type) - res['type'] = chan_type + dbase = dict(base) + dbase['type'] = chan_type + + res = await self._channels_extra(dbase) res['permission_overwrites'] = \ list(await self._chan_overwrites(channel_id)) @@ -240,8 +251,12 @@ class Storage: WHERE id = $1 """, row['id']) - res = await self._channels_extra(dict(row), ctype) - res['type'] = ctype + drow = dict(row) + drow['type'] = ctype + + res = await self._channels_extra(drow) + + print(res) res['permission_overwrites'] = \ list(await self._chan_overwrites(row['id'])) diff --git a/run.py b/run.py index 9c744fb..8f19dcd 100644 --- a/run.py +++ b/run.py @@ -88,7 +88,8 @@ async def app_before_serving(): app.state_manager = StateManager() app.dispatcher = EventDispatcher(app.state_manager) app.storage = Storage(app.db) - app.presence = PresenceManager(app.storage, app.state_manager) + app.presence = PresenceManager(app.storage, + app.state_manager, app.dispatcher) # start the websocket, etc host, port = app.config['WS_HOST'], app.config['WS_PORT']