mirror of https://gitlab.com/litecord/litecord.git
Rough presence management
Lots of changes to get it working.
One day changes will be able to be small enough to be split across
commits.
- enums: use enum.Enum, make EasyEnum subclass
- enums: add ActivityType, ActivityFlags, StatusType
- gateway.state: use 128 random bits instead of 256
- gateway.state: add MAX_STORE_SIZE in PayloadStore and check it when
adding a new payload
- gateway.websocket: add GatewayWebsocket.update_status
- presence: add PresenceManager.dispatch_guild_pres and
PresenceManager.dispatch_pres
- schema: add snowflake, activity_type, status_external types
- schema: add GW_ACTIVITY, GW_STATUS_UPDATE
- storage: fix _channels_extra and fixes for ChannelType as enum instead of
class
This commit is contained in:
parent
d39783e666
commit
cd5dbc4886
|
|
@ -1,5 +1,15 @@
|
||||||
|
import ctypes
|
||||||
|
|
||||||
class ChannelType:
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class EasyEnum(Enum):
|
||||||
|
@classmethod
|
||||||
|
def values(cls):
|
||||||
|
return [v.value for v in cls.__members__.values()]
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelType(EasyEnum):
|
||||||
GUILD_TEXT = 0
|
GUILD_TEXT = 0
|
||||||
DM = 1
|
DM = 1
|
||||||
GUILD_VOICE = 2
|
GUILD_VOICE = 2
|
||||||
|
|
@ -7,7 +17,13 @@ class ChannelType:
|
||||||
GUILD_CATEGORY = 4
|
GUILD_CATEGORY = 4
|
||||||
|
|
||||||
|
|
||||||
class MessageType:
|
class ActivityType(EasyEnum):
|
||||||
|
PLAYING = 0
|
||||||
|
STREAMING = 1
|
||||||
|
LISTENING = 2
|
||||||
|
|
||||||
|
|
||||||
|
class MessageType(EasyEnum):
|
||||||
DEFAULT = 0
|
DEFAULT = 0
|
||||||
RECIPIENT_ADD = 1
|
RECIPIENT_ADD = 1
|
||||||
RECIPIENT_REMOVE = 2
|
RECIPIENT_REMOVE = 2
|
||||||
|
|
@ -18,8 +34,40 @@ class MessageType:
|
||||||
GUILD_MEMBER_JOIN = 7
|
GUILD_MEMBER_JOIN = 7
|
||||||
|
|
||||||
|
|
||||||
class MessageActivityType:
|
class MessageActivityType(EasyEnum):
|
||||||
JOIN = 1
|
JOIN = 1
|
||||||
SPECTATE = 2
|
SPECTATE = 2
|
||||||
LISTEN = 3
|
LISTEN = 3
|
||||||
JOIN_REQUEST = 5
|
JOIN_REQUEST = 5
|
||||||
|
|
||||||
|
|
||||||
|
uint8 = ctypes.c_uint8
|
||||||
|
|
||||||
|
|
||||||
|
# use ctypes to interpret the bits in activity flags
|
||||||
|
class ActivityFlagsBits(ctypes.LittleEndianStructure):
|
||||||
|
_fields_ = [
|
||||||
|
('instance', uint8, 1),
|
||||||
|
('join', uint8, 1),
|
||||||
|
('spectate', uint8, 1),
|
||||||
|
('join_request', uint8, 1),
|
||||||
|
('sync', uint8, 1),
|
||||||
|
('play', uint8, 1),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ActivityFlags(ctypes.Union):
|
||||||
|
_anonymous_ = ('bit',)
|
||||||
|
|
||||||
|
_fields_ = [
|
||||||
|
('bit', ActivityFlagsBits),
|
||||||
|
('as_byte', uint8),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class StatusType(EasyEnum):
|
||||||
|
ONLINE = 'online'
|
||||||
|
DND = 'dnd'
|
||||||
|
IDLE = 'idle'
|
||||||
|
INVISIBLE = 'invisible'
|
||||||
|
OFFLINE = 'offline'
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,13 @@ import os
|
||||||
|
|
||||||
def gen_session_id() -> str:
|
def gen_session_id() -> str:
|
||||||
"""Generate a random session ID."""
|
"""Generate a random session ID."""
|
||||||
return hashlib.sha1(os.urandom(256)).hexdigest()
|
return hashlib.sha1(os.urandom(128)).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class PayloadStore:
|
class PayloadStore:
|
||||||
"""Store manager for payloads."""
|
"""Store manager for payloads."""
|
||||||
|
MAX_STORE_SIZE = 250
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.store = {}
|
self.store = {}
|
||||||
|
|
||||||
|
|
@ -16,9 +18,26 @@ class PayloadStore:
|
||||||
return self.store[opcode]
|
return self.store[opcode]
|
||||||
|
|
||||||
def __setitem__(self, opcode: int, payload: dict):
|
def __setitem__(self, opcode: int, payload: dict):
|
||||||
|
if len(self.store) > 250:
|
||||||
|
# if more than 250, remove old keys until we get 250
|
||||||
|
opcodes = sorted(list(self.store.keys()))
|
||||||
|
to_remove = len(opcodes) - self.MAX_STORE_SIZE
|
||||||
|
|
||||||
|
for idx in range(to_remove):
|
||||||
|
opcode = opcodes[idx]
|
||||||
|
self.store.pop(opcode)
|
||||||
|
|
||||||
self.store[opcode] = payload
|
self.store[opcode] = payload
|
||||||
|
|
||||||
|
|
||||||
|
class Presence:
|
||||||
|
def __init__(self, raw: dict):
|
||||||
|
self.afk = raw.get('afk', False)
|
||||||
|
self.status = raw.get('status', 'online')
|
||||||
|
self.game = raw.get('game', None)
|
||||||
|
self.since = raw.get('since', 0)
|
||||||
|
|
||||||
|
|
||||||
class GatewayState:
|
class GatewayState:
|
||||||
"""Main websocket state.
|
"""Main websocket state.
|
||||||
|
|
||||||
|
|
@ -32,6 +51,7 @@ class GatewayState:
|
||||||
self.shard = kwargs.get('shard', [0, 1])
|
self.shard = kwargs.get('shard', [0, 1])
|
||||||
self.user_id = kwargs.get('user_id')
|
self.user_id = kwargs.get('user_id')
|
||||||
self.bot = kwargs.get('bot', False)
|
self.bot = kwargs.get('bot', False)
|
||||||
|
self.presence = {}
|
||||||
self.store = PayloadStore()
|
self.store = PayloadStore()
|
||||||
|
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ from .errors import DecodeError, UnknownOPCode, \
|
||||||
from .opcodes import OP
|
from .opcodes import OP
|
||||||
from .state import GatewayState
|
from .state import GatewayState
|
||||||
|
|
||||||
|
from ..schemas import validate, GW_STATUS_UPDATE
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
WebsocketProperties = collections.namedtuple(
|
WebsocketProperties = collections.namedtuple(
|
||||||
|
|
@ -76,7 +78,7 @@ class GatewayWebsocket:
|
||||||
This function accounts for the zlib-stream
|
This function accounts for the zlib-stream
|
||||||
transport method used by Discord.
|
transport method used by Discord.
|
||||||
"""
|
"""
|
||||||
log.debug('Sending {}', pprint.pformat(payload))
|
log.debug('sending {}', pprint.pformat(payload))
|
||||||
encoded = self.encoder(payload)
|
encoded = self.encoder(payload)
|
||||||
|
|
||||||
if not isinstance(encoded, bytes):
|
if not isinstance(encoded, bytes):
|
||||||
|
|
@ -162,15 +164,29 @@ class GatewayWebsocket:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
# TODO
|
||||||
'relationships': [],
|
'relationships': [],
|
||||||
|
|
||||||
|
# TODO
|
||||||
'user_guild_settings': [],
|
'user_guild_settings': [],
|
||||||
|
|
||||||
|
# TODO
|
||||||
'notes': {},
|
'notes': {},
|
||||||
'friend_suggestion_count': 0,
|
'friend_suggestion_count': 0,
|
||||||
|
|
||||||
|
# TODO
|
||||||
'presences': [],
|
'presences': [],
|
||||||
|
|
||||||
|
# TODO
|
||||||
'read_state': [],
|
'read_state': [],
|
||||||
|
|
||||||
'experiments': [],
|
'experiments': [],
|
||||||
'guild_experiments': [],
|
'guild_experiments': [],
|
||||||
|
|
||||||
|
# TODO
|
||||||
'connected_accounts': [],
|
'connected_accounts': [],
|
||||||
|
|
||||||
|
# TODO: make those changeable
|
||||||
'user_settings': {
|
'user_settings': {
|
||||||
'afk_timeout': 300,
|
'afk_timeout': 300,
|
||||||
'animate_emoji': True,
|
'animate_emoji': True,
|
||||||
|
|
@ -198,6 +214,7 @@ class GatewayWebsocket:
|
||||||
'theme': 'dark',
|
'theme': 'dark',
|
||||||
'timezone_offset': 420,
|
'timezone_offset': 420,
|
||||||
},
|
},
|
||||||
|
|
||||||
'analytics_token': 'transbian',
|
'analytics_token': 'transbian',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -246,19 +263,49 @@ class GatewayWebsocket:
|
||||||
if current_shard > shard_count:
|
if current_shard > shard_count:
|
||||||
raise InvalidShard('Shard count > Total shards')
|
raise InvalidShard('Shard count > Total shards')
|
||||||
|
|
||||||
async def subscribe_guilds(self):
|
async def _guild_ids(self):
|
||||||
"""Subscribe to all available guilds"""
|
# TODO: account for sharding
|
||||||
guild_ids = await self.ext.db.fetch("""
|
guild_ids = await self.ext.db.fetch("""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
FROM members
|
FROM members
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", self.state.user_id)
|
""", self.state.user_id)
|
||||||
|
|
||||||
guild_ids = [r['guild_id'] for r in guild_ids]
|
return [r['guild_id'] for r in guild_ids]
|
||||||
|
|
||||||
|
async def subscribe_guilds(self):
|
||||||
|
"""Subscribe to all available guilds"""
|
||||||
|
guild_ids = await self._guild_ids()
|
||||||
self.ext.dispatcher.sub_many(self.state.user_id, guild_ids)
|
self.ext.dispatcher.sub_many(self.state.user_id, guild_ids)
|
||||||
|
|
||||||
|
async def update_status(self, status: dict):
|
||||||
|
if status is None:
|
||||||
|
status = {
|
||||||
|
'afk': False,
|
||||||
|
|
||||||
|
# TODO: fetch status from settings
|
||||||
|
'status': 'online',
|
||||||
|
'game': None,
|
||||||
|
|
||||||
|
# TODO: this
|
||||||
|
'since': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.state.presence = status
|
||||||
|
|
||||||
|
status = validate(status, GW_STATUS_UPDATE)
|
||||||
|
|
||||||
|
if not status:
|
||||||
|
# invalid status, must ignore
|
||||||
|
return
|
||||||
|
|
||||||
|
self.state.presence = status
|
||||||
|
await self.ext.presence.dispatch_pres(self.state.user_id,
|
||||||
|
self.state.presence)
|
||||||
|
|
||||||
async def handle_1(self, payload: Dict[str, Any]):
|
async def handle_1(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 1 Heartbeat packets."""
|
"""Handle OP 1 Heartbeat packets."""
|
||||||
|
# TODO: handling heartbeats
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def handle_2(self, payload: Dict[str, Any]):
|
async def handle_2(self, payload: Dict[str, Any]):
|
||||||
|
|
@ -294,7 +341,6 @@ class GatewayWebsocket:
|
||||||
shard=shard,
|
shard=shard,
|
||||||
current_shard=shard[0],
|
current_shard=shard[0],
|
||||||
shard_count=shard[1],
|
shard_count=shard[1],
|
||||||
presence=presence,
|
|
||||||
ws=self
|
ws=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -304,6 +350,9 @@ class GatewayWebsocket:
|
||||||
await self.dispatch_ready()
|
await self.dispatch_ready()
|
||||||
await self.subscribe_guilds()
|
await self.subscribe_guilds()
|
||||||
|
|
||||||
|
# dispatch presence only after subscribing
|
||||||
|
await self.update_status(presence)
|
||||||
|
|
||||||
async def handle_3(self, payload: Dict[str, Any]):
|
async def handle_3(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 3 Status Update."""
|
"""Handle OP 3 Status Update."""
|
||||||
pass
|
pass
|
||||||
|
|
@ -426,8 +475,8 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
payload = self.decoder(message)
|
payload = self.decoder(message)
|
||||||
|
|
||||||
pretty_printed = pprint.pformat(payload)
|
log.debug('received message: {}',
|
||||||
log.debug('received message: {}', pretty_printed)
|
pprint.pformat(payload))
|
||||||
|
|
||||||
await self.process_message(payload)
|
await self.process_message(payload)
|
||||||
|
|
||||||
|
|
@ -438,9 +487,9 @@ class GatewayWebsocket:
|
||||||
await self.send_hello()
|
await self.send_hello()
|
||||||
await self.listen_messages()
|
await self.listen_messages()
|
||||||
except websockets.exceptions.ConnectionClosed as err:
|
except websockets.exceptions.ConnectionClosed as err:
|
||||||
log.warning('Client closed, state={}, err={}', self.state, err)
|
log.warning('conn close, state={}, err={}', self.state, err)
|
||||||
except WebsocketClose as err:
|
except WebsocketClose as err:
|
||||||
log.warning('closed a client, state={} err={}', self.state, err)
|
log.warning('ws close, state={} err={}', self.state, err)
|
||||||
|
|
||||||
await self.ws.close(code=err.code, reason=err.reason)
|
await self.ws.close(code=err.code, reason=err.reason)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,14 @@ from typing import List, Dict, Any
|
||||||
|
|
||||||
class PresenceManager:
|
class PresenceManager:
|
||||||
"""Presence related functions."""
|
"""Presence related functions."""
|
||||||
def __init__(self, storage, state_manager):
|
def __init__(self, storage, state_manager, dispatcher):
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.state_manager = state_manager
|
self.state_manager = state_manager
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
|
||||||
async def guild_presences(self, member_ids: List[int],
|
async def guild_presences(self, member_ids: List[int],
|
||||||
guild_id: int) -> List[Dict[Any, str]]:
|
guild_id: int) -> List[Dict[Any, str]]:
|
||||||
|
"""Fetch all presences in a guild."""
|
||||||
states = self.state_manager.guild_states(member_ids, guild_id)
|
states = self.state_manager.guild_states(member_ids, guild_id)
|
||||||
|
|
||||||
presences = []
|
presences = []
|
||||||
|
|
@ -20,9 +22,38 @@ class PresenceManager:
|
||||||
presences.append({
|
presences.append({
|
||||||
'user': member['user'],
|
'user': member['user'],
|
||||||
'roles': member['roles'],
|
'roles': member['roles'],
|
||||||
'game': state.presence['game'],
|
|
||||||
'guild_id': guild_id,
|
'guild_id': guild_id,
|
||||||
'status': state.presence['status'],
|
'game': state.presence.get('game', None),
|
||||||
|
'status': state.presence.get('status', None),
|
||||||
})
|
})
|
||||||
|
|
||||||
return presences
|
return presences
|
||||||
|
|
||||||
|
async def dispatch_guild_pres(self, guild_id: int,
|
||||||
|
user_id: int, new_state: dict):
|
||||||
|
"""Dispatch a Presence update to an entire guild."""
|
||||||
|
state = dict(new_state)
|
||||||
|
|
||||||
|
if state['status'] == 'invisible':
|
||||||
|
state['status'] = 'offline'
|
||||||
|
|
||||||
|
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||||
|
|
||||||
|
await self.dispatcher.dispatch_guild(
|
||||||
|
guild_id, 'PRESENCE_UPDATE', {
|
||||||
|
'user': member['user'],
|
||||||
|
'roles': member['roles'],
|
||||||
|
'guild_id': guild_id,
|
||||||
|
|
||||||
|
'game': state['game'],
|
||||||
|
'status': state['status'],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def dispatch_pres(self, user_id: int, state):
|
||||||
|
"""Dispatch a new presence to all guilds the user is in."""
|
||||||
|
# TODO: account for sharding
|
||||||
|
guild_ids = await self.storage.get_user_guilds(user_id)
|
||||||
|
|
||||||
|
for guild_id in guild_ids:
|
||||||
|
await self.dispatch_guild_pres(guild_id, user_id, state)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ import re
|
||||||
from cerberus import Validator
|
from cerberus import Validator
|
||||||
|
|
||||||
from .errors import BadRequest
|
from .errors import BadRequest
|
||||||
|
from .enums import ActivityType, StatusType
|
||||||
|
|
||||||
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
|
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
|
||||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
|
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
|
||||||
|
|
@ -22,17 +23,38 @@ class LitecordValidator(Validator):
|
||||||
"""Validate against the username regex."""
|
"""Validate against the username regex."""
|
||||||
return bool(USERNAME_REGEX.match(value))
|
return bool(USERNAME_REGEX.match(value))
|
||||||
|
|
||||||
|
def _validate_type_snowflake(self, value: str) -> bool:
|
||||||
|
try:
|
||||||
|
int(value)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
def _validate_type_voice_region(self, value: str) -> bool:
|
def _validate_type_voice_region(self, value: str) -> bool:
|
||||||
# TODO: complete this list
|
# TODO: complete this list
|
||||||
return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
|
return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia')
|
||||||
|
|
||||||
|
def _validate_type_activity_type(self, value: int) -> bool:
|
||||||
|
return value in ActivityType.values()
|
||||||
|
|
||||||
def validate(reqjson, schema):
|
def _validate_type_status_external(self, value: str) -> bool:
|
||||||
|
statuses = StatusType.values()
|
||||||
|
|
||||||
|
# clients should send INVISIBLE instead of OFFLINE
|
||||||
|
statuses.remove(StatusType.OFFLINE.value)
|
||||||
|
|
||||||
|
return value in statuses
|
||||||
|
|
||||||
|
|
||||||
|
def validate(reqjson, schema, raise_err: bool = False):
|
||||||
validator = LitecordValidator(schema)
|
validator = LitecordValidator(schema)
|
||||||
if not validator.validate(reqjson):
|
if not validator.validate(reqjson):
|
||||||
errs = validator.errors
|
errs = validator.errors
|
||||||
|
|
||||||
raise BadRequest('bad payload', errs)
|
if raise_err:
|
||||||
|
raise BadRequest('bad payload', errs)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
return reqjson
|
return reqjson
|
||||||
|
|
||||||
|
|
@ -75,6 +97,7 @@ MEMBER_UPDATE = {
|
||||||
'channel_id': {'type': 'snowflake', 'required': False},
|
'channel_id': {'type': 'snowflake', 'required': False},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
MESSAGE_CREATE = {
|
MESSAGE_CREATE = {
|
||||||
'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000},
|
'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000},
|
||||||
'nonce': {'type': 'string', 'required': False},
|
'nonce': {'type': 'string', 'required': False},
|
||||||
|
|
@ -82,3 +105,71 @@ MESSAGE_CREATE = {
|
||||||
|
|
||||||
# TODO: file, embed, payload_json
|
# TODO: file, embed, payload_json
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
GW_ACTIVITY = {
|
||||||
|
'name': {'type': 'string', 'required': True},
|
||||||
|
'type': {'type': 'activity_type', 'required': True},
|
||||||
|
|
||||||
|
'url': {'type': 'string', 'required': False, 'nullable': True},
|
||||||
|
|
||||||
|
'timestamps': {
|
||||||
|
'type': 'dict',
|
||||||
|
'required': False,
|
||||||
|
'schema': {
|
||||||
|
'start': {'type': 'number', 'required': True},
|
||||||
|
'end': {'type': 'number', 'required': True},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
'application_id': {'type': 'snowflake', 'required': False,
|
||||||
|
'nullable': False},
|
||||||
|
'details': {'type': 'string', 'required': False, 'nullable': True},
|
||||||
|
'state': {'type': 'string', 'required': False, 'nullable': True},
|
||||||
|
|
||||||
|
'party': {
|
||||||
|
'type': 'dict',
|
||||||
|
'required': False,
|
||||||
|
'schema': {
|
||||||
|
'id': {'type': 'snowflake', 'required': False},
|
||||||
|
'size': {'type': 'list', 'required': False},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
'assets': {
|
||||||
|
'type': 'dict',
|
||||||
|
'required': False,
|
||||||
|
'schema': {
|
||||||
|
'large_image': {'type': 'snowflake', 'required': False},
|
||||||
|
'large_text': {'type': 'string', 'required': False},
|
||||||
|
'small_image': {'type': 'snowflake', 'required': False},
|
||||||
|
'small_text': {'type': 'string', 'required': False},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
'secrets': {
|
||||||
|
'type': 'dict',
|
||||||
|
'required': False,
|
||||||
|
'schema': {
|
||||||
|
'join': {'type': 'string', 'required': False},
|
||||||
|
'spectate': {'type': 'string', 'required': False},
|
||||||
|
'match': {'type': 'string', 'required': False},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
'instance': {'type': 'boolean', 'required': False},
|
||||||
|
'flags': {'type': 'number', 'required': False},
|
||||||
|
}
|
||||||
|
|
||||||
|
GW_STATUS_UPDATE = {
|
||||||
|
'status': {'type': 'status_external', 'required': False},
|
||||||
|
'afk': {'type': 'boolean', 'required': False},
|
||||||
|
|
||||||
|
'since': {'type': 'number', 'required': True, 'nullable': True},
|
||||||
|
'game': {
|
||||||
|
'type': 'dict',
|
||||||
|
'required': True,
|
||||||
|
'nullable': True,
|
||||||
|
'schema': GW_ACTIVITY,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,14 @@
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
from .enums import ChannelType
|
from .enums import ChannelType
|
||||||
from .schemas import USER_MENTION, ROLE_MENTION
|
from .schemas import USER_MENTION, ROLE_MENTION
|
||||||
|
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def _dummy(any_id):
|
async def _dummy(any_id):
|
||||||
return str(any_id)
|
return str(any_id)
|
||||||
|
|
||||||
|
|
@ -157,12 +162,13 @@ class Storage:
|
||||||
|
|
||||||
return members
|
return members
|
||||||
|
|
||||||
async def _channels_extra(self, row, channel_type: int) -> Dict:
|
async def _channels_extra(self, row) -> Dict:
|
||||||
"""Fill in more information about a channel."""
|
"""Fill in more information about a channel."""
|
||||||
# TODO: This could probably be better with a dictionary.
|
channel_type = row['type']
|
||||||
|
|
||||||
# TODO: dm and group dm?
|
# TODO: dm and group dm?
|
||||||
if channel_type == ChannelType.GUILD_TEXT:
|
chan_type = ChannelType(channel_type)
|
||||||
|
if chan_type == ChannelType.GUILD_TEXT:
|
||||||
topic = await self.db.fetchval("""
|
topic = await self.db.fetchval("""
|
||||||
SELECT topic FROM guild_text_channels
|
SELECT topic FROM guild_text_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
@ -171,7 +177,7 @@ class Storage:
|
||||||
return {**row, **{
|
return {**row, **{
|
||||||
'topic': topic,
|
'topic': topic,
|
||||||
}}
|
}}
|
||||||
elif channel_type == ChannelType.GUILD_VOICE:
|
elif chan_type == ChannelType.GUILD_VOICE:
|
||||||
vrow = await self.db.fetchval("""
|
vrow = await self.db.fetchval("""
|
||||||
SELECT bitrate, user_limit FROM guild_voice_channels
|
SELECT bitrate, user_limit FROM guild_voice_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
@ -179,6 +185,8 @@ class Storage:
|
||||||
|
|
||||||
return {**row, **dict(vrow)}
|
return {**row, **dict(vrow)}
|
||||||
|
|
||||||
|
log.warning('unknown channel type: {}', chan_type)
|
||||||
|
|
||||||
async def get_chan_type(self, channel_id) -> int:
|
async def get_chan_type(self, channel_id) -> int:
|
||||||
return await self.db.fetchval("""
|
return await self.db.fetchval("""
|
||||||
SELECT channel_type
|
SELECT channel_type
|
||||||
|
|
@ -205,16 +213,19 @@ class Storage:
|
||||||
"""Fetch a single channel's information."""
|
"""Fetch a single channel's information."""
|
||||||
chan_type = await self.get_chan_type(channel_id)
|
chan_type = await self.get_chan_type(channel_id)
|
||||||
|
|
||||||
if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
if ChannelType(chan_type) in (ChannelType.GUILD_TEXT,
|
||||||
ChannelType.GUILD_CATEGORY):
|
ChannelType.GUILD_VOICE,
|
||||||
|
ChannelType.GUILD_CATEGORY):
|
||||||
base = await self.db.fetchrow("""
|
base = await self.db.fetchrow("""
|
||||||
SELECT id, guild_id::text, parent_id, name, position, nsfw
|
SELECT id, guild_id::text, parent_id, name, position, nsfw
|
||||||
FROM guild_channels
|
FROM guild_channels
|
||||||
WHERE guild_channels.id = $1
|
WHERE guild_channels.id = $1
|
||||||
""", channel_id)
|
""", channel_id)
|
||||||
|
|
||||||
res = await self._channels_extra(dict(base), chan_type)
|
dbase = dict(base)
|
||||||
res['type'] = chan_type
|
dbase['type'] = chan_type
|
||||||
|
|
||||||
|
res = await self._channels_extra(dbase)
|
||||||
res['permission_overwrites'] = \
|
res['permission_overwrites'] = \
|
||||||
list(await self._chan_overwrites(channel_id))
|
list(await self._chan_overwrites(channel_id))
|
||||||
|
|
||||||
|
|
@ -240,8 +251,12 @@ class Storage:
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", row['id'])
|
""", row['id'])
|
||||||
|
|
||||||
res = await self._channels_extra(dict(row), ctype)
|
drow = dict(row)
|
||||||
res['type'] = ctype
|
drow['type'] = ctype
|
||||||
|
|
||||||
|
res = await self._channels_extra(drow)
|
||||||
|
|
||||||
|
print(res)
|
||||||
|
|
||||||
res['permission_overwrites'] = \
|
res['permission_overwrites'] = \
|
||||||
list(await self._chan_overwrites(row['id']))
|
list(await self._chan_overwrites(row['id']))
|
||||||
|
|
|
||||||
3
run.py
3
run.py
|
|
@ -88,7 +88,8 @@ async def app_before_serving():
|
||||||
app.state_manager = StateManager()
|
app.state_manager = StateManager()
|
||||||
app.dispatcher = EventDispatcher(app.state_manager)
|
app.dispatcher = EventDispatcher(app.state_manager)
|
||||||
app.storage = Storage(app.db)
|
app.storage = Storage(app.db)
|
||||||
app.presence = PresenceManager(app.storage, app.state_manager)
|
app.presence = PresenceManager(app.storage,
|
||||||
|
app.state_manager, app.dispatcher)
|
||||||
|
|
||||||
# start the websocket, etc
|
# start the websocket, etc
|
||||||
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue