mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'enhance/refactor-and-move-internals' into 'master'
Refactor and move internals See merge request litecord/litecord!47
This commit is contained in:
commit
57220179bc
138
litecord/auth.py
138
litecord/auth.py
|
|
@ -19,17 +19,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
from random import randint
|
|
||||||
from typing import Tuple, Optional
|
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from asyncpg import UniqueViolationError
|
|
||||||
from itsdangerous import TimestampSigner, BadSignature
|
from itsdangerous import TimestampSigner, BadSignature
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from quart import request, current_app as app
|
from quart import request, current_app as app
|
||||||
|
|
||||||
from litecord.errors import Forbidden, Unauthorized, BadRequest
|
from litecord.errors import Forbidden, Unauthorized
|
||||||
from litecord.snowflake import get_snowflake
|
|
||||||
from litecord.enums import UserFlags
|
from litecord.enums import UserFlags
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -56,20 +52,20 @@ async def raw_token_check(token: str, db=None) -> int:
|
||||||
# just try by fragments instead of
|
# just try by fragments instead of
|
||||||
# unpacking
|
# unpacking
|
||||||
fragments = token.split(".")
|
fragments = token.split(".")
|
||||||
user_id = fragments[0]
|
user_id_str = fragments[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = base64.b64decode(user_id.encode())
|
user_id_decoded = base64.b64decode(user_id_str.encode())
|
||||||
user_id = int(user_id)
|
user_id = int(user_id_decoded)
|
||||||
except (ValueError, binascii.Error):
|
except (ValueError, binascii.Error):
|
||||||
raise Unauthorized("Invalid user ID type")
|
raise Unauthorized("Invalid user ID type")
|
||||||
|
|
||||||
pwd_hash = await db.fetchval(
|
pwd_hash = await db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT password_hash
|
SELECT password_hash
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""",
|
""",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -88,10 +84,10 @@ async def raw_token_check(token: str, db=None) -> int:
|
||||||
# with people leaving their clients open forever)
|
# with people leaving their clients open forever)
|
||||||
await db.execute(
|
await db.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET last_session = (now() at time zone 'utc')
|
SET last_session = (now() at time zone 'utc')
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""",
|
""",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -128,10 +124,10 @@ async def admin_check() -> int:
|
||||||
|
|
||||||
flags = await app.db.fetchval(
|
flags = await app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT flags
|
SELECT flags
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""",
|
""",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -150,105 +146,3 @@ async def hash_data(data: str, loop=None) -> str:
|
||||||
hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
|
hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
|
||||||
|
|
||||||
return hashed.decode()
|
return hashed.decode()
|
||||||
|
|
||||||
|
|
||||||
async def check_username_usage(username: str):
|
|
||||||
"""Raise an error if too many people are with the same username."""
|
|
||||||
same_username = await app.db.fetchval(
|
|
||||||
"""
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM users
|
|
||||||
WHERE username = $1
|
|
||||||
""",
|
|
||||||
username,
|
|
||||||
)
|
|
||||||
|
|
||||||
if same_username > 9000:
|
|
||||||
raise BadRequest(
|
|
||||||
"Too many people.",
|
|
||||||
{
|
|
||||||
"username": "Too many people used the same username. "
|
|
||||||
"Please choose another"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _raw_discrim() -> str:
|
|
||||||
discrim_number = randint(1, 9999)
|
|
||||||
return "%04d" % discrim_number
|
|
||||||
|
|
||||||
|
|
||||||
async def roll_discrim(username: str) -> Optional[str]:
|
|
||||||
"""Roll a discriminator for a DiscordTag.
|
|
||||||
|
|
||||||
Tries to generate one 10 times.
|
|
||||||
|
|
||||||
Calls check_username_usage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# we shouldn't roll discrims for usernames
|
|
||||||
# that have been used too much.
|
|
||||||
await check_username_usage(username)
|
|
||||||
|
|
||||||
# max 10 times for a reroll
|
|
||||||
for _ in range(10):
|
|
||||||
# generate random discrim
|
|
||||||
discrim = _raw_discrim()
|
|
||||||
|
|
||||||
# check if anyone is with it
|
|
||||||
res = await app.db.fetchval(
|
|
||||||
"""
|
|
||||||
SELECT id
|
|
||||||
FROM users
|
|
||||||
WHERE username = $1 AND discriminator = $2
|
|
||||||
""",
|
|
||||||
username,
|
|
||||||
discrim,
|
|
||||||
)
|
|
||||||
|
|
||||||
# if no user is found with the (username, discrim)
|
|
||||||
# pair, then this is unique! return it.
|
|
||||||
if res is None:
|
|
||||||
return discrim
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def create_user(username: str, email: str, password: str) -> Tuple[int, str]:
|
|
||||||
"""Create a single user.
|
|
||||||
|
|
||||||
Generates a distriminator and other information. You can fetch the user
|
|
||||||
data back with :meth:`Storage.get_user`.
|
|
||||||
"""
|
|
||||||
db = app.db
|
|
||||||
loop = app.loop
|
|
||||||
|
|
||||||
new_id = get_snowflake()
|
|
||||||
new_discrim = await roll_discrim(username)
|
|
||||||
|
|
||||||
if new_discrim is None:
|
|
||||||
raise BadRequest(
|
|
||||||
"Unable to register.",
|
|
||||||
{"username": "Too many people are with this username."},
|
|
||||||
)
|
|
||||||
|
|
||||||
pwd_hash = await hash_data(password, loop)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO users
|
|
||||||
(id, email, username, discriminator, password_hash)
|
|
||||||
VALUES
|
|
||||||
($1, $2, $3, $4, $5)
|
|
||||||
""",
|
|
||||||
new_id,
|
|
||||||
email,
|
|
||||||
username,
|
|
||||||
new_discrim,
|
|
||||||
pwd_hash,
|
|
||||||
)
|
|
||||||
except UniqueViolationError:
|
|
||||||
raise BadRequest("Email already used.")
|
|
||||||
|
|
||||||
return new_id, pwd_hash
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from quart import Blueprint, jsonify, current_app as app, request
|
||||||
from litecord.auth import admin_check
|
from litecord.auth import admin_check
|
||||||
from litecord.schemas import validate
|
from litecord.schemas import validate
|
||||||
from litecord.admin_schemas import GUILD_UPDATE
|
from litecord.admin_schemas import GUILD_UPDATE
|
||||||
from litecord.blueprints.guilds import delete_guild
|
from litecord.common.guilds import delete_guild
|
||||||
from litecord.errors import GuildNotFound
|
from litecord.errors import GuildNotFound
|
||||||
|
|
||||||
bp = Blueprint("guilds_admin", __name__)
|
bp = Blueprint("guilds_admin", __name__)
|
||||||
|
|
|
||||||
|
|
@ -20,13 +20,17 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
from quart import Blueprint, jsonify, current_app as app, request
|
from quart import Blueprint, jsonify, current_app as app, request
|
||||||
|
|
||||||
from litecord.auth import admin_check
|
from litecord.auth import admin_check
|
||||||
from litecord.blueprints.auth import create_user
|
|
||||||
from litecord.schemas import validate
|
from litecord.schemas import validate
|
||||||
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
|
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
|
||||||
from litecord.errors import BadRequest, Forbidden
|
from litecord.errors import BadRequest, Forbidden
|
||||||
from litecord.utils import async_map
|
from litecord.utils import async_map
|
||||||
from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
|
|
||||||
from litecord.enums import UserFlags
|
from litecord.enums import UserFlags
|
||||||
|
from litecord.common.users import (
|
||||||
|
create_user,
|
||||||
|
delete_user,
|
||||||
|
user_disconnect,
|
||||||
|
mass_user_update,
|
||||||
|
)
|
||||||
|
|
||||||
bp = Blueprint("users_admin", __name__)
|
bp = Blueprint("users_admin", __name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ async def deprecate_region(region):
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def guild_region_check(app_):
|
async def guild_region_check():
|
||||||
"""Check all guilds for voice region inconsistencies.
|
"""Check all guilds for voice region inconsistencies.
|
||||||
|
|
||||||
Since the voice migration caused all guilds.region columns
|
Since the voice migration caused all guilds.region columns
|
||||||
|
|
@ -126,23 +126,23 @@ async def guild_region_check(app_):
|
||||||
than one region setup.
|
than one region setup.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
regions = await app_.storage.all_voice_regions()
|
regions = await app.storage.all_voice_regions()
|
||||||
|
|
||||||
if not regions:
|
if not regions:
|
||||||
log.info("region check: no regions to move guilds to")
|
log.info("region check: no regions to move guilds to")
|
||||||
return
|
return
|
||||||
|
|
||||||
res = await app_.db.execute(
|
res = await app.db.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET region = (
|
SET region = (
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM voice_regions
|
FROM voice_regions
|
||||||
OFFSET floor(random()*$1)
|
OFFSET floor(random()*$1)
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
)
|
)
|
||||||
WHERE region = NULL
|
WHERE region = NULL
|
||||||
""",
|
""",
|
||||||
len(regions),
|
len(regions),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,8 @@ import bcrypt
|
||||||
from quart import Blueprint, jsonify, request, current_app as app
|
from quart import Blueprint, jsonify, request, current_app as app
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.auth import token_check, create_user
|
from litecord.auth import token_check
|
||||||
|
from litecord.common.users import create_user
|
||||||
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
|
from litecord.schemas import validate, REGISTER, REGISTER_WITH_INVITE
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.snowflake import get_snowflake
|
||||||
|
|
@ -120,7 +121,7 @@ async def _register_with_invite():
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id, pwd_hash = await create_user(
|
user_id, pwd_hash = await create_user(
|
||||||
data["username"], data["email"], data["password"], app.db
|
data["username"], data["email"], data["password"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
|
return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
|
||||||
|
|
|
||||||
|
|
@ -17,65 +17,36 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from quart import Blueprint, request, current_app as app, jsonify
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.blueprints.auth import token_check
|
from litecord.blueprints.auth import token_check
|
||||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
from litecord.blueprints.dms import try_dm_state
|
from litecord.errors import MessageNotFound, Forbidden
|
||||||
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
|
||||||
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.snowflake import get_snowflake
|
||||||
from litecord.schemas import validate, MESSAGE_CREATE
|
from litecord.schemas import validate, MESSAGE_CREATE
|
||||||
from litecord.utils import pg_set_json
|
from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
from litecord.embed.sanitizer import fill_embed
|
from litecord.embed.sanitizer import fill_embed
|
||||||
from litecord.embed.messages import process_url_embed
|
from litecord.embed.messages import process_url_embed
|
||||||
from litecord.blueprints.channel.dm_checks import dm_pre_check
|
from litecord.common.channels import dm_pre_check, try_dm_state
|
||||||
from litecord.images import try_unlink
|
from litecord.images import try_unlink
|
||||||
|
from litecord.common.messages import (
|
||||||
|
msg_create_request,
|
||||||
|
msg_create_check_content,
|
||||||
|
msg_add_attachment,
|
||||||
|
msg_guild_text_mentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("channel_messages", __name__)
|
bp = Blueprint("channel_messages", __name__)
|
||||||
|
|
||||||
|
|
||||||
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
|
||||||
"""Extract a limit kwarg."""
|
|
||||||
try:
|
|
||||||
limit = int(request_.args.get("limit", default))
|
|
||||||
|
|
||||||
if limit not in range(0, max_val + 1):
|
|
||||||
raise ValueError()
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
raise BadRequest("limit not int")
|
|
||||||
|
|
||||||
return limit
|
|
||||||
|
|
||||||
|
|
||||||
def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
|
||||||
"""Extract a 2-tuple out of request arguments."""
|
|
||||||
before, after = None, None
|
|
||||||
|
|
||||||
if "around" in request.args:
|
|
||||||
average = int(limit / 2)
|
|
||||||
around = int(args["around"])
|
|
||||||
|
|
||||||
after = around - average
|
|
||||||
before = around + average
|
|
||||||
|
|
||||||
elif "before" in args:
|
|
||||||
before = int(args["before"])
|
|
||||||
elif "after" in args:
|
|
||||||
before = int(args["after"])
|
|
||||||
|
|
||||||
return before, after
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:channel_id>/messages", methods=["GET"])
|
@bp.route("/<int:channel_id>/messages", methods=["GET"])
|
||||||
async def get_messages(channel_id):
|
async def get_messages(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -204,191 +175,8 @@ async def create_message(
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
|
|
||||||
async def msg_guild_text_mentions(
|
async def _spawn_embed(payload, **kwargs):
|
||||||
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
|
app.sched.spawn(process_url_embed(payload, **kwargs))
|
||||||
):
|
|
||||||
"""Calculates mention data side-effects."""
|
|
||||||
channel_id = int(payload["channel_id"])
|
|
||||||
|
|
||||||
# calculate the user ids we'll bump the mention count for
|
|
||||||
uids = set()
|
|
||||||
|
|
||||||
# first is extracting user mentions
|
|
||||||
for mention in payload["mentions"]:
|
|
||||||
uids.add(int(mention["id"]))
|
|
||||||
|
|
||||||
# then role mentions
|
|
||||||
for role_mention in payload["mention_roles"]:
|
|
||||||
role_id = int(role_mention)
|
|
||||||
member_ids = await app.storage.get_role_members(role_id)
|
|
||||||
|
|
||||||
for member_id in member_ids:
|
|
||||||
uids.add(member_id)
|
|
||||||
|
|
||||||
# at-here only updates the state
|
|
||||||
# for the users that have a state
|
|
||||||
# in the channel.
|
|
||||||
if mentions_here:
|
|
||||||
uids = set()
|
|
||||||
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
UPDATE user_read_state
|
|
||||||
SET mention_count = mention_count + 1
|
|
||||||
WHERE channel_id = $1
|
|
||||||
""",
|
|
||||||
channel_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# at-here updates the read state
|
|
||||||
# for all users, including the ones
|
|
||||||
# that might not have read permissions
|
|
||||||
# to the channel.
|
|
||||||
if mentions_everyone:
|
|
||||||
uids = set()
|
|
||||||
|
|
||||||
member_ids = await app.storage.get_member_ids(guild_id)
|
|
||||||
|
|
||||||
await app.db.executemany(
|
|
||||||
"""
|
|
||||||
UPDATE user_read_state
|
|
||||||
SET mention_count = mention_count + 1
|
|
||||||
WHERE channel_id = $1 AND user_id = $2
|
|
||||||
""",
|
|
||||||
[(channel_id, uid) for uid in member_ids],
|
|
||||||
)
|
|
||||||
|
|
||||||
for user_id in uids:
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
UPDATE user_read_state
|
|
||||||
SET mention_count = mention_count + 1
|
|
||||||
WHERE user_id = $1
|
|
||||||
AND channel_id = $2
|
|
||||||
""",
|
|
||||||
user_id,
|
|
||||||
channel_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def msg_create_request() -> tuple:
|
|
||||||
"""Extract the json input and any file information
|
|
||||||
the client gave to us in the request.
|
|
||||||
|
|
||||||
This only applies to create message route.
|
|
||||||
"""
|
|
||||||
form = await request.form
|
|
||||||
request_json = await request.get_json() or {}
|
|
||||||
|
|
||||||
# NOTE: embed isn't set on form data
|
|
||||||
json_from_form = {
|
|
||||||
"content": form.get("content", ""),
|
|
||||||
"nonce": form.get("nonce", "0"),
|
|
||||||
"tts": json.loads(form.get("tts", "false")),
|
|
||||||
}
|
|
||||||
|
|
||||||
payload_json = json.loads(form.get("payload_json", "{}"))
|
|
||||||
|
|
||||||
json_from_form.update(request_json)
|
|
||||||
json_from_form.update(payload_json)
|
|
||||||
|
|
||||||
files = await request.files
|
|
||||||
|
|
||||||
# we don't really care about the given fields on the files dict, so
|
|
||||||
# we only extract the values
|
|
||||||
return json_from_form, [v for k, v in files.items()]
|
|
||||||
|
|
||||||
|
|
||||||
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
|
||||||
"""Check if there is actually any content being sent to us."""
|
|
||||||
has_content = bool(payload.get("content", ""))
|
|
||||||
has_files = len(files) > 0
|
|
||||||
|
|
||||||
embed_field = "embeds" if use_embeds else "embed"
|
|
||||||
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
|
||||||
|
|
||||||
has_total_content = has_content or has_embed or has_files
|
|
||||||
|
|
||||||
if not has_total_content:
|
|
||||||
raise BadRequest("No content has been provided.")
|
|
||||||
|
|
||||||
|
|
||||||
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
|
|
||||||
"""Add an attachment to a message.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
message_id: int
|
|
||||||
The ID of the message getting the attachment.
|
|
||||||
channel_id: int
|
|
||||||
The ID of the channel the message belongs to.
|
|
||||||
|
|
||||||
Exists because the attachment URL scheme contains
|
|
||||||
a channel id. The purpose is unknown, but we are
|
|
||||||
implementing Discord's behavior.
|
|
||||||
attachment_file: quart.FileStorage
|
|
||||||
quart FileStorage instance of the file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
attachment_id = get_snowflake()
|
|
||||||
filename = attachment_file.filename
|
|
||||||
|
|
||||||
# understand file info
|
|
||||||
mime = attachment_file.mimetype
|
|
||||||
is_image = mime.startswith("image/")
|
|
||||||
|
|
||||||
img_width, img_height = None, None
|
|
||||||
|
|
||||||
# extract file size
|
|
||||||
# TODO: this is probably inneficient
|
|
||||||
file_size = attachment_file.stream.getbuffer().nbytes
|
|
||||||
|
|
||||||
if is_image:
|
|
||||||
# open with pillow, extract image size
|
|
||||||
image = Image.open(attachment_file.stream)
|
|
||||||
img_width, img_height = image.size
|
|
||||||
|
|
||||||
# NOTE: DO NOT close the image, as closing the image will
|
|
||||||
# also close the stream.
|
|
||||||
|
|
||||||
# reset it to 0 for later usage
|
|
||||||
attachment_file.stream.seek(0)
|
|
||||||
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO attachments
|
|
||||||
(id, channel_id, message_id,
|
|
||||||
filename, filesize,
|
|
||||||
image, width, height)
|
|
||||||
VALUES
|
|
||||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
|
||||||
""",
|
|
||||||
attachment_id,
|
|
||||||
channel_id,
|
|
||||||
message_id,
|
|
||||||
filename,
|
|
||||||
file_size,
|
|
||||||
is_image,
|
|
||||||
img_width,
|
|
||||||
img_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
ext = filename.split(".")[-1]
|
|
||||||
|
|
||||||
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
|
|
||||||
attach_file.write(attachment_file.stream.read())
|
|
||||||
|
|
||||||
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
|
|
||||||
|
|
||||||
return attachment_id
|
|
||||||
|
|
||||||
|
|
||||||
async def _spawn_embed(app_, payload, **kwargs):
|
|
||||||
app_.sched.spawn(
|
|
||||||
process_url_embed(
|
|
||||||
app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:channel_id>/messages", methods=["POST"])
|
@bp.route("/<int:channel_id>/messages", methods=["POST"])
|
||||||
|
|
@ -458,7 +246,7 @@ async def _create_message(channel_id):
|
||||||
# spawn url processor for embedding of images
|
# spawn url processor for embedding of images
|
||||||
perms = await get_permissions(user_id, channel_id)
|
perms = await get_permissions(user_id, channel_id)
|
||||||
if perms.bits.embed_links:
|
if perms.bits.embed_links:
|
||||||
await _spawn_embed(app, payload)
|
await _spawn_embed(payload)
|
||||||
|
|
||||||
# update read state for the author
|
# update read state for the author
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
|
|
@ -536,7 +324,6 @@ async def edit_message(channel_id, message_id):
|
||||||
perms = await get_permissions(user_id, channel_id)
|
perms = await get_permissions(user_id, channel_id)
|
||||||
if perms.bits.embed_links:
|
if perms.bits.embed_links:
|
||||||
await _spawn_embed(
|
await _spawn_embed(
|
||||||
app,
|
|
||||||
{
|
{
|
||||||
"id": message_id,
|
"id": message_id,
|
||||||
"channel_id": channel_id,
|
"channel_id": channel_id,
|
||||||
|
|
|
||||||
|
|
@ -23,10 +23,9 @@ from quart import Blueprint, request, current_app as app, jsonify
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
||||||
from litecord.utils import async_map
|
from litecord.utils import async_map, query_tuple_from_args, extract_limit
|
||||||
from litecord.blueprints.auth import token_check
|
from litecord.blueprints.auth import token_check
|
||||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
|
|
||||||
|
|
||||||
from litecord.enums import GUILD_CHANS
|
from litecord.enums import GUILD_CHANS
|
||||||
|
|
||||||
|
|
@ -165,7 +164,8 @@ def _emoji_sql_simple(emoji: str, param=4):
|
||||||
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
|
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
|
||||||
|
|
||||||
|
|
||||||
async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
|
async def _remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
|
||||||
|
"""Remove given reaction from a message."""
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||||
|
|
@ -201,8 +201,7 @@ async def remove_own_reaction(channel_id, message_id, emoji):
|
||||||
"""Remove a reaction."""
|
"""Remove a reaction."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await remove_reaction(channel_id, message_id, user_id, emoji)
|
await _remove_reaction(channel_id, message_id, user_id, emoji)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -212,7 +211,7 @@ async def remove_user_reaction(channel_id, message_id, emoji, other_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await channel_perm_check(user_id, channel_id, "manage_messages")
|
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||||
|
|
||||||
await remove_reaction(channel_id, message_id, other_id, emoji)
|
await _remove_reaction(channel_id, message_id, other_id, emoji)
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
|
||||||
from litecord.utils import search_result_from_list
|
from litecord.utils import search_result_from_list
|
||||||
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
||||||
from litecord.snowflake import snowflake_datetime
|
from litecord.snowflake import snowflake_datetime
|
||||||
|
from litecord.common.channels import channel_ack
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("channels", __name__)
|
bp = Blueprint("channels", __name__)
|
||||||
|
|
@ -136,7 +137,7 @@ async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
|
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
|
||||||
|
|
||||||
|
|
||||||
async def delete_messages(channel_id):
|
async def _delete_messages(channel_id):
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
DELETE FROM channel_pins
|
DELETE FROM channel_pins
|
||||||
|
|
@ -162,7 +163,7 @@ async def delete_messages(channel_id):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def guild_cleanup(channel_id):
|
async def _guild_cleanup(channel_id):
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
DELETE FROM channel_overwrites
|
DELETE FROM channel_overwrites
|
||||||
|
|
@ -220,8 +221,8 @@ async def close_channel(channel_id):
|
||||||
# didn't work on my setup, so I delete
|
# didn't work on my setup, so I delete
|
||||||
# everything before moving to the main
|
# everything before moving to the main
|
||||||
# channel table deletes
|
# channel table deletes
|
||||||
await delete_messages(channel_id)
|
await _delete_messages(channel_id)
|
||||||
await guild_cleanup(channel_id)
|
await _guild_cleanup(channel_id)
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
f"""
|
f"""
|
||||||
|
|
@ -595,48 +596,6 @@ async def trigger_typing(channel_id):
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
|
||||||
"""ACK a channel."""
|
|
||||||
|
|
||||||
if not message_id:
|
|
||||||
message_id = await app.storage.chan_last_message(channel_id)
|
|
||||||
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO user_read_state
|
|
||||||
(user_id, channel_id, last_message_id, mention_count)
|
|
||||||
VALUES
|
|
||||||
($1, $2, $3, 0)
|
|
||||||
ON CONFLICT ON CONSTRAINT user_read_state_pkey
|
|
||||||
DO
|
|
||||||
UPDATE
|
|
||||||
SET last_message_id = $3, mention_count = 0
|
|
||||||
WHERE user_read_state.user_id = $1
|
|
||||||
AND user_read_state.channel_id = $2
|
|
||||||
""",
|
|
||||||
user_id,
|
|
||||||
channel_id,
|
|
||||||
message_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if guild_id:
|
|
||||||
await app.dispatcher.dispatch_user_guild(
|
|
||||||
user_id,
|
|
||||||
guild_id,
|
|
||||||
"MESSAGE_ACK",
|
|
||||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# we don't use ChannelDispatcher here because since
|
|
||||||
# guild_id is None, all user devices are already subscribed
|
|
||||||
# to the given channel (a dm or a group dm)
|
|
||||||
await app.dispatcher.dispatch_user(
|
|
||||||
user_id,
|
|
||||||
"MESSAGE_ACK",
|
|
||||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:channel_id>/messages/<int:message_id>/ack", methods=["POST"])
|
@bp.route("/<int:channel_id>/messages/<int:message_id>/ack", methods=["POST"])
|
||||||
async def ack_channel(channel_id, message_id):
|
async def ack_channel(channel_id, message_id):
|
||||||
"""Acknowledge a channel."""
|
"""Acknowledge a channel."""
|
||||||
|
|
@ -799,7 +758,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
||||||
|
|
||||||
message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
|
message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
|
||||||
|
|
||||||
await msg_update_embeds(message, [], app.storage, app.dispatcher)
|
await msg_update_embeds(message, [])
|
||||||
elif not suppress and not url_embeds:
|
elif not suppress and not url_embeds:
|
||||||
# spawn process_url_embed to restore the embeds, if any
|
# spawn process_url_embed to restore the embeds, if any
|
||||||
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
|
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
|
||||||
|
|
@ -809,11 +768,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
app.sched.spawn(
|
app.sched.spawn(process_url_embed(message))
|
||||||
process_url_embed(
|
|
||||||
app.config, app.storage, app.dispatcher, app.session, message
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ from ..snowflake import get_snowflake
|
||||||
from .auth import token_check
|
from .auth import token_check
|
||||||
|
|
||||||
from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
|
from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
|
||||||
|
from litecord.common.channels import try_dm_state
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("dms", __name__)
|
bp = Blueprint("dms", __name__)
|
||||||
|
|
@ -44,24 +45,6 @@ async def get_dms():
|
||||||
return jsonify(dms)
|
return jsonify(dms)
|
||||||
|
|
||||||
|
|
||||||
async def try_dm_state(user_id: int, dm_id: int):
|
|
||||||
"""Try inserting the user into the dm state
|
|
||||||
for the given DM.
|
|
||||||
|
|
||||||
Does not do anything if the user is already
|
|
||||||
in the dm state.
|
|
||||||
"""
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO dm_channel_state (user_id, dm_id)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
ON CONFLICT DO NOTHING
|
|
||||||
""",
|
|
||||||
user_id,
|
|
||||||
dm_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def jsonify_dm(dm_id: int, user_id: int):
|
async def jsonify_dm(dm_id: int, user_id: int):
|
||||||
dm_chan = await app.storage.get_dm(dm_id, user_id)
|
dm_chan = await app.storage.get_dm(dm_id, user_id)
|
||||||
return jsonify(dm_chan)
|
return jsonify(dm_chan)
|
||||||
|
|
|
||||||
|
|
@ -27,77 +27,11 @@ from litecord.blueprints.guild.roles import gen_pairs
|
||||||
|
|
||||||
from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
|
from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
|
||||||
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
|
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
|
||||||
|
from litecord.common.guilds import create_guild_channel
|
||||||
|
|
||||||
bp = Blueprint("guild_channels", __name__)
|
bp = Blueprint("guild_channels", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
|
||||||
if ctype == ChannelType.GUILD_TEXT:
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO guild_text_channels (id, topic)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
""",
|
|
||||||
channel_id,
|
|
||||||
kwargs.get("topic", ""),
|
|
||||||
)
|
|
||||||
elif ctype == ChannelType.GUILD_VOICE:
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
""",
|
|
||||||
channel_id,
|
|
||||||
kwargs.get("bitrate", 64),
|
|
||||||
kwargs.get("user_limit", 0),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_guild_channel(
|
|
||||||
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
|
||||||
):
|
|
||||||
"""Create a channel in a guild."""
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO channels (id, channel_type)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
""",
|
|
||||||
channel_id,
|
|
||||||
ctype.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# calc new pos
|
|
||||||
max_pos = await app.db.fetchval(
|
|
||||||
"""
|
|
||||||
SELECT MAX(position)
|
|
||||||
FROM guild_channels
|
|
||||||
WHERE guild_id = $1
|
|
||||||
""",
|
|
||||||
guild_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# account for the first channel in a guild too
|
|
||||||
max_pos = max_pos or 0
|
|
||||||
|
|
||||||
# all channels go to guild_channels
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
|
||||||
VALUES ($1, $2, $3, $4)
|
|
||||||
""",
|
|
||||||
channel_id,
|
|
||||||
guild_id,
|
|
||||||
kwargs["name"],
|
|
||||||
max_pos + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# the rest of sql magic is dependant on the channel
|
|
||||||
# we're creating (a text or voice or category),
|
|
||||||
# so we use this function.
|
|
||||||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:guild_id>/channels", methods=["GET"])
|
@bp.route("/<int:guild_id>/channels", methods=["GET"])
|
||||||
async def get_guild_channels(guild_id):
|
async def get_guild_channels(guild_id):
|
||||||
"""Get the list of channels in a guild."""
|
"""Get the list of channels in a guild."""
|
||||||
|
|
|
||||||
|
|
@ -23,47 +23,11 @@ from litecord.blueprints.auth import token_check
|
||||||
from litecord.blueprints.checks import guild_perm_check
|
from litecord.blueprints.checks import guild_perm_check
|
||||||
|
|
||||||
from litecord.schemas import validate, GUILD_PRUNE
|
from litecord.schemas import validate, GUILD_PRUNE
|
||||||
|
from litecord.common.guilds import remove_member, remove_member_multi
|
||||||
|
|
||||||
bp = Blueprint("guild_moderation", __name__)
|
bp = Blueprint("guild_moderation", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def remove_member(guild_id: int, member_id: int):
|
|
||||||
"""Do common tasks related to deleting a member from the guild,
|
|
||||||
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
|
||||||
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM members
|
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
|
||||||
""",
|
|
||||||
guild_id,
|
|
||||||
member_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user_guild(
|
|
||||||
member_id,
|
|
||||||
guild_id,
|
|
||||||
"GUILD_DELETE",
|
|
||||||
{"guild_id": str(guild_id), "unavailable": False},
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.dispatcher.unsub("guild", guild_id, member_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
|
||||||
guild_id,
|
|
||||||
"GUILD_MEMBER_REMOVE",
|
|
||||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def remove_member_multi(guild_id: int, members: list):
|
|
||||||
"""Remove multiple members."""
|
|
||||||
for member_id in members:
|
|
||||||
await remove_member(guild_id, member_id)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["DELETE"])
|
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["DELETE"])
|
||||||
async def kick_guild_member(guild_id, member_id):
|
async def kick_guild_member(guild_id, member_id):
|
||||||
"""Remove a member from a guild."""
|
"""Remove a member from a guild."""
|
||||||
|
|
@ -221,6 +185,5 @@ async def begin_guild_prune(guild_id):
|
||||||
days = j["days"]
|
days = j["days"]
|
||||||
member_ids = await get_prune(guild_id, days)
|
member_ids = await get_prune(guild_id, days)
|
||||||
|
|
||||||
app.loop.create_task(remove_member_multi(guild_id, member_ids))
|
app.sched.spawn(remove_member_multi(guild_id, member_ids))
|
||||||
|
|
||||||
return jsonify({"pruned": len(member_ids)})
|
return jsonify({"pruned": len(member_ids)})
|
||||||
|
|
|
||||||
|
|
@ -27,11 +27,9 @@ from litecord.auth import token_check
|
||||||
from litecord.blueprints.checks import guild_check, guild_perm_check
|
from litecord.blueprints.checks import guild_check, guild_perm_check
|
||||||
from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
||||||
|
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.utils import maybe_lazy_guild_dispatch
|
||||||
from litecord.utils import dict_get
|
from litecord.common.guilds import create_role
|
||||||
from litecord.permissions import get_role_perms
|
|
||||||
|
|
||||||
DEFAULT_EVERYONE_PERMS = 104324161
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("guild_roles", __name__)
|
bp = Blueprint("guild_roles", __name__)
|
||||||
|
|
||||||
|
|
@ -45,71 +43,6 @@ async def get_guild_roles(guild_id):
|
||||||
return jsonify(await app.storage.get_role_data(guild_id))
|
return jsonify(await app.storage.get_role_data(guild_id))
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
|
|
||||||
# sometimes we want to dispatch an event
|
|
||||||
# even if the role isn't hoisted
|
|
||||||
|
|
||||||
# an example of such a case is when a role loses
|
|
||||||
# its hoist status.
|
|
||||||
|
|
||||||
# check if is a dict first because role_delete
|
|
||||||
# only receives the role id.
|
|
||||||
if isinstance(role, dict) and not role["hoist"] and not force:
|
|
||||||
return
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_role(guild_id, name: str, **kwargs):
|
|
||||||
"""Create a role in a guild."""
|
|
||||||
new_role_id = get_snowflake()
|
|
||||||
|
|
||||||
everyone_perms = await get_role_perms(guild_id, guild_id)
|
|
||||||
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
|
|
||||||
|
|
||||||
# update all roles so that we have space for pos 1, but without
|
|
||||||
# sending GUILD_ROLE_UPDATE for everyone
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
UPDATE roles
|
|
||||||
SET
|
|
||||||
position = position + 1
|
|
||||||
WHERE guild_id = $1
|
|
||||||
AND NOT (position = 0)
|
|
||||||
""",
|
|
||||||
guild_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO roles (id, guild_id, name, color,
|
|
||||||
hoist, position, permissions, managed, mentionable)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
||||||
""",
|
|
||||||
new_role_id,
|
|
||||||
guild_id,
|
|
||||||
name,
|
|
||||||
dict_get(kwargs, "color", 0),
|
|
||||||
dict_get(kwargs, "hoist", False),
|
|
||||||
# always set ourselves on position 1
|
|
||||||
1,
|
|
||||||
int(dict_get(kwargs, "permissions", default_perms)),
|
|
||||||
False,
|
|
||||||
dict_get(kwargs, "mentionable", False),
|
|
||||||
)
|
|
||||||
|
|
||||||
role = await app.storage.get_role(new_role_id, guild_id)
|
|
||||||
|
|
||||||
# we need to update the lazy guild handlers for the newly created group
|
|
||||||
await _maybe_lg(guild_id, "new_role", role)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
|
||||||
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
|
||||||
)
|
|
||||||
|
|
||||||
return role
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:guild_id>/roles", methods=["POST"])
|
@bp.route("/<int:guild_id>/roles", methods=["POST"])
|
||||||
async def create_guild_role(guild_id: int):
|
async def create_guild_role(guild_id: int):
|
||||||
"""Add a role to a guild"""
|
"""Add a role to a guild"""
|
||||||
|
|
@ -132,7 +65,7 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
|
||||||
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
|
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
|
||||||
role = await app.storage.get_role(role_id, guild_id)
|
role = await app.storage.get_role(role_id, guild_id)
|
||||||
|
|
||||||
await _maybe_lg(guild_id, "role_pos_upd", role)
|
await maybe_lazy_guild_dispatch(guild_id, "role_pos_upd", role)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(
|
||||||
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
|
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
|
||||||
|
|
@ -343,7 +276,7 @@ async def update_guild_role(guild_id, role_id):
|
||||||
)
|
)
|
||||||
|
|
||||||
role = await _role_update_dispatch(role_id, guild_id)
|
role = await _role_update_dispatch(role_id, guild_id)
|
||||||
await _maybe_lg(guild_id, "role_update", role, True)
|
await maybe_lazy_guild_dispatch(guild_id, "role_update", role, True)
|
||||||
return jsonify(role)
|
return jsonify(role)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -369,7 +302,7 @@ async def delete_guild_role(guild_id, role_id):
|
||||||
if res == "DELETE 0":
|
if res == "DELETE 0":
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
await _maybe_lg(guild_id, "role_delete", role_id, True)
|
await maybe_lazy_guild_dispatch(guild_id, "role_delete", role_id, True)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(
|
||||||
guild_id,
|
guild_id,
|
||||||
|
|
|
||||||
|
|
@ -21,8 +21,12 @@ from typing import Optional, List
|
||||||
|
|
||||||
from quart import Blueprint, request, current_app as app, jsonify
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
from litecord.blueprints.guild.channels import create_guild_channel
|
from litecord.common.guilds import (
|
||||||
from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
|
create_role,
|
||||||
|
create_guild_channel,
|
||||||
|
delete_guild,
|
||||||
|
create_guild_settings,
|
||||||
|
)
|
||||||
|
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
from ..snowflake import get_snowflake
|
from ..snowflake import get_snowflake
|
||||||
|
|
@ -34,44 +38,17 @@ from ..schemas import (
|
||||||
SEARCH_CHANNEL,
|
SEARCH_CHANNEL,
|
||||||
VANITY_URL_PATCH,
|
VANITY_URL_PATCH,
|
||||||
)
|
)
|
||||||
from .channels import channel_ack
|
|
||||||
from .checks import guild_check, guild_owner_check, guild_perm_check
|
from .checks import guild_check, guild_owner_check, guild_perm_check
|
||||||
|
from ..common.channels import channel_ack
|
||||||
from litecord.utils import to_update, search_result_from_list
|
from litecord.utils import to_update, search_result_from_list
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
|
DEFAULT_EVERYONE_PERMS = 104324161
|
||||||
|
|
||||||
bp = Blueprint("guilds", __name__)
|
bp = Blueprint("guilds", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def create_guild_settings(guild_id: int, user_id: int):
|
|
||||||
"""Create guild settings for the user
|
|
||||||
joining the guild."""
|
|
||||||
|
|
||||||
# new guild_settings are based off the currently
|
|
||||||
# set guild settings (for the guild)
|
|
||||||
m_notifs = await app.db.fetchval(
|
|
||||||
"""
|
|
||||||
SELECT default_message_notifications
|
|
||||||
FROM guilds
|
|
||||||
WHERE id = $1
|
|
||||||
""",
|
|
||||||
guild_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO guild_settings
|
|
||||||
(user_id, guild_id, message_notifications)
|
|
||||||
VALUES
|
|
||||||
($1, $2, $3)
|
|
||||||
""",
|
|
||||||
user_id,
|
|
||||||
guild_id,
|
|
||||||
m_notifs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def add_member(guild_id: int, user_id: int):
|
async def add_member(guild_id: int, user_id: int):
|
||||||
"""Add a user to a guild."""
|
"""Add a user to a guild."""
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
|
|
@ -393,36 +370,6 @@ async def _update_guild(guild_id):
|
||||||
return jsonify(guild)
|
return jsonify(guild)
|
||||||
|
|
||||||
|
|
||||||
async def delete_guild(guild_id: int, *, app_=None):
|
|
||||||
"""Delete a single guild."""
|
|
||||||
app_ = app_ or app
|
|
||||||
|
|
||||||
await app_.db.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM guilds
|
|
||||||
WHERE guilds.id = $1
|
|
||||||
""",
|
|
||||||
guild_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Discord's client expects IDs being string
|
|
||||||
await app_.dispatcher.dispatch(
|
|
||||||
"guild",
|
|
||||||
guild_id,
|
|
||||||
"GUILD_DELETE",
|
|
||||||
{
|
|
||||||
"guild_id": str(guild_id),
|
|
||||||
"id": str(guild_id),
|
|
||||||
# 'unavailable': False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# remove from the dispatcher so nobody
|
|
||||||
# becomes the little memer that tries to fuck up with
|
|
||||||
# everybody's gateway
|
|
||||||
await app_.dispatcher.remove("guild", guild_id)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/<int:guild_id>", methods=["DELETE"])
|
@bp.route("/<int:guild_id>", methods=["DELETE"])
|
||||||
# this endpoint is not documented, but used by the official client.
|
# this endpoint is not documented, but used by the official client.
|
||||||
@bp.route("/<int:guild_id>/delete", methods=["POST"])
|
@bp.route("/<int:guild_id>/delete", methods=["POST"])
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ async def _get_default_user_avatar(default_id: int):
|
||||||
|
|
||||||
|
|
||||||
async def _handle_webhook_avatar(md_url_redir: str):
|
async def _handle_webhook_avatar(md_url_redir: str):
|
||||||
md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
|
md_url = make_md_req_url("img", EmbedURL(md_url_redir))
|
||||||
return redirect(md_url)
|
return redirect(md_url)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,6 @@ from ..auth import token_check
|
||||||
from ..schemas import validate, INVITE
|
from ..schemas import validate, INVITE
|
||||||
from ..enums import ChannelType
|
from ..enums import ChannelType
|
||||||
from ..errors import BadRequest, Forbidden
|
from ..errors import BadRequest, Forbidden
|
||||||
from .guilds import create_guild_settings
|
|
||||||
from ..utils import async_map
|
from ..utils import async_map
|
||||||
|
|
||||||
from litecord.blueprints.checks import (
|
from litecord.blueprints.checks import (
|
||||||
|
|
@ -39,6 +38,7 @@ from litecord.blueprints.checks import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
|
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
|
||||||
|
from litecord.common.guilds import create_guild_settings
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("invites", __name__)
|
bp = Blueprint("invites", __name__)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from litecord.snowflake import snowflake_datetime, get_snowflake
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
from litecord.types import timestamp_, HOURS
|
from litecord.types import timestamp_, HOURS
|
||||||
from litecord.enums import UserFlags, PremiumType
|
from litecord.enums import UserFlags, PremiumType
|
||||||
from litecord.blueprints.users import mass_user_update
|
from litecord.common.users import mass_user_update
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint("users_billing", __name__)
|
bp = Blueprint("users_billing", __name__)
|
||||||
|
|
@ -122,16 +122,13 @@ async def get_payment_source_ids(user_id: int) -> list:
|
||||||
return [r["id"] for r in rows]
|
return [r["id"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
async def get_payment_ids(user_id: int, db=None) -> list:
|
async def get_payment_ids(user_id: int) -> list:
|
||||||
if not db:
|
rows = await app.db.fetch(
|
||||||
db = app.db
|
|
||||||
|
|
||||||
rows = await db.fetch(
|
|
||||||
"""
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""",
|
""",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -151,18 +148,14 @@ async def get_subscription_ids(user_id: int) -> list:
|
||||||
return [r["id"] for r in rows]
|
return [r["id"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
async def get_payment_source(user_id: int, source_id: int) -> dict:
|
||||||
"""Get a payment source's information."""
|
"""Get a payment source's information."""
|
||||||
|
source_type = await app.db.fetchval(
|
||||||
if not db:
|
|
||||||
db = app.db
|
|
||||||
|
|
||||||
source_type = await db.fetchval(
|
|
||||||
"""
|
"""
|
||||||
SELECT source_type
|
SELECT source_type
|
||||||
FROM user_payment_sources
|
FROM user_payment_sources
|
||||||
WHERE id = $1 AND user_id = $2
|
WHERE id = $1 AND user_id = $2
|
||||||
""",
|
""",
|
||||||
source_id,
|
source_id,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
@ -176,7 +169,7 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
||||||
|
|
||||||
fields = ",".join(specific_fields)
|
fields = ",".join(specific_fields)
|
||||||
|
|
||||||
extras_row = await db.fetchrow(
|
extras_row = await app.db.fetchrow(
|
||||||
f"""
|
f"""
|
||||||
SELECT {fields}, billing_address, default_, id::text
|
SELECT {fields}, billing_address, default_, id::text
|
||||||
FROM user_payment_sources
|
FROM user_payment_sources
|
||||||
|
|
@ -199,22 +192,19 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
||||||
return {**source, **derow}
|
return {**source, **derow}
|
||||||
|
|
||||||
|
|
||||||
async def get_subscription(subscription_id: int, db=None):
|
async def get_subscription(subscription_id: int):
|
||||||
"""Get a subscription's information."""
|
"""Get a subscription's information."""
|
||||||
if not db:
|
row = await app.db.fetchrow(
|
||||||
db = app.db
|
|
||||||
|
|
||||||
row = await db.fetchrow(
|
|
||||||
"""
|
"""
|
||||||
SELECT id::text, source_id::text AS payment_source_id,
|
SELECT id::text, source_id::text AS payment_source_id,
|
||||||
user_id,
|
user_id,
|
||||||
payment_gateway, payment_gateway_plan_id,
|
payment_gateway, payment_gateway_plan_id,
|
||||||
period_start AS current_period_start,
|
period_start AS current_period_start,
|
||||||
period_end AS current_period_end,
|
period_end AS current_period_end,
|
||||||
canceled_at, s_type, status
|
canceled_at, s_type, status
|
||||||
FROM user_subscriptions
|
FROM user_subscriptions
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""",
|
""",
|
||||||
subscription_id,
|
subscription_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -231,19 +221,16 @@ async def get_subscription(subscription_id: int, db=None):
|
||||||
return drow
|
return drow
|
||||||
|
|
||||||
|
|
||||||
async def get_payment(payment_id: int, db=None):
|
async def get_payment(payment_id: int):
|
||||||
"""Get a single payment's information."""
|
"""Get a single payment's information."""
|
||||||
if not db:
|
row = await app.db.fetchrow(
|
||||||
db = app.db
|
|
||||||
|
|
||||||
row = await db.fetchrow(
|
|
||||||
"""
|
"""
|
||||||
SELECT id::text, source_id, subscription_id, user_id,
|
SELECT id::text, source_id, subscription_id, user_id,
|
||||||
amount, amount_refunded, currency,
|
amount, amount_refunded, currency,
|
||||||
description, status, tax, tax_inclusive
|
description, status, tax, tax_inclusive
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""",
|
""",
|
||||||
payment_id,
|
payment_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -255,27 +242,22 @@ async def get_payment(payment_id: int, db=None):
|
||||||
|
|
||||||
drow["created_at"] = snowflake_datetime(int(drow["id"]))
|
drow["created_at"] = snowflake_datetime(int(drow["id"]))
|
||||||
|
|
||||||
drow["payment_source"] = await get_payment_source(
|
drow["payment_source"] = await get_payment_source(row["user_id"], row["source_id"])
|
||||||
row["user_id"], row["source_id"], db
|
|
||||||
)
|
|
||||||
|
|
||||||
drow["subscription"] = await get_subscription(row["subscription_id"], db)
|
drow["subscription"] = await get_subscription(row["subscription_id"])
|
||||||
|
|
||||||
return drow
|
return drow
|
||||||
|
|
||||||
|
|
||||||
async def create_payment(subscription_id, db=None):
|
async def create_payment(subscription_id):
|
||||||
"""Create a payment."""
|
"""Create a payment."""
|
||||||
if not db:
|
sub = await get_subscription(subscription_id)
|
||||||
db = app.db
|
|
||||||
|
|
||||||
sub = await get_subscription(subscription_id, db)
|
|
||||||
|
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
||||||
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
|
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
|
||||||
|
|
||||||
await db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO user_payments (
|
INSERT INTO user_payments (
|
||||||
id, source_id, subscription_id, user_id,
|
id, source_id, subscription_id, user_id,
|
||||||
|
|
@ -298,9 +280,9 @@ async def create_payment(subscription_id, db=None):
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
async def process_subscription(app, subscription_id: int):
|
async def process_subscription(subscription_id: int):
|
||||||
"""Process a single subscription."""
|
"""Process a single subscription."""
|
||||||
sub = await get_subscription(subscription_id, app.db)
|
sub = await get_subscription(subscription_id)
|
||||||
|
|
||||||
user_id = int(sub["user_id"])
|
user_id = int(sub["user_id"])
|
||||||
|
|
||||||
|
|
@ -313,10 +295,10 @@ async def process_subscription(app, subscription_id: int):
|
||||||
# payments), then we should update premium status
|
# payments), then we should update premium status
|
||||||
first_payment_id = await app.db.fetchval(
|
first_payment_id = await app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT MIN(id)
|
SELECT MIN(id)
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
WHERE subscription_id = $1
|
WHERE subscription_id = $1
|
||||||
""",
|
""",
|
||||||
subscription_id,
|
subscription_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -324,10 +306,10 @@ async def process_subscription(app, subscription_id: int):
|
||||||
|
|
||||||
premium_since = await app.db.fetchval(
|
premium_since = await app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT premium_since
|
SELECT premium_since
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""",
|
""",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -343,10 +325,10 @@ async def process_subscription(app, subscription_id: int):
|
||||||
|
|
||||||
old_flags = await app.db.fetchval(
|
old_flags = await app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT flags
|
SELECT flags
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""",
|
""",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -355,17 +337,17 @@ async def process_subscription(app, subscription_id: int):
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET premium_since = $1, flags = $2
|
SET premium_since = $1, flags = $2
|
||||||
WHERE id = $3
|
WHERE id = $3
|
||||||
""",
|
""",
|
||||||
first_payment_ts,
|
first_payment_ts,
|
||||||
new_flags,
|
new_flags,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# dispatch updated user to all possible clients
|
# dispatch updated user to all possible clients
|
||||||
await mass_user_update(user_id, app)
|
await mass_user_update(user_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/@me/billing/payment-sources", methods=["GET"])
|
@bp.route("/@me/billing/payment-sources", methods=["GET"])
|
||||||
|
|
@ -474,11 +456,11 @@ async def _create_subscription():
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
await create_payment(new_id, app.db)
|
await create_payment(new_id)
|
||||||
|
|
||||||
# make sure we update the user's premium status
|
# make sure we update the user's premium status
|
||||||
# and dispatch respective user updates to other people.
|
# and dispatch respective user updates to other people.
|
||||||
await process_subscription(app, new_id)
|
await process_subscription(new_id)
|
||||||
|
|
||||||
return jsonify(await get_subscription(new_id))
|
return jsonify(await get_subscription(new_id))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
this file only serves the periodic payment job code.
|
this file only serves the periodic payment job code.
|
||||||
"""
|
"""
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
from asyncio import sleep, CancelledError
|
from asyncio import sleep, CancelledError
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
@ -47,14 +49,14 @@ THRESHOLDS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _resched(app):
|
async def _resched():
|
||||||
log.debug("waiting 30 minutes for job.")
|
log.debug("waiting 30 minutes for job.")
|
||||||
await sleep(30 * MINUTES)
|
await sleep(30 * MINUTES)
|
||||||
app.sched.spawn(payment_job(app))
|
app.sched.spawn(payment_job())
|
||||||
|
|
||||||
|
|
||||||
async def _process_user_payments(app, user_id: int):
|
async def _process_user_payments(user_id: int):
|
||||||
payments = await get_payment_ids(user_id, app.db)
|
payments = await get_payment_ids(user_id)
|
||||||
|
|
||||||
if not payments:
|
if not payments:
|
||||||
log.debug("no payments for uid {}, skipping", user_id)
|
log.debug("no payments for uid {}, skipping", user_id)
|
||||||
|
|
@ -64,7 +66,7 @@ async def _process_user_payments(app, user_id: int):
|
||||||
|
|
||||||
latest_payment = max(payments)
|
latest_payment = max(payments)
|
||||||
|
|
||||||
payment_data = await get_payment(latest_payment, app.db)
|
payment_data = await get_payment(latest_payment)
|
||||||
|
|
||||||
# calculate the difference between this payment
|
# calculate the difference between this payment
|
||||||
# and now.
|
# and now.
|
||||||
|
|
@ -74,7 +76,7 @@ async def _process_user_payments(app, user_id: int):
|
||||||
delta = now - payment_tstamp
|
delta = now - payment_tstamp
|
||||||
|
|
||||||
sub_id = int(payment_data["subscription"]["id"])
|
sub_id = int(payment_data["subscription"]["id"])
|
||||||
subscription = await get_subscription(sub_id, app.db)
|
subscription = await get_subscription(sub_id)
|
||||||
|
|
||||||
# if the max payment is X days old, we create another.
|
# if the max payment is X days old, we create another.
|
||||||
# X is 30 for monthly subscriptions of nitro,
|
# X is 30 for monthly subscriptions of nitro,
|
||||||
|
|
@ -89,12 +91,12 @@ async def _process_user_payments(app, user_id: int):
|
||||||
# create_payment does not call any Stripe
|
# create_payment does not call any Stripe
|
||||||
# or BrainTree APIs at all, since we'll just
|
# or BrainTree APIs at all, since we'll just
|
||||||
# give it as free.
|
# give it as free.
|
||||||
await create_payment(sub_id, app.db)
|
await create_payment(sub_id)
|
||||||
else:
|
else:
|
||||||
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
|
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
|
||||||
|
|
||||||
|
|
||||||
async def payment_job(app):
|
async def payment_job():
|
||||||
"""Main payment job function.
|
"""Main payment job function.
|
||||||
|
|
||||||
This function will check through users' payments
|
This function will check through users' payments
|
||||||
|
|
@ -104,9 +106,9 @@ async def payment_job(app):
|
||||||
|
|
||||||
user_ids = await app.db.fetch(
|
user_ids = await app.db.fetch(
|
||||||
"""
|
"""
|
||||||
SELECT DISTINCT user_id
|
SELECT DISTINCT user_id
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug("working {} users", len(user_ids))
|
log.debug("working {} users", len(user_ids))
|
||||||
|
|
@ -115,24 +117,24 @@ async def payment_job(app):
|
||||||
for row in user_ids:
|
for row in user_ids:
|
||||||
user_id = row["user_id"]
|
user_id = row["user_id"]
|
||||||
try:
|
try:
|
||||||
await _process_user_payments(app, user_id)
|
await _process_user_payments(user_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("error while processing user payments")
|
log.exception("error while processing user payments")
|
||||||
|
|
||||||
subscribers = await app.db.fetch(
|
subscribers = await app.db.fetch(
|
||||||
"""
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM user_subscriptions
|
FROM user_subscriptions
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in subscribers:
|
for row in subscribers:
|
||||||
try:
|
try:
|
||||||
await process_subscription(app, row["id"])
|
await process_subscription(row["id"])
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("error while processing subscription")
|
log.exception("error while processing subscription")
|
||||||
log.debug("rescheduling..")
|
log.debug("rescheduling..")
|
||||||
try:
|
try:
|
||||||
await _resched(app)
|
await _resched()
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
log.info("cancelled while waiting for resched")
|
log.info("cancelled while waiting for resched")
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from os import urandom
|
|
||||||
|
|
||||||
from asyncpg import UniqueViolationError
|
from asyncpg import UniqueViolationError
|
||||||
from quart import Blueprint, jsonify, request, current_app as app
|
from quart import Blueprint, jsonify, request, current_app as app
|
||||||
|
|
@ -27,8 +26,8 @@ from ..errors import Forbidden, BadRequest, Unauthorized
|
||||||
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
|
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
|
||||||
|
|
||||||
from .guilds import guild_check
|
from .guilds import guild_check
|
||||||
from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
|
from litecord.auth import token_check, hash_data
|
||||||
from litecord.blueprints.guild.mod import remove_member
|
from litecord.common.guilds import remove_member
|
||||||
|
|
||||||
from litecord.enums import PremiumType
|
from litecord.enums import PremiumType
|
||||||
from litecord.images import parse_data_uri
|
from litecord.images import parse_data_uri
|
||||||
|
|
@ -36,46 +35,18 @@ from litecord.permissions import base_permissions
|
||||||
|
|
||||||
from litecord.blueprints.auth import check_password
|
from litecord.blueprints.auth import check_password
|
||||||
from litecord.utils import to_update
|
from litecord.utils import to_update
|
||||||
|
from litecord.common.users import (
|
||||||
|
mass_user_update,
|
||||||
|
delete_user,
|
||||||
|
check_username_usage,
|
||||||
|
roll_discrim,
|
||||||
|
user_disconnect,
|
||||||
|
)
|
||||||
|
|
||||||
bp = Blueprint("user", __name__)
|
bp = Blueprint("user", __name__)
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def mass_user_update(user_id):
|
|
||||||
"""Dispatch USER_UPDATE in a mass way."""
|
|
||||||
# by using dispatch_with_filter
|
|
||||||
# we're guaranteeing all shards will get
|
|
||||||
# a USER_UPDATE once and not any others.
|
|
||||||
|
|
||||||
session_ids = []
|
|
||||||
|
|
||||||
public_user = await app.storage.get_user(user_id)
|
|
||||||
private_user = await app.storage.get_user(user_id, secure=True)
|
|
||||||
|
|
||||||
session_ids.extend(
|
|
||||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
|
|
||||||
)
|
|
||||||
|
|
||||||
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
|
||||||
friend_ids = await app.user_storage.get_friend_ids(user_id)
|
|
||||||
|
|
||||||
session_ids.extend(
|
|
||||||
await app.dispatcher.dispatch_many_filter_list(
|
|
||||||
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
session_ids.extend(
|
|
||||||
await app.dispatcher.dispatch_many_filter_list(
|
|
||||||
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
|
|
||||||
|
|
||||||
return public_user, private_user
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/@me", methods=["GET"])
|
@bp.route("/@me", methods=["GET"])
|
||||||
async def get_me():
|
async def get_me():
|
||||||
"""Get the current user's information."""
|
"""Get the current user's information."""
|
||||||
|
|
@ -276,7 +247,7 @@ async def patch_me():
|
||||||
|
|
||||||
user.pop("password_hash")
|
user.pop("password_hash")
|
||||||
|
|
||||||
_, private_user = await mass_user_update(user_id, app)
|
_, private_user = await mass_user_update(user_id)
|
||||||
return jsonify(private_user)
|
return jsonify(private_user)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -319,7 +290,6 @@ async def leave_guild(guild_id: int):
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
await remove_member(guild_id, user_id)
|
await remove_member(guild_id, user_id)
|
||||||
|
|
||||||
return "", 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -468,118 +438,6 @@ async def _get_mentions():
|
||||||
return jsonify(res)
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
def rand_hex(length: int = 8) -> str:
|
|
||||||
"""Generate random hex characters."""
|
|
||||||
return urandom(length).hex()[:length]
|
|
||||||
|
|
||||||
|
|
||||||
async def _del_from_table(db, table: str, user_id: int):
|
|
||||||
"""Delete a row from a table."""
|
|
||||||
column = {
|
|
||||||
"channel_overwrites": "target_user",
|
|
||||||
"user_settings": "id",
|
|
||||||
"group_dm_members": "member_id",
|
|
||||||
}.get(table, "user_id")
|
|
||||||
|
|
||||||
res = await db.execute(
|
|
||||||
f"""
|
|
||||||
DELETE FROM {table}
|
|
||||||
WHERE {column} = $1
|
|
||||||
""",
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_user(user_id, *, mass_update: bool = True):
|
|
||||||
"""Delete a user. Does not disconnect the user."""
|
|
||||||
db = app.db
|
|
||||||
|
|
||||||
new_username = f"Deleted User {rand_hex()}"
|
|
||||||
|
|
||||||
# by using a random hex in password_hash
|
|
||||||
# we break attempts at using the default '123' password hash
|
|
||||||
# to issue valid tokens for deleted users.
|
|
||||||
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
UPDATE users
|
|
||||||
SET
|
|
||||||
username = $1,
|
|
||||||
email = NULL,
|
|
||||||
mfa_enabled = false,
|
|
||||||
verified = false,
|
|
||||||
avatar = NULL,
|
|
||||||
flags = 0,
|
|
||||||
premium_since = NULL,
|
|
||||||
phone = '',
|
|
||||||
password_hash = $2
|
|
||||||
WHERE
|
|
||||||
id = $3
|
|
||||||
""",
|
|
||||||
new_username,
|
|
||||||
rand_hex(32),
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# remove the user from various tables
|
|
||||||
await _del_from_table(db, "user_settings", user_id)
|
|
||||||
await _del_from_table(db, "user_payment_sources", user_id)
|
|
||||||
await _del_from_table(db, "user_subscriptions", user_id)
|
|
||||||
await _del_from_table(db, "user_payments", user_id)
|
|
||||||
await _del_from_table(db, "user_read_state", user_id)
|
|
||||||
await _del_from_table(db, "guild_settings", user_id)
|
|
||||||
await _del_from_table(db, "guild_settings_channel_overrides", user_id)
|
|
||||||
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
DELETE FROM relationships
|
|
||||||
WHERE user_id = $1 OR peer_id = $1
|
|
||||||
""",
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# DMs are still maintained, but not the state.
|
|
||||||
await _del_from_table(db, "dm_channel_state", user_id)
|
|
||||||
|
|
||||||
# NOTE: we don't delete the group dms the user is an owner of...
|
|
||||||
# TODO: group dm owner reassign when the owner leaves a gdm
|
|
||||||
await _del_from_table(db, "group_dm_members", user_id)
|
|
||||||
|
|
||||||
await _del_from_table(db, "members", user_id)
|
|
||||||
await _del_from_table(db, "member_roles", user_id)
|
|
||||||
await _del_from_table(db, "channel_overwrites", user_id)
|
|
||||||
|
|
||||||
# after updating the user, we send USER_UPDATE so that all the other
|
|
||||||
# clients can refresh their caches on the now-deleted user
|
|
||||||
if mass_update:
|
|
||||||
await mass_user_update(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def user_disconnect(user_id: int):
|
|
||||||
"""Disconnects the given user's devices."""
|
|
||||||
# after removing the user from all tables, we need to force
|
|
||||||
# all known user states to reconnect, causing the user to not
|
|
||||||
# be online anymore.
|
|
||||||
user_states = app.state_manager.user_states(user_id)
|
|
||||||
|
|
||||||
for state in user_states:
|
|
||||||
# make it unable to resume
|
|
||||||
app.state_manager.remove(state)
|
|
||||||
|
|
||||||
if not state.ws:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# force a close, 4000 should make the client reconnect.
|
|
||||||
await state.ws.ws.close(4000)
|
|
||||||
|
|
||||||
# force everyone to see the user as offline
|
|
||||||
await app.presence.dispatch_pres(
|
|
||||||
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/@me/delete", methods=["POST"])
|
@bp.route("/@me/delete", methods=["POST"])
|
||||||
async def delete_account():
|
async def delete_account():
|
||||||
"""Delete own account.
|
"""Delete own account.
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ from litecord.snowflake import get_snowflake
|
||||||
from litecord.utils import async_map
|
from litecord.utils import async_map
|
||||||
from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
|
from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
|
||||||
|
|
||||||
from litecord.blueprints.channel.messages import (
|
from litecord.common.messages import (
|
||||||
msg_create_request,
|
msg_create_request,
|
||||||
msg_create_check_content,
|
msg_create_check_content,
|
||||||
msg_add_attachment,
|
msg_add_attachment,
|
||||||
|
|
@ -499,9 +499,7 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
||||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||||
|
|
||||||
# spawn embedder in the background, even when we're on a webhook.
|
# spawn embedder in the background, even when we're on a webhook.
|
||||||
app.sched.spawn(
|
app.sched.spawn(process_url_embed(payload))
|
||||||
process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload)
|
|
||||||
)
|
|
||||||
|
|
||||||
# we can assume its a guild text channel, so just call it
|
# we can assume its a guild text channel, so just call it
|
||||||
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)
|
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)
|
||||||
|
|
|
||||||
|
|
@ -16,15 +16,55 @@ You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
from litecord.errors import Forbidden
|
|
||||||
|
from litecord.errors import ForbiddenDM
|
||||||
from litecord.enums import RelationshipType
|
from litecord.enums import RelationshipType
|
||||||
|
|
||||||
|
|
||||||
class ForbiddenDM(Forbidden):
|
async def channel_ack(
|
||||||
error_code = 50007
|
user_id: int, guild_id: int, channel_id: int, message_id: int = None
|
||||||
|
):
|
||||||
|
"""ACK a channel."""
|
||||||
|
|
||||||
|
if not message_id:
|
||||||
|
message_id = await app.storage.chan_last_message(channel_id)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO user_read_state
|
||||||
|
(user_id, channel_id, last_message_id, mention_count)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, 0)
|
||||||
|
ON CONFLICT ON CONSTRAINT user_read_state_pkey
|
||||||
|
DO
|
||||||
|
UPDATE
|
||||||
|
SET last_message_id = $3, mention_count = 0
|
||||||
|
WHERE user_read_state.user_id = $1
|
||||||
|
AND user_read_state.channel_id = $2
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
channel_id,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if guild_id:
|
||||||
|
await app.dispatcher.dispatch_user_guild(
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
"MESSAGE_ACK",
|
||||||
|
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# we don't use ChannelDispatcher here because since
|
||||||
|
# guild_id is None, all user devices are already subscribed
|
||||||
|
# to the given channel (a dm or a group dm)
|
||||||
|
await app.dispatcher.dispatch_user(
|
||||||
|
user_id,
|
||||||
|
"MESSAGE_ACK",
|
||||||
|
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||||
|
|
@ -32,12 +72,12 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||||
# first step is checking if there is a block in any direction
|
# first step is checking if there is a block in any direction
|
||||||
blockrow = await app.db.fetchrow(
|
blockrow = await app.db.fetchrow(
|
||||||
"""
|
"""
|
||||||
SELECT rel_type
|
SELECT rel_type
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE rel_type = $3
|
WHERE rel_type = $3
|
||||||
AND user_id IN ($1, $2)
|
AND user_id IN ($1, $2)
|
||||||
AND peer_id IN ($1, $2)
|
AND peer_id IN ($1, $2)
|
||||||
""",
|
""",
|
||||||
user_id,
|
user_id,
|
||||||
peer_id,
|
peer_id,
|
||||||
RelationshipType.BLOCK.value,
|
RelationshipType.BLOCK.value,
|
||||||
|
|
@ -75,3 +115,21 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||||
# if after this filtering we don't have any more guilds, error
|
# if after this filtering we don't have any more guilds, error
|
||||||
if not mutual_guilds:
|
if not mutual_guilds:
|
||||||
raise ForbiddenDM()
|
raise ForbiddenDM()
|
||||||
|
|
||||||
|
|
||||||
|
async def try_dm_state(user_id: int, dm_id: int):
|
||||||
|
"""Try inserting the user into the dm state
|
||||||
|
for the given DM.
|
||||||
|
|
||||||
|
Does not do anything if the user is already
|
||||||
|
in the dm state.
|
||||||
|
"""
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO dm_channel_state (user_id, dm_id)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
dm_id,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,234 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
Litecord
|
||||||
|
Copyright (C) 2018-2019 Luna Mendes
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, version 3 of the License.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
|
|
||||||
|
from ..snowflake import get_snowflake
|
||||||
|
from ..permissions import get_role_perms
|
||||||
|
from ..utils import dict_get, maybe_lazy_guild_dispatch
|
||||||
|
from ..enums import ChannelType
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_member(guild_id: int, member_id: int):
|
||||||
|
"""Do common tasks related to deleting a member from the guild,
|
||||||
|
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM members
|
||||||
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
|
""",
|
||||||
|
guild_id,
|
||||||
|
member_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_user_guild(
|
||||||
|
member_id,
|
||||||
|
guild_id,
|
||||||
|
"GUILD_DELETE",
|
||||||
|
{"guild_id": str(guild_id), "unavailable": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.dispatcher.unsub("guild", guild_id, member_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id,
|
||||||
|
"GUILD_MEMBER_REMOVE",
|
||||||
|
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_member_multi(guild_id: int, members: list):
|
||||||
|
"""Remove multiple members."""
|
||||||
|
for member_id in members:
|
||||||
|
await remove_member(guild_id, member_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_role(guild_id, name: str, **kwargs):
|
||||||
|
"""Create a role in a guild."""
|
||||||
|
new_role_id = get_snowflake()
|
||||||
|
|
||||||
|
everyone_perms = await get_role_perms(guild_id, guild_id)
|
||||||
|
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
|
||||||
|
|
||||||
|
# update all roles so that we have space for pos 1, but without
|
||||||
|
# sending GUILD_ROLE_UPDATE for everyone
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
UPDATE roles
|
||||||
|
SET
|
||||||
|
position = position + 1
|
||||||
|
WHERE guild_id = $1
|
||||||
|
AND NOT (position = 0)
|
||||||
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO roles (id, guild_id, name, color,
|
||||||
|
hoist, position, permissions, managed, mentionable)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
|
""",
|
||||||
|
new_role_id,
|
||||||
|
guild_id,
|
||||||
|
name,
|
||||||
|
dict_get(kwargs, "color", 0),
|
||||||
|
dict_get(kwargs, "hoist", False),
|
||||||
|
# always set ourselves on position 1
|
||||||
|
1,
|
||||||
|
int(dict_get(kwargs, "permissions", default_perms)),
|
||||||
|
False,
|
||||||
|
dict_get(kwargs, "mentionable", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
role = await app.storage.get_role(new_role_id, guild_id)
|
||||||
|
|
||||||
|
# we need to update the lazy guild handlers for the newly created group
|
||||||
|
await maybe_lazy_guild_dispatch(guild_id, "new_role", role)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_guild(
|
||||||
|
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
||||||
|
)
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||||
|
if ctype == ChannelType.GUILD_TEXT:
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO guild_text_channels (id, topic)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
kwargs.get("topic", ""),
|
||||||
|
)
|
||||||
|
elif ctype == ChannelType.GUILD_VOICE:
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO guild_voice_channels (id, bitrate, user_limit)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
kwargs.get("bitrate", 64),
|
||||||
|
kwargs.get("user_limit", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_guild_channel(
|
||||||
|
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
||||||
|
):
|
||||||
|
"""Create a channel in a guild."""
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO channels (id, channel_type)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
ctype.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calc new pos
|
||||||
|
max_pos = await app.db.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT MAX(position)
|
||||||
|
FROM guild_channels
|
||||||
|
WHERE guild_id = $1
|
||||||
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# account for the first channel in a guild too
|
||||||
|
max_pos = max_pos or 0
|
||||||
|
|
||||||
|
# all channels go to guild_channels
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
guild_id,
|
||||||
|
kwargs["name"],
|
||||||
|
max_pos + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# the rest of sql magic is dependant on the channel
|
||||||
|
# we're creating (a text or voice or category),
|
||||||
|
# so we use this function.
|
||||||
|
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_guild(guild_id: int):
|
||||||
|
"""Delete a single guild."""
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM guilds
|
||||||
|
WHERE guilds.id = $1
|
||||||
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Discord's client expects IDs being string
|
||||||
|
await app.dispatcher.dispatch(
|
||||||
|
"guild",
|
||||||
|
guild_id,
|
||||||
|
"GUILD_DELETE",
|
||||||
|
{
|
||||||
|
"guild_id": str(guild_id),
|
||||||
|
"id": str(guild_id),
|
||||||
|
# 'unavailable': False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove from the dispatcher so nobody
|
||||||
|
# becomes the little memer that tries to fuck up with
|
||||||
|
# everybody's gateway
|
||||||
|
await app.dispatcher.remove("guild", guild_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_guild_settings(guild_id: int, user_id: int):
|
||||||
|
"""Create guild settings for the user
|
||||||
|
joining the guild."""
|
||||||
|
|
||||||
|
# new guild_settings are based off the currently
|
||||||
|
# set guild settings (for the guild)
|
||||||
|
m_notifs = await app.db.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT default_message_notifications
|
||||||
|
FROM guilds
|
||||||
|
WHERE id = $1
|
||||||
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO guild_settings
|
||||||
|
(user_id, guild_id, message_notifications)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3)
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
m_notifs,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,189 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from quart import request, current_app as app
|
||||||
|
|
||||||
|
from litecord.errors import BadRequest
|
||||||
|
from ..snowflake import get_snowflake
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def msg_create_request() -> tuple:
|
||||||
|
"""Extract the json input and any file information
|
||||||
|
the client gave to us in the request.
|
||||||
|
|
||||||
|
This only applies to create message route.
|
||||||
|
"""
|
||||||
|
form = await request.form
|
||||||
|
request_json = await request.get_json() or {}
|
||||||
|
|
||||||
|
# NOTE: embed isn't set on form data
|
||||||
|
json_from_form = {
|
||||||
|
"content": form.get("content", ""),
|
||||||
|
"nonce": form.get("nonce", "0"),
|
||||||
|
"tts": json.loads(form.get("tts", "false")),
|
||||||
|
}
|
||||||
|
|
||||||
|
payload_json = json.loads(form.get("payload_json", "{}"))
|
||||||
|
|
||||||
|
json_from_form.update(request_json)
|
||||||
|
json_from_form.update(payload_json)
|
||||||
|
|
||||||
|
files = await request.files
|
||||||
|
|
||||||
|
# we don't really care about the given fields on the files dict, so
|
||||||
|
# we only extract the values
|
||||||
|
return json_from_form, [v for k, v in files.items()]
|
||||||
|
|
||||||
|
|
||||||
|
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
||||||
|
"""Check if there is actually any content being sent to us."""
|
||||||
|
has_content = bool(payload.get("content", ""))
|
||||||
|
has_files = len(files) > 0
|
||||||
|
|
||||||
|
embed_field = "embeds" if use_embeds else "embed"
|
||||||
|
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
||||||
|
|
||||||
|
has_total_content = has_content or has_embed or has_files
|
||||||
|
|
||||||
|
if not has_total_content:
|
||||||
|
raise BadRequest("No content has been provided.")
|
||||||
|
|
||||||
|
|
||||||
|
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
|
||||||
|
"""Add an attachment to a message.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
message_id: int
|
||||||
|
The ID of the message getting the attachment.
|
||||||
|
channel_id: int
|
||||||
|
The ID of the channel the message belongs to.
|
||||||
|
|
||||||
|
Exists because the attachment URL scheme contains
|
||||||
|
a channel id. The purpose is unknown, but we are
|
||||||
|
implementing Discord's behavior.
|
||||||
|
attachment_file: quart.FileStorage
|
||||||
|
quart FileStorage instance of the file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attachment_id = get_snowflake()
|
||||||
|
filename = attachment_file.filename
|
||||||
|
|
||||||
|
# understand file info
|
||||||
|
mime = attachment_file.mimetype
|
||||||
|
is_image = mime.startswith("image/")
|
||||||
|
|
||||||
|
img_width, img_height = None, None
|
||||||
|
|
||||||
|
# extract file size
|
||||||
|
# TODO: this is probably inneficient
|
||||||
|
file_size = attachment_file.stream.getbuffer().nbytes
|
||||||
|
|
||||||
|
if is_image:
|
||||||
|
# open with pillow, extract image size
|
||||||
|
image = Image.open(attachment_file.stream)
|
||||||
|
img_width, img_height = image.size
|
||||||
|
|
||||||
|
# NOTE: DO NOT close the image, as closing the image will
|
||||||
|
# also close the stream.
|
||||||
|
|
||||||
|
# reset it to 0 for later usage
|
||||||
|
attachment_file.stream.seek(0)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO attachments
|
||||||
|
(id, channel_id, message_id,
|
||||||
|
filename, filesize,
|
||||||
|
image, width, height)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
|
""",
|
||||||
|
attachment_id,
|
||||||
|
channel_id,
|
||||||
|
message_id,
|
||||||
|
filename,
|
||||||
|
file_size,
|
||||||
|
is_image,
|
||||||
|
img_width,
|
||||||
|
img_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
ext = filename.split(".")[-1]
|
||||||
|
|
||||||
|
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
|
||||||
|
attach_file.write(attachment_file.stream.read())
|
||||||
|
|
||||||
|
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
|
||||||
|
|
||||||
|
return attachment_id
|
||||||
|
|
||||||
|
|
||||||
|
async def msg_guild_text_mentions(
|
||||||
|
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
|
||||||
|
):
|
||||||
|
"""Calculates mention data side-effects."""
|
||||||
|
channel_id = int(payload["channel_id"])
|
||||||
|
|
||||||
|
# calculate the user ids we'll bump the mention count for
|
||||||
|
uids = set()
|
||||||
|
|
||||||
|
# first is extracting user mentions
|
||||||
|
for mention in payload["mentions"]:
|
||||||
|
uids.add(int(mention["id"]))
|
||||||
|
|
||||||
|
# then role mentions
|
||||||
|
for role_mention in payload["mention_roles"]:
|
||||||
|
role_id = int(role_mention)
|
||||||
|
member_ids = await app.storage.get_role_members(role_id)
|
||||||
|
|
||||||
|
for member_id in member_ids:
|
||||||
|
uids.add(member_id)
|
||||||
|
|
||||||
|
# at-here only updates the state
|
||||||
|
# for the users that have a state
|
||||||
|
# in the channel.
|
||||||
|
if mentions_here:
|
||||||
|
uids = set()
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count = mention_count + 1
|
||||||
|
WHERE channel_id = $1
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# at-here updates the read state
|
||||||
|
# for all users, including the ones
|
||||||
|
# that might not have read permissions
|
||||||
|
# to the channel.
|
||||||
|
if mentions_everyone:
|
||||||
|
uids = set()
|
||||||
|
|
||||||
|
member_ids = await app.storage.get_member_ids(guild_id)
|
||||||
|
|
||||||
|
await app.db.executemany(
|
||||||
|
"""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count = mention_count + 1
|
||||||
|
WHERE channel_id = $1 AND user_id = $2
|
||||||
|
""",
|
||||||
|
[(channel_id, uid) for uid in member_ids],
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id in uids:
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count = mention_count + 1
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND channel_id = $2
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,273 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
Litecord
|
||||||
|
Copyright (C) 2018-2019 Luna Mendes
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, version 3 of the License.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from random import randint
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
|
from asyncpg import UniqueViolationError
|
||||||
|
|
||||||
|
from ..snowflake import get_snowflake
|
||||||
|
from ..errors import BadRequest
|
||||||
|
from ..auth import hash_data
|
||||||
|
from ..utils import rand_hex
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def mass_user_update(user_id):
|
||||||
|
"""Dispatch USER_UPDATE in a mass way."""
|
||||||
|
# by using dispatch_with_filter
|
||||||
|
# we're guaranteeing all shards will get
|
||||||
|
# a USER_UPDATE once and not any others.
|
||||||
|
|
||||||
|
session_ids = []
|
||||||
|
|
||||||
|
public_user = await app.storage.get_user(user_id)
|
||||||
|
private_user = await app.storage.get_user(user_id, secure=True)
|
||||||
|
|
||||||
|
session_ids.extend(
|
||||||
|
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
|
||||||
|
)
|
||||||
|
|
||||||
|
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
||||||
|
friend_ids = await app.user_storage.get_friend_ids(user_id)
|
||||||
|
|
||||||
|
session_ids.extend(
|
||||||
|
await app.dispatcher.dispatch_many_filter_list(
|
||||||
|
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session_ids.extend(
|
||||||
|
await app.dispatcher.dispatch_many_filter_list(
|
||||||
|
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
|
||||||
|
|
||||||
|
return public_user, private_user
|
||||||
|
|
||||||
|
|
||||||
|
async def check_username_usage(username: str):
|
||||||
|
"""Raise an error if too many people are with the same username."""
|
||||||
|
same_username = await app.db.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM users
|
||||||
|
WHERE username = $1
|
||||||
|
""",
|
||||||
|
username,
|
||||||
|
)
|
||||||
|
|
||||||
|
if same_username > 9000:
|
||||||
|
raise BadRequest(
|
||||||
|
"Too many people.",
|
||||||
|
{
|
||||||
|
"username": "Too many people used the same username. "
|
||||||
|
"Please choose another"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _raw_discrim() -> str:
|
||||||
|
discrim_number = randint(1, 9999)
|
||||||
|
return "%04d" % discrim_number
|
||||||
|
|
||||||
|
|
||||||
|
async def roll_discrim(username: str) -> Optional[str]:
|
||||||
|
"""Roll a discriminator for a DiscordTag.
|
||||||
|
|
||||||
|
Tries to generate one 10 times.
|
||||||
|
|
||||||
|
Calls check_username_usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# we shouldn't roll discrims for usernames
|
||||||
|
# that have been used too much.
|
||||||
|
await check_username_usage(username)
|
||||||
|
|
||||||
|
# max 10 times for a reroll
|
||||||
|
for _ in range(10):
|
||||||
|
# generate random discrim
|
||||||
|
discrim = _raw_discrim()
|
||||||
|
|
||||||
|
# check if anyone is with it
|
||||||
|
res = await app.db.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT id
|
||||||
|
FROM users
|
||||||
|
WHERE username = $1 AND discriminator = $2
|
||||||
|
""",
|
||||||
|
username,
|
||||||
|
discrim,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if no user is found with the (username, discrim)
|
||||||
|
# pair, then this is unique! return it.
|
||||||
|
if res is None:
|
||||||
|
return discrim
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def create_user(username: str, email: str, password: str) -> Tuple[int, str]:
|
||||||
|
"""Create a single user.
|
||||||
|
|
||||||
|
Generates a distriminator and other information. You can fetch the user
|
||||||
|
data back with :meth:`Storage.get_user`.
|
||||||
|
"""
|
||||||
|
new_id = get_snowflake()
|
||||||
|
new_discrim = await roll_discrim(username)
|
||||||
|
|
||||||
|
if new_discrim is None:
|
||||||
|
raise BadRequest(
|
||||||
|
"Unable to register.",
|
||||||
|
{"username": "Too many people are with this username."},
|
||||||
|
)
|
||||||
|
|
||||||
|
pwd_hash = await hash_data(password)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO users
|
||||||
|
(id, email, username, discriminator, password_hash)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, $4, $5)
|
||||||
|
""",
|
||||||
|
new_id,
|
||||||
|
email,
|
||||||
|
username,
|
||||||
|
new_discrim,
|
||||||
|
pwd_hash,
|
||||||
|
)
|
||||||
|
except UniqueViolationError:
|
||||||
|
raise BadRequest("Email already used.")
|
||||||
|
|
||||||
|
return new_id, pwd_hash
|
||||||
|
|
||||||
|
|
||||||
|
async def _del_from_table(db, table: str, user_id: int):
|
||||||
|
"""Delete a row from a table."""
|
||||||
|
column = {
|
||||||
|
"channel_overwrites": "target_user",
|
||||||
|
"user_settings": "id",
|
||||||
|
"group_dm_members": "member_id",
|
||||||
|
}.get(table, "user_id")
|
||||||
|
|
||||||
|
res = await db.execute(
|
||||||
|
f"""
|
||||||
|
DELETE FROM {table}
|
||||||
|
WHERE {column} = $1
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_user(user_id, *, mass_update: bool = True):
|
||||||
|
"""Delete a user. Does not disconnect the user."""
|
||||||
|
db = app.db
|
||||||
|
|
||||||
|
new_username = f"Deleted User {rand_hex()}"
|
||||||
|
|
||||||
|
# by using a random hex in password_hash
|
||||||
|
# we break attempts at using the default '123' password hash
|
||||||
|
# to issue valid tokens for deleted users.
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
UPDATE users
|
||||||
|
SET
|
||||||
|
username = $1,
|
||||||
|
email = NULL,
|
||||||
|
mfa_enabled = false,
|
||||||
|
verified = false,
|
||||||
|
avatar = NULL,
|
||||||
|
flags = 0,
|
||||||
|
premium_since = NULL,
|
||||||
|
phone = '',
|
||||||
|
password_hash = $2
|
||||||
|
WHERE
|
||||||
|
id = $3
|
||||||
|
""",
|
||||||
|
new_username,
|
||||||
|
rand_hex(32),
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove the user from various tables
|
||||||
|
await _del_from_table(db, "user_settings", user_id)
|
||||||
|
await _del_from_table(db, "user_payment_sources", user_id)
|
||||||
|
await _del_from_table(db, "user_subscriptions", user_id)
|
||||||
|
await _del_from_table(db, "user_payments", user_id)
|
||||||
|
await _del_from_table(db, "user_read_state", user_id)
|
||||||
|
await _del_from_table(db, "guild_settings", user_id)
|
||||||
|
await _del_from_table(db, "guild_settings_channel_overrides", user_id)
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM relationships
|
||||||
|
WHERE user_id = $1 OR peer_id = $1
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# DMs are still maintained, but not the state.
|
||||||
|
await _del_from_table(db, "dm_channel_state", user_id)
|
||||||
|
|
||||||
|
# NOTE: we don't delete the group dms the user is an owner of...
|
||||||
|
# TODO: group dm owner reassign when the owner leaves a gdm
|
||||||
|
await _del_from_table(db, "group_dm_members", user_id)
|
||||||
|
|
||||||
|
await _del_from_table(db, "members", user_id)
|
||||||
|
await _del_from_table(db, "member_roles", user_id)
|
||||||
|
await _del_from_table(db, "channel_overwrites", user_id)
|
||||||
|
|
||||||
|
# after updating the user, we send USER_UPDATE so that all the other
|
||||||
|
# clients can refresh their caches on the now-deleted user
|
||||||
|
if mass_update:
|
||||||
|
await mass_user_update(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def user_disconnect(user_id: int):
|
||||||
|
"""Disconnects the given user's devices."""
|
||||||
|
# after removing the user from all tables, we need to force
|
||||||
|
# all known user states to reconnect, causing the user to not
|
||||||
|
# be online anymore.
|
||||||
|
user_states = app.state_manager.user_states(user_id)
|
||||||
|
|
||||||
|
for state in user_states:
|
||||||
|
# make it unable to resume
|
||||||
|
app.state_manager.remove(state)
|
||||||
|
|
||||||
|
if not state.ws:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# force a close, 4000 should make the client reconnect.
|
||||||
|
await state.ws.ws.close(4000)
|
||||||
|
|
||||||
|
# force everyone to see the user as offline
|
||||||
|
await app.presence.dispatch_pres(
|
||||||
|
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||||
|
)
|
||||||
|
|
@ -22,6 +22,7 @@ import asyncio
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from quart import current_app as app
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.embed.sanitizer import proxify, fetch_metadata, fetch_embed
|
from litecord.embed.sanitizer import proxify, fetch_metadata, fetch_embed
|
||||||
|
|
@ -33,10 +34,10 @@ log = Logger(__name__)
|
||||||
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
|
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
|
||||||
|
|
||||||
|
|
||||||
async def insert_media_meta(url, config, session):
|
async def insert_media_meta(url):
|
||||||
"""Insert media metadata as an embed."""
|
"""Insert media metadata as an embed."""
|
||||||
img_proxy_url = proxify(url, config=config)
|
img_proxy_url = proxify(url)
|
||||||
meta = await fetch_metadata(url, config=config, session=session)
|
meta = await fetch_metadata(url)
|
||||||
|
|
||||||
if meta is None:
|
if meta is None:
|
||||||
return
|
return
|
||||||
|
|
@ -56,19 +57,19 @@ async def insert_media_meta(url, config, session):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
|
async def msg_update_embeds(payload, new_embeds):
|
||||||
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
|
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
|
||||||
to users."""
|
to users."""
|
||||||
|
|
||||||
message_id = int(payload["id"])
|
message_id = int(payload["id"])
|
||||||
channel_id = int(payload["channel_id"])
|
channel_id = int(payload["channel_id"])
|
||||||
|
|
||||||
await storage.execute_with_json(
|
await app.storage.execute_with_json(
|
||||||
"""
|
"""
|
||||||
UPDATE messages
|
UPDATE messages
|
||||||
SET embeds = $1
|
SET embeds = $1
|
||||||
WHERE messages.id = $2
|
WHERE messages.id = $2
|
||||||
""",
|
""",
|
||||||
new_embeds,
|
new_embeds,
|
||||||
message_id,
|
message_id,
|
||||||
)
|
)
|
||||||
|
|
@ -85,7 +86,9 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
|
||||||
if "flags" in payload:
|
if "flags" in payload:
|
||||||
update_payload["flags"] = payload["flags"]
|
update_payload["flags"] = payload["flags"]
|
||||||
|
|
||||||
await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload)
|
await app.dispatcher.dispatch(
|
||||||
|
"channel", channel_id, "MESSAGE_UPDATE", update_payload
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_media_url(url) -> bool:
|
def is_media_url(url) -> bool:
|
||||||
|
|
@ -102,15 +105,13 @@ def is_media_url(url) -> bool:
|
||||||
return extension in MEDIA_EXTENSIONS
|
return extension in MEDIA_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
async def insert_mp_embed(parsed, config, session):
|
async def insert_mp_embed(parsed):
|
||||||
"""Insert mediaproxy embed."""
|
"""Insert mediaproxy embed."""
|
||||||
embed = await fetch_embed(parsed, config=config, session=session)
|
embed = await fetch_embed(parsed)
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
async def process_url_embed(
|
async def process_url_embed(payload: dict, *, delay=0):
|
||||||
config, storage, dispatcher, session, payload: dict, *, delay=0
|
|
||||||
):
|
|
||||||
"""Process URLs in a message and generate embeds based on that."""
|
"""Process URLs in a message and generate embeds based on that."""
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
|
@ -145,9 +146,9 @@ async def process_url_embed(
|
||||||
url = EmbedURL(url)
|
url = EmbedURL(url)
|
||||||
|
|
||||||
if is_media_url(url):
|
if is_media_url(url):
|
||||||
embed = await insert_media_meta(url, config, session)
|
embed = await insert_media_meta(url)
|
||||||
else:
|
else:
|
||||||
embed = await insert_mp_embed(url, config, session)
|
embed = await insert_mp_embed(url)
|
||||||
|
|
||||||
if not embed:
|
if not embed:
|
||||||
continue
|
continue
|
||||||
|
|
@ -160,4 +161,4 @@ async def process_url_embed(
|
||||||
|
|
||||||
log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
|
log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
|
||||||
|
|
||||||
await msg_update_embeds(payload, new_embeds, storage, dispatcher)
|
await msg_update_embeds(payload, new_embeds)
|
||||||
|
|
|
||||||
|
|
@ -75,35 +75,24 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _mk_cfg_sess(config, session) -> tuple:
|
def _md_base() -> Optional[tuple]:
|
||||||
"""Return a tuple of (config, session)."""
|
|
||||||
if config is None:
|
|
||||||
config = app.config
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
session = app.session
|
|
||||||
|
|
||||||
return config, session
|
|
||||||
|
|
||||||
|
|
||||||
def _md_base(config) -> Optional[tuple]:
|
|
||||||
"""Return the protocol and base url for the mediaproxy."""
|
"""Return the protocol and base url for the mediaproxy."""
|
||||||
md_base_url = config["MEDIA_PROXY"]
|
md_base_url = app.config["MEDIA_PROXY"]
|
||||||
if md_base_url is None:
|
if md_base_url is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
proto = "https" if config["IS_SSL"] else "http"
|
proto = "https" if app.config["IS_SSL"] else "http"
|
||||||
|
|
||||||
return proto, md_base_url
|
return proto, md_base_url
|
||||||
|
|
||||||
|
|
||||||
def make_md_req_url(config, scope: str, url):
|
def make_md_req_url(scope: str, url):
|
||||||
"""Make a mediaproxy request URL given the config, scope, and the url
|
"""Make a mediaproxy request URL given the scope and the url
|
||||||
to be proxied.
|
to be proxied.
|
||||||
|
|
||||||
When MEDIA_PROXY is None, however, returns the original URL.
|
When MEDIA_PROXY is None, however, returns the original URL.
|
||||||
"""
|
"""
|
||||||
base = _md_base(config)
|
base = _md_base()
|
||||||
if base is None:
|
if base is None:
|
||||||
return url.url if isinstance(url, EmbedURL) else url
|
return url.url if isinstance(url, EmbedURL) else url
|
||||||
|
|
||||||
|
|
@ -111,38 +100,25 @@ def make_md_req_url(config, scope: str, url):
|
||||||
return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
|
return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
|
||||||
|
|
||||||
|
|
||||||
def proxify(url, *, config=None) -> str:
|
def proxify(url) -> str:
|
||||||
"""Return a mediaproxy url for the given EmbedURL. Returns an
|
"""Return a mediaproxy url for the given EmbedURL. Returns an
|
||||||
/img/ scope."""
|
/img/ scope."""
|
||||||
config, _sess = _mk_cfg_sess(config, False)
|
|
||||||
|
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
url = EmbedURL(url)
|
url = EmbedURL(url)
|
||||||
|
|
||||||
return make_md_req_url(config, "img", url)
|
return make_md_req_url("img", url)
|
||||||
|
|
||||||
|
|
||||||
async def _md_client_req(
|
async def _md_client_req(
|
||||||
config, session, scope: str, url, *, ret_resp=False
|
scope: str, url, *, ret_resp=False
|
||||||
) -> Optional[Union[Tuple, Dict]]:
|
) -> Optional[Union[Tuple, Dict]]:
|
||||||
"""Makes a request to the mediaproxy.
|
"""Makes a request to the mediaproxy.
|
||||||
|
|
||||||
This has common code between all the main mediaproxy request functions
|
This has common code between all the main mediaproxy request functions
|
||||||
to decrease code repetition.
|
to decrease code repetition.
|
||||||
|
|
||||||
Note that config and session exist because there are cases where the app
|
|
||||||
isn't retrievable (as those functions usually run in background tasks,
|
|
||||||
not in the app itself).
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
config: dict-like
|
|
||||||
the app configuration, if None, this will get the global one from the
|
|
||||||
app instance.
|
|
||||||
session: aiohttp client session
|
|
||||||
the aiohttp ClientSession instance to use, if None, this will get
|
|
||||||
the global one from the app.
|
|
||||||
|
|
||||||
scope: str
|
scope: str
|
||||||
the scope of your request. one of 'meta', 'img', or 'embed' are
|
the scope of your request. one of 'meta', 'img', or 'embed' are
|
||||||
available for the mediaproxy's API.
|
available for the mediaproxy's API.
|
||||||
|
|
@ -155,14 +131,12 @@ async def _md_client_req(
|
||||||
the raw bytes of the response, but by the time this function is
|
the raw bytes of the response, but by the time this function is
|
||||||
returned, the response object is invalid and the socket is closed
|
returned, the response object is invalid and the socket is closed
|
||||||
"""
|
"""
|
||||||
config, session = _mk_cfg_sess(config, session)
|
|
||||||
|
|
||||||
if not isinstance(url, EmbedURL):
|
if not isinstance(url, EmbedURL):
|
||||||
url = EmbedURL(url)
|
url = EmbedURL(url)
|
||||||
|
|
||||||
request_url = make_md_req_url(config, scope, url)
|
request_url = make_md_req_url(scope, url)
|
||||||
|
|
||||||
async with session.get(request_url) as resp:
|
async with app.session.get(request_url) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
if ret_resp:
|
if ret_resp:
|
||||||
return resp, await resp.read()
|
return resp, await resp.read()
|
||||||
|
|
@ -174,18 +148,18 @@ async def _md_client_req(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
|
async def fetch_metadata(url) -> Optional[Dict]:
|
||||||
"""Fetch metadata for a url (image width, mime, etc)."""
|
"""Fetch metadata for a url (image width, mime, etc)."""
|
||||||
return await _md_client_req(config, session, "meta", url)
|
return await _md_client_req("meta", url)
|
||||||
|
|
||||||
|
|
||||||
async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
|
async def fetch_raw_img(url) -> Optional[tuple]:
|
||||||
"""Fetch raw data for a url (the bytes given off, used to proxy images).
|
"""Fetch raw data for a url (the bytes given off, used to proxy images).
|
||||||
|
|
||||||
Returns a tuple containing the response object and the raw bytes given by
|
Returns a tuple containing the response object and the raw bytes given by
|
||||||
the website.
|
the website.
|
||||||
"""
|
"""
|
||||||
tup = await _md_client_req(config, session, "img", url, ret_resp=True)
|
tup = await _md_client_req("img", url, ret_resp=True)
|
||||||
|
|
||||||
if not tup:
|
if not tup:
|
||||||
return None
|
return None
|
||||||
|
|
@ -193,13 +167,13 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
|
||||||
return tup
|
return tup
|
||||||
|
|
||||||
|
|
||||||
async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]:
|
async def fetch_embed(url) -> Dict[str, Any]:
|
||||||
"""Fetch an embed for a given webpage (an automatically generated embed
|
"""Fetch an embed for a given webpage (an automatically generated embed
|
||||||
by the mediaproxy, look over the project on how it generates embeds).
|
by the mediaproxy, look over the project on how it generates embeds).
|
||||||
|
|
||||||
Returns a discord embed object.
|
Returns a discord embed object.
|
||||||
"""
|
"""
|
||||||
return await _md_client_req(config, session, "embed", url)
|
return await _md_client_req("embed", url)
|
||||||
|
|
||||||
|
|
||||||
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
|
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
|
||||||
|
|
|
||||||
|
|
@ -54,27 +54,21 @@ class Flags:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init_subclass__(cls, **_kwargs):
|
def __init_subclass__(cls, **_kwargs):
|
||||||
attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
|
# get only the members that represent a field
|
||||||
|
cls._attrs = inspect.getmembers(cls, lambda x: isinstance(x, int))
|
||||||
|
|
||||||
def _make_int(value):
|
@classmethod
|
||||||
res = Flags()
|
def from_int(cls, value: int):
|
||||||
|
"""Create a Flags from a given int value."""
|
||||||
|
res = Flags()
|
||||||
|
setattr(res, "value", value)
|
||||||
|
|
||||||
setattr(res, "value", value)
|
for attr, val in cls._attrs:
|
||||||
|
has_attr = (value & val) == val
|
||||||
|
# set attributes dynamically
|
||||||
|
setattr(res, f"is_{attr}", has_attr)
|
||||||
|
|
||||||
for attr, val in attrs:
|
return res
|
||||||
# get only the ones that represent a field in the
|
|
||||||
# number's bits
|
|
||||||
if not isinstance(val, int):
|
|
||||||
continue
|
|
||||||
|
|
||||||
has_attr = (value & val) == val
|
|
||||||
|
|
||||||
# set each attribute
|
|
||||||
setattr(res, f"is_{attr}", has_attr)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
cls.from_int = _make_int
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelType(EasyEnum):
|
class ChannelType(EasyEnum):
|
||||||
|
|
|
||||||
|
|
@ -116,6 +116,10 @@ class Forbidden(LitecordError):
|
||||||
status_code = 403
|
status_code = 403
|
||||||
|
|
||||||
|
|
||||||
|
class ForbiddenDM(Forbidden):
|
||||||
|
error_code = 50007
|
||||||
|
|
||||||
|
|
||||||
class NotFound(LitecordError):
|
class NotFound(LitecordError):
|
||||||
status_code = 404
|
status_code = 404
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ import collections
|
||||||
import asyncio
|
import asyncio
|
||||||
import pprint
|
import pprint
|
||||||
import zlib
|
import zlib
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Iterable
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
@ -56,41 +56,15 @@ WebsocketProperties = collections.namedtuple(
|
||||||
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
|
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
WebsocketObjects = collections.namedtuple(
|
|
||||||
"WebsocketObjects",
|
|
||||||
(
|
|
||||||
"db",
|
|
||||||
"state_manager",
|
|
||||||
"storage",
|
|
||||||
"loop",
|
|
||||||
"dispatcher",
|
|
||||||
"presence",
|
|
||||||
"ratelimiter",
|
|
||||||
"user_storage",
|
|
||||||
"voice",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
"""Main gateway websocket logic."""
|
"""Main gateway websocket logic."""
|
||||||
|
|
||||||
def __init__(self, ws, app, **kwargs):
|
def __init__(self, ws, app, **kwargs):
|
||||||
self.ext = WebsocketObjects(
|
self.app = app
|
||||||
app.db,
|
self.storage = app.storage
|
||||||
app.state_manager,
|
self.user_storage = app.user_storage
|
||||||
app.storage,
|
self.presence = app.presence
|
||||||
app.loop,
|
|
||||||
app.dispatcher,
|
|
||||||
app.presence,
|
|
||||||
app.ratelimiter,
|
|
||||||
app.user_storage,
|
|
||||||
app.voice,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.storage = self.ext.storage
|
|
||||||
self.user_storage = self.ext.user_storage
|
|
||||||
self.presence = self.ext.presence
|
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
||||||
self.wsp = WebsocketProperties(
|
self.wsp = WebsocketProperties(
|
||||||
|
|
@ -225,7 +199,7 @@ class GatewayWebsocket:
|
||||||
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
||||||
|
|
||||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
def _check_ratelimit(self, key: str, ratelimit_key):
|
||||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
|
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
|
||||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||||
return bucket.update_rate_limit()
|
return bucket.update_rate_limit()
|
||||||
|
|
||||||
|
|
@ -245,7 +219,7 @@ class GatewayWebsocket:
|
||||||
if task:
|
if task:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
|
self.wsp.tasks["heartbeat"] = self.app.loop.create_task(
|
||||||
task_wrapper("hb wait", self._hb_wait(interval))
|
task_wrapper("hb wait", self._hb_wait(interval))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -330,7 +304,7 @@ class GatewayWebsocket:
|
||||||
if r["type"] == RelationshipType.FRIEND.value
|
if r["type"] == RelationshipType.FRIEND.value
|
||||||
]
|
]
|
||||||
|
|
||||||
friend_presences = await self.ext.presence.friend_presences(friend_ids)
|
friend_presences = await self.app.presence.friend_presences(friend_ids)
|
||||||
settings = await self.user_storage.get_user_settings(user_id)
|
settings = await self.user_storage.get_user_settings(user_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -377,14 +351,14 @@ class GatewayWebsocket:
|
||||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||||
|
|
||||||
# async dispatch of guilds
|
# async dispatch of guilds
|
||||||
self.ext.loop.create_task(self._guild_dispatch(guilds))
|
self.app.loop.create_task(self._guild_dispatch(guilds))
|
||||||
|
|
||||||
async def _check_shards(self, shard, user_id):
|
async def _check_shards(self, shard, user_id):
|
||||||
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
"""Check if the given `shard` value in IDENTIFY has good enough values.
|
||||||
"""
|
"""
|
||||||
current_shard, shard_count = shard
|
current_shard, shard_count = shard
|
||||||
|
|
||||||
guilds = await self.ext.db.fetchval(
|
guilds = await self.app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM members
|
FROM members
|
||||||
|
|
@ -460,7 +434,7 @@ class GatewayWebsocket:
|
||||||
("channel", gdm_ids),
|
("channel", gdm_ids),
|
||||||
]
|
]
|
||||||
|
|
||||||
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
|
await self.app.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||||
|
|
||||||
if not self.state.bot:
|
if not self.state.bot:
|
||||||
# subscribe to all friends
|
# subscribe to all friends
|
||||||
|
|
@ -468,7 +442,7 @@ class GatewayWebsocket:
|
||||||
# when they come online)
|
# when they come online)
|
||||||
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
||||||
log.info("subscribing to {} friends", len(friend_ids))
|
log.info("subscribing to {} friends", len(friend_ids))
|
||||||
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
|
await self.app.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||||
|
|
||||||
async def update_status(self, status: dict):
|
async def update_status(self, status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
|
|
@ -520,7 +494,7 @@ class GatewayWebsocket:
|
||||||
f'Updating presence status={status["status"]} for '
|
f'Updating presence status={status["status"]} for '
|
||||||
f"uid={self.state.user_id}"
|
f"uid={self.state.user_id}"
|
||||||
)
|
)
|
||||||
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
|
await self.app.presence.dispatch_pres(self.state.user_id, self.state.presence)
|
||||||
|
|
||||||
async def handle_1(self, payload: Dict[str, Any]):
|
async def handle_1(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 1 Heartbeat packets."""
|
"""Handle OP 1 Heartbeat packets."""
|
||||||
|
|
@ -558,13 +532,13 @@ class GatewayWebsocket:
|
||||||
presence = data.get("presence")
|
presence = data.get("presence")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token, self.ext.db)
|
user_id = await raw_token_check(token, self.app.db)
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, "Authentication failed")
|
raise WebsocketClose(4004, "Authentication failed")
|
||||||
|
|
||||||
await self._connect_ratelimit(user_id)
|
await self._connect_ratelimit(user_id)
|
||||||
|
|
||||||
bot = await self.ext.db.fetchval(
|
bot = await self.app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
SELECT bot FROM users
|
SELECT bot FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
@ -587,7 +561,7 @@ class GatewayWebsocket:
|
||||||
)
|
)
|
||||||
|
|
||||||
# link the state to the user
|
# link the state to the user
|
||||||
self.ext.state_manager.insert(self.state)
|
self.app.state_manager.insert(self.state)
|
||||||
|
|
||||||
await self.update_status(presence)
|
await self.update_status(presence)
|
||||||
await self.subscribe_all(data.get("guild_subscriptions", True))
|
await self.subscribe_all(data.get("guild_subscriptions", True))
|
||||||
|
|
@ -631,12 +605,12 @@ class GatewayWebsocket:
|
||||||
# if its null and null, disconnect the user from any voice
|
# if its null and null, disconnect the user from any voice
|
||||||
# TODO: maybe just leave from DMs? idk...
|
# TODO: maybe just leave from DMs? idk...
|
||||||
if channel_id is None and guild_id is None:
|
if channel_id is None and guild_id is None:
|
||||||
return await self.ext.voice.leave_all(self.state.user_id)
|
return await self.app.voice.leave_all(self.state.user_id)
|
||||||
|
|
||||||
# if guild is not none but channel is, we are leaving
|
# if guild is not none but channel is, we are leaving
|
||||||
# a guild's channel
|
# a guild's channel
|
||||||
if channel_id is None:
|
if channel_id is None:
|
||||||
return await self.ext.voice.leave(guild_id, self.state.user_id)
|
return await self.app.voice.leave(guild_id, self.state.user_id)
|
||||||
|
|
||||||
# fetch an existing state given user and guild OR user and channel
|
# fetch an existing state given user and guild OR user and channel
|
||||||
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
|
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
|
||||||
|
|
@ -659,10 +633,10 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
# this state id format takes care of that.
|
# this state id format takes care of that.
|
||||||
voice_key = (self.state.user_id, state_id2)
|
voice_key = (self.state.user_id, state_id2)
|
||||||
voice_state = await self.ext.voice.get_state(voice_key)
|
voice_state = await self.app.voice.get_state(voice_key)
|
||||||
|
|
||||||
if voice_state is None:
|
if voice_state is None:
|
||||||
return await self.ext.voice.create_state(voice_key, data)
|
return await self.app.voice.create_state(voice_key, data)
|
||||||
|
|
||||||
same_guild = guild_id == voice_state.guild_id
|
same_guild = guild_id == voice_state.guild_id
|
||||||
same_channel = channel_id == voice_state.channel_id
|
same_channel = channel_id == voice_state.channel_id
|
||||||
|
|
@ -670,10 +644,10 @@ class GatewayWebsocket:
|
||||||
prop = await self._vsu_get_prop(voice_state, data)
|
prop = await self._vsu_get_prop(voice_state, data)
|
||||||
|
|
||||||
if same_guild and same_channel:
|
if same_guild and same_channel:
|
||||||
return await self.ext.voice.update_state(voice_state, prop)
|
return await self.app.voice.update_state(voice_state, prop)
|
||||||
|
|
||||||
if same_guild and not same_channel:
|
if same_guild and not same_channel:
|
||||||
return await self.ext.voice.move_state(voice_state, channel_id)
|
return await self.app.voice.move_state(voice_state, channel_id)
|
||||||
|
|
||||||
async def _handle_5(self, payload: Dict[str, Any]):
|
async def _handle_5(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 5 Voice Server Ping.
|
"""Handle OP 5 Voice Server Ping.
|
||||||
|
|
@ -698,9 +672,9 @@ class GatewayWebsocket:
|
||||||
# since the state will be removed from
|
# since the state will be removed from
|
||||||
# the manager, it will become unreachable
|
# the manager, it will become unreachable
|
||||||
# when trying to resume.
|
# when trying to resume.
|
||||||
self.ext.state_manager.remove(self.state)
|
self.app.state_manager.remove(self.state)
|
||||||
|
|
||||||
async def _resume(self, replay_seqs: iter):
|
async def _resume(self, replay_seqs: Iterable):
|
||||||
presences = []
|
presences = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -740,12 +714,12 @@ class GatewayWebsocket:
|
||||||
raise DecodeError("Invalid resume payload")
|
raise DecodeError("Invalid resume payload")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token, self.ext.db)
|
user_id = await raw_token_check(token, self.app.db)
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, "Invalid token")
|
raise WebsocketClose(4004, "Invalid token")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = self.ext.state_manager.fetch(user_id, sess_id)
|
state = self.app.state_manager.fetch(user_id, sess_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return await self.invalidate_session(False)
|
return await self.invalidate_session(False)
|
||||||
|
|
||||||
|
|
@ -948,7 +922,7 @@ class GatewayWebsocket:
|
||||||
log.debug("lazy request: members: {}", data.get("members", []))
|
log.debug("lazy request: members: {}", data.get("members", []))
|
||||||
|
|
||||||
# make shard query
|
# make shard query
|
||||||
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
|
lazy_guilds = self.app.dispatcher.backends["lazy_guild"]
|
||||||
|
|
||||||
for chan_id, ranges in data.get("channels", {}).items():
|
for chan_id, ranges in data.get("channels", {}).items():
|
||||||
chan_id = int(chan_id)
|
chan_id = int(chan_id)
|
||||||
|
|
@ -992,10 +966,10 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
# close anyone trying to login while the
|
# close anyone trying to login while the
|
||||||
# server is shutting down
|
# server is shutting down
|
||||||
if self.ext.state_manager.closed:
|
if self.app.state_manager.closed:
|
||||||
raise WebsocketClose(4000, "state manager closed")
|
raise WebsocketClose(4000, "state manager closed")
|
||||||
|
|
||||||
if not self.ext.state_manager.accept_new:
|
if not self.app.state_manager.accept_new:
|
||||||
raise WebsocketClose(4000, "state manager closed for new")
|
raise WebsocketClose(4000, "state manager closed for new")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -1016,7 +990,7 @@ class GatewayWebsocket:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
if self.state:
|
if self.state:
|
||||||
self.ext.state_manager.remove(self.state)
|
self.app.state_manager.remove(self.state)
|
||||||
self.state.ws = None
|
self.state.ws = None
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
|
|
@ -1031,14 +1005,14 @@ class GatewayWebsocket:
|
||||||
# TODO: account for sharding
|
# TODO: account for sharding
|
||||||
# this only updates status to offline once
|
# this only updates status to offline once
|
||||||
# ALL shards have come offline
|
# ALL shards have come offline
|
||||||
states = self.ext.state_manager.user_states(user_id)
|
states = self.app.state_manager.user_states(user_id)
|
||||||
with_ws = [s for s in states if s.ws]
|
with_ws = [s for s in states if s.ws]
|
||||||
|
|
||||||
# there arent any other states with websocket
|
# there arent any other states with websocket
|
||||||
if not with_ws:
|
if not with_ws:
|
||||||
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||||
|
|
||||||
await self.ext.presence.dispatch_pres(user_id, offline)
|
await self.app.presence.dispatch_pres(user_id, offline)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Wrap :meth:`listen_messages` inside
|
"""Wrap :meth:`listen_messages` inside
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from quart.ctx import copy_current_app_context
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -47,9 +49,14 @@ class JobManager:
|
||||||
|
|
||||||
def spawn(self, coro):
|
def spawn(self, coro):
|
||||||
"""Spawn a given future or coroutine in the background."""
|
"""Spawn a given future or coroutine in the background."""
|
||||||
task = self.loop.create_task(self._wrapper(coro))
|
|
||||||
|
|
||||||
|
@copy_current_app_context
|
||||||
|
async def _ctx_wrapper_bg() -> Any:
|
||||||
|
return await coro
|
||||||
|
|
||||||
|
task = self.loop.create_task(self._wrapper(_ctx_wrapper_bg()))
|
||||||
self.jobs.append(task)
|
self.jobs.append(task)
|
||||||
|
return task
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Close the job manager, cancelling all existing jobs.
|
"""Close the job manager, cancelling all existing jobs.
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import secrets
|
||||||
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
|
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from quart.json import JSONEncoder
|
from quart.json import JSONEncoder
|
||||||
from quart import current_app as app
|
from quart import current_app as app, request
|
||||||
|
|
||||||
|
from .errors import BadRequest
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -233,3 +236,57 @@ def maybe_int(val: Any) -> Union[int, Any]:
|
||||||
return int(val)
|
return int(val)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
async def maybe_lazy_guild_dispatch(
|
||||||
|
guild_id: int, event: str, role, force: bool = False
|
||||||
|
):
|
||||||
|
# sometimes we want to dispatch an event
|
||||||
|
# even if the role isn't hoisted
|
||||||
|
|
||||||
|
# an example of such a case is when a role loses
|
||||||
|
# its hoist status.
|
||||||
|
|
||||||
|
# check if is a dict first because role_delete
|
||||||
|
# only receives the role id.
|
||||||
|
if isinstance(role, dict) and not role["hoist"] and not force:
|
||||||
|
return
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
||||||
|
"""Extract a limit kwarg."""
|
||||||
|
try:
|
||||||
|
limit = int(request_.args.get("limit", default))
|
||||||
|
|
||||||
|
if limit not in range(0, max_val + 1):
|
||||||
|
raise ValueError()
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise BadRequest("limit not int")
|
||||||
|
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
||||||
|
"""Extract a 2-tuple out of request arguments."""
|
||||||
|
before, after = None, None
|
||||||
|
|
||||||
|
if "around" in request.args:
|
||||||
|
average = int(limit / 2)
|
||||||
|
around = int(args["around"])
|
||||||
|
|
||||||
|
after = around - average
|
||||||
|
before = around + average
|
||||||
|
|
||||||
|
elif "before" in args:
|
||||||
|
before = int(args["before"])
|
||||||
|
elif "after" in args:
|
||||||
|
before = int(args["after"])
|
||||||
|
|
||||||
|
return before, after
|
||||||
|
|
||||||
|
|
||||||
|
def rand_hex(length: int = 8) -> str:
|
||||||
|
"""Generate random hex characters."""
|
||||||
|
return secrets.token_hex(length)[:length]
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from litecord.auth import create_user
|
from litecord.common.users import create_user, delete_user
|
||||||
from litecord.blueprints.auth import make_token
|
from litecord.blueprints.auth import make_token
|
||||||
from litecord.blueprints.users import delete_user
|
|
||||||
from litecord.enums import UserFlags
|
from litecord.enums import UserFlags
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
4
run.py
4
run.py
|
|
@ -337,9 +337,9 @@ async def api_index(app_):
|
||||||
|
|
||||||
async def post_app_start(app_):
|
async def post_app_start(app_):
|
||||||
# we'll need to start a billing job
|
# we'll need to start a billing job
|
||||||
app_.sched.spawn(payment_job(app_))
|
app_.sched.spawn(payment_job())
|
||||||
app_.sched.spawn(api_index(app_))
|
app_.sched.spawn(api_index(app_))
|
||||||
app_.sched.spawn(guild_region_check(app_))
|
app_.sched.spawn(guild_region_check())
|
||||||
|
|
||||||
|
|
||||||
def start_websocket(host, port, ws_handler) -> asyncio.Future:
|
def start_websocket(host, port, ws_handler) -> asyncio.Future:
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,9 @@ from tests.common import email, TestClient
|
||||||
|
|
||||||
from run import app as main_app, set_blueprints
|
from run import app as main_app, set_blueprints
|
||||||
|
|
||||||
from litecord.auth import create_user
|
from litecord.common.users import create_user, delete_user
|
||||||
from litecord.enums import UserFlags
|
from litecord.enums import UserFlags
|
||||||
from litecord.blueprints.auth import make_token
|
from litecord.blueprints.auth import make_token
|
||||||
from litecord.blueprints.users import delete_user
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="app")
|
@pytest.fixture(name="app")
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,11 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
||||||
return rjson
|
return rjson
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_guild(test_cli, guild_id: int):
|
||||||
|
async with test_cli.app.app_context():
|
||||||
|
await delete_guild(int(guild_id))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_guild_fetch(test_cli_staff):
|
async def test_guild_fetch(test_cli_staff):
|
||||||
"""Test the creation and fetching of a guild via the Admin API."""
|
"""Test the creation and fetching of a guild via the Admin API."""
|
||||||
|
|
@ -63,7 +68,7 @@ async def test_guild_fetch(test_cli_staff):
|
||||||
try:
|
try:
|
||||||
await _fetch_guild(test_cli_staff, guild_id)
|
await _fetch_guild(test_cli_staff, guild_id)
|
||||||
finally:
|
finally:
|
||||||
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
await _delete_guild(test_cli_staff, int(guild_id))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -91,7 +96,7 @@ async def test_guild_update(test_cli_staff):
|
||||||
rjson = await _fetch_guild(test_cli_staff, guild_id)
|
rjson = await _fetch_guild(test_cli_staff, guild_id)
|
||||||
assert rjson["unavailable"]
|
assert rjson["unavailable"]
|
||||||
finally:
|
finally:
|
||||||
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
await _delete_guild(test_cli_staff, int(guild_id))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -113,4 +118,4 @@ async def test_guild_delete(test_cli_staff):
|
||||||
assert rjson["error"]
|
assert rjson["error"]
|
||||||
assert rjson["code"] == GuildNotFound.error_code
|
assert rjson["code"] == GuildNotFound.error_code
|
||||||
finally:
|
finally:
|
||||||
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
await _delete_guild(test_cli_staff, int(guild_id))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue