mirror of https://gitlab.com/litecord/litecord.git
add better docstring to update_status
- simplify logic to fetch game - safer getting when making final status
This commit is contained in:
parent
8dc27ae9de
commit
b1d1f4f228
|
|
@ -94,6 +94,7 @@ class ActivityType(EasyEnum):
|
|||
STREAMING = 1
|
||||
LISTENING = 2
|
||||
WATCHING = 3
|
||||
CUSTOM = 4
|
||||
|
||||
|
||||
class MessageType(EasyEnum):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ import collections
|
|||
import asyncio
|
||||
import pprint
|
||||
import zlib
|
||||
from typing import List, Dict, Any, Iterable
|
||||
import time
|
||||
from typing import List, Dict, Any, Iterable, Optional
|
||||
from random import randint
|
||||
|
||||
import websockets
|
||||
|
|
@ -30,12 +31,15 @@ from logbook import Logger
|
|||
from quart import current_app as app
|
||||
|
||||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType, ChannelType
|
||||
from litecord.enums import RelationshipType, ChannelType, ActivityType
|
||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||
from litecord.utils import (
|
||||
task_wrapper,
|
||||
yield_chunks,
|
||||
maybe_int,
|
||||
custom_status_to_activity,
|
||||
custom_status_is_expired,
|
||||
custom_status_set_null,
|
||||
want_bytes,
|
||||
want_string,
|
||||
)
|
||||
|
|
@ -87,6 +91,7 @@ class GatewayWebsocket:
|
|||
log.debug("websocket properties: {!r}", self.wsp)
|
||||
|
||||
self.state = None
|
||||
self._hb_counter = 0
|
||||
|
||||
self._set_encoders()
|
||||
|
||||
|
|
@ -301,7 +306,7 @@ class GatewayWebsocket:
|
|||
|
||||
await self.dispatch("GUILD_CREATE", guild)
|
||||
|
||||
async def _user_ready(self) -> dict:
|
||||
async def _user_ready(self, *, settings=None) -> dict:
|
||||
"""Fetch information about users in the READY packet.
|
||||
|
||||
This part of the API is completly undocumented.
|
||||
|
|
@ -319,7 +324,7 @@ class GatewayWebsocket:
|
|||
]
|
||||
|
||||
friend_presences = await self.app.presence.friend_presences(friend_ids)
|
||||
settings = await self.user_storage.get_user_settings(user_id)
|
||||
settings = settings or await self.user_storage.get_user_settings(user_id)
|
||||
|
||||
return {
|
||||
"user_settings": settings,
|
||||
|
|
@ -336,7 +341,7 @@ class GatewayWebsocket:
|
|||
"analytics_token": "transbian",
|
||||
}
|
||||
|
||||
async def dispatch_ready(self):
|
||||
async def dispatch_ready(self, **kwargs):
|
||||
"""Dispatch the READY packet for a connecting account."""
|
||||
guilds = await self._make_guild_list()
|
||||
|
||||
|
|
@ -346,7 +351,7 @@ class GatewayWebsocket:
|
|||
user_ready = {}
|
||||
if not self.state.bot:
|
||||
# user, fetch info
|
||||
user_ready = await self._user_ready()
|
||||
user_ready = await self._user_ready(**kwargs)
|
||||
|
||||
private_channels = await self.user_storage.get_dms(
|
||||
user_id
|
||||
|
|
@ -481,56 +486,111 @@ class GatewayWebsocket:
|
|||
for friend_id in friend_ids:
|
||||
await app.dispatcher.friend.sub(user_id, friend_id)
|
||||
|
||||
async def update_status(self, incoming_status: dict):
|
||||
"""Update the status of the current websocket connection."""
|
||||
async def update_presence(
|
||||
self,
|
||||
given_presence: dict,
|
||||
*,
|
||||
settings: Optional[dict] = None,
|
||||
override_ratelimit=False,
|
||||
):
|
||||
"""Update the presence of the current websocket connection.
|
||||
|
||||
Invalid presences are silently dropped. As well as when the state is
|
||||
invalid/incomplete.
|
||||
When the session is beyond the Status Update's ratelimits, the update
|
||||
is silently dropped.
|
||||
"""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
if self._check_ratelimit("presence", self.state.session_id):
|
||||
# Presence Updates beyond the ratelimit
|
||||
# are just silently dropped.
|
||||
if not override_ratelimit and self._check_ratelimit(
|
||||
"presence", self.state.session_id
|
||||
):
|
||||
return
|
||||
|
||||
status = {
|
||||
"afk": False,
|
||||
# TODO: fetch status from settings
|
||||
"status": "online",
|
||||
"game": None,
|
||||
# TODO: this
|
||||
"since": 0,
|
||||
}
|
||||
status.update(incoming_status or {})
|
||||
settings = settings or await self.user_storage.get_user_settings(
|
||||
self.state.user_id
|
||||
)
|
||||
|
||||
presence = BasePresence(status=settings["status"] or "online", game=None)
|
||||
|
||||
custom_status = settings.get("custom_status") or None
|
||||
if isinstance(custom_status, dict) and custom_status is not None:
|
||||
presence.game = await custom_status_to_activity(custom_status)
|
||||
if presence.game is None:
|
||||
await custom_status_set_null(self.state.user_id)
|
||||
|
||||
log.debug("pres={}, given pres={}", presence, given_presence)
|
||||
|
||||
try:
|
||||
status = validate(status, GW_STATUS_UPDATE)
|
||||
given_presence = validate(given_presence, GW_STATUS_UPDATE)
|
||||
except BadRequest as err:
|
||||
log.warning(f"Invalid status update: {err}")
|
||||
return
|
||||
|
||||
# try to extract game from activities
|
||||
# when game not provided
|
||||
if not status.get("game"):
|
||||
try:
|
||||
game = status["activities"][0]
|
||||
except (KeyError, IndexError):
|
||||
game = None
|
||||
else:
|
||||
game = status["game"]
|
||||
presence.update_from_incoming_dict(given_presence)
|
||||
|
||||
pres_status = status.get("status") or "online"
|
||||
pres_status = "offline" if pres_status == "invisible" else pres_status
|
||||
self.state.presence = BasePresence(status=pres_status, game=game)
|
||||
# always try to use activities.0 to replace game when possible
|
||||
activities: Optional[List[dict]] = given_presence.get("activities")
|
||||
try:
|
||||
activity: Optional[dict] = (activities or [])[0]
|
||||
except IndexError:
|
||||
activity = None
|
||||
|
||||
game: Optional[dict] = activity or presence.game
|
||||
|
||||
# hacky, but works (id and created_at aren't documented)
|
||||
if game is not None and game["type"] == ActivityType.CUSTOM.value:
|
||||
game["id"] = "custom"
|
||||
game["created_at"] = int(time.time() * 1000)
|
||||
|
||||
emoji = game.get("emoji") or {}
|
||||
if emoji.get("id") is None and emoji.get("name") is not None:
|
||||
# drop the other fields when we're using unicode emoji
|
||||
game["emoji"] = {"name": emoji["name"]}
|
||||
|
||||
presence.game = game
|
||||
|
||||
if presence.status == "invisible":
|
||||
presence.status = "offline"
|
||||
|
||||
self.state.presence = presence
|
||||
|
||||
log.info(
|
||||
f'Updating presence status={status["status"]} for '
|
||||
f"uid={self.state.user_id}"
|
||||
"updating presence status={} for uid={}",
|
||||
presence.status,
|
||||
self.state.user_id,
|
||||
)
|
||||
log.debug("full presence = {}", presence)
|
||||
await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
|
||||
|
||||
async def _custom_status_expire_check(self):
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
settings = await self.user_storage.get_user_settings(self.state.user_id)
|
||||
custom_status = settings["custom_status"]
|
||||
if custom_status is None:
|
||||
return
|
||||
|
||||
if not custom_status_is_expired(custom_status.get("expires_at")):
|
||||
return
|
||||
|
||||
await custom_status_set_null(self.state.user_id)
|
||||
await self.update_presence(
|
||||
{"status": self.state.presence.status, "game": None},
|
||||
override_ratelimit=True,
|
||||
)
|
||||
|
||||
async def handle_1(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 1 Heartbeat packets."""
|
||||
# give the client 3 more seconds before we
|
||||
# close the websocket
|
||||
|
||||
self._hb_counter += 1
|
||||
if self._hb_counter % 2 == 0:
|
||||
self.app.sched.spawn(self._custom_status_expire_check())
|
||||
|
||||
self._hb_start((46 + 3) * 1000)
|
||||
cliseq = payload.get("d")
|
||||
|
||||
|
|
@ -562,7 +622,7 @@ class GatewayWebsocket:
|
|||
large = data.get("large_threshold", 50)
|
||||
|
||||
shard = data.get("shard", [0, 1])
|
||||
presence = data.get("presence")
|
||||
presence = data.get("presence") or {}
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.app.db)
|
||||
|
|
@ -596,17 +656,16 @@ class GatewayWebsocket:
|
|||
# link the state to the user
|
||||
self.app.state_manager.insert(self.state)
|
||||
|
||||
await self.update_status(presence)
|
||||
settings = await self.user_storage.get_user_settings(user_id)
|
||||
|
||||
await self.update_presence(presence, settings=settings)
|
||||
await self.subscribe_all(data.get("guild_subscriptions", True))
|
||||
await self.dispatch_ready()
|
||||
await self.dispatch_ready(settings=settings)
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 3 Status Update."""
|
||||
presence = payload["d"]
|
||||
|
||||
# update_status will take care of validation and
|
||||
# setting new presence to state
|
||||
await self.update_status(presence)
|
||||
presence = payload["d"] or {}
|
||||
await self.update_presence(presence)
|
||||
|
||||
async def _vsu_get_prop(self, state, data):
|
||||
"""Get voice state properties from data, fallbacking to
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@ class BasePresence:
|
|||
status: str
|
||||
game: Optional[dict] = None
|
||||
|
||||
@property
|
||||
def activities(self) -> list:
|
||||
return [self.game] if self.game else []
|
||||
|
||||
@property
|
||||
def partial_dict(self) -> dict:
|
||||
return {
|
||||
|
|
@ -40,9 +44,23 @@ class BasePresence:
|
|||
"since": 0,
|
||||
"client_status": {},
|
||||
"mobile": False,
|
||||
"activities": [self.game] if self.game else [],
|
||||
"activities": self.activities,
|
||||
}
|
||||
|
||||
def update_from_incoming_dict(self, given_presence: dict) -> None:
|
||||
given_status, given_game = (
|
||||
given_presence.get("status"),
|
||||
given_presence.get("game"),
|
||||
)
|
||||
|
||||
if given_status is not None:
|
||||
assert isinstance(given_status, str)
|
||||
self.status = given_status
|
||||
|
||||
if given_game is not None:
|
||||
assert isinstance(given_game, dict)
|
||||
self.game = given_game
|
||||
|
||||
|
||||
Presence = Dict[str, Any]
|
||||
|
||||
|
|
@ -139,6 +157,7 @@ class PresenceManager:
|
|||
"roles": member["roles"],
|
||||
"status": presence.status,
|
||||
"game": presence.game,
|
||||
"activities": presence.activities,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ def merge(member: dict, presence: Presence) -> dict:
|
|||
"user": {"id": str(member["user"]["id"])},
|
||||
"status": presence["status"],
|
||||
"game": presence["game"],
|
||||
"activities": presence["activities"],
|
||||
"activities": (presence.get("activities") or []),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
import re
|
||||
|
||||
# from datetime import datetime
|
||||
from typing import Union, Dict, List, Optional
|
||||
|
||||
from cerberus import Validator
|
||||
|
|
@ -41,7 +43,10 @@ from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
|
|||
|
||||
log = Logger(__name__)
|
||||
|
||||
# TODO use any char instead of english lol
|
||||
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A)
|
||||
|
||||
# TODO better email regex maybe
|
||||
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A)
|
||||
DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A)
|
||||
|
||||
|
|
@ -437,6 +442,16 @@ GW_ACTIVITY = {
|
|||
},
|
||||
"instance": {"type": "boolean", "required": False},
|
||||
"flags": {"type": "number", "required": False},
|
||||
"emoji": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"nullable": True,
|
||||
"schema": {
|
||||
"animated": {"type": "boolean", "required": False, "default": False},
|
||||
"id": {"coerce": int, "nullable": True, "default": None},
|
||||
"name": {"type": "string", "required": True},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
GW_STATUS_UPDATE = {
|
||||
|
|
@ -526,6 +541,19 @@ USER_SETTINGS = {
|
|||
"timezone_offset": {"type": "number", "required": False},
|
||||
"status": {"type": "status_external", "required": False},
|
||||
"theme": {"type": "theme", "required": False},
|
||||
"custom_status": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"nullable": True,
|
||||
"schema": {
|
||||
"emoji_id": {"coerce": int, "nullable": True},
|
||||
"emoji_name": {"type": "string", "nullable": True},
|
||||
# discord's timestamps dont seem to work well with
|
||||
# datetime.fromisoformat, so for now, we trust the client
|
||||
"expires_at": {"type": "string", "nullable": True},
|
||||
"text": {"type": "string", "nullable": True},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
RELATIONSHIP = {
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class UserStorage:
|
|||
user_id,
|
||||
)
|
||||
|
||||
if not row:
|
||||
if row is None:
|
||||
log.info("Generating user settings for {}", user_id)
|
||||
|
||||
await self.db.execute(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
import asyncio
|
||||
import json
|
||||
import secrets
|
||||
import datetime
|
||||
import re
|
||||
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
|
||||
|
||||
from logbook import Logger
|
||||
|
|
@ -292,6 +294,75 @@ def rand_hex(length: int = 8) -> str:
|
|||
return secrets.token_hex(length)[:length]
|
||||
|
||||
|
||||
def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
|
||||
if timestamp:
|
||||
splitted = re.split(r"[^\d]", timestamp.replace("+00:00", ""))
|
||||
|
||||
# ignore last component (which can be empty, because of the last Z
|
||||
# letter in a timestamp)
|
||||
splitted = splitted[:7]
|
||||
components = list(map(int, splitted))
|
||||
return datetime.datetime(*components)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def custom_status_is_expired(expired_at: Optional[str]) -> bool:
|
||||
"""Return if a custom status is expired."""
|
||||
expires_at = parse_time(expired_at)
|
||||
now = datetime.datetime.utcnow()
|
||||
return bool(expires_at and now > expires_at)
|
||||
|
||||
|
||||
async def custom_status_set_null(user_id: int) -> None:
|
||||
"""Set a user's custom status in the database to NULL.
|
||||
|
||||
This function does not do any gateway side effects.
|
||||
"""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_settings
|
||||
SET custom_status = NULL
|
||||
WHERE user_id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
async def custom_status_to_activity(custom_status: dict) -> Optional[dict]:
|
||||
"""Convert a custom status coming from user settings to an activity.
|
||||
|
||||
Returns None if the given custom status is invalid and shouldn't be
|
||||
used anymore.
|
||||
"""
|
||||
text = custom_status.get("text")
|
||||
emoji_id = custom_status.get("emoji_id")
|
||||
emoji_name = custom_status.get("emoji_name")
|
||||
emoji = None if emoji_id is None else await app.storage.get_emoji(emoji_id)
|
||||
|
||||
activity = {"type": 4, "name": "Custom Status"}
|
||||
|
||||
if emoji is not None:
|
||||
activity["emoji"] = {
|
||||
"animated": emoji["animated"],
|
||||
"id": str(emoji["id"]),
|
||||
"name": emoji["name"],
|
||||
}
|
||||
elif emoji_name is not None:
|
||||
activity["emoji"] = {"name": emoji_name}
|
||||
|
||||
if text is not None:
|
||||
activity["state"] = text
|
||||
|
||||
if "emoji" not in activity and "state" not in activity:
|
||||
return None
|
||||
|
||||
if custom_status_is_expired(custom_status.get("expired_at")):
|
||||
return None
|
||||
|
||||
return activity
|
||||
|
||||
|
||||
def want_bytes(data: Union[str, bytes]) -> bytes:
|
||||
return data if isinstance(data, bytes) else data.encode()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
ALTER TABLE user_settings ADD COLUMN custom_status jsonb DEFAULT NULL;
|
||||
Loading…
Reference in New Issue