mirror of https://gitlab.com/litecord/litecord.git
444 lines
12 KiB
Python
444 lines
12 KiB
Python
"""
|
|
|
|
Litecord
|
|
Copyright (C) 2018-2019 Luna Mendes
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, version 3 of the License.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import sys
|
|
from multiprocessing.managers import BaseManager
|
|
from typing import Dict, Tuple, List
|
|
|
|
import asyncpg
|
|
import logbook
|
|
import logging
|
|
import websockets
|
|
from quart import Quart, jsonify, request
|
|
from logbook import StreamHandler, Logger
|
|
from logbook.compat import redirect_logging
|
|
from aiohttp import ClientSession
|
|
|
|
# import the config set by instance owner
|
|
import config
|
|
|
|
from litecord.blueprints import (
|
|
gateway,
|
|
auth,
|
|
users,
|
|
guilds,
|
|
channels,
|
|
webhooks,
|
|
science,
|
|
voice,
|
|
invites,
|
|
relationships,
|
|
dms,
|
|
icons,
|
|
nodeinfo,
|
|
static,
|
|
attachments,
|
|
dm_channels,
|
|
)
|
|
|
|
# those blueprints are separated from the "main" ones
|
|
# for code readability if people want to dig through
|
|
# the codebase.
|
|
from litecord.blueprints.guild import (
|
|
guild_roles,
|
|
guild_members,
|
|
guild_channels,
|
|
guild_mod,
|
|
guild_emoji,
|
|
)
|
|
|
|
from litecord.blueprints.channel import (
|
|
channel_messages,
|
|
channel_reactions,
|
|
channel_pins,
|
|
)
|
|
|
|
from litecord.blueprints.user import user_settings, user_billing, fake_store
|
|
|
|
from litecord.blueprints.user.billing_job import payment_job
|
|
|
|
from litecord.blueprints.admin_api import (
|
|
voice as voice_admin,
|
|
features as features_admin,
|
|
guilds as guilds_admin,
|
|
users as users_admin,
|
|
instance_invites,
|
|
)
|
|
|
|
from litecord.blueprints.admin_api.voice import guild_region_check
|
|
|
|
from litecord.ratelimits.handler import ratelimit_handler
|
|
from litecord.ratelimits.main import RatelimitManager
|
|
|
|
from litecord.errors import LitecordError
|
|
from litecord.gateway.state_manager import StateManager
|
|
from litecord.storage import Storage
|
|
from litecord.user_storage import UserStorage
|
|
from litecord.dispatcher import EventDispatcher
|
|
from litecord.presence import PresenceManager
|
|
from litecord.images import IconManager
|
|
from litecord.jobs import JobManager
|
|
from litecord.voice.manager import VoiceManager
|
|
from litecord.guild_memory_store import GuildMemoryStore
|
|
from litecord.pubsub.lazy_guild import LazyGuildManager
|
|
|
|
from litecord.gateway.gateway import websocket_handler
|
|
|
|
from litecord.utils import LitecordJSONEncoder
|
|
|
|
# setup logbook
|
|
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
|
handler.push_application()
|
|
log = Logger("litecord.boot")
|
|
redirect_logging()
|
|
|
|
|
|
def make_app():
|
|
app = Quart(__name__)
|
|
app.config.from_object(f"config.{config.MODE}")
|
|
is_debug = app.config.get("DEBUG", False)
|
|
app.debug = is_debug
|
|
|
|
if is_debug:
|
|
log.info("on debug")
|
|
handler.level = logbook.DEBUG
|
|
app.logger.level = logbook.DEBUG
|
|
|
|
# always keep websockets on INFO
|
|
logging.getLogger("websockets").setLevel(logbook.INFO)
|
|
|
|
# use our custom json encoder for custom data types
|
|
app.json_encoder = LitecordJSONEncoder
|
|
|
|
return app
|
|
|
|
|
|
def set_blueprints(app_):
|
|
"""Set the blueprints for a given app instance"""
|
|
bps = {
|
|
gateway: None,
|
|
auth: "/auth",
|
|
users: "/users",
|
|
user_settings: "/users",
|
|
user_billing: "/users",
|
|
relationships: "/users",
|
|
guilds: "/guilds",
|
|
guild_roles: "/guilds",
|
|
guild_members: "/guilds",
|
|
guild_channels: "/guilds",
|
|
guild_mod: "/guilds",
|
|
guild_emoji: "/guilds",
|
|
channels: "/channels",
|
|
channel_messages: "/channels",
|
|
channel_reactions: "/channels",
|
|
channel_pins: "/channels",
|
|
webhooks: None,
|
|
science: None,
|
|
voice: "/voice",
|
|
invites: None,
|
|
dms: "/users",
|
|
dm_channels: "/channels",
|
|
fake_store: None,
|
|
icons: -1,
|
|
attachments: -1,
|
|
nodeinfo: -1,
|
|
static: -1,
|
|
voice_admin: "/admin/voice",
|
|
features_admin: "/admin/guilds",
|
|
guilds_admin: "/admin/guilds",
|
|
users_admin: "/admin/users",
|
|
instance_invites: "/admin/instance/invites",
|
|
}
|
|
|
|
for bp, suffix in bps.items():
|
|
url_prefix = f'/api/v6{suffix or ""}'
|
|
|
|
if suffix == -1:
|
|
url_prefix = ""
|
|
|
|
app_.register_blueprint(bp, url_prefix=url_prefix)
|
|
|
|
|
|
app = make_app()
|
|
set_blueprints(app)
|
|
|
|
|
|
@app.before_request
|
|
async def app_before_request():
|
|
"""Functions to call before the request actually
|
|
takes place."""
|
|
await ratelimit_handler()
|
|
|
|
|
|
@app.after_request
|
|
async def app_after_request(resp):
|
|
"""Handle CORS headers."""
|
|
origin = request.headers.get("Origin", "*")
|
|
resp.headers["Access-Control-Allow-Origin"] = origin
|
|
resp.headers["Access-Control-Allow-Headers"] = (
|
|
"*, X-Super-Properties, "
|
|
"X-Fingerprint, "
|
|
"X-Context-Properties, "
|
|
"X-Failed-Requests, "
|
|
"X-Debug-Options, "
|
|
"Content-Type, "
|
|
"Authorization, "
|
|
"Origin, "
|
|
"If-None-Match"
|
|
)
|
|
resp.headers["Access-Control-Allow-Methods"] = resp.headers.get("allow", "*")
|
|
|
|
return resp
|
|
|
|
|
|
def _set_rtl_reset(bucket, resp):
|
|
reset = bucket._window + bucket.second
|
|
precision = request.headers.get("x-ratelimit-precision", "second")
|
|
|
|
if precision == "second":
|
|
resp.headers["X-RateLimit-Reset"] = str(round(reset))
|
|
elif precision == "millisecond":
|
|
resp.headers["X-RateLimit-Reset"] = str(reset)
|
|
else:
|
|
resp.headers["X-RateLimit-Reset"] = (
|
|
"Invalid X-RateLimit-Precision, " "valid options are (second, millisecond)"
|
|
)
|
|
|
|
|
|
@app.after_request
|
|
async def app_set_ratelimit_headers(resp):
|
|
"""Set the specific ratelimit headers."""
|
|
try:
|
|
bucket = request.bucket
|
|
|
|
if bucket is None:
|
|
raise AttributeError()
|
|
|
|
resp.headers["X-RateLimit-Limit"] = str(bucket.requests)
|
|
resp.headers["X-RateLimit-Remaining"] = str(bucket._tokens)
|
|
resp.headers["X-RateLimit-Global"] = str(request.bucket_global).lower()
|
|
_set_rtl_reset(bucket, resp)
|
|
|
|
# only add Retry-After if we actually hit a ratelimit
|
|
retry_after = request.retry_after
|
|
if request.retry_after:
|
|
resp.headers["Retry-After"] = str(retry_after)
|
|
except AttributeError:
|
|
pass
|
|
|
|
return resp
|
|
|
|
|
|
async def init_app_db(app_):
|
|
"""Connect to databases.
|
|
|
|
Also spawns the job scheduler.
|
|
"""
|
|
log.info("db connect")
|
|
app_.db = await asyncpg.create_pool(**app.config["POSTGRES"])
|
|
|
|
app_.sched = JobManager()
|
|
|
|
|
|
def awooawoo():
|
|
print("awoo")
|
|
|
|
|
|
def init_app_managers(app_: Quart, *, init_voice=True):
|
|
"""Initialize singleton classes."""
|
|
# app.state = BaseManager(("", 36969), b"awooawoo")
|
|
# app.state.register("get_test")
|
|
# app.state.connect()
|
|
# print(app.state)
|
|
# print(app.state.get_test())
|
|
|
|
manager = BaseManager(("", 36970), b"awooawoo")
|
|
manager.register("awooawoo", awooawoo)
|
|
|
|
app_.loop = asyncio.get_event_loop()
|
|
|
|
app_.loop.run_in_executor(None, manager.get_server().serve_forever)
|
|
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
|
|
app_.state_manager = StateManager()
|
|
|
|
app_.storage = Storage(app_)
|
|
app_.user_storage = UserStorage(app_.storage)
|
|
|
|
app_.icons = IconManager(app_)
|
|
|
|
app_.dispatcher = EventDispatcher()
|
|
app_.presence = PresenceManager(app_)
|
|
|
|
app_.storage.presence = app_.presence
|
|
|
|
# only start VoiceManager if needed.
|
|
# we do this because of a bug on ./manage.py where it
|
|
# cancels the LVSPManager's spawn regions task. we don't
|
|
# need to start it on manage time.
|
|
if init_voice:
|
|
app_.voice = VoiceManager(app_)
|
|
|
|
app_.guild_store = GuildMemoryStore()
|
|
app_.lazy_guild = LazyGuildManager()
|
|
|
|
|
|
async def api_index(app_):
|
|
to_find: Dict[Tuple[str, str], str] = {}
|
|
found: List[str] = []
|
|
|
|
with open("discord_endpoints.txt") as fd:
|
|
for line in fd.readlines():
|
|
components = line.split(" ")
|
|
components = list(filter(bool, components))
|
|
name, method, path = components
|
|
path = f"/api/v6{path.strip()}"
|
|
method = method.strip()
|
|
to_find[(path, method)] = name
|
|
|
|
for rule in app_.url_map.rules:
|
|
path = rule.rule
|
|
|
|
# convert the path to the discord_endpoints file's style
|
|
path = path.replace("_", ".")
|
|
path = path.replace("<", "{")
|
|
path = path.replace(">", "}")
|
|
path = path.replace("int:", "")
|
|
|
|
# change our parameters into user.id
|
|
path = path.replace("member.id", "user.id")
|
|
path = path.replace("banned.id", "user.id")
|
|
path = path.replace("target.id", "user.id")
|
|
path = path.replace("other.id", "user.id")
|
|
path = path.replace("peer.id", "user.id")
|
|
|
|
methods = rule.methods
|
|
|
|
for method in methods:
|
|
pathname = to_find.get((path, method))
|
|
if pathname:
|
|
found.append(pathname)
|
|
|
|
found_set = set(found)
|
|
api = set(to_find.values())
|
|
|
|
missing = api - found_set
|
|
|
|
percentage = (len(found_set) / len(api)) * 100
|
|
percentage = round(percentage, 2)
|
|
|
|
log.debug(
|
|
"API compliance: {} out of {} ({} missing), {}% compliant",
|
|
len(found_set),
|
|
len(api),
|
|
len(missing),
|
|
percentage,
|
|
)
|
|
|
|
log.debug("missing: {}", missing)
|
|
|
|
|
|
async def post_app_start(app_):
|
|
# we'll need to start a billing job
|
|
app_.sched.spawn(payment_job())
|
|
app_.sched.spawn(api_index(app_))
|
|
app_.sched.spawn(guild_region_check())
|
|
|
|
|
|
def start_websocket(host, port, ws_handler) -> asyncio.Future:
|
|
"""Start a websocket. Returns the websocket future"""
|
|
log.info(f"starting websocket at {host} {port}")
|
|
|
|
async def _wrapper(ws, url):
|
|
# We wrap the main websocket_handler
|
|
# so we can pass quart's app object.
|
|
await ws_handler(app, ws, url)
|
|
|
|
return websockets.serve(_wrapper, host, port)
|
|
|
|
|
|
@app.before_serving
|
|
async def app_before_serving():
|
|
"""Callback for variable setup.
|
|
|
|
Also sets up the websocket handlers.
|
|
"""
|
|
log.info("opening db")
|
|
await init_app_db(app)
|
|
|
|
app.session = ClientSession()
|
|
|
|
init_app_managers(app)
|
|
await post_app_start(app)
|
|
|
|
# start gateway websocket
|
|
# voice websocket is handled by the voice server
|
|
ws_fut = start_websocket(
|
|
app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler
|
|
)
|
|
|
|
await ws_fut
|
|
|
|
|
|
@app.after_serving
|
|
async def app_after_serving():
|
|
"""Shutdown tasks for the server."""
|
|
|
|
# first close all clients, then close db
|
|
tasks = app.state_manager.gen_close_tasks()
|
|
if tasks:
|
|
await asyncio.wait(tasks, loop=app.loop)
|
|
|
|
app.state_manager.close()
|
|
|
|
app.sched.close()
|
|
|
|
log.info("closing db")
|
|
await app.db.close()
|
|
|
|
|
|
@app.errorhandler(LitecordError)
|
|
async def handle_litecord_err(err):
|
|
try:
|
|
ejson = err.json
|
|
except IndexError:
|
|
ejson = {}
|
|
|
|
try:
|
|
ejson["code"] = err.error_code
|
|
except AttributeError:
|
|
pass
|
|
|
|
log.warning("error: {} {!r}", err.status_code, err.message)
|
|
|
|
return (
|
|
jsonify(
|
|
{"error": True, "status": err.status_code, "message": err.message, **ejson}
|
|
),
|
|
err.status_code,
|
|
)
|
|
|
|
|
|
@app.errorhandler(500)
|
|
async def handle_500(err):
|
|
return (
|
|
jsonify({"error": True, "message": repr(err), "internal_server_error": True}),
|
|
500,
|
|
)
|