diff --git a/litecord/enums.py b/litecord/enums.py
index 2c4e3b8..0a2ddf3 100644
--- a/litecord/enums.py
+++ b/litecord/enums.py
@@ -94,6 +94,7 @@ class ActivityType(EasyEnum):
STREAMING = 1
LISTENING = 2
WATCHING = 3
+ CUSTOM = 4
class MessageType(EasyEnum):
diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py
index 8be9a8f..a98752f 100644
--- a/litecord/gateway/websocket.py
+++ b/litecord/gateway/websocket.py
@@ -21,7 +21,8 @@ import collections
import asyncio
import pprint
import zlib
-from typing import List, Dict, Any, Iterable
+import time
+from typing import List, Dict, Any, Iterable, Optional
from random import randint
import websockets
@@ -30,12 +31,15 @@ from logbook import Logger
from quart import current_app as app
from litecord.auth import raw_token_check
-from litecord.enums import RelationshipType, ChannelType
+from litecord.enums import RelationshipType, ChannelType, ActivityType
from litecord.schemas import validate, GW_STATUS_UPDATE
from litecord.utils import (
task_wrapper,
yield_chunks,
maybe_int,
+ custom_status_to_activity,
+ custom_status_is_expired,
+ custom_status_set_null,
want_bytes,
want_string,
)
@@ -87,6 +91,7 @@ class GatewayWebsocket:
log.debug("websocket properties: {!r}", self.wsp)
self.state = None
+ self._hb_counter = 0
self._set_encoders()
@@ -301,7 +306,7 @@ class GatewayWebsocket:
await self.dispatch("GUILD_CREATE", guild)
- async def _user_ready(self) -> dict:
+ async def _user_ready(self, *, settings=None) -> dict:
"""Fetch information about users in the READY packet.
This part of the API is completly undocumented.
@@ -319,7 +324,7 @@ class GatewayWebsocket:
]
friend_presences = await self.app.presence.friend_presences(friend_ids)
- settings = await self.user_storage.get_user_settings(user_id)
+ settings = settings or await self.user_storage.get_user_settings(user_id)
return {
"user_settings": settings,
@@ -336,7 +341,7 @@ class GatewayWebsocket:
"analytics_token": "transbian",
}
- async def dispatch_ready(self):
+ async def dispatch_ready(self, **kwargs):
"""Dispatch the READY packet for a connecting account."""
guilds = await self._make_guild_list()
@@ -346,7 +351,7 @@ class GatewayWebsocket:
user_ready = {}
if not self.state.bot:
# user, fetch info
- user_ready = await self._user_ready()
+ user_ready = await self._user_ready(**kwargs)
private_channels = await self.user_storage.get_dms(
user_id
@@ -481,56 +486,111 @@ class GatewayWebsocket:
for friend_id in friend_ids:
await app.dispatcher.friend.sub(user_id, friend_id)
- async def update_status(self, incoming_status: dict):
- """Update the status of the current websocket connection."""
+ async def update_presence(
+ self,
+ given_presence: dict,
+ *,
+ settings: Optional[dict] = None,
+ override_ratelimit=False,
+ ):
+ """Update the presence of the current websocket connection.
+
+ Invalid presences are silently dropped. As well as when the state is
+ invalid/incomplete.
+ When the session is beyond the Status Update's ratelimits, the update
+ is silently dropped.
+ """
if not self.state:
return
- if self._check_ratelimit("presence", self.state.session_id):
- # Presence Updates beyond the ratelimit
- # are just silently dropped.
+ if not override_ratelimit and self._check_ratelimit(
+ "presence", self.state.session_id
+ ):
return
- status = {
- "afk": False,
- # TODO: fetch status from settings
- "status": "online",
- "game": None,
- # TODO: this
- "since": 0,
- }
- status.update(incoming_status or {})
+ settings = settings or await self.user_storage.get_user_settings(
+ self.state.user_id
+ )
+
+ presence = BasePresence(status=settings["status"] or "online", game=None)
+
+ custom_status = settings.get("custom_status") or None
+ if isinstance(custom_status, dict) and custom_status is not None:
+ presence.game = await custom_status_to_activity(custom_status)
+ if presence.game is None:
+ await custom_status_set_null(self.state.user_id)
+
+ log.debug("pres={}, given pres={}", presence, given_presence)
try:
- status = validate(status, GW_STATUS_UPDATE)
+ given_presence = validate(given_presence, GW_STATUS_UPDATE)
except BadRequest as err:
log.warning(f"Invalid status update: {err}")
return
- # try to extract game from activities
- # when game not provided
- if not status.get("game"):
- try:
- game = status["activities"][0]
- except (KeyError, IndexError):
- game = None
- else:
- game = status["game"]
+ presence.update_from_incoming_dict(given_presence)
- pres_status = status.get("status") or "online"
- pres_status = "offline" if pres_status == "invisible" else pres_status
- self.state.presence = BasePresence(status=pres_status, game=game)
+ # always try to use activities.0 to replace game when possible
+ activities: Optional[List[dict]] = given_presence.get("activities")
+ try:
+ activity: Optional[dict] = (activities or [])[0]
+ except IndexError:
+ activity = None
+
+ game: Optional[dict] = activity or presence.game
+
+ # hacky, but works (id and created_at aren't documented)
+ if game is not None and game["type"] == ActivityType.CUSTOM.value:
+ game["id"] = "custom"
+ game["created_at"] = int(time.time() * 1000)
+
+ emoji = game.get("emoji") or {}
+ if emoji.get("id") is None and emoji.get("name") is not None:
+ # drop the other fields when we're using unicode emoji
+ game["emoji"] = {"name": emoji["name"]}
+
+ presence.game = game
+
+ if presence.status == "invisible":
+ presence.status = "offline"
+
+ self.state.presence = presence
log.info(
- f'Updating presence status={status["status"]} for '
- f"uid={self.state.user_id}"
+ "updating presence status={} for uid={}",
+ presence.status,
+ self.state.user_id,
)
+ log.debug("full presence = {}", presence)
await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
+ async def _custom_status_expire_check(self):
+ if not self.state:
+ return
+
+ settings = await self.user_storage.get_user_settings(self.state.user_id)
+ custom_status = settings["custom_status"]
+ if custom_status is None:
+ return
+
+ if not custom_status_is_expired(custom_status.get("expires_at")):
+ return
+
+ await custom_status_set_null(self.state.user_id)
+ await self.update_presence(
+ {"status": self.state.presence.status, "game": None},
+ override_ratelimit=True,
+ )
+
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
# give the client 3 more seconds before we
# close the websocket
+
+ self._hb_counter += 1
+ if self._hb_counter % 2 == 0:
+ self.app.sched.spawn(self._custom_status_expire_check())
+
self._hb_start((46 + 3) * 1000)
cliseq = payload.get("d")
@@ -562,7 +622,7 @@ class GatewayWebsocket:
large = data.get("large_threshold", 50)
shard = data.get("shard", [0, 1])
- presence = data.get("presence")
+ presence = data.get("presence") or {}
try:
user_id = await raw_token_check(token, self.app.db)
@@ -596,17 +656,16 @@ class GatewayWebsocket:
# link the state to the user
self.app.state_manager.insert(self.state)
- await self.update_status(presence)
+ settings = await self.user_storage.get_user_settings(user_id)
+
+ await self.update_presence(presence, settings=settings)
await self.subscribe_all(data.get("guild_subscriptions", True))
- await self.dispatch_ready()
+ await self.dispatch_ready(settings=settings)
async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update."""
- presence = payload["d"]
-
- # update_status will take care of validation and
- # setting new presence to state
- await self.update_status(presence)
+ presence = payload["d"] or {}
+ await self.update_presence(presence)
async def _vsu_get_prop(self, state, data):
"""Get voice state properties from data, fallbacking to
diff --git a/litecord/presence.py b/litecord/presence.py
index eafc118..a66dba0 100644
--- a/litecord/presence.py
+++ b/litecord/presence.py
@@ -32,6 +32,10 @@ class BasePresence:
status: str
game: Optional[dict] = None
+ @property
+ def activities(self) -> list:
+ return [self.game] if self.game else []
+
@property
def partial_dict(self) -> dict:
return {
@@ -40,9 +44,23 @@ class BasePresence:
"since": 0,
"client_status": {},
"mobile": False,
- "activities": [self.game] if self.game else [],
+ "activities": self.activities,
}
+ def update_from_incoming_dict(self, given_presence: dict) -> None:
+ given_status, given_game = (
+ given_presence.get("status"),
+ given_presence.get("game"),
+ )
+
+ if given_status is not None:
+ assert isinstance(given_status, str)
+ self.status = given_status
+
+ if given_game is not None:
+ assert isinstance(given_game, dict)
+ self.game = given_game
+
Presence = Dict[str, Any]
@@ -139,6 +157,7 @@ class PresenceManager:
"roles": member["roles"],
"status": presence.status,
"game": presence.game,
+ "activities": presence.activities,
},
)
diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py
index 4b85dc4..083c617 100644
--- a/litecord/pubsub/lazy_guild.py
+++ b/litecord/pubsub/lazy_guild.py
@@ -223,7 +223,7 @@ def merge(member: dict, presence: Presence) -> dict:
"user": {"id": str(member["user"]["id"])},
"status": presence["status"],
"game": presence["game"],
- "activities": presence["activities"],
+ "activities": (presence.get("activities") or []),
}
},
}
diff --git a/litecord/schemas.py b/litecord/schemas.py
index 3962934..4f8fdc9 100644
--- a/litecord/schemas.py
+++ b/litecord/schemas.py
@@ -18,6 +18,8 @@ along with this program. If not, see .
"""
import re
+
+# from datetime import datetime
from typing import Union, Dict, List, Optional
from cerberus import Validator
@@ -41,7 +43,10 @@ from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
log = Logger(__name__)
+# TODO use any char instead of english lol
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A)
+
+# TODO better email regex maybe
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A)
DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A)
@@ -437,6 +442,16 @@ GW_ACTIVITY = {
},
"instance": {"type": "boolean", "required": False},
"flags": {"type": "number", "required": False},
+ "emoji": {
+ "type": "dict",
+ "required": False,
+ "nullable": True,
+ "schema": {
+ "animated": {"type": "boolean", "required": False, "default": False},
+ "id": {"coerce": int, "nullable": True, "default": None},
+ "name": {"type": "string", "required": True},
+ },
+ },
}
GW_STATUS_UPDATE = {
@@ -526,6 +541,19 @@ USER_SETTINGS = {
"timezone_offset": {"type": "number", "required": False},
"status": {"type": "status_external", "required": False},
"theme": {"type": "theme", "required": False},
+ "custom_status": {
+ "type": "dict",
+ "required": False,
+ "nullable": True,
+ "schema": {
+ "emoji_id": {"coerce": int, "nullable": True},
+ "emoji_name": {"type": "string", "nullable": True},
+ # discord's timestamps dont seem to work well with
+ # datetime.fromisoformat, so for now, we trust the client
+ "expires_at": {"type": "string", "nullable": True},
+ "text": {"type": "string", "nullable": True},
+ },
+ },
}
RELATIONSHIP = {
diff --git a/litecord/user_storage.py b/litecord/user_storage.py
index 45704cc..02979fd 100644
--- a/litecord/user_storage.py
+++ b/litecord/user_storage.py
@@ -56,7 +56,7 @@ class UserStorage:
user_id,
)
- if not row:
+ if row is None:
log.info("Generating user settings for {}", user_id)
await self.db.execute(
diff --git a/litecord/utils.py b/litecord/utils.py
index 7fbc43f..b516029 100644
--- a/litecord/utils.py
+++ b/litecord/utils.py
@@ -20,6 +20,8 @@ along with this program. If not, see .
import asyncio
import json
import secrets
+import datetime
+import re
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger
@@ -292,6 +294,75 @@ def rand_hex(length: int = 8) -> str:
return secrets.token_hex(length)[:length]
+def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
+ if timestamp:
+ splitted = re.split(r"[^\d]", timestamp.replace("+00:00", ""))
+
+ # ignore last component (which can be empty, because of the last Z
+ # letter in a timestamp)
+ splitted = splitted[:7]
+ components = list(map(int, splitted))
+ return datetime.datetime(*components)
+
+ return None
+
+
+def custom_status_is_expired(expired_at: Optional[str]) -> bool:
+ """Return if a custom status is expired."""
+ expires_at = parse_time(expired_at)
+ now = datetime.datetime.utcnow()
+ return bool(expires_at and now > expires_at)
+
+
+async def custom_status_set_null(user_id: int) -> None:
+ """Set a user's custom status in the database to NULL.
+
+ This function does not do any gateway side effects.
+ """
+ await app.db.execute(
+ """
+ UPDATE user_settings
+ SET custom_status = NULL
+ WHERE user_id = $1
+ """,
+ user_id,
+ )
+
+
+async def custom_status_to_activity(custom_status: dict) -> Optional[dict]:
+ """Convert a custom status coming from user settings to an activity.
+
+ Returns None if the given custom status is invalid and shouldn't be
+ used anymore.
+ """
+ text = custom_status.get("text")
+ emoji_id = custom_status.get("emoji_id")
+ emoji_name = custom_status.get("emoji_name")
+ emoji = None if emoji_id is None else await app.storage.get_emoji(emoji_id)
+
+ activity = {"type": 4, "name": "Custom Status"}
+
+ if emoji is not None:
+ activity["emoji"] = {
+ "animated": emoji["animated"],
+ "id": str(emoji["id"]),
+ "name": emoji["name"],
+ }
+ elif emoji_name is not None:
+ activity["emoji"] = {"name": emoji_name}
+
+ if text is not None:
+ activity["state"] = text
+
+ if "emoji" not in activity and "state" not in activity:
+ return None
+
+ if custom_status_is_expired(custom_status.get("expired_at")):
+ return None
+
+ return activity
+
+
def want_bytes(data: Union[str, bytes]) -> bytes:
return data if isinstance(data, bytes) else data.encode()
diff --git a/manage/cmd/migration/scripts/4_add_custom_status_settings.sql b/manage/cmd/migration/scripts/4_add_custom_status_settings.sql
new file mode 100644
index 0000000..cbe368b
--- /dev/null
+++ b/manage/cmd/migration/scripts/4_add_custom_status_settings.sql
@@ -0,0 +1 @@
+ALTER TABLE user_settings ADD COLUMN custom_status jsonb DEFAULT NULL;