litecord: add Storage

Storage serves as a way to reduce code repeatbility. So that we
don't need to keep repeating the same SQL statements over and over,
and to detach some SQL calls into their own code (like guild fetching)

 - gateway.websocket: add WebsocketObjects to hold db, state_manager,
    storage and loop

 - gateway.websocket: add _make_guild_list
 - schema: add members.deafened, members.muted
This commit is contained in:
Luna Mendes 2018-06-20 16:53:22 -03:00
parent f7d530c787
commit 3eb6d5e60f
5 changed files with 148 additions and 11 deletions

View File

@ -2,7 +2,7 @@ import urllib.parse
from .websocket import GatewayWebsocket from .websocket import GatewayWebsocket
async def websocket_handler(db, sm, ws, url): async def websocket_handler(prop, ws, url):
qs = urllib.parse.parse_qs( qs = urllib.parse.parse_qs(
urllib.parse.urlparse(url).query urllib.parse.urlparse(url).query
) )
@ -27,6 +27,6 @@ async def websocket_handler(db, sm, ws, url):
if gw_compress and gw_compress not in ('zlib-stream',): if gw_compress and gw_compress not in ('zlib-stream',):
return await ws.close(1000, 'Invalid gateway compress') return await ws.close(1000, 'Invalid gateway compress')
gws = GatewayWebsocket(sm, db, ws, v=gw_version, gws = GatewayWebsocket(ws, prop=prop, v=gw_version,
encoding=gw_encoding, compress=gw_compress) encoding=gw_encoding, compress=gw_compress)
await gws.run() await gws.run()

View File

@ -1,5 +1,6 @@
import json import json
import collections import collections
from typing import List
import earl import earl
from logbook import Logger from logbook import Logger
@ -17,6 +18,10 @@ WebsocketProperties = collections.namedtuple(
'WebsocketProperties', 'v encoding compress' 'WebsocketProperties', 'v encoding compress'
) )
WebsocketObjects = collections.namedtuple(
'WebsocketObjects', 'db state_manager storage loop'
)
def encode_json(payload) -> str: def encode_json(payload) -> str:
return json.dumps(payload) return json.dumps(payload)
@ -30,16 +35,17 @@ def encode_etf(payload) -> str:
return earl.pack(payload) return earl.pack(payload)
def decode_etf(data): def decode_etf(data: bytes):
return earl.unpack(data) return earl.unpack(data)
class GatewayWebsocket: class GatewayWebsocket:
"""Main gateway websocket logic.""" """Main gateway websocket logic."""
def __init__(self, sm, db, ws, **kwargs): def __init__(self, ws, **kwargs):
self.state_manager = sm self.ext = WebsocketObjects(*kwargs['prop'])
self.db = db self.storage = self.ext.storage
self.state_manager = self.ext.state_manager
self.ws = ws self.ws = ws
self.wsp = WebsocketProperties(kwargs.get('v'), self.wsp = WebsocketProperties(kwargs.get('v'),
@ -91,22 +97,53 @@ class GatewayWebsocket:
'd': data, 'd': data,
}) })
async def _make_guild_list(self) -> List[int]:
# TODO: This function does not account for sharding.
user_id = self.state.user_id
guild_ids = await self.ext.db.fetch("""
SELECT guild_id
FROM members
WHERE user_id = $1
""", user_id)
return [{
'id': row[0],
'unavailable': True,
} for row in guild_ids]
async def guild_dispatch(self, unavailable_guilds: List[dict]):
for guild_obj in unavailable_guilds:
guild = await self.storage.get_guild(guild_obj['id'],
self.state.user_id)
if not guild:
continue
await self.dispatch('GUILD_CREATE', dict(guild))
async def dispatch_ready(self): async def dispatch_ready(self):
"""Dispatch the READY packet for a connecting user.""" """Dispatch the READY packet for a connecting user."""
guilds = await self._make_guild_list()
user = await self.storage.get_user(self.state.user_id, True)
await self.dispatch('READY', { await self.dispatch('READY', {
'v': 6, 'v': 6,
'user': {}, 'user': user,
'private_channels': [], 'private_channels': [],
'guilds': [], 'guilds': guilds,
'session_id': self.state.session_id, 'session_id': self.state.session_id,
'_trace': ['transbian'] '_trace': ['transbian']
}) })
# async dispatch of guilds
self.ext.loop.create_task(self.guild_dispatch(guilds))
async def _check_shards(self): async def _check_shards(self):
shard = self.state.shard shard = self.state.shard
current_shard, shard_count = shard current_shard, shard_count = shard
guilds = await self.db.fetchval(""" guilds = await self.ext.db.fetchval("""
SELECT COUNT(*) SELECT COUNT(*)
FROM members FROM members
WHERE user_id = $1 WHERE user_id = $1
@ -139,7 +176,7 @@ class GatewayWebsocket:
presence = data.get('presence') presence = data.get('presence')
try: try:
user_id = await raw_token_check(token, self.db) user_id = await raw_token_check(token, self.ext.db)
except AuthError: except AuthError:
raise WebsocketClose(4004, 'Authentication failed') raise WebsocketClose(4004, 'Authentication failed')

95
litecord/storage.py Normal file
View File

@ -0,0 +1,95 @@
from typing import Dict
class Storage:
"""Class for common SQL statements."""
def __init__(self, db):
self.db = db
async def get_user(self, guild_id, secure=False):
pass
async def get_guild(self, guild_id: int, state) -> Dict:
row = await self.db.fetchrow("""
SELECT *
FROM guilds
WHERE guilds.id = $1
""", guild_id)
if not row:
return
drow = dict(row)
if state:
drow['owner'] = drow['owner_id'] == state.user_id
# TODO: Probably a really bad idea to repeat str() calls
# Any ideas to make this simpler?
# (No, changing the types on the db wouldn't be nice)
drow['id'] = str(drow['id'])
drow['owner_id'] = str(drow['owner_id'])
drow['afk_channel_id'] = str(drow['afk_channel_id'])
drow['embed_channel_id'] = str(drow['embed_channel_id'])
drow['widget_channel_id'] = str(drow['widget_channel_id'])
drow['system_channel_id'] = str(drow['system_channel_id'])
return {**drow, **{
'roles': [],
'emojis': [],
}}
async def get_guild_extra(self, guild_id: int, state=None) -> Dict:
"""Get extra information about a guild."""
res = {}
member_count = await self.db.fetchval("""
SELECT COUNT(*)
FROM members
WHERE guild_id = $1
""", guild_id)
if state:
joined_at = await self.db.fetchval("""
SELECT joined_at
FROM members
WHERE guild_id = $1 AND user_id = $2
""", guild_id, state.user_id)
res['large'] = state.large > member_count
res['joined_at'] = joined_at.isoformat()
members_basic = await self.db.fetch("""
SELECT user_id, nickname, joined_at
FROM members
WHERE guild_id = $1
""", guild_id)
members = []
for row in members_basic:
member_id = row['user_id']
members_roles = await self.db.fetch("""
SELECT role_id
FROM member_roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
members.append({
'user': await self.get_user(member_id),
'nick': row['nickname'],
'roles': [str(row[0]) for row in members_roles],
'joined_at': row['joined_at'].isoformat(),
'deaf': row['deafened'],
'mute': row['muted'],
})
return {**res, **{
'member_count': member_count,
'members': members,
'voice_states': [],
# TODO: finish those
'channels': [],
'presences': [],
}}

5
run.py
View File

@ -12,6 +12,7 @@ from litecord.blueprints import gateway, auth
from litecord.gateway import websocket_handler from litecord.gateway import websocket_handler
from litecord.errors import LitecordError from litecord.errors import LitecordError
from litecord.gateway.state_manager import StateManager from litecord.gateway.state_manager import StateManager
from litecord.storage import Storage
# setup logbook # setup logbook
handler = StreamHandler(sys.stdout, level=logbook.INFO) handler = StreamHandler(sys.stdout, level=logbook.INFO)
@ -46,6 +47,7 @@ async def app_before_serving():
g.loop = asyncio.get_event_loop() g.loop = asyncio.get_event_loop()
app.state_manager = StateManager() app.state_manager = StateManager()
app.storage = Storage(app.db)
# start the websocket, etc # start the websocket, etc
host, port = app.config['WS_HOST'], app.config['WS_PORT'] host, port = app.config['WS_HOST'], app.config['WS_PORT']
@ -54,7 +56,8 @@ async def app_before_serving():
async def _wrapper(ws, url): async def _wrapper(ws, url):
# We wrap the main websocket_handler # We wrap the main websocket_handler
# so we can pass quart's app object. # so we can pass quart's app object.
await websocket_handler(app.db, app.state_manager, ws, url) await websocket_handler((app.db, app.state_manager,
app.storage, app.loop), ws, url)
ws_future = websockets.serve(_wrapper, host, port) ws_future = websockets.serve(_wrapper, host, port)

View File

@ -267,6 +267,8 @@ CREATE TABLE IF NOT EXISTS members (
guild_id bigint REFERENCES guilds (id) ON DELETE CASCADE, guild_id bigint REFERENCES guilds (id) ON DELETE CASCADE,
nickname varchar(100) DEFAULT NULL, nickname varchar(100) DEFAULT NULL,
joined_at timestamp without time zone default now(), joined_at timestamp without time zone default now(),
deafened boolean DEFAULT false,
muted boolean DEFAULT false,
PRIMARY KEY (user_id, guild_id) PRIMARY KEY (user_id, guild_id)
); );