mirror of https://gitlab.com/litecord/litecord.git
litecord: add Storage
Storage serves as a way to reduce code repeatbility. So that we
don't need to keep repeating the same SQL statements over and over,
and to detach some SQL calls into their own code (like guild fetching)
- gateway.websocket: add WebsocketObjects to hold db, state_manager,
storage and loop
- gateway.websocket: add _make_guild_list
- schema: add members.deafened, members.muted
This commit is contained in:
parent
f7d530c787
commit
3eb6d5e60f
|
|
@ -2,7 +2,7 @@ import urllib.parse
|
|||
from .websocket import GatewayWebsocket
|
||||
|
||||
|
||||
async def websocket_handler(db, sm, ws, url):
|
||||
async def websocket_handler(prop, ws, url):
|
||||
qs = urllib.parse.parse_qs(
|
||||
urllib.parse.urlparse(url).query
|
||||
)
|
||||
|
|
@ -27,6 +27,6 @@ async def websocket_handler(db, sm, ws, url):
|
|||
if gw_compress and gw_compress not in ('zlib-stream',):
|
||||
return await ws.close(1000, 'Invalid gateway compress')
|
||||
|
||||
gws = GatewayWebsocket(sm, db, ws, v=gw_version,
|
||||
gws = GatewayWebsocket(ws, prop=prop, v=gw_version,
|
||||
encoding=gw_encoding, compress=gw_compress)
|
||||
await gws.run()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import collections
|
||||
from typing import List
|
||||
|
||||
import earl
|
||||
from logbook import Logger
|
||||
|
|
@ -17,6 +18,10 @@ WebsocketProperties = collections.namedtuple(
|
|||
'WebsocketProperties', 'v encoding compress'
|
||||
)
|
||||
|
||||
WebsocketObjects = collections.namedtuple(
|
||||
'WebsocketObjects', 'db state_manager storage loop'
|
||||
)
|
||||
|
||||
|
||||
def encode_json(payload) -> str:
|
||||
return json.dumps(payload)
|
||||
|
|
@ -30,16 +35,17 @@ def encode_etf(payload) -> str:
|
|||
return earl.pack(payload)
|
||||
|
||||
|
||||
def decode_etf(data):
|
||||
def decode_etf(data: bytes):
|
||||
return earl.unpack(data)
|
||||
|
||||
|
||||
class GatewayWebsocket:
|
||||
"""Main gateway websocket logic."""
|
||||
|
||||
def __init__(self, sm, db, ws, **kwargs):
|
||||
self.state_manager = sm
|
||||
self.db = db
|
||||
def __init__(self, ws, **kwargs):
|
||||
self.ext = WebsocketObjects(*kwargs['prop'])
|
||||
self.storage = self.ext.storage
|
||||
self.state_manager = self.ext.state_manager
|
||||
self.ws = ws
|
||||
|
||||
self.wsp = WebsocketProperties(kwargs.get('v'),
|
||||
|
|
@ -91,22 +97,53 @@ class GatewayWebsocket:
|
|||
'd': data,
|
||||
})
|
||||
|
||||
async def _make_guild_list(self) -> List[int]:
|
||||
# TODO: This function does not account for sharding.
|
||||
user_id = self.state.user_id
|
||||
|
||||
guild_ids = await self.ext.db.fetch("""
|
||||
SELECT guild_id
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
return [{
|
||||
'id': row[0],
|
||||
'unavailable': True,
|
||||
} for row in guild_ids]
|
||||
|
||||
async def guild_dispatch(self, unavailable_guilds: List[dict]):
|
||||
for guild_obj in unavailable_guilds:
|
||||
guild = await self.storage.get_guild(guild_obj['id'],
|
||||
self.state.user_id)
|
||||
|
||||
if not guild:
|
||||
continue
|
||||
|
||||
await self.dispatch('GUILD_CREATE', dict(guild))
|
||||
|
||||
async def dispatch_ready(self):
|
||||
"""Dispatch the READY packet for a connecting user."""
|
||||
guilds = await self._make_guild_list()
|
||||
user = await self.storage.get_user(self.state.user_id, True)
|
||||
|
||||
await self.dispatch('READY', {
|
||||
'v': 6,
|
||||
'user': {},
|
||||
'user': user,
|
||||
'private_channels': [],
|
||||
'guilds': [],
|
||||
'guilds': guilds,
|
||||
'session_id': self.state.session_id,
|
||||
'_trace': ['transbian']
|
||||
})
|
||||
|
||||
# async dispatch of guilds
|
||||
self.ext.loop.create_task(self.guild_dispatch(guilds))
|
||||
|
||||
async def _check_shards(self):
|
||||
shard = self.state.shard
|
||||
current_shard, shard_count = shard
|
||||
|
||||
guilds = await self.db.fetchval("""
|
||||
guilds = await self.ext.db.fetchval("""
|
||||
SELECT COUNT(*)
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
|
|
@ -139,7 +176,7 @@ class GatewayWebsocket:
|
|||
presence = data.get('presence')
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.db)
|
||||
user_id = await raw_token_check(token, self.ext.db)
|
||||
except AuthError:
|
||||
raise WebsocketClose(4004, 'Authentication failed')
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
from typing import Dict
|
||||
|
||||
|
||||
class Storage:
|
||||
"""Class for common SQL statements."""
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
|
||||
async def get_user(self, guild_id, secure=False):
|
||||
pass
|
||||
|
||||
async def get_guild(self, guild_id: int, state) -> Dict:
|
||||
row = await self.db.fetchrow("""
|
||||
SELECT *
|
||||
FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
|
||||
if not row:
|
||||
return
|
||||
|
||||
drow = dict(row)
|
||||
|
||||
if state:
|
||||
drow['owner'] = drow['owner_id'] == state.user_id
|
||||
|
||||
# TODO: Probably a really bad idea to repeat str() calls
|
||||
# Any ideas to make this simpler?
|
||||
# (No, changing the types on the db wouldn't be nice)
|
||||
drow['id'] = str(drow['id'])
|
||||
drow['owner_id'] = str(drow['owner_id'])
|
||||
drow['afk_channel_id'] = str(drow['afk_channel_id'])
|
||||
drow['embed_channel_id'] = str(drow['embed_channel_id'])
|
||||
drow['widget_channel_id'] = str(drow['widget_channel_id'])
|
||||
drow['system_channel_id'] = str(drow['system_channel_id'])
|
||||
|
||||
return {**drow, **{
|
||||
'roles': [],
|
||||
'emojis': [],
|
||||
}}
|
||||
|
||||
async def get_guild_extra(self, guild_id: int, state=None) -> Dict:
|
||||
"""Get extra information about a guild."""
|
||||
res = {}
|
||||
|
||||
member_count = await self.db.fetchval("""
|
||||
SELECT COUNT(*)
|
||||
FROM members
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
if state:
|
||||
joined_at = await self.db.fetchval("""
|
||||
SELECT joined_at
|
||||
FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, state.user_id)
|
||||
|
||||
res['large'] = state.large > member_count
|
||||
res['joined_at'] = joined_at.isoformat()
|
||||
|
||||
members_basic = await self.db.fetch("""
|
||||
SELECT user_id, nickname, joined_at
|
||||
FROM members
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
members = []
|
||||
|
||||
for row in members_basic:
|
||||
member_id = row['user_id']
|
||||
|
||||
members_roles = await self.db.fetch("""
|
||||
SELECT role_id
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
members.append({
|
||||
'user': await self.get_user(member_id),
|
||||
'nick': row['nickname'],
|
||||
'roles': [str(row[0]) for row in members_roles],
|
||||
'joined_at': row['joined_at'].isoformat(),
|
||||
'deaf': row['deafened'],
|
||||
'mute': row['muted'],
|
||||
})
|
||||
|
||||
return {**res, **{
|
||||
'member_count': member_count,
|
||||
'members': members,
|
||||
'voice_states': [],
|
||||
# TODO: finish those
|
||||
'channels': [],
|
||||
'presences': [],
|
||||
}}
|
||||
5
run.py
5
run.py
|
|
@ -12,6 +12,7 @@ from litecord.blueprints import gateway, auth
|
|||
from litecord.gateway import websocket_handler
|
||||
from litecord.errors import LitecordError
|
||||
from litecord.gateway.state_manager import StateManager
|
||||
from litecord.storage import Storage
|
||||
|
||||
# setup logbook
|
||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||
|
|
@ -46,6 +47,7 @@ async def app_before_serving():
|
|||
g.loop = asyncio.get_event_loop()
|
||||
|
||||
app.state_manager = StateManager()
|
||||
app.storage = Storage(app.db)
|
||||
|
||||
# start the websocket, etc
|
||||
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
||||
|
|
@ -54,7 +56,8 @@ async def app_before_serving():
|
|||
async def _wrapper(ws, url):
|
||||
# We wrap the main websocket_handler
|
||||
# so we can pass quart's app object.
|
||||
await websocket_handler(app.db, app.state_manager, ws, url)
|
||||
await websocket_handler((app.db, app.state_manager,
|
||||
app.storage, app.loop), ws, url)
|
||||
|
||||
ws_future = websockets.serve(_wrapper, host, port)
|
||||
|
||||
|
|
|
|||
|
|
@ -267,6 +267,8 @@ CREATE TABLE IF NOT EXISTS members (
|
|||
guild_id bigint REFERENCES guilds (id) ON DELETE CASCADE,
|
||||
nickname varchar(100) DEFAULT NULL,
|
||||
joined_at timestamp without time zone default now(),
|
||||
deafened boolean DEFAULT false,
|
||||
muted boolean DEFAULT false,
|
||||
PRIMARY KEY (user_id, guild_id)
|
||||
);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue