add better docstring to update_status

- simplify logic to fetch game
 - safer getting when making final status
This commit is contained in:
Luna 2020-03-14 20:05:14 +00:00
parent 8dc27ae9de
commit b1d1f4f228
8 changed files with 225 additions and 46 deletions

View File

@ -94,6 +94,7 @@ class ActivityType(EasyEnum):
STREAMING = 1 STREAMING = 1
LISTENING = 2 LISTENING = 2
WATCHING = 3 WATCHING = 3
CUSTOM = 4
class MessageType(EasyEnum): class MessageType(EasyEnum):

View File

@ -21,7 +21,8 @@ import collections
import asyncio import asyncio
import pprint import pprint
import zlib import zlib
from typing import List, Dict, Any, Iterable import time
from typing import List, Dict, Any, Iterable, Optional
from random import randint from random import randint
import websockets import websockets
@ -30,12 +31,15 @@ from logbook import Logger
from quart import current_app as app from quart import current_app as app
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType from litecord.enums import RelationshipType, ChannelType, ActivityType
from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.schemas import validate, GW_STATUS_UPDATE
from litecord.utils import ( from litecord.utils import (
task_wrapper, task_wrapper,
yield_chunks, yield_chunks,
maybe_int, maybe_int,
custom_status_to_activity,
custom_status_is_expired,
custom_status_set_null,
want_bytes, want_bytes,
want_string, want_string,
) )
@ -87,6 +91,7 @@ class GatewayWebsocket:
log.debug("websocket properties: {!r}", self.wsp) log.debug("websocket properties: {!r}", self.wsp)
self.state = None self.state = None
self._hb_counter = 0
self._set_encoders() self._set_encoders()
@ -301,7 +306,7 @@ class GatewayWebsocket:
await self.dispatch("GUILD_CREATE", guild) await self.dispatch("GUILD_CREATE", guild)
async def _user_ready(self) -> dict: async def _user_ready(self, *, settings=None) -> dict:
"""Fetch information about users in the READY packet. """Fetch information about users in the READY packet.
This part of the API is completly undocumented. This part of the API is completly undocumented.
@ -319,7 +324,7 @@ class GatewayWebsocket:
] ]
friend_presences = await self.app.presence.friend_presences(friend_ids) friend_presences = await self.app.presence.friend_presences(friend_ids)
settings = await self.user_storage.get_user_settings(user_id) settings = settings or await self.user_storage.get_user_settings(user_id)
return { return {
"user_settings": settings, "user_settings": settings,
@ -336,7 +341,7 @@ class GatewayWebsocket:
"analytics_token": "transbian", "analytics_token": "transbian",
} }
async def dispatch_ready(self): async def dispatch_ready(self, **kwargs):
"""Dispatch the READY packet for a connecting account.""" """Dispatch the READY packet for a connecting account."""
guilds = await self._make_guild_list() guilds = await self._make_guild_list()
@ -346,7 +351,7 @@ class GatewayWebsocket:
user_ready = {} user_ready = {}
if not self.state.bot: if not self.state.bot:
# user, fetch info # user, fetch info
user_ready = await self._user_ready() user_ready = await self._user_ready(**kwargs)
private_channels = await self.user_storage.get_dms( private_channels = await self.user_storage.get_dms(
user_id user_id
@ -481,56 +486,111 @@ class GatewayWebsocket:
for friend_id in friend_ids: for friend_id in friend_ids:
await app.dispatcher.friend.sub(user_id, friend_id) await app.dispatcher.friend.sub(user_id, friend_id)
async def update_status(self, incoming_status: dict): async def update_presence(
"""Update the status of the current websocket connection.""" self,
given_presence: dict,
*,
settings: Optional[dict] = None,
override_ratelimit=False,
):
"""Update the presence of the current websocket connection.
Invalid presences are silently dropped. As well as when the state is
invalid/incomplete.
When the session is beyond the Status Update's ratelimits, the update
is silently dropped.
"""
if not self.state: if not self.state:
return return
if self._check_ratelimit("presence", self.state.session_id): if not override_ratelimit and self._check_ratelimit(
# Presence Updates beyond the ratelimit "presence", self.state.session_id
# are just silently dropped. ):
return return
status = { settings = settings or await self.user_storage.get_user_settings(
"afk": False, self.state.user_id
# TODO: fetch status from settings )
"status": "online",
"game": None, presence = BasePresence(status=settings["status"] or "online", game=None)
# TODO: this
"since": 0, custom_status = settings.get("custom_status") or None
} if isinstance(custom_status, dict) and custom_status is not None:
status.update(incoming_status or {}) presence.game = await custom_status_to_activity(custom_status)
if presence.game is None:
await custom_status_set_null(self.state.user_id)
log.debug("pres={}, given pres={}", presence, given_presence)
try: try:
status = validate(status, GW_STATUS_UPDATE) given_presence = validate(given_presence, GW_STATUS_UPDATE)
except BadRequest as err: except BadRequest as err:
log.warning(f"Invalid status update: {err}") log.warning(f"Invalid status update: {err}")
return return
# try to extract game from activities presence.update_from_incoming_dict(given_presence)
# when game not provided
if not status.get("game"):
try:
game = status["activities"][0]
except (KeyError, IndexError):
game = None
else:
game = status["game"]
pres_status = status.get("status") or "online" # always try to use activities.0 to replace game when possible
pres_status = "offline" if pres_status == "invisible" else pres_status activities: Optional[List[dict]] = given_presence.get("activities")
self.state.presence = BasePresence(status=pres_status, game=game) try:
activity: Optional[dict] = (activities or [])[0]
except IndexError:
activity = None
game: Optional[dict] = activity or presence.game
# hacky, but works (id and created_at aren't documented)
if game is not None and game["type"] == ActivityType.CUSTOM.value:
game["id"] = "custom"
game["created_at"] = int(time.time() * 1000)
emoji = game.get("emoji") or {}
if emoji.get("id") is None and emoji.get("name") is not None:
# drop the other fields when we're using unicode emoji
game["emoji"] = {"name": emoji["name"]}
presence.game = game
if presence.status == "invisible":
presence.status = "offline"
self.state.presence = presence
log.info( log.info(
f'Updating presence status={status["status"]} for ' "updating presence status={} for uid={}",
f"uid={self.state.user_id}" presence.status,
self.state.user_id,
) )
log.debug("full presence = {}", presence)
await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence) await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
async def _custom_status_expire_check(self):
if not self.state:
return
settings = await self.user_storage.get_user_settings(self.state.user_id)
custom_status = settings["custom_status"]
if custom_status is None:
return
if not custom_status_is_expired(custom_status.get("expires_at")):
return
await custom_status_set_null(self.state.user_id)
await self.update_presence(
{"status": self.state.presence.status, "game": None},
override_ratelimit=True,
)
async def handle_1(self, payload: Dict[str, Any]): async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets.""" """Handle OP 1 Heartbeat packets."""
# give the client 3 more seconds before we # give the client 3 more seconds before we
# close the websocket # close the websocket
self._hb_counter += 1
if self._hb_counter % 2 == 0:
self.app.sched.spawn(self._custom_status_expire_check())
self._hb_start((46 + 3) * 1000) self._hb_start((46 + 3) * 1000)
cliseq = payload.get("d") cliseq = payload.get("d")
@ -562,7 +622,7 @@ class GatewayWebsocket:
large = data.get("large_threshold", 50) large = data.get("large_threshold", 50)
shard = data.get("shard", [0, 1]) shard = data.get("shard", [0, 1])
presence = data.get("presence") presence = data.get("presence") or {}
try: try:
user_id = await raw_token_check(token, self.app.db) user_id = await raw_token_check(token, self.app.db)
@ -596,17 +656,16 @@ class GatewayWebsocket:
# link the state to the user # link the state to the user
self.app.state_manager.insert(self.state) self.app.state_manager.insert(self.state)
await self.update_status(presence) settings = await self.user_storage.get_user_settings(user_id)
await self.update_presence(presence, settings=settings)
await self.subscribe_all(data.get("guild_subscriptions", True)) await self.subscribe_all(data.get("guild_subscriptions", True))
await self.dispatch_ready() await self.dispatch_ready(settings=settings)
async def handle_3(self, payload: Dict[str, Any]): async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update.""" """Handle OP 3 Status Update."""
presence = payload["d"] presence = payload["d"] or {}
await self.update_presence(presence)
# update_status will take care of validation and
# setting new presence to state
await self.update_status(presence)
async def _vsu_get_prop(self, state, data): async def _vsu_get_prop(self, state, data):
"""Get voice state properties from data, fallbacking to """Get voice state properties from data, fallbacking to

View File

@ -32,6 +32,10 @@ class BasePresence:
status: str status: str
game: Optional[dict] = None game: Optional[dict] = None
@property
def activities(self) -> list:
return [self.game] if self.game else []
@property @property
def partial_dict(self) -> dict: def partial_dict(self) -> dict:
return { return {
@ -40,9 +44,23 @@ class BasePresence:
"since": 0, "since": 0,
"client_status": {}, "client_status": {},
"mobile": False, "mobile": False,
"activities": [self.game] if self.game else [], "activities": self.activities,
} }
def update_from_incoming_dict(self, given_presence: dict) -> None:
given_status, given_game = (
given_presence.get("status"),
given_presence.get("game"),
)
if given_status is not None:
assert isinstance(given_status, str)
self.status = given_status
if given_game is not None:
assert isinstance(given_game, dict)
self.game = given_game
Presence = Dict[str, Any] Presence = Dict[str, Any]
@ -139,6 +157,7 @@ class PresenceManager:
"roles": member["roles"], "roles": member["roles"],
"status": presence.status, "status": presence.status,
"game": presence.game, "game": presence.game,
"activities": presence.activities,
}, },
) )

View File

@ -223,7 +223,7 @@ def merge(member: dict, presence: Presence) -> dict:
"user": {"id": str(member["user"]["id"])}, "user": {"id": str(member["user"]["id"])},
"status": presence["status"], "status": presence["status"],
"game": presence["game"], "game": presence["game"],
"activities": presence["activities"], "activities": (presence.get("activities") or []),
} }
}, },
} }

View File

@ -18,6 +18,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
import re import re
# from datetime import datetime
from typing import Union, Dict, List, Optional from typing import Union, Dict, List, Optional
from cerberus import Validator from cerberus import Validator
@ -41,7 +43,10 @@ from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
log = Logger(__name__) log = Logger(__name__)
# TODO use any char instead of english lol
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A) USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A)
# TODO better email regex maybe
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A) EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A)
DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A) DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A)
@ -437,6 +442,16 @@ GW_ACTIVITY = {
}, },
"instance": {"type": "boolean", "required": False}, "instance": {"type": "boolean", "required": False},
"flags": {"type": "number", "required": False}, "flags": {"type": "number", "required": False},
"emoji": {
"type": "dict",
"required": False,
"nullable": True,
"schema": {
"animated": {"type": "boolean", "required": False, "default": False},
"id": {"coerce": int, "nullable": True, "default": None},
"name": {"type": "string", "required": True},
},
},
} }
GW_STATUS_UPDATE = { GW_STATUS_UPDATE = {
@ -526,6 +541,19 @@ USER_SETTINGS = {
"timezone_offset": {"type": "number", "required": False}, "timezone_offset": {"type": "number", "required": False},
"status": {"type": "status_external", "required": False}, "status": {"type": "status_external", "required": False},
"theme": {"type": "theme", "required": False}, "theme": {"type": "theme", "required": False},
"custom_status": {
"type": "dict",
"required": False,
"nullable": True,
"schema": {
"emoji_id": {"coerce": int, "nullable": True},
"emoji_name": {"type": "string", "nullable": True},
# discord's timestamps dont seem to work well with
# datetime.fromisoformat, so for now, we trust the client
"expires_at": {"type": "string", "nullable": True},
"text": {"type": "string", "nullable": True},
},
},
} }
RELATIONSHIP = { RELATIONSHIP = {

View File

@ -56,7 +56,7 @@ class UserStorage:
user_id, user_id,
) )
if not row: if row is None:
log.info("Generating user settings for {}", user_id) log.info("Generating user settings for {}", user_id)
await self.db.execute( await self.db.execute(

View File

@ -20,6 +20,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import json import json
import secrets import secrets
import datetime
import re
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger from logbook import Logger
@ -292,6 +294,75 @@ def rand_hex(length: int = 8) -> str:
return secrets.token_hex(length)[:length] return secrets.token_hex(length)[:length]
def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
if timestamp:
splitted = re.split(r"[^\d]", timestamp.replace("+00:00", ""))
# ignore last component (which can be empty, because of the last Z
# letter in a timestamp)
splitted = splitted[:7]
components = list(map(int, splitted))
return datetime.datetime(*components)
return None
def custom_status_is_expired(expired_at: Optional[str]) -> bool:
"""Return if a custom status is expired."""
expires_at = parse_time(expired_at)
now = datetime.datetime.utcnow()
return bool(expires_at and now > expires_at)
async def custom_status_set_null(user_id: int) -> None:
"""Set a user's custom status in the database to NULL.
This function does not do any gateway side effects.
"""
await app.db.execute(
"""
UPDATE user_settings
SET custom_status = NULL
WHERE user_id = $1
""",
user_id,
)
async def custom_status_to_activity(custom_status: dict) -> Optional[dict]:
"""Convert a custom status coming from user settings to an activity.
Returns None if the given custom status is invalid and shouldn't be
used anymore.
"""
text = custom_status.get("text")
emoji_id = custom_status.get("emoji_id")
emoji_name = custom_status.get("emoji_name")
emoji = None if emoji_id is None else await app.storage.get_emoji(emoji_id)
activity = {"type": 4, "name": "Custom Status"}
if emoji is not None:
activity["emoji"] = {
"animated": emoji["animated"],
"id": str(emoji["id"]),
"name": emoji["name"],
}
elif emoji_name is not None:
activity["emoji"] = {"name": emoji_name}
if text is not None:
activity["state"] = text
if "emoji" not in activity and "state" not in activity:
return None
if custom_status_is_expired(custom_status.get("expired_at")):
return None
return activity
def want_bytes(data: Union[str, bytes]) -> bytes: def want_bytes(data: Union[str, bytes]) -> bytes:
return data if isinstance(data, bytes) else data.encode() return data if isinstance(data, bytes) else data.encode()

View File

@ -0,0 +1 @@
ALTER TABLE user_settings ADD COLUMN custom_status jsonb DEFAULT NULL;