""" Litecord Copyright (C) 2018-2019 Luna Mendes This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 3 of the License. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ import asyncio import sys from multiprocessing.managers import BaseManager from typing import Dict, Tuple, List import asyncpg import logbook import logging import websockets from quart import Quart, jsonify, request from logbook import StreamHandler, Logger from logbook.compat import redirect_logging from aiohttp import ClientSession # import the config set by instance owner import config from litecord.blueprints import ( gateway, auth, users, guilds, channels, webhooks, science, voice, invites, relationships, dms, icons, nodeinfo, static, attachments, dm_channels, ) # those blueprints are separated from the "main" ones # for code readability if people want to dig through # the codebase. from litecord.blueprints.guild import ( guild_roles, guild_members, guild_channels, guild_mod, guild_emoji, ) from litecord.blueprints.channel import ( channel_messages, channel_reactions, channel_pins, ) from litecord.blueprints.user import user_settings, user_billing, fake_store from litecord.blueprints.user.billing_job import payment_job from litecord.blueprints.admin_api import ( voice as voice_admin, features as features_admin, guilds as guilds_admin, users as users_admin, instance_invites, ) from litecord.blueprints.admin_api.voice import guild_region_check from litecord.ratelimits.handler import ratelimit_handler from litecord.ratelimits.main import RatelimitManager from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager from litecord.storage import Storage from litecord.user_storage import UserStorage from litecord.dispatcher import EventDispatcher from litecord.presence import PresenceManager from litecord.images import IconManager from litecord.jobs import JobManager from litecord.voice.manager import VoiceManager from litecord.guild_memory_store import GuildMemoryStore from litecord.pubsub.lazy_guild import LazyGuildManager from litecord.gateway.gateway import websocket_handler from litecord.utils import LitecordJSONEncoder # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) handler.push_application() log = Logger("litecord.boot") redirect_logging() def make_app(): app = Quart(__name__) app.config.from_object(f"config.{config.MODE}") is_debug = app.config.get("DEBUG", False) app.debug = is_debug if is_debug: log.info("on debug") handler.level = logbook.DEBUG app.logger.level = logbook.DEBUG # always keep websockets on INFO logging.getLogger("websockets").setLevel(logbook.INFO) # use our custom json encoder for custom data types app.json_encoder = LitecordJSONEncoder return app def set_blueprints(app_): """Set the blueprints for a given app instance""" bps = { gateway: None, auth: "/auth", users: "/users", user_settings: "/users", user_billing: "/users", relationships: "/users", guilds: "/guilds", guild_roles: "/guilds", guild_members: "/guilds", guild_channels: "/guilds", guild_mod: "/guilds", guild_emoji: "/guilds", channels: "/channels", channel_messages: "/channels", channel_reactions: "/channels", channel_pins: "/channels", webhooks: None, science: None, voice: "/voice", invites: None, dms: "/users", dm_channels: "/channels", fake_store: None, icons: -1, attachments: -1, nodeinfo: -1, static: -1, voice_admin: "/admin/voice", features_admin: "/admin/guilds", guilds_admin: "/admin/guilds", users_admin: "/admin/users", instance_invites: "/admin/instance/invites", } for bp, suffix in bps.items(): url_prefix = f'/api/v6{suffix or ""}' if suffix == -1: url_prefix = "" app_.register_blueprint(bp, url_prefix=url_prefix) app = make_app() set_blueprints(app) @app.before_request async def app_before_request(): """Functions to call before the request actually takes place.""" await ratelimit_handler() @app.after_request async def app_after_request(resp): """Handle CORS headers.""" origin = request.headers.get("Origin", "*") resp.headers["Access-Control-Allow-Origin"] = origin resp.headers["Access-Control-Allow-Headers"] = ( "*, X-Super-Properties, " "X-Fingerprint, " "X-Context-Properties, " "X-Failed-Requests, " "X-Debug-Options, " "Content-Type, " "Authorization, " "Origin, " "If-None-Match" ) resp.headers["Access-Control-Allow-Methods"] = resp.headers.get("allow", "*") return resp def _set_rtl_reset(bucket, resp): reset = bucket._window + bucket.second precision = request.headers.get("x-ratelimit-precision", "second") if precision == "second": resp.headers["X-RateLimit-Reset"] = str(round(reset)) elif precision == "millisecond": resp.headers["X-RateLimit-Reset"] = str(reset) else: resp.headers["X-RateLimit-Reset"] = ( "Invalid X-RateLimit-Precision, " "valid options are (second, millisecond)" ) @app.after_request async def app_set_ratelimit_headers(resp): """Set the specific ratelimit headers.""" try: bucket = request.bucket if bucket is None: raise AttributeError() resp.headers["X-RateLimit-Limit"] = str(bucket.requests) resp.headers["X-RateLimit-Remaining"] = str(bucket._tokens) resp.headers["X-RateLimit-Global"] = str(request.bucket_global).lower() _set_rtl_reset(bucket, resp) # only add Retry-After if we actually hit a ratelimit retry_after = request.retry_after if request.retry_after: resp.headers["Retry-After"] = str(retry_after) except AttributeError: pass return resp async def init_app_db(app_): """Connect to databases. Also spawns the job scheduler. """ log.info("db connect") app_.db = await asyncpg.create_pool(**app.config["POSTGRES"]) app_.sched = JobManager() def awooawoo(): print("awoo") def init_app_managers(app_: Quart, *, init_voice=True): """Initialize singleton classes.""" # app.state = BaseManager(("", 36969), b"awooawoo") # app.state.register("get_test") # app.state.connect() # print(app.state) # print(app.state.get_test()) manager = BaseManager(("", 36970), b"awooawoo") manager.register("awooawoo", awooawoo) app_.loop = asyncio.get_event_loop() app_.loop.run_in_executor(None, manager.get_server().serve_forever) app_.ratelimiter = RatelimitManager(app_.config.get("_testing")) app_.state_manager = StateManager() app_.storage = Storage(app_) app_.user_storage = UserStorage(app_.storage) app_.icons = IconManager(app_) app_.dispatcher = EventDispatcher() app_.presence = PresenceManager(app_) app_.storage.presence = app_.presence # only start VoiceManager if needed. # we do this because of a bug on ./manage.py where it # cancels the LVSPManager's spawn regions task. we don't # need to start it on manage time. if init_voice: app_.voice = VoiceManager(app_) app_.guild_store = GuildMemoryStore() app_.lazy_guild = LazyGuildManager() async def api_index(app_): to_find: Dict[Tuple[str, str], str] = {} found: List[str] = [] with open("discord_endpoints.txt") as fd: for line in fd.readlines(): components = line.split(" ") components = list(filter(bool, components)) name, method, path = components path = f"/api/v6{path.strip()}" method = method.strip() to_find[(path, method)] = name for rule in app_.url_map.rules: path = rule.rule # convert the path to the discord_endpoints file's style path = path.replace("_", ".") path = path.replace("<", "{") path = path.replace(">", "}") path = path.replace("int:", "") # change our parameters into user.id path = path.replace("member.id", "user.id") path = path.replace("banned.id", "user.id") path = path.replace("target.id", "user.id") path = path.replace("other.id", "user.id") path = path.replace("peer.id", "user.id") methods = rule.methods for method in methods: pathname = to_find.get((path, method)) if pathname: found.append(pathname) found_set = set(found) api = set(to_find.values()) missing = api - found_set percentage = (len(found_set) / len(api)) * 100 percentage = round(percentage, 2) log.debug( "API compliance: {} out of {} ({} missing), {}% compliant", len(found_set), len(api), len(missing), percentage, ) log.debug("missing: {}", missing) async def post_app_start(app_): # we'll need to start a billing job app_.sched.spawn(payment_job()) app_.sched.spawn(api_index(app_)) app_.sched.spawn(guild_region_check()) def start_websocket(host, port, ws_handler) -> asyncio.Future: """Start a websocket. Returns the websocket future""" log.info(f"starting websocket at {host} {port}") async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. await ws_handler(app, ws, url) return websockets.serve(_wrapper, host, port) @app.before_serving async def app_before_serving(): """Callback for variable setup. Also sets up the websocket handlers. """ log.info("opening db") await init_app_db(app) app.session = ClientSession() init_app_managers(app) await post_app_start(app) # start gateway websocket # voice websocket is handled by the voice server ws_fut = start_websocket( app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler ) await ws_fut @app.after_serving async def app_after_serving(): """Shutdown tasks for the server.""" # first close all clients, then close db tasks = app.state_manager.gen_close_tasks() if tasks: await asyncio.wait(tasks, loop=app.loop) app.state_manager.close() app.sched.close() log.info("closing db") await app.db.close() @app.errorhandler(LitecordError) async def handle_litecord_err(err): try: ejson = err.json except IndexError: ejson = {} try: ejson["code"] = err.error_code except AttributeError: pass log.warning("error: {} {!r}", err.status_code, err.message) return ( jsonify( {"error": True, "status": err.status_code, "message": err.message, **ejson} ), err.status_code, ) @app.errorhandler(500) async def handle_500(err): return ( jsonify({"error": True, "message": repr(err), "internal_server_error": True}), 500, )