diff --git a/litecord/enums.py b/litecord/enums.py index 2c4e3b8..0a2ddf3 100644 --- a/litecord/enums.py +++ b/litecord/enums.py @@ -94,6 +94,7 @@ class ActivityType(EasyEnum): STREAMING = 1 LISTENING = 2 WATCHING = 3 + CUSTOM = 4 class MessageType(EasyEnum): diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 8be9a8f..a98752f 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -21,7 +21,8 @@ import collections import asyncio import pprint import zlib -from typing import List, Dict, Any, Iterable +import time +from typing import List, Dict, Any, Iterable, Optional from random import randint import websockets @@ -30,12 +31,15 @@ from logbook import Logger from quart import current_app as app from litecord.auth import raw_token_check -from litecord.enums import RelationshipType, ChannelType +from litecord.enums import RelationshipType, ChannelType, ActivityType from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.utils import ( task_wrapper, yield_chunks, maybe_int, + custom_status_to_activity, + custom_status_is_expired, + custom_status_set_null, want_bytes, want_string, ) @@ -87,6 +91,7 @@ class GatewayWebsocket: log.debug("websocket properties: {!r}", self.wsp) self.state = None + self._hb_counter = 0 self._set_encoders() @@ -301,7 +306,7 @@ class GatewayWebsocket: await self.dispatch("GUILD_CREATE", guild) - async def _user_ready(self) -> dict: + async def _user_ready(self, *, settings=None) -> dict: """Fetch information about users in the READY packet. This part of the API is completly undocumented. @@ -319,7 +324,7 @@ class GatewayWebsocket: ] friend_presences = await self.app.presence.friend_presences(friend_ids) - settings = await self.user_storage.get_user_settings(user_id) + settings = settings or await self.user_storage.get_user_settings(user_id) return { "user_settings": settings, @@ -336,7 +341,7 @@ class GatewayWebsocket: "analytics_token": "transbian", } - async def dispatch_ready(self): + async def dispatch_ready(self, **kwargs): """Dispatch the READY packet for a connecting account.""" guilds = await self._make_guild_list() @@ -346,7 +351,7 @@ class GatewayWebsocket: user_ready = {} if not self.state.bot: # user, fetch info - user_ready = await self._user_ready() + user_ready = await self._user_ready(**kwargs) private_channels = await self.user_storage.get_dms( user_id @@ -481,56 +486,111 @@ class GatewayWebsocket: for friend_id in friend_ids: await app.dispatcher.friend.sub(user_id, friend_id) - async def update_status(self, incoming_status: dict): - """Update the status of the current websocket connection.""" + async def update_presence( + self, + given_presence: dict, + *, + settings: Optional[dict] = None, + override_ratelimit=False, + ): + """Update the presence of the current websocket connection. + + Invalid presences are silently dropped. As well as when the state is + invalid/incomplete. + When the session is beyond the Status Update's ratelimits, the update + is silently dropped. + """ if not self.state: return - if self._check_ratelimit("presence", self.state.session_id): - # Presence Updates beyond the ratelimit - # are just silently dropped. + if not override_ratelimit and self._check_ratelimit( + "presence", self.state.session_id + ): return - status = { - "afk": False, - # TODO: fetch status from settings - "status": "online", - "game": None, - # TODO: this - "since": 0, - } - status.update(incoming_status or {}) + settings = settings or await self.user_storage.get_user_settings( + self.state.user_id + ) + + presence = BasePresence(status=settings["status"] or "online", game=None) + + custom_status = settings.get("custom_status") or None + if isinstance(custom_status, dict) and custom_status is not None: + presence.game = await custom_status_to_activity(custom_status) + if presence.game is None: + await custom_status_set_null(self.state.user_id) + + log.debug("pres={}, given pres={}", presence, given_presence) try: - status = validate(status, GW_STATUS_UPDATE) + given_presence = validate(given_presence, GW_STATUS_UPDATE) except BadRequest as err: log.warning(f"Invalid status update: {err}") return - # try to extract game from activities - # when game not provided - if not status.get("game"): - try: - game = status["activities"][0] - except (KeyError, IndexError): - game = None - else: - game = status["game"] + presence.update_from_incoming_dict(given_presence) - pres_status = status.get("status") or "online" - pres_status = "offline" if pres_status == "invisible" else pres_status - self.state.presence = BasePresence(status=pres_status, game=game) + # always try to use activities.0 to replace game when possible + activities: Optional[List[dict]] = given_presence.get("activities") + try: + activity: Optional[dict] = (activities or [])[0] + except IndexError: + activity = None + + game: Optional[dict] = activity or presence.game + + # hacky, but works (id and created_at aren't documented) + if game is not None and game["type"] == ActivityType.CUSTOM.value: + game["id"] = "custom" + game["created_at"] = int(time.time() * 1000) + + emoji = game.get("emoji") or {} + if emoji.get("id") is None and emoji.get("name") is not None: + # drop the other fields when we're using unicode emoji + game["emoji"] = {"name": emoji["name"]} + + presence.game = game + + if presence.status == "invisible": + presence.status = "offline" + + self.state.presence = presence log.info( - f'Updating presence status={status["status"]} for ' - f"uid={self.state.user_id}" + "updating presence status={} for uid={}", + presence.status, + self.state.user_id, ) + log.debug("full presence = {}", presence) await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence) + async def _custom_status_expire_check(self): + if not self.state: + return + + settings = await self.user_storage.get_user_settings(self.state.user_id) + custom_status = settings["custom_status"] + if custom_status is None: + return + + if not custom_status_is_expired(custom_status.get("expires_at")): + return + + await custom_status_set_null(self.state.user_id) + await self.update_presence( + {"status": self.state.presence.status, "game": None}, + override_ratelimit=True, + ) + async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" # give the client 3 more seconds before we # close the websocket + + self._hb_counter += 1 + if self._hb_counter % 2 == 0: + self.app.sched.spawn(self._custom_status_expire_check()) + self._hb_start((46 + 3) * 1000) cliseq = payload.get("d") @@ -562,7 +622,7 @@ class GatewayWebsocket: large = data.get("large_threshold", 50) shard = data.get("shard", [0, 1]) - presence = data.get("presence") + presence = data.get("presence") or {} try: user_id = await raw_token_check(token, self.app.db) @@ -596,17 +656,16 @@ class GatewayWebsocket: # link the state to the user self.app.state_manager.insert(self.state) - await self.update_status(presence) + settings = await self.user_storage.get_user_settings(user_id) + + await self.update_presence(presence, settings=settings) await self.subscribe_all(data.get("guild_subscriptions", True)) - await self.dispatch_ready() + await self.dispatch_ready(settings=settings) async def handle_3(self, payload: Dict[str, Any]): """Handle OP 3 Status Update.""" - presence = payload["d"] - - # update_status will take care of validation and - # setting new presence to state - await self.update_status(presence) + presence = payload["d"] or {} + await self.update_presence(presence) async def _vsu_get_prop(self, state, data): """Get voice state properties from data, fallbacking to diff --git a/litecord/presence.py b/litecord/presence.py index eafc118..a66dba0 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -32,6 +32,10 @@ class BasePresence: status: str game: Optional[dict] = None + @property + def activities(self) -> list: + return [self.game] if self.game else [] + @property def partial_dict(self) -> dict: return { @@ -40,9 +44,23 @@ class BasePresence: "since": 0, "client_status": {}, "mobile": False, - "activities": [self.game] if self.game else [], + "activities": self.activities, } + def update_from_incoming_dict(self, given_presence: dict) -> None: + given_status, given_game = ( + given_presence.get("status"), + given_presence.get("game"), + ) + + if given_status is not None: + assert isinstance(given_status, str) + self.status = given_status + + if given_game is not None: + assert isinstance(given_game, dict) + self.game = given_game + Presence = Dict[str, Any] @@ -139,6 +157,7 @@ class PresenceManager: "roles": member["roles"], "status": presence.status, "game": presence.game, + "activities": presence.activities, }, ) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index 4b85dc4..083c617 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -223,7 +223,7 @@ def merge(member: dict, presence: Presence) -> dict: "user": {"id": str(member["user"]["id"])}, "status": presence["status"], "game": presence["game"], - "activities": presence["activities"], + "activities": (presence.get("activities") or []), } }, } diff --git a/litecord/schemas.py b/litecord/schemas.py index 3962934..4f8fdc9 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -18,6 +18,8 @@ along with this program. If not, see . """ import re + +# from datetime import datetime from typing import Union, Dict, List, Optional from cerberus import Validator @@ -41,7 +43,10 @@ from litecord.embed.schemas import EMBED_OBJECT, EmbedURL log = Logger(__name__) +# TODO use any char instead of english lol USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A) + +# TODO better email regex maybe EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A) DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A) @@ -437,6 +442,16 @@ GW_ACTIVITY = { }, "instance": {"type": "boolean", "required": False}, "flags": {"type": "number", "required": False}, + "emoji": { + "type": "dict", + "required": False, + "nullable": True, + "schema": { + "animated": {"type": "boolean", "required": False, "default": False}, + "id": {"coerce": int, "nullable": True, "default": None}, + "name": {"type": "string", "required": True}, + }, + }, } GW_STATUS_UPDATE = { @@ -526,6 +541,19 @@ USER_SETTINGS = { "timezone_offset": {"type": "number", "required": False}, "status": {"type": "status_external", "required": False}, "theme": {"type": "theme", "required": False}, + "custom_status": { + "type": "dict", + "required": False, + "nullable": True, + "schema": { + "emoji_id": {"coerce": int, "nullable": True}, + "emoji_name": {"type": "string", "nullable": True}, + # discord's timestamps dont seem to work well with + # datetime.fromisoformat, so for now, we trust the client + "expires_at": {"type": "string", "nullable": True}, + "text": {"type": "string", "nullable": True}, + }, + }, } RELATIONSHIP = { diff --git a/litecord/user_storage.py b/litecord/user_storage.py index 45704cc..02979fd 100644 --- a/litecord/user_storage.py +++ b/litecord/user_storage.py @@ -56,7 +56,7 @@ class UserStorage: user_id, ) - if not row: + if row is None: log.info("Generating user settings for {}", user_id) await self.db.execute( diff --git a/litecord/utils.py b/litecord/utils.py index 7fbc43f..b516029 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -20,6 +20,8 @@ along with this program. If not, see . import asyncio import json import secrets +import datetime +import re from typing import Any, Iterable, Optional, Sequence, List, Dict, Union from logbook import Logger @@ -292,6 +294,75 @@ def rand_hex(length: int = 8) -> str: return secrets.token_hex(length)[:length] +def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]: + if timestamp: + splitted = re.split(r"[^\d]", timestamp.replace("+00:00", "")) + + # ignore last component (which can be empty, because of the last Z + # letter in a timestamp) + splitted = splitted[:7] + components = list(map(int, splitted)) + return datetime.datetime(*components) + + return None + + +def custom_status_is_expired(expired_at: Optional[str]) -> bool: + """Return if a custom status is expired.""" + expires_at = parse_time(expired_at) + now = datetime.datetime.utcnow() + return bool(expires_at and now > expires_at) + + +async def custom_status_set_null(user_id: int) -> None: + """Set a user's custom status in the database to NULL. + + This function does not do any gateway side effects. + """ + await app.db.execute( + """ + UPDATE user_settings + SET custom_status = NULL + WHERE user_id = $1 + """, + user_id, + ) + + +async def custom_status_to_activity(custom_status: dict) -> Optional[dict]: + """Convert a custom status coming from user settings to an activity. + + Returns None if the given custom status is invalid and shouldn't be + used anymore. + """ + text = custom_status.get("text") + emoji_id = custom_status.get("emoji_id") + emoji_name = custom_status.get("emoji_name") + emoji = None if emoji_id is None else await app.storage.get_emoji(emoji_id) + + activity = {"type": 4, "name": "Custom Status"} + + if emoji is not None: + activity["emoji"] = { + "animated": emoji["animated"], + "id": str(emoji["id"]), + "name": emoji["name"], + } + elif emoji_name is not None: + activity["emoji"] = {"name": emoji_name} + + if text is not None: + activity["state"] = text + + if "emoji" not in activity and "state" not in activity: + return None + + if custom_status_is_expired(custom_status.get("expired_at")): + return None + + return activity + + def want_bytes(data: Union[str, bytes]) -> bytes: return data if isinstance(data, bytes) else data.encode() diff --git a/manage/cmd/migration/scripts/4_add_custom_status_settings.sql b/manage/cmd/migration/scripts/4_add_custom_status_settings.sql new file mode 100644 index 0000000..cbe368b --- /dev/null +++ b/manage/cmd/migration/scripts/4_add_custom_status_settings.sql @@ -0,0 +1 @@ +ALTER TABLE user_settings ADD COLUMN custom_status jsonb DEFAULT NULL;