mirror of https://gitlab.com/litecord/litecord.git
black fmt pass
This commit is contained in:
parent
0bc4b1ba3f
commit
83a1c1ae29
25
config.ci.py
25
config.ci.py
|
|
@ -17,13 +17,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
MODE = 'CI'
|
||||
MODE = "CI"
|
||||
|
||||
|
||||
class Config:
|
||||
"""Default configuration values for litecord."""
|
||||
MAIN_URL = 'localhost:1'
|
||||
NAME = 'gitlab ci'
|
||||
|
||||
MAIN_URL = "localhost:1"
|
||||
NAME = "gitlab ci"
|
||||
|
||||
# Enable debug logging?
|
||||
DEBUG = False
|
||||
|
|
@ -37,11 +38,11 @@ class Config:
|
|||
# Set this url to somewhere *your users*
|
||||
# will hit the websocket.
|
||||
# e.g 'gateway.example.com' for reverse proxies.
|
||||
WEBSOCKET_URL = 'localhost:5001'
|
||||
WEBSOCKET_URL = "localhost:5001"
|
||||
|
||||
# Where to host the websocket?
|
||||
# (a local address the server will bind to)
|
||||
WS_HOST = 'localhost'
|
||||
WS_HOST = "localhost"
|
||||
WS_PORT = 5001
|
||||
|
||||
# Postgres credentials
|
||||
|
|
@ -51,10 +52,10 @@ class Config:
|
|||
class Development(Config):
|
||||
DEBUG = True
|
||||
POSTGRES = {
|
||||
'host': 'localhost',
|
||||
'user': 'litecord',
|
||||
'password': '123',
|
||||
'database': 'litecord',
|
||||
"host": "localhost",
|
||||
"user": "litecord",
|
||||
"password": "123",
|
||||
"database": "litecord",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -66,8 +67,4 @@ class Production(Config):
|
|||
class CI(Config):
|
||||
DEBUG = True
|
||||
|
||||
POSTGRES = {
|
||||
'host': 'postgres',
|
||||
'user': 'postgres',
|
||||
'password': ''
|
||||
}
|
||||
POSTGRES = {"host": "postgres", "user": "postgres", "password": ""}
|
||||
|
|
|
|||
|
|
@ -17,16 +17,17 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
MODE = 'Development'
|
||||
MODE = "Development"
|
||||
|
||||
|
||||
class Config:
|
||||
"""Default configuration values for litecord."""
|
||||
|
||||
#: Main URL of the instance.
|
||||
MAIN_URL = 'discordapp.io'
|
||||
MAIN_URL = "discordapp.io"
|
||||
|
||||
#: Name of the instance
|
||||
NAME = 'Litecord/Nya'
|
||||
NAME = "Litecord/Nya"
|
||||
|
||||
#: Enable debug logging?
|
||||
DEBUG = False
|
||||
|
|
@ -45,17 +46,17 @@ class Config:
|
|||
# Set this url to somewhere *your users*
|
||||
# will hit the websocket.
|
||||
# e.g 'gateway.example.com' for reverse proxies.
|
||||
WEBSOCKET_URL = 'localhost:5001'
|
||||
WEBSOCKET_URL = "localhost:5001"
|
||||
|
||||
#: Where to host the websocket?
|
||||
# (a local address the server will bind to)
|
||||
WS_HOST = '0.0.0.0'
|
||||
WS_HOST = "0.0.0.0"
|
||||
WS_PORT = 5001
|
||||
|
||||
#: Mediaproxy URL on the internet
|
||||
# mediaproxy is made to prevent client IPs being leaked.
|
||||
# None is a valid value if you don't want to deploy mediaproxy.
|
||||
MEDIA_PROXY = 'localhost:5002'
|
||||
MEDIA_PROXY = "localhost:5002"
|
||||
|
||||
#: Postgres credentials
|
||||
POSTGRES = {}
|
||||
|
|
@ -65,10 +66,10 @@ class Development(Config):
|
|||
DEBUG = True
|
||||
|
||||
POSTGRES = {
|
||||
'host': 'localhost',
|
||||
'user': 'litecord',
|
||||
'password': '123',
|
||||
'database': 'litecord',
|
||||
"host": "localhost",
|
||||
"user": "litecord",
|
||||
"password": "123",
|
||||
"database": "litecord",
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -77,8 +78,8 @@ class Production(Config):
|
|||
IS_SSL = True
|
||||
|
||||
POSTGRES = {
|
||||
'host': 'some_production_postgres',
|
||||
'user': 'some_production_user',
|
||||
'password': 'some_production_password',
|
||||
'database': 'litecord_or_anything_else_really',
|
||||
"host": "some_production_postgres",
|
||||
"user": "some_production_user",
|
||||
"password": "some_production_password",
|
||||
"database": "litecord_or_anything_else_really",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -19,42 +19,33 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
from litecord.enums import Feature, UserFlags
|
||||
|
||||
VOICE_SERVER = {
|
||||
'hostname': {'type': 'string', 'maxlength': 255, 'required': True}
|
||||
}
|
||||
VOICE_SERVER = {"hostname": {"type": "string", "maxlength": 255, "required": True}}
|
||||
|
||||
VOICE_REGION = {
|
||||
'id': {'type': 'string', 'maxlength': 255, 'required': True},
|
||||
'name': {'type': 'string', 'maxlength': 255, 'required': True},
|
||||
|
||||
'vip': {'type': 'boolean', 'default': False},
|
||||
'deprecated': {'type': 'boolean', 'default': False},
|
||||
'custom': {'type': 'boolean', 'default': False},
|
||||
"id": {"type": "string", "maxlength": 255, "required": True},
|
||||
"name": {"type": "string", "maxlength": 255, "required": True},
|
||||
"vip": {"type": "boolean", "default": False},
|
||||
"deprecated": {"type": "boolean", "default": False},
|
||||
"custom": {"type": "boolean", "default": False},
|
||||
}
|
||||
|
||||
FEATURES = {
|
||||
'features': {
|
||||
'type': 'list', 'required': True,
|
||||
|
||||
"features": {
|
||||
"type": "list",
|
||||
"required": True,
|
||||
# using Feature doesn't seem to work with a "not callable" error.
|
||||
'schema': {'coerce': lambda x: Feature(x)}
|
||||
"schema": {"coerce": lambda x: Feature(x)},
|
||||
}
|
||||
}
|
||||
|
||||
USER_CREATE = {
|
||||
'username': {'type': 'username', 'required': True},
|
||||
'email': {'type': 'email', 'required': True},
|
||||
'password': {'type': 'string', 'minlength': 5, 'required': True},
|
||||
"username": {"type": "username", "required": True},
|
||||
"email": {"type": "email", "required": True},
|
||||
"password": {"type": "string", "minlength": 5, "required": True},
|
||||
}
|
||||
|
||||
INSTANCE_INVITE = {
|
||||
'max_uses': {'type': 'integer', 'required': True}
|
||||
}
|
||||
INSTANCE_INVITE = {"max_uses": {"type": "integer", "required": True}}
|
||||
|
||||
GUILD_UPDATE = {
|
||||
'unavailable': {'type': 'boolean', 'required': False}
|
||||
}
|
||||
GUILD_UPDATE = {"unavailable": {"type": "boolean", "required": False}}
|
||||
|
||||
USER_UPDATE = {
|
||||
'flags': {'required': False, 'coerce': UserFlags.from_int}
|
||||
}
|
||||
USER_UPDATE = {"flags": {"required": False, "coerce": UserFlags.from_int}}
|
||||
|
|
|
|||
100
litecord/auth.py
100
litecord/auth.py
|
|
@ -55,44 +55,50 @@ async def raw_token_check(token: str, db=None) -> int:
|
|||
|
||||
# just try by fragments instead of
|
||||
# unpacking
|
||||
fragments = token.split('.')
|
||||
fragments = token.split(".")
|
||||
user_id = fragments[0]
|
||||
|
||||
try:
|
||||
user_id = base64.b64decode(user_id.encode())
|
||||
user_id = int(user_id)
|
||||
except (ValueError, binascii.Error):
|
||||
raise Unauthorized('Invalid user ID type')
|
||||
raise Unauthorized("Invalid user ID type")
|
||||
|
||||
pwd_hash = await db.fetchval("""
|
||||
pwd_hash = await db.fetchval(
|
||||
"""
|
||||
SELECT password_hash
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
if not pwd_hash:
|
||||
raise Unauthorized('User ID not found')
|
||||
raise Unauthorized("User ID not found")
|
||||
|
||||
signer = TimestampSigner(pwd_hash)
|
||||
|
||||
try:
|
||||
signer.unsign(token)
|
||||
log.debug('login for uid {} successful', user_id)
|
||||
log.debug("login for uid {} successful", user_id)
|
||||
|
||||
# update the user's last_session field
|
||||
# so that we can keep an exact track of activity,
|
||||
# even on long-lived single sessions (that can happen
|
||||
# with people leaving their clients open forever)
|
||||
await db.execute("""
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET last_session = (now() at time zone 'utc')
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return user_id
|
||||
except BadSignature:
|
||||
log.warning('token failed for uid {}', user_id)
|
||||
raise Forbidden('Invalid token')
|
||||
log.warning("token failed for uid {}", user_id)
|
||||
raise Forbidden("Invalid token")
|
||||
|
||||
|
||||
async def token_check() -> int:
|
||||
|
|
@ -104,12 +110,12 @@ async def token_check() -> int:
|
|||
pass
|
||||
|
||||
try:
|
||||
token = request.headers['Authorization']
|
||||
token = request.headers["Authorization"]
|
||||
except KeyError:
|
||||
raise Unauthorized('No token provided')
|
||||
raise Unauthorized("No token provided")
|
||||
|
||||
if token.startswith('Bot '):
|
||||
token = token.replace('Bot ', '')
|
||||
if token.startswith("Bot "):
|
||||
token = token.replace("Bot ", "")
|
||||
|
||||
user_id = await raw_token_check(token)
|
||||
request.user_id = user_id
|
||||
|
|
@ -120,15 +126,18 @@ async def admin_check() -> int:
|
|||
"""Check if the user is an admin."""
|
||||
user_id = await token_check()
|
||||
|
||||
flags = await app.db.fetchval("""
|
||||
flags = await app.db.fetchval(
|
||||
"""
|
||||
SELECT flags
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
flags = UserFlags.from_int(flags)
|
||||
if not flags.is_staff:
|
||||
raise Unauthorized('you are not staff')
|
||||
raise Unauthorized("you are not staff")
|
||||
|
||||
return user_id
|
||||
|
||||
|
|
@ -138,9 +147,7 @@ async def hash_data(data: str, loop=None) -> str:
|
|||
loop = loop or app.loop
|
||||
buf = data.encode()
|
||||
|
||||
hashed = await loop.run_in_executor(
|
||||
None, bcrypt.hashpw, buf, bcrypt.gensalt(14)
|
||||
)
|
||||
hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
|
||||
|
||||
return hashed.decode()
|
||||
|
||||
|
|
@ -148,22 +155,28 @@ async def hash_data(data: str, loop=None) -> str:
|
|||
async def check_username_usage(username: str, db=None):
|
||||
"""Raise an error if too many people are with the same username."""
|
||||
db = db or app.db
|
||||
same_username = await db.fetchval("""
|
||||
same_username = await db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
""", username)
|
||||
""",
|
||||
username,
|
||||
)
|
||||
|
||||
if same_username > 9000:
|
||||
raise BadRequest('Too many people.', {
|
||||
'username': 'Too many people used the same username. '
|
||||
'Please choose another'
|
||||
})
|
||||
raise BadRequest(
|
||||
"Too many people.",
|
||||
{
|
||||
"username": "Too many people used the same username. "
|
||||
"Please choose another"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _raw_discrim() -> str:
|
||||
new_discrim = randint(1, 9999)
|
||||
new_discrim = '%04d' % new_discrim
|
||||
new_discrim = "%04d" % new_discrim
|
||||
return new_discrim
|
||||
|
||||
|
||||
|
|
@ -186,11 +199,15 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
|
|||
discrim = _raw_discrim()
|
||||
|
||||
# check if anyone is with it
|
||||
res = await db.fetchval("""
|
||||
res = await db.fetchval(
|
||||
"""
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE username = $1 AND discriminator = $2
|
||||
""", username, discrim)
|
||||
""",
|
||||
username,
|
||||
discrim,
|
||||
)
|
||||
|
||||
# if no user is found with the (username, discrim)
|
||||
# pair, then this is unique! return it.
|
||||
|
|
@ -200,8 +217,9 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
async def create_user(username: str, email: str, password: str,
|
||||
db=None, loop=None) -> Tuple[int, str]:
|
||||
async def create_user(
|
||||
username: str, email: str, password: str, db=None, loop=None
|
||||
) -> Tuple[int, str]:
|
||||
"""Create a single user.
|
||||
|
||||
Generates a distriminator and other information. You can fetch the user
|
||||
|
|
@ -214,20 +232,28 @@ async def create_user(username: str, email: str, password: str,
|
|||
new_discrim = await roll_discrim(username, db=db)
|
||||
|
||||
if new_discrim is None:
|
||||
raise BadRequest('Unable to register.', {
|
||||
'username': 'Too many people are with this username.'
|
||||
})
|
||||
raise BadRequest(
|
||||
"Unable to register.",
|
||||
{"username": "Too many people are with this username."},
|
||||
)
|
||||
|
||||
pwd_hash = await hash_data(password, loop)
|
||||
|
||||
try:
|
||||
await db.execute("""
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO users
|
||||
(id, email, username, discriminator, password_hash)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5)
|
||||
""", new_id, email, username, new_discrim, pwd_hash)
|
||||
""",
|
||||
new_id,
|
||||
email,
|
||||
username,
|
||||
new_discrim,
|
||||
pwd_hash,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise BadRequest('Email already used.')
|
||||
raise BadRequest("Email already used.")
|
||||
|
||||
return new_id, pwd_hash
|
||||
|
|
|
|||
|
|
@ -34,7 +34,21 @@ from .static import bp as static
|
|||
from .attachments import bp as attachments
|
||||
from .dm_channels import bp as dm_channels
|
||||
|
||||
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
|
||||
'webhooks', 'science', 'voice', 'invites', 'relationships',
|
||||
'dms', 'icons', 'nodeinfo', 'static', 'attachments',
|
||||
'dm_channels']
|
||||
__all__ = [
|
||||
"gateway",
|
||||
"auth",
|
||||
"users",
|
||||
"guilds",
|
||||
"channels",
|
||||
"webhooks",
|
||||
"science",
|
||||
"voice",
|
||||
"invites",
|
||||
"relationships",
|
||||
"dms",
|
||||
"icons",
|
||||
"nodeinfo",
|
||||
"static",
|
||||
"attachments",
|
||||
"dm_channels",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -23,4 +23,4 @@ from .guilds import bp as guilds
|
|||
from .users import bp as users
|
||||
from .instance_invites import bp as instance_invites
|
||||
|
||||
__all__ = ['voice', 'features', 'guilds', 'users', 'instance_invites']
|
||||
__all__ = ["voice", "features", "guilds", "users", "instance_invites"]
|
||||
|
|
|
|||
|
|
@ -25,45 +25,53 @@ from litecord.errors import BadRequest
|
|||
from litecord.schemas import validate
|
||||
from litecord.admin_schemas import FEATURES
|
||||
|
||||
bp = Blueprint('features_admin', __name__)
|
||||
bp = Blueprint("features_admin", __name__)
|
||||
|
||||
|
||||
async def _features_from_req() -> List[str]:
|
||||
j = validate(await request.get_json(), FEATURES)
|
||||
return [feature.value for feature in j['features']]
|
||||
return [feature.value for feature in j["features"]]
|
||||
|
||||
|
||||
async def _features(guild_id: int):
|
||||
return jsonify({
|
||||
'features': await app.storage.guild_features(guild_id)
|
||||
})
|
||||
return jsonify({"features": await app.storage.guild_features(guild_id)})
|
||||
|
||||
|
||||
async def _update_features(guild_id: int, features: list):
|
||||
if 'VANITY_URL' not in features:
|
||||
if "VANITY_URL" not in features:
|
||||
existing_inv = await app.storage.vanity_invite(guild_id)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM vanity_invites
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM invites
|
||||
WHERE code = $1
|
||||
""", existing_inv)
|
||||
""",
|
||||
existing_inv,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET features = $1
|
||||
WHERE id = $2
|
||||
""", features, guild_id)
|
||||
""",
|
||||
features,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
guild = await app.storage.get_guild_full(guild_id)
|
||||
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_UPDATE', guild)
|
||||
await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/features', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/features", methods=["PATCH"])
|
||||
async def replace_features(guild_id: int):
|
||||
"""Replace the feature list in a guild"""
|
||||
await admin_check()
|
||||
|
|
@ -76,7 +84,7 @@ async def replace_features(guild_id: int):
|
|||
return await _features(guild_id)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/features', methods=['PUT'])
|
||||
@bp.route("/<int:guild_id>/features", methods=["PUT"])
|
||||
async def insert_features(guild_id: int):
|
||||
"""Insert a feature on a guild."""
|
||||
await admin_check()
|
||||
|
|
@ -93,7 +101,7 @@ async def insert_features(guild_id: int):
|
|||
return await _features(guild_id)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/features', methods=['DELETE'])
|
||||
@bp.route("/<int:guild_id>/features", methods=["DELETE"])
|
||||
async def remove_features(guild_id: int):
|
||||
"""Remove a feature from a guild"""
|
||||
await admin_check()
|
||||
|
|
@ -104,7 +112,7 @@ async def remove_features(guild_id: int):
|
|||
try:
|
||||
features.remove(feature)
|
||||
except ValueError:
|
||||
raise BadRequest('Trying to remove already removed feature.')
|
||||
raise BadRequest("Trying to remove already removed feature.")
|
||||
|
||||
await _update_features(guild_id, features)
|
||||
return await _features(guild_id)
|
||||
|
|
|
|||
|
|
@ -25,9 +25,10 @@ from litecord.admin_schemas import GUILD_UPDATE
|
|||
from litecord.blueprints.guilds import delete_guild
|
||||
from litecord.errors import GuildNotFound
|
||||
|
||||
bp = Blueprint('guilds_admin', __name__)
|
||||
bp = Blueprint("guilds_admin", __name__)
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['GET'])
|
||||
|
||||
@bp.route("/<int:guild_id>", methods=["GET"])
|
||||
async def get_guild(guild_id: int):
|
||||
"""Get a basic guild payload."""
|
||||
await admin_check()
|
||||
|
|
@ -40,7 +41,7 @@ async def get_guild(guild_id: int):
|
|||
return jsonify(guild)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>", methods=["PATCH"])
|
||||
async def update_guild(guild_id: int):
|
||||
await admin_check()
|
||||
|
||||
|
|
@ -48,13 +49,13 @@ async def update_guild(guild_id: int):
|
|||
|
||||
# TODO: what happens to the other guild attributes when its
|
||||
# unavailable? do they vanish?
|
||||
old_unavailable = app.guild_store.get(guild_id, 'unavailable')
|
||||
new_unavailable = j.get('unavailable', old_unavailable)
|
||||
old_unavailable = app.guild_store.get(guild_id, "unavailable")
|
||||
new_unavailable = j.get("unavailable", old_unavailable)
|
||||
|
||||
# always set unavailable status since new_unavailable will be
|
||||
# old_unavailable when not provided, so we don't need to check if
|
||||
# j.unavailable is there
|
||||
app.guild_store.set(guild_id, 'unavailable', j['unavailable'])
|
||||
app.guild_store.set(guild_id, "unavailable", j["unavailable"])
|
||||
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
|
||||
|
|
@ -62,17 +63,17 @@ async def update_guild(guild_id: int):
|
|||
|
||||
if old_unavailable and not new_unavailable:
|
||||
# guild became available
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
|
||||
else:
|
||||
# guild became unavailable
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_DELETE', guild)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
|
||||
|
||||
return jsonify(guild)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:guild_id>", methods=["DELETE"])
|
||||
async def delete_guild_as_admin(guild_id):
|
||||
"""Delete a single guild via the admin API without ownership checks."""
|
||||
await admin_check()
|
||||
await delete_guild(guild_id)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -27,13 +27,13 @@ from litecord.types import timestamp_
|
|||
from litecord.schemas import validate
|
||||
from litecord.admin_schemas import INSTANCE_INVITE
|
||||
|
||||
bp = Blueprint('instance_invites', __name__)
|
||||
bp = Blueprint("instance_invites", __name__)
|
||||
ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
|
||||
|
||||
|
||||
async def _gen_inv() -> str:
|
||||
"""Generate an invite code"""
|
||||
return ''.join(choice(ALPHABET) for _ in range(6))
|
||||
return "".join(choice(ALPHABET) for _ in range(6))
|
||||
|
||||
|
||||
async def gen_inv(ctx) -> str:
|
||||
|
|
@ -41,11 +41,14 @@ async def gen_inv(ctx) -> str:
|
|||
for _ in range(10):
|
||||
possible_inv = await _gen_inv()
|
||||
|
||||
created_at = await ctx.db.fetchval("""
|
||||
created_at = await ctx.db.fetchval(
|
||||
"""
|
||||
SELECT created_at
|
||||
FROM instance_invites
|
||||
WHERE code = $1
|
||||
""", possible_inv)
|
||||
""",
|
||||
possible_inv,
|
||||
)
|
||||
|
||||
if created_at is None:
|
||||
return possible_inv
|
||||
|
|
@ -53,57 +56,71 @@ async def gen_inv(ctx) -> str:
|
|||
return None
|
||||
|
||||
|
||||
@bp.route('', methods=['GET'])
|
||||
@bp.route("", methods=["GET"])
|
||||
async def _all_instance_invites():
|
||||
await admin_check()
|
||||
|
||||
rows = await app.db.fetch("""
|
||||
rows = await app.db.fetch(
|
||||
"""
|
||||
SELECT code, created_at, uses, max_uses
|
||||
FROM instance_invites
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
rows = [dict(row) for row in rows]
|
||||
|
||||
for row in rows:
|
||||
row['created_at'] = timestamp_(row['created_at'])
|
||||
row["created_at"] = timestamp_(row["created_at"])
|
||||
|
||||
return jsonify(rows)
|
||||
|
||||
|
||||
@bp.route('', methods=['PUT'])
|
||||
@bp.route("", methods=["PUT"])
|
||||
async def _create_invite():
|
||||
await admin_check()
|
||||
|
||||
code = await gen_inv(app)
|
||||
if code is None:
|
||||
return 'failed to make invite', 500
|
||||
return "failed to make invite", 500
|
||||
|
||||
j = validate(await request.get_json(), INSTANCE_INVITE)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO instance_invites (code, max_uses)
|
||||
VALUES ($1, $2)
|
||||
""", code, j['max_uses'])
|
||||
""",
|
||||
code,
|
||||
j["max_uses"],
|
||||
)
|
||||
|
||||
inv = dict(await app.db.fetchrow("""
|
||||
inv = dict(
|
||||
await app.db.fetchrow(
|
||||
"""
|
||||
SELECT code, created_at, uses, max_uses
|
||||
FROM instance_invites
|
||||
WHERE code = $1
|
||||
""", code))
|
||||
""",
|
||||
code,
|
||||
)
|
||||
)
|
||||
|
||||
return jsonify(dict(inv))
|
||||
|
||||
|
||||
@bp.route('/<invite>', methods=['DELETE'])
|
||||
@bp.route("/<invite>", methods=["DELETE"])
|
||||
async def _del_invite(invite: str):
|
||||
await admin_check()
|
||||
|
||||
res = await app.db.execute("""
|
||||
res = await app.db.execute(
|
||||
"""
|
||||
DELETE FROM instance_invites
|
||||
WHERE code = $1
|
||||
""", invite)
|
||||
""",
|
||||
invite,
|
||||
)
|
||||
|
||||
if res.lower() == 'delete 0':
|
||||
return 'invite not found', 404
|
||||
if res.lower() == "delete 0":
|
||||
return "invite not found", 404
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -25,24 +25,21 @@ from litecord.schemas import validate
|
|||
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
|
||||
from litecord.errors import BadRequest, Forbidden
|
||||
from litecord.utils import async_map
|
||||
from litecord.blueprints.users import (
|
||||
delete_user, user_disconnect, mass_user_update
|
||||
)
|
||||
from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
|
||||
from litecord.enums import UserFlags
|
||||
|
||||
bp = Blueprint('users_admin', __name__)
|
||||
bp = Blueprint("users_admin", __name__)
|
||||
|
||||
@bp.route('', methods=['POST'])
|
||||
|
||||
@bp.route("", methods=["POST"])
|
||||
async def _create_user():
|
||||
await admin_check()
|
||||
|
||||
j = validate(await request.get_json(), USER_CREATE)
|
||||
|
||||
user_id, _ = await create_user(j['username'], j['email'], j['password'])
|
||||
user_id, _ = await create_user(j["username"], j["email"], j["password"])
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_user(user_id)
|
||||
)
|
||||
return jsonify(await app.storage.get_user(user_id))
|
||||
|
||||
|
||||
def args_try(args: dict, typ, field: str, default):
|
||||
|
|
@ -51,29 +48,29 @@ def args_try(args: dict, typ, field: str, default):
|
|||
try:
|
||||
return typ(args.get(field, default))
|
||||
except (TypeError, ValueError):
|
||||
raise BadRequest(f'invalid {field} value')
|
||||
raise BadRequest(f"invalid {field} value")
|
||||
|
||||
|
||||
@bp.route('', methods=['GET'])
|
||||
@bp.route("", methods=["GET"])
|
||||
async def _search_users():
|
||||
await admin_check()
|
||||
|
||||
args = request.args
|
||||
|
||||
username, discrim = args.get('username'), args.get('discriminator')
|
||||
username, discrim = args.get("username"), args.get("discriminator")
|
||||
|
||||
per_page = args_try(args, int, 'per_page', 20)
|
||||
page = args_try(args, int, 'page', 0)
|
||||
per_page = args_try(args, int, "per_page", 20)
|
||||
page = args_try(args, int, "page", 0)
|
||||
|
||||
if page < 0:
|
||||
raise BadRequest('invalid page number')
|
||||
raise BadRequest("invalid page number")
|
||||
|
||||
if per_page > 50:
|
||||
raise BadRequest('invalid per page number')
|
||||
raise BadRequest("invalid per page number")
|
||||
|
||||
# any of those must be available.
|
||||
if not any((username, discrim)):
|
||||
raise BadRequest('must insert username or discrim')
|
||||
raise BadRequest("must insert username or discrim")
|
||||
|
||||
wheres, args = [], []
|
||||
|
||||
|
|
@ -82,29 +79,31 @@ async def _search_users():
|
|||
args.append(username)
|
||||
|
||||
if discrim:
|
||||
wheres.append(f'discriminator = ${len(args) + 2}')
|
||||
wheres.append(f"discriminator = ${len(args) + 2}")
|
||||
args.append(discrim)
|
||||
|
||||
where_tot = 'WHERE ' if args else ''
|
||||
where_tot += ' AND '.join(wheres)
|
||||
where_tot = "WHERE " if args else ""
|
||||
where_tot += " AND ".join(wheres)
|
||||
|
||||
rows = await app.db.fetch(f"""
|
||||
rows = await app.db.fetch(
|
||||
f"""
|
||||
SELECT id
|
||||
FROM users
|
||||
{where_tot}
|
||||
ORDER BY id ASC
|
||||
LIMIT {per_page}
|
||||
OFFSET ($1 * {per_page})
|
||||
""", page, *args)
|
||||
|
||||
rows = [r['id'] for r in rows]
|
||||
|
||||
return jsonify(
|
||||
await async_map(app.storage.get_user, rows)
|
||||
""",
|
||||
page,
|
||||
*args,
|
||||
)
|
||||
|
||||
rows = [r["id"] for r in rows]
|
||||
|
||||
@bp.route('/<int:user_id>', methods=['DELETE'])
|
||||
return jsonify(await async_map(app.storage.get_user, rows))
|
||||
|
||||
|
||||
@bp.route("/<int:user_id>", methods=["DELETE"])
|
||||
async def _delete_single_user(user_id: int):
|
||||
await admin_check()
|
||||
|
||||
|
|
@ -115,13 +114,10 @@ async def _delete_single_user(user_id: int):
|
|||
|
||||
new_user = await app.storage.get_user(user_id)
|
||||
|
||||
return jsonify({
|
||||
'old': old_user,
|
||||
'new': new_user
|
||||
})
|
||||
return jsonify({"old": old_user, "new": new_user})
|
||||
|
||||
|
||||
@bp.route('/<int:user_id>', methods=['PATCH'])
|
||||
@bp.route("/<int:user_id>", methods=["PATCH"])
|
||||
async def patch_user(user_id: int):
|
||||
await admin_check()
|
||||
|
||||
|
|
@ -129,21 +125,25 @@ async def patch_user(user_id: int):
|
|||
|
||||
# get the original user for flags checking
|
||||
user = await app.storage.get_user(user_id)
|
||||
old_flags = UserFlags.from_int(user['flags'])
|
||||
old_flags = UserFlags.from_int(user["flags"])
|
||||
|
||||
# j.flags is already a UserFlags since we coerce it.
|
||||
if 'flags' in j:
|
||||
new_flags = j['flags']
|
||||
if "flags" in j:
|
||||
new_flags = j["flags"]
|
||||
|
||||
# disallow any changes to the staff badge
|
||||
if new_flags.is_staff != old_flags.is_staff:
|
||||
raise Forbidden('you can not change a users staff badge')
|
||||
raise Forbidden("you can not change a users staff badge")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET flags = $1
|
||||
WHERE id = $2
|
||||
""", new_flags.value, user_id)
|
||||
""",
|
||||
new_flags.value,
|
||||
user_id,
|
||||
)
|
||||
|
||||
public_user, _ = await mass_user_update(user_id, app)
|
||||
return jsonify(public_user)
|
||||
|
|
|
|||
|
|
@ -27,10 +27,10 @@ from litecord.admin_schemas import VOICE_SERVER, VOICE_REGION
|
|||
from litecord.errors import BadRequest
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('voice_admin', __name__)
|
||||
bp = Blueprint("voice_admin", __name__)
|
||||
|
||||
|
||||
@bp.route('/regions/<region>', methods=['GET'])
|
||||
@bp.route("/regions/<region>", methods=["GET"])
|
||||
async def get_region_servers(region):
|
||||
"""Return a list of all servers for a region."""
|
||||
await admin_check()
|
||||
|
|
@ -38,18 +38,25 @@ async def get_region_servers(region):
|
|||
return jsonify(servers)
|
||||
|
||||
|
||||
@bp.route('/regions', methods=['PUT'])
|
||||
@bp.route("/regions", methods=["PUT"])
|
||||
async def insert_new_region():
|
||||
"""Create a voice region."""
|
||||
await admin_check()
|
||||
j = validate(await request.get_json(), VOICE_REGION)
|
||||
|
||||
j['id'] = j['id'].lower()
|
||||
j["id"] = j["id"].lower()
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO voice_regions (id, name, vip, deprecated, custom)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""", j['id'], j['name'], j['vip'], j['deprecated'], j['custom'])
|
||||
""",
|
||||
j["id"],
|
||||
j["name"],
|
||||
j["vip"],
|
||||
j["deprecated"],
|
||||
j["custom"],
|
||||
)
|
||||
|
||||
regions = await app.storage.all_voice_regions()
|
||||
region_count = len(regions)
|
||||
|
|
@ -57,34 +64,41 @@ async def insert_new_region():
|
|||
# if region count is 1, this is the first region to be created,
|
||||
# so we should update all guilds to that region
|
||||
if region_count == 1:
|
||||
res = await app.db.execute("""
|
||||
res = await app.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET region = $1
|
||||
""", j['id'])
|
||||
""",
|
||||
j["id"],
|
||||
)
|
||||
|
||||
log.info('updating guilds to first voice region: {}', res)
|
||||
log.info("updating guilds to first voice region: {}", res)
|
||||
|
||||
return jsonify(regions)
|
||||
|
||||
|
||||
@bp.route('/regions/<region>/servers', methods=['PUT'])
|
||||
@bp.route("/regions/<region>/servers", methods=["PUT"])
|
||||
async def put_region_server(region):
|
||||
"""Insert a voice server to a region"""
|
||||
await admin_check()
|
||||
j = validate(await request.get_json(), VOICE_SERVER)
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO voice_servers (hostname, region_id)
|
||||
VALUES ($1, $2)
|
||||
""", j['hostname'], region)
|
||||
""",
|
||||
j["hostname"],
|
||||
region,
|
||||
)
|
||||
except asyncpg.UniqueViolationError:
|
||||
raise BadRequest('voice server already exists with given hostname')
|
||||
raise BadRequest("voice server already exists with given hostname")
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/regions/<region>/deprecate', methods=['PUT'])
|
||||
@bp.route("/regions/<region>/deprecate", methods=["PUT"])
|
||||
async def deprecate_region(region):
|
||||
"""Deprecate a voice region."""
|
||||
await admin_check()
|
||||
|
|
@ -92,13 +106,16 @@ async def deprecate_region(region):
|
|||
# TODO: write this
|
||||
await app.voice.disable_region(region)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE voice_regions
|
||||
SET deprecated = true
|
||||
WHERE id = $1
|
||||
""", region)
|
||||
""",
|
||||
region,
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
async def guild_region_check(app_):
|
||||
|
|
@ -112,10 +129,11 @@ async def guild_region_check(app_):
|
|||
regions = await app_.storage.all_voice_regions()
|
||||
|
||||
if not regions:
|
||||
log.info('region check: no regions to move guilds to')
|
||||
log.info("region check: no regions to move guilds to")
|
||||
return
|
||||
|
||||
res = await app_.db.execute("""
|
||||
res = await app_.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET region = (
|
||||
SELECT id
|
||||
|
|
@ -124,6 +142,8 @@ async def guild_region_check(app_):
|
|||
LIMIT 1
|
||||
)
|
||||
WHERE region = NULL
|
||||
""", len(regions))
|
||||
""",
|
||||
len(regions),
|
||||
)
|
||||
|
||||
log.info('region check: updating guild.region=null: {!r}', res)
|
||||
log.info("region check: updating guild.region=null: {!r}", res)
|
||||
|
|
|
|||
|
|
@ -24,16 +24,17 @@ from PIL import Image
|
|||
|
||||
from litecord.images import resize_gif
|
||||
|
||||
bp = Blueprint('attachments', __name__)
|
||||
ATTACHMENTS = Path.cwd() / 'attachments'
|
||||
bp = Blueprint("attachments", __name__)
|
||||
ATTACHMENTS = Path.cwd() / "attachments"
|
||||
|
||||
|
||||
async def _resize_gif(attach_id: int, resized_path: Path,
|
||||
width: int, height: int) -> str:
|
||||
async def _resize_gif(
|
||||
attach_id: int, resized_path: Path, width: int, height: int
|
||||
) -> str:
|
||||
"""Resize a GIF attachment."""
|
||||
|
||||
# get original gif bytes
|
||||
orig_path = ATTACHMENTS / f'{attach_id}.gif'
|
||||
orig_path = ATTACHMENTS / f"{attach_id}.gif"
|
||||
orig_bytes = orig_path.read_bytes()
|
||||
|
||||
# give them and the target size to the
|
||||
|
|
@ -47,10 +48,7 @@ async def _resize_gif(attach_id: int, resized_path: Path,
|
|||
return str(resized_path)
|
||||
|
||||
|
||||
FORMAT_HARDCODE = {
|
||||
'jpg': 'jpeg',
|
||||
'jpe': 'jpeg'
|
||||
}
|
||||
FORMAT_HARDCODE = {"jpg": "jpeg", "jpe": "jpeg"}
|
||||
|
||||
|
||||
def to_format(ext: str) -> str:
|
||||
|
|
@ -63,11 +61,10 @@ def to_format(ext: str) -> str:
|
|||
return ext
|
||||
|
||||
|
||||
async def _resize(image, attach_id: int, ext: str,
|
||||
width: int, height: int) -> str:
|
||||
async def _resize(image, attach_id: int, ext: str, width: int, height: int) -> str:
|
||||
"""Resize an image."""
|
||||
# check if we have it on the folder
|
||||
resized_path = ATTACHMENTS / f'{attach_id}_{width}_{height}.{ext}'
|
||||
resized_path = ATTACHMENTS / f"{attach_id}_{width}_{height}.{ext}"
|
||||
|
||||
# keep a str-fied instance since that is what
|
||||
# we'll return.
|
||||
|
|
@ -81,7 +78,7 @@ async def _resize(image, attach_id: int, ext: str,
|
|||
|
||||
# the process is different for gif files because we need
|
||||
# gifsicle. doing it manually is too troublesome.
|
||||
if ext == 'gif':
|
||||
if ext == "gif":
|
||||
return await _resize_gif(attach_id, resized_path, width, height)
|
||||
|
||||
# NOTE: this is the same resize mode for icons.
|
||||
|
|
@ -91,38 +88,42 @@ async def _resize(image, attach_id: int, ext: str,
|
|||
return resized_path_s
|
||||
|
||||
|
||||
@bp.route('/attachments'
|
||||
'/<int:channel_id>/<int:message_id>/<filename>',
|
||||
methods=['GET'])
|
||||
async def _get_attachment(channel_id: int, message_id: int,
|
||||
filename: str):
|
||||
@bp.route(
|
||||
"/attachments" "/<int:channel_id>/<int:message_id>/<filename>", methods=["GET"]
|
||||
)
|
||||
async def _get_attachment(channel_id: int, message_id: int, filename: str):
|
||||
|
||||
attach_id = await app.db.fetchval("""
|
||||
attach_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT id
|
||||
FROM attachments
|
||||
WHERE channel_id = $1
|
||||
AND message_id = $2
|
||||
AND filename = $3
|
||||
""", channel_id, message_id, filename)
|
||||
""",
|
||||
channel_id,
|
||||
message_id,
|
||||
filename,
|
||||
)
|
||||
|
||||
if attach_id is None:
|
||||
return '', 404
|
||||
return "", 404
|
||||
|
||||
ext = filename.split('.')[-1]
|
||||
filepath = f'./attachments/{attach_id}.{ext}'
|
||||
ext = filename.split(".")[-1]
|
||||
filepath = f"./attachments/{attach_id}.{ext}"
|
||||
|
||||
image = Image.open(filepath)
|
||||
im_width, im_height = image.size
|
||||
|
||||
try:
|
||||
width = int(request.args.get('width', 0)) or im_width
|
||||
width = int(request.args.get("width", 0)) or im_width
|
||||
except ValueError:
|
||||
return '', 400
|
||||
return "", 400
|
||||
|
||||
try:
|
||||
height = int(request.args.get('height', 0)) or im_height
|
||||
height = int(request.args.get("height", 0)) or im_height
|
||||
except ValueError:
|
||||
return '', 400
|
||||
return "", 400
|
||||
|
||||
# if width and height are the same (happens if they weren't provided)
|
||||
if width == im_width and height == im_height:
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from litecord.snowflake import get_snowflake
|
|||
from .invites import use_invite
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('auth', __name__)
|
||||
bp = Blueprint("auth", __name__)
|
||||
|
||||
|
||||
async def check_password(pwd_hash: str, given_password: str) -> bool:
|
||||
|
|
@ -53,141 +53,139 @@ def make_token(user_id, user_pwd_hash) -> str:
|
|||
return signer.sign(user_id).decode()
|
||||
|
||||
|
||||
@bp.route('/register', methods=['POST'])
|
||||
@bp.route("/register", methods=["POST"])
|
||||
async def register():
|
||||
"""Register a single user."""
|
||||
enabled = app.config.get('REGISTRATIONS')
|
||||
enabled = app.config.get("REGISTRATIONS")
|
||||
if not enabled:
|
||||
raise BadRequest('Registrations disabled', {
|
||||
'email': 'Registrations are disabled.'
|
||||
})
|
||||
raise BadRequest(
|
||||
"Registrations disabled", {"email": "Registrations are disabled."}
|
||||
)
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
if not 'password' in j:
|
||||
if not "password" in j:
|
||||
# we need a password to generate a token.
|
||||
# passwords are optional, so
|
||||
j['password'] = 'default_password'
|
||||
j["password"] = "default_password"
|
||||
|
||||
j = validate(j, REGISTER)
|
||||
|
||||
# they're optional
|
||||
email = j.get('email')
|
||||
invite = j.get('invite')
|
||||
email = j.get("email")
|
||||
invite = j.get("invite")
|
||||
|
||||
username, password = j['username'], j['password']
|
||||
username, password = j["username"], j["password"]
|
||||
|
||||
new_id, pwd_hash = await create_user(
|
||||
username, email, password, app.db
|
||||
)
|
||||
new_id, pwd_hash = await create_user(username, email, password, app.db)
|
||||
|
||||
if invite:
|
||||
try:
|
||||
await use_invite(new_id, invite)
|
||||
except Exception:
|
||||
log.exception('failed to use invite for register {} {!r}',
|
||||
new_id, invite)
|
||||
log.exception("failed to use invite for register {} {!r}", new_id, invite)
|
||||
|
||||
return jsonify({
|
||||
'token': make_token(new_id, pwd_hash)
|
||||
})
|
||||
return jsonify({"token": make_token(new_id, pwd_hash)})
|
||||
|
||||
|
||||
@bp.route('/register_inv', methods=['POST'])
|
||||
@bp.route("/register_inv", methods=["POST"])
|
||||
async def _register_with_invite():
|
||||
data = await request.form
|
||||
data = validate(await request.form, REGISTER_WITH_INVITE)
|
||||
|
||||
invcode = data['invcode']
|
||||
invcode = data["invcode"]
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
row = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT uses, max_uses
|
||||
FROM instance_invites
|
||||
WHERE code = $1
|
||||
""", invcode)
|
||||
""",
|
||||
invcode,
|
||||
)
|
||||
|
||||
if row is None:
|
||||
raise BadRequest('unknown instance invite')
|
||||
raise BadRequest("unknown instance invite")
|
||||
|
||||
if row['max_uses'] != -1 and row['uses'] >= row['max_uses']:
|
||||
raise BadRequest('invite expired')
|
||||
if row["max_uses"] != -1 and row["uses"] >= row["max_uses"]:
|
||||
raise BadRequest("invite expired")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE instance_invites
|
||||
SET uses = uses + 1
|
||||
WHERE code = $1
|
||||
""", invcode)
|
||||
""",
|
||||
invcode,
|
||||
)
|
||||
|
||||
user_id, pwd_hash = await create_user(
|
||||
data['username'], data['email'], data['password'], app.db)
|
||||
data["username"], data["email"], data["password"], app.db
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'token': make_token(user_id, pwd_hash),
|
||||
'user_id': str(user_id),
|
||||
})
|
||||
return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
|
||||
|
||||
|
||||
@bp.route('/login', methods=['POST'])
|
||||
@bp.route("/login", methods=["POST"])
|
||||
async def login():
|
||||
j = await request.get_json()
|
||||
email, password = j['email'], j['password']
|
||||
email, password = j["email"], j["password"]
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
row = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT id, password_hash
|
||||
FROM users
|
||||
WHERE email = $1
|
||||
""", email)
|
||||
""",
|
||||
email,
|
||||
)
|
||||
|
||||
if not row:
|
||||
return jsonify({'email': ['User not found.']}), 401
|
||||
return jsonify({"email": ["User not found."]}), 401
|
||||
|
||||
user_id, pwd_hash = row
|
||||
|
||||
if not await check_password(pwd_hash, password):
|
||||
return jsonify({'password': ['Password does not match.']}), 401
|
||||
return jsonify({"password": ["Password does not match."]}), 401
|
||||
|
||||
return jsonify({
|
||||
'token': make_token(user_id, pwd_hash)
|
||||
})
|
||||
return jsonify({"token": make_token(user_id, pwd_hash)})
|
||||
|
||||
|
||||
@bp.route('/consent-required', methods=['GET'])
|
||||
@bp.route("/consent-required", methods=["GET"])
|
||||
async def consent_required():
|
||||
return jsonify({
|
||||
'required': True,
|
||||
})
|
||||
return jsonify({"required": True})
|
||||
|
||||
|
||||
@bp.route('/verify/resend', methods=['POST'])
|
||||
@bp.route("/verify/resend", methods=["POST"])
|
||||
async def verify_user():
|
||||
user_id = await token_check()
|
||||
|
||||
# TODO: actually verify a user by sending an email
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET verified = true
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
new_user = await app.storage.get_user(user_id, True)
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, 'USER_UPDATE', new_user)
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/logout', methods=['POST'])
|
||||
@bp.route("/logout", methods=["POST"])
|
||||
async def _logout():
|
||||
"""Called by the client to logout."""
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/fingerprint', methods=['POST'])
|
||||
@bp.route("/fingerprint", methods=["POST"])
|
||||
async def _fingerprint():
|
||||
"""No idea what this route is about."""
|
||||
fingerprint_id = get_snowflake()
|
||||
fingerprint = f'{fingerprint_id}.{secrets.token_urlsafe(32)}'
|
||||
fingerprint = f"{fingerprint_id}.{secrets.token_urlsafe(32)}"
|
||||
|
||||
return jsonify({
|
||||
'fingerprint': fingerprint
|
||||
})
|
||||
return jsonify({"fingerprint": fingerprint})
|
||||
|
|
|
|||
|
|
@ -21,4 +21,4 @@ from .messages import bp as channel_messages
|
|||
from .reactions import bp as channel_reactions
|
||||
from .pins import bp as channel_pins
|
||||
|
||||
__all__ = ['channel_messages', 'channel_reactions', 'channel_pins']
|
||||
__all__ = ["channel_messages", "channel_reactions", "channel_pins"]
|
||||
|
|
|
|||
|
|
@ -30,13 +30,18 @@ class ForbiddenDM(Forbidden):
|
|||
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||
"""Check if the user can DM the peer."""
|
||||
# first step is checking if there is a block in any direction
|
||||
blockrow = await app.db.fetchrow("""
|
||||
blockrow = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT rel_type
|
||||
FROM relationships
|
||||
WHERE rel_type = $3
|
||||
AND user_id IN ($1, $2)
|
||||
AND peer_id IN ($1, $2)
|
||||
""", user_id, peer_id, RelationshipType.BLOCK.value)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
RelationshipType.BLOCK.value,
|
||||
)
|
||||
|
||||
if blockrow is not None:
|
||||
raise ForbiddenDM()
|
||||
|
|
@ -58,8 +63,8 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
|||
user_settings = await app.user_storage.get_user_settings(user_id)
|
||||
peer_settings = await app.user_storage.get_user_settings(peer_id)
|
||||
|
||||
restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
|
||||
restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
|
||||
restricted_user_ = [int(v) for v in user_settings["restricted_guilds"]]
|
||||
restricted_peer_ = [int(v) for v in peer_settings["restricted_guilds"]]
|
||||
|
||||
restricted_user = set(restricted_user_)
|
||||
restricted_peer = set(restricted_peer_)
|
||||
|
|
|
|||
|
|
@ -41,18 +41,18 @@ from litecord.images import try_unlink
|
|||
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channel_messages', __name__)
|
||||
bp = Blueprint("channel_messages", __name__)
|
||||
|
||||
|
||||
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
||||
"""Extract a limit kwarg."""
|
||||
try:
|
||||
limit = int(request_.args.get('limit', default))
|
||||
limit = int(request_.args.get("limit", default))
|
||||
|
||||
if limit not in range(0, max_val + 1):
|
||||
raise ValueError()
|
||||
except (TypeError, ValueError):
|
||||
raise BadRequest('limit not int')
|
||||
raise BadRequest("limit not int")
|
||||
|
||||
return limit
|
||||
|
||||
|
|
@ -61,27 +61,27 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
|||
"""Extract a 2-tuple out of request arguments."""
|
||||
before, after = None, None
|
||||
|
||||
if 'around' in request.args:
|
||||
if "around" in request.args:
|
||||
average = int(limit / 2)
|
||||
around = int(args['around'])
|
||||
around = int(args["around"])
|
||||
|
||||
after = around - average
|
||||
before = around + average
|
||||
|
||||
elif 'before' in args:
|
||||
before = int(args['before'])
|
||||
elif 'after' in args:
|
||||
before = int(args['after'])
|
||||
elif "before" in args:
|
||||
before = int(args["before"])
|
||||
elif "after" in args:
|
||||
before = int(args["after"])
|
||||
|
||||
return before, after
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
||||
@bp.route("/<int:channel_id>/messages", methods=["GET"])
|
||||
async def get_messages(channel_id):
|
||||
user_id = await token_check()
|
||||
|
||||
ctype, peer_id = await channel_check(user_id, channel_id)
|
||||
await channel_perm_check(user_id, channel_id, 'read_history')
|
||||
await channel_perm_check(user_id, channel_id, "read_history")
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
# make sure both parties will be subbed
|
||||
|
|
@ -91,42 +91,45 @@ async def get_messages(channel_id):
|
|||
|
||||
limit = extract_limit(request, 50)
|
||||
|
||||
where_clause = ''
|
||||
where_clause = ""
|
||||
before, after = query_tuple_from_args(request.args, limit)
|
||||
|
||||
if before:
|
||||
where_clause += f'AND id < {before}'
|
||||
where_clause += f"AND id < {before}"
|
||||
|
||||
if after:
|
||||
where_clause += f'AND id > {after}'
|
||||
where_clause += f"AND id > {after}"
|
||||
|
||||
message_ids = await app.db.fetch(f"""
|
||||
message_ids = await app.db.fetch(
|
||||
f"""
|
||||
SELECT id
|
||||
FROM messages
|
||||
WHERE channel_id = $1 {where_clause}
|
||||
ORDER BY id DESC
|
||||
LIMIT {limit}
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
result = []
|
||||
|
||||
for message_id in message_ids:
|
||||
msg = await app.storage.get_message(message_id['id'], user_id)
|
||||
msg = await app.storage.get_message(message_id["id"], user_id)
|
||||
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
|
||||
log.info('Fetched {} messages', len(result))
|
||||
log.info("Fetched {} messages", len(result))
|
||||
return jsonify(result)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
|
||||
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["GET"])
|
||||
async def get_single_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
await channel_perm_check(user_id, channel_id, 'read_history')
|
||||
await channel_perm_check(user_id, channel_id, "read_history")
|
||||
|
||||
message = await app.storage.get_message(message_id, user_id)
|
||||
|
||||
|
|
@ -142,11 +145,15 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
|||
|
||||
# check the other party's dm_channel_state
|
||||
|
||||
dm_state = await app.db.fetchval("""
|
||||
dm_state = await app.db.fetchval(
|
||||
"""
|
||||
SELECT dm_id
|
||||
FROM dm_channel_state
|
||||
WHERE user_id = $1 AND dm_id = $2
|
||||
""", peer_id, channel_id)
|
||||
""",
|
||||
peer_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
if dm_state:
|
||||
# the peer already has the channel
|
||||
|
|
@ -157,18 +164,19 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
|||
|
||||
# dispatch CHANNEL_CREATE so the client knows which
|
||||
# channel the future event is about
|
||||
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
|
||||
await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
||||
|
||||
# subscribe the peer to the channel
|
||||
await app.dispatcher.sub('channel', channel_id, peer_id)
|
||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||
|
||||
# insert it on dm_channel_state so the client
|
||||
# is subscribed on the future
|
||||
await try_dm_state(peer_id, channel_id)
|
||||
|
||||
|
||||
async def create_message(channel_id: int, actual_guild_id: int,
|
||||
author_id: int, data: dict) -> int:
|
||||
async def create_message(
|
||||
channel_id: int, actual_guild_id: int, author_id: int, data: dict
|
||||
) -> int:
|
||||
message_id = get_snowflake()
|
||||
|
||||
async with app.db.acquire() as conn:
|
||||
|
|
@ -185,32 +193,32 @@ async def create_message(channel_id: int, actual_guild_id: int,
|
|||
channel_id,
|
||||
actual_guild_id,
|
||||
author_id,
|
||||
data['content'],
|
||||
|
||||
data['tts'],
|
||||
data['everyone_mention'],
|
||||
|
||||
data['nonce'],
|
||||
data["content"],
|
||||
data["tts"],
|
||||
data["everyone_mention"],
|
||||
data["nonce"],
|
||||
MessageType.DEFAULT.value,
|
||||
data.get('embeds') or []
|
||||
data.get("embeds") or [],
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
async def msg_guild_text_mentions(payload: dict, guild_id: int,
|
||||
mentions_everyone: bool, mentions_here: bool):
|
||||
|
||||
async def msg_guild_text_mentions(
|
||||
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
|
||||
):
|
||||
"""Calculates mention data side-effects."""
|
||||
channel_id = int(payload['channel_id'])
|
||||
channel_id = int(payload["channel_id"])
|
||||
|
||||
# calculate the user ids we'll bump the mention count for
|
||||
uids = set()
|
||||
|
||||
# first is extracting user mentions
|
||||
for mention in payload['mentions']:
|
||||
uids.add(int(mention['id']))
|
||||
for mention in payload["mentions"]:
|
||||
uids.add(int(mention["id"]))
|
||||
|
||||
# then role mentions
|
||||
for role_mention in payload['mention_roles']:
|
||||
for role_mention in payload["mention_roles"]:
|
||||
role_id = int(role_mention)
|
||||
member_ids = await app.storage.get_role_members(role_id)
|
||||
|
||||
|
|
@ -223,11 +231,14 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
|
|||
if mentions_here:
|
||||
uids = set()
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
# at-here updates the read state
|
||||
# for all users, including the ones
|
||||
|
|
@ -238,19 +249,26 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
|
|||
|
||||
member_ids = await app.storage.get_member_ids(guild_id)
|
||||
|
||||
await app.db.executemany("""
|
||||
await app.db.executemany(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE channel_id = $1 AND user_id = $2
|
||||
""", [(channel_id, uid) for uid in member_ids])
|
||||
""",
|
||||
[(channel_id, uid) for uid in member_ids],
|
||||
)
|
||||
|
||||
for user_id in uids:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
WHERE user_id = $1
|
||||
AND channel_id = $2
|
||||
""", user_id, channel_id)
|
||||
""",
|
||||
user_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
|
||||
async def msg_create_request() -> tuple:
|
||||
|
|
@ -264,12 +282,12 @@ async def msg_create_request() -> tuple:
|
|||
|
||||
# NOTE: embed isn't set on form data
|
||||
json_from_form = {
|
||||
'content': form.get('content', ''),
|
||||
'nonce': form.get('nonce', '0'),
|
||||
'tts': json.loads(form.get('tts', 'false')),
|
||||
"content": form.get("content", ""),
|
||||
"nonce": form.get("nonce", "0"),
|
||||
"tts": json.loads(form.get("tts", "false")),
|
||||
}
|
||||
|
||||
payload_json = json.loads(form.get('payload_json', '{}'))
|
||||
payload_json = json.loads(form.get("payload_json", "{}"))
|
||||
|
||||
json_from_form.update(request_json)
|
||||
json_from_form.update(payload_json)
|
||||
|
|
@ -283,20 +301,19 @@ async def msg_create_request() -> tuple:
|
|||
|
||||
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
||||
"""Check if there is actually any content being sent to us."""
|
||||
has_content = bool(payload.get('content', ''))
|
||||
has_content = bool(payload.get("content", ""))
|
||||
has_files = len(files) > 0
|
||||
|
||||
embed_field = 'embeds' if use_embeds else 'embed'
|
||||
embed_field = "embeds" if use_embeds else "embed"
|
||||
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
||||
|
||||
has_total_content = has_content or has_embed or has_files
|
||||
|
||||
if not has_total_content:
|
||||
raise BadRequest('No content has been provided.')
|
||||
raise BadRequest("No content has been provided.")
|
||||
|
||||
|
||||
async def msg_add_attachment(message_id: int, channel_id: int,
|
||||
attachment_file) -> int:
|
||||
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
|
||||
"""Add an attachment to a message.
|
||||
|
||||
Parameters
|
||||
|
|
@ -318,7 +335,7 @@ async def msg_add_attachment(message_id: int, channel_id: int,
|
|||
|
||||
# understand file info
|
||||
mime = attachment_file.mimetype
|
||||
is_image = mime.startswith('image/')
|
||||
is_image = mime.startswith("image/")
|
||||
|
||||
img_width, img_height = None, None
|
||||
|
||||
|
|
@ -346,17 +363,22 @@ async def msg_add_attachment(message_id: int, channel_id: int,
|
|||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""",
|
||||
attachment_id, channel_id, message_id,
|
||||
filename, file_size,
|
||||
is_image, img_width, img_height)
|
||||
attachment_id,
|
||||
channel_id,
|
||||
message_id,
|
||||
filename,
|
||||
file_size,
|
||||
is_image,
|
||||
img_width,
|
||||
img_height,
|
||||
)
|
||||
|
||||
ext = filename.split('.')[-1]
|
||||
ext = filename.split(".")[-1]
|
||||
|
||||
with open(f'attachments/{attachment_id}.{ext}', 'wb') as attach_file:
|
||||
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
|
||||
attach_file.write(attachment_file.stream.read())
|
||||
|
||||
log.debug('written {} bytes for attachment id {}',
|
||||
file_size, attachment_id)
|
||||
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
|
||||
|
||||
return attachment_id
|
||||
|
||||
|
|
@ -364,12 +386,12 @@ async def msg_add_attachment(message_id: int, channel_id: int,
|
|||
async def _spawn_embed(app_, payload, **kwargs):
|
||||
app_.sched.spawn(
|
||||
process_url_embed(
|
||||
app_.config, app_.storage, app_.dispatcher, app_.session,
|
||||
payload, **kwargs)
|
||||
app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
||||
@bp.route("/<int:channel_id>/messages", methods=["POST"])
|
||||
async def _create_message(channel_id):
|
||||
"""Create a message."""
|
||||
|
||||
|
|
@ -379,7 +401,7 @@ async def _create_message(channel_id):
|
|||
actual_guild_id = None
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
await channel_perm_check(user_id, channel_id, 'send_messages')
|
||||
await channel_perm_check(user_id, channel_id, "send_messages")
|
||||
actual_guild_id = guild_id
|
||||
|
||||
payload_json, files = await msg_create_request()
|
||||
|
|
@ -394,29 +416,31 @@ async def _create_message(channel_id):
|
|||
await dm_pre_check(user_id, channel_id, guild_id)
|
||||
|
||||
can_everyone = await channel_perm_check(
|
||||
user_id, channel_id, 'mention_everyone', False
|
||||
user_id, channel_id, "mention_everyone", False
|
||||
)
|
||||
|
||||
mentions_everyone = ('@everyone' in j['content']) and can_everyone
|
||||
mentions_here = ('@here' in j['content']) and can_everyone
|
||||
mentions_everyone = ("@everyone" in j["content"]) and can_everyone
|
||||
mentions_here = ("@here" in j["content"]) and can_everyone
|
||||
|
||||
is_tts = (j.get('tts', False) and
|
||||
await channel_perm_check(
|
||||
user_id, channel_id, 'send_tts_messages', False
|
||||
))
|
||||
is_tts = j.get("tts", False) and await channel_perm_check(
|
||||
user_id, channel_id, "send_tts_messages", False
|
||||
)
|
||||
|
||||
message_id = await create_message(
|
||||
channel_id, actual_guild_id, user_id, {
|
||||
'content': j['content'],
|
||||
'tts': is_tts,
|
||||
'nonce': int(j.get('nonce', 0)),
|
||||
'everyone_mention': mentions_everyone or mentions_here,
|
||||
|
||||
channel_id,
|
||||
actual_guild_id,
|
||||
user_id,
|
||||
{
|
||||
"content": j["content"],
|
||||
"tts": is_tts,
|
||||
"nonce": int(j.get("nonce", 0)),
|
||||
"everyone_mention": mentions_everyone or mentions_here,
|
||||
# fill_embed takes care of filling proxy and width/height
|
||||
'embeds': ([await fill_embed(j['embed'])]
|
||||
if j.get('embed') is not None
|
||||
else []),
|
||||
})
|
||||
"embeds": (
|
||||
[await fill_embed(j["embed"])] if j.get("embed") is not None else []
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
# for each file given, we add it as an attachment
|
||||
for pre_attachment in files:
|
||||
|
|
@ -429,8 +453,7 @@ async def _create_message(channel_id):
|
|||
await _dm_pre_dispatch(channel_id, user_id)
|
||||
await _dm_pre_dispatch(channel_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch('channel', channel_id,
|
||||
'MESSAGE_CREATE', payload)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||
|
||||
# spawn url processor for embedding of images
|
||||
perms = await get_permissions(user_id, channel_id)
|
||||
|
|
@ -438,54 +461,71 @@ async def _create_message(channel_id):
|
|||
await _spawn_embed(app, payload)
|
||||
|
||||
# update read state for the author
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE user_read_state
|
||||
SET last_message_id = $1
|
||||
WHERE channel_id = $2 AND user_id = $3
|
||||
""", message_id, channel_id, user_id)
|
||||
""",
|
||||
message_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if ctype == ChannelType.GUILD_TEXT:
|
||||
await msg_guild_text_mentions(
|
||||
payload, guild_id, mentions_everyone, mentions_here)
|
||||
payload, guild_id, mentions_everyone, mentions_here
|
||||
)
|
||||
|
||||
return jsonify(payload)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
|
||||
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["PATCH"])
|
||||
async def edit_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, _guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
author_id = await app.db.fetchval("""
|
||||
author_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT author_id FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
if not author_id == user_id:
|
||||
raise Forbidden('You can not edit this message')
|
||||
raise Forbidden("You can not edit this message")
|
||||
|
||||
j = await request.get_json()
|
||||
updated = 'content' in j or 'embed' in j
|
||||
updated = "content" in j or "embed" in j
|
||||
old_message = await app.storage.get_message(message_id)
|
||||
|
||||
if 'content' in j:
|
||||
await app.db.execute("""
|
||||
if "content" in j:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE messages
|
||||
SET content=$1
|
||||
WHERE messages.id = $2
|
||||
""", j['content'], message_id)
|
||||
""",
|
||||
j["content"],
|
||||
message_id,
|
||||
)
|
||||
|
||||
if 'embed' in j:
|
||||
embeds = [await fill_embed(j['embed'])]
|
||||
if "embed" in j:
|
||||
embeds = [await fill_embed(j["embed"])]
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE messages
|
||||
SET embeds=$1
|
||||
WHERE messages.id = $2
|
||||
""", embeds, message_id)
|
||||
""",
|
||||
embeds,
|
||||
message_id,
|
||||
)
|
||||
|
||||
# do not spawn process_url_embed since we already have embeds.
|
||||
elif 'content' in j:
|
||||
elif "content" in j:
|
||||
# if there weren't any embed changes BUT
|
||||
# we had a content change, we dispatch process_url_embed but with
|
||||
# an artificial delay.
|
||||
|
|
@ -495,46 +535,55 @@ async def edit_message(channel_id, message_id):
|
|||
# BEFORE the MESSAGE_UPDATE with the new embeds (based on content)
|
||||
perms = await get_permissions(user_id, channel_id)
|
||||
if perms.bits.embed_links:
|
||||
await _spawn_embed(app, {
|
||||
'id': message_id,
|
||||
'channel_id': channel_id,
|
||||
'content': j['content'],
|
||||
'embeds': old_message['embeds']
|
||||
}, delay=0.2)
|
||||
await _spawn_embed(
|
||||
app,
|
||||
{
|
||||
"id": message_id,
|
||||
"channel_id": channel_id,
|
||||
"content": j["content"],
|
||||
"embeds": old_message["embeds"],
|
||||
},
|
||||
delay=0.2,
|
||||
)
|
||||
|
||||
# only set new timestamp upon actual update
|
||||
if updated:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE messages
|
||||
SET edited_at = (now() at time zone 'utc')
|
||||
WHERE id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
message = await app.storage.get_message(message_id, user_id)
|
||||
|
||||
# only dispatch MESSAGE_UPDATE if any update
|
||||
# actually happened
|
||||
if updated:
|
||||
await app.dispatcher.dispatch('channel', channel_id,
|
||||
'MESSAGE_UPDATE', message)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
|
||||
|
||||
return jsonify(message)
|
||||
|
||||
|
||||
async def _del_msg_fkeys(message_id: int):
|
||||
attachs = await app.db.fetch("""
|
||||
attachs = await app.db.fetch(
|
||||
"""
|
||||
SELECT id FROM attachments
|
||||
WHERE message_id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
attachs = [r['id'] for r in attachs]
|
||||
attachs = [r["id"] for r in attachs]
|
||||
|
||||
attachments = Path('./attachments')
|
||||
attachments = Path("./attachments")
|
||||
for attach_id in attachs:
|
||||
# anything starting with the given attachment shall be
|
||||
# deleted, because there may be resizes of the original
|
||||
# attachment laying around.
|
||||
for filepath in attachments.glob(f'{attach_id}*'):
|
||||
for filepath in attachments.glob(f"{attach_id}*"):
|
||||
try_unlink(filepath)
|
||||
|
||||
# after trying to delete all available attachments, delete
|
||||
|
|
@ -542,51 +591,64 @@ async def _del_msg_fkeys(message_id: int):
|
|||
|
||||
# take the chance and delete all the data from the other tables too!
|
||||
|
||||
tables = ['attachments', 'message_webhook_info',
|
||||
'message_reactions', 'channel_pins']
|
||||
tables = [
|
||||
"attachments",
|
||||
"message_webhook_info",
|
||||
"message_reactions",
|
||||
"channel_pins",
|
||||
]
|
||||
|
||||
for table in tables:
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
DELETE FROM {table}
|
||||
WHERE message_id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["DELETE"])
|
||||
async def delete_message(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
author_id = await app.db.fetchval("""
|
||||
author_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT author_id FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
|
||||
by_perm = await channel_perm_check(
|
||||
user_id, channel_id, 'manage_messages', False
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
by_perm = await channel_perm_check(user_id, channel_id, "manage_messages", False)
|
||||
|
||||
by_ownership = author_id == user_id
|
||||
|
||||
can_delete = by_perm or by_ownership
|
||||
if not can_delete:
|
||||
raise Forbidden('You can not delete this message')
|
||||
raise Forbidden("You can not delete this message")
|
||||
|
||||
await _del_msg_fkeys(message_id)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id,
|
||||
'MESSAGE_DELETE', {
|
||||
'id': str(message_id),
|
||||
'channel_id': str(channel_id),
|
||||
|
||||
"channel",
|
||||
channel_id,
|
||||
"MESSAGE_DELETE",
|
||||
{
|
||||
"id": str(message_id),
|
||||
"channel_id": str(channel_id),
|
||||
# for lazy guilds
|
||||
'guild_id': str(guild_id),
|
||||
})
|
||||
"guild_id": str(guild_id),
|
||||
},
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -28,28 +28,32 @@ from litecord.system_messages import send_sys_message
|
|||
from litecord.enums import MessageType, SYS_MESSAGES
|
||||
from litecord.errors import BadRequest
|
||||
|
||||
bp = Blueprint('channel_pins', __name__)
|
||||
bp = Blueprint("channel_pins", __name__)
|
||||
|
||||
|
||||
class SysMsgInvalidAction(BadRequest):
|
||||
"""Invalid action on a system message."""
|
||||
|
||||
error_code = 50021
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins', methods=['GET'])
|
||||
@bp.route("/<int:channel_id>/pins", methods=["GET"])
|
||||
async def get_pins(channel_id):
|
||||
"""Get the pins for a channel"""
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
ids = await app.db.fetch("""
|
||||
ids = await app.db.fetch(
|
||||
"""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id DESC
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
ids = [r['message_id'] for r in ids]
|
||||
ids = [r["message_id"] for r in ids]
|
||||
res = []
|
||||
|
||||
for message_id in ids:
|
||||
|
|
@ -60,80 +64,96 @@ async def get_pins(channel_id):
|
|||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
|
||||
@bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["PUT"])
|
||||
async def add_pin(channel_id, message_id):
|
||||
"""Add a pin to a channel"""
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
||||
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||
|
||||
mtype = await app.db.fetchval("""
|
||||
mtype = await app.db.fetchval(
|
||||
"""
|
||||
SELECT message_type
|
||||
FROM messages
|
||||
WHERE id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
if mtype in SYS_MESSAGES:
|
||||
raise SysMsgInvalidAction(
|
||||
'Cannot execute action on a system message')
|
||||
raise SysMsgInvalidAction("Cannot execute action on a system message")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO channel_pins (channel_id, message_id)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, message_id)
|
||||
""",
|
||||
channel_id,
|
||||
message_id,
|
||||
)
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
row = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
LIMIT 1
|
||||
""", channel_id)
|
||||
|
||||
timestamp = snowflake_datetime(row['message_id'])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_PINS_UPDATE',
|
||||
{
|
||||
'channel_id': str(channel_id),
|
||||
'last_pin_timestamp': timestamp_(timestamp)
|
||||
}
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await send_sys_message(app, channel_id,
|
||||
MessageType.CHANNEL_PINNED_MESSAGE,
|
||||
message_id, user_id)
|
||||
timestamp = snowflake_datetime(row["message_id"])
|
||||
|
||||
return '', 204
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
channel_id,
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
|
||||
)
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["DELETE"])
|
||||
async def delete_pin(channel_id, message_id):
|
||||
user_id = await token_check()
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
||||
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM channel_pins
|
||||
WHERE channel_id = $1 AND message_id = $2
|
||||
""", channel_id, message_id)
|
||||
""",
|
||||
channel_id,
|
||||
message_id,
|
||||
)
|
||||
|
||||
row = await app.db.fetchrow("""
|
||||
row = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT message_id
|
||||
FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
ORDER BY message_id ASC
|
||||
LIMIT 1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
timestamp = snowflake_datetime(row['message_id'])
|
||||
timestamp = snowflake_datetime(row["message_id"])
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
|
||||
'channel_id': str(channel_id),
|
||||
'last_pin_timestamp': timestamp.isoformat()
|
||||
})
|
||||
"channel",
|
||||
channel_id,
|
||||
"CHANNEL_PINS_UPDATE",
|
||||
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -26,17 +26,15 @@ from logbook import Logger
|
|||
from litecord.utils import async_map
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||
from litecord.blueprints.channel.messages import (
|
||||
query_tuple_from_args, extract_limit
|
||||
)
|
||||
from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
|
||||
|
||||
from litecord.enums import GUILD_CHANS
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channel_reactions', __name__)
|
||||
bp = Blueprint("channel_reactions", __name__)
|
||||
|
||||
BASEPATH = '/<int:channel_id>/messages/<int:message_id>/reactions'
|
||||
BASEPATH = "/<int:channel_id>/messages/<int:message_id>/reactions"
|
||||
|
||||
|
||||
class EmojiType(IntEnum):
|
||||
|
|
@ -51,16 +49,14 @@ def emoji_info_from_str(emoji: str) -> tuple:
|
|||
# unicode emoji just have the raw unicode.
|
||||
|
||||
# try checking if the emoji is custom or unicode
|
||||
emoji_type = 0 if ':' in emoji else 1
|
||||
emoji_type = 0 if ":" in emoji else 1
|
||||
emoji_type = EmojiType(emoji_type)
|
||||
|
||||
# extract the emoji id OR the unicode value of the emoji
|
||||
# depending if it is custom or not
|
||||
emoji_id = (int(emoji.split(':')[1])
|
||||
if emoji_type == EmojiType.CUSTOM
|
||||
else emoji)
|
||||
emoji_id = int(emoji.split(":")[1]) if emoji_type == EmojiType.CUSTOM else emoji
|
||||
|
||||
emoji_name = emoji.split(':')[0]
|
||||
emoji_name = emoji.split(":")[0]
|
||||
|
||||
return emoji_type, emoji_id, emoji_name
|
||||
|
||||
|
|
@ -68,27 +64,27 @@ def emoji_info_from_str(emoji: str) -> tuple:
|
|||
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
|
||||
print(emoji_type, emoji_id, emoji_name)
|
||||
return {
|
||||
'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
|
||||
"id": None if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||
"name": emoji_name if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||
}
|
||||
|
||||
|
||||
def _make_payload(user_id, channel_id, message_id, partial):
|
||||
return {
|
||||
'user_id': str(user_id),
|
||||
'channel_id': str(channel_id),
|
||||
'message_id': str(message_id),
|
||||
'emoji': partial
|
||||
"user_id": str(user_id),
|
||||
"channel_id": str(channel_id),
|
||||
"message_id": str(message_id),
|
||||
"emoji": partial,
|
||||
}
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['PUT'])
|
||||
@bp.route(f"{BASEPATH}/<emoji>/@me", methods=["PUT"])
|
||||
async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
||||
"""Put a reaction."""
|
||||
user_id = await token_check()
|
||||
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
await channel_perm_check(user_id, channel_id, 'read_history')
|
||||
await channel_perm_check(user_id, channel_id, "read_history")
|
||||
|
||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||
|
||||
|
|
@ -97,52 +93,64 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
|||
|
||||
# ADD_REACTIONS is only checked when this is the first
|
||||
# reaction in a message.
|
||||
reaction_count = await app.db.fetchval("""
|
||||
reaction_count = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM message_reactions
|
||||
WHERE message_id = $1
|
||||
AND emoji_type = $2
|
||||
AND emoji_id = $3
|
||||
AND emoji_text = $4
|
||||
""", message_id, emoji_type, emoji_id, emoji_text)
|
||||
""",
|
||||
message_id,
|
||||
emoji_type,
|
||||
emoji_id,
|
||||
emoji_text,
|
||||
)
|
||||
|
||||
if reaction_count == 0:
|
||||
await channel_perm_check(user_id, channel_id, 'add_reactions')
|
||||
await channel_perm_check(user_id, channel_id, "add_reactions")
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO message_reactions (message_id, user_id,
|
||||
emoji_type, emoji_id, emoji_text)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""", message_id, user_id, emoji_type,
|
||||
|
||||
""",
|
||||
message_id,
|
||||
user_id,
|
||||
emoji_type,
|
||||
# if it is custom, we put the emoji_id on emoji_id
|
||||
# column, if it isn't, we put it on emoji_text
|
||||
# column.
|
||||
emoji_id, emoji_text
|
||||
emoji_id,
|
||||
emoji_text,
|
||||
)
|
||||
|
||||
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
payload['guild_id'] = str(guild_id)
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_REACTION_ADD', payload)
|
||||
"channel", channel_id, "MESSAGE_REACTION_ADD", payload
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
|
||||
"""Extract SQL clauses to search for specific emoji
|
||||
in the message_reactions table."""
|
||||
param = f'${param}'
|
||||
param = f"${param}"
|
||||
|
||||
# know which column to filter with
|
||||
where_ext = (f'AND emoji_id = {param}'
|
||||
if emoji_type == EmojiType.CUSTOM else
|
||||
f'AND emoji_text = {param}')
|
||||
where_ext = (
|
||||
f"AND emoji_id = {param}"
|
||||
if emoji_type == EmojiType.CUSTOM
|
||||
else f"AND emoji_text = {param}"
|
||||
)
|
||||
|
||||
# which emoji to remove (custom or unicode)
|
||||
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
|
||||
|
|
@ -157,8 +165,7 @@ def _emoji_sql_simple(emoji: str, param=4):
|
|||
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
|
||||
|
||||
|
||||
async def remove_reaction(channel_id: int, message_id: int,
|
||||
user_id: int, emoji: str):
|
||||
async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||
|
|
@ -171,39 +178,45 @@ async def remove_reaction(channel_id: int, message_id: int,
|
|||
AND user_id = $2
|
||||
AND emoji_type = $3
|
||||
{where_ext}
|
||||
""", message_id, user_id, emoji_type, main_emoji)
|
||||
""",
|
||||
message_id,
|
||||
user_id,
|
||||
emoji_type,
|
||||
main_emoji,
|
||||
)
|
||||
|
||||
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
payload['guild_id'] = str(guild_id)
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload)
|
||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
|
||||
)
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['DELETE'])
|
||||
@bp.route(f"{BASEPATH}/<emoji>/@me", methods=["DELETE"])
|
||||
async def remove_own_reaction(channel_id, message_id, emoji):
|
||||
"""Remove a reaction."""
|
||||
user_id = await token_check()
|
||||
|
||||
await remove_reaction(channel_id, message_id, user_id, emoji)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>/<int:other_id>', methods=['DELETE'])
|
||||
@bp.route(f"{BASEPATH}/<emoji>/<int:other_id>", methods=["DELETE"])
|
||||
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
|
||||
"""Remove a reaction made by another user."""
|
||||
user_id = await token_check()
|
||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
||||
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||
|
||||
await remove_reaction(channel_id, message_id, other_id, emoji)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}/<emoji>', methods=['GET'])
|
||||
@bp.route(f"{BASEPATH}/<emoji>", methods=["GET"])
|
||||
async def list_users_reaction(channel_id, message_id, emoji):
|
||||
"""Get the list of all users who reacted with a certain emoji."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -215,42 +228,49 @@ async def list_users_reaction(channel_id, message_id, emoji):
|
|||
limit = extract_limit(request, 25)
|
||||
before, after = query_tuple_from_args(request.args, limit)
|
||||
|
||||
before_clause = 'AND user_id < $2' if before else ''
|
||||
after_clause = 'AND user_id > $3' if after else ''
|
||||
before_clause = "AND user_id < $2" if before else ""
|
||||
after_clause = "AND user_id > $3" if after else ""
|
||||
|
||||
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
|
||||
|
||||
rows = await app.db.fetch(f"""
|
||||
rows = await app.db.fetch(
|
||||
f"""
|
||||
SELECT user_id
|
||||
FROM message_reactions
|
||||
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
|
||||
""", message_id, before, after, main_emoji)
|
||||
""",
|
||||
message_id,
|
||||
before,
|
||||
after,
|
||||
main_emoji,
|
||||
)
|
||||
|
||||
user_ids = [r['user_id'] for r in rows]
|
||||
user_ids = [r["user_id"] for r in rows]
|
||||
users = await async_map(app.storage.get_user, user_ids)
|
||||
return jsonify(users)
|
||||
|
||||
|
||||
@bp.route(f'{BASEPATH}', methods=['DELETE'])
|
||||
@bp.route(f"{BASEPATH}", methods=["DELETE"])
|
||||
async def remove_all_reactions(channel_id, message_id):
|
||||
"""Remove all reactions in a message."""
|
||||
user_id = await token_check()
|
||||
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
||||
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM message_reactions
|
||||
WHERE message_id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
payload = {
|
||||
'channel_id': str(channel_id),
|
||||
'message_id': str(message_id),
|
||||
}
|
||||
payload = {"channel_id": str(channel_id), "message_id": str(message_id)}
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
payload['guild_id'] = str(guild_id)
|
||||
payload["guild_id"] = str(guild_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload)
|
||||
"channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,24 +28,26 @@ from litecord.auth import token_check
|
|||
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
|
||||
from litecord.errors import ChannelNotFound, Forbidden, BadRequest
|
||||
from litecord.schemas import (
|
||||
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE,
|
||||
validate,
|
||||
CHAN_UPDATE,
|
||||
CHAN_OVERWRITE,
|
||||
SEARCH_CHANNEL,
|
||||
GROUP_DM_UPDATE,
|
||||
BULK_DELETE,
|
||||
)
|
||||
|
||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||
from litecord.system_messages import send_sys_message
|
||||
from litecord.blueprints.dm_channels import (
|
||||
gdm_remove_recipient, gdm_destroy
|
||||
)
|
||||
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
|
||||
from litecord.utils import search_result_from_list
|
||||
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
||||
from litecord.snowflake import snowflake_datetime
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channels', __name__)
|
||||
bp = Blueprint("channels", __name__)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>', methods=['GET'])
|
||||
@bp.route("/<int:channel_id>", methods=["GET"])
|
||||
async def get_channel(channel_id):
|
||||
"""Get a single channel's information"""
|
||||
user_id = await token_check()
|
||||
|
|
@ -56,7 +58,7 @@ async def get_channel(channel_id):
|
|||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
if not chan:
|
||||
raise ChannelNotFound('single channel not found')
|
||||
raise ChannelNotFound("single channel not found")
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
|
@ -64,106 +66,129 @@ async def get_channel(channel_id):
|
|||
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
|
||||
"""Update a guild's channel id field to NULL,
|
||||
if it was set to the given channel id before."""
|
||||
return await app.db.execute(f"""
|
||||
return await app.db.execute(
|
||||
f"""
|
||||
UPDATE guilds
|
||||
SET {field} = NULL
|
||||
WHERE guilds.id = $1 AND {field} = $2
|
||||
""", guild_id, channel_id)
|
||||
""",
|
||||
guild_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
|
||||
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
||||
res_embed = await __guild_chan_sql(
|
||||
guild_id, channel_id, 'embed_channel_id')
|
||||
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
|
||||
|
||||
res_widget = await __guild_chan_sql(
|
||||
guild_id, channel_id, 'widget_channel_id')
|
||||
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
|
||||
|
||||
res_system = await __guild_chan_sql(
|
||||
guild_id, channel_id, 'system_channel_id')
|
||||
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
|
||||
|
||||
# if none of them were actually updated,
|
||||
# ignore and dont dispatch anything
|
||||
if 'UPDATE 1' not in (res_embed, res_widget, res_system):
|
||||
if "UPDATE 1" not in (res_embed, res_widget, res_system):
|
||||
return
|
||||
|
||||
# at least one of the fields were updated,
|
||||
# dispatch GUILD_UPDATE
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'GUILD_UPDATE', guild)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
|
||||
|
||||
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
||||
res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id')
|
||||
res = await __guild_chan_sql(guild_id, channel_id, "afk_channel_id")
|
||||
|
||||
# guild didnt update
|
||||
if res == 'UPDATE 0':
|
||||
if res == "UPDATE 0":
|
||||
return
|
||||
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'GUILD_UPDATE', guild)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
|
||||
|
||||
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||
# get all channels that were childs of the category
|
||||
childs = await app.db.fetch("""
|
||||
childs = await app.db.fetch(
|
||||
"""
|
||||
SELECT id
|
||||
FROM guild_channels
|
||||
WHERE guild_id = $1 AND parent_id = $2
|
||||
""", guild_id, channel_id)
|
||||
childs = [c['id'] for c in childs]
|
||||
""",
|
||||
guild_id,
|
||||
channel_id,
|
||||
)
|
||||
childs = [c["id"] for c in childs]
|
||||
|
||||
# update every child channel to parent_id = NULL
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guild_channels
|
||||
SET parent_id = NULL
|
||||
WHERE guild_id = $1 AND parent_id = $2
|
||||
""", guild_id, channel_id)
|
||||
""",
|
||||
guild_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
# tell all people in the guild of the category removal
|
||||
for child_id in childs:
|
||||
child = await app.storage.get_channel(child_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'CHANNEL_UPDATE', child
|
||||
)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
|
||||
|
||||
|
||||
async def delete_messages(channel_id):
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM channel_pins
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM user_read_state
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM messages
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
|
||||
async def guild_cleanup(channel_id):
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM channel_overwrites
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM invites
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM webhooks
|
||||
WHERE channel_id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:channel_id>", methods=["DELETE"])
|
||||
async def close_channel(channel_id):
|
||||
"""Close or delete a channel."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -184,9 +209,8 @@ async def close_channel(channel_id):
|
|||
}[ctype]
|
||||
|
||||
main_tbl = {
|
||||
ChannelType.GUILD_TEXT: 'guild_text_channels',
|
||||
ChannelType.GUILD_VOICE: 'guild_voice_channels',
|
||||
|
||||
ChannelType.GUILD_TEXT: "guild_text_channels",
|
||||
ChannelType.GUILD_VOICE: "guild_voice_channels",
|
||||
# TODO: categories?
|
||||
}[ctype]
|
||||
|
||||
|
|
@ -199,29 +223,37 @@ async def close_channel(channel_id):
|
|||
await delete_messages(channel_id)
|
||||
await guild_cleanup(channel_id)
|
||||
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
DELETE FROM {main_tbl}
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM guild_channels
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM channels
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
# clean its member list representation
|
||||
lazy_guilds = app.dispatcher.backends['lazy_guild']
|
||||
lazy_guilds = app.dispatcher.backends["lazy_guild"]
|
||||
lazy_guilds.remove_channel(channel_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'CHANNEL_DELETE', chan)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
|
||||
|
||||
await app.dispatcher.remove('channel', channel_id)
|
||||
await app.dispatcher.remove("channel", channel_id)
|
||||
return jsonify(chan)
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
|
|
@ -231,27 +263,34 @@ async def close_channel(channel_id):
|
|||
# instead, we close the channel for the user that is making
|
||||
# the request via removing the link between them and
|
||||
# the channel on dm_channel_state
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM dm_channel_state
|
||||
WHERE user_id = $1 AND dm_id = $2
|
||||
""", user_id, channel_id)
|
||||
""",
|
||||
user_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
# unsubscribe
|
||||
await app.dispatcher.unsub('channel', channel_id, user_id)
|
||||
await app.dispatcher.unsub("channel", channel_id, user_id)
|
||||
|
||||
# nothing happens to the other party of the dm channel
|
||||
await app.dispatcher.dispatch_user(user_id, 'CHANNEL_DELETE', chan)
|
||||
await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
if ctype == ChannelType.GROUP_DM:
|
||||
await gdm_remove_recipient(channel_id, user_id)
|
||||
|
||||
gdm_count = await app.db.fetchval("""
|
||||
gdm_count = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM group_dm_members
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
if gdm_count == 0:
|
||||
# destroy dm
|
||||
|
|
@ -261,11 +300,15 @@ async def close_channel(channel_id):
|
|||
|
||||
|
||||
async def _update_pos(channel_id, pos: int):
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guild_channels
|
||||
SET position = $1
|
||||
WHERE id = $2
|
||||
""", pos, channel_id)
|
||||
""",
|
||||
pos,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
|
||||
async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
||||
|
|
@ -274,20 +317,19 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
|||
continue
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch(
|
||||
'guild', guild_id, 'CHANNEL_UPDATE', chan)
|
||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
|
||||
async def _process_overwrites(channel_id: int, overwrites: list):
|
||||
for overwrite in overwrites:
|
||||
|
||||
# 0 for member overwrite, 1 for role overwrite
|
||||
target_type = 0 if overwrite['type'] == 'member' else 1
|
||||
target_role = None if target_type == 0 else overwrite['id']
|
||||
target_user = overwrite['id'] if target_type == 0 else None
|
||||
target_type = 0 if overwrite["type"] == "member" else 1
|
||||
target_role = None if target_type == 0 else overwrite["id"]
|
||||
target_user = overwrite["id"] if target_type == 0 else None
|
||||
|
||||
col_name = 'target_user' if target_type == 0 else 'target_role'
|
||||
constraint_name = f'channel_overwrites_{col_name}_uniq'
|
||||
col_name = "target_user" if target_type == 0 else "target_role"
|
||||
constraint_name = f"channel_overwrites_{col_name}_uniq"
|
||||
|
||||
await app.db.execute(
|
||||
f"""
|
||||
|
|
@ -301,53 +343,66 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
|||
UPDATE
|
||||
SET allow = $5, deny = $6
|
||||
""",
|
||||
channel_id, target_type,
|
||||
target_role, target_user,
|
||||
overwrite['allow'], overwrite['deny'])
|
||||
channel_id,
|
||||
target_type,
|
||||
target_role,
|
||||
target_user,
|
||||
overwrite["allow"],
|
||||
overwrite["deny"],
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/permissions/<int:overwrite_id>', methods=['PUT'])
|
||||
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
|
||||
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||
"""Insert or modify a channel overwrite."""
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
if ctype not in GUILD_CHANS:
|
||||
raise ChannelNotFound('Only usable for guild channels.')
|
||||
raise ChannelNotFound("Only usable for guild channels.")
|
||||
|
||||
await channel_perm_check(user_id, guild_id, 'manage_roles')
|
||||
await channel_perm_check(user_id, guild_id, "manage_roles")
|
||||
|
||||
j = validate(
|
||||
# inserting a fake id on the payload so validation passes through
|
||||
{**await request.get_json(), **{'id': -1}},
|
||||
CHAN_OVERWRITE
|
||||
{**await request.get_json(), **{"id": -1}},
|
||||
CHAN_OVERWRITE,
|
||||
)
|
||||
|
||||
await _process_overwrites(channel_id, [{
|
||||
'allow': j['allow'],
|
||||
'deny': j['deny'],
|
||||
'type': j['type'],
|
||||
'id': overwrite_id
|
||||
}])
|
||||
await _process_overwrites(
|
||||
channel_id,
|
||||
[
|
||||
{
|
||||
"allow": j["allow"],
|
||||
"deny": j["deny"],
|
||||
"type": j["type"],
|
||||
"id": overwrite_id,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await _mass_chan_update(guild_id, [channel_id])
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
||||
if 'name' in j:
|
||||
await app.db.execute("""
|
||||
if "name" in j:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guild_channels
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""", j['name'], channel_id)
|
||||
""",
|
||||
j["name"],
|
||||
channel_id,
|
||||
)
|
||||
|
||||
if 'position' in j:
|
||||
if "position" in j:
|
||||
channel_data = await app.storage.get_channel_data(guild_id)
|
||||
|
||||
chans = [None] * len(channel_data)
|
||||
for chandata in channel_data:
|
||||
chans.insert(chandata['position'], int(chandata['id']))
|
||||
chans.insert(chandata["position"], int(chandata["id"]))
|
||||
|
||||
# are we changing to the left or to the right?
|
||||
|
||||
|
|
@ -358,7 +413,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
|||
# channelN-1 going to the position channel2
|
||||
# was occupying.
|
||||
current_pos = chans.index(channel_id)
|
||||
new_pos = j['position']
|
||||
new_pos = j["position"]
|
||||
|
||||
# if the new position is bigger than the current one,
|
||||
# we're making a left shift of all the channels that are
|
||||
|
|
@ -366,113 +421,136 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
|||
left_shift = new_pos > current_pos
|
||||
|
||||
# find all channels that we'll have to shift
|
||||
shift_block = (chans[current_pos:new_pos]
|
||||
if left_shift else
|
||||
chans[new_pos:current_pos]
|
||||
)
|
||||
shift_block = (
|
||||
chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
|
||||
)
|
||||
|
||||
shift = -1 if left_shift else 1
|
||||
|
||||
# do the shift (to the left or to the right)
|
||||
await app.db.executemany("""
|
||||
await app.db.executemany(
|
||||
"""
|
||||
UPDATE guild_channels
|
||||
SET position = position + $1
|
||||
WHERE id = $2
|
||||
""", [(shift, chan_id) for chan_id in shift_block])
|
||||
""",
|
||||
[(shift, chan_id) for chan_id in shift_block],
|
||||
)
|
||||
|
||||
await _mass_chan_update(guild_id, shift_block)
|
||||
|
||||
# since theres now an empty slot, move current channel to it
|
||||
await _update_pos(channel_id, new_pos)
|
||||
|
||||
if 'channel_overwrites' in j:
|
||||
overwrites = j['channel_overwrites']
|
||||
if "channel_overwrites" in j:
|
||||
overwrites = j["channel_overwrites"]
|
||||
await _process_overwrites(channel_id, overwrites)
|
||||
|
||||
|
||||
async def _common_guild_chan(channel_id, j: dict):
|
||||
# common updates to the guild_channels table
|
||||
for field in [field for field in j.keys()
|
||||
if field in ('nsfw', 'parent_id')]:
|
||||
await app.db.execute(f"""
|
||||
for field in [field for field in j.keys() if field in ("nsfw", "parent_id")]:
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guild_channels
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], channel_id)
|
||||
""",
|
||||
j[field],
|
||||
channel_id,
|
||||
)
|
||||
|
||||
|
||||
async def _update_text_channel(channel_id: int, j: dict, _user_id: int):
|
||||
# first do the specific ones related to guild_text_channels
|
||||
for field in [field for field in j.keys()
|
||||
if field in ('topic', 'rate_limit_per_user')]:
|
||||
await app.db.execute(f"""
|
||||
for field in [
|
||||
field for field in j.keys() if field in ("topic", "rate_limit_per_user")
|
||||
]:
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guild_text_channels
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], channel_id)
|
||||
""",
|
||||
j[field],
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await _common_guild_chan(channel_id, j)
|
||||
|
||||
|
||||
async def _update_voice_channel(channel_id: int, j: dict, _user_id: int):
|
||||
# first do the specific ones in guild_voice_channels
|
||||
for field in [field for field in j.keys()
|
||||
if field in ('bitrate', 'user_limit')]:
|
||||
await app.db.execute(f"""
|
||||
for field in [field for field in j.keys() if field in ("bitrate", "user_limit")]:
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guild_voice_channels
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], channel_id)
|
||||
""",
|
||||
j[field],
|
||||
channel_id,
|
||||
)
|
||||
|
||||
# yes, i'm letting voice channels have nsfw, you cant stop me
|
||||
await _common_guild_chan(channel_id, j)
|
||||
|
||||
|
||||
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
|
||||
if 'name' in j:
|
||||
await app.db.execute("""
|
||||
if "name" in j:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE group_dm_channels
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""", j['name'], channel_id)
|
||||
""",
|
||||
j["name"],
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
|
||||
)
|
||||
|
||||
if 'icon' in j:
|
||||
if "icon" in j:
|
||||
new_icon = await app.icons.update(
|
||||
'channel-icons', channel_id, j['icon'], always_icon=True
|
||||
"channel-icons", channel_id, j["icon"], always_icon=True
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE group_dm_channels
|
||||
SET icon = $1
|
||||
WHERE id = $2
|
||||
""", new_icon.icon_hash, channel_id)
|
||||
""",
|
||||
new_icon.icon_hash,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
|
||||
@bp.route("/<int:channel_id>", methods=["PUT", "PATCH"])
|
||||
async def update_channel(channel_id):
|
||||
"""Update a channel's information"""
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
||||
ChannelType.GROUP_DM):
|
||||
raise ChannelNotFound('unable to edit unsupported chan type')
|
||||
if ctype not in (
|
||||
ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE,
|
||||
ChannelType.GROUP_DM,
|
||||
):
|
||||
raise ChannelNotFound("unable to edit unsupported chan type")
|
||||
|
||||
is_guild = ctype in GUILD_CHANS
|
||||
|
||||
if is_guild:
|
||||
await channel_perm_check(user_id, channel_id, 'manage_channels')
|
||||
await channel_perm_check(user_id, channel_id, "manage_channels")
|
||||
|
||||
j = validate(await request.get_json(),
|
||||
CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
|
||||
j = validate(await request.get_json(), CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
|
||||
|
||||
# TODO: categories
|
||||
update_handler = {
|
||||
|
|
@ -489,30 +567,32 @@ async def update_channel(channel_id):
|
|||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
if is_guild:
|
||||
await app.dispatcher.dispatch(
|
||||
'guild', guild_id, 'CHANNEL_UPDATE', chan)
|
||||
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||
else:
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_UPDATE', chan)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/typing', methods=['POST'])
|
||||
@bp.route("/<int:channel_id>/typing", methods=["POST"])
|
||||
async def trigger_typing(channel_id):
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
await app.dispatcher.dispatch('channel', channel_id, 'TYPING_START', {
|
||||
'channel_id': str(channel_id),
|
||||
'user_id': str(user_id),
|
||||
'timestamp': int(time.time()),
|
||||
await app.dispatcher.dispatch(
|
||||
"channel",
|
||||
channel_id,
|
||||
"TYPING_START",
|
||||
{
|
||||
"channel_id": str(channel_id),
|
||||
"user_id": str(user_id),
|
||||
"timestamp": int(time.time()),
|
||||
# guild_id for lazy guilds
|
||||
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
||||
},
|
||||
)
|
||||
|
||||
# guild_id for lazy guilds
|
||||
'guild_id': str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
||||
})
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||
|
|
@ -521,7 +601,8 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
|||
if not message_id:
|
||||
message_id = await app.storage.chan_last_message(channel_id)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO user_read_state
|
||||
(user_id, channel_id, last_message_id, mention_count)
|
||||
VALUES
|
||||
|
|
@ -532,26 +613,31 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
|||
SET last_message_id = $3, mention_count = 0
|
||||
WHERE user_read_state.user_id = $1
|
||||
AND user_read_state.channel_id = $2
|
||||
""", user_id, channel_id, message_id)
|
||||
""",
|
||||
user_id,
|
||||
channel_id,
|
||||
message_id,
|
||||
)
|
||||
|
||||
if guild_id:
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
user_id, guild_id, 'MESSAGE_ACK', {
|
||||
'message_id': str(message_id),
|
||||
'channel_id': str(channel_id)
|
||||
})
|
||||
user_id,
|
||||
guild_id,
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
)
|
||||
else:
|
||||
# we don't use ChannelDispatcher here because since
|
||||
# guild_id is None, all user devices are already subscribed
|
||||
# to the given channel (a dm or a group dm)
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, 'MESSAGE_ACK', {
|
||||
'message_id': str(message_id),
|
||||
'channel_id': str(channel_id)
|
||||
})
|
||||
user_id,
|
||||
"MESSAGE_ACK",
|
||||
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
|
||||
@bp.route("/<int:channel_id>/messages/<int:message_id>/ack", methods=["POST"])
|
||||
async def ack_channel(channel_id, message_id):
|
||||
"""Acknowledge a channel."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -562,40 +648,47 @@ async def ack_channel(channel_id, message_id):
|
|||
|
||||
await channel_ack(user_id, guild_id, channel_id, message_id)
|
||||
|
||||
return jsonify({
|
||||
# token seems to be used for
|
||||
# data collection activities,
|
||||
# so we never use it.
|
||||
'token': None
|
||||
})
|
||||
return jsonify(
|
||||
{
|
||||
# token seems to be used for
|
||||
# data collection activities,
|
||||
# so we never use it.
|
||||
"token": None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE'])
|
||||
@bp.route("/<int:channel_id>/messages/ack", methods=["DELETE"])
|
||||
async def delete_read_state(channel_id):
|
||||
"""Delete the read state of a channel."""
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM user_read_state
|
||||
WHERE user_id = $1 AND channel_id = $2
|
||||
""", user_id, channel_id)
|
||||
""",
|
||||
user_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/search', methods=['GET'])
|
||||
@bp.route("/<int:channel_id>/messages/search", methods=["GET"])
|
||||
async def _search_channel(channel_id):
|
||||
"""Search in DMs or group DMs"""
|
||||
user_id = await token_check()
|
||||
await channel_check(user_id, channel_id)
|
||||
await channel_perm_check(user_id, channel_id, 'read_messages')
|
||||
await channel_perm_check(user_id, channel_id, "read_messages")
|
||||
|
||||
j = validate(dict(request.args), SEARCH_CHANNEL)
|
||||
|
||||
# main search query
|
||||
# the context (before/after) columns are copied from the guilds blueprint.
|
||||
rows = await app.db.fetch(f"""
|
||||
rows = await app.db.fetch(
|
||||
f"""
|
||||
SELECT orig.id AS current_id,
|
||||
COUNT(*) OVER() AS total_results,
|
||||
array((SELECT messages.id AS before_id
|
||||
|
|
@ -611,28 +704,40 @@ async def _search_channel(channel_id):
|
|||
ORDER BY orig.id DESC
|
||||
LIMIT 50
|
||||
OFFSET $2
|
||||
""", channel_id, j['offset'], j['content'])
|
||||
""",
|
||||
channel_id,
|
||||
j["offset"],
|
||||
j["content"],
|
||||
)
|
||||
|
||||
return jsonify(await search_result_from_list(rows))
|
||||
|
||||
|
||||
# NOTE that those functions stay here until some other
|
||||
# route or code wants it.
|
||||
|
||||
|
||||
async def _msg_update_flags(message_id: int, flags: int):
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE messages
|
||||
SET flags = $1
|
||||
WHERE id = $2
|
||||
""", flags, message_id)
|
||||
""",
|
||||
flags,
|
||||
message_id,
|
||||
)
|
||||
|
||||
|
||||
async def _msg_get_flags(message_id: int):
|
||||
return await app.db.fetchval("""
|
||||
return await app.db.fetchval(
|
||||
"""
|
||||
SELECT flags
|
||||
FROM messages
|
||||
WHERE id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
|
||||
async def _msg_set_flags(message_id: int, new_flags: int):
|
||||
|
|
@ -647,8 +752,9 @@ async def _msg_unset_flags(message_id: int, unset_flags: int):
|
|||
await _msg_update_flags(message_id, flags)
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/<int:message_id>/suppress-embeds',
|
||||
methods=['POST'])
|
||||
@bp.route(
|
||||
"/<int:channel_id>/messages/<int:message_id>/suppress-embeds", methods=["POST"]
|
||||
)
|
||||
async def suppress_embeds(channel_id: int, message_id: int):
|
||||
"""Toggle the embeds in a message.
|
||||
|
||||
|
|
@ -661,29 +767,27 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
|||
# the checks here have been copied from the delete_message()
|
||||
# handler on blueprints.channel.messages. maybe we can combine
|
||||
# them someday?
|
||||
author_id = await app.db.fetchval("""
|
||||
author_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT author_id FROM messages
|
||||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
""",
|
||||
message_id,
|
||||
)
|
||||
|
||||
by_perms = await channel_perm_check(
|
||||
user_id, channel_id, 'manage_messages', False)
|
||||
by_perms = await channel_perm_check(user_id, channel_id, "manage_messages", False)
|
||||
|
||||
by_author = author_id == user_id
|
||||
|
||||
can_suppress = by_perms or by_author
|
||||
if not can_suppress:
|
||||
raise Forbidden('Not enough permissions.')
|
||||
raise Forbidden("Not enough permissions.")
|
||||
|
||||
j = validate(
|
||||
await request.get_json(),
|
||||
{'suppress': {'type': 'boolean'}},
|
||||
)
|
||||
j = validate(await request.get_json(), {"suppress": {"type": "boolean"}})
|
||||
|
||||
suppress = j['suppress']
|
||||
suppress = j["suppress"]
|
||||
message = await app.storage.get_message(message_id)
|
||||
url_embeds = sum(
|
||||
1 for embed in message['embeds'] if embed['type'] == 'url')
|
||||
url_embeds = sum(1 for embed in message["embeds"] if embed["type"] == "url")
|
||||
|
||||
# NOTE for any future self. discord doing flags an optional thing instead
|
||||
# of just giving 0 is a pretty bad idea because now i have to deal with
|
||||
|
|
@ -693,8 +797,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
|||
# delete all embeds then dispatch an update
|
||||
await _msg_set_flags(message_id, MessageFlags.suppress_embeds)
|
||||
|
||||
message['flags'] = \
|
||||
message.get('flags', 0) | MessageFlags.suppress_embeds
|
||||
message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
|
||||
|
||||
await msg_update_embeds(message, [], app.storage, app.dispatcher)
|
||||
elif not suppress and not url_embeds:
|
||||
|
|
@ -702,30 +805,29 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
|||
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
|
||||
|
||||
try:
|
||||
message.pop('flags')
|
||||
message.pop("flags")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
app.sched.spawn(
|
||||
process_url_embed(
|
||||
app.config, app.storage, app.dispatcher, app.session,
|
||||
message
|
||||
app.config, app.storage, app.dispatcher, app.session, message
|
||||
)
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:channel_id>/messages/bulk-delete', methods=['POST'])
|
||||
@bp.route("/<int:channel_id>/messages/bulk-delete", methods=["POST"])
|
||||
async def bulk_delete(channel_id: int):
|
||||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
guild_id = guild_id if ctype in GUILD_CHANS else None
|
||||
|
||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
||||
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||
|
||||
j = validate(await request.get_json(), BULK_DELETE)
|
||||
message_ids = set(j['messages'])
|
||||
message_ids = set(j["messages"])
|
||||
|
||||
# as per discord behavior, if any id here is older than two weeks,
|
||||
# we must error. a cuter behavior would be returning the message ids
|
||||
|
|
@ -738,25 +840,28 @@ async def bulk_delete(channel_id: int):
|
|||
raise BadRequest(50034)
|
||||
|
||||
payload = {
|
||||
'guild_id': str(guild_id),
|
||||
'channel_id': str(channel_id),
|
||||
'ids': list(map(str, message_ids)),
|
||||
"guild_id": str(guild_id),
|
||||
"channel_id": str(channel_id),
|
||||
"ids": list(map(str, message_ids)),
|
||||
}
|
||||
|
||||
# payload.guild_id is optional in the event, not nullable.
|
||||
if guild_id is None:
|
||||
payload.pop('guild_id')
|
||||
payload.pop("guild_id")
|
||||
|
||||
res = await app.db.execute("""
|
||||
res = await app.db.execute(
|
||||
"""
|
||||
DELETE FROM messages
|
||||
WHERE
|
||||
channel_id = $1
|
||||
AND ARRAY[id] <@ $2::bigint[]
|
||||
""", channel_id, list(message_ids))
|
||||
""",
|
||||
channel_id,
|
||||
list(message_ids),
|
||||
)
|
||||
|
||||
if res == 'DELETE 0':
|
||||
raise BadRequest('No messages were removed')
|
||||
if res == "DELETE 0":
|
||||
raise BadRequest("No messages were removed")
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_DELETE_BULK', payload)
|
||||
return '', 204
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -23,46 +23,57 @@ from quart import current_app as app
|
|||
|
||||
from litecord.enums import ChannelType, GUILD_CHANS
|
||||
from litecord.errors import (
|
||||
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
|
||||
GuildNotFound,
|
||||
ChannelNotFound,
|
||||
Forbidden,
|
||||
MissingPermissions,
|
||||
)
|
||||
from litecord.permissions import base_permissions, get_permissions
|
||||
|
||||
|
||||
async def guild_check(user_id: int, guild_id: int):
|
||||
"""Check if a user is in a guild."""
|
||||
joined_at = await app.db.fetchval("""
|
||||
joined_at = await app.db.fetchval(
|
||||
"""
|
||||
SELECT joined_at
|
||||
FROM members
|
||||
WHERE user_id = $1 AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if not joined_at:
|
||||
raise GuildNotFound('guild not found')
|
||||
raise GuildNotFound("guild not found")
|
||||
|
||||
|
||||
async def guild_owner_check(user_id: int, guild_id: int):
|
||||
"""Check if a user is the owner of the guild."""
|
||||
owner_id = await app.db.fetchval("""
|
||||
owner_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT owner_id
|
||||
FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if not owner_id:
|
||||
raise GuildNotFound()
|
||||
|
||||
if user_id != owner_id:
|
||||
raise Forbidden('You are not the owner of the guild')
|
||||
raise Forbidden("You are not the owner of the guild")
|
||||
|
||||
|
||||
async def channel_check(user_id, channel_id, *,
|
||||
only: Union[ChannelType, List[ChannelType]] = None):
|
||||
async def channel_check(
|
||||
user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None
|
||||
):
|
||||
"""Check if the current user is authorized
|
||||
to read the channel's information."""
|
||||
chan_type = await app.storage.get_chan_type(channel_id)
|
||||
|
||||
if chan_type is None:
|
||||
raise ChannelNotFound('channel type not found')
|
||||
raise ChannelNotFound("channel type not found")
|
||||
|
||||
ctype = ChannelType(chan_type)
|
||||
|
||||
|
|
@ -70,14 +81,17 @@ async def channel_check(user_id, channel_id, *,
|
|||
only = [only]
|
||||
|
||||
if only and ctype not in only:
|
||||
raise ChannelNotFound('invalid channel type')
|
||||
raise ChannelNotFound("invalid channel type")
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
guild_id = await app.db.fetchval("""
|
||||
guild_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT guild_id
|
||||
FROM guild_channels
|
||||
WHERE guild_channels.id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
return ctype, guild_id
|
||||
|
|
@ -87,11 +101,14 @@ async def channel_check(user_id, channel_id, *,
|
|||
return ctype, peer_id
|
||||
|
||||
if ctype == ChannelType.GROUP_DM:
|
||||
owner_id = await app.db.fetchval("""
|
||||
owner_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT owner_id
|
||||
FROM group_dm_channels
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
return ctype, owner_id
|
||||
|
||||
|
|
@ -102,18 +119,17 @@ async def guild_perm_check(user_id, guild_id, permission: str):
|
|||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
||||
if not hasperm:
|
||||
raise MissingPermissions('Missing permissions.')
|
||||
raise MissingPermissions("Missing permissions.")
|
||||
|
||||
return bool(hasperm)
|
||||
|
||||
|
||||
async def channel_perm_check(user_id, channel_id,
|
||||
permission: str, raise_err=True):
|
||||
async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True):
|
||||
"""Check channel permissions for a user."""
|
||||
base_perms = await get_permissions(user_id, channel_id)
|
||||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
||||
if not hasperm and raise_err:
|
||||
raise MissingPermissions('Missing permissions.')
|
||||
raise MissingPermissions("Missing permissions.")
|
||||
|
||||
return bool(hasperm)
|
||||
|
|
|
|||
|
|
@ -29,21 +29,29 @@ from litecord.system_messages import send_sys_message
|
|||
from litecord.pubsub.channel import gdm_recipient_view
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('dm_channels', __name__)
|
||||
bp = Blueprint("dm_channels", __name__)
|
||||
|
||||
|
||||
async def _raw_gdm_add(channel_id, user_id):
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO group_dm_members (id, member_id)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, user_id)
|
||||
""",
|
||||
channel_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
async def _raw_gdm_remove(channel_id, user_id):
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM group_dm_members
|
||||
WHERE id = $1 AND member_id = $2
|
||||
""", channel_id, user_id)
|
||||
""",
|
||||
channel_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
async def gdm_create(user_id, peer_id) -> int:
|
||||
|
|
@ -53,24 +61,32 @@ async def gdm_create(user_id, peer_id) -> int:
|
|||
"""
|
||||
channel_id = get_snowflake()
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, ChannelType.GROUP_DM.value)
|
||||
""",
|
||||
channel_id,
|
||||
ChannelType.GROUP_DM.value,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO group_dm_channels (id, owner_id, name, icon)
|
||||
VALUES ($1, $2, NULL, NULL)
|
||||
""", channel_id, user_id)
|
||||
""",
|
||||
channel_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
await _raw_gdm_add(channel_id, user_id)
|
||||
await _raw_gdm_add(channel_id, peer_id)
|
||||
|
||||
await app.dispatcher.sub('channel', channel_id, user_id)
|
||||
await app.dispatcher.sub('channel', channel_id, peer_id)
|
||||
await app.dispatcher.sub("channel", channel_id, user_id)
|
||||
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
|
||||
|
||||
return channel_id
|
||||
|
||||
|
|
@ -89,17 +105,16 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
|
||||
# the reasoning behind gdm_recipient_view is in its docstring.
|
||||
await app.dispatcher.dispatch(
|
||||
'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id))
|
||||
"user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_UPDATE', chan)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
await app.dispatcher.sub('channel', peer_id)
|
||||
await app.dispatcher.sub("channel", peer_id)
|
||||
|
||||
if user_id:
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.RECIPIENT_ADD,
|
||||
user_id, peer_id
|
||||
app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -116,22 +131,22 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
|||
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch(
|
||||
'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id))
|
||||
"user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
|
||||
)
|
||||
|
||||
await app.dispatcher.unsub('channel', peer_id)
|
||||
await app.dispatcher.unsub("channel", peer_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', {
|
||||
'channel_id': str(channel_id),
|
||||
'user': await app.storage.get_user(peer_id)
|
||||
}
|
||||
"channel",
|
||||
channel_id,
|
||||
"CHANNEL_RECIPIENT_REMOVE",
|
||||
{"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
|
||||
)
|
||||
|
||||
author_id = peer_id if user_id is None else user_id
|
||||
|
||||
await send_sys_message(
|
||||
app, channel_id, MessageType.RECIPIENT_REMOVE,
|
||||
author_id, peer_id
|
||||
app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -139,40 +154,51 @@ async def gdm_destroy(channel_id):
|
|||
"""Destroy a Group DM."""
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM group_dm_members
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM group_dm_channels
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
|
||||
await app.db.execute("""
|
||||
DELETE FROM channels
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'CHANNEL_DELETE', chan
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.remove('channel', channel_id)
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM group_dm_channels
|
||||
WHERE id = $1
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM channels
|
||||
WHERE id = $1
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
|
||||
|
||||
await app.dispatcher.remove("channel", channel_id)
|
||||
|
||||
|
||||
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
|
||||
"""Return if the given user is a member of the Group DM."""
|
||||
row = await app.db.fetchval("""
|
||||
row = await app.db.fetchval(
|
||||
"""
|
||||
SELECT id
|
||||
FROM group_dm_members
|
||||
WHERE id = $1 AND member_id = $2
|
||||
""", channel_id, user_id)
|
||||
""",
|
||||
channel_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return row is not None
|
||||
|
||||
|
||||
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['PUT'])
|
||||
@bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["PUT"])
|
||||
async def add_to_group_dm(dm_chan, peer_id):
|
||||
"""Adds a member to a group dm OR creates a group dm."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -182,8 +208,7 @@ async def add_to_group_dm(dm_chan, peer_id):
|
|||
|
||||
# other_id is the peer of the dm if the given channel is a dm
|
||||
ctype, other_id = await channel_check(
|
||||
user_id, dm_chan,
|
||||
only=[ChannelType.DM, ChannelType.GROUP_DM]
|
||||
user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM]
|
||||
)
|
||||
|
||||
# check relationship with the given user id
|
||||
|
|
@ -191,30 +216,24 @@ async def add_to_group_dm(dm_chan, peer_id):
|
|||
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
||||
|
||||
if not friends:
|
||||
raise BadRequest('Cant insert peer into dm')
|
||||
raise BadRequest("Cant insert peer into dm")
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
dm_chan = await gdm_create(
|
||||
user_id, other_id
|
||||
)
|
||||
dm_chan = await gdm_create(user_id, other_id)
|
||||
|
||||
await gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_channel(dm_chan)
|
||||
)
|
||||
return jsonify(await app.storage.get_channel(dm_chan))
|
||||
|
||||
|
||||
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["DELETE"])
|
||||
async def remove_from_group_dm(dm_chan, peer_id):
|
||||
"""Remove users from group dm."""
|
||||
user_id = await token_check()
|
||||
_ctype, owner_id = await channel_check(
|
||||
user_id, dm_chan, only=ChannelType.GROUP_DM
|
||||
)
|
||||
_ctype, owner_id = await channel_check(user_id, dm_chan, only=ChannelType.GROUP_DM)
|
||||
|
||||
if owner_id != user_id:
|
||||
raise Forbidden('You are now the owner of the group DM')
|
||||
raise Forbidden("You are now the owner of the group DM")
|
||||
|
||||
await gdm_remove_recipient(dm_chan, peer_id)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -30,15 +30,13 @@ from ..snowflake import get_snowflake
|
|||
|
||||
from .auth import token_check
|
||||
|
||||
from litecord.blueprints.dm_channels import (
|
||||
gdm_create, gdm_add_recipient
|
||||
)
|
||||
from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('dms', __name__)
|
||||
bp = Blueprint("dms", __name__)
|
||||
|
||||
|
||||
@bp.route('/@me/channels', methods=['GET'])
|
||||
@bp.route("/@me/channels", methods=["GET"])
|
||||
async def get_dms():
|
||||
"""Get the open DMs for the user."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -53,11 +51,15 @@ async def try_dm_state(user_id: int, dm_id: int):
|
|||
Does not do anything if the user is already
|
||||
in the dm state.
|
||||
"""
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO dm_channel_state (user_id, dm_id)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT DO NOTHING
|
||||
""", user_id, dm_id)
|
||||
""",
|
||||
user_id,
|
||||
dm_id,
|
||||
)
|
||||
|
||||
|
||||
async def jsonify_dm(dm_id: int, user_id: int):
|
||||
|
|
@ -69,12 +71,16 @@ async def create_dm(user_id, recipient_id):
|
|||
"""Create a new dm with a user,
|
||||
or get the existing DM id if it already exists."""
|
||||
|
||||
dm_id = await app.db.fetchval("""
|
||||
dm_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT id
|
||||
FROM dm_channels
|
||||
WHERE (party1_id = $1 OR party2_id = $1) AND
|
||||
(party1_id = $2 OR party2_id = $2)
|
||||
""", user_id, recipient_id)
|
||||
""",
|
||||
user_id,
|
||||
recipient_id,
|
||||
)
|
||||
|
||||
if dm_id:
|
||||
return await jsonify_dm(dm_id, user_id)
|
||||
|
|
@ -82,15 +88,24 @@ async def create_dm(user_id, recipient_id):
|
|||
# if no dm was found, create a new one
|
||||
|
||||
dm_id = get_snowflake()
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", dm_id, ChannelType.DM.value)
|
||||
""",
|
||||
dm_id,
|
||||
ChannelType.DM.value,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", dm_id, user_id, recipient_id)
|
||||
""",
|
||||
dm_id,
|
||||
user_id,
|
||||
recipient_id,
|
||||
)
|
||||
|
||||
# the dm state is something we use
|
||||
# to give the currently "open dms"
|
||||
|
|
@ -103,24 +118,24 @@ async def create_dm(user_id, recipient_id):
|
|||
return await jsonify_dm(dm_id, user_id)
|
||||
|
||||
|
||||
@bp.route('/@me/channels', methods=['POST'])
|
||||
@bp.route("/@me/channels", methods=["POST"])
|
||||
async def start_dm():
|
||||
"""Create a DM with a user."""
|
||||
user_id = await token_check()
|
||||
j = validate(await request.get_json(), CREATE_DM)
|
||||
recipient_id = j['recipient_id']
|
||||
recipient_id = j["recipient_id"]
|
||||
|
||||
return await create_dm(user_id, recipient_id)
|
||||
|
||||
|
||||
@bp.route('/<int:p_user_id>/channels', methods=['POST'])
|
||||
@bp.route("/<int:p_user_id>/channels", methods=["POST"])
|
||||
async def create_group_dm(p_user_id: int):
|
||||
"""Create a DM or a Group DM with user(s)."""
|
||||
user_id = await token_check()
|
||||
assert user_id == p_user_id
|
||||
|
||||
j = validate(await request.get_json(), CREATE_GROUP_DM)
|
||||
recipients = j['recipients']
|
||||
recipients = j["recipients"]
|
||||
|
||||
if len(recipients) == 1:
|
||||
# its a group dm with 1 user... a dm!
|
||||
|
|
|
|||
|
|
@ -23,37 +23,38 @@ from quart import Blueprint, jsonify, current_app as app
|
|||
|
||||
from ..auth import token_check
|
||||
|
||||
bp = Blueprint('gateway', __name__)
|
||||
bp = Blueprint("gateway", __name__)
|
||||
|
||||
|
||||
def get_gw():
|
||||
"""Get the gateway's web"""
|
||||
proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
|
||||
proto = "wss://" if app.config["IS_SSL"] else "ws://"
|
||||
return f'{proto}{app.config["WEBSOCKET_URL"]}'
|
||||
|
||||
|
||||
@bp.route('/gateway')
|
||||
@bp.route("/gateway")
|
||||
def api_gateway():
|
||||
"""Get the raw URL."""
|
||||
return jsonify({
|
||||
'url': get_gw()
|
||||
})
|
||||
return jsonify({"url": get_gw()})
|
||||
|
||||
|
||||
@bp.route('/gateway/bot')
|
||||
@bp.route("/gateway/bot")
|
||||
async def api_gateway_bot():
|
||||
user_id = await token_check()
|
||||
|
||||
guild_count = await app.db.fetchval("""
|
||||
guild_count = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
shards = max(int(guild_count / 1000), 1)
|
||||
|
||||
# get _ws.session ratelimit
|
||||
ratelimit = app.ratelimiter.get_ratelimit('_ws.session')
|
||||
ratelimit = app.ratelimiter.get_ratelimit("_ws.session")
|
||||
bucket = ratelimit.get_bucket(user_id)
|
||||
|
||||
# timestamp of bucket reset
|
||||
|
|
@ -62,13 +63,14 @@ async def api_gateway_bot():
|
|||
# how many seconds until bucket reset
|
||||
reset_after_ts = reset_ts - time.time()
|
||||
|
||||
return jsonify({
|
||||
'url': get_gw(),
|
||||
'shards': shards,
|
||||
|
||||
'session_start_limit': {
|
||||
'total': bucket.requests,
|
||||
'remaining': bucket._tokens,
|
||||
'reset_after': int(reset_after_ts * 1000),
|
||||
return jsonify(
|
||||
{
|
||||
"url": get_gw(),
|
||||
"shards": shards,
|
||||
"session_start_limit": {
|
||||
"total": bucket.requests,
|
||||
"remaining": bucket._tokens,
|
||||
"reset_after": int(reset_after_ts * 1000),
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,5 +23,4 @@ from .channels import bp as guild_channels
|
|||
from .mod import bp as guild_mod
|
||||
from .emoji import bp as guild_emoji
|
||||
|
||||
__all__ = ['guild_roles', 'guild_members', 'guild_channels', 'guild_mod',
|
||||
'guild_emoji']
|
||||
__all__ = ["guild_roles", "guild_members", "guild_channels", "guild_mod", "guild_emoji"]
|
||||
|
|
|
|||
|
|
@ -25,23 +25,23 @@ from litecord.errors import BadRequest
|
|||
from litecord.enums import ChannelType
|
||||
from litecord.blueprints.guild.roles import gen_pairs
|
||||
|
||||
from litecord.schemas import (
|
||||
validate, ROLE_UPDATE_POSITION, CHAN_CREATE
|
||||
)
|
||||
from litecord.blueprints.checks import (
|
||||
guild_check, guild_owner_check, guild_perm_check
|
||||
)
|
||||
from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
|
||||
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
|
||||
|
||||
|
||||
bp = Blueprint('guild_channels', __name__)
|
||||
bp = Blueprint("guild_channels", __name__)
|
||||
|
||||
|
||||
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||
if ctype == ChannelType.GUILD_TEXT:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_text_channels (id, topic)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, kwargs.get('topic', ''))
|
||||
""",
|
||||
channel_id,
|
||||
kwargs.get("topic", ""),
|
||||
)
|
||||
elif ctype == ChannelType.GUILD_VOICE:
|
||||
await app.db.execute(
|
||||
"""
|
||||
|
|
@ -49,34 +49,48 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
|
|||
VALUES ($1, $2, $3)
|
||||
""",
|
||||
channel_id,
|
||||
kwargs.get('bitrate', 64),
|
||||
kwargs.get('user_limit', 0)
|
||||
kwargs.get("bitrate", 64),
|
||||
kwargs.get("user_limit", 0),
|
||||
)
|
||||
|
||||
|
||||
async def create_guild_channel(guild_id: int, channel_id: int,
|
||||
ctype: ChannelType, **kwargs):
|
||||
async def create_guild_channel(
|
||||
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
||||
):
|
||||
"""Create a channel in a guild."""
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", channel_id, ctype.value)
|
||||
""",
|
||||
channel_id,
|
||||
ctype.value,
|
||||
)
|
||||
|
||||
# calc new pos
|
||||
max_pos = await app.db.fetchval("""
|
||||
max_pos = await app.db.fetchval(
|
||||
"""
|
||||
SELECT MAX(position)
|
||||
FROM guild_channels
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# account for the first channel in a guild too
|
||||
max_pos = max_pos or 0
|
||||
|
||||
# all channels go to guild_channels
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", channel_id, guild_id, kwargs['name'], max_pos + 1)
|
||||
""",
|
||||
channel_id,
|
||||
guild_id,
|
||||
kwargs["name"],
|
||||
max_pos + 1,
|
||||
)
|
||||
|
||||
# the rest of sql magic is dependant on the channel
|
||||
# we're creating (a text or voice or category),
|
||||
|
|
@ -84,35 +98,32 @@ async def create_guild_channel(guild_id: int, channel_id: int,
|
|||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/channels', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/channels", methods=["GET"])
|
||||
async def get_guild_channels(guild_id):
|
||||
"""Get the list of channels in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_channel_data(guild_id))
|
||||
return jsonify(await app.storage.get_channel_data(guild_id))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/channels', methods=['POST'])
|
||||
@bp.route("/<int:guild_id>/channels", methods=["POST"])
|
||||
async def create_channel(guild_id):
|
||||
"""Create a channel in a guild."""
|
||||
user_id = await token_check()
|
||||
j = validate(await request.get_json(), CHAN_CREATE)
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
||||
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||
|
||||
channel_type = j.get('type', ChannelType.GUILD_TEXT)
|
||||
channel_type = j.get("type", ChannelType.GUILD_TEXT)
|
||||
channel_type = ChannelType(channel_type)
|
||||
|
||||
if channel_type not in (ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE):
|
||||
raise BadRequest('Invalid channel type')
|
||||
if channel_type not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE):
|
||||
raise BadRequest("Invalid channel type")
|
||||
|
||||
new_channel_id = get_snowflake()
|
||||
await create_guild_channel(
|
||||
guild_id, new_channel_id, channel_type, **j)
|
||||
await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
|
||||
|
||||
# TODO: do a better method
|
||||
# subscribe the currently subscribed users to the new channel
|
||||
|
|
@ -120,14 +131,13 @@ async def create_channel(guild_id):
|
|||
|
||||
# since GuildDispatcher calls Storage.get_channel_ids,
|
||||
# it will subscribe all users to the newly created channel.
|
||||
guild_pubsub = app.dispatcher.backends['guild']
|
||||
guild_pubsub = app.dispatcher.backends["guild"]
|
||||
user_ids = guild_pubsub.state[guild_id]
|
||||
for uid in user_ids:
|
||||
await app.dispatcher.sub('guild', guild_id, uid)
|
||||
await app.dispatcher.sub("guild", guild_id, uid)
|
||||
|
||||
chan = await app.storage.get_channel(new_channel_id)
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'CHANNEL_CREATE', chan)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
|
||||
return jsonify(chan)
|
||||
|
||||
|
||||
|
|
@ -135,7 +145,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
|
|||
"""Fetch new information about the channel and dispatch
|
||||
a single CHANNEL_UPDATE event to the guild."""
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
|
||||
|
||||
|
||||
async def _do_single_swap(guild_id: int, pair: tuple):
|
||||
|
|
@ -149,13 +159,14 @@ async def _do_single_swap(guild_id: int, pair: tuple):
|
|||
conn = await app.db.acquire()
|
||||
|
||||
async with conn.transaction():
|
||||
await conn.executemany("""
|
||||
await conn.executemany(
|
||||
"""
|
||||
UPDATE guild_channels
|
||||
SET position = $1
|
||||
WHERE id = $2 AND guild_id = $3
|
||||
""", [
|
||||
(new_pos_1, channel_1, guild_id),
|
||||
(new_pos_2, channel_2, guild_id)])
|
||||
""",
|
||||
[(new_pos_1, channel_1, guild_id), (new_pos_2, channel_2, guild_id)],
|
||||
)
|
||||
|
||||
await app.db.release(conn)
|
||||
|
||||
|
|
@ -173,30 +184,26 @@ async def _do_channel_swaps(guild_id: int, swap_pairs: list):
|
|||
await _do_single_swap(guild_id, pair)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/channels", methods=["PATCH"])
|
||||
async def modify_channel_pos(guild_id):
|
||||
"""Change positions of channels in a guild."""
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
||||
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||
|
||||
# same thing as guild.roles, so we use
|
||||
# the same schema and all.
|
||||
raw_j = await request.get_json()
|
||||
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
||||
j = j['roles']
|
||||
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
|
||||
j = j["roles"]
|
||||
|
||||
channels = await app.storage.get_channel_data(guild_id)
|
||||
|
||||
channel_positions = {chan['position']: int(chan['id'])
|
||||
for chan in channels}
|
||||
channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
|
||||
|
||||
swap_pairs = gen_pairs(
|
||||
j,
|
||||
channel_positions
|
||||
)
|
||||
swap_pairs = gen_pairs(j, channel_positions)
|
||||
|
||||
await _do_channel_swaps(guild_id, swap_pairs)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -27,78 +27,85 @@ from litecord.types import KILOBYTES
|
|||
from litecord.images import parse_data_uri
|
||||
from litecord.errors import BadRequest
|
||||
|
||||
bp = Blueprint('guild.emoji', __name__)
|
||||
bp = Blueprint("guild.emoji", __name__)
|
||||
|
||||
|
||||
async def _dispatch_emojis(guild_id):
|
||||
"""Dispatch a Guild Emojis Update payload to a guild."""
|
||||
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_EMOJIS_UPDATE', {
|
||||
'guild_id': str(guild_id),
|
||||
'emojis': await app.storage.get_guild_emojis(guild_id)
|
||||
})
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
guild_id,
|
||||
"GUILD_EMOJIS_UPDATE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"emojis": await app.storage.get_guild_emojis(guild_id),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/emojis', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/emojis", methods=["GET"])
|
||||
async def _get_guild_emoji(guild_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
return jsonify(
|
||||
await app.storage.get_guild_emojis(guild_id)
|
||||
)
|
||||
return jsonify(await app.storage.get_guild_emojis(guild_id))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["GET"])
|
||||
async def _get_guild_emoji_one(guild_id, emoji_id):
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
return jsonify(
|
||||
await app.storage.get_emoji(emoji_id)
|
||||
)
|
||||
return jsonify(await app.storage.get_emoji(emoji_id))
|
||||
|
||||
|
||||
async def _guild_emoji_size_check(guild_id: int, mime: str):
|
||||
limit = 50
|
||||
if await app.storage.has_feature(guild_id, 'MORE_EMOJI'):
|
||||
if await app.storage.has_feature(guild_id, "MORE_EMOJI"):
|
||||
limit = 200
|
||||
|
||||
# NOTE: I'm assuming you can have 200 animated emojis.
|
||||
select_animated = mime == 'image/gif'
|
||||
select_animated = mime == "image/gif"
|
||||
|
||||
total_emoji = await app.db.fetchval("""
|
||||
total_emoji = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*) FROM guild_emoji
|
||||
WHERE guild_id = $1 AND animated = $2
|
||||
""", guild_id, select_animated)
|
||||
""",
|
||||
guild_id,
|
||||
select_animated,
|
||||
)
|
||||
|
||||
if total_emoji >= limit:
|
||||
# TODO: really return a BadRequest? needs more looking.
|
||||
raise BadRequest(f'too many emoji ({limit})')
|
||||
raise BadRequest(f"too many emoji ({limit})")
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/emojis', methods=['POST'])
|
||||
@bp.route("/<int:guild_id>/emojis", methods=["POST"])
|
||||
async def _put_emoji(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_emojis')
|
||||
await guild_perm_check(user_id, guild_id, "manage_emojis")
|
||||
|
||||
j = validate(await request.get_json(), NEW_EMOJI)
|
||||
|
||||
# we have to parse it before passing on so that we know which
|
||||
# size to check.
|
||||
mime, _ = parse_data_uri(j['image'])
|
||||
mime, _ = parse_data_uri(j["image"])
|
||||
await _guild_emoji_size_check(guild_id, mime)
|
||||
|
||||
emoji_id = get_snowflake()
|
||||
|
||||
icon = await app.icons.put(
|
||||
'emoji', emoji_id, j['image'],
|
||||
|
||||
"emoji",
|
||||
emoji_id,
|
||||
j["image"],
|
||||
# limits to emojis
|
||||
bsize=128 * KILOBYTES, size=(128, 128)
|
||||
bsize=128 * KILOBYTES,
|
||||
size=(128, 128),
|
||||
)
|
||||
|
||||
if not icon:
|
||||
return '', 400
|
||||
return "", 400
|
||||
|
||||
# TODO: better way to detect animated emoji rather than just gifs,
|
||||
# maybe a list perhaps?
|
||||
|
|
@ -109,25 +116,25 @@ async def _put_emoji(guild_id):
|
|||
VALUES
|
||||
($1, $2, $3, $4, $5, $6)
|
||||
""",
|
||||
emoji_id, guild_id, user_id,
|
||||
j['name'],
|
||||
emoji_id,
|
||||
guild_id,
|
||||
user_id,
|
||||
j["name"],
|
||||
icon.icon_hash,
|
||||
icon.mime == 'image/gif'
|
||||
icon.mime == "image/gif",
|
||||
)
|
||||
|
||||
await _dispatch_emojis(guild_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_emoji(emoji_id)
|
||||
)
|
||||
return jsonify(await app.storage.get_emoji(emoji_id))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["PATCH"])
|
||||
async def _patch_emoji(guild_id, emoji_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_emojis')
|
||||
await guild_perm_check(user_id, guild_id, "manage_emojis")
|
||||
|
||||
j = validate(await request.get_json(), PATCH_EMOJI)
|
||||
emoji = await app.storage.get_emoji(emoji_id)
|
||||
|
|
@ -135,34 +142,39 @@ async def _patch_emoji(guild_id, emoji_id):
|
|||
# if emoji.name is still the same, we don't update anything
|
||||
# or send ane events, just return the same emoji we'd send
|
||||
# as if we updated it.
|
||||
if j['name'] == emoji['name']:
|
||||
if j["name"] == emoji["name"]:
|
||||
return jsonify(emoji)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guild_emoji
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""", j['name'], emoji_id)
|
||||
""",
|
||||
j["name"],
|
||||
emoji_id,
|
||||
)
|
||||
|
||||
await _dispatch_emojis(guild_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_emoji(emoji_id)
|
||||
)
|
||||
return jsonify(await app.storage.get_emoji(emoji_id))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["DELETE"])
|
||||
async def _del_emoji(guild_id, emoji_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_emojis')
|
||||
await guild_perm_check(user_id, guild_id, "manage_emojis")
|
||||
|
||||
# TODO: check if actually deleted
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM guild_emoji
|
||||
WHERE id = $2
|
||||
""", emoji_id)
|
||||
""",
|
||||
emoji_id,
|
||||
)
|
||||
|
||||
await _dispatch_emojis(guild_id)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -22,18 +22,14 @@ from quart import Blueprint, request, current_app as app, jsonify
|
|||
from litecord.blueprints.auth import token_check
|
||||
from litecord.errors import BadRequest
|
||||
|
||||
from litecord.schemas import (
|
||||
validate, MEMBER_UPDATE
|
||||
)
|
||||
from litecord.schemas import validate, MEMBER_UPDATE
|
||||
|
||||
from litecord.blueprints.checks import (
|
||||
guild_check, guild_owner_check, guild_perm_check
|
||||
)
|
||||
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
|
||||
|
||||
bp = Blueprint('guild_members', __name__)
|
||||
bp = Blueprint("guild_members", __name__)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["GET"])
|
||||
async def get_guild_member(guild_id, member_id):
|
||||
"""Get a member's information in a guild."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -42,7 +38,7 @@ async def get_guild_member(guild_id, member_id):
|
|||
return jsonify(member)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/members", methods=["GET"])
|
||||
async def get_members(guild_id):
|
||||
"""Get members inside a guild."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -50,34 +46,41 @@ async def get_members(guild_id):
|
|||
|
||||
j = await request.get_json()
|
||||
|
||||
limit, after = int(j.get('limit', 1)), j.get('after', 0)
|
||||
limit, after = int(j.get("limit", 1)), j.get("after", 0)
|
||||
|
||||
if limit < 1 or limit > 1000:
|
||||
raise BadRequest('limit not in 1-1000 range')
|
||||
raise BadRequest("limit not in 1-1000 range")
|
||||
|
||||
user_ids = await app.db.fetch(f"""
|
||||
user_ids = await app.db.fetch(
|
||||
f"""
|
||||
SELECT user_id
|
||||
WHERE guild_id = $1, user_id > $2
|
||||
LIMIT {limit}
|
||||
ORDER BY user_id ASC
|
||||
""", guild_id, after)
|
||||
""",
|
||||
guild_id,
|
||||
after,
|
||||
)
|
||||
|
||||
user_ids = [r[0] for r in user_ids]
|
||||
members = await app.storage.get_member_multi(guild_id, user_ids)
|
||||
return jsonify(members)
|
||||
|
||||
|
||||
async def _update_member_roles(guild_id: int, member_id: int,
|
||||
wanted_roles: set):
|
||||
async def _update_member_roles(guild_id: int, member_id: int, wanted_roles: set):
|
||||
"""Update the roles a member has."""
|
||||
|
||||
# first, fetch all current roles
|
||||
roles = await app.db.fetch("""
|
||||
roles = await app.db.fetch(
|
||||
"""
|
||||
SELECT role_id from member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
""",
|
||||
guild_id,
|
||||
member_id,
|
||||
)
|
||||
|
||||
roles = [r['role_id'] for r in roles]
|
||||
roles = [r["role_id"] for r in roles]
|
||||
|
||||
roles = set(roles)
|
||||
wanted_roles = set(wanted_roles)
|
||||
|
|
@ -96,26 +99,30 @@ async def _update_member_roles(guild_id: int, member_id: int,
|
|||
|
||||
async with conn.transaction():
|
||||
# add roles
|
||||
await app.db.executemany("""
|
||||
await app.db.executemany(
|
||||
"""
|
||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", [(member_id, guild_id, role_id)
|
||||
for role_id in added_roles])
|
||||
""",
|
||||
[(member_id, guild_id, role_id) for role_id in added_roles],
|
||||
)
|
||||
|
||||
# remove roles
|
||||
await app.db.executemany("""
|
||||
await app.db.executemany(
|
||||
"""
|
||||
DELETE FROM member_roles
|
||||
WHERE
|
||||
user_id = $1
|
||||
AND guild_id = $2
|
||||
AND role_id = $3
|
||||
""", [(member_id, guild_id, role_id)
|
||||
for role_id in removed_roles])
|
||||
""",
|
||||
[(member_id, guild_id, role_id) for role_id in removed_roles],
|
||||
)
|
||||
|
||||
await app.db.release(conn)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["PATCH"])
|
||||
async def modify_guild_member(guild_id, member_id):
|
||||
"""Modify a members' information in a guild."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -124,96 +131,112 @@ async def modify_guild_member(guild_id, member_id):
|
|||
j = validate(await request.get_json(), MEMBER_UPDATE)
|
||||
nick_flag = False
|
||||
|
||||
if 'nick' in j:
|
||||
await guild_perm_check(user_id, guild_id, 'manage_nicknames')
|
||||
if "nick" in j:
|
||||
await guild_perm_check(user_id, guild_id, "manage_nicknames")
|
||||
|
||||
nick = j['nick'] or None
|
||||
nick = j["nick"] or None
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE members
|
||||
SET nickname = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", nick, member_id, guild_id)
|
||||
""",
|
||||
nick,
|
||||
member_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
nick_flag = True
|
||||
|
||||
if 'mute' in j:
|
||||
await guild_perm_check(user_id, guild_id, 'mute_members')
|
||||
if "mute" in j:
|
||||
await guild_perm_check(user_id, guild_id, "mute_members")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE members
|
||||
SET muted = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['mute'], member_id, guild_id)
|
||||
""",
|
||||
j["mute"],
|
||||
member_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if 'deaf' in j:
|
||||
await guild_perm_check(user_id, guild_id, 'deafen_members')
|
||||
if "deaf" in j:
|
||||
await guild_perm_check(user_id, guild_id, "deafen_members")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE members
|
||||
SET deafened = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j['deaf'], member_id, guild_id)
|
||||
""",
|
||||
j["deaf"],
|
||||
member_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if 'channel_id' in j:
|
||||
if "channel_id" in j:
|
||||
# TODO: check MOVE_MEMBERS and CONNECT to the channel
|
||||
# TODO: change the member's voice channel
|
||||
pass
|
||||
|
||||
if 'roles' in j:
|
||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
||||
await _update_member_roles(guild_id, member_id, j['roles'])
|
||||
if "roles" in j:
|
||||
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||
await _update_member_roles(guild_id, member_id, j["roles"])
|
||||
|
||||
member = await app.storage.get_member_data_one(guild_id, member_id)
|
||||
member.pop('joined_at')
|
||||
member.pop("joined_at")
|
||||
|
||||
# call pres_update for role and nick changes.
|
||||
partial = {
|
||||
'roles': member['roles']
|
||||
}
|
||||
partial = {"roles": member["roles"]}
|
||||
|
||||
if nick_flag:
|
||||
partial['nick'] = j['nick']
|
||||
partial["nick"] = j["nick"]
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'lazy_guild', guild_id, 'pres_update', user_id, partial)
|
||||
"lazy_guild", guild_id, "pres_update", user_id, partial
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||
'guild_id': str(guild_id)
|
||||
}, **member})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/members/@me/nick", methods=["PATCH"])
|
||||
async def update_nickname(guild_id):
|
||||
"""Update a member's nickname in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
j = validate(await request.get_json(), {
|
||||
'nick': {'type': 'nickname'}
|
||||
})
|
||||
j = validate(await request.get_json(), {"nick": {"type": "nickname"}})
|
||||
|
||||
nick = j['nick'] or None
|
||||
nick = j["nick"] or None
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE members
|
||||
SET nickname = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", nick, user_id, guild_id)
|
||||
""",
|
||||
nick,
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||
member.pop('joined_at')
|
||||
member.pop("joined_at")
|
||||
|
||||
# call pres_update for nick changes, etc.
|
||||
await app.dispatcher.dispatch(
|
||||
'lazy_guild', guild_id, 'pres_update', user_id, {
|
||||
'nick': j['nick']
|
||||
})
|
||||
"lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
||||
'guild_id': str(guild_id)
|
||||
}, **member})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||
)
|
||||
|
||||
return j['nick']
|
||||
return j["nick"]
|
||||
|
|
|
|||
|
|
@ -24,33 +24,38 @@ from litecord.blueprints.checks import guild_perm_check
|
|||
|
||||
from litecord.schemas import validate, GUILD_PRUNE
|
||||
|
||||
bp = Blueprint('guild_moderation', __name__)
|
||||
bp = Blueprint("guild_moderation", __name__)
|
||||
|
||||
|
||||
async def remove_member(guild_id: int, member_id: int):
|
||||
"""Do common tasks related to deleting a member from the guild,
|
||||
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
""",
|
||||
guild_id,
|
||||
member_id,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
member_id, guild_id, 'GUILD_DELETE', {
|
||||
'guild_id': str(guild_id),
|
||||
'unavailable': False,
|
||||
})
|
||||
member_id,
|
||||
guild_id,
|
||||
"GUILD_DELETE",
|
||||
{"guild_id": str(guild_id), "unavailable": False},
|
||||
)
|
||||
|
||||
await app.dispatcher.unsub('guild', guild_id, member_id)
|
||||
await app.dispatcher.unsub("guild", guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'lazy_guild', guild_id, 'remove_member', member_id)
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
||||
'guild_id': str(guild_id),
|
||||
'user': await app.storage.get_user(member_id),
|
||||
})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id,
|
||||
"GUILD_MEMBER_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
)
|
||||
|
||||
|
||||
async def remove_member_multi(guild_id: int, members: list):
|
||||
|
|
@ -59,84 +64,100 @@ async def remove_member_multi(guild_id: int, members: list):
|
|||
await remove_member(guild_id, member_id)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["DELETE"])
|
||||
async def kick_guild_member(guild_id, member_id):
|
||||
"""Remove a member from a guild."""
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'kick_members')
|
||||
await guild_perm_check(user_id, guild_id, "kick_members")
|
||||
await remove_member(guild_id, member_id)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/bans", methods=["GET"])
|
||||
async def get_bans(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'ban_members')
|
||||
await guild_perm_check(user_id, guild_id, "ban_members")
|
||||
|
||||
bans = await app.db.fetch("""
|
||||
bans = await app.db.fetch(
|
||||
"""
|
||||
SELECT user_id, reason
|
||||
FROM bans
|
||||
WHERE bans.guild_id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
res = []
|
||||
|
||||
for ban in bans:
|
||||
res.append({
|
||||
'reason': ban['reason'],
|
||||
'user': await app.storage.get_user(ban['user_id'])
|
||||
})
|
||||
res.append(
|
||||
{
|
||||
"reason": ban["reason"],
|
||||
"user": await app.storage.get_user(ban["user_id"]),
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
|
||||
@bp.route("/<int:guild_id>/bans/<int:member_id>", methods=["PUT"])
|
||||
async def create_ban(guild_id, member_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'ban_members')
|
||||
await guild_perm_check(user_id, guild_id, "ban_members")
|
||||
|
||||
j = await request.get_json()
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO bans (guild_id, user_id, reason)
|
||||
VALUES ($1, $2, $3)
|
||||
""", guild_id, member_id, j.get('reason', ''))
|
||||
""",
|
||||
guild_id,
|
||||
member_id,
|
||||
j.get("reason", ""),
|
||||
)
|
||||
|
||||
await remove_member(guild_id, member_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {
|
||||
'guild_id': str(guild_id),
|
||||
'user': await app.storage.get_user(member_id)
|
||||
})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id,
|
||||
"GUILD_BAN_ADD",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/bans/<int:banned_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:guild_id>/bans/<int:banned_id>", methods=["DELETE"])
|
||||
async def remove_ban(guild_id, banned_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'ban_members')
|
||||
await guild_perm_check(user_id, guild_id, "ban_members")
|
||||
|
||||
res = await app.db.execute("""
|
||||
res = await app.db.execute(
|
||||
"""
|
||||
DELETE FROM bans
|
||||
WHERE guild_id = $1 AND user_id = $@
|
||||
""", guild_id, banned_id)
|
||||
""",
|
||||
guild_id,
|
||||
banned_id,
|
||||
)
|
||||
|
||||
# we don't really need to dispatch GUILD_BAN_REMOVE
|
||||
# when no bans were actually removed.
|
||||
if res == 'DELETE 0':
|
||||
return '', 204
|
||||
if res == "DELETE 0":
|
||||
return "", 204
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', {
|
||||
'guild_id': str(guild_id),
|
||||
'user': await app.storage.get_user(banned_id)
|
||||
})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id,
|
||||
"GUILD_BAN_REMOVE",
|
||||
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
async def get_prune(guild_id: int, days: int) -> list:
|
||||
|
|
@ -146,23 +167,30 @@ async def get_prune(guild_id: int, days: int) -> list:
|
|||
- don't have any roles.
|
||||
"""
|
||||
# a good solution would be in pure sql.
|
||||
member_ids = await app.db.fetch(f"""
|
||||
member_ids = await app.db.fetch(
|
||||
f"""
|
||||
SELECT id
|
||||
FROM users
|
||||
JOIN members
|
||||
ON members.guild_id = $1 AND members.user_id = users.id
|
||||
WHERE users.last_session < (now() - (interval '{days} days'))
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
member_ids = [r['id'] for r in member_ids]
|
||||
member_ids = [r["id"] for r in member_ids]
|
||||
members = []
|
||||
|
||||
for member_id in member_ids:
|
||||
role_count = await app.db.fetchval("""
|
||||
role_count = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
""",
|
||||
guild_id,
|
||||
member_id,
|
||||
)
|
||||
|
||||
if role_count == 0:
|
||||
members.append(member_id)
|
||||
|
|
@ -170,33 +198,29 @@ async def get_prune(guild_id: int, days: int) -> list:
|
|||
return members
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/prune', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/prune", methods=["GET"])
|
||||
async def get_guild_prune_count(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'kick_members')
|
||||
await guild_perm_check(user_id, guild_id, "kick_members")
|
||||
|
||||
j = validate(request.args, GUILD_PRUNE)
|
||||
days = j['days']
|
||||
days = j["days"]
|
||||
member_ids = await get_prune(guild_id, days)
|
||||
|
||||
return jsonify({
|
||||
'pruned': len(member_ids),
|
||||
})
|
||||
return jsonify({"pruned": len(member_ids)})
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/prune', methods=['POST'])
|
||||
@bp.route("/<int:guild_id>/prune", methods=["POST"])
|
||||
async def begin_guild_prune(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'kick_members')
|
||||
await guild_perm_check(user_id, guild_id, "kick_members")
|
||||
|
||||
j = validate(request.args, GUILD_PRUNE)
|
||||
days = j['days']
|
||||
days = j["days"]
|
||||
member_ids = await get_prune(guild_id, days)
|
||||
|
||||
app.loop.create_task(remove_member_multi(guild_id, member_ids))
|
||||
|
||||
return jsonify({
|
||||
'pruned': len(member_ids)
|
||||
})
|
||||
return jsonify({"pruned": len(member_ids)})
|
||||
|
|
|
|||
|
|
@ -24,12 +24,8 @@ from logbook import Logger
|
|||
|
||||
from litecord.auth import token_check
|
||||
|
||||
from litecord.blueprints.checks import (
|
||||
guild_check, guild_perm_check
|
||||
)
|
||||
from litecord.schemas import (
|
||||
validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
||||
)
|
||||
from litecord.blueprints.checks import guild_check, guild_perm_check
|
||||
from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
||||
|
||||
from litecord.snowflake import get_snowflake
|
||||
from litecord.utils import dict_get
|
||||
|
|
@ -37,22 +33,19 @@ from litecord.permissions import get_role_perms
|
|||
|
||||
DEFAULT_EVERYONE_PERMS = 104324161
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('guild_roles', __name__)
|
||||
bp = Blueprint("guild_roles", __name__)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/roles", methods=["GET"])
|
||||
async def get_guild_roles(guild_id):
|
||||
"""Get all roles in a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_role_data(guild_id)
|
||||
)
|
||||
return jsonify(await app.storage.get_role_data(guild_id))
|
||||
|
||||
|
||||
async def _maybe_lg(guild_id: int, event: str,
|
||||
role, force: bool = False):
|
||||
async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
|
||||
# sometimes we want to dispatch an event
|
||||
# even if the role isn't hoisted
|
||||
|
||||
|
|
@ -61,11 +54,10 @@ async def _maybe_lg(guild_id: int, event: str,
|
|||
|
||||
# check if is a dict first because role_delete
|
||||
# only receives the role id.
|
||||
if isinstance(role, dict) and not role['hoist'] and not force:
|
||||
if isinstance(role, dict) and not role["hoist"] and not force:
|
||||
return
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'lazy_guild', guild_id, event, role)
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
|
||||
|
||||
|
||||
async def create_role(guild_id, name: str, **kwargs):
|
||||
|
|
@ -73,18 +65,20 @@ async def create_role(guild_id, name: str, **kwargs):
|
|||
new_role_id = get_snowflake()
|
||||
|
||||
everyone_perms = await get_role_perms(guild_id, guild_id)
|
||||
default_perms = dict_get(kwargs, 'default_perms',
|
||||
everyone_perms.binary)
|
||||
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
|
||||
|
||||
# update all roles so that we have space for pos 1, but without
|
||||
# sending GUILD_ROLE_UPDATE for everyone
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE roles
|
||||
SET
|
||||
position = position + 1
|
||||
WHERE guild_id = $1
|
||||
AND NOT (position = 0)
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
|
|
@ -95,42 +89,39 @@ async def create_role(guild_id, name: str, **kwargs):
|
|||
new_role_id,
|
||||
guild_id,
|
||||
name,
|
||||
dict_get(kwargs, 'color', 0),
|
||||
dict_get(kwargs, 'hoist', False),
|
||||
|
||||
dict_get(kwargs, "color", 0),
|
||||
dict_get(kwargs, "hoist", False),
|
||||
# always set ourselves on position 1
|
||||
1,
|
||||
int(dict_get(kwargs, 'permissions', default_perms)),
|
||||
int(dict_get(kwargs, "permissions", default_perms)),
|
||||
False,
|
||||
dict_get(kwargs, 'mentionable', False)
|
||||
dict_get(kwargs, "mentionable", False),
|
||||
)
|
||||
|
||||
role = await app.storage.get_role(new_role_id, guild_id)
|
||||
|
||||
# we need to update the lazy guild handlers for the newly created group
|
||||
await _maybe_lg(guild_id, 'new_role', role)
|
||||
await _maybe_lg(guild_id, "new_role", role)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'GUILD_ROLE_CREATE', {
|
||||
'guild_id': str(guild_id),
|
||||
'role': role,
|
||||
})
|
||||
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
||||
)
|
||||
|
||||
return role
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles', methods=['POST'])
|
||||
@bp.route("/<int:guild_id>/roles", methods=["POST"])
|
||||
async def create_guild_role(guild_id: int):
|
||||
"""Add a role to a guild"""
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
||||
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||
|
||||
# client can just send null
|
||||
j = validate(await request.get_json() or {}, ROLE_CREATE)
|
||||
|
||||
role_name = j['name']
|
||||
j.pop('name')
|
||||
role_name = j["name"]
|
||||
j.pop("name")
|
||||
|
||||
role = await create_role(guild_id, role_name, **j)
|
||||
|
||||
|
|
@ -141,12 +132,11 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
|
|||
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
|
||||
role = await app.storage.get_role(role_id, guild_id)
|
||||
|
||||
await _maybe_lg(guild_id, 'role_pos_upd', role)
|
||||
await _maybe_lg(guild_id, "role_pos_upd", role)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', {
|
||||
'guild_id': str(guild_id),
|
||||
'role': role,
|
||||
})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
|
||||
)
|
||||
|
||||
return role
|
||||
|
||||
|
|
@ -166,17 +156,25 @@ async def _role_pairs_update(guild_id: int, pairs: list):
|
|||
async with conn.transaction():
|
||||
# update happens in a transaction
|
||||
# so we don't fuck it up
|
||||
await conn.execute("""
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE roles
|
||||
SET position = $1
|
||||
WHERE roles.id = $2
|
||||
""", new_pos_1, role_1)
|
||||
""",
|
||||
new_pos_1,
|
||||
role_1,
|
||||
)
|
||||
|
||||
await conn.execute("""
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE roles
|
||||
SET position = $1
|
||||
WHERE roles.id = $2
|
||||
""", new_pos_2, role_2)
|
||||
""",
|
||||
new_pos_2,
|
||||
role_2,
|
||||
)
|
||||
|
||||
await app.db.release(conn)
|
||||
|
||||
|
|
@ -184,11 +182,15 @@ async def _role_pairs_update(guild_id: int, pairs: list):
|
|||
await _role_update_dispatch(role_1, guild_id)
|
||||
await _role_update_dispatch(role_2, guild_id)
|
||||
|
||||
|
||||
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
|
||||
|
||||
def gen_pairs(list_of_changes: List[Dict[str, int]],
|
||||
current_state: Dict[int, int],
|
||||
blacklist: List[int] = None) -> PairList:
|
||||
|
||||
def gen_pairs(
|
||||
list_of_changes: List[Dict[str, int]],
|
||||
current_state: Dict[int, int],
|
||||
blacklist: List[int] = None,
|
||||
) -> PairList:
|
||||
"""Generate a list of pairs that, when applied to the database,
|
||||
will generate the desired state given in list_of_changes.
|
||||
|
||||
|
|
@ -230,8 +232,9 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
|
|||
pairs = []
|
||||
blacklist = blacklist or []
|
||||
|
||||
preferred_state = {element['id']: element['position']
|
||||
for element in list_of_changes}
|
||||
preferred_state = {
|
||||
element["id"]: element["position"] for element in list_of_changes
|
||||
}
|
||||
|
||||
for blacklisted_id in blacklist:
|
||||
preferred_state.pop(blacklisted_id)
|
||||
|
|
@ -239,7 +242,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
|
|||
# for each change, we must find a matching change
|
||||
# in the same list, so we can make a swap pair
|
||||
for change in list_of_changes:
|
||||
element_1, new_pos_1 = change['id'], change['position']
|
||||
element_1, new_pos_1 = change["id"], change["position"]
|
||||
|
||||
# check current pairs
|
||||
# so we don't repeat an element
|
||||
|
|
@ -267,36 +270,34 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
|
|||
# if its being swapped to leave space, add it
|
||||
# to the pairs list
|
||||
if new_pos_2 is not None:
|
||||
pairs.append(
|
||||
((element_1, new_pos_1), (element_2, new_pos_2))
|
||||
)
|
||||
pairs.append(((element_1, new_pos_1), (element_2, new_pos_2)))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/roles", methods=["PATCH"])
|
||||
async def update_guild_role_positions(guild_id):
|
||||
"""Update the positions for a bunch of roles."""
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
||||
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||
|
||||
raw_j = await request.get_json()
|
||||
|
||||
# we need to do this hackiness because thats
|
||||
# cerberus for ya.
|
||||
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
||||
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
|
||||
|
||||
# extract the list out
|
||||
j = j['roles']
|
||||
j = j["roles"]
|
||||
|
||||
log.debug('role stuff: {!r}', j)
|
||||
log.debug("role stuff: {!r}", j)
|
||||
|
||||
all_roles = await app.storage.get_role_data(guild_id)
|
||||
|
||||
# we'll have to calculate pairs of changing roles,
|
||||
# then do the changes, etc.
|
||||
roles_pos = {role['position']: int(role['id']) for role in all_roles}
|
||||
roles_pos = {role["position"]: int(role["id"]) for role in all_roles}
|
||||
|
||||
# TODO: check if the user can even change the roles in the first place,
|
||||
# preferrably when we have a proper perms system.
|
||||
|
|
@ -306,10 +307,9 @@ async def update_guild_role_positions(guild_id):
|
|||
pairs = gen_pairs(
|
||||
j,
|
||||
roles_pos,
|
||||
|
||||
# always ignore people trying to change
|
||||
# the @everyone's role position
|
||||
[guild_id]
|
||||
[guild_id],
|
||||
)
|
||||
|
||||
await _role_pairs_update(guild_id, pairs)
|
||||
|
|
@ -318,31 +318,36 @@ async def update_guild_role_positions(guild_id):
|
|||
return jsonify(await app.storage.get_role_data(guild_id))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["PATCH"])
|
||||
async def update_guild_role(guild_id, role_id):
|
||||
"""Update a single role's information."""
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
||||
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||
|
||||
j = validate(await request.get_json(), ROLE_UPDATE)
|
||||
|
||||
# we only update ints on the db, not Permissions
|
||||
j['permissions'] = int(j['permissions'])
|
||||
j["permissions"] = int(j["permissions"])
|
||||
|
||||
for field in j:
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE roles
|
||||
SET {field} = $1
|
||||
WHERE roles.id = $2 AND roles.guild_id = $3
|
||||
""", j[field], role_id, guild_id)
|
||||
""",
|
||||
j[field],
|
||||
role_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
role = await _role_update_dispatch(role_id, guild_id)
|
||||
await _maybe_lg(guild_id, 'role_update', role, True)
|
||||
await _maybe_lg(guild_id, "role_update", role, True)
|
||||
return jsonify(role)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["DELETE"])
|
||||
async def delete_guild_role(guild_id, role_id):
|
||||
"""Delete a role.
|
||||
|
||||
|
|
@ -350,21 +355,26 @@ async def delete_guild_role(guild_id, role_id):
|
|||
"""
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
||||
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||
|
||||
res = await app.db.execute("""
|
||||
res = await app.db.execute(
|
||||
"""
|
||||
DELETE FROM roles
|
||||
WHERE guild_id = $1 AND id = $2
|
||||
""", guild_id, role_id)
|
||||
""",
|
||||
guild_id,
|
||||
role_id,
|
||||
)
|
||||
|
||||
if res == 'DELETE 0':
|
||||
return '', 204
|
||||
if res == "DELETE 0":
|
||||
return "", 204
|
||||
|
||||
await _maybe_lg(guild_id, 'role_delete', role_id, True)
|
||||
await _maybe_lg(guild_id, "role_delete", role_id, True)
|
||||
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', {
|
||||
'guild_id': str(guild_id),
|
||||
'role_id': str(role_id),
|
||||
})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id,
|
||||
"GUILD_ROLE_DELETE",
|
||||
{"guild_id": str(guild_id), "role_id": str(role_id)},
|
||||
)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -22,16 +22,17 @@ from typing import Optional, List
|
|||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
|
||||
from litecord.blueprints.guild.channels import create_guild_channel
|
||||
from litecord.blueprints.guild.roles import (
|
||||
create_role, DEFAULT_EVERYONE_PERMS
|
||||
)
|
||||
from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
|
||||
|
||||
from ..auth import token_check
|
||||
from ..snowflake import get_snowflake
|
||||
from ..enums import ChannelType
|
||||
from ..schemas import (
|
||||
validate, GUILD_CREATE, GUILD_UPDATE, SEARCH_CHANNEL,
|
||||
VANITY_URL_PATCH
|
||||
validate,
|
||||
GUILD_CREATE,
|
||||
GUILD_UPDATE,
|
||||
SEARCH_CHANNEL,
|
||||
VANITY_URL_PATCH,
|
||||
)
|
||||
from .channels import channel_ack
|
||||
from .checks import guild_check, guild_owner_check, guild_perm_check
|
||||
|
|
@ -40,7 +41,7 @@ from litecord.errors import BadRequest
|
|||
from litecord.permissions import get_permissions
|
||||
|
||||
|
||||
bp = Blueprint('guilds', __name__)
|
||||
bp = Blueprint("guilds", __name__)
|
||||
|
||||
|
||||
async def create_guild_settings(guild_id: int, user_id: int):
|
||||
|
|
@ -49,26 +50,38 @@ async def create_guild_settings(guild_id: int, user_id: int):
|
|||
|
||||
# new guild_settings are based off the currently
|
||||
# set guild settings (for the guild)
|
||||
m_notifs = await app.db.fetchval("""
|
||||
m_notifs = await app.db.fetchval(
|
||||
"""
|
||||
SELECT default_message_notifications
|
||||
FROM guilds
|
||||
WHERE id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_settings
|
||||
(user_id, guild_id, message_notifications)
|
||||
VALUES
|
||||
($1, $2, $3)
|
||||
""", user_id, guild_id, m_notifs)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
m_notifs,
|
||||
)
|
||||
|
||||
|
||||
async def add_member(guild_id: int, user_id: int):
|
||||
"""Add a user to a guild."""
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO members (user_id, guild_id)
|
||||
VALUES ($1, $2)
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
await create_guild_settings(guild_id, user_id)
|
||||
|
||||
|
|
@ -83,28 +96,28 @@ async def guild_create_roles_prep(guild_id: int, roles: list):
|
|||
# are patches to the @everyone role
|
||||
everyone_patches = roles[0]
|
||||
for field in everyone_patches:
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE roles
|
||||
SET {field}={everyone_patches[field]}
|
||||
WHERE roles.id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
default_perms = (everyone_patches.get('permissions')
|
||||
or DEFAULT_EVERYONE_PERMS)
|
||||
default_perms = everyone_patches.get("permissions") or DEFAULT_EVERYONE_PERMS
|
||||
|
||||
# from the 2nd and forward,
|
||||
# should be treated as new roles
|
||||
for role in roles[1:]:
|
||||
await create_role(
|
||||
guild_id, role['name'], default_perms=default_perms, **role
|
||||
)
|
||||
await create_role(guild_id, role["name"], default_perms=default_perms, **role)
|
||||
|
||||
|
||||
async def guild_create_channels_prep(guild_id: int, channels: list):
|
||||
"""Create channels pre-guild create"""
|
||||
for channel_raw in channels:
|
||||
channel_id = get_snowflake()
|
||||
ctype = ChannelType(channel_raw['type'])
|
||||
ctype = ChannelType(channel_raw["type"])
|
||||
|
||||
await create_guild_channel(guild_id, channel_id, ctype)
|
||||
|
||||
|
|
@ -114,37 +127,29 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
|
|||
|
||||
Defaults to a jpeg icon when the header isn't given.
|
||||
"""
|
||||
if icon and icon.startswith('data'):
|
||||
if icon and icon.startswith("data"):
|
||||
return icon
|
||||
|
||||
return (f'data:image/jpeg;base64,{icon}'
|
||||
if icon
|
||||
else None)
|
||||
return f"data:image/jpeg;base64,{icon}" if icon else None
|
||||
|
||||
|
||||
async def _general_guild_icon(scope: str, guild_id: int,
|
||||
icon: str, **kwargs):
|
||||
async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs):
|
||||
encoded = sanitize_icon(icon)
|
||||
|
||||
icon_kwargs = {
|
||||
'always_icon': True
|
||||
}
|
||||
icon_kwargs = {"always_icon": True}
|
||||
|
||||
if 'size' in kwargs:
|
||||
icon_kwargs['size'] = kwargs['size']
|
||||
if "size" in kwargs:
|
||||
icon_kwargs["size"] = kwargs["size"]
|
||||
|
||||
return await app.icons.put(
|
||||
scope, guild_id, encoded,
|
||||
**icon_kwargs
|
||||
)
|
||||
return await app.icons.put(scope, guild_id, encoded, **icon_kwargs)
|
||||
|
||||
|
||||
async def put_guild_icon(guild_id: int, icon: Optional[str]):
|
||||
"""Insert a guild icon on the icon database."""
|
||||
return await _general_guild_icon('guild', guild_id, icon, size=(128, 128))
|
||||
return await _general_guild_icon("guild", guild_id, icon, size=(128, 128))
|
||||
|
||||
|
||||
@bp.route('', methods=['POST'])
|
||||
@bp.route("", methods=["POST"])
|
||||
async def create_guild():
|
||||
"""Create a new guild, assigning
|
||||
the user creating it as the owner and
|
||||
|
|
@ -154,8 +159,8 @@ async def create_guild():
|
|||
|
||||
guild_id = get_snowflake()
|
||||
|
||||
if 'icon' in j:
|
||||
image = await put_guild_icon(guild_id, j['icon'])
|
||||
if "icon" in j:
|
||||
image = await put_guild_icon(guild_id, j["icon"])
|
||||
image = image.icon_hash
|
||||
else:
|
||||
image = None
|
||||
|
|
@ -166,10 +171,16 @@ async def create_guild():
|
|||
verification_level, default_message_notifications,
|
||||
explicit_content_filter)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
""", guild_id, j['name'], j['region'], image, user_id,
|
||||
j.get('verification_level', 0),
|
||||
j.get('default_message_notifications', 0),
|
||||
j.get('explicit_content_filter', 0))
|
||||
""",
|
||||
guild_id,
|
||||
j["name"],
|
||||
j["region"],
|
||||
image,
|
||||
user_id,
|
||||
j.get("verification_level", 0),
|
||||
j.get("default_message_notifications", 0),
|
||||
j.get("explicit_content_filter", 0),
|
||||
)
|
||||
|
||||
await add_member(guild_id, user_id)
|
||||
|
||||
|
|
@ -179,107 +190,127 @@ async def create_guild():
|
|||
# we also don't use create_role because the id of the role
|
||||
# is the same as the id of the guild, and create_role
|
||||
# generates a new snowflake.
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO roles (id, guild_id, name, position, permissions)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""", guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS)
|
||||
""",
|
||||
guild_id,
|
||||
guild_id,
|
||||
"@everyone",
|
||||
0,
|
||||
DEFAULT_EVERYONE_PERMS,
|
||||
)
|
||||
|
||||
# add the @everyone role to the guild creator
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", user_id, guild_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# create a single #general channel.
|
||||
general_id = get_snowflake()
|
||||
|
||||
await create_guild_channel(
|
||||
guild_id, general_id, ChannelType.GUILD_TEXT,
|
||||
name='general')
|
||||
guild_id, general_id, ChannelType.GUILD_TEXT, name="general"
|
||||
)
|
||||
|
||||
if j.get('roles'):
|
||||
await guild_create_roles_prep(guild_id, j['roles'])
|
||||
if j.get("roles"):
|
||||
await guild_create_roles_prep(guild_id, j["roles"])
|
||||
|
||||
if j.get('channels'):
|
||||
await guild_create_channels_prep(guild_id, j['channels'])
|
||||
if j.get("channels"):
|
||||
await guild_create_channels_prep(guild_id, j["channels"])
|
||||
|
||||
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
|
||||
await app.dispatcher.sub('guild', guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild_total)
|
||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
|
||||
return jsonify(guild_total)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>", methods=["GET"])
|
||||
async def get_guild(guild_id):
|
||||
"""Get a single guilds' information."""
|
||||
user_id = await token_check()
|
||||
await guild_check(user_id, guild_id)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
)
|
||||
return jsonify(await app.storage.get_guild_full(guild_id, user_id, 250))
|
||||
|
||||
|
||||
async def _guild_update_icon(scope: str, guild_id: int,
|
||||
icon: Optional[str], **kwargs):
|
||||
async def _guild_update_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
|
||||
"""Update icon."""
|
||||
new_icon = await app.icons.update(
|
||||
scope, guild_id, icon, always_icon=True, **kwargs
|
||||
)
|
||||
new_icon = await app.icons.update(scope, guild_id, icon, always_icon=True, **kwargs)
|
||||
|
||||
table = {
|
||||
'guild': 'icon',
|
||||
}.get(scope, scope)
|
||||
table = {"guild": "icon"}.get(scope, scope)
|
||||
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guilds
|
||||
SET {table} = $1
|
||||
WHERE id = $2
|
||||
""", new_icon.icon_hash, guild_id)
|
||||
""",
|
||||
new_icon.icon_hash,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
|
||||
async def _guild_update_region(guild_id, region):
|
||||
is_vip = region.vip
|
||||
can_vip = await app.storage.has_feature(guild_id, 'VIP_REGIONS')
|
||||
can_vip = await app.storage.has_feature(guild_id, "VIP_REGIONS")
|
||||
|
||||
if is_vip and not can_vip:
|
||||
raise BadRequest('can not assign guild to vip-only region')
|
||||
raise BadRequest("can not assign guild to vip-only region")
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET region = $1
|
||||
WHERE id = $2
|
||||
""", region.id, guild_id)
|
||||
""",
|
||||
region.id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>", methods=["PATCH"])
|
||||
async def _update_guild(guild_id):
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
||||
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||
j = validate(await request.get_json(), GUILD_UPDATE)
|
||||
|
||||
if 'owner_id' in j:
|
||||
if "owner_id" in j:
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET owner_id = $1
|
||||
WHERE id = $2
|
||||
""", int(j['owner_id']), guild_id)
|
||||
""",
|
||||
int(j["owner_id"]),
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if 'name' in j:
|
||||
await app.db.execute("""
|
||||
if "name" in j:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""", j['name'], guild_id)
|
||||
""",
|
||||
j["name"],
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if 'region' in j:
|
||||
region = app.voice.lvsp.region(j['region'])
|
||||
if "region" in j:
|
||||
region = app.voice.lvsp.region(j["region"])
|
||||
|
||||
if region is not None:
|
||||
await _guild_update_region(guild_id, region)
|
||||
|
|
@ -287,65 +318,77 @@ async def _update_guild(guild_id):
|
|||
# small guild to work with to_update()
|
||||
guild = await app.storage.get_guild(guild_id)
|
||||
|
||||
if to_update(j, guild, 'icon'):
|
||||
await _guild_update_icon(
|
||||
'guild', guild_id, j['icon'], size=(128, 128))
|
||||
if to_update(j, guild, "icon"):
|
||||
await _guild_update_icon("guild", guild_id, j["icon"], size=(128, 128))
|
||||
|
||||
if to_update(j, guild, 'splash'):
|
||||
if not await app.storage.has_feature(guild_id, 'INVITE_SPLASH'):
|
||||
raise BadRequest('guild does not have INVITE_SPLASH feature')
|
||||
if to_update(j, guild, "splash"):
|
||||
if not await app.storage.has_feature(guild_id, "INVITE_SPLASH"):
|
||||
raise BadRequest("guild does not have INVITE_SPLASH feature")
|
||||
|
||||
await _guild_update_icon('splash', guild_id, j['splash'])
|
||||
await _guild_update_icon("splash", guild_id, j["splash"])
|
||||
|
||||
if to_update(j, guild, 'banner'):
|
||||
if not await app.storage.has_feature(guild_id, 'VERIFIED'):
|
||||
raise BadRequest('guild is not verified')
|
||||
if to_update(j, guild, "banner"):
|
||||
if not await app.storage.has_feature(guild_id, "VERIFIED"):
|
||||
raise BadRequest("guild is not verified")
|
||||
|
||||
await _guild_update_icon('banner', guild_id, j['banner'])
|
||||
await _guild_update_icon("banner", guild_id, j["banner"])
|
||||
|
||||
fields = ['verification_level', 'default_message_notifications',
|
||||
'explicit_content_filter', 'afk_timeout', 'description']
|
||||
fields = [
|
||||
"verification_level",
|
||||
"default_message_notifications",
|
||||
"explicit_content_filter",
|
||||
"afk_timeout",
|
||||
"description",
|
||||
]
|
||||
|
||||
for field in [f for f in fields if f in j]:
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guilds
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], guild_id)
|
||||
""",
|
||||
j[field],
|
||||
guild_id,
|
||||
)
|
||||
|
||||
channel_fields = ['afk_channel_id', 'system_channel_id']
|
||||
channel_fields = ["afk_channel_id", "system_channel_id"]
|
||||
for field in [f for f in channel_fields if f in j]:
|
||||
# setting to null should remove the link between the afk/sys channel
|
||||
# to the guild.
|
||||
if j[field] is None:
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guilds
|
||||
SET {field} = NULL
|
||||
WHERE id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
chan = await app.storage.get_channel(int(j[field]))
|
||||
|
||||
if chan is None:
|
||||
raise BadRequest('invalid channel id')
|
||||
raise BadRequest("invalid channel id")
|
||||
|
||||
if chan['guild_id'] != str(guild_id):
|
||||
raise BadRequest('channel id not linked to guild')
|
||||
if chan["guild_id"] != str(guild_id):
|
||||
raise BadRequest("channel id not linked to guild")
|
||||
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guilds
|
||||
SET {field} = $1
|
||||
WHERE id = $2
|
||||
""", j[field], guild_id)
|
||||
""",
|
||||
j[field],
|
||||
guild_id,
|
||||
)
|
||||
|
||||
guild = await app.storage.get_guild_full(
|
||||
guild_id, user_id
|
||||
)
|
||||
guild = await app.storage.get_guild_full(guild_id, user_id)
|
||||
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, 'GUILD_UPDATE', guild)
|
||||
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||
|
||||
return jsonify(guild)
|
||||
|
||||
|
|
@ -354,33 +397,41 @@ async def delete_guild(guild_id: int, *, app_=None):
|
|||
"""Delete a single guild."""
|
||||
app_ = app_ or app
|
||||
|
||||
await app_.db.execute("""
|
||||
await app_.db.execute(
|
||||
"""
|
||||
DELETE FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# Discord's client expects IDs being string
|
||||
await app_.dispatcher.dispatch('guild', guild_id, 'GUILD_DELETE', {
|
||||
'guild_id': str(guild_id),
|
||||
'id': str(guild_id),
|
||||
# 'unavailable': False,
|
||||
})
|
||||
await app_.dispatcher.dispatch(
|
||||
"guild",
|
||||
guild_id,
|
||||
"GUILD_DELETE",
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"id": str(guild_id),
|
||||
# 'unavailable': False,
|
||||
},
|
||||
)
|
||||
|
||||
# remove from the dispatcher so nobody
|
||||
# becomes the little memer that tries to fuck up with
|
||||
# everybody's gateway
|
||||
await app_.dispatcher.remove('guild', guild_id)
|
||||
await app_.dispatcher.remove("guild", guild_id)
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>', methods=['DELETE'])
|
||||
@bp.route("/<int:guild_id>", methods=["DELETE"])
|
||||
# this endpoint is not documented, but used by the official client.
|
||||
@bp.route('/<int:guild_id>/delete', methods=['POST'])
|
||||
@bp.route("/<int:guild_id>/delete", methods=["POST"])
|
||||
async def delete_guild_handler(guild_id):
|
||||
"""Delete a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_owner_check(user_id, guild_id)
|
||||
await delete_guild(guild_id)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
|
||||
|
|
@ -397,7 +448,7 @@ async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
|
|||
return res
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/messages/search', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/messages/search", methods=["GET"])
|
||||
async def search_messages(guild_id):
|
||||
"""Search messages in a guild.
|
||||
|
||||
|
|
@ -415,7 +466,8 @@ async def search_messages(guild_id):
|
|||
# use that list on the main search query.
|
||||
can_read = await fetch_readable_channels(guild_id, user_id)
|
||||
|
||||
rows = await app.db.fetch(f"""
|
||||
rows = await app.db.fetch(
|
||||
f"""
|
||||
SELECT orig.id AS current_id,
|
||||
COUNT(*) OVER() as total_results,
|
||||
array((SELECT messages.id AS before_id
|
||||
|
|
@ -432,12 +484,17 @@ async def search_messages(guild_id):
|
|||
ORDER BY orig.id DESC
|
||||
LIMIT 50
|
||||
OFFSET $3
|
||||
""", guild_id, j['content'], j['offset'], can_read)
|
||||
""",
|
||||
guild_id,
|
||||
j["content"],
|
||||
j["offset"],
|
||||
can_read,
|
||||
)
|
||||
|
||||
return jsonify(await search_result_from_list(rows))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/ack', methods=['POST'])
|
||||
@bp.route("/<int:guild_id>/ack", methods=["POST"])
|
||||
async def ack_guild(guild_id):
|
||||
"""ACKnowledge all messages in the guild."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -448,45 +505,43 @@ async def ack_guild(guild_id):
|
|||
for chan_id in chan_ids:
|
||||
await channel_ack(user_id, guild_id, chan_id)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/vanity-url', methods=['GET'])
|
||||
@bp.route("/<int:guild_id>/vanity-url", methods=["GET"])
|
||||
async def get_vanity_url(guild_id: int):
|
||||
"""Get the vanity url of a guild."""
|
||||
user_id = await token_check()
|
||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
||||
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||
|
||||
inv_code = await app.storage.vanity_invite(guild_id)
|
||||
|
||||
if inv_code is None:
|
||||
return jsonify({'code': None})
|
||||
return jsonify({"code": None})
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_invite(inv_code)
|
||||
)
|
||||
return jsonify(await app.storage.get_invite(inv_code))
|
||||
|
||||
|
||||
@bp.route('/<int:guild_id>/vanity-url', methods=['PATCH'])
|
||||
@bp.route("/<int:guild_id>/vanity-url", methods=["PATCH"])
|
||||
async def change_vanity_url(guild_id: int):
|
||||
"""Get the vanity url of a guild."""
|
||||
user_id = await token_check()
|
||||
|
||||
if not await app.storage.has_feature(guild_id, 'VANITY_URL'):
|
||||
if not await app.storage.has_feature(guild_id, "VANITY_URL"):
|
||||
# TODO: is this the right error
|
||||
raise BadRequest('guild has no vanity url support')
|
||||
raise BadRequest("guild has no vanity url support")
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
||||
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||
|
||||
j = validate(await request.get_json(), VANITY_URL_PATCH)
|
||||
inv_code = j['code']
|
||||
inv_code = j["code"]
|
||||
|
||||
# store old vanity in a variable to delete it from
|
||||
# invites table
|
||||
old_vanity = await app.storage.vanity_invite(guild_id)
|
||||
|
||||
if old_vanity == inv_code:
|
||||
raise BadRequest('can not change to same invite')
|
||||
raise BadRequest("can not change to same invite")
|
||||
|
||||
# this is sad because we don't really use the things
|
||||
# sql gives us, but i havent really found a way to put
|
||||
|
|
@ -494,19 +549,22 @@ async def change_vanity_url(guild_id: int):
|
|||
# guild_id_fkey fails but INSERT when code_fkey fails..
|
||||
inv = await app.storage.get_invite(inv_code)
|
||||
if inv:
|
||||
raise BadRequest('invite already exists')
|
||||
raise BadRequest("invite already exists")
|
||||
|
||||
# TODO: this is bad, what if a guild has no channels?
|
||||
# we should probably choose the first channel that has
|
||||
# @everyone read messages
|
||||
channels = await app.storage.get_channel_data(guild_id)
|
||||
channel_id = int(channels[0]['id'])
|
||||
channel_id = int(channels[0]["id"])
|
||||
|
||||
# delete the old invite, insert new one
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM invites
|
||||
WHERE code = $1
|
||||
""", old_vanity)
|
||||
""",
|
||||
old_vanity,
|
||||
)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
|
|
@ -515,21 +573,27 @@ async def change_vanity_url(guild_id: int):
|
|||
max_age, temporary)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""",
|
||||
inv_code, guild_id, channel_id, user_id,
|
||||
|
||||
inv_code,
|
||||
guild_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
# sane defaults for vanity urls.
|
||||
0, 0, False,
|
||||
0,
|
||||
0,
|
||||
False,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO vanity_invites (guild_id, code)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO
|
||||
UPDATE
|
||||
SET code = $2
|
||||
WHERE vanity_invites.guild_id = $1
|
||||
""", guild_id, inv_code)
|
||||
|
||||
return jsonify(
|
||||
await app.storage.get_invite(inv_code)
|
||||
""",
|
||||
guild_id,
|
||||
inv_code,
|
||||
)
|
||||
|
||||
return jsonify(await app.storage.get_invite(inv_code))
|
||||
|
|
|
|||
|
|
@ -24,41 +24,39 @@ from quart import Blueprint, current_app as app, send_file, redirect
|
|||
from litecord.embed.sanitizer import make_md_req_url
|
||||
from litecord.embed.schemas import EmbedURL
|
||||
|
||||
bp = Blueprint('images', __name__)
|
||||
bp = Blueprint("images", __name__)
|
||||
|
||||
|
||||
async def send_icon(scope, key, icon_hash, **kwargs):
|
||||
"""Send an icon."""
|
||||
icon = await app.icons.generic_get(
|
||||
scope, key, icon_hash, **kwargs)
|
||||
icon = await app.icons.generic_get(scope, key, icon_hash, **kwargs)
|
||||
|
||||
if not icon:
|
||||
return '', 404
|
||||
return "", 404
|
||||
|
||||
return await send_file(icon.as_path)
|
||||
|
||||
|
||||
def splitext_(filepath):
|
||||
name, ext = splitext(filepath)
|
||||
return name, ext.strip('.')
|
||||
return name, ext.strip(".")
|
||||
|
||||
|
||||
@bp.route('/emojis/<emoji_file>', methods=['GET'])
|
||||
@bp.route("/emojis/<emoji_file>", methods=["GET"])
|
||||
async def _get_raw_emoji(emoji_file):
|
||||
# emoji = app.icons.get_emoji(emoji_id, ext=ext)
|
||||
# just a test file for now
|
||||
emoji_id, ext = splitext_(emoji_file)
|
||||
return await send_icon(
|
||||
'emoji', emoji_id, None, ext=ext)
|
||||
return await send_icon("emoji", emoji_id, None, ext=ext)
|
||||
|
||||
|
||||
@bp.route('/icons/<int:guild_id>/<icon_file>', methods=['GET'])
|
||||
@bp.route("/icons/<int:guild_id>/<icon_file>", methods=["GET"])
|
||||
async def _get_guild_icon(guild_id: int, icon_file: str):
|
||||
icon_hash, ext = splitext_(icon_file)
|
||||
return await send_icon('guild', guild_id, icon_hash, ext=ext)
|
||||
return await send_icon("guild", guild_id, icon_hash, ext=ext)
|
||||
|
||||
|
||||
@bp.route('/embed/avatars/<int:default_id>.png')
|
||||
@bp.route("/embed/avatars/<int:default_id>.png")
|
||||
async def _get_default_user_avatar(default_id: int):
|
||||
# TODO: how do we determine which assets to use for this?
|
||||
# I don't think we can use discord assets.
|
||||
|
|
@ -66,25 +64,29 @@ async def _get_default_user_avatar(default_id: int):
|
|||
|
||||
|
||||
async def _handle_webhook_avatar(md_url_redir: str):
|
||||
md_url = make_md_req_url(app.config, 'img', EmbedURL(md_url_redir))
|
||||
md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
|
||||
return redirect(md_url)
|
||||
|
||||
|
||||
@bp.route('/avatars/<int:user_id>/<avatar_file>')
|
||||
@bp.route("/avatars/<int:user_id>/<avatar_file>")
|
||||
async def _get_user_avatar(user_id, avatar_file):
|
||||
avatar_hash, ext = splitext_(avatar_file)
|
||||
|
||||
# first, check if this is a webhook avatar to redir to
|
||||
md_url_redir = await app.db.fetchval("""
|
||||
md_url_redir = await app.db.fetchval(
|
||||
"""
|
||||
SELECT md_url_redir
|
||||
FROM webhook_avatars
|
||||
WHERE webhook_id = $1 AND hash = $2
|
||||
""", user_id, avatar_hash)
|
||||
""",
|
||||
user_id,
|
||||
avatar_hash,
|
||||
)
|
||||
|
||||
if md_url_redir:
|
||||
return await _handle_webhook_avatar(md_url_redir)
|
||||
|
||||
return await send_icon('user', user_id, avatar_hash, ext=ext)
|
||||
return await send_icon("user", user_id, avatar_hash, ext=ext)
|
||||
|
||||
|
||||
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
|
||||
|
|
@ -92,19 +94,19 @@ async def get_app_icon(application_id, icon_hash, ext):
|
|||
pass
|
||||
|
||||
|
||||
@bp.route('/channel-icons/<int:channel_id>/<icon_file>', methods=['GET'])
|
||||
@bp.route("/channel-icons/<int:channel_id>/<icon_file>", methods=["GET"])
|
||||
async def _get_gdm_icon(channel_id: int, icon_file: str):
|
||||
icon_hash, ext = splitext_(icon_file)
|
||||
return await send_icon('channel-icons', channel_id, icon_hash, ext=ext)
|
||||
return await send_icon("channel-icons", channel_id, icon_hash, ext=ext)
|
||||
|
||||
|
||||
@bp.route('/splashes/<int:guild_id>/<icon_file>', methods=['GET'])
|
||||
@bp.route("/splashes/<int:guild_id>/<icon_file>", methods=["GET"])
|
||||
async def _get_guild_splash(guild_id: int, icon_file: str):
|
||||
icon_hash, ext = splitext_(icon_file)
|
||||
return await send_icon('splash', guild_id, icon_hash, ext=ext)
|
||||
return await send_icon("splash", guild_id, icon_hash, ext=ext)
|
||||
|
||||
|
||||
@bp.route('/banners/<int:guild_id>/<icon_file>', methods=['GET'])
|
||||
@bp.route("/banners/<int:guild_id>/<icon_file>", methods=["GET"])
|
||||
async def _get_guild_banner(guild_id: int, icon_file: str):
|
||||
icon_hash, ext = splitext_(icon_file)
|
||||
return await send_icon('banner', guild_id, icon_hash, ext=ext)
|
||||
return await send_icon("banner", guild_id, icon_hash, ext=ext)
|
||||
|
|
|
|||
|
|
@ -32,13 +32,16 @@ from .guilds import create_guild_settings
|
|||
from ..utils import async_map
|
||||
|
||||
from litecord.blueprints.checks import (
|
||||
channel_check, channel_perm_check, guild_check, guild_perm_check
|
||||
channel_check,
|
||||
channel_perm_check,
|
||||
guild_check,
|
||||
guild_perm_check,
|
||||
)
|
||||
|
||||
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('invites', __name__)
|
||||
bp = Blueprint("invites", __name__)
|
||||
|
||||
|
||||
class UnknownInvite(BadRequest):
|
||||
|
|
@ -48,16 +51,18 @@ class UnknownInvite(BadRequest):
|
|||
class InvalidInvite(Forbidden):
|
||||
error_code = 50020
|
||||
|
||||
|
||||
class AlreadyInvited(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
def gen_inv_code() -> str:
|
||||
"""Generate an invite code.
|
||||
|
||||
This is a primitive and does not guarantee uniqueness.
|
||||
"""
|
||||
raw = secrets.token_urlsafe(10)
|
||||
raw = re.sub(r'\/|\+|\-|\_', '', raw)
|
||||
raw = re.sub(r"\/|\+|\-|\_", "", raw)
|
||||
|
||||
return raw[:7]
|
||||
|
||||
|
|
@ -65,23 +70,31 @@ def gen_inv_code() -> str:
|
|||
async def invite_precheck(user_id: int, guild_id: int):
|
||||
"""pre-check invite use in the context of a guild."""
|
||||
|
||||
joined = await app.db.fetchval("""
|
||||
joined = await app.db.fetchval(
|
||||
"""
|
||||
SELECT joined_at
|
||||
FROM members
|
||||
WHERE user_id = $1 AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if joined is not None:
|
||||
raise AlreadyInvited('You are already in the guild')
|
||||
raise AlreadyInvited("You are already in the guild")
|
||||
|
||||
banned = await app.db.fetchval("""
|
||||
banned = await app.db.fetchval(
|
||||
"""
|
||||
SELECT reason
|
||||
FROM bans
|
||||
WHERE user_id = $1 AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if banned is not None:
|
||||
raise InvalidInvite('You are banned.')
|
||||
raise InvalidInvite("You are banned.")
|
||||
|
||||
|
||||
async def invite_precheck_gdm(user_id: int, channel_id: int):
|
||||
|
|
@ -89,23 +102,23 @@ async def invite_precheck_gdm(user_id: int, channel_id: int):
|
|||
is_member = await gdm_is_member(channel_id, user_id)
|
||||
|
||||
if is_member:
|
||||
raise AlreadyInvited('You are already in the Group DM')
|
||||
raise AlreadyInvited("You are already in the Group DM")
|
||||
|
||||
|
||||
async def _inv_check_age(inv: dict):
|
||||
if inv['max_age'] == 0:
|
||||
if inv["max_age"] == 0:
|
||||
return
|
||||
|
||||
now = datetime.datetime.utcnow()
|
||||
delta_sec = (now - inv['created_at']).total_seconds()
|
||||
delta_sec = (now - inv["created_at"]).total_seconds()
|
||||
|
||||
if delta_sec > inv['max_age']:
|
||||
await delete_invite(inv['code'])
|
||||
raise InvalidInvite('Invite is expired')
|
||||
if delta_sec > inv["max_age"]:
|
||||
await delete_invite(inv["code"])
|
||||
raise InvalidInvite("Invite is expired")
|
||||
|
||||
if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']:
|
||||
await delete_invite(inv['code'])
|
||||
raise InvalidInvite('Too many uses')
|
||||
if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
|
||||
await delete_invite(inv["code"])
|
||||
raise InvalidInvite("Too many uses")
|
||||
|
||||
|
||||
async def _guild_add_member(guild_id: int, user_id: int):
|
||||
|
|
@ -119,78 +132,89 @@ async def _guild_add_member(guild_id: int, user_id: int):
|
|||
"""
|
||||
|
||||
# TODO: system message for member join
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO members (user_id, guild_id)
|
||||
VALUES ($1, $2)
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
await create_guild_settings(guild_id, user_id)
|
||||
|
||||
# add the @everyone role to the invited member
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||
VALUES ($1, $2, $3)
|
||||
""", user_id, guild_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
# tell current members a new member came up
|
||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_ADD', {
|
||||
**member,
|
||||
**{
|
||||
'guild_id': str(guild_id),
|
||||
},
|
||||
})
|
||||
await app.dispatcher.dispatch_guild(
|
||||
guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
|
||||
)
|
||||
|
||||
# update member lists for the new member
|
||||
await app.dispatcher.dispatch(
|
||||
'lazy_guild', guild_id, 'new_member', user_id)
|
||||
await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
|
||||
|
||||
# subscribe new member to guild, so they get events n stuff
|
||||
await app.dispatcher.sub('guild', guild_id, user_id)
|
||||
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||
|
||||
# tell the new member that theres the guild it just joined.
|
||||
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
||||
# just to the shards that are actually tied to it.
|
||||
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
await app.dispatcher.dispatch_user_guild(
|
||||
user_id, guild_id, 'GUILD_CREATE', guild)
|
||||
await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
|
||||
|
||||
|
||||
async def use_invite(user_id, invite_code):
|
||||
"""Try using an invite"""
|
||||
inv = await app.db.fetchrow("""
|
||||
inv = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT code, channel_id, guild_id, created_at,
|
||||
max_age, uses, max_uses
|
||||
FROM invites
|
||||
WHERE code = $1
|
||||
""", invite_code)
|
||||
""",
|
||||
invite_code,
|
||||
)
|
||||
|
||||
if inv is None:
|
||||
raise UnknownInvite('Unknown invite')
|
||||
raise UnknownInvite("Unknown invite")
|
||||
|
||||
await _inv_check_age(inv)
|
||||
|
||||
# NOTE: if group dm invite, guild_id is null.
|
||||
guild_id = inv['guild_id']
|
||||
guild_id = inv["guild_id"]
|
||||
|
||||
try:
|
||||
try:
|
||||
if guild_id is None:
|
||||
channel_id = inv['channel_id']
|
||||
await invite_precheck_gdm(user_id, inv['channel_id'])
|
||||
channel_id = inv["channel_id"]
|
||||
await invite_precheck_gdm(user_id, inv["channel_id"])
|
||||
await gdm_add_recipient(channel_id, user_id)
|
||||
else:
|
||||
await invite_precheck(user_id, guild_id)
|
||||
await _guild_add_member(guild_id, user_id)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE invites
|
||||
SET uses = uses + 1
|
||||
WHERE code = $1
|
||||
""", invite_code)
|
||||
""",
|
||||
invite_code,
|
||||
)
|
||||
except AlreadyInvited:
|
||||
pass
|
||||
|
||||
@bp.route('/channels/<int:channel_id>/invites', methods=['POST'])
|
||||
|
||||
@bp.route("/channels/<int:channel_id>/invites", methods=["POST"])
|
||||
async def create_invite(channel_id):
|
||||
"""Create an invite to a channel."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -201,12 +225,14 @@ async def create_invite(channel_id):
|
|||
|
||||
# NOTE: this works on group dms, since it returns ALL_PERMISSIONS on
|
||||
# non-guild channels.
|
||||
await channel_perm_check(user_id, channel_id, 'create_invites')
|
||||
await channel_perm_check(user_id, channel_id, "create_invites")
|
||||
|
||||
if chantype not in (ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE,
|
||||
ChannelType.GROUP_DM):
|
||||
raise BadRequest('Invalid channel type')
|
||||
if chantype not in (
|
||||
ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE,
|
||||
ChannelType.GROUP_DM,
|
||||
):
|
||||
raise BadRequest("Invalid channel type")
|
||||
|
||||
invite_code = gen_inv_code()
|
||||
|
||||
|
|
@ -222,101 +248,122 @@ async def create_invite(channel_id):
|
|||
max_age, temporary)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""",
|
||||
invite_code, guild_id, channel_id, user_id,
|
||||
j['max_uses'], j['max_age'], j['temporary']
|
||||
invite_code,
|
||||
guild_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
j["max_uses"],
|
||||
j["max_age"],
|
||||
j["temporary"],
|
||||
)
|
||||
|
||||
invite = await app.storage.get_invite(invite_code)
|
||||
return jsonify(invite)
|
||||
|
||||
|
||||
@bp.route('/invite/<invite_code>', methods=['GET'])
|
||||
@bp.route('/invites/<invite_code>', methods=['GET'])
|
||||
@bp.route("/invite/<invite_code>", methods=["GET"])
|
||||
@bp.route("/invites/<invite_code>", methods=["GET"])
|
||||
async def get_invite(invite_code: str):
|
||||
inv = await app.storage.get_invite(invite_code)
|
||||
|
||||
if not inv:
|
||||
return '', 404
|
||||
return "", 404
|
||||
|
||||
if request.args.get('with_counts'):
|
||||
if request.args.get("with_counts"):
|
||||
extra = await app.storage.get_invite_extra(invite_code)
|
||||
inv.update(extra)
|
||||
|
||||
return jsonify(inv)
|
||||
|
||||
|
||||
async def delete_invite(invite_code: str):
|
||||
"""Delete an invite."""
|
||||
await app.db.fetchval("""
|
||||
await app.db.fetchval(
|
||||
"""
|
||||
DELETE FROM invites
|
||||
WHERE code = $1
|
||||
""", invite_code)
|
||||
""",
|
||||
invite_code,
|
||||
)
|
||||
|
||||
@bp.route('/invite/<invite_code>', methods=['DELETE'])
|
||||
@bp.route('/invites/<invite_code>', methods=['DELETE'])
|
||||
|
||||
@bp.route("/invite/<invite_code>", methods=["DELETE"])
|
||||
@bp.route("/invites/<invite_code>", methods=["DELETE"])
|
||||
async def _delete_invite(invite_code: str):
|
||||
user_id = await token_check()
|
||||
|
||||
guild_id = await app.db.fetchval("""
|
||||
guild_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT guild_id
|
||||
FROM invites
|
||||
WHERE code = $1
|
||||
""", invite_code)
|
||||
""",
|
||||
invite_code,
|
||||
)
|
||||
|
||||
if guild_id is None:
|
||||
raise BadRequest('Unknown invite')
|
||||
raise BadRequest("Unknown invite")
|
||||
|
||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
||||
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||
|
||||
inv = await app.storage.get_invite(invite_code)
|
||||
await delete_invite(invite_code)
|
||||
return jsonify(inv)
|
||||
|
||||
|
||||
async def _get_inv(code):
|
||||
inv = await app.storage.get_invite(code)
|
||||
meta = await app.storage.get_invite_metadata(code)
|
||||
return {**inv, **meta}
|
||||
|
||||
|
||||
@bp.route('/guilds/<int:guild_id>/invites', methods=['GET'])
|
||||
@bp.route("/guilds/<int:guild_id>/invites", methods=["GET"])
|
||||
async def get_guild_invites(guild_id: int):
|
||||
"""Get all invites for a guild."""
|
||||
user_id = await token_check()
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
||||
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||
|
||||
inv_codes = await app.db.fetch("""
|
||||
inv_codes = await app.db.fetch(
|
||||
"""
|
||||
SELECT code
|
||||
FROM invites
|
||||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
inv_codes = [r['code'] for r in inv_codes]
|
||||
inv_codes = [r["code"] for r in inv_codes]
|
||||
invs = await async_map(_get_inv, inv_codes)
|
||||
return jsonify(invs)
|
||||
|
||||
|
||||
@bp.route('/channels/<int:channel_id>/invites', methods=['GET'])
|
||||
@bp.route("/channels/<int:channel_id>/invites", methods=["GET"])
|
||||
async def get_channel_invites(channel_id: int):
|
||||
"""Get all invites for a channel."""
|
||||
user_id = await token_check()
|
||||
|
||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
||||
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||
|
||||
inv_codes = await app.db.fetch("""
|
||||
inv_codes = await app.db.fetch(
|
||||
"""
|
||||
SELECT code
|
||||
FROM invites
|
||||
WHERE guild_id = $1 AND channel_id = $2
|
||||
""", guild_id, channel_id)
|
||||
""",
|
||||
guild_id,
|
||||
channel_id,
|
||||
)
|
||||
|
||||
inv_codes = [r['code'] for r in inv_codes]
|
||||
inv_codes = [r["code"] for r in inv_codes]
|
||||
invs = await async_map(_get_inv, inv_codes)
|
||||
return jsonify(invs)
|
||||
|
||||
|
||||
@bp.route('/invite/<invite_code>', methods=['POST'])
|
||||
@bp.route('/invites/<invite_code>', methods=['POST'])
|
||||
@bp.route("/invite/<invite_code>", methods=["POST"])
|
||||
@bp.route("/invites/<invite_code>", methods=["POST"])
|
||||
async def _use_invite(invite_code):
|
||||
"""Use an invite."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -327,9 +374,4 @@ async def _use_invite(invite_code):
|
|||
inv = await app.storage.get_invite(invite_code)
|
||||
inv_meta = await app.storage.get_invite_metadata(invite_code)
|
||||
|
||||
return jsonify({
|
||||
**inv,
|
||||
**{
|
||||
'inviter': inv_meta['inviter']
|
||||
}
|
||||
})
|
||||
return jsonify({**inv, **{"inviter": inv_meta["inviter"]}})
|
||||
|
|
|
|||
|
|
@ -19,83 +19,75 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
from quart import Blueprint, current_app as app, jsonify, request
|
||||
|
||||
bp = Blueprint('nodeinfo', __name__)
|
||||
bp = Blueprint("nodeinfo", __name__)
|
||||
|
||||
|
||||
@bp.route('/.well-known/nodeinfo')
|
||||
@bp.route("/.well-known/nodeinfo")
|
||||
async def _dummy_nodeinfo_index():
|
||||
proto = 'http' if not app.config['IS_SSL'] else 'https'
|
||||
main_url = app.config.get('MAIN_URL', request.host)
|
||||
proto = "http" if not app.config["IS_SSL"] else "https"
|
||||
main_url = app.config.get("MAIN_URL", request.host)
|
||||
|
||||
return jsonify({
|
||||
'links': [{
|
||||
'href': f'{proto}://{main_url}/nodeinfo/2.0.json',
|
||||
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.0'
|
||||
}, {
|
||||
'href': f'{proto}://{main_url}/nodeinfo/2.1.json',
|
||||
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
|
||||
}]
|
||||
})
|
||||
return jsonify(
|
||||
{
|
||||
"links": [
|
||||
{
|
||||
"href": f"{proto}://{main_url}/nodeinfo/2.0.json",
|
||||
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
|
||||
},
|
||||
{
|
||||
"href": f"{proto}://{main_url}/nodeinfo/2.1.json",
|
||||
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def fetch_nodeinfo_20():
|
||||
usercount = await app.db.fetchval("""
|
||||
usercount = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM users
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
message_count = await app.db.fetchval("""
|
||||
message_count = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM messages
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
return {
|
||||
'metadata': {
|
||||
'features': [
|
||||
'discord_api'
|
||||
],
|
||||
|
||||
'nodeDescription': 'A Litecord instance',
|
||||
'nodeName': 'Litecord/Nya',
|
||||
'private': False,
|
||||
|
||||
'federation': {}
|
||||
"metadata": {
|
||||
"features": ["discord_api"],
|
||||
"nodeDescription": "A Litecord instance",
|
||||
"nodeName": "Litecord/Nya",
|
||||
"private": False,
|
||||
"federation": {},
|
||||
},
|
||||
'openRegistrations': app.config['REGISTRATIONS'],
|
||||
'protocols': [],
|
||||
'software': {
|
||||
'name': 'litecord',
|
||||
'version': 'litecord v0',
|
||||
},
|
||||
|
||||
'services': {
|
||||
'inbound': [],
|
||||
'outbound': [],
|
||||
},
|
||||
|
||||
'usage': {
|
||||
'localPosts': message_count,
|
||||
'users': {
|
||||
'total': usercount
|
||||
}
|
||||
},
|
||||
'version': '2.0',
|
||||
"openRegistrations": app.config["REGISTRATIONS"],
|
||||
"protocols": [],
|
||||
"software": {"name": "litecord", "version": "litecord v0"},
|
||||
"services": {"inbound": [], "outbound": []},
|
||||
"usage": {"localPosts": message_count, "users": {"total": usercount}},
|
||||
"version": "2.0",
|
||||
}
|
||||
|
||||
|
||||
@bp.route('/nodeinfo/2.0.json')
|
||||
@bp.route("/nodeinfo/2.0.json")
|
||||
async def _nodeinfo_20():
|
||||
"""Handler for nodeinfo 2.0."""
|
||||
raw_nodeinfo = await fetch_nodeinfo_20()
|
||||
return jsonify(raw_nodeinfo)
|
||||
|
||||
|
||||
@bp.route('/nodeinfo/2.1.json')
|
||||
@bp.route("/nodeinfo/2.1.json")
|
||||
async def _nodeinfo_21():
|
||||
"""Handler for nodeinfo 2.1."""
|
||||
raw_nodeinfo = await fetch_nodeinfo_20()
|
||||
|
||||
raw_nodeinfo['software']['repository'] = 'https://gitlab.com/litecord/litecord'
|
||||
raw_nodeinfo['version'] = '2.1'
|
||||
raw_nodeinfo["software"]["repository"] = "https://gitlab.com/litecord/litecord"
|
||||
raw_nodeinfo["version"] = "2.1"
|
||||
|
||||
return jsonify(raw_nodeinfo)
|
||||
|
|
|
|||
|
|
@ -26,76 +26,89 @@ from ..enums import RelationshipType
|
|||
from litecord.errors import BadRequest
|
||||
|
||||
|
||||
bp = Blueprint('relationship', __name__)
|
||||
bp = Blueprint("relationship", __name__)
|
||||
|
||||
|
||||
@bp.route('/@me/relationships', methods=['GET'])
|
||||
@bp.route("/@me/relationships", methods=["GET"])
|
||||
async def get_me_relationships():
|
||||
user_id = await token_check()
|
||||
return jsonify(
|
||||
await app.user_storage.get_relationships(user_id))
|
||||
return jsonify(await app.user_storage.get_relationships(user_id))
|
||||
|
||||
|
||||
async def _dispatch_single_pres(user_id, presence: dict):
|
||||
await app.dispatcher.dispatch(
|
||||
'user', user_id, 'PRESENCE_UPDATE', presence
|
||||
)
|
||||
await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
|
||||
|
||||
|
||||
async def _unsub_friend(user_id, peer_id):
|
||||
await app.dispatcher.unsub('friend', user_id, peer_id)
|
||||
await app.dispatcher.unsub('friend', peer_id, user_id)
|
||||
await app.dispatcher.unsub("friend", user_id, peer_id)
|
||||
await app.dispatcher.unsub("friend", peer_id, user_id)
|
||||
|
||||
|
||||
async def _sub_friend(user_id, peer_id):
|
||||
await app.dispatcher.sub('friend', user_id, peer_id)
|
||||
await app.dispatcher.sub('friend', peer_id, user_id)
|
||||
await app.dispatcher.sub("friend", user_id, peer_id)
|
||||
await app.dispatcher.sub("friend", peer_id, user_id)
|
||||
|
||||
# dispatch presence update to the user and peer about
|
||||
# eachother's presence.
|
||||
user_pres, peer_pres = await app.presence.friend_presences(
|
||||
[user_id, peer_id]
|
||||
)
|
||||
user_pres, peer_pres = await app.presence.friend_presences([user_id, peer_id])
|
||||
|
||||
await _dispatch_single_pres(user_id, peer_pres)
|
||||
await _dispatch_single_pres(peer_id, user_pres)
|
||||
|
||||
|
||||
async def make_friend(user_id: int, peer_id: int,
|
||||
rel_type=RelationshipType.FRIEND.value):
|
||||
async def make_friend(
|
||||
user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value
|
||||
):
|
||||
_friend = RelationshipType.FRIEND.value
|
||||
_block = RelationshipType.BLOCK.value
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO relationships (user_id, peer_id, rel_type)
|
||||
VALUES ($1, $2, $3)
|
||||
""", user_id, peer_id, rel_type)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
rel_type,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
# try to update rel_type
|
||||
old_rel_type = await app.db.fetchval("""
|
||||
old_rel_type = await app.db.fetchval(
|
||||
"""
|
||||
SELECT rel_type
|
||||
FROM relationships
|
||||
WHERE user_id = $1 AND peer_id = $2
|
||||
""", user_id, peer_id)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
)
|
||||
|
||||
if old_rel_type == _friend and rel_type == _block:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE relationships
|
||||
SET rel_type = $1
|
||||
WHERE user_id = $2 AND peer_id = $3
|
||||
""", rel_type, user_id, peer_id)
|
||||
""",
|
||||
rel_type,
|
||||
user_id,
|
||||
peer_id,
|
||||
)
|
||||
|
||||
# remove any existing friendship before the block
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM relationships
|
||||
WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3
|
||||
""", peer_id, user_id, _friend)
|
||||
""",
|
||||
peer_id,
|
||||
user_id,
|
||||
_friend,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user(
|
||||
peer_id, 'RELATIONSHIP_REMOVE', {
|
||||
'type': _friend,
|
||||
'id': str(user_id)
|
||||
}
|
||||
peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
|
@ -106,95 +119,118 @@ async def make_friend(user_id: int, peer_id: int,
|
|||
|
||||
# check if this is an acceptance
|
||||
# of a friend request
|
||||
existing = await app.db.fetchrow("""
|
||||
existing = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT user_id, peer_id
|
||||
FROM relationships
|
||||
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
||||
""", peer_id, user_id, _friend)
|
||||
""",
|
||||
peer_id,
|
||||
user_id,
|
||||
_friend,
|
||||
)
|
||||
|
||||
_dispatch = app.dispatcher.dispatch_user
|
||||
|
||||
if existing:
|
||||
# accepted a friend request, dispatch respective
|
||||
# relationship events
|
||||
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
|
||||
'type': RelationshipType.INCOMING.value,
|
||||
'id': str(peer_id)
|
||||
})
|
||||
await _dispatch(
|
||||
user_id,
|
||||
"RELATIONSHIP_REMOVE",
|
||||
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
|
||||
)
|
||||
|
||||
await _dispatch(user_id, 'RELATIONSHIP_ADD', {
|
||||
'type': _friend,
|
||||
'id': str(peer_id),
|
||||
'user': await app.storage.get_user(peer_id)
|
||||
})
|
||||
await _dispatch(
|
||||
user_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(peer_id),
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
)
|
||||
|
||||
await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
|
||||
'type': _friend,
|
||||
'id': str(user_id),
|
||||
'user': await app.storage.get_user(user_id)
|
||||
})
|
||||
await _dispatch(
|
||||
peer_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"type": _friend,
|
||||
"id": str(user_id),
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
await _sub_friend(user_id, peer_id)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
# check if friend AND not acceptance of fr
|
||||
if rel_type == _friend:
|
||||
await _dispatch(user_id, 'RELATIONSHIP_ADD', {
|
||||
'id': str(peer_id),
|
||||
'type': RelationshipType.OUTGOING.value,
|
||||
'user': await app.storage.get_user(peer_id),
|
||||
})
|
||||
await _dispatch(
|
||||
user_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(peer_id),
|
||||
"type": RelationshipType.OUTGOING.value,
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
)
|
||||
|
||||
await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
|
||||
'id': str(user_id),
|
||||
'type': RelationshipType.INCOMING.value,
|
||||
'user': await app.storage.get_user(user_id)
|
||||
})
|
||||
await _dispatch(
|
||||
peer_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(user_id),
|
||||
"type": RelationshipType.INCOMING.value,
|
||||
"user": await app.storage.get_user(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
# we don't make the pubsub link
|
||||
# until the peer accepts the friend request
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
return
|
||||
|
||||
|
||||
class RelationshipFailed(BadRequest):
|
||||
"""Exception for general relationship errors."""
|
||||
|
||||
error_code = 80004
|
||||
|
||||
|
||||
class RelationshipBlocked(BadRequest):
|
||||
"""Exception for when the peer has blocked the user."""
|
||||
|
||||
error_code = 80001
|
||||
|
||||
|
||||
@bp.route('/@me/relationships', methods=['POST'])
|
||||
@bp.route("/@me/relationships", methods=["POST"])
|
||||
async def post_relationship():
|
||||
user_id = await token_check()
|
||||
j = validate(await request.get_json(), SPECIFIC_FRIEND)
|
||||
|
||||
uid = await app.storage.search_user(j['username'],
|
||||
str(j['discriminator']))
|
||||
uid = await app.storage.search_user(j["username"], str(j["discriminator"]))
|
||||
|
||||
if not uid:
|
||||
raise RelationshipFailed('No users with DiscordTag exist')
|
||||
raise RelationshipFailed("No users with DiscordTag exist")
|
||||
|
||||
res = await make_friend(user_id, uid)
|
||||
|
||||
if res is None:
|
||||
raise RelationshipBlocked('Can not friend user due to block')
|
||||
raise RelationshipBlocked("Can not friend user due to block")
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/@me/relationships/<int:peer_id>', methods=['PUT'])
|
||||
@bp.route("/@me/relationships/<int:peer_id>", methods=["PUT"])
|
||||
async def add_relationship(peer_id: int):
|
||||
"""Add a relationship to the peer."""
|
||||
user_id = await token_check()
|
||||
payload = validate(await request.get_json(), RELATIONSHIP)
|
||||
rel_type = payload['type']
|
||||
rel_type = payload["type"]
|
||||
|
||||
res = await make_friend(user_id, peer_id, rel_type)
|
||||
|
||||
|
|
@ -204,18 +240,22 @@ async def add_relationship(peer_id: int):
|
|||
# make_friend did not succeed, so we
|
||||
# assume it is a block and dispatch
|
||||
# the respective RELATIONSHIP_ADD.
|
||||
await app.dispatcher.dispatch_user(user_id, 'RELATIONSHIP_ADD', {
|
||||
'id': str(peer_id),
|
||||
'type': RelationshipType.BLOCK.value,
|
||||
'user': await app.storage.get_user(peer_id)
|
||||
})
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id,
|
||||
"RELATIONSHIP_ADD",
|
||||
{
|
||||
"id": str(peer_id),
|
||||
"type": RelationshipType.BLOCK.value,
|
||||
"user": await app.storage.get_user(peer_id),
|
||||
},
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/@me/relationships/<int:peer_id>', methods=['DELETE'])
|
||||
@bp.route("/@me/relationships/<int:peer_id>", methods=["DELETE"])
|
||||
async def remove_relationship(peer_id: int):
|
||||
"""Remove an existing relationship"""
|
||||
user_id = await token_check()
|
||||
|
|
@ -223,69 +263,86 @@ async def remove_relationship(peer_id: int):
|
|||
_block = RelationshipType.BLOCK.value
|
||||
_dispatch = app.dispatcher.dispatch_user
|
||||
|
||||
rel_type = await app.db.fetchval("""
|
||||
rel_type = await app.db.fetchval(
|
||||
"""
|
||||
SELECT rel_type
|
||||
FROM relationships
|
||||
WHERE user_id = $1 AND peer_id = $2
|
||||
""", user_id, peer_id)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
)
|
||||
|
||||
incoming_rel_type = await app.db.fetchval("""
|
||||
incoming_rel_type = await app.db.fetchval(
|
||||
"""
|
||||
SELECT rel_type
|
||||
FROM relationships
|
||||
WHERE user_id = $1 AND peer_id = $2
|
||||
""", peer_id, user_id)
|
||||
""",
|
||||
peer_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
# if any of those are friend
|
||||
if _friend in (rel_type, incoming_rel_type):
|
||||
# closing the friendship, have to delete both rows
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM relationships
|
||||
WHERE (
|
||||
(user_id = $1 AND peer_id = $2) OR
|
||||
(user_id = $2 AND peer_id = $1)
|
||||
) AND rel_type = $3
|
||||
""", user_id, peer_id, _friend)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
_friend,
|
||||
)
|
||||
|
||||
# if there wasnt any mutual friendship before,
|
||||
# assume they were requests of INCOMING
|
||||
# and OUTGOING.
|
||||
user_del_type = RelationshipType.OUTGOING.value if \
|
||||
incoming_rel_type != _friend else _friend
|
||||
user_del_type = (
|
||||
RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend
|
||||
)
|
||||
|
||||
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
|
||||
'id': str(peer_id),
|
||||
'type': user_del_type,
|
||||
})
|
||||
await _dispatch(
|
||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
|
||||
)
|
||||
|
||||
peer_del_type = RelationshipType.INCOMING.value if \
|
||||
incoming_rel_type != _friend else _friend
|
||||
peer_del_type = (
|
||||
RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend
|
||||
)
|
||||
|
||||
await _dispatch(peer_id, 'RELATIONSHIP_REMOVE', {
|
||||
'id': str(user_id),
|
||||
'type': peer_del_type,
|
||||
})
|
||||
await _dispatch(
|
||||
peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
# was a block!
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM relationships
|
||||
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
||||
""", user_id, peer_id, _block)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
_block,
|
||||
)
|
||||
|
||||
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
|
||||
'id': str(peer_id),
|
||||
'type': _block,
|
||||
})
|
||||
await _dispatch(
|
||||
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
|
||||
)
|
||||
|
||||
await _unsub_friend(user_id, peer_id)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/<int:peer_id>/relationships', methods=['GET'])
|
||||
@bp.route("/<int:peer_id>/relationships", methods=["GET"])
|
||||
async def get_mutual_friends(peer_id: int):
|
||||
"""Fetch a users' mutual friends with the current user."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -294,17 +351,15 @@ async def get_mutual_friends(peer_id: int):
|
|||
peer = await app.storage.get_user(peer_id)
|
||||
|
||||
if not peer:
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
# NOTE: maybe this could be better with pure SQL calculations
|
||||
# but it would be beyond my current SQL knowledge, so...
|
||||
user_rels = await app.user_storage.get_relationships(user_id)
|
||||
peer_rels = await app.user_storage.get_relationships(peer_id)
|
||||
|
||||
user_friends = {rel['user']['id']
|
||||
for rel in user_rels if rel['type'] == _friend}
|
||||
peer_friends = {rel['user']['id']
|
||||
for rel in peer_rels if rel['type'] == _friend}
|
||||
user_friends = {rel["user"]["id"] for rel in user_rels if rel["type"] == _friend}
|
||||
peer_friends = {rel["user"]["id"] for rel in peer_rels if rel["type"] == _friend}
|
||||
|
||||
# get the intersection, then map them to Storage.get_user() calls
|
||||
mutual_ids = user_friends & peer_friends
|
||||
|
|
@ -312,8 +367,6 @@ async def get_mutual_friends(peer_id: int):
|
|||
mutual_friends = []
|
||||
|
||||
for friend_id in mutual_ids:
|
||||
mutual_friends.append(
|
||||
await app.storage.get_user(int(friend_id))
|
||||
)
|
||||
mutual_friends.append(await app.storage.get_user(int(friend_id)))
|
||||
|
||||
return jsonify(mutual_friends)
|
||||
|
|
|
|||
|
|
@ -19,21 +19,19 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
from quart import Blueprint, jsonify
|
||||
|
||||
bp = Blueprint('science', __name__)
|
||||
bp = Blueprint("science", __name__)
|
||||
|
||||
|
||||
@bp.route('/science', methods=['POST'])
|
||||
@bp.route("/science", methods=["POST"])
|
||||
async def science():
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/applications', methods=['GET'])
|
||||
@bp.route("/applications", methods=["GET"])
|
||||
async def applications():
|
||||
return jsonify([])
|
||||
|
||||
|
||||
@bp.route('/experiments', methods=['GET'])
|
||||
@bp.route("/experiments", methods=["GET"])
|
||||
async def experiments():
|
||||
return jsonify({
|
||||
'assignments': []
|
||||
})
|
||||
return jsonify({"assignments": []})
|
||||
|
|
|
|||
|
|
@ -20,23 +20,24 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
from quart import Blueprint, current_app as app, render_template_string
|
||||
from pathlib import Path
|
||||
|
||||
bp = Blueprint('static', __name__)
|
||||
bp = Blueprint("static", __name__)
|
||||
|
||||
|
||||
@bp.route('/<path:path>')
|
||||
@bp.route("/<path:path>")
|
||||
async def static_pages(path):
|
||||
"""Map requests from / to /static."""
|
||||
if '..' in path:
|
||||
return 'no', 404
|
||||
if ".." in path:
|
||||
return "no", 404
|
||||
|
||||
static_path = Path.cwd() / Path('static') / path
|
||||
static_path = Path.cwd() / Path("static") / path
|
||||
return await app.send_static_file(str(static_path))
|
||||
|
||||
|
||||
@bp.route('/')
|
||||
@bp.route('/api')
|
||||
@bp.route("/")
|
||||
@bp.route("/api")
|
||||
async def index_handler():
|
||||
"""Handler for the index page."""
|
||||
index_path = Path.cwd() / Path('static') / 'index.html'
|
||||
index_path = Path.cwd() / Path("static") / "index.html"
|
||||
return await render_template_string(
|
||||
index_path.read_text(), inst_name=app.config['NAME'])
|
||||
index_path.read_text(), inst_name=app.config["NAME"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,4 +21,4 @@ from .billing import bp as user_billing
|
|||
from .settings import bp as user_settings
|
||||
from .fake_store import bp as fake_store
|
||||
|
||||
__all__ = ['user_billing', 'user_settings', 'fake_store']
|
||||
__all__ = ["user_billing", "user_settings", "fake_store"]
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from litecord.enums import UserFlags, PremiumType
|
|||
from litecord.blueprints.users import mass_user_update
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('users_billing', __name__)
|
||||
bp = Blueprint("users_billing", __name__)
|
||||
|
||||
|
||||
class PaymentSource(Enum):
|
||||
|
|
@ -68,78 +68,87 @@ class PaymentStatus:
|
|||
|
||||
|
||||
PLAN_ID_TO_TYPE = {
|
||||
'premium_month_tier_1': PremiumType.TIER_1,
|
||||
'premium_month_tier_2': PremiumType.TIER_2,
|
||||
'premium_year_tier_1': PremiumType.TIER_1,
|
||||
'premium_year_tier_2': PremiumType.TIER_2,
|
||||
"premium_month_tier_1": PremiumType.TIER_1,
|
||||
"premium_month_tier_2": PremiumType.TIER_2,
|
||||
"premium_year_tier_1": PremiumType.TIER_1,
|
||||
"premium_year_tier_2": PremiumType.TIER_2,
|
||||
}
|
||||
|
||||
|
||||
# how much should a payment be, depending
|
||||
# of the subscription
|
||||
AMOUNTS = {
|
||||
'premium_month_tier_1': 499,
|
||||
'premium_month_tier_2': 999,
|
||||
'premium_year_tier_1': 4999,
|
||||
'premium_year_tier_2': 9999,
|
||||
"premium_month_tier_1": 499,
|
||||
"premium_month_tier_2": 999,
|
||||
"premium_year_tier_1": 4999,
|
||||
"premium_year_tier_2": 9999,
|
||||
}
|
||||
|
||||
|
||||
CREATE_SUBSCRIPTION = {
|
||||
'payment_gateway_plan_id': {'type': 'string'},
|
||||
'payment_source_id': {'coerce': int}
|
||||
"payment_gateway_plan_id": {"type": "string"},
|
||||
"payment_source_id": {"coerce": int},
|
||||
}
|
||||
|
||||
|
||||
PAYMENT_SOURCE = {
|
||||
'billing_address': {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'country': {'type': 'string', 'required': True},
|
||||
'city': {'type': 'string', 'required': True},
|
||||
'name': {'type': 'string', 'required': True},
|
||||
'line_1': {'type': 'string', 'required': False},
|
||||
'line_2': {'type': 'string', 'required': False},
|
||||
'postal_code': {'type': 'string', 'required': True},
|
||||
'state': {'type': 'string', 'required': True},
|
||||
}
|
||||
"billing_address": {
|
||||
"type": "dict",
|
||||
"schema": {
|
||||
"country": {"type": "string", "required": True},
|
||||
"city": {"type": "string", "required": True},
|
||||
"name": {"type": "string", "required": True},
|
||||
"line_1": {"type": "string", "required": False},
|
||||
"line_2": {"type": "string", "required": False},
|
||||
"postal_code": {"type": "string", "required": True},
|
||||
"state": {"type": "string", "required": True},
|
||||
},
|
||||
},
|
||||
'payment_gateway': {'type': 'number', 'required': True},
|
||||
'token': {'type': 'string', 'required': True},
|
||||
"payment_gateway": {"type": "number", "required": True},
|
||||
"token": {"type": "string", "required": True},
|
||||
}
|
||||
|
||||
|
||||
async def get_payment_source_ids(user_id: int) -> list:
|
||||
rows = await app.db.fetch("""
|
||||
rows = await app.db.fetch(
|
||||
"""
|
||||
SELECT id
|
||||
FROM user_payment_sources
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return [r['id'] for r in rows]
|
||||
return [r["id"] for r in rows]
|
||||
|
||||
|
||||
async def get_payment_ids(user_id: int, db=None) -> list:
|
||||
if not db:
|
||||
db = app.db
|
||||
|
||||
rows = await db.fetch("""
|
||||
rows = await db.fetch(
|
||||
"""
|
||||
SELECT id
|
||||
FROM user_payments
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return [r['id'] for r in rows]
|
||||
return [r["id"] for r in rows]
|
||||
|
||||
|
||||
async def get_subscription_ids(user_id: int) -> list:
|
||||
rows = await app.db.fetch("""
|
||||
rows = await app.db.fetch(
|
||||
"""
|
||||
SELECT id
|
||||
FROM user_subscriptions
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return [r['id'] for r in rows]
|
||||
return [r["id"] for r in rows]
|
||||
|
||||
|
||||
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
||||
|
|
@ -148,41 +157,44 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
|||
if not db:
|
||||
db = app.db
|
||||
|
||||
source_type = await db.fetchval("""
|
||||
source_type = await db.fetchval(
|
||||
"""
|
||||
SELECT source_type
|
||||
FROM user_payment_sources
|
||||
WHERE id = $1 AND user_id = $2
|
||||
""", source_id, user_id)
|
||||
""",
|
||||
source_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
source_type = PaymentSource(source_type)
|
||||
|
||||
specific_fields = {
|
||||
PaymentSource.PAYPAL: ['paypal_email'],
|
||||
PaymentSource.CREDIT: ['expires_month', 'expires_year',
|
||||
'brand', 'cc_full']
|
||||
PaymentSource.PAYPAL: ["paypal_email"],
|
||||
PaymentSource.CREDIT: ["expires_month", "expires_year", "brand", "cc_full"],
|
||||
}[source_type]
|
||||
|
||||
fields = ','.join(specific_fields)
|
||||
fields = ",".join(specific_fields)
|
||||
|
||||
extras_row = await db.fetchrow(f"""
|
||||
extras_row = await db.fetchrow(
|
||||
f"""
|
||||
SELECT {fields}, billing_address, default_, id::text
|
||||
FROM user_payment_sources
|
||||
WHERE id = $1
|
||||
""", source_id)
|
||||
""",
|
||||
source_id,
|
||||
)
|
||||
|
||||
derow = dict(extras_row)
|
||||
|
||||
if source_type == PaymentSource.CREDIT:
|
||||
derow['last_4'] = derow['cc_full'][-4:]
|
||||
derow.pop('cc_full')
|
||||
derow["last_4"] = derow["cc_full"][-4:]
|
||||
derow.pop("cc_full")
|
||||
|
||||
derow['default'] = derow['default_']
|
||||
derow.pop('default_')
|
||||
derow["default"] = derow["default_"]
|
||||
derow.pop("default_")
|
||||
|
||||
source = {
|
||||
'id': str(source_id),
|
||||
'type': source_type.value,
|
||||
}
|
||||
source = {"id": str(source_id), "type": source_type.value}
|
||||
|
||||
return {**source, **derow}
|
||||
|
||||
|
|
@ -192,7 +204,8 @@ async def get_subscription(subscription_id: int, db=None):
|
|||
if not db:
|
||||
db = app.db
|
||||
|
||||
row = await db.fetchrow("""
|
||||
row = await db.fetchrow(
|
||||
"""
|
||||
SELECT id::text, source_id::text AS payment_source_id,
|
||||
user_id,
|
||||
payment_gateway, payment_gateway_plan_id,
|
||||
|
|
@ -201,14 +214,16 @@ async def get_subscription(subscription_id: int, db=None):
|
|||
canceled_at, s_type, status
|
||||
FROM user_subscriptions
|
||||
WHERE id = $1
|
||||
""", subscription_id)
|
||||
""",
|
||||
subscription_id,
|
||||
)
|
||||
|
||||
drow = dict(row)
|
||||
|
||||
drow['type'] = drow['s_type']
|
||||
drow.pop('s_type')
|
||||
drow["type"] = drow["s_type"]
|
||||
drow.pop("s_type")
|
||||
|
||||
to_tstamp = ['current_period_start', 'current_period_end', 'canceled_at']
|
||||
to_tstamp = ["current_period_start", "current_period_end", "canceled_at"]
|
||||
|
||||
for field in to_tstamp:
|
||||
drow[field] = timestamp_(drow[field])
|
||||
|
|
@ -221,27 +236,30 @@ async def get_payment(payment_id: int, db=None):
|
|||
if not db:
|
||||
db = app.db
|
||||
|
||||
row = await db.fetchrow("""
|
||||
row = await db.fetchrow(
|
||||
"""
|
||||
SELECT id::text, source_id, subscription_id, user_id,
|
||||
amount, amount_refunded, currency,
|
||||
description, status, tax, tax_inclusive
|
||||
FROM user_payments
|
||||
WHERE id = $1
|
||||
""", payment_id)
|
||||
""",
|
||||
payment_id,
|
||||
)
|
||||
|
||||
drow = dict(row)
|
||||
|
||||
drow.pop('source_id')
|
||||
drow.pop('subscription_id')
|
||||
drow.pop('user_id')
|
||||
drow.pop("source_id")
|
||||
drow.pop("subscription_id")
|
||||
drow.pop("user_id")
|
||||
|
||||
drow['created_at'] = snowflake_datetime(int(drow['id']))
|
||||
drow["created_at"] = snowflake_datetime(int(drow["id"]))
|
||||
|
||||
drow['payment_source'] = await get_payment_source(
|
||||
row['user_id'], row['source_id'], db)
|
||||
drow["payment_source"] = await get_payment_source(
|
||||
row["user_id"], row["source_id"], db
|
||||
)
|
||||
|
||||
drow['subscription'] = await get_subscription(
|
||||
row['subscription_id'], db)
|
||||
drow["subscription"] = await get_subscription(row["subscription_id"], db)
|
||||
|
||||
return drow
|
||||
|
||||
|
|
@ -255,7 +273,7 @@ async def create_payment(subscription_id, db=None):
|
|||
|
||||
new_id = get_snowflake()
|
||||
|
||||
amount = AMOUNTS[sub['payment_gateway_plan_id']]
|
||||
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
|
||||
|
||||
await db.execute(
|
||||
"""
|
||||
|
|
@ -266,10 +284,16 @@ async def create_payment(subscription_id, db=None):
|
|||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
|
||||
""", new_id, int(sub['payment_source_id']),
|
||||
subscription_id, int(sub['user_id']),
|
||||
amount, 'usd', 'FUCK NITRO',
|
||||
PaymentStatus.SUCCESS)
|
||||
""",
|
||||
new_id,
|
||||
int(sub["payment_source_id"]),
|
||||
subscription_id,
|
||||
int(sub["user_id"]),
|
||||
amount,
|
||||
"usd",
|
||||
"FUCK NITRO",
|
||||
PaymentStatus.SUCCESS,
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
|
|
@ -278,29 +302,34 @@ async def process_subscription(app, subscription_id: int):
|
|||
"""Process a single subscription."""
|
||||
sub = await get_subscription(subscription_id, app.db)
|
||||
|
||||
user_id = int(sub['user_id'])
|
||||
user_id = int(sub["user_id"])
|
||||
|
||||
if sub['status'] != SubscriptionStatus.ACTIVE:
|
||||
log.debug('ignoring sub {}, not active',
|
||||
subscription_id)
|
||||
if sub["status"] != SubscriptionStatus.ACTIVE:
|
||||
log.debug("ignoring sub {}, not active", subscription_id)
|
||||
return
|
||||
|
||||
# if the subscription is still active
|
||||
# (should get cancelled status on failed
|
||||
# payments), then we should update premium status
|
||||
first_payment_id = await app.db.fetchval("""
|
||||
first_payment_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT MIN(id)
|
||||
FROM user_payments
|
||||
WHERE subscription_id = $1
|
||||
""", subscription_id)
|
||||
""",
|
||||
subscription_id,
|
||||
)
|
||||
|
||||
first_payment_ts = snowflake_datetime(first_payment_id)
|
||||
|
||||
premium_since = await app.db.fetchval("""
|
||||
premium_since = await app.db.fetchval(
|
||||
"""
|
||||
SELECT premium_since
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
premium_since = premium_since or datetime.datetime.fromtimestamp(0)
|
||||
|
||||
|
|
@ -312,27 +341,34 @@ async def process_subscription(app, subscription_id: int):
|
|||
if delta.total_seconds() < 24 * HOURS:
|
||||
return
|
||||
|
||||
old_flags = await app.db.fetchval("""
|
||||
old_flags = await app.db.fetchval(
|
||||
"""
|
||||
SELECT flags
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
new_flags = old_flags | UserFlags.premium_early
|
||||
log.debug('updating flags {}, {} => {}',
|
||||
user_id, old_flags, new_flags)
|
||||
log.debug("updating flags {}, {} => {}", user_id, old_flags, new_flags)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET premium_since = $1, flags = $2
|
||||
WHERE id = $3
|
||||
""", first_payment_ts, new_flags, user_id)
|
||||
""",
|
||||
first_payment_ts,
|
||||
new_flags,
|
||||
user_id,
|
||||
)
|
||||
|
||||
# dispatch updated user to all possible clients
|
||||
await mass_user_update(user_id, app)
|
||||
|
||||
|
||||
@bp.route('/@me/billing/payment-sources', methods=['GET'])
|
||||
@bp.route("/@me/billing/payment-sources", methods=["GET"])
|
||||
async def _get_billing_sources():
|
||||
user_id = await token_check()
|
||||
source_ids = await get_payment_source_ids(user_id)
|
||||
|
|
@ -346,7 +382,7 @@ async def _get_billing_sources():
|
|||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/@me/billing/subscriptions', methods=['GET'])
|
||||
@bp.route("/@me/billing/subscriptions", methods=["GET"])
|
||||
async def _get_billing_subscriptions():
|
||||
user_id = await token_check()
|
||||
sub_ids = await get_subscription_ids(user_id)
|
||||
|
|
@ -358,7 +394,7 @@ async def _get_billing_subscriptions():
|
|||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/@me/billing/payments', methods=['GET'])
|
||||
@bp.route("/@me/billing/payments", methods=["GET"])
|
||||
async def _get_billing_payments():
|
||||
user_id = await token_check()
|
||||
payment_ids = await get_payment_ids(user_id)
|
||||
|
|
@ -370,7 +406,7 @@ async def _get_billing_payments():
|
|||
return jsonify(res)
|
||||
|
||||
|
||||
@bp.route('/@me/billing/payment-sources', methods=['POST'])
|
||||
@bp.route("/@me/billing/payment-sources", methods=["POST"])
|
||||
async def _create_payment_source():
|
||||
user_id = await token_check()
|
||||
j = validate(await request.get_json(), PAYMENT_SOURCE)
|
||||
|
|
@ -383,34 +419,40 @@ async def _create_payment_source():
|
|||
default_, expires_month, expires_year, brand, cc_full,
|
||||
billing_address)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
""", new_source_id, user_id, PaymentSource.CREDIT.value,
|
||||
True, 12, 6969, 'Visa', '4242424242424242',
|
||||
json.dumps(j['billing_address']))
|
||||
|
||||
return jsonify(
|
||||
await get_payment_source(user_id, new_source_id)
|
||||
""",
|
||||
new_source_id,
|
||||
user_id,
|
||||
PaymentSource.CREDIT.value,
|
||||
True,
|
||||
12,
|
||||
6969,
|
||||
"Visa",
|
||||
"4242424242424242",
|
||||
json.dumps(j["billing_address"]),
|
||||
)
|
||||
|
||||
return jsonify(await get_payment_source(user_id, new_source_id))
|
||||
|
||||
@bp.route('/@me/billing/subscriptions', methods=['POST'])
|
||||
|
||||
@bp.route("/@me/billing/subscriptions", methods=["POST"])
|
||||
async def _create_subscription():
|
||||
user_id = await token_check()
|
||||
j = validate(await request.get_json(), CREATE_SUBSCRIPTION)
|
||||
|
||||
source = await get_payment_source(user_id, j['payment_source_id'])
|
||||
source = await get_payment_source(user_id, j["payment_source_id"])
|
||||
if not source:
|
||||
raise BadRequest('invalid source id')
|
||||
raise BadRequest("invalid source id")
|
||||
|
||||
plan_id = j['payment_gateway_plan_id']
|
||||
plan_id = j["payment_gateway_plan_id"]
|
||||
|
||||
# tier 1 is lightro / classic
|
||||
# tier 2 is nitro
|
||||
|
||||
period_end = {
|
||||
'premium_month_tier_1': '1 month',
|
||||
'premium_month_tier_2': '1 month',
|
||||
'premium_year_tier_1': '1 year',
|
||||
'premium_year_tier_2': '1 year',
|
||||
"premium_month_tier_1": "1 month",
|
||||
"premium_month_tier_2": "1 month",
|
||||
"premium_year_tier_1": "1 year",
|
||||
"premium_year_tier_2": "1 year",
|
||||
}[plan_id]
|
||||
|
||||
new_id = get_snowflake()
|
||||
|
|
@ -422,9 +464,15 @@ async def _create_subscription():
|
|||
status, period_end)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7,
|
||||
now()::timestamp + interval '{period_end}')
|
||||
""", new_id, j['payment_source_id'], user_id,
|
||||
SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
|
||||
plan_id, 1)
|
||||
""",
|
||||
new_id,
|
||||
j["payment_source_id"],
|
||||
user_id,
|
||||
SubscriptionType.PURCHASE,
|
||||
PaymentGateway.STRIPE,
|
||||
plan_id,
|
||||
1,
|
||||
)
|
||||
|
||||
await create_payment(new_id, app.db)
|
||||
|
||||
|
|
@ -432,21 +480,17 @@ async def _create_subscription():
|
|||
# and dispatch respective user updates to other people.
|
||||
await process_subscription(app, new_id)
|
||||
|
||||
return jsonify(
|
||||
await get_subscription(new_id)
|
||||
)
|
||||
return jsonify(await get_subscription(new_id))
|
||||
|
||||
|
||||
@bp.route('/@me/billing/subscriptions/<int:subscription_id>',
|
||||
methods=['DELETE'])
|
||||
@bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["DELETE"])
|
||||
async def _delete_subscription(subscription_id):
|
||||
# user_id = await token_check()
|
||||
# return '', 204
|
||||
pass
|
||||
|
||||
|
||||
@bp.route('/@me/billing/subscriptions/<int:subscription_id>',
|
||||
methods=['PATCH'])
|
||||
@bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["PATCH"])
|
||||
async def _patch_subscription(subscription_id):
|
||||
"""change a subscription's payment source"""
|
||||
# user_id = await token_check()
|
||||
|
|
|
|||
|
|
@ -25,8 +25,11 @@ from asyncio import sleep, CancelledError
|
|||
from logbook import Logger
|
||||
|
||||
from litecord.blueprints.user.billing import (
|
||||
get_subscription, get_payment_ids, get_payment, create_payment,
|
||||
process_subscription
|
||||
get_subscription,
|
||||
get_payment_ids,
|
||||
get_payment,
|
||||
create_payment,
|
||||
process_subscription,
|
||||
)
|
||||
|
||||
from litecord.snowflake import snowflake_datetime
|
||||
|
|
@ -37,15 +40,15 @@ log = Logger(__name__)
|
|||
# how many days until a payment needs
|
||||
# to be issued
|
||||
THRESHOLDS = {
|
||||
'premium_month_tier_1': 30,
|
||||
'premium_month_tier_2': 30,
|
||||
'premium_year_tier_1': 365,
|
||||
'premium_year_tier_2': 365,
|
||||
"premium_month_tier_1": 30,
|
||||
"premium_month_tier_2": 30,
|
||||
"premium_year_tier_1": 365,
|
||||
"premium_year_tier_2": 365,
|
||||
}
|
||||
|
||||
|
||||
async def _resched(app):
|
||||
log.debug('waiting 30 minutes for job.')
|
||||
log.debug("waiting 30 minutes for job.")
|
||||
await sleep(30 * MINUTES)
|
||||
app.sched.spawn(payment_job(app))
|
||||
|
||||
|
|
@ -54,10 +57,10 @@ async def _process_user_payments(app, user_id: int):
|
|||
payments = await get_payment_ids(user_id, app.db)
|
||||
|
||||
if not payments:
|
||||
log.debug('no payments for uid {}, skipping', user_id)
|
||||
log.debug("no payments for uid {}, skipping", user_id)
|
||||
return
|
||||
|
||||
log.debug('{} payments for uid {}', len(payments), user_id)
|
||||
log.debug("{} payments for uid {}", len(payments), user_id)
|
||||
|
||||
latest_payment = max(payments)
|
||||
|
||||
|
|
@ -66,33 +69,29 @@ async def _process_user_payments(app, user_id: int):
|
|||
# calculate the difference between this payment
|
||||
# and now.
|
||||
now = datetime.datetime.now()
|
||||
payment_tstamp = snowflake_datetime(int(payment_data['id']))
|
||||
payment_tstamp = snowflake_datetime(int(payment_data["id"]))
|
||||
|
||||
delta = now - payment_tstamp
|
||||
|
||||
sub_id = int(payment_data['subscription']['id'])
|
||||
subscription = await get_subscription(
|
||||
sub_id, app.db)
|
||||
sub_id = int(payment_data["subscription"]["id"])
|
||||
subscription = await get_subscription(sub_id, app.db)
|
||||
|
||||
# if the max payment is X days old, we create another.
|
||||
# X is 30 for monthly subscriptions of nitro,
|
||||
# X is 365 for yearly subscriptions of nitro
|
||||
threshold = THRESHOLDS[subscription['payment_gateway_plan_id']]
|
||||
threshold = THRESHOLDS[subscription["payment_gateway_plan_id"]]
|
||||
|
||||
log.debug('delta {} delta days {} threshold {}',
|
||||
delta, delta.days, threshold)
|
||||
log.debug("delta {} delta days {} threshold {}", delta, delta.days, threshold)
|
||||
|
||||
if delta.days > threshold:
|
||||
log.info('creating payment for sid={}',
|
||||
sub_id)
|
||||
log.info("creating payment for sid={}", sub_id)
|
||||
|
||||
# create_payment does not call any Stripe
|
||||
# or BrainTree APIs at all, since we'll just
|
||||
# give it as free.
|
||||
await create_payment(sub_id, app.db)
|
||||
else:
|
||||
log.debug('sid={}, missing {} days',
|
||||
sub_id, threshold - delta.days)
|
||||
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
|
||||
|
||||
|
||||
async def payment_job(app):
|
||||
|
|
@ -101,35 +100,39 @@ async def payment_job(app):
|
|||
This function will check through users' payments
|
||||
and add a new one once a month / year.
|
||||
"""
|
||||
log.debug('payment job start!')
|
||||
log.debug("payment job start!")
|
||||
|
||||
user_ids = await app.db.fetch("""
|
||||
user_ids = await app.db.fetch(
|
||||
"""
|
||||
SELECT DISTINCT user_id
|
||||
FROM user_payments
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
log.debug('working {} users', len(user_ids))
|
||||
log.debug("working {} users", len(user_ids))
|
||||
|
||||
# go through each user's payments
|
||||
for row in user_ids:
|
||||
user_id = row['user_id']
|
||||
user_id = row["user_id"]
|
||||
try:
|
||||
await _process_user_payments(app, user_id)
|
||||
except Exception:
|
||||
log.exception('error while processing user payments')
|
||||
log.exception("error while processing user payments")
|
||||
|
||||
subscribers = await app.db.fetch("""
|
||||
subscribers = await app.db.fetch(
|
||||
"""
|
||||
SELECT id
|
||||
FROM user_subscriptions
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
for row in subscribers:
|
||||
try:
|
||||
await process_subscription(app, row['id'])
|
||||
await process_subscription(app, row["id"])
|
||||
except Exception:
|
||||
log.exception('error while processing subscription')
|
||||
log.debug('rescheduling..')
|
||||
log.exception("error while processing subscription")
|
||||
log.debug("rescheduling..")
|
||||
try:
|
||||
await _resched(app)
|
||||
except CancelledError:
|
||||
log.info('cancelled while waiting for resched')
|
||||
log.info("cancelled while waiting for resched")
|
||||
|
|
|
|||
|
|
@ -22,24 +22,26 @@ fake routes for discord store
|
|||
"""
|
||||
from quart import Blueprint, jsonify
|
||||
|
||||
bp = Blueprint('fake_store', __name__)
|
||||
bp = Blueprint("fake_store", __name__)
|
||||
|
||||
|
||||
@bp.route('/promotions')
|
||||
@bp.route("/promotions")
|
||||
async def _get_promotions():
|
||||
return jsonify([])
|
||||
|
||||
|
||||
@bp.route('/users/@me/library')
|
||||
@bp.route("/users/@me/library")
|
||||
async def _get_library():
|
||||
return jsonify([])
|
||||
|
||||
|
||||
@bp.route('/users/@me/feed/settings')
|
||||
@bp.route("/users/@me/feed/settings")
|
||||
async def _get_feed_settings():
|
||||
return jsonify({
|
||||
'subscribed_games': [],
|
||||
'subscribed_users': [],
|
||||
'unsubscribed_users': [],
|
||||
'unsubscribed_games': [],
|
||||
})
|
||||
return jsonify(
|
||||
{
|
||||
"subscribed_games": [],
|
||||
"subscribed_users": [],
|
||||
"unsubscribed_users": [],
|
||||
"unsubscribed_games": [],
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,10 +23,10 @@ from litecord.auth import token_check
|
|||
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
|
||||
from litecord.blueprints.checks import guild_check
|
||||
|
||||
bp = Blueprint('users_settings', __name__)
|
||||
bp = Blueprint("users_settings", __name__)
|
||||
|
||||
|
||||
@bp.route('/@me/settings', methods=['GET'])
|
||||
@bp.route("/@me/settings", methods=["GET"])
|
||||
async def get_user_settings():
|
||||
"""Get the current user's settings."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -34,7 +34,7 @@ async def get_user_settings():
|
|||
return jsonify(settings)
|
||||
|
||||
|
||||
@bp.route('/@me/settings', methods=['PATCH'])
|
||||
@bp.route("/@me/settings", methods=["PATCH"])
|
||||
async def patch_current_settings():
|
||||
"""Patch the users' current settings.
|
||||
|
||||
|
|
@ -47,19 +47,22 @@ async def patch_current_settings():
|
|||
for key in j:
|
||||
val = j[key]
|
||||
|
||||
await app.storage.execute_with_json(f"""
|
||||
await app.storage.execute_with_json(
|
||||
f"""
|
||||
UPDATE user_settings
|
||||
SET {key}=$1
|
||||
WHERE id = $2
|
||||
""", val, user_id)
|
||||
""",
|
||||
val,
|
||||
user_id,
|
||||
)
|
||||
|
||||
settings = await app.user_storage.get_user_settings(user_id)
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, 'USER_SETTINGS_UPDATE', settings)
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
|
||||
return jsonify(settings)
|
||||
|
||||
|
||||
@bp.route('/@me/guilds/<int:guild_id>/settings', methods=['PATCH'])
|
||||
@bp.route("/@me/guilds/<int:guild_id>/settings", methods=["PATCH"])
|
||||
async def patch_guild_settings(guild_id: int):
|
||||
"""Update the users' guild settings for a given guild.
|
||||
|
||||
|
|
@ -74,16 +77,21 @@ async def patch_guild_settings(guild_id: int):
|
|||
# will make sure they exist in the table.
|
||||
await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
||||
|
||||
for field in (k for k in j.keys() if k != 'channel_overrides'):
|
||||
await app.db.execute(f"""
|
||||
for field in (k for k in j.keys() if k != "channel_overrides"):
|
||||
await app.db.execute(
|
||||
f"""
|
||||
UPDATE guild_settings
|
||||
SET {field} = $1
|
||||
WHERE user_id = $2 AND guild_id = $3
|
||||
""", j[field], user_id, guild_id)
|
||||
""",
|
||||
j[field],
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
chan_ids = await app.storage.get_channel_ids(guild_id)
|
||||
|
||||
for chandata in j.get('channel_overrides', {}).items():
|
||||
for chandata in j.get("channel_overrides", {}).items():
|
||||
chan_id, chan_overrides = chandata
|
||||
chan_id = int(chan_id)
|
||||
|
||||
|
|
@ -92,7 +100,8 @@ async def patch_guild_settings(guild_id: int):
|
|||
continue
|
||||
|
||||
for field in chan_overrides:
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
INSERT INTO guild_settings_channel_overrides
|
||||
(user_id, guild_id, channel_id, {field})
|
||||
VALUES
|
||||
|
|
@ -105,18 +114,21 @@ async def patch_guild_settings(guild_id: int):
|
|||
WHERE guild_settings_channel_overrides.user_id = $1
|
||||
AND guild_settings_channel_overrides.guild_id = $2
|
||||
AND guild_settings_channel_overrides.channel_id = $3
|
||||
""", user_id, guild_id, chan_id, chan_overrides[field])
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
chan_id,
|
||||
chan_overrides[field],
|
||||
)
|
||||
|
||||
settings = await app.user_storage.get_guild_settings_one(
|
||||
user_id, guild_id)
|
||||
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
||||
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, 'USER_GUILD_SETTINGS_UPDATE', settings)
|
||||
await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
|
||||
|
||||
return jsonify(settings)
|
||||
|
||||
|
||||
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
|
||||
@bp.route("/@me/notes/<int:target_id>", methods=["PUT"])
|
||||
async def put_note(target_id: int):
|
||||
"""Put a note to a user.
|
||||
|
||||
|
|
@ -126,10 +138,11 @@ async def put_note(target_id: int):
|
|||
user_id = await token_check()
|
||||
|
||||
j = await request.get_json()
|
||||
note = str(j['note'])
|
||||
note = str(j["note"])
|
||||
|
||||
# UPSERTs are beautiful
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO notes (user_id, target_id, note)
|
||||
VALUES ($1, $2, $3)
|
||||
|
||||
|
|
@ -138,12 +151,14 @@ async def put_note(target_id: int):
|
|||
note = $3
|
||||
WHERE notes.user_id = $1
|
||||
AND notes.target_id = $2
|
||||
""", user_id, target_id, note)
|
||||
""",
|
||||
user_id,
|
||||
target_id,
|
||||
note,
|
||||
)
|
||||
|
||||
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
|
||||
'id': str(target_id),
|
||||
'note': note,
|
||||
})
|
||||
|
||||
return '', 204
|
||||
await app.dispatcher.dispatch_user(
|
||||
user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
|
||||
)
|
||||
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -27,9 +27,7 @@ from ..errors import Forbidden, BadRequest, Unauthorized
|
|||
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
|
||||
|
||||
from .guilds import guild_check
|
||||
from litecord.auth import (
|
||||
token_check, hash_data, check_username_usage, roll_discrim
|
||||
)
|
||||
from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
|
||||
from litecord.blueprints.guild.mod import remove_member
|
||||
|
||||
from litecord.enums import PremiumType
|
||||
|
|
@ -39,7 +37,7 @@ from litecord.permissions import base_permissions
|
|||
from litecord.blueprints.auth import check_password
|
||||
from litecord.utils import to_update
|
||||
|
||||
bp = Blueprint('user', __name__)
|
||||
bp = Blueprint("user", __name__)
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
|
|
@ -58,8 +56,7 @@ async def mass_user_update(user_id, app_=None):
|
|||
private_user = await app_.storage.get_user(user_id, secure=True)
|
||||
|
||||
session_ids.extend(
|
||||
await app_.dispatcher.dispatch_user(
|
||||
user_id, 'USER_UPDATE', private_user)
|
||||
await app_.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
|
||||
)
|
||||
|
||||
guild_ids = await app_.user_storage.get_user_guilds(user_id)
|
||||
|
|
@ -67,26 +64,22 @@ async def mass_user_update(user_id, app_=None):
|
|||
|
||||
session_ids.extend(
|
||||
await app_.dispatcher.dispatch_many_filter_list(
|
||||
'guild', guild_ids, session_ids,
|
||||
'USER_UPDATE', public_user
|
||||
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
|
||||
)
|
||||
)
|
||||
|
||||
session_ids.extend(
|
||||
await app_.dispatcher.dispatch_many_filter_list(
|
||||
'friend', friend_ids, session_ids,
|
||||
'USER_UPDATE', public_user
|
||||
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
|
||||
)
|
||||
)
|
||||
|
||||
await app_.dispatcher.dispatch_many(
|
||||
'lazy_guild', guild_ids, 'update_user', user_id
|
||||
)
|
||||
await app_.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
|
||||
|
||||
return public_user, private_user
|
||||
|
||||
|
||||
@bp.route('/@me', methods=['GET'])
|
||||
@bp.route("/@me", methods=["GET"])
|
||||
async def get_me():
|
||||
"""Get the current user's information."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -94,18 +87,21 @@ async def get_me():
|
|||
return jsonify(user)
|
||||
|
||||
|
||||
@bp.route('/<int:target_id>', methods=['GET'])
|
||||
@bp.route("/<int:target_id>", methods=["GET"])
|
||||
async def get_other(target_id):
|
||||
"""Get any user, given the user ID."""
|
||||
user_id = await token_check()
|
||||
|
||||
bot = await app.db.fetchval("""
|
||||
bot = await app.db.fetchval(
|
||||
"""
|
||||
SELECT bot FROM users
|
||||
WHERE users.id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
if not bot:
|
||||
raise Forbidden('Only bots can use this endpoint')
|
||||
raise Forbidden("Only bots can use this endpoint")
|
||||
|
||||
other = await app.storage.get_user(target_id)
|
||||
return jsonify(other)
|
||||
|
|
@ -116,66 +112,80 @@ async def _try_username_patch(user_id, new_username: str) -> str:
|
|||
discrim = None
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET username = $1
|
||||
WHERE users.id = $2
|
||||
""", new_username, user_id)
|
||||
""",
|
||||
new_username,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return await app.db.fetchval("""
|
||||
return await app.db.fetchval(
|
||||
"""
|
||||
SELECT discriminator
|
||||
FROM users
|
||||
WHERE users.id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
discrim = await roll_discrim(new_username)
|
||||
|
||||
if not discrim:
|
||||
raise BadRequest('Unable to change username', {
|
||||
'username': 'Too many people are with this username.'
|
||||
})
|
||||
raise BadRequest(
|
||||
"Unable to change username",
|
||||
{"username": "Too many people are with this username."},
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET username = $1, discriminator = $2
|
||||
WHERE users.id = $3
|
||||
""", new_username, discrim, user_id)
|
||||
""",
|
||||
new_username,
|
||||
discrim,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return discrim
|
||||
|
||||
|
||||
async def _try_discrim_patch(user_id, new_discrim: str):
|
||||
try:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET discriminator = $1
|
||||
WHERE id = $2
|
||||
""", new_discrim, user_id)
|
||||
""",
|
||||
new_discrim,
|
||||
user_id,
|
||||
)
|
||||
except UniqueViolationError:
|
||||
raise BadRequest('Invalid discriminator', {
|
||||
'discriminator': 'Someone already used this discriminator.'
|
||||
})
|
||||
raise BadRequest(
|
||||
"Invalid discriminator",
|
||||
{"discriminator": "Someone already used this discriminator."},
|
||||
)
|
||||
|
||||
|
||||
async def _check_pass(j, user):
|
||||
# Do not do password checks on unclaimed accounts
|
||||
if user['email'] is None:
|
||||
if user["email"] is None:
|
||||
return
|
||||
|
||||
if not j['password']:
|
||||
raise BadRequest('password required', {
|
||||
'password': 'password required'
|
||||
})
|
||||
if not j["password"]:
|
||||
raise BadRequest("password required", {"password": "password required"})
|
||||
|
||||
phash = user['password_hash']
|
||||
phash = user["password_hash"]
|
||||
|
||||
if not await check_password(phash, j['password']):
|
||||
raise BadRequest('password incorrect', {
|
||||
'password': 'password does not match.'
|
||||
})
|
||||
if not await check_password(phash, j["password"]):
|
||||
raise BadRequest("password incorrect", {"password": "password does not match."})
|
||||
|
||||
|
||||
@bp.route('/@me', methods=['PATCH'])
|
||||
@bp.route("/@me", methods=["PATCH"])
|
||||
async def patch_me():
|
||||
"""Patch the current user's information."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -183,36 +193,43 @@ async def patch_me():
|
|||
j = validate(await request.get_json(), USER_UPDATE)
|
||||
user = await app.storage.get_user(user_id, True)
|
||||
|
||||
user['password_hash'] = await app.db.fetchval("""
|
||||
user["password_hash"] = await app.db.fetchval(
|
||||
"""
|
||||
SELECT password_hash
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
if to_update(j, user, 'username'):
|
||||
if to_update(j, user, "username"):
|
||||
# this will take care of regenning a new discriminator
|
||||
discrim = await _try_username_patch(user_id, j['username'])
|
||||
user['username'] = j['username']
|
||||
user['discriminator'] = discrim
|
||||
discrim = await _try_username_patch(user_id, j["username"])
|
||||
user["username"] = j["username"]
|
||||
user["discriminator"] = discrim
|
||||
|
||||
if to_update(j, user, 'discriminator'):
|
||||
if to_update(j, user, "discriminator"):
|
||||
# the API treats discriminators as integers,
|
||||
# but I work with strings on the database.
|
||||
new_discrim = str(j['discriminator'])
|
||||
new_discrim = str(j["discriminator"])
|
||||
|
||||
await _try_discrim_patch(user_id, new_discrim)
|
||||
user['discriminator'] = new_discrim
|
||||
user["discriminator"] = new_discrim
|
||||
|
||||
if to_update(j, user, 'email'):
|
||||
if to_update(j, user, "email"):
|
||||
await _check_pass(j, user)
|
||||
|
||||
# TODO: reverify the new email?
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET email = $1
|
||||
WHERE id = $2
|
||||
""", j['email'], user_id)
|
||||
user['email'] = j['email']
|
||||
""",
|
||||
j["email"],
|
||||
user_id,
|
||||
)
|
||||
user["email"] = j["email"]
|
||||
|
||||
# only update if values are different
|
||||
# from what the user gave.
|
||||
|
|
@ -224,44 +241,49 @@ async def patch_me():
|
|||
|
||||
# IconManager.update will take care of validating
|
||||
# the value once put()-ing
|
||||
if to_update(j, user, 'avatar'):
|
||||
mime, _ = parse_data_uri(j['avatar'])
|
||||
if to_update(j, user, "avatar"):
|
||||
mime, _ = parse_data_uri(j["avatar"])
|
||||
|
||||
if mime == 'image/gif' and user['premium_type'] == PremiumType.NONE:
|
||||
raise BadRequest('no gif without nitro')
|
||||
if mime == "image/gif" and user["premium_type"] == PremiumType.NONE:
|
||||
raise BadRequest("no gif without nitro")
|
||||
|
||||
new_icon = await app.icons.update(
|
||||
'user', user_id, j['avatar'], size=(128, 128))
|
||||
new_icon = await app.icons.update("user", user_id, j["avatar"], size=(128, 128))
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET avatar = $1
|
||||
WHERE id = $2
|
||||
""", new_icon.icon_hash, user_id)
|
||||
""",
|
||||
new_icon.icon_hash,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if user['email'] is None and not 'new_password' in j:
|
||||
raise BadRequest('missing password', {
|
||||
'password': 'Please set a password.'
|
||||
})
|
||||
if user["email"] is None and not "new_password" in j:
|
||||
raise BadRequest("missing password", {"password": "Please set a password."})
|
||||
|
||||
if 'new_password' in j and j['new_password']:
|
||||
if "new_password" in j and j["new_password"]:
|
||||
await _check_pass(j, user)
|
||||
|
||||
new_hash = await hash_data(j['new_password'])
|
||||
new_hash = await hash_data(j["new_password"])
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET password_hash = $1
|
||||
WHERE id = $2
|
||||
""", new_hash, user_id)
|
||||
""",
|
||||
new_hash,
|
||||
user_id,
|
||||
)
|
||||
|
||||
user.pop('password_hash')
|
||||
user.pop("password_hash")
|
||||
|
||||
_, private_user = await mass_user_update(user_id, app)
|
||||
return jsonify(private_user)
|
||||
|
||||
|
||||
@bp.route('/@me/guilds', methods=['GET'])
|
||||
@bp.route("/@me/guilds", methods=["GET"])
|
||||
async def get_me_guilds():
|
||||
"""Get partial user guilds."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -270,27 +292,30 @@ async def get_me_guilds():
|
|||
partials = []
|
||||
|
||||
for guild_id in guild_ids:
|
||||
partial = await app.db.fetchrow("""
|
||||
partial = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT id::text, name, icon, owner_id
|
||||
FROM guilds
|
||||
WHERE guilds.id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
partial = dict(partial)
|
||||
|
||||
user_perms = await base_permissions(user_id, guild_id)
|
||||
partial['permissions'] = user_perms.binary
|
||||
partial["permissions"] = user_perms.binary
|
||||
|
||||
partial['owner'] = partial['owner_id'] == user_id
|
||||
partial["owner"] = partial["owner_id"] == user_id
|
||||
|
||||
partial.pop('owner_id')
|
||||
partial.pop("owner_id")
|
||||
|
||||
partials.append(partial)
|
||||
|
||||
return jsonify(partials)
|
||||
|
||||
|
||||
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
|
||||
@bp.route("/@me/guilds/<int:guild_id>", methods=["DELETE"])
|
||||
async def leave_guild(guild_id: int):
|
||||
"""Leave a guild."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -298,7 +323,7 @@ async def leave_guild(guild_id: int):
|
|||
|
||||
await remove_member(guild_id, user_id)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
# @bp.route('/@me/connections', methods=['GET'])
|
||||
|
|
@ -306,7 +331,7 @@ async def get_connections():
|
|||
pass
|
||||
|
||||
|
||||
@bp.route('/@me/consent', methods=['GET', 'POST'])
|
||||
@bp.route("/@me/consent", methods=["GET", "POST"])
|
||||
async def get_consent():
|
||||
"""Always disable data collection.
|
||||
|
||||
|
|
@ -314,57 +339,58 @@ async def get_consent():
|
|||
by the client and ignores them, as they
|
||||
will always be false.
|
||||
"""
|
||||
return jsonify({
|
||||
'usage_statistics': {
|
||||
'consented': False,
|
||||
},
|
||||
'personalization': {
|
||||
'consented': False,
|
||||
return jsonify(
|
||||
{
|
||||
"usage_statistics": {"consented": False},
|
||||
"personalization": {"consented": False},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/@me/harvest', methods=['GET'])
|
||||
@bp.route("/@me/harvest", methods=["GET"])
|
||||
async def get_harvest():
|
||||
"""Dummy route"""
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/@me/activities/statistics/applications', methods=['GET'])
|
||||
@bp.route("/@me/activities/statistics/applications", methods=["GET"])
|
||||
async def get_stats_applications():
|
||||
"""Dummy route for info on gameplay time and such"""
|
||||
return jsonify([])
|
||||
|
||||
|
||||
@bp.route('/@me/library', methods=['GET'])
|
||||
@bp.route("/@me/library", methods=["GET"])
|
||||
async def get_library():
|
||||
"""Probably related to Discord Store?"""
|
||||
return jsonify([])
|
||||
|
||||
|
||||
@bp.route('/<int:peer_id>/profile', methods=['GET'])
|
||||
@bp.route("/<int:peer_id>/profile", methods=["GET"])
|
||||
async def get_profile(peer_id: int):
|
||||
"""Get a user's profile."""
|
||||
user_id = await token_check()
|
||||
peer = await app.storage.get_user(peer_id)
|
||||
|
||||
if not peer:
|
||||
return '', 404
|
||||
return "", 404
|
||||
|
||||
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
||||
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
||||
|
||||
# don't return a proper card if no guilds are being shared.
|
||||
if not mutuals and not friends:
|
||||
return '', 404
|
||||
return "", 404
|
||||
|
||||
# actual premium status is determined by that
|
||||
# column being NULL or not
|
||||
peer_premium = await app.db.fetchval("""
|
||||
peer_premium = await app.db.fetchval(
|
||||
"""
|
||||
SELECT premium_since
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", peer_id)
|
||||
""",
|
||||
peer_id,
|
||||
)
|
||||
|
||||
mutual_guilds = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
||||
mutual_res = []
|
||||
|
|
@ -372,45 +398,49 @@ async def get_profile(peer_id: int):
|
|||
# ascending sorting
|
||||
for guild_id in sorted(mutual_guilds):
|
||||
|
||||
nick = await app.db.fetchval("""
|
||||
nick = await app.db.fetchval(
|
||||
"""
|
||||
SELECT nickname
|
||||
FROM members
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, peer_id)
|
||||
""",
|
||||
guild_id,
|
||||
peer_id,
|
||||
)
|
||||
|
||||
mutual_res.append({
|
||||
'id': str(guild_id),
|
||||
'nick': nick,
|
||||
})
|
||||
mutual_res.append({"id": str(guild_id), "nick": nick})
|
||||
|
||||
return jsonify({
|
||||
'user': peer,
|
||||
'connected_accounts': [],
|
||||
'premium_since': peer_premium,
|
||||
'mutual_guilds': mutual_res,
|
||||
})
|
||||
return jsonify(
|
||||
{
|
||||
"user": peer,
|
||||
"connected_accounts": [],
|
||||
"premium_since": peer_premium,
|
||||
"mutual_guilds": mutual_res,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/@me/mentions', methods=['GET'])
|
||||
@bp.route("/@me/mentions", methods=["GET"])
|
||||
async def _get_mentions():
|
||||
user_id = await token_check()
|
||||
|
||||
j = validate(dict(request.args), GET_MENTIONS)
|
||||
|
||||
guild_query = 'AND messages.guild_id = $2' if 'guild_id' in j else ''
|
||||
role_query = "OR content LIKE '%<@&%'" if j['roles'] else ''
|
||||
everyone_query = "OR content LIKE '%@everyone%'" if j['everyone'] else ''
|
||||
mention_user = f'<@{user_id}>'
|
||||
guild_query = "AND messages.guild_id = $2" if "guild_id" in j else ""
|
||||
role_query = "OR content LIKE '%<@&%'" if j["roles"] else ""
|
||||
everyone_query = "OR content LIKE '%@everyone%'" if j["everyone"] else ""
|
||||
mention_user = f"<@{user_id}>"
|
||||
|
||||
args = [mention_user]
|
||||
|
||||
if guild_query:
|
||||
args.append(j['guild_id'])
|
||||
args.append(j["guild_id"])
|
||||
|
||||
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
||||
gids = ','.join(str(guild_id) for guild_id in guild_ids)
|
||||
gids = ",".join(str(guild_id) for guild_id in guild_ids)
|
||||
|
||||
rows = await app.db.fetch(f"""
|
||||
rows = await app.db.fetch(
|
||||
f"""
|
||||
SELECT messages.id
|
||||
FROM messages
|
||||
JOIN channels ON messages.channel_id = channels.id
|
||||
|
|
@ -423,20 +453,20 @@ async def _get_mentions():
|
|||
{guild_query}
|
||||
)
|
||||
LIMIT {j["limit"]}
|
||||
""", *args)
|
||||
""",
|
||||
*args,
|
||||
)
|
||||
|
||||
res = []
|
||||
for row in rows:
|
||||
message = await app.storage.get_message(row['id'])
|
||||
gid = int(message['guild_id'])
|
||||
message = await app.storage.get_message(row["id"])
|
||||
gid = int(message["guild_id"])
|
||||
|
||||
# ignore messages pre-messages.guild_id
|
||||
if gid not in guild_ids:
|
||||
continue
|
||||
|
||||
res.append(
|
||||
message
|
||||
)
|
||||
res.append(message)
|
||||
|
||||
return jsonify(res)
|
||||
|
||||
|
|
@ -449,18 +479,20 @@ def rand_hex(length: int = 8) -> str:
|
|||
async def _del_from_table(db, table: str, user_id: int):
|
||||
"""Delete a row from a table."""
|
||||
column = {
|
||||
'channel_overwrites': 'target_user',
|
||||
'user_settings': 'id',
|
||||
'group_dm_members': 'member_id'
|
||||
}.get(table, 'user_id')
|
||||
"channel_overwrites": "target_user",
|
||||
"user_settings": "id",
|
||||
"group_dm_members": "member_id",
|
||||
}.get(table, "user_id")
|
||||
|
||||
res = await db.execute(f"""
|
||||
res = await db.execute(
|
||||
f"""
|
||||
DELETE FROM {table}
|
||||
WHERE {column} = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
log.info('Deleting uid {} from {}, res: {!r}',
|
||||
user_id, table, res)
|
||||
log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
|
||||
|
||||
|
||||
async def delete_user(user_id, *, app_=None):
|
||||
|
|
@ -470,13 +502,14 @@ async def delete_user(user_id, *, app_=None):
|
|||
|
||||
db = app_.db
|
||||
|
||||
new_username = f'Deleted User {rand_hex()}'
|
||||
new_username = f"Deleted User {rand_hex()}"
|
||||
|
||||
# by using a random hex in password_hash
|
||||
# we break attempts at using the default '123' password hash
|
||||
# to issue valid tokens for deleted users.
|
||||
|
||||
await db.execute("""
|
||||
await db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET
|
||||
username = $1,
|
||||
|
|
@ -490,32 +523,39 @@ async def delete_user(user_id, *, app_=None):
|
|||
password_hash = $2
|
||||
WHERE
|
||||
id = $3
|
||||
""", new_username, rand_hex(32), user_id)
|
||||
""",
|
||||
new_username,
|
||||
rand_hex(32),
|
||||
user_id,
|
||||
)
|
||||
|
||||
# remove the user from various tables
|
||||
await _del_from_table(db, 'user_settings', user_id)
|
||||
await _del_from_table(db, 'user_payment_sources', user_id)
|
||||
await _del_from_table(db, 'user_subscriptions', user_id)
|
||||
await _del_from_table(db, 'user_payments', user_id)
|
||||
await _del_from_table(db, 'user_read_state', user_id)
|
||||
await _del_from_table(db, 'guild_settings', user_id)
|
||||
await _del_from_table(db, 'guild_settings_channel_overrides', user_id)
|
||||
await _del_from_table(db, "user_settings", user_id)
|
||||
await _del_from_table(db, "user_payment_sources", user_id)
|
||||
await _del_from_table(db, "user_subscriptions", user_id)
|
||||
await _del_from_table(db, "user_payments", user_id)
|
||||
await _del_from_table(db, "user_read_state", user_id)
|
||||
await _del_from_table(db, "guild_settings", user_id)
|
||||
await _del_from_table(db, "guild_settings_channel_overrides", user_id)
|
||||
|
||||
await db.execute("""
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM relationships
|
||||
WHERE user_id = $1 OR peer_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
# DMs are still maintained, but not the state.
|
||||
await _del_from_table(db, 'dm_channel_state', user_id)
|
||||
await _del_from_table(db, "dm_channel_state", user_id)
|
||||
|
||||
# NOTE: we don't delete the group dms the user is an owner of...
|
||||
# TODO: group dm owner reassign when the owner leaves a gdm
|
||||
await _del_from_table(db, 'group_dm_members', user_id)
|
||||
await _del_from_table(db, "group_dm_members", user_id)
|
||||
|
||||
await _del_from_table(db, 'members', user_id)
|
||||
await _del_from_table(db, 'member_roles', user_id)
|
||||
await _del_from_table(db, 'channel_overwrites', user_id)
|
||||
await _del_from_table(db, "members", user_id)
|
||||
await _del_from_table(db, "member_roles", user_id)
|
||||
await _del_from_table(db, "channel_overwrites", user_id)
|
||||
|
||||
# after updating the user, we send USER_UPDATE so that all the other
|
||||
# clients can refresh their caches on the now-deleted user
|
||||
|
|
@ -540,15 +580,12 @@ async def user_disconnect(user_id: int):
|
|||
await state.ws.ws.close(4000)
|
||||
|
||||
# force everyone to see the user as offline
|
||||
await app.presence.dispatch_pres(user_id, {
|
||||
'afk': False,
|
||||
'status': 'offline',
|
||||
'game': None,
|
||||
'since': 0,
|
||||
})
|
||||
await app.presence.dispatch_pres(
|
||||
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/@me/delete', methods=['POST'])
|
||||
@bp.route("/@me/delete", methods=["POST"])
|
||||
async def delete_account():
|
||||
"""Delete own account.
|
||||
|
||||
|
|
@ -560,29 +597,35 @@ async def delete_account():
|
|||
j = await request.get_json()
|
||||
|
||||
try:
|
||||
password = j['password']
|
||||
password = j["password"]
|
||||
except KeyError:
|
||||
raise BadRequest('password required')
|
||||
raise BadRequest("password required")
|
||||
|
||||
owned_guilds = await app.db.fetchval("""
|
||||
owned_guilds = await app.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM guilds
|
||||
WHERE owner_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
if owned_guilds > 0:
|
||||
raise BadRequest('You still own guilds.')
|
||||
raise BadRequest("You still own guilds.")
|
||||
|
||||
pwd_hash = await app.db.fetchval("""
|
||||
pwd_hash = await app.db.fetchval(
|
||||
"""
|
||||
SELECT password_hash
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
if not await check_password(pwd_hash, password):
|
||||
raise Unauthorized('password does not match')
|
||||
raise Unauthorized("password does not match")
|
||||
|
||||
await delete_user(user_id)
|
||||
await user_disconnect(user_id)
|
||||
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from quart import Blueprint, jsonify, current_app as app
|
|||
|
||||
from litecord.blueprints.auth import token_check
|
||||
|
||||
bp = Blueprint('voice', __name__)
|
||||
bp = Blueprint("voice", __name__)
|
||||
|
||||
|
||||
def _majority_region_count(regions: list) -> str:
|
||||
|
|
@ -39,12 +39,14 @@ def _majority_region_count(regions: list) -> str:
|
|||
|
||||
async def _choose_random_region() -> Optional[str]:
|
||||
"""Give a random voice region."""
|
||||
regions = await app.db.fetch("""
|
||||
regions = await app.db.fetch(
|
||||
"""
|
||||
SELECT id
|
||||
FROM voice_regions
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
regions = [r['id'] for r in regions]
|
||||
regions = [r["id"] for r in regions]
|
||||
|
||||
if not regions:
|
||||
return None
|
||||
|
|
@ -64,11 +66,14 @@ async def _majority_region_any(user_id) -> Optional[str]:
|
|||
res = []
|
||||
|
||||
for guild_id in guilds:
|
||||
region = await app.db.fetchval("""
|
||||
region = await app.db.fetchval(
|
||||
"""
|
||||
SELECT region
|
||||
FROM guilds
|
||||
WHERE id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
res.append(region)
|
||||
|
||||
|
|
@ -83,20 +88,23 @@ async def _majority_region_any(user_id) -> Optional[str]:
|
|||
async def majority_region(user_id: int) -> Optional[str]:
|
||||
"""Given a user ID, give the most likely region for the user to be
|
||||
happy with."""
|
||||
regions = await app.db.fetch("""
|
||||
regions = await app.db.fetch(
|
||||
"""
|
||||
SELECT region
|
||||
FROM guilds
|
||||
WHERE owner_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
if not regions:
|
||||
return await _majority_region_any(user_id)
|
||||
|
||||
regions = [r['region'] for r in regions]
|
||||
regions = [r["region"] for r in regions]
|
||||
return _majority_region_count(regions)
|
||||
|
||||
|
||||
@bp.route('/regions', methods=['GET'])
|
||||
@bp.route("/regions", methods=["GET"])
|
||||
async def voice_regions():
|
||||
"""Return voice regions."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -105,6 +113,6 @@ async def voice_regions():
|
|||
regions = await app.storage.all_voice_regions()
|
||||
|
||||
for region in regions:
|
||||
region['optimal'] = region['id'] == best_region
|
||||
region["optimal"] = region["id"] == best_region
|
||||
|
||||
return jsonify(regions)
|
||||
|
|
|
|||
|
|
@ -26,22 +26,28 @@ from quart import Blueprint, jsonify, current_app as app, request
|
|||
|
||||
from litecord.auth import token_check
|
||||
from litecord.blueprints.checks import (
|
||||
channel_check, channel_perm_check, guild_check, guild_perm_check
|
||||
channel_check,
|
||||
channel_perm_check,
|
||||
guild_check,
|
||||
guild_perm_check,
|
||||
)
|
||||
|
||||
from litecord.schemas import (
|
||||
validate, WEBHOOK_CREATE, WEBHOOK_UPDATE, WEBHOOK_MESSAGE_CREATE
|
||||
validate,
|
||||
WEBHOOK_CREATE,
|
||||
WEBHOOK_UPDATE,
|
||||
WEBHOOK_MESSAGE_CREATE,
|
||||
)
|
||||
from litecord.enums import ChannelType
|
||||
from litecord.snowflake import get_snowflake
|
||||
from litecord.utils import async_map
|
||||
from litecord.errors import (
|
||||
WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
|
||||
)
|
||||
from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
|
||||
|
||||
from litecord.blueprints.channel.messages import (
|
||||
msg_create_request, msg_create_check_content, msg_add_attachment,
|
||||
msg_guild_text_mentions
|
||||
msg_create_request,
|
||||
msg_create_check_content,
|
||||
msg_add_attachment,
|
||||
msg_guild_text_mentions,
|
||||
)
|
||||
from litecord.embed.sanitizer import fill_embed, fetch_raw_img
|
||||
from litecord.embed.messages import process_url_embed, is_media_url
|
||||
|
|
@ -50,30 +56,34 @@ from litecord.utils import pg_set_json
|
|||
from litecord.enums import MessageType
|
||||
from litecord.images import STATIC_IMAGE_MIMES
|
||||
|
||||
bp = Blueprint('webhooks', __name__)
|
||||
bp = Blueprint("webhooks", __name__)
|
||||
|
||||
|
||||
async def get_webhook(webhook_id: int, *,
|
||||
secure: bool=True) -> Optional[Dict[str, Any]]:
|
||||
async def get_webhook(
|
||||
webhook_id: int, *, secure: bool = True
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get a webhook data"""
|
||||
row = await app.db.fetchrow("""
|
||||
row = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT id::text, guild_id::text, channel_id::text, creator_id,
|
||||
name, avatar, token
|
||||
FROM webhooks
|
||||
WHERE id = $1
|
||||
""", webhook_id)
|
||||
""",
|
||||
webhook_id,
|
||||
)
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
drow = dict(row)
|
||||
|
||||
drow['user'] = await app.storage.get_user(row['creator_id'])
|
||||
drow.pop('creator_id')
|
||||
drow["user"] = await app.storage.get_user(row["creator_id"])
|
||||
drow.pop("creator_id")
|
||||
|
||||
if not secure:
|
||||
drow.pop('user')
|
||||
drow.pop('guild_id')
|
||||
drow.pop("user")
|
||||
drow.pop("guild_id")
|
||||
|
||||
return drow
|
||||
|
||||
|
|
@ -82,7 +92,7 @@ async def _webhook_check(channel_id):
|
|||
user_id = await token_check()
|
||||
|
||||
await channel_check(user_id, channel_id, only=ChannelType.GUILD_TEXT)
|
||||
await channel_perm_check(user_id, channel_id, 'manage_webhooks')
|
||||
await channel_perm_check(user_id, channel_id, "manage_webhooks")
|
||||
|
||||
return user_id
|
||||
|
||||
|
|
@ -91,17 +101,20 @@ async def _webhook_check_guild(guild_id):
|
|||
user_id = await token_check()
|
||||
|
||||
await guild_check(user_id, guild_id)
|
||||
await guild_perm_check(user_id, guild_id, 'manage_webhooks')
|
||||
await guild_perm_check(user_id, guild_id, "manage_webhooks")
|
||||
|
||||
return user_id
|
||||
|
||||
|
||||
async def _webhook_check_fw(webhook_id):
|
||||
"""Make a check from an incoming webhook id (fw = from webhook)."""
|
||||
guild_id = await app.db.fetchval("""
|
||||
guild_id = await app.db.fetchval(
|
||||
"""
|
||||
SELECT guild_id FROM webhooks
|
||||
WHERE id = $1
|
||||
""", webhook_id)
|
||||
""",
|
||||
webhook_id,
|
||||
)
|
||||
|
||||
if guild_id is None:
|
||||
raise WebhookNotFound()
|
||||
|
|
@ -110,42 +123,48 @@ async def _webhook_check_fw(webhook_id):
|
|||
|
||||
|
||||
async def _webhook_many(where_clause, arg: int):
|
||||
webhook_ids = await app.db.fetch(f"""
|
||||
webhook_ids = await app.db.fetch(
|
||||
f"""
|
||||
SELECT id
|
||||
FROM webhooks
|
||||
{where_clause}
|
||||
""", arg)
|
||||
|
||||
webhook_ids = [r['id'] for r in webhook_ids]
|
||||
|
||||
return jsonify(
|
||||
await async_map(get_webhook, webhook_ids)
|
||||
""",
|
||||
arg,
|
||||
)
|
||||
|
||||
webhook_ids = [r["id"] for r in webhook_ids]
|
||||
|
||||
return jsonify(await async_map(get_webhook, webhook_ids))
|
||||
|
||||
|
||||
async def webhook_token_check(webhook_id: int, webhook_token: str):
|
||||
"""token_check() equivalent for webhooks."""
|
||||
row = await app.db.fetchrow("""
|
||||
row = await app.db.fetchrow(
|
||||
"""
|
||||
SELECT guild_id, channel_id
|
||||
FROM webhooks
|
||||
WHERE id = $1 AND token = $2
|
||||
""", webhook_id, webhook_token)
|
||||
""",
|
||||
webhook_id,
|
||||
webhook_token,
|
||||
)
|
||||
|
||||
if row is None:
|
||||
raise Unauthorized('webhook not found or unauthorized')
|
||||
raise Unauthorized("webhook not found or unauthorized")
|
||||
|
||||
return row['guild_id'], row['channel_id']
|
||||
return row["guild_id"], row["channel_id"]
|
||||
|
||||
|
||||
async def _dispatch_webhook_update(guild_id: int, channel_id):
|
||||
await app.dispatcher.dispatch('guild', guild_id, 'WEBHOOKS_UPDATE', {
|
||||
'guild_id': str(guild_id),
|
||||
'channel_id': str(channel_id)
|
||||
})
|
||||
await app.dispatcher.dispatch(
|
||||
"guild",
|
||||
guild_id,
|
||||
"WEBHOOKS_UPDATE",
|
||||
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@bp.route('/channels/<int:channel_id>/webhooks', methods=['POST'])
|
||||
@bp.route("/channels/<int:channel_id>/webhooks", methods=["POST"])
|
||||
async def create_webhook(channel_id: int):
|
||||
"""Create a webhook given a channel."""
|
||||
user_id = await _webhook_check(channel_id)
|
||||
|
|
@ -162,8 +181,7 @@ async def create_webhook(channel_id: int):
|
|||
token = secrets.token_urlsafe(40)
|
||||
|
||||
webhook_icon = await app.icons.put(
|
||||
'user', webhook_id, j.get('avatar'),
|
||||
always_icon=True, size=(128, 128)
|
||||
"user", webhook_id, j.get("avatar"), always_icon=True, size=(128, 128)
|
||||
)
|
||||
|
||||
await app.db.execute(
|
||||
|
|
@ -173,36 +191,41 @@ async def create_webhook(channel_id: int):
|
|||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7)
|
||||
""",
|
||||
webhook_id, guild_id, channel_id, user_id,
|
||||
j['name'], webhook_icon.icon_hash, token
|
||||
webhook_id,
|
||||
guild_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
j["name"],
|
||||
webhook_icon.icon_hash,
|
||||
token,
|
||||
)
|
||||
|
||||
await _dispatch_webhook_update(guild_id, channel_id)
|
||||
return jsonify(await get_webhook(webhook_id))
|
||||
|
||||
|
||||
@bp.route('/channels/<int:channel_id>/webhooks', methods=['GET'])
|
||||
@bp.route("/channels/<int:channel_id>/webhooks", methods=["GET"])
|
||||
async def get_channel_webhook(channel_id: int):
|
||||
"""Get a list of webhooks in a channel"""
|
||||
await _webhook_check(channel_id)
|
||||
return await _webhook_many('WHERE channel_id = $1', channel_id)
|
||||
return await _webhook_many("WHERE channel_id = $1", channel_id)
|
||||
|
||||
|
||||
@bp.route('/guilds/<int:guild_id>/webhooks', methods=['GET'])
|
||||
@bp.route("/guilds/<int:guild_id>/webhooks", methods=["GET"])
|
||||
async def get_guild_webhook(guild_id):
|
||||
"""Get all webhooks in a guild"""
|
||||
await _webhook_check_guild(guild_id)
|
||||
return await _webhook_many('WHERE guild_id = $1', guild_id)
|
||||
return await _webhook_many("WHERE guild_id = $1", guild_id)
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>', methods=['GET'])
|
||||
@bp.route("/webhooks/<int:webhook_id>", methods=["GET"])
|
||||
async def get_single_webhook(webhook_id):
|
||||
"""Get a single webhook's information."""
|
||||
await _webhook_check_fw(webhook_id)
|
||||
return await jsonify(await get_webhook(webhook_id))
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['GET'])
|
||||
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"])
|
||||
async def get_tokened_webhook(webhook_id, webhook_token):
|
||||
"""Get a webhook using its token."""
|
||||
await webhook_token_check(webhook_id, webhook_token)
|
||||
|
|
@ -210,46 +233,58 @@ async def get_tokened_webhook(webhook_id, webhook_token):
|
|||
|
||||
|
||||
async def _update_webhook(webhook_id: int, j: dict):
|
||||
if 'name' in j:
|
||||
await app.db.execute("""
|
||||
if "name" in j:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE webhooks
|
||||
SET name = $1
|
||||
WHERE id = $2
|
||||
""", j['name'], webhook_id)
|
||||
""",
|
||||
j["name"],
|
||||
webhook_id,
|
||||
)
|
||||
|
||||
if 'channel_id' in j:
|
||||
await app.db.execute("""
|
||||
if "channel_id" in j:
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE webhooks
|
||||
SET channel_id = $1
|
||||
WHERE id = $2
|
||||
""", j['channel_id'], webhook_id)
|
||||
|
||||
if 'avatar' in j:
|
||||
new_icon = await app.icons.update(
|
||||
'user', webhook_id, j['avatar'], always_icon=True, size=(128, 128)
|
||||
""",
|
||||
j["channel_id"],
|
||||
webhook_id,
|
||||
)
|
||||
|
||||
await app.db.execute("""
|
||||
if "avatar" in j:
|
||||
new_icon = await app.icons.update(
|
||||
"user", webhook_id, j["avatar"], always_icon=True, size=(128, 128)
|
||||
)
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE webhooks
|
||||
SET icon = $1
|
||||
WHERE id = $2
|
||||
""", new_icon.icon_hash, webhook_id)
|
||||
""",
|
||||
new_icon.icon_hash,
|
||||
webhook_id,
|
||||
)
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>', methods=['PATCH'])
|
||||
@bp.route("/webhooks/<int:webhook_id>", methods=["PATCH"])
|
||||
async def modify_webhook(webhook_id: int):
|
||||
"""Patch a webhook."""
|
||||
_user_id, guild_id = await _webhook_check_fw(webhook_id)
|
||||
j = validate(await request.get_json(), WEBHOOK_UPDATE)
|
||||
|
||||
if 'channel_id' in j:
|
||||
if "channel_id" in j:
|
||||
# pre checks
|
||||
chan = await app.storage.get_channel(j['channel_id'])
|
||||
chan = await app.storage.get_channel(j["channel_id"])
|
||||
|
||||
# short-circuiting should ensure chan isn't none
|
||||
# by the time we do chan['guild_id']
|
||||
if chan and chan['guild_id'] != str(guild_id):
|
||||
raise ChannelNotFound('cant assign webhook to channel')
|
||||
if chan and chan["guild_id"] != str(guild_id):
|
||||
raise ChannelNotFound("cant assign webhook to channel")
|
||||
|
||||
await _update_webhook(webhook_id, j)
|
||||
|
||||
|
|
@ -257,20 +292,18 @@ async def modify_webhook(webhook_id: int):
|
|||
|
||||
# we don't need to cast channel_id to int since that isn't
|
||||
# used in the dispatcher call
|
||||
await _dispatch_webhook_update(guild_id, webhook['channel_id'])
|
||||
await _dispatch_webhook_update(guild_id, webhook["channel_id"])
|
||||
return jsonify(webhook)
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['PATCH'])
|
||||
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["PATCH"])
|
||||
async def modify_webhook_tokened(webhook_id, webhook_token):
|
||||
"""Modify a webhook, using its token."""
|
||||
guild_id, channel_id = await webhook_token_check(
|
||||
webhook_id, webhook_token)
|
||||
guild_id, channel_id = await webhook_token_check(webhook_id, webhook_token)
|
||||
|
||||
# forcefully pop() the channel id out of the schema
|
||||
# instead of making another, for simplicity's sake
|
||||
j = validate(await request.get_json(),
|
||||
WEBHOOK_UPDATE.pop('channel_id'))
|
||||
j = validate(await request.get_json(), WEBHOOK_UPDATE.pop("channel_id"))
|
||||
|
||||
await _update_webhook(webhook_id, j)
|
||||
await _dispatch_webhook_update(guild_id, channel_id)
|
||||
|
|
@ -281,35 +314,36 @@ async def delete_webhook(webhook_id: int):
|
|||
"""Delete a webhook."""
|
||||
webhook = await get_webhook(webhook_id)
|
||||
|
||||
res = await app.db.execute("""
|
||||
res = await app.db.execute(
|
||||
"""
|
||||
DELETE FROM webhooks
|
||||
WHERE id = $1
|
||||
""", webhook_id)
|
||||
""",
|
||||
webhook_id,
|
||||
)
|
||||
|
||||
if res.lower() == 'delete 0':
|
||||
if res.lower() == "delete 0":
|
||||
raise WebhookNotFound()
|
||||
|
||||
# only casting the guild id since that's whats used
|
||||
# on the dispatcher call.
|
||||
await _dispatch_webhook_update(
|
||||
int(webhook['guild_id']), webhook['channel_id']
|
||||
)
|
||||
await _dispatch_webhook_update(int(webhook["guild_id"]), webhook["channel_id"])
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>', methods=['DELETE'])
|
||||
@bp.route("/webhooks/<int:webhook_id>", methods=["DELETE"])
|
||||
async def del_webhook(webhook_id):
|
||||
"""Delete a webhook."""
|
||||
await _webhook_check_fw(webhook_id)
|
||||
await delete_webhook(webhook_id)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['DELETE'])
|
||||
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["DELETE"])
|
||||
async def del_webhook_tokened(webhook_id, webhook_token):
|
||||
"""Delete a webhook, with its token."""
|
||||
await webhook_token_check(webhook_id, webhook_token)
|
||||
await delete_webhook(webhook_id)
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
|
||||
async def create_message_webhook(guild_id, channel_id, webhook_id, data):
|
||||
|
|
@ -328,23 +362,27 @@ async def create_message_webhook(guild_id, channel_id, webhook_id, data):
|
|||
message_id,
|
||||
channel_id,
|
||||
guild_id,
|
||||
data['content'],
|
||||
|
||||
data['tts'],
|
||||
data['everyone_mention'],
|
||||
|
||||
data["content"],
|
||||
data["tts"],
|
||||
data["everyone_mention"],
|
||||
MessageType.DEFAULT.value,
|
||||
data.get('embeds', [])
|
||||
data.get("embeds", []),
|
||||
)
|
||||
|
||||
info = data['info']
|
||||
info = data["info"]
|
||||
|
||||
await conn.execute("""
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO message_webhook_info
|
||||
(message_id, webhook_id, name, avatar)
|
||||
VALUES
|
||||
($1, $2, $3, $4)
|
||||
""", message_id, webhook_id, info['name'], info['avatar'])
|
||||
""",
|
||||
message_id,
|
||||
webhook_id,
|
||||
info["name"],
|
||||
info["avatar"],
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
|
@ -354,10 +392,15 @@ async def _webhook_avy_redir(webhook_id: int, avatar_url: EmbedURL):
|
|||
url_hash = hashlib.sha256(avatar_url.to_md_path.encode()).hexdigest()
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO webhook_avatars (webhook_id, hash, md_url_redir)
|
||||
VALUES ($1, $2, $3)
|
||||
""", webhook_id, url_hash, avatar_url.url)
|
||||
""",
|
||||
webhook_id,
|
||||
url_hash,
|
||||
avatar_url.url,
|
||||
)
|
||||
except asyncpg.UniqueViolationError:
|
||||
pass
|
||||
|
||||
|
|
@ -371,36 +414,36 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
|
|||
Litecord will write an URL that redirects to the given avatar_url,
|
||||
using mediaproxy.
|
||||
"""
|
||||
if avatar_url.scheme not in ('http', 'https'):
|
||||
raise BadRequest('invalid avatar url scheme')
|
||||
if avatar_url.scheme not in ("http", "https"):
|
||||
raise BadRequest("invalid avatar url scheme")
|
||||
|
||||
if not is_media_url(avatar_url):
|
||||
raise BadRequest('url is not media url')
|
||||
raise BadRequest("url is not media url")
|
||||
|
||||
# we still fetch the URL to check its validity, mimetypes, etc
|
||||
# but in the end, we will store it under the webhook_avatars table,
|
||||
# not IconManager.
|
||||
resp, raw = await fetch_raw_img(avatar_url)
|
||||
#raw_b64 = base64.b64encode(raw).decode()
|
||||
# raw_b64 = base64.b64encode(raw).decode()
|
||||
|
||||
mime = resp.headers['content-type']
|
||||
mime = resp.headers["content-type"]
|
||||
|
||||
# TODO: apng checks are missing (for this and everywhere else)
|
||||
if mime not in STATIC_IMAGE_MIMES:
|
||||
raise BadRequest('invalid mime type for given url')
|
||||
raise BadRequest("invalid mime type for given url")
|
||||
|
||||
#b64_data = f'data:{mime};base64,{raw_b64}'
|
||||
# b64_data = f'data:{mime};base64,{raw_b64}'
|
||||
|
||||
# TODO: replace this by webhook_avatars
|
||||
#icon = await app.icons.put(
|
||||
# icon = await app.icons.put(
|
||||
# 'user', webhook_id, b64_data,
|
||||
# always_icon=True, size=(128, 128)
|
||||
#)
|
||||
# )
|
||||
|
||||
return await _webhook_avy_redir(webhook_id, avatar_url)
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['POST'])
|
||||
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["POST"])
|
||||
async def execute_webhook(webhook_id: int, webhook_token):
|
||||
"""Execute a webhook. Sends a message to the channel the webhook
|
||||
is tied to."""
|
||||
|
|
@ -413,41 +456,39 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
|||
# NOTE: we really pop here instead of adding a kwarg
|
||||
# to msg_create_request just because of webhooks.
|
||||
# nonce isn't allowed on WEBHOOK_MESSAGE_CREATE
|
||||
payload_json.pop('nonce')
|
||||
payload_json.pop("nonce")
|
||||
|
||||
j = validate(payload_json, WEBHOOK_MESSAGE_CREATE)
|
||||
|
||||
msg_create_check_content(j, files)
|
||||
|
||||
# webhooks don't need permissions.
|
||||
mentions_everyone = '@everyone' in j['content']
|
||||
mentions_here = '@here' in j['content']
|
||||
mentions_everyone = "@everyone" in j["content"]
|
||||
mentions_here = "@here" in j["content"]
|
||||
|
||||
given_embeds = j.get('embeds', [])
|
||||
given_embeds = j.get("embeds", [])
|
||||
|
||||
webhook = await get_webhook(webhook_id)
|
||||
|
||||
# webhooks have TWO avatars. one is from settings, the other is from
|
||||
# the json's icon_url. one can be handled gracefully by IconManager,
|
||||
# but the other can't, at all.
|
||||
avatar = webhook['avatar']
|
||||
avatar = webhook["avatar"]
|
||||
|
||||
if 'avatar_url' in j and j['avatar_url'] is not None:
|
||||
avatar = await _create_avatar(webhook_id, j['avatar_url'])
|
||||
if "avatar_url" in j and j["avatar_url"] is not None:
|
||||
avatar = await _create_avatar(webhook_id, j["avatar_url"])
|
||||
|
||||
message_id = await create_message_webhook(
|
||||
guild_id, channel_id, webhook_id, {
|
||||
'content': j.get('content', ''),
|
||||
'tts': j.get('tts', False),
|
||||
|
||||
'everyone_mention': mentions_everyone or mentions_here,
|
||||
'embeds': await async_map(fill_embed, given_embeds),
|
||||
|
||||
'info': {
|
||||
'name': j.get('username', webhook['name']),
|
||||
'avatar': avatar
|
||||
}
|
||||
}
|
||||
guild_id,
|
||||
channel_id,
|
||||
webhook_id,
|
||||
{
|
||||
"content": j.get("content", ""),
|
||||
"tts": j.get("tts", False),
|
||||
"everyone_mention": mentions_everyone or mentions_here,
|
||||
"embeds": await async_map(fill_embed, given_embeds),
|
||||
"info": {"name": j.get("username", webhook["name"]), "avatar": avatar},
|
||||
},
|
||||
)
|
||||
|
||||
for pre_attachment in files:
|
||||
|
|
@ -455,33 +496,28 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
|||
|
||||
payload = await app.storage.get_message(message_id)
|
||||
|
||||
await app.dispatcher.dispatch('channel', channel_id,
|
||||
'MESSAGE_CREATE', payload)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||
|
||||
# spawn embedder in the background, even when we're on a webhook.
|
||||
app.sched.spawn(
|
||||
process_url_embed(
|
||||
app.config, app.storage, app.dispatcher, app.session,
|
||||
payload
|
||||
)
|
||||
process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload)
|
||||
)
|
||||
|
||||
# we can assume its a guild text channel, so just call it
|
||||
await msg_guild_text_mentions(
|
||||
payload, guild_id, mentions_everyone, mentions_here)
|
||||
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)
|
||||
|
||||
# TODO: is it really 204?
|
||||
return '', 204
|
||||
return "", 204
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/slack',
|
||||
methods=['POST'])
|
||||
|
||||
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>/slack", methods=["POST"])
|
||||
async def execute_slack_webhook(webhook_id, webhook_token):
|
||||
"""Execute a webhook but expecting Slack data."""
|
||||
# TODO: know slack webhooks
|
||||
await webhook_token_check(webhook_id, webhook_token)
|
||||
|
||||
|
||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/github', methods=['POST'])
|
||||
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>/github", methods=["POST"])
|
||||
async def execute_github_webhook(webhook_id, webhook_token):
|
||||
"""Execute a webhook but expecting GitHub data."""
|
||||
# TODO: know github webhooks
|
||||
|
|
|
|||
|
|
@ -21,9 +21,14 @@ from typing import List, Any, Dict
|
|||
|
||||
from logbook import Logger
|
||||
|
||||
from .pubsub import GuildDispatcher, MemberDispatcher, \
|
||||
UserDispatcher, ChannelDispatcher, FriendDispatcher, \
|
||||
LazyGuildDispatcher
|
||||
from .pubsub import (
|
||||
GuildDispatcher,
|
||||
MemberDispatcher,
|
||||
UserDispatcher,
|
||||
ChannelDispatcher,
|
||||
FriendDispatcher,
|
||||
LazyGuildDispatcher,
|
||||
)
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -44,17 +49,18 @@ class EventDispatcher:
|
|||
when dispatching, the backend can do its own logic, given
|
||||
its subscriber ids.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.state_manager = app.state_manager
|
||||
self.app = app
|
||||
|
||||
self.backends = {
|
||||
'guild': GuildDispatcher(self),
|
||||
'member': MemberDispatcher(self),
|
||||
'channel': ChannelDispatcher(self),
|
||||
'user': UserDispatcher(self),
|
||||
'friend': FriendDispatcher(self),
|
||||
'lazy_guild': LazyGuildDispatcher(self),
|
||||
"guild": GuildDispatcher(self),
|
||||
"member": MemberDispatcher(self),
|
||||
"channel": ChannelDispatcher(self),
|
||||
"user": UserDispatcher(self),
|
||||
"friend": FriendDispatcher(self),
|
||||
"lazy_guild": LazyGuildDispatcher(self),
|
||||
}
|
||||
|
||||
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
||||
|
|
@ -71,13 +77,13 @@ class EventDispatcher:
|
|||
|
||||
return await method(key, identifier, *args)
|
||||
|
||||
async def subscribe(self, backend: str, key: Any, identifier: Any,
|
||||
flags: Dict[str, Any] = None):
|
||||
async def subscribe(
|
||||
self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
|
||||
):
|
||||
"""Subscribe a single element to the given backend."""
|
||||
flags = flags or {}
|
||||
|
||||
log.debug('SUB backend={} key={} <= id={}',
|
||||
backend, key, identifier, backend)
|
||||
log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
|
||||
|
||||
# this is a hacky solution for backwards compatibility between backends
|
||||
# that implement flags and backends that don't.
|
||||
|
|
@ -85,16 +91,15 @@ class EventDispatcher:
|
|||
# passing flags to backends that don't implement flags will
|
||||
# cause errors as expected.
|
||||
if flags:
|
||||
return await self.action(backend, 'sub', key, identifier, flags)
|
||||
return await self.action(backend, "sub", key, identifier, flags)
|
||||
|
||||
return await self.action(backend, 'sub', key, identifier)
|
||||
return await self.action(backend, "sub", key, identifier)
|
||||
|
||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
||||
"""Unsubscribe an element from the given backend."""
|
||||
log.debug('UNSUB backend={} key={} => id={}',
|
||||
backend, key, identifier, backend)
|
||||
log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
|
||||
|
||||
return await self.action(backend, 'unsub', key, identifier)
|
||||
return await self.action(backend, "unsub", key, identifier)
|
||||
|
||||
async def sub(self, backend, key, identifier):
|
||||
"""Alias to subscribe()."""
|
||||
|
|
@ -104,8 +109,13 @@ class EventDispatcher:
|
|||
"""Alias to unsubscribe()."""
|
||||
return await self.unsubscribe(backend, key, identifier)
|
||||
|
||||
async def sub_many(self, backend_str: str, identifier: Any,
|
||||
keys: list, flags: Dict[str, Any] = None):
|
||||
async def sub_many(
|
||||
self,
|
||||
backend_str: str,
|
||||
identifier: Any,
|
||||
keys: list,
|
||||
flags: Dict[str, Any] = None,
|
||||
):
|
||||
"""Subscribe to multiple channels (all in a single backend)
|
||||
at a time.
|
||||
|
||||
|
|
@ -116,8 +126,7 @@ class EventDispatcher:
|
|||
for key in keys:
|
||||
await self.subscribe(backend_str, key, identifier, flags)
|
||||
|
||||
async def mass_sub(self, identifier: Any,
|
||||
backends: List[tuple]):
|
||||
async def mass_sub(self, identifier: Any, backends: List[tuple]):
|
||||
"""Mass subscribe to many backends at once."""
|
||||
for bcall in backends:
|
||||
backend_str, keys = bcall[0], bcall[1]
|
||||
|
|
@ -128,8 +137,13 @@ class EventDispatcher:
|
|||
# we have flags
|
||||
flags = bcall[2]
|
||||
|
||||
log.debug('subscribing {} to {} keys in backend {}, flags: {}',
|
||||
identifier, len(keys), backend_str, flags)
|
||||
log.debug(
|
||||
"subscribing {} to {} keys in backend {}, flags: {}",
|
||||
identifier,
|
||||
len(keys),
|
||||
backend_str,
|
||||
flags,
|
||||
)
|
||||
|
||||
await self.sub_many(backend_str, identifier, keys, flags)
|
||||
|
||||
|
|
@ -145,17 +159,14 @@ class EventDispatcher:
|
|||
key = backend.KEY_TYPE(key)
|
||||
return await backend.dispatch(key, *args, **kwargs)
|
||||
|
||||
async def dispatch_many(self, backend_str: str,
|
||||
keys: List[Any], *args, **kwargs):
|
||||
async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
|
||||
"""Dispatch to multiple keys in a single backend."""
|
||||
log.info('MULTI DISPATCH: {!r}, {} keys',
|
||||
backend_str, len(keys))
|
||||
log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
|
||||
|
||||
for key in keys:
|
||||
await self.dispatch(backend_str, key, *args, **kwargs)
|
||||
|
||||
async def dispatch_filter(self, backend_str: str,
|
||||
key: Any, func, *args):
|
||||
async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
|
||||
"""Dispatch to a backend that only accepts
|
||||
(event, data) arguments with an optional filter
|
||||
function."""
|
||||
|
|
@ -163,9 +174,9 @@ class EventDispatcher:
|
|||
key = backend.KEY_TYPE(key)
|
||||
return await backend.dispatch_filter(key, func, *args)
|
||||
|
||||
async def dispatch_many_filter_list(self, backend_str: str,
|
||||
keys: List[Any], sess_list: List[str],
|
||||
*args):
|
||||
async def dispatch_many_filter_list(
|
||||
self, backend_str: str, keys: List[Any], sess_list: List[str], *args
|
||||
):
|
||||
"""Make a "unique" dispatch given a list of session ids.
|
||||
|
||||
This only works for backends that have a dispatch_filter
|
||||
|
|
@ -175,9 +186,8 @@ class EventDispatcher:
|
|||
for key in keys:
|
||||
sess_list.extend(
|
||||
await self.dispatch_filter(
|
||||
backend_str, key,
|
||||
lambda sess_id: sess_id not in sess_list,
|
||||
*args)
|
||||
backend_str, key, lambda sess_id: sess_id not in sess_list, *args
|
||||
)
|
||||
)
|
||||
|
||||
return sess_list
|
||||
|
|
@ -197,12 +207,12 @@ class EventDispatcher:
|
|||
|
||||
async def dispatch_guild(self, guild_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch('guild', guild_id, event, data)
|
||||
return await self.dispatch("guild", guild_id, event, data)
|
||||
|
||||
async def dispatch_user_guild(self, user_id, guild_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch('member', (guild_id, user_id), event, data)
|
||||
return await self.dispatch("member", (guild_id, user_id), event, data)
|
||||
|
||||
async def dispatch_user(self, user_id, event, data):
|
||||
"""Backwards compatibility with old EventDispatcher."""
|
||||
return await self.dispatch('user', user_id, event, data)
|
||||
return await self.dispatch("user", user_id, event, data)
|
||||
|
|
|
|||
|
|
@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
from .sanitizer import sanitize_embed
|
||||
|
||||
__all__ = ['sanitize_embed']
|
||||
__all__ = ["sanitize_embed"]
|
||||
|
|
|
|||
|
|
@ -30,11 +30,7 @@ from litecord.embed.schemas import EmbedURL
|
|||
log = Logger(__name__)
|
||||
|
||||
|
||||
MEDIA_EXTENSIONS = (
|
||||
'png',
|
||||
'jpg', 'jpeg',
|
||||
'gif', 'webm'
|
||||
)
|
||||
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
|
||||
|
||||
|
||||
async def insert_media_meta(url, config, session):
|
||||
|
|
@ -45,18 +41,18 @@ async def insert_media_meta(url, config, session):
|
|||
if meta is None:
|
||||
return
|
||||
|
||||
if not meta['image']:
|
||||
if not meta["image"]:
|
||||
return
|
||||
|
||||
return {
|
||||
'type': 'image',
|
||||
'url': url,
|
||||
'thumbnail': {
|
||||
'width': meta['width'],
|
||||
'height': meta['height'],
|
||||
'url': url,
|
||||
'proxy_url': img_proxy_url
|
||||
}
|
||||
"type": "image",
|
||||
"url": url,
|
||||
"thumbnail": {
|
||||
"width": meta["width"],
|
||||
"height": meta["height"],
|
||||
"url": url,
|
||||
"proxy_url": img_proxy_url,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -64,29 +60,32 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
|
|||
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
|
||||
to users."""
|
||||
|
||||
message_id = int(payload['id'])
|
||||
channel_id = int(payload['channel_id'])
|
||||
message_id = int(payload["id"])
|
||||
channel_id = int(payload["channel_id"])
|
||||
|
||||
await storage.execute_with_json("""
|
||||
await storage.execute_with_json(
|
||||
"""
|
||||
UPDATE messages
|
||||
SET embeds = $1
|
||||
WHERE messages.id = $2
|
||||
""", new_embeds, message_id)
|
||||
""",
|
||||
new_embeds,
|
||||
message_id,
|
||||
)
|
||||
|
||||
update_payload = {
|
||||
'id': str(message_id),
|
||||
'channel_id': str(channel_id),
|
||||
'embeds': new_embeds,
|
||||
"id": str(message_id),
|
||||
"channel_id": str(channel_id),
|
||||
"embeds": new_embeds,
|
||||
}
|
||||
|
||||
if 'guild_id' in payload:
|
||||
update_payload['guild_id'] = payload['guild_id']
|
||||
if "guild_id" in payload:
|
||||
update_payload["guild_id"] = payload["guild_id"]
|
||||
|
||||
if 'flags' in payload:
|
||||
update_payload['flags'] = payload['flags']
|
||||
if "flags" in payload:
|
||||
update_payload["flags"] = payload["flags"]
|
||||
|
||||
await dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_UPDATE', update_payload)
|
||||
await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload)
|
||||
|
||||
|
||||
def is_media_url(url) -> bool:
|
||||
|
|
@ -98,7 +97,7 @@ def is_media_url(url) -> bool:
|
|||
parsed = urllib.parse.urlparse(url)
|
||||
|
||||
path = Path(parsed.path)
|
||||
extension = path.suffix.lstrip('.')
|
||||
extension = path.suffix.lstrip(".")
|
||||
|
||||
return extension in MEDIA_EXTENSIONS
|
||||
|
||||
|
|
@ -109,20 +108,20 @@ async def insert_mp_embed(parsed, config, session):
|
|||
return embed
|
||||
|
||||
|
||||
async def process_url_embed(config, storage, dispatcher,
|
||||
session, payload: dict, *, delay=0):
|
||||
async def process_url_embed(
|
||||
config, storage, dispatcher, session, payload: dict, *, delay=0
|
||||
):
|
||||
"""Process URLs in a message and generate embeds based on that."""
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
message_id = int(payload['id'])
|
||||
message_id = int(payload["id"])
|
||||
|
||||
# if we already have embeds
|
||||
# we shouldn't add our own.
|
||||
embeds = payload['embeds']
|
||||
embeds = payload["embeds"]
|
||||
|
||||
if embeds:
|
||||
log.debug('url processor: ignoring existing embeds @ mid {}',
|
||||
message_id)
|
||||
log.debug("url processor: ignoring existing embeds @ mid {}", message_id)
|
||||
return
|
||||
|
||||
# now, we have two types of embeds:
|
||||
|
|
@ -130,7 +129,7 @@ async def process_url_embed(config, storage, dispatcher,
|
|||
# - url embeds
|
||||
|
||||
# use regex to get URLs
|
||||
urls = re.findall(r'(https?://\S+)', payload['content'])
|
||||
urls = re.findall(r"(https?://\S+)", payload["content"])
|
||||
urls = urls[:5]
|
||||
|
||||
# from there, we need to parse each found url and check its path.
|
||||
|
|
@ -159,7 +158,6 @@ async def process_url_embed(config, storage, dispatcher,
|
|||
if not new_embeds:
|
||||
return
|
||||
|
||||
log.debug('made {} embeds for mid {}',
|
||||
len(new_embeds), message_id)
|
||||
log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
|
||||
|
||||
await msg_update_embeds(payload, new_embeds, storage, dispatcher)
|
||||
|
|
|
|||
|
|
@ -39,9 +39,7 @@ def sanitize_embed(embed: Embed) -> Embed:
|
|||
This is non-complex sanitization as it doesn't
|
||||
need the app object.
|
||||
"""
|
||||
return {**embed, **{
|
||||
'type': 'rich'
|
||||
}}
|
||||
return {**embed, **{"type": "rich"}}
|
||||
|
||||
|
||||
def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
||||
|
|
@ -55,7 +53,7 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
|||
|
||||
# get the list of components given
|
||||
if isinstance(components_in, str):
|
||||
components = components_in.split('.')
|
||||
components = components_in.split(".")
|
||||
else:
|
||||
components = list(components_in)
|
||||
|
||||
|
|
@ -77,7 +75,6 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
|||
return False
|
||||
|
||||
|
||||
|
||||
def _mk_cfg_sess(config, session) -> tuple:
|
||||
"""Return a tuple of (config, session)."""
|
||||
if config is None:
|
||||
|
|
@ -91,11 +88,11 @@ def _mk_cfg_sess(config, session) -> tuple:
|
|||
|
||||
def _md_base(config) -> Optional[tuple]:
|
||||
"""Return the protocol and base url for the mediaproxy."""
|
||||
md_base_url = config['MEDIA_PROXY']
|
||||
md_base_url = config["MEDIA_PROXY"]
|
||||
if md_base_url is None:
|
||||
return None
|
||||
|
||||
proto = 'https' if config['IS_SSL'] else 'http'
|
||||
proto = "https" if config["IS_SSL"] else "http"
|
||||
|
||||
return proto, md_base_url
|
||||
|
||||
|
|
@ -111,7 +108,7 @@ def make_md_req_url(config, scope: str, url):
|
|||
return url.url if isinstance(url, EmbedURL) else url
|
||||
|
||||
proto, base_url = base
|
||||
return f'{proto}://{base_url}/{scope}/{url.to_md_path}'
|
||||
return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
|
||||
|
||||
|
||||
def proxify(url, *, config=None) -> str:
|
||||
|
|
@ -122,11 +119,12 @@ def proxify(url, *, config=None) -> str:
|
|||
if isinstance(url, str):
|
||||
url = EmbedURL(url)
|
||||
|
||||
return make_md_req_url(config, 'img', url)
|
||||
return make_md_req_url(config, "img", url)
|
||||
|
||||
|
||||
async def _md_client_req(config, session, scope: str,
|
||||
url, *, ret_resp=False) -> Optional[Union[Tuple, Dict]]:
|
||||
async def _md_client_req(
|
||||
config, session, scope: str, url, *, ret_resp=False
|
||||
) -> Optional[Union[Tuple, Dict]]:
|
||||
"""Makes a request to the mediaproxy.
|
||||
|
||||
This has common code between all the main mediaproxy request functions
|
||||
|
|
@ -172,17 +170,13 @@ async def _md_client_req(config, session, scope: str,
|
|||
return await resp.json()
|
||||
|
||||
body = await resp.text()
|
||||
log.warning('failed to call {!r}, {} {!r}',
|
||||
request_url, resp.status, body)
|
||||
log.warning("failed to call {!r}, {} {!r}", request_url, resp.status, body)
|
||||
return None
|
||||
|
||||
|
||||
async def fetch_metadata(url, *, config=None,
|
||||
session=None) -> Optional[Dict]:
|
||||
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
|
||||
"""Fetch metadata for a url (image width, mime, etc)."""
|
||||
return await _md_client_req(
|
||||
config, session, 'meta', url
|
||||
)
|
||||
return await _md_client_req(config, session, "meta", url)
|
||||
|
||||
|
||||
async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
|
||||
|
|
@ -191,9 +185,7 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
|
|||
Returns a tuple containing the response object and the raw bytes given by
|
||||
the website.
|
||||
"""
|
||||
tup = await _md_client_req(
|
||||
config, session, 'img', url, ret_resp=True
|
||||
)
|
||||
tup = await _md_client_req(config, session, "img", url, ret_resp=True)
|
||||
|
||||
if not tup:
|
||||
return None
|
||||
|
|
@ -207,9 +199,7 @@ async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]:
|
|||
|
||||
Returns a discord embed object.
|
||||
"""
|
||||
return await _md_client_req(
|
||||
config, session, 'embed', url
|
||||
)
|
||||
return await _md_client_req(config, session, "embed", url)
|
||||
|
||||
|
||||
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
|
||||
|
|
@ -229,22 +219,20 @@ async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
|
|||
|
||||
embed = sanitize_embed(embed)
|
||||
|
||||
if path_exists(embed, 'footer.icon_url'):
|
||||
embed['footer']['proxy_icon_url'] = \
|
||||
proxify(embed['footer']['icon_url'])
|
||||
if path_exists(embed, "footer.icon_url"):
|
||||
embed["footer"]["proxy_icon_url"] = proxify(embed["footer"]["icon_url"])
|
||||
|
||||
if path_exists(embed, 'author.icon_url'):
|
||||
embed['author']['proxy_icon_url'] = \
|
||||
proxify(embed['author']['icon_url'])
|
||||
if path_exists(embed, "author.icon_url"):
|
||||
embed["author"]["proxy_icon_url"] = proxify(embed["author"]["icon_url"])
|
||||
|
||||
if path_exists(embed, 'image.url'):
|
||||
image_url = embed['image']['url']
|
||||
if path_exists(embed, "image.url"):
|
||||
image_url = embed["image"]["url"]
|
||||
|
||||
meta = await fetch_metadata(image_url)
|
||||
embed['image']['proxy_url'] = proxify(image_url)
|
||||
embed["image"]["proxy_url"] = proxify(image_url)
|
||||
|
||||
if meta and meta['image']:
|
||||
embed['image']['width'] = meta['width']
|
||||
embed['image']['height'] = meta['height']
|
||||
if meta and meta["image"]:
|
||||
embed["image"]["width"] = meta["width"]
|
||||
embed["image"]["height"] = meta["height"]
|
||||
|
||||
return embed
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ class EmbedURL:
|
|||
def __init__(self, url: str):
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
|
||||
if parsed.scheme not in ('http', 'https', 'attachment'):
|
||||
raise ValueError('Invalid URL scheme')
|
||||
if parsed.scheme not in ("http", "https", "attachment"):
|
||||
raise ValueError("Invalid URL scheme")
|
||||
|
||||
self.scheme = parsed.scheme
|
||||
self.raw_url = url
|
||||
|
|
@ -54,105 +54,61 @@ class EmbedURL:
|
|||
def to_md_path(self) -> str:
|
||||
"""Convert the EmbedURL to a mediaproxy path (post img/meta)."""
|
||||
parsed = self.parsed
|
||||
return (
|
||||
f'{parsed.scheme}/{parsed.netloc}'
|
||||
f'{parsed.path}?{parsed.query}'
|
||||
)
|
||||
return f"{parsed.scheme}/{parsed.netloc}" f"{parsed.path}?{parsed.query}"
|
||||
|
||||
|
||||
EMBED_FOOTER = {
|
||||
'text': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True},
|
||||
|
||||
'icon_url': {
|
||||
'coerce': EmbedURL, 'required': False,
|
||||
},
|
||||
|
||||
"text": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
|
||||
"icon_url": {"coerce": EmbedURL, "required": False},
|
||||
# NOTE: proxy_icon_url set by us
|
||||
}
|
||||
|
||||
EMBED_IMAGE = {
|
||||
'url': {'coerce': EmbedURL, 'required': True},
|
||||
|
||||
"url": {"coerce": EmbedURL, "required": True},
|
||||
# NOTE: proxy_url, width, height set by us
|
||||
}
|
||||
|
||||
EMBED_THUMBNAIL = EMBED_IMAGE
|
||||
|
||||
EMBED_AUTHOR = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False
|
||||
},
|
||||
'url': {
|
||||
'coerce': EmbedURL, 'required': False,
|
||||
},
|
||||
'icon_url': {
|
||||
'coerce': EmbedURL, 'required': False,
|
||||
}
|
||||
|
||||
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
|
||||
"url": {"coerce": EmbedURL, "required": False},
|
||||
"icon_url": {"coerce": EmbedURL, "required": False}
|
||||
# NOTE: proxy_icon_url set by us
|
||||
}
|
||||
|
||||
EMBED_FIELD = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True
|
||||
},
|
||||
'value': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True
|
||||
},
|
||||
'inline': {
|
||||
'type': 'boolean', 'required': False, 'default': True,
|
||||
},
|
||||
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
|
||||
"value": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
|
||||
"inline": {"type": "boolean", "required": False, "default": True},
|
||||
}
|
||||
|
||||
EMBED_OBJECT = {
|
||||
'title': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False},
|
||||
"title": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
|
||||
# NOTE: type set by us
|
||||
'description': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 2048, 'required': False,
|
||||
"description": {
|
||||
"type": "string",
|
||||
"minlength": 1,
|
||||
"maxlength": 2048,
|
||||
"required": False,
|
||||
},
|
||||
'url': {
|
||||
'coerce': EmbedURL, 'required': False,
|
||||
},
|
||||
'timestamp': {
|
||||
"url": {"coerce": EmbedURL, "required": False},
|
||||
"timestamp": {
|
||||
# TODO: an ISO 8601 type
|
||||
# TODO: maybe replace the default in here with now().isoformat?
|
||||
'type': 'string', 'required': False
|
||||
"type": "string",
|
||||
"required": False,
|
||||
},
|
||||
|
||||
'color': {
|
||||
'coerce': Color, 'required': False
|
||||
},
|
||||
|
||||
'footer': {
|
||||
'type': 'dict',
|
||||
'schema': EMBED_FOOTER,
|
||||
'required': False,
|
||||
},
|
||||
'image': {
|
||||
'type': 'dict',
|
||||
'schema': EMBED_IMAGE,
|
||||
'required': False,
|
||||
},
|
||||
'thumbnail': {
|
||||
'type': 'dict',
|
||||
'schema': EMBED_THUMBNAIL,
|
||||
'required': False,
|
||||
},
|
||||
|
||||
"color": {"coerce": Color, "required": False},
|
||||
"footer": {"type": "dict", "schema": EMBED_FOOTER, "required": False},
|
||||
"image": {"type": "dict", "schema": EMBED_IMAGE, "required": False},
|
||||
"thumbnail": {"type": "dict", "schema": EMBED_THUMBNAIL, "required": False},
|
||||
# NOTE: 'video' set by us
|
||||
# NOTE: 'provider' set by us
|
||||
|
||||
'author': {
|
||||
'type': 'dict',
|
||||
'schema': EMBED_AUTHOR,
|
||||
'required': False,
|
||||
},
|
||||
|
||||
'fields': {
|
||||
'type': 'list',
|
||||
'schema': {'type': 'dict', 'schema': EMBED_FIELD},
|
||||
'required': False,
|
||||
"author": {"type": "dict", "schema": EMBED_AUTHOR, "required": False},
|
||||
"fields": {
|
||||
"type": "list",
|
||||
"schema": {"type": "dict", "schema": EMBED_FIELD},
|
||||
"required": False,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,13 +52,14 @@ class Flags:
|
|||
>>> i2.is_field_3
|
||||
False
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **_kwargs):
|
||||
attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
|
||||
|
||||
def _make_int(value):
|
||||
res = Flags()
|
||||
|
||||
setattr(res, 'value', value)
|
||||
setattr(res, "value", value)
|
||||
|
||||
for attr, val in attrs:
|
||||
# get only the ones that represent a field in the
|
||||
|
|
@ -69,7 +70,7 @@ class Flags:
|
|||
has_attr = (value & val) == val
|
||||
|
||||
# set each attribute
|
||||
setattr(res, f'is_{attr}', has_attr)
|
||||
setattr(res, f"is_{attr}", has_attr)
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -84,17 +85,16 @@ class ChannelType(EasyEnum):
|
|||
GUILD_CATEGORY = 4
|
||||
|
||||
|
||||
GUILD_CHANS = (ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE,
|
||||
ChannelType.GUILD_CATEGORY)
|
||||
|
||||
|
||||
VOICE_CHANNELS = (
|
||||
ChannelType.DM, ChannelType.GUILD_VOICE,
|
||||
ChannelType.GUILD_CATEGORY
|
||||
GUILD_CHANS = (
|
||||
ChannelType.GUILD_TEXT,
|
||||
ChannelType.GUILD_VOICE,
|
||||
ChannelType.GUILD_CATEGORY,
|
||||
)
|
||||
|
||||
|
||||
VOICE_CHANNELS = (ChannelType.DM, ChannelType.GUILD_VOICE, ChannelType.GUILD_CATEGORY)
|
||||
|
||||
|
||||
class ActivityType(EasyEnum):
|
||||
PLAYING = 0
|
||||
STREAMING = 1
|
||||
|
|
@ -120,7 +120,7 @@ SYS_MESSAGES = (
|
|||
MessageType.CHANNEL_NAME_CHANGE,
|
||||
MessageType.CHANNEL_ICON_CHANGE,
|
||||
MessageType.CHANNEL_PINNED_MESSAGE,
|
||||
MessageType.GUILD_MEMBER_JOIN
|
||||
MessageType.GUILD_MEMBER_JOIN,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -137,6 +137,7 @@ class ActivityFlags(Flags):
|
|||
|
||||
Only related to rich presence.
|
||||
"""
|
||||
|
||||
instance = 1
|
||||
join = 2
|
||||
spectate = 4
|
||||
|
|
@ -150,6 +151,7 @@ class UserFlags(Flags):
|
|||
|
||||
Used by the client to show badges.
|
||||
"""
|
||||
|
||||
staff = 1
|
||||
partner = 2
|
||||
hypesquad = 4
|
||||
|
|
@ -166,6 +168,7 @@ class UserFlags(Flags):
|
|||
|
||||
class MessageFlags(Flags):
|
||||
"""Message flags."""
|
||||
|
||||
none = 0
|
||||
|
||||
crossposted = 1 << 0
|
||||
|
|
@ -175,11 +178,12 @@ class MessageFlags(Flags):
|
|||
|
||||
class StatusType(EasyEnum):
|
||||
"""All statuses there can be in a presence."""
|
||||
ONLINE = 'online'
|
||||
DND = 'dnd'
|
||||
IDLE = 'idle'
|
||||
INVISIBLE = 'invisible'
|
||||
OFFLINE = 'offline'
|
||||
|
||||
ONLINE = "online"
|
||||
DND = "dnd"
|
||||
IDLE = "idle"
|
||||
INVISIBLE = "invisible"
|
||||
OFFLINE = "offline"
|
||||
|
||||
|
||||
class ExplicitFilter(EasyEnum):
|
||||
|
|
@ -187,6 +191,7 @@ class ExplicitFilter(EasyEnum):
|
|||
|
||||
Also applies to guilds.
|
||||
"""
|
||||
|
||||
EDGE = 0
|
||||
FRIENDS = 1
|
||||
SAFE = 2
|
||||
|
|
@ -194,6 +199,7 @@ class ExplicitFilter(EasyEnum):
|
|||
|
||||
class VerificationLevel(IntEnum):
|
||||
"""Verification level for guilds."""
|
||||
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
|
|
@ -205,6 +211,7 @@ class VerificationLevel(IntEnum):
|
|||
|
||||
class RelationshipType(EasyEnum):
|
||||
"""Relationship types between users."""
|
||||
|
||||
FRIEND = 1
|
||||
BLOCK = 2
|
||||
INCOMING = 3
|
||||
|
|
@ -213,6 +220,7 @@ class RelationshipType(EasyEnum):
|
|||
|
||||
class MessageNotifications(EasyEnum):
|
||||
"""Message notifications"""
|
||||
|
||||
ALL = 0
|
||||
MENTIONS = 1
|
||||
NOTHING = 2
|
||||
|
|
@ -220,6 +228,7 @@ class MessageNotifications(EasyEnum):
|
|||
|
||||
class PremiumType:
|
||||
"""Premium (Nitro) type."""
|
||||
|
||||
TIER_1 = 1
|
||||
TIER_2 = 2
|
||||
NONE = None
|
||||
|
|
@ -227,12 +236,13 @@ class PremiumType:
|
|||
|
||||
class Feature(EasyEnum):
|
||||
"""Guild features."""
|
||||
invite_splash = 'INVITE_SPLASH'
|
||||
vip = 'VIP_REGIONS'
|
||||
vanity = 'VANITY_URL'
|
||||
emoji = 'MORE_EMOJI'
|
||||
verified = 'VERIFIED'
|
||||
|
||||
invite_splash = "INVITE_SPLASH"
|
||||
vip = "VIP_REGIONS"
|
||||
vanity = "VANITY_URL"
|
||||
emoji = "MORE_EMOJI"
|
||||
verified = "VERIFIED"
|
||||
|
||||
# unknown
|
||||
commerce = 'COMMERCE'
|
||||
news = 'NEWS'
|
||||
commerce = "COMMERCE"
|
||||
news = "NEWS"
|
||||
|
|
|
|||
|
|
@ -18,60 +18,64 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
ERR_MSG_MAP = {
|
||||
10001: 'Unknown account',
|
||||
10002: 'Unknown application',
|
||||
10003: 'Unknown channel',
|
||||
10004: 'Unknown guild',
|
||||
10005: 'Unknown integration',
|
||||
10006: 'Unknown invite',
|
||||
10007: 'Unknown member',
|
||||
10008: 'Unknown message',
|
||||
10009: 'Unknown overwrite',
|
||||
10010: 'Unknown provider',
|
||||
10011: 'Unknown role',
|
||||
10012: 'Unknown token',
|
||||
10013: 'Unknown user',
|
||||
10014: 'Unknown Emoji',
|
||||
10015: 'Unknown Webhook',
|
||||
20001: 'Bots cannot use this endpoint',
|
||||
20002: 'Only bots can use this endpoint',
|
||||
30001: 'Maximum number of guilds reached (100)',
|
||||
30002: 'Maximum number of friends reached (1000)',
|
||||
30003: 'Maximum number of pins reached (50)',
|
||||
30005: 'Maximum number of guild roles reached (250)',
|
||||
30010: 'Maximum number of reactions reached (20)',
|
||||
30013: 'Maximum number of guild channels reached (500)',
|
||||
40001: 'Unauthorized',
|
||||
50001: 'Missing access',
|
||||
50002: 'Invalid account type',
|
||||
50003: 'Cannot execute action on a DM channel',
|
||||
50004: 'Widget Disabled',
|
||||
50005: 'Cannot edit a message authored by another user',
|
||||
50006: 'Cannot send an empty message',
|
||||
50007: 'Cannot send messages to this user',
|
||||
50008: 'Cannot send messages in a voice channel',
|
||||
50009: 'Channel verification level is too high',
|
||||
50010: 'OAuth2 application does not have a bot',
|
||||
50011: 'OAuth2 application limit reached',
|
||||
50012: 'Invalid OAuth state',
|
||||
50013: 'Missing permissions',
|
||||
50014: 'Invalid authentication token',
|
||||
50015: 'Note is too long',
|
||||
50016: ('Provided too few or too many messages to delete. Must provide at '
|
||||
'least 2 and fewer than 100 messages to delete.'),
|
||||
50019: 'A message can only be pinned to the channel it was sent in',
|
||||
50020: 'Invite code is either invalid or taken.',
|
||||
50021: 'Cannot execute action on a system message',
|
||||
50025: 'Invalid OAuth2 access token',
|
||||
50034: 'A message provided was too old to bulk delete',
|
||||
50035: 'Invalid Form Body',
|
||||
50036: 'An invite was accepted to a guild the application\'s bot is not in',
|
||||
50041: 'Invalid API version',
|
||||
90001: 'Reaction blocked',
|
||||
10001: "Unknown account",
|
||||
10002: "Unknown application",
|
||||
10003: "Unknown channel",
|
||||
10004: "Unknown guild",
|
||||
10005: "Unknown integration",
|
||||
10006: "Unknown invite",
|
||||
10007: "Unknown member",
|
||||
10008: "Unknown message",
|
||||
10009: "Unknown overwrite",
|
||||
10010: "Unknown provider",
|
||||
10011: "Unknown role",
|
||||
10012: "Unknown token",
|
||||
10013: "Unknown user",
|
||||
10014: "Unknown Emoji",
|
||||
10015: "Unknown Webhook",
|
||||
20001: "Bots cannot use this endpoint",
|
||||
20002: "Only bots can use this endpoint",
|
||||
30001: "Maximum number of guilds reached (100)",
|
||||
30002: "Maximum number of friends reached (1000)",
|
||||
30003: "Maximum number of pins reached (50)",
|
||||
30005: "Maximum number of guild roles reached (250)",
|
||||
30010: "Maximum number of reactions reached (20)",
|
||||
30013: "Maximum number of guild channels reached (500)",
|
||||
40001: "Unauthorized",
|
||||
50001: "Missing access",
|
||||
50002: "Invalid account type",
|
||||
50003: "Cannot execute action on a DM channel",
|
||||
50004: "Widget Disabled",
|
||||
50005: "Cannot edit a message authored by another user",
|
||||
50006: "Cannot send an empty message",
|
||||
50007: "Cannot send messages to this user",
|
||||
50008: "Cannot send messages in a voice channel",
|
||||
50009: "Channel verification level is too high",
|
||||
50010: "OAuth2 application does not have a bot",
|
||||
50011: "OAuth2 application limit reached",
|
||||
50012: "Invalid OAuth state",
|
||||
50013: "Missing permissions",
|
||||
50014: "Invalid authentication token",
|
||||
50015: "Note is too long",
|
||||
50016: (
|
||||
"Provided too few or too many messages to delete. Must provide at "
|
||||
"least 2 and fewer than 100 messages to delete."
|
||||
),
|
||||
50019: "A message can only be pinned to the channel it was sent in",
|
||||
50020: "Invite code is either invalid or taken.",
|
||||
50021: "Cannot execute action on a system message",
|
||||
50025: "Invalid OAuth2 access token",
|
||||
50034: "A message provided was too old to bulk delete",
|
||||
50035: "Invalid Form Body",
|
||||
50036: "An invite was accepted to a guild the application's bot is not in",
|
||||
50041: "Invalid API version",
|
||||
90001: "Reaction blocked",
|
||||
}
|
||||
|
||||
|
||||
class LitecordError(Exception):
|
||||
"""Base class for litecord errors"""
|
||||
|
||||
status_code = 500
|
||||
|
||||
def _get_err_msg(self, err_code: int) -> str:
|
||||
|
|
@ -91,7 +95,7 @@ class LitecordError(Exception):
|
|||
|
||||
return message
|
||||
except IndexError:
|
||||
return self._get_err_msg(getattr(self, 'error_code', None))
|
||||
return self._get_err_msg(getattr(self, "error_code", None))
|
||||
|
||||
@property
|
||||
def json(self):
|
||||
|
|
@ -143,7 +147,7 @@ class MissingPermissions(Forbidden):
|
|||
class WebsocketClose(Exception):
|
||||
@property
|
||||
def code(self):
|
||||
from_class = getattr(self, 'close_code', None)
|
||||
from_class = getattr(self, "close_code", None)
|
||||
|
||||
if from_class:
|
||||
return from_class
|
||||
|
|
@ -152,7 +156,7 @@ class WebsocketClose(Exception):
|
|||
|
||||
@property
|
||||
def reason(self):
|
||||
from_class = getattr(self, 'close_code', None)
|
||||
from_class = getattr(self, "close_code", None)
|
||||
|
||||
if from_class:
|
||||
return self.args[0]
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ from litecord.utils import LitecordJSONEncoder
|
|||
|
||||
def encode_json(payload) -> str:
|
||||
"""Encode a given payload to JSON."""
|
||||
return json.dumps(payload, separators=(',', ':'),
|
||||
cls=LitecordJSONEncoder)
|
||||
return json.dumps(payload, separators=(",", ":"), cls=LitecordJSONEncoder)
|
||||
|
||||
|
||||
def decode_json(data: str):
|
||||
|
|
@ -71,6 +70,7 @@ def _etf_decode_dict(data):
|
|||
|
||||
return result
|
||||
|
||||
|
||||
def decode_etf(data: bytes):
|
||||
"""Decode data in ETF to any."""
|
||||
res = earl.unpack(data)
|
||||
|
|
|
|||
|
|
@ -24,37 +24,36 @@ from litecord.gateway.websocket import GatewayWebsocket
|
|||
async def websocket_handler(app, ws, url):
|
||||
"""Main websocket handler, checks query arguments when connecting to
|
||||
the gateway and spawns a GatewayWebsocket instance for the connection."""
|
||||
args = urllib.parse.parse_qs(
|
||||
urllib.parse.urlparse(url).query
|
||||
)
|
||||
args = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
|
||||
|
||||
# pull a dict.get but in a really bad way.
|
||||
try:
|
||||
gw_version = args['v'][0]
|
||||
gw_version = args["v"][0]
|
||||
except (KeyError, IndexError):
|
||||
gw_version = '6'
|
||||
gw_version = "6"
|
||||
|
||||
try:
|
||||
gw_encoding = args['encoding'][0]
|
||||
gw_encoding = args["encoding"][0]
|
||||
except (KeyError, IndexError):
|
||||
gw_encoding = 'json'
|
||||
gw_encoding = "json"
|
||||
|
||||
if gw_version not in ('6', '7'):
|
||||
return await ws.close(1000, 'Invalid gateway version')
|
||||
if gw_version not in ("6", "7"):
|
||||
return await ws.close(1000, "Invalid gateway version")
|
||||
|
||||
if gw_encoding not in ('json', 'etf'):
|
||||
return await ws.close(1000, 'Invalid gateway encoding')
|
||||
if gw_encoding not in ("json", "etf"):
|
||||
return await ws.close(1000, "Invalid gateway encoding")
|
||||
|
||||
try:
|
||||
gw_compress = args['compress'][0]
|
||||
gw_compress = args["compress"][0]
|
||||
except (KeyError, IndexError):
|
||||
gw_compress = None
|
||||
|
||||
if gw_compress and gw_compress not in ('zlib-stream', 'zstd-stream'):
|
||||
return await ws.close(1000, 'Invalid gateway compress')
|
||||
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
|
||||
return await ws.close(1000, "Invalid gateway compress")
|
||||
|
||||
gws = GatewayWebsocket(
|
||||
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress)
|
||||
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
||||
)
|
||||
|
||||
# this can be run with a single await since this whole coroutine
|
||||
# is already running in the background.
|
||||
|
|
|
|||
|
|
@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
|
||||
class OP:
|
||||
"""Gateway OP codes."""
|
||||
|
||||
DISPATCH = 0
|
||||
HEARTBEAT = 1
|
||||
IDENTIFY = 2
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class PayloadStore:
|
|||
This will only store a maximum of MAX_STORE_SIZE,
|
||||
dropping the older payloads when adding new ones.
|
||||
"""
|
||||
|
||||
MAX_STORE_SIZE = 250
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -60,20 +61,20 @@ class GatewayState:
|
|||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.session_id = kwargs.get('session_id', gen_session_id())
|
||||
self.session_id = kwargs.get("session_id", gen_session_id())
|
||||
|
||||
#: event sequence number
|
||||
self.seq = kwargs.get('seq', 0)
|
||||
self.seq = kwargs.get("seq", 0)
|
||||
|
||||
#: last seq sent by us, the backend
|
||||
self.last_seq = 0
|
||||
|
||||
#: shard information about the state,
|
||||
# its id and shard count
|
||||
self.shard = kwargs.get('shard', [0, 1])
|
||||
self.shard = kwargs.get("shard", [0, 1])
|
||||
|
||||
self.user_id = kwargs.get('user_id')
|
||||
self.bot = kwargs.get('bot', False)
|
||||
self.user_id = kwargs.get("user_id")
|
||||
self.bot = kwargs.get("bot", False)
|
||||
|
||||
#: set by the gateway connection
|
||||
# on OP STATUS_UPDATE
|
||||
|
|
@ -90,5 +91,4 @@ class GatewayState:
|
|||
self.__dict__[key] = value
|
||||
|
||||
def __repr__(self):
|
||||
return (f'GatewayState<seq={self.seq} '
|
||||
f'shard={self.shard} uid={self.user_id}>')
|
||||
return f"GatewayState<seq={self.seq} " f"shard={self.shard} uid={self.user_id}>"
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class ManagerClose(Exception):
|
|||
class StateDictWrapper:
|
||||
"""Wrap a mapping so that any kind of access to the mapping while the
|
||||
state manager is closed raises a ManagerClose error"""
|
||||
|
||||
def __init__(self, state_manager, mapping):
|
||||
self.state_manager = state_manager
|
||||
self._map = mapping
|
||||
|
|
@ -98,7 +99,7 @@ class StateManager:
|
|||
"""Insert a new state object."""
|
||||
user_states = self.states[state.user_id]
|
||||
|
||||
log.debug('inserting state: {!r}', state)
|
||||
log.debug("inserting state: {!r}", state)
|
||||
user_states[state.session_id] = state
|
||||
self.states_raw[state.session_id] = state
|
||||
|
||||
|
|
@ -128,7 +129,7 @@ class StateManager:
|
|||
pass
|
||||
|
||||
try:
|
||||
log.debug('removing state: {!r}', state)
|
||||
log.debug("removing state: {!r}", state)
|
||||
self.states[state.user_id].pop(state.session_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
|
@ -152,8 +153,7 @@ class StateManager:
|
|||
"""Fetch all states tied to a single user."""
|
||||
return list(self.states[user_id].values())
|
||||
|
||||
def guild_states(self, member_ids: List[int],
|
||||
guild_id: int) -> List[GatewayState]:
|
||||
def guild_states(self, member_ids: List[int], guild_id: int) -> List[GatewayState]:
|
||||
"""Fetch all possible states about members in a guild."""
|
||||
states = []
|
||||
|
||||
|
|
@ -164,14 +164,14 @@ class StateManager:
|
|||
# since server start, so we need to add a dummy state
|
||||
if not member_states:
|
||||
dummy_state = GatewayState(
|
||||
session_id='',
|
||||
session_id="",
|
||||
user_id=member_id,
|
||||
presence={
|
||||
'afk': False,
|
||||
'status': 'offline',
|
||||
'game': None,
|
||||
'since': 0
|
||||
}
|
||||
"afk": False,
|
||||
"status": "offline",
|
||||
"game": None,
|
||||
"since": 0,
|
||||
},
|
||||
)
|
||||
|
||||
states.append(dummy_state)
|
||||
|
|
@ -187,9 +187,7 @@ class StateManager:
|
|||
"""Send OP Reconnect to a single connection."""
|
||||
websocket = state.ws
|
||||
|
||||
await websocket.send({
|
||||
'op': OP.RECONNECT
|
||||
})
|
||||
await websocket.send({"op": OP.RECONNECT})
|
||||
|
||||
# wait 200ms
|
||||
# so that the client has time to process
|
||||
|
|
@ -198,12 +196,9 @@ class StateManager:
|
|||
|
||||
try:
|
||||
# try to close the connection ourselves
|
||||
await websocket.ws.close(
|
||||
code=4000,
|
||||
reason='litecord shutting down'
|
||||
)
|
||||
await websocket.ws.close(code=4000, reason="litecord shutting down")
|
||||
except ConnectionClosed:
|
||||
log.info('client {} already closed', state)
|
||||
log.info("client {} already closed", state)
|
||||
|
||||
def gen_close_tasks(self):
|
||||
"""Generate the tasks that will order the clients
|
||||
|
|
@ -222,11 +217,9 @@ class StateManager:
|
|||
if not state.ws:
|
||||
continue
|
||||
|
||||
tasks.append(
|
||||
self.shutdown_single(state)
|
||||
)
|
||||
tasks.append(self.shutdown_single(state))
|
||||
|
||||
log.info('made {} shutdown tasks', len(tasks))
|
||||
log.info("made {} shutdown tasks", len(tasks))
|
||||
return tasks
|
||||
|
||||
def close(self):
|
||||
|
|
|
|||
|
|
@ -19,9 +19,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import asyncio
|
||||
|
||||
|
||||
class WebsocketFileHandler:
|
||||
"""A handler around a websocket that wraps normal I/O calls into
|
||||
the websocket's respective asyncio calls via asyncio.ensure_future."""
|
||||
|
||||
def __init__(self, ws):
|
||||
self.ws = ws
|
||||
|
||||
|
|
|
|||
|
|
@ -31,23 +31,20 @@ from logbook import Logger
|
|||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType, ChannelType
|
||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||
from litecord.utils import (
|
||||
task_wrapper, yield_chunks, maybe_int
|
||||
)
|
||||
from litecord.utils import task_wrapper, yield_chunks, maybe_int
|
||||
from litecord.permissions import get_permissions
|
||||
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.gateway.state import GatewayState
|
||||
|
||||
from litecord.errors import (
|
||||
WebsocketClose, Unauthorized, Forbidden, BadRequest
|
||||
)
|
||||
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
|
||||
from litecord.gateway.errors import (
|
||||
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired
|
||||
)
|
||||
from litecord.gateway.encoding import (
|
||||
encode_json, decode_json, encode_etf, decode_etf
|
||||
DecodeError,
|
||||
UnknownOPCode,
|
||||
InvalidShard,
|
||||
ShardingRequired,
|
||||
)
|
||||
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
|
||||
|
||||
from litecord.gateway.utils import WebsocketFileHandler
|
||||
|
||||
|
|
@ -56,15 +53,22 @@ from litecord.storage import int_
|
|||
log = Logger(__name__)
|
||||
|
||||
WebsocketProperties = collections.namedtuple(
|
||||
'WebsocketProperties', 'v encoding compress zctx zsctx tasks'
|
||||
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
|
||||
)
|
||||
|
||||
WebsocketObjects = collections.namedtuple(
|
||||
'WebsocketObjects', (
|
||||
'db', 'state_manager', 'storage',
|
||||
'loop', 'dispatcher', 'presence', 'ratelimiter',
|
||||
'user_storage', 'voice'
|
||||
)
|
||||
"WebsocketObjects",
|
||||
(
|
||||
"db",
|
||||
"state_manager",
|
||||
"storage",
|
||||
"loop",
|
||||
"dispatcher",
|
||||
"presence",
|
||||
"ratelimiter",
|
||||
"user_storage",
|
||||
"voice",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -73,9 +77,15 @@ class GatewayWebsocket:
|
|||
|
||||
def __init__(self, ws, app, **kwargs):
|
||||
self.ext = WebsocketObjects(
|
||||
app.db, app.state_manager, app.storage, app.loop,
|
||||
app.dispatcher, app.presence, app.ratelimiter,
|
||||
app.user_storage, app.voice
|
||||
app.db,
|
||||
app.state_manager,
|
||||
app.storage,
|
||||
app.loop,
|
||||
app.dispatcher,
|
||||
app.presence,
|
||||
app.ratelimiter,
|
||||
app.user_storage,
|
||||
app.voice,
|
||||
)
|
||||
|
||||
self.storage = self.ext.storage
|
||||
|
|
@ -84,15 +94,15 @@ class GatewayWebsocket:
|
|||
self.ws = ws
|
||||
|
||||
self.wsp = WebsocketProperties(
|
||||
kwargs.get('v'),
|
||||
kwargs.get('encoding', 'json'),
|
||||
kwargs.get('compress', None),
|
||||
kwargs.get("v"),
|
||||
kwargs.get("encoding", "json"),
|
||||
kwargs.get("compress", None),
|
||||
zlib.compressobj(),
|
||||
zstd.ZstdCompressor(),
|
||||
{}
|
||||
{},
|
||||
)
|
||||
|
||||
log.debug('websocket properties: {!r}', self.wsp)
|
||||
log.debug("websocket properties: {!r}", self.wsp)
|
||||
|
||||
self.state = None
|
||||
|
||||
|
|
@ -102,8 +112,8 @@ class GatewayWebsocket:
|
|||
encoding = self.wsp.encoding
|
||||
|
||||
encodings = {
|
||||
'json': (encode_json, decode_json),
|
||||
'etf': (encode_etf, decode_etf),
|
||||
"json": (encode_json, decode_json),
|
||||
"etf": (encode_etf, decode_etf),
|
||||
}
|
||||
|
||||
self.encoder, self.decoder = encodings[encoding]
|
||||
|
|
@ -111,16 +121,17 @@ class GatewayWebsocket:
|
|||
async def _chunked_send(self, data: bytes, chunk_size: int):
|
||||
"""Split data in chunk_size-big chunks and send them
|
||||
over the websocket."""
|
||||
log.debug('zlib-stream: chunking {} bytes into {}-byte chunks',
|
||||
len(data), chunk_size)
|
||||
log.debug(
|
||||
"zlib-stream: chunking {} bytes into {}-byte chunks", len(data), chunk_size
|
||||
)
|
||||
|
||||
total_chunks = 0
|
||||
for chunk in yield_chunks(data, chunk_size):
|
||||
total_chunks += 1
|
||||
log.debug('zlib-stream: chunk {}', total_chunks)
|
||||
log.debug("zlib-stream: chunk {}", total_chunks)
|
||||
await self.ws.send(chunk)
|
||||
|
||||
log.debug('zlib-stream: sent {} chunks', total_chunks)
|
||||
log.debug("zlib-stream: sent {} chunks", total_chunks)
|
||||
|
||||
async def _zlib_stream_send(self, encoded):
|
||||
"""Sending a single payload across multiple compressed
|
||||
|
|
@ -130,8 +141,12 @@ class GatewayWebsocket:
|
|||
data1 = self.wsp.zctx.compress(encoded)
|
||||
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
||||
|
||||
log.debug('zlib-stream: length {} -> compressed ({} + {})',
|
||||
len(encoded), len(data1), len(data2))
|
||||
log.debug(
|
||||
"zlib-stream: length {} -> compressed ({} + {})",
|
||||
len(encoded),
|
||||
len(data1),
|
||||
len(data2),
|
||||
)
|
||||
|
||||
if not data1:
|
||||
# if data1 is nothing, that might cause problems
|
||||
|
|
@ -139,8 +154,11 @@ class GatewayWebsocket:
|
|||
data1 = bytes([data2[0]])
|
||||
data2 = data2[1:]
|
||||
|
||||
log.debug('zlib-stream: len(data1) == 0, remaking as ({} + {})',
|
||||
len(data1), len(data2))
|
||||
log.debug(
|
||||
"zlib-stream: len(data1) == 0, remaking as ({} + {})",
|
||||
len(data1),
|
||||
len(data2),
|
||||
)
|
||||
|
||||
# NOTE: the old approach was ws.send(data1 + data2).
|
||||
# I changed this to a chunked send of data1 and data2
|
||||
|
|
@ -157,8 +175,7 @@ class GatewayWebsocket:
|
|||
await self._chunked_send(data2, 1024)
|
||||
|
||||
async def _zstd_stream_send(self, encoded):
|
||||
compressor = self.wsp.zsctx.stream_writer(
|
||||
WebsocketFileHandler(self.ws))
|
||||
compressor = self.wsp.zsctx.stream_writer(WebsocketFileHandler(self.ws))
|
||||
|
||||
compressor.write(encoded)
|
||||
compressor.flush(zstd.FLUSH_FRAME)
|
||||
|
|
@ -172,21 +189,23 @@ class GatewayWebsocket:
|
|||
encoded = self.encoder(payload)
|
||||
|
||||
if len(encoded) < 2048:
|
||||
log.debug('sending\n{}', pprint.pformat(payload))
|
||||
log.debug("sending\n{}", pprint.pformat(payload))
|
||||
else:
|
||||
log.debug('sending {}', pprint.pformat(payload))
|
||||
log.debug('sending op={} s={} t={} (too big)',
|
||||
payload.get('op'),
|
||||
payload.get('s'),
|
||||
payload.get('t'))
|
||||
log.debug("sending {}", pprint.pformat(payload))
|
||||
log.debug(
|
||||
"sending op={} s={} t={} (too big)",
|
||||
payload.get("op"),
|
||||
payload.get("s"),
|
||||
payload.get("t"),
|
||||
)
|
||||
|
||||
# treat encoded as bytes
|
||||
if not isinstance(encoded, bytes):
|
||||
encoded = encoded.encode()
|
||||
|
||||
if self.wsp.compress == 'zlib-stream':
|
||||
if self.wsp.compress == "zlib-stream":
|
||||
await self._zlib_stream_send(encoded)
|
||||
elif self.wsp.compress == 'zstd-stream':
|
||||
elif self.wsp.compress == "zstd-stream":
|
||||
await self._zstd_stream_send(encoded)
|
||||
elif self.state and self.state.compress and len(encoded) > 1024:
|
||||
# TODO: should we only compress on >1KB packets? or maybe we
|
||||
|
|
@ -203,16 +222,10 @@ class GatewayWebsocket:
|
|||
|
||||
async def send_op(self, op_code: int, data: Any):
|
||||
"""Send a packet but just the OP code information is filled in."""
|
||||
await self.send({
|
||||
'op': op_code,
|
||||
'd': data,
|
||||
|
||||
't': None,
|
||||
's': None
|
||||
})
|
||||
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
||||
|
||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
|
||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||
return bucket.update_rate_limit()
|
||||
|
||||
|
|
@ -221,19 +234,19 @@ class GatewayWebsocket:
|
|||
# if the client heartbeats in time,
|
||||
# this task will be cancelled.
|
||||
await asyncio.sleep(interval / 1000)
|
||||
await self.ws.close(4000, 'Heartbeat expired')
|
||||
await self.ws.close(4000, "Heartbeat expired")
|
||||
|
||||
self._cleanup()
|
||||
|
||||
def _hb_start(self, interval: int):
|
||||
# always refresh the heartbeat task
|
||||
# when possible
|
||||
task = self.wsp.tasks.get('heartbeat')
|
||||
task = self.wsp.tasks.get("heartbeat")
|
||||
if task:
|
||||
task.cancel()
|
||||
|
||||
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
|
||||
task_wrapper('hb wait', self._hb_wait(interval))
|
||||
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
|
||||
task_wrapper("hb wait", self._hb_wait(interval))
|
||||
)
|
||||
|
||||
async def _send_hello(self):
|
||||
|
|
@ -241,12 +254,9 @@ class GatewayWebsocket:
|
|||
# random heartbeat intervals
|
||||
interval = randint(40, 46) * 1000
|
||||
|
||||
await self.send_op(OP.HELLO, {
|
||||
'heartbeat_interval': interval,
|
||||
'_trace': [
|
||||
'lesbian-server'
|
||||
],
|
||||
})
|
||||
await self.send_op(
|
||||
OP.HELLO, {"heartbeat_interval": interval, "_trace": ["lesbian-server"]}
|
||||
)
|
||||
|
||||
self._hb_start(interval)
|
||||
|
||||
|
|
@ -255,16 +265,15 @@ class GatewayWebsocket:
|
|||
self.state.seq += 1
|
||||
|
||||
payload = {
|
||||
'op': OP.DISPATCH,
|
||||
't': event.upper(),
|
||||
's': self.state.seq,
|
||||
'd': data,
|
||||
"op": OP.DISPATCH,
|
||||
"t": event.upper(),
|
||||
"s": self.state.seq,
|
||||
"d": data,
|
||||
}
|
||||
|
||||
self.state.store[self.state.seq] = payload
|
||||
|
||||
log.debug('sending payload {!r} sid {}',
|
||||
event.upper(), self.state.session_id)
|
||||
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
|
||||
|
||||
await self.send(payload)
|
||||
|
||||
|
|
@ -274,16 +283,14 @@ class GatewayWebsocket:
|
|||
guild_ids = await self._guild_ids()
|
||||
|
||||
if self.state.bot:
|
||||
return [{
|
||||
'id': row,
|
||||
'unavailable': True,
|
||||
} for row in guild_ids]
|
||||
return [{"id": row, "unavailable": True} for row in guild_ids]
|
||||
|
||||
return [
|
||||
{
|
||||
**await self.storage.get_guild(guild_id, user_id),
|
||||
**await self.storage.get_guild_extra(guild_id, user_id,
|
||||
self.state.large)
|
||||
**await self.storage.get_guild_extra(
|
||||
guild_id, user_id, self.state.large
|
||||
),
|
||||
}
|
||||
for guild_id in guild_ids
|
||||
]
|
||||
|
|
@ -298,13 +305,13 @@ class GatewayWebsocket:
|
|||
for guild_obj in unavailable_guilds:
|
||||
# fetch full guild object including the 'large' field
|
||||
guild = await self.storage.get_guild_full(
|
||||
int(guild_obj['id']), self.state.user_id, self.state.large
|
||||
int(guild_obj["id"]), self.state.user_id, self.state.large
|
||||
)
|
||||
|
||||
if guild is None:
|
||||
continue
|
||||
|
||||
await self.dispatch('GUILD_CREATE', guild)
|
||||
await self.dispatch("GUILD_CREATE", guild)
|
||||
|
||||
async def _user_ready(self) -> dict:
|
||||
"""Fetch information about users in the READY packet.
|
||||
|
|
@ -317,28 +324,28 @@ class GatewayWebsocket:
|
|||
|
||||
relationships = await self.user_storage.get_relationships(user_id)
|
||||
|
||||
friend_ids = [int(r['user']['id']) for r in relationships
|
||||
if r['type'] == RelationshipType.FRIEND.value]
|
||||
friend_ids = [
|
||||
int(r["user"]["id"])
|
||||
for r in relationships
|
||||
if r["type"] == RelationshipType.FRIEND.value
|
||||
]
|
||||
|
||||
friend_presences = await self.ext.presence.friend_presences(friend_ids)
|
||||
settings = await self.user_storage.get_user_settings(user_id)
|
||||
|
||||
return {
|
||||
'user_settings': settings,
|
||||
'notes': await self.user_storage.fetch_notes(user_id),
|
||||
'relationships': relationships,
|
||||
'presences': friend_presences,
|
||||
'read_state': await self.user_storage.get_read_state(user_id),
|
||||
'user_guild_settings': await self.user_storage.get_guild_settings(
|
||||
user_id),
|
||||
|
||||
'friend_suggestion_count': 0,
|
||||
|
||||
"user_settings": settings,
|
||||
"notes": await self.user_storage.fetch_notes(user_id),
|
||||
"relationships": relationships,
|
||||
"presences": friend_presences,
|
||||
"read_state": await self.user_storage.get_read_state(user_id),
|
||||
"user_guild_settings": await self.user_storage.get_guild_settings(user_id),
|
||||
"friend_suggestion_count": 0,
|
||||
# those are unused default values.
|
||||
'connected_accounts': [],
|
||||
'experiments': [],
|
||||
'guild_experiments': [],
|
||||
'analytics_token': 'transbian',
|
||||
"connected_accounts": [],
|
||||
"experiments": [],
|
||||
"guild_experiments": [],
|
||||
"analytics_token": "transbian",
|
||||
}
|
||||
|
||||
async def dispatch_ready(self):
|
||||
|
|
@ -353,24 +360,21 @@ class GatewayWebsocket:
|
|||
# user, fetch info
|
||||
user_ready = await self._user_ready()
|
||||
|
||||
private_channels = (
|
||||
await self.user_storage.get_dms(user_id) +
|
||||
await self.user_storage.get_gdms(user_id)
|
||||
)
|
||||
private_channels = await self.user_storage.get_dms(
|
||||
user_id
|
||||
) + await self.user_storage.get_gdms(user_id)
|
||||
|
||||
base_ready = {
|
||||
'v': 6,
|
||||
'user': user,
|
||||
|
||||
'private_channels': private_channels,
|
||||
|
||||
'guilds': guilds,
|
||||
'session_id': self.state.session_id,
|
||||
'_trace': ['transbian'],
|
||||
'shard': self.state.shard,
|
||||
"v": 6,
|
||||
"user": user,
|
||||
"private_channels": private_channels,
|
||||
"guilds": guilds,
|
||||
"session_id": self.state.session_id,
|
||||
"_trace": ["transbian"],
|
||||
"shard": self.state.shard,
|
||||
}
|
||||
|
||||
await self.dispatch('READY', {**base_ready, **user_ready})
|
||||
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||
|
||||
# async dispatch of guilds
|
||||
self.ext.loop.create_task(self._guild_dispatch(guilds))
|
||||
|
|
@ -380,33 +384,32 @@ class GatewayWebsocket:
|
|||
"""
|
||||
current_shard, shard_count = shard
|
||||
|
||||
guilds = await self.ext.db.fetchval("""
|
||||
guilds = await self.ext.db.fetchval(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
recommended = max(int(guilds / 1200), 1)
|
||||
|
||||
if shard_count < recommended:
|
||||
raise ShardingRequired('Too many guilds for shard '
|
||||
f'{current_shard}')
|
||||
raise ShardingRequired("Too many guilds for shard " f"{current_shard}")
|
||||
|
||||
if guilds > 2500 and guilds / shard_count > 0.8:
|
||||
raise ShardingRequired('Too many shards. '
|
||||
f'(g={guilds} sc={shard_count})')
|
||||
raise ShardingRequired("Too many shards. " f"(g={guilds} sc={shard_count})")
|
||||
|
||||
if current_shard > shard_count:
|
||||
raise InvalidShard('Shard count > Total shards')
|
||||
raise InvalidShard("Shard count > Total shards")
|
||||
|
||||
async def _guild_ids(self) -> list:
|
||||
"""Get a list of Guild IDs that are tied to this connection.
|
||||
|
||||
The implementation is shard-aware.
|
||||
"""
|
||||
guild_ids = await self.user_storage.get_user_guilds(
|
||||
self.state.user_id
|
||||
)
|
||||
guild_ids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||
|
||||
shard_id = self.state.current_shard
|
||||
shard_count = self.state.shard_count
|
||||
|
|
@ -414,10 +417,7 @@ class GatewayWebsocket:
|
|||
def _get_shard(guild_id):
|
||||
return (guild_id >> 22) % shard_count
|
||||
|
||||
filtered = filter(
|
||||
lambda guild_id: _get_shard(guild_id) == shard_id,
|
||||
guild_ids
|
||||
)
|
||||
filtered = filter(lambda guild_id: _get_shard(guild_id) == shard_id, guild_ids)
|
||||
|
||||
return list(filtered)
|
||||
|
||||
|
|
@ -432,13 +432,17 @@ class GatewayWebsocket:
|
|||
|
||||
# subscribe the user to all dms they have OPENED.
|
||||
dms = await self.user_storage.get_dms(user_id)
|
||||
dm_ids = [int(dm['id']) for dm in dms]
|
||||
dm_ids = [int(dm["id"]) for dm in dms]
|
||||
|
||||
# fetch all group dms the user is a member of.
|
||||
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
||||
|
||||
log.info('subscribing to {} guilds {} dms {} gdms',
|
||||
len(guild_ids), len(dm_ids), len(gdm_ids))
|
||||
log.info(
|
||||
"subscribing to {} guilds {} dms {} gdms",
|
||||
len(guild_ids),
|
||||
len(dm_ids),
|
||||
len(gdm_ids),
|
||||
)
|
||||
|
||||
# guild_subscriptions:
|
||||
# enables dispatching of guild subscription events
|
||||
|
|
@ -447,10 +451,13 @@ class GatewayWebsocket:
|
|||
# we enable processing of guild_subscriptions by adding flags
|
||||
# when subscribing to the given backend. those are optional.
|
||||
channels_to_sub = [
|
||||
('guild', guild_ids,
|
||||
{'presence': guild_subscriptions, 'typing': guild_subscriptions}),
|
||||
('channel', dm_ids),
|
||||
('channel', gdm_ids),
|
||||
(
|
||||
"guild",
|
||||
guild_ids,
|
||||
{"presence": guild_subscriptions, "typing": guild_subscriptions},
|
||||
),
|
||||
("channel", dm_ids),
|
||||
("channel", gdm_ids),
|
||||
]
|
||||
|
||||
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||
|
|
@ -460,28 +467,26 @@ class GatewayWebsocket:
|
|||
# (their friends will also subscribe back
|
||||
# when they come online)
|
||||
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
||||
log.info('subscribing to {} friends', len(friend_ids))
|
||||
await self.ext.dispatcher.sub_many('friend', user_id, friend_ids)
|
||||
log.info("subscribing to {} friends", len(friend_ids))
|
||||
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||
|
||||
async def update_status(self, status: dict):
|
||||
"""Update the status of the current websocket connection."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
if self._check_ratelimit('presence', self.state.session_id):
|
||||
if self._check_ratelimit("presence", self.state.session_id):
|
||||
# Presence Updates beyond the ratelimit
|
||||
# are just silently dropped.
|
||||
return
|
||||
|
||||
default_status = {
|
||||
'afk': False,
|
||||
|
||||
"afk": False,
|
||||
# TODO: fetch status from settings
|
||||
'status': 'online',
|
||||
'game': None,
|
||||
|
||||
"status": "online",
|
||||
"game": None,
|
||||
# TODO: this
|
||||
'since': 0,
|
||||
"since": 0,
|
||||
}
|
||||
|
||||
status = {**(status or {}), **default_status}
|
||||
|
|
@ -489,39 +494,40 @@ class GatewayWebsocket:
|
|||
try:
|
||||
status = validate(status, GW_STATUS_UPDATE)
|
||||
except BadRequest as err:
|
||||
log.warning(f'Invalid status update: {err}')
|
||||
log.warning(f"Invalid status update: {err}")
|
||||
return
|
||||
|
||||
# try to extract game from activities
|
||||
# when game not provided
|
||||
if not status.get('game'):
|
||||
if not status.get("game"):
|
||||
try:
|
||||
game = status['activities'][0]
|
||||
game = status["activities"][0]
|
||||
except (KeyError, IndexError):
|
||||
game = None
|
||||
else:
|
||||
game = status['game']
|
||||
game = status["game"]
|
||||
|
||||
# construct final status
|
||||
status = {
|
||||
'afk': status.get('afk', False),
|
||||
'status': status.get('status', 'online'),
|
||||
'game': game,
|
||||
'since': status.get('since', 0),
|
||||
"afk": status.get("afk", False),
|
||||
"status": status.get("status", "online"),
|
||||
"game": game,
|
||||
"since": status.get("since", 0),
|
||||
}
|
||||
|
||||
self.state.presence = status
|
||||
log.info(f'Updating presence status={status["status"]} for '
|
||||
f'uid={self.state.user_id}')
|
||||
await self.ext.presence.dispatch_pres(self.state.user_id,
|
||||
self.state.presence)
|
||||
log.info(
|
||||
f'Updating presence status={status["status"]} for '
|
||||
f"uid={self.state.user_id}"
|
||||
)
|
||||
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
|
||||
|
||||
async def handle_1(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 1 Heartbeat packets."""
|
||||
# give the client 3 more seconds before we
|
||||
# close the websocket
|
||||
self._hb_start((46 + 3) * 1000)
|
||||
cliseq = payload.get('d')
|
||||
cliseq = payload.get("d")
|
||||
|
||||
if self.state:
|
||||
self.state.last_seq = cliseq
|
||||
|
|
@ -529,39 +535,42 @@ class GatewayWebsocket:
|
|||
await self.send_op(OP.HEARTBEAT_ACK, None)
|
||||
|
||||
async def _connect_ratelimit(self, user_id: int):
|
||||
if self._check_ratelimit('connect', user_id):
|
||||
if self._check_ratelimit("connect", user_id):
|
||||
await self.invalidate_session(False)
|
||||
raise WebsocketClose(4009, 'You are being ratelimited.')
|
||||
raise WebsocketClose(4009, "You are being ratelimited.")
|
||||
|
||||
if self._check_ratelimit('session', user_id):
|
||||
if self._check_ratelimit("session", user_id):
|
||||
await self.invalidate_session(False)
|
||||
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.')
|
||||
raise WebsocketClose(4004, "Websocket Session Ratelimit reached.")
|
||||
|
||||
async def handle_2(self, payload: Dict[str, Any]):
|
||||
"""Handle the OP 2 Identify packet."""
|
||||
try:
|
||||
data = payload['d']
|
||||
token = data['token']
|
||||
data = payload["d"]
|
||||
token = data["token"]
|
||||
except KeyError:
|
||||
raise DecodeError('Invalid identify parameters')
|
||||
raise DecodeError("Invalid identify parameters")
|
||||
|
||||
compress = data.get('compress', False)
|
||||
large = data.get('large_threshold', 50)
|
||||
compress = data.get("compress", False)
|
||||
large = data.get("large_threshold", 50)
|
||||
|
||||
shard = data.get('shard', [0, 1])
|
||||
presence = data.get('presence')
|
||||
shard = data.get("shard", [0, 1])
|
||||
presence = data.get("presence")
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.ext.db)
|
||||
except (Unauthorized, Forbidden):
|
||||
raise WebsocketClose(4004, 'Authentication failed')
|
||||
raise WebsocketClose(4004, "Authentication failed")
|
||||
|
||||
await self._connect_ratelimit(user_id)
|
||||
|
||||
bot = await self.ext.db.fetchval("""
|
||||
bot = await self.ext.db.fetchval(
|
||||
"""
|
||||
SELECT bot FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
await self._check_shards(shard, user_id)
|
||||
|
||||
|
|
@ -574,19 +583,19 @@ class GatewayWebsocket:
|
|||
shard=shard,
|
||||
current_shard=shard[0],
|
||||
shard_count=shard[1],
|
||||
ws=self
|
||||
ws=self,
|
||||
)
|
||||
|
||||
# link the state to the user
|
||||
self.ext.state_manager.insert(self.state)
|
||||
|
||||
await self.update_status(presence)
|
||||
await self.subscribe_all(data.get('guild_subscriptions', True))
|
||||
await self.subscribe_all(data.get("guild_subscriptions", True))
|
||||
await self.dispatch_ready()
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 3 Status Update."""
|
||||
presence = payload['d']
|
||||
presence = payload["d"]
|
||||
|
||||
# update_status will take care of validation and
|
||||
# setting new presence to state
|
||||
|
|
@ -597,27 +606,27 @@ class GatewayWebsocket:
|
|||
user settings."""
|
||||
try:
|
||||
# TODO: fetch from settings if not provided
|
||||
self_deaf = bool(data['self_deaf'])
|
||||
self_mute = bool(data['self_mute'])
|
||||
self_deaf = bool(data["self_deaf"])
|
||||
self_mute = bool(data["self_mute"])
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
return {
|
||||
'deaf': state.deaf,
|
||||
'mute': state.mute,
|
||||
'self_deaf': self_deaf,
|
||||
'self_mute': self_mute,
|
||||
"deaf": state.deaf,
|
||||
"mute": state.mute,
|
||||
"self_deaf": self_deaf,
|
||||
"self_mute": self_mute,
|
||||
}
|
||||
|
||||
async def handle_4(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 4 Voice Status Update."""
|
||||
data = payload['d']
|
||||
data = payload["d"]
|
||||
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
channel_id = int_(data.get('channel_id'))
|
||||
guild_id = int_(data.get('guild_id'))
|
||||
channel_id = int_(data.get("channel_id"))
|
||||
guild_id = int_(data.get("guild_id"))
|
||||
|
||||
# if its null and null, disconnect the user from any voice
|
||||
# TODO: maybe just leave from DMs? idk...
|
||||
|
|
@ -630,9 +639,7 @@ class GatewayWebsocket:
|
|||
return await self.ext.voice.leave(guild_id, self.state.user_id)
|
||||
|
||||
# fetch an existing state given user and guild OR user and channel
|
||||
chan_type = ChannelType(
|
||||
await self.storage.get_chan_type(channel_id)
|
||||
)
|
||||
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
|
||||
|
||||
state_id2 = channel_id
|
||||
|
||||
|
|
@ -704,39 +711,38 @@ class GatewayWebsocket:
|
|||
# ignore unknown seqs
|
||||
continue
|
||||
|
||||
payload_t = payload.get('t')
|
||||
payload_t = payload.get("t")
|
||||
|
||||
# presence resumption happens
|
||||
# on a separate event, PRESENCE_REPLACE.
|
||||
if payload_t == 'PRESENCE_UPDATE':
|
||||
presences.append(payload.get('d'))
|
||||
if payload_t == "PRESENCE_UPDATE":
|
||||
presences.append(payload.get("d"))
|
||||
continue
|
||||
|
||||
await self.send(payload)
|
||||
except Exception:
|
||||
log.exception('error while resuming')
|
||||
log.exception("error while resuming")
|
||||
await self.invalidate_session(False)
|
||||
return
|
||||
|
||||
if presences:
|
||||
await self.dispatch('PRESENCE_REPLACE', presences)
|
||||
await self.dispatch("PRESENCE_REPLACE", presences)
|
||||
|
||||
await self.dispatch('RESUMED', {})
|
||||
await self.dispatch("RESUMED", {})
|
||||
|
||||
async def handle_6(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 6 Resume."""
|
||||
data = payload['d']
|
||||
data = payload["d"]
|
||||
|
||||
try:
|
||||
token, sess_id, seq = data['token'], \
|
||||
data['session_id'], data['seq']
|
||||
token, sess_id, seq = data["token"], data["session_id"], data["seq"]
|
||||
except KeyError:
|
||||
raise DecodeError('Invalid resume payload')
|
||||
raise DecodeError("Invalid resume payload")
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.ext.db)
|
||||
except (Unauthorized, Forbidden):
|
||||
raise WebsocketClose(4004, 'Invalid token')
|
||||
raise WebsocketClose(4004, "Invalid token")
|
||||
|
||||
try:
|
||||
state = self.ext.state_manager.fetch(user_id, sess_id)
|
||||
|
|
@ -744,11 +750,11 @@ class GatewayWebsocket:
|
|||
return await self.invalidate_session(False)
|
||||
|
||||
if seq > state.seq:
|
||||
raise WebsocketClose(4007, 'Invalid seq')
|
||||
raise WebsocketClose(4007, "Invalid seq")
|
||||
|
||||
# check if a websocket isnt on that state already
|
||||
if state.ws is not None:
|
||||
log.info('Resuming failed, websocket already connected')
|
||||
log.info("Resuming failed, websocket already connected")
|
||||
return await self.invalidate_session(False)
|
||||
|
||||
# relink this connection
|
||||
|
|
@ -757,8 +763,9 @@ class GatewayWebsocket:
|
|||
|
||||
await self._resume(range(seq, state.seq))
|
||||
|
||||
async def _req_guild_members(self, guild_id, user_ids: List[int],
|
||||
query: str, limit: int):
|
||||
async def _req_guild_members(
|
||||
self, guild_id, user_ids: List[int], query: str, limit: int
|
||||
):
|
||||
try:
|
||||
guild_id = int(guild_id)
|
||||
except (TypeError, ValueError):
|
||||
|
|
@ -778,32 +785,32 @@ class GatewayWebsocket:
|
|||
# ASSUMPTION: requesting user_ids means we don't do query.
|
||||
if user_ids:
|
||||
members = await self.storage.get_member_multi(guild_id, user_ids)
|
||||
mids = [m['user']['id'] for m in members]
|
||||
mids = [m["user"]["id"] for m in members]
|
||||
not_found = [uid for uid in user_ids if uid not in mids]
|
||||
|
||||
await self.dispatch('GUILD_MEMBERS_CHUNK', {
|
||||
'guild_id': str(guild_id),
|
||||
'members': members,
|
||||
'not_found': not_found,
|
||||
})
|
||||
await self.dispatch(
|
||||
"GUILD_MEMBERS_CHUNK",
|
||||
{"guild_id": str(guild_id), "members": members, "not_found": not_found},
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# do the search
|
||||
result = await self.storage.query_members(guild_id, query, limit)
|
||||
await self.dispatch('GUILD_MEMBERS_CHUNK', {
|
||||
'guild_id': str(guild_id),
|
||||
'members': result
|
||||
})
|
||||
await self.dispatch(
|
||||
"GUILD_MEMBERS_CHUNK", {"guild_id": str(guild_id), "members": result}
|
||||
)
|
||||
|
||||
async def handle_8(self, payload: Dict):
|
||||
"""Handle OP 8 Request Guild Members."""
|
||||
data = payload['d']
|
||||
gids = data['guild_id']
|
||||
data = payload["d"]
|
||||
gids = data["guild_id"]
|
||||
|
||||
uids, query, limit = data.get('user_ids', []), \
|
||||
data.get('query', ''), \
|
||||
data.get('limit', 0)
|
||||
uids, query, limit = (
|
||||
data.get("user_ids", []),
|
||||
data.get("query", ""),
|
||||
data.get("limit", 0),
|
||||
)
|
||||
|
||||
if isinstance(gids, str):
|
||||
await self._req_guild_members(gids, uids, query, limit)
|
||||
|
|
@ -820,23 +827,21 @@ class GatewayWebsocket:
|
|||
GUILD_SYNC event with that info.
|
||||
"""
|
||||
members = await self.storage.get_member_data(guild_id)
|
||||
member_ids = [int(m['user']['id']) for m in members]
|
||||
member_ids = [int(m["user"]["id"]) for m in members]
|
||||
|
||||
log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members')
|
||||
log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members")
|
||||
presences = await self.presence.guild_presences(member_ids, guild_id)
|
||||
|
||||
await self.dispatch('GUILD_SYNC', {
|
||||
'id': str(guild_id),
|
||||
'presences': presences,
|
||||
'members': members,
|
||||
})
|
||||
await self.dispatch(
|
||||
"GUILD_SYNC",
|
||||
{"id": str(guild_id), "presences": presences, "members": members},
|
||||
)
|
||||
|
||||
async def handle_12(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 12 Guild Sync."""
|
||||
data = payload['d']
|
||||
data = payload["d"]
|
||||
|
||||
gids = await self.user_storage.get_user_guilds(
|
||||
self.state.user_id)
|
||||
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||
|
||||
for guild_id in data:
|
||||
try:
|
||||
|
|
@ -931,35 +936,33 @@ class GatewayWebsocket:
|
|||
]
|
||||
}
|
||||
"""
|
||||
data = payload['d']
|
||||
data = payload["d"]
|
||||
|
||||
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||
guild_id = int(data['guild_id'])
|
||||
guild_id = int(data["guild_id"])
|
||||
|
||||
# make sure to not extract info you shouldn't get
|
||||
if guild_id not in gids:
|
||||
return
|
||||
|
||||
log.debug('lazy request: members: {}',
|
||||
data.get('members', []))
|
||||
log.debug("lazy request: members: {}", data.get("members", []))
|
||||
|
||||
# make shard query
|
||||
lazy_guilds = self.ext.dispatcher.backends['lazy_guild']
|
||||
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
|
||||
|
||||
for chan_id, ranges in data.get('channels', {}).items():
|
||||
for chan_id, ranges in data.get("channels", {}).items():
|
||||
chan_id = int(chan_id)
|
||||
member_list = await lazy_guilds.get_gml(chan_id)
|
||||
|
||||
perms = await get_permissions(
|
||||
self.state.user_id, chan_id, storage=self.storage)
|
||||
self.state.user_id, chan_id, storage=self.storage
|
||||
)
|
||||
|
||||
if not perms.bits.read_messages:
|
||||
# ignore requests to unknown channels
|
||||
return
|
||||
|
||||
await member_list.shard_query(
|
||||
self.state.session_id, ranges
|
||||
)
|
||||
await member_list.shard_query(self.state.session_id, ranges)
|
||||
|
||||
async def _handle_23(self, payload):
|
||||
# TODO reverse-engineer opcode 23, sent by client
|
||||
|
|
@ -968,21 +971,21 @@ class GatewayWebsocket:
|
|||
async def _process_message(self, payload):
|
||||
"""Process a single message coming in from the client."""
|
||||
try:
|
||||
op_code = payload['op']
|
||||
op_code = payload["op"]
|
||||
except KeyError:
|
||||
raise UnknownOPCode('No OP code')
|
||||
raise UnknownOPCode("No OP code")
|
||||
|
||||
try:
|
||||
handler = getattr(self, f'handle_{op_code}')
|
||||
handler = getattr(self, f"handle_{op_code}")
|
||||
except AttributeError:
|
||||
log.warning('Payload with bad op: {}', pprint.pformat(payload))
|
||||
raise UnknownOPCode(f'Bad OP code: {op_code}')
|
||||
log.warning("Payload with bad op: {}", pprint.pformat(payload))
|
||||
raise UnknownOPCode(f"Bad OP code: {op_code}")
|
||||
|
||||
await handler(payload)
|
||||
|
||||
async def _msg_ratelimit(self):
|
||||
if self._check_ratelimit('messages', self.state.session_id):
|
||||
raise WebsocketClose(4008, 'You are being ratelimited.')
|
||||
if self._check_ratelimit("messages", self.state.session_id):
|
||||
raise WebsocketClose(4008, "You are being ratelimited.")
|
||||
|
||||
async def _listen_messages(self):
|
||||
"""Listen for messages coming in from the websocket."""
|
||||
|
|
@ -990,15 +993,15 @@ class GatewayWebsocket:
|
|||
# close anyone trying to login while the
|
||||
# server is shutting down
|
||||
if self.ext.state_manager.closed:
|
||||
raise WebsocketClose(4000, 'state manager closed')
|
||||
raise WebsocketClose(4000, "state manager closed")
|
||||
|
||||
if not self.ext.state_manager.accept_new:
|
||||
raise WebsocketClose(4000, 'state manager closed for new')
|
||||
raise WebsocketClose(4000, "state manager closed for new")
|
||||
|
||||
while True:
|
||||
message = await self.ws.recv()
|
||||
if len(message) > 4096:
|
||||
raise DecodeError('Payload length exceeded')
|
||||
raise DecodeError("Payload length exceeded")
|
||||
|
||||
if self.state:
|
||||
await self._msg_ratelimit()
|
||||
|
|
@ -1033,17 +1036,9 @@ class GatewayWebsocket:
|
|||
|
||||
# there arent any other states with websocket
|
||||
if not with_ws:
|
||||
offline = {
|
||||
'afk': False,
|
||||
'status': 'offline',
|
||||
'game': None,
|
||||
'since': 0,
|
||||
}
|
||||
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||
|
||||
await self.ext.presence.dispatch_pres(
|
||||
user_id,
|
||||
offline
|
||||
)
|
||||
await self.ext.presence.dispatch_pres(user_id, offline)
|
||||
|
||||
async def run(self):
|
||||
"""Wrap :meth:`listen_messages` inside
|
||||
|
|
@ -1052,12 +1047,12 @@ class GatewayWebsocket:
|
|||
await self._send_hello()
|
||||
await self._listen_messages()
|
||||
except websockets.exceptions.ConnectionClosed as err:
|
||||
log.warning('conn close, state={}, err={}', self.state, err)
|
||||
log.warning("conn close, state={}, err={}", self.state, err)
|
||||
except WebsocketClose as err:
|
||||
log.warning('ws close, state={} err={}', self.state, err)
|
||||
log.warning("ws close, state={} err={}", self.state, err)
|
||||
await self.ws.close(code=err.code, reason=err.reason)
|
||||
except Exception as err:
|
||||
log.exception('An exception has occoured. state={}', self.state)
|
||||
log.exception("An exception has occoured. state={}", self.state)
|
||||
await self.ws.close(code=4000, reason=repr(err))
|
||||
finally:
|
||||
user_id = self.state.user_id if self.state else None
|
||||
|
|
|
|||
|
|
@ -17,19 +17,21 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
|
||||
class GuildMemoryStore:
|
||||
"""Store in-memory properties about guilds.
|
||||
|
||||
I could have just used Redis... probably too overkill to add
|
||||
aioredis to the already long depedency list, plus, I don't need
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
|
||||
def get(self, guild_id: int, attribute: str, default=None):
|
||||
"""get a key"""
|
||||
return self._store.get(f'{guild_id}:{attribute}', default)
|
||||
return self._store.get(f"{guild_id}:{attribute}", default)
|
||||
|
||||
def set(self, guild_id: int, attribute: str, value):
|
||||
"""set a key"""
|
||||
self._store[f'{guild_id}:{attribute}'] = value
|
||||
self._store[f"{guild_id}:{attribute}"] = value
|
||||
|
|
|
|||
|
|
@ -33,47 +33,42 @@ from logbook import Logger
|
|||
from PIL import Image
|
||||
|
||||
|
||||
IMAGE_FOLDER = Path('./images')
|
||||
IMAGE_FOLDER = Path("./images")
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
EXTENSIONS = {
|
||||
'image/jpeg': 'jpeg',
|
||||
'image/webp': 'webp'
|
||||
}
|
||||
EXTENSIONS = {"image/jpeg": "jpeg", "image/webp": "webp"}
|
||||
|
||||
|
||||
MIMES = {
|
||||
'jpg': 'image/jpeg',
|
||||
'jpe': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'webp': 'image/webp',
|
||||
"jpg": "image/jpeg",
|
||||
"jpe": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
}
|
||||
|
||||
STATIC_IMAGE_MIMES = [
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/webp'
|
||||
]
|
||||
STATIC_IMAGE_MIMES = ["image/png", "image/jpeg", "image/webp"]
|
||||
|
||||
|
||||
def get_ext(mime: str) -> str:
|
||||
if mime in EXTENSIONS:
|
||||
return EXTENSIONS[mime]
|
||||
|
||||
extensions = mimetypes.guess_all_extensions(mime)
|
||||
return extensions[0].strip('.')
|
||||
return extensions[0].strip(".")
|
||||
|
||||
|
||||
def get_mime(ext: str):
|
||||
if ext in MIMES:
|
||||
return MIMES[ext]
|
||||
|
||||
return mimetypes.types_map[f'.{ext}']
|
||||
return mimetypes.types_map[f".{ext}"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Icon:
|
||||
"""Main icon class"""
|
||||
|
||||
key: Optional[str]
|
||||
icon_hash: Optional[str]
|
||||
mime: Optional[str]
|
||||
|
|
@ -85,7 +80,7 @@ class Icon:
|
|||
return None
|
||||
|
||||
ext = get_ext(self.mime)
|
||||
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
|
||||
return str(IMAGE_FOLDER / f"{self.key}_{self.icon_hash}.{ext}")
|
||||
|
||||
@property
|
||||
def as_pathlib(self) -> Optional[Path]:
|
||||
|
|
@ -106,13 +101,14 @@ class Icon:
|
|||
|
||||
class ImageError(Exception):
|
||||
"""Image error class."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def to_raw(data_type: str, data: str) -> Optional[bytes]:
|
||||
"""Given a data type in the data URI and data,
|
||||
give the raw bytes being encoded."""
|
||||
if data_type == 'base64':
|
||||
if data_type == "base64":
|
||||
return base64.b64decode(data)
|
||||
|
||||
return None
|
||||
|
|
@ -136,7 +132,7 @@ def _calculate_hash(fhandler) -> str:
|
|||
"""
|
||||
hash_obj = sha256()
|
||||
|
||||
for chunk in iter(lambda: fhandler.read(4096), b''):
|
||||
for chunk in iter(lambda: fhandler.read(4096), b""):
|
||||
hash_obj.update(chunk)
|
||||
|
||||
# so that we can reuse the same handler
|
||||
|
|
@ -162,39 +158,36 @@ async def calculate_hash(fhandle, loop=None) -> str:
|
|||
def parse_data_uri(string) -> tuple:
|
||||
"""Extract image data."""
|
||||
try:
|
||||
header, headered_data = string.split(';')
|
||||
header, headered_data = string.split(";")
|
||||
|
||||
_, given_mime = header.split(':')
|
||||
data_type, data = headered_data.split(',')
|
||||
_, given_mime = header.split(":")
|
||||
data_type, data = headered_data.split(",")
|
||||
|
||||
raw_data = to_raw(data_type, data)
|
||||
if raw_data is None:
|
||||
raise ImageError('Unknown data header')
|
||||
raise ImageError("Unknown data header")
|
||||
|
||||
return given_mime, raw_data
|
||||
except ValueError:
|
||||
raise ImageError('data URI invalid syntax')
|
||||
raise ImageError("data URI invalid syntax")
|
||||
|
||||
|
||||
def _gen_update_sql(scope: str) -> str:
|
||||
# match a scope to (table, field)
|
||||
field = {
|
||||
'user': 'avatar',
|
||||
'guild': 'icon',
|
||||
'splash': 'splash',
|
||||
'banner': 'banner',
|
||||
|
||||
'channel-icons': 'icon',
|
||||
"user": "avatar",
|
||||
"guild": "icon",
|
||||
"splash": "splash",
|
||||
"banner": "banner",
|
||||
"channel-icons": "icon",
|
||||
}[scope]
|
||||
|
||||
table = {
|
||||
'user': 'users',
|
||||
|
||||
'guild': 'guilds',
|
||||
'splash': 'guilds',
|
||||
'banner': 'guilds',
|
||||
|
||||
'channel-icons': 'group_dm_channels'
|
||||
"user": "users",
|
||||
"guild": "guilds",
|
||||
"splash": "guilds",
|
||||
"banner": "guilds",
|
||||
"channel-icons": "group_dm_channels",
|
||||
}[scope]
|
||||
|
||||
return f"""
|
||||
|
|
@ -204,10 +197,10 @@ def _gen_update_sql(scope: str) -> str:
|
|||
|
||||
def _invalid(kwargs: dict) -> Optional[Icon]:
|
||||
"""Send an invalid value."""
|
||||
if not kwargs.get('always_icon', False):
|
||||
if not kwargs.get("always_icon", False):
|
||||
return None
|
||||
|
||||
return Icon(None, None, '')
|
||||
return Icon(None, None, "")
|
||||
|
||||
|
||||
def try_unlink(path: Union[Path, str]):
|
||||
|
|
@ -225,18 +218,17 @@ def try_unlink(path: Union[Path, str]):
|
|||
async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
||||
"""Resize a GIF image."""
|
||||
# generate a temporary file to call gifsticle to and from.
|
||||
input_fd, input_path = tempfile.mkstemp(suffix='.gif')
|
||||
_, output_path = tempfile.mkstemp(suffix='.gif')
|
||||
input_fd, input_path = tempfile.mkstemp(suffix=".gif")
|
||||
_, output_path = tempfile.mkstemp(suffix=".gif")
|
||||
|
||||
input_handler = os.fdopen(input_fd, 'wb')
|
||||
input_handler = os.fdopen(input_fd, "wb")
|
||||
|
||||
# make sure its valid image data
|
||||
data_fd = BytesIO(raw_data)
|
||||
image = Image.open(data_fd)
|
||||
image.close()
|
||||
|
||||
log.info('resizing a GIF from {} to {}',
|
||||
image.size, target)
|
||||
log.info("resizing a GIF from {} to {}", image.size, target)
|
||||
|
||||
# insert image info on input_handler
|
||||
# close it to make it ready for consumption by gifsicle
|
||||
|
|
@ -244,12 +236,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
|||
input_handler.close()
|
||||
|
||||
# call gifsicle under subprocess
|
||||
log.debug('input: {}', input_path)
|
||||
log.debug('output: {}', output_path)
|
||||
log.debug("input: {}", input_path)
|
||||
log.debug("output: {}", output_path)
|
||||
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
f'gifsicle --resize {target[0]}x{target[1]} '
|
||||
f'{input_path} > {output_path}',
|
||||
f"gifsicle --resize {target[0]}x{target[1]} " f"{input_path} > {output_path}",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
|
@ -257,11 +248,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
|||
# run it, etc.
|
||||
out, err = await process.communicate()
|
||||
|
||||
log.debug('out + err from gifsicle: {}', out + err)
|
||||
log.debug("out + err from gifsicle: {}", out + err)
|
||||
|
||||
# write over an empty data_fd
|
||||
data_fd = BytesIO()
|
||||
output_handler = open(output_path, 'rb')
|
||||
output_handler = open(output_path, "rb")
|
||||
data_fd.write(output_handler.read())
|
||||
|
||||
# close unused handlers
|
||||
|
|
@ -283,40 +274,40 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
|||
|
||||
class IconManager:
|
||||
"""Main icon manager."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.storage = app.storage
|
||||
|
||||
async def _convert_ext(self, icon: Icon, target: str):
|
||||
target = 'jpeg' if target == 'jpg' else target
|
||||
target = "jpeg" if target == "jpg" else target
|
||||
|
||||
target_mime = get_mime(target)
|
||||
log.info('converting from {} to {}', icon.mime, target_mime)
|
||||
log.info("converting from {} to {}", icon.mime, target_mime)
|
||||
|
||||
target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}'
|
||||
target_path = IMAGE_FOLDER / f"{icon.key}_{icon.icon_hash}.{target}"
|
||||
|
||||
if target_path.exists():
|
||||
return Icon(icon.key, icon.icon_hash, target_mime)
|
||||
|
||||
image = Image.open(icon.as_path)
|
||||
target_fd = target_path.open('wb')
|
||||
target_fd = target_path.open("wb")
|
||||
|
||||
if target == 'jpeg':
|
||||
image = image.convert('RGB')
|
||||
if target == "jpeg":
|
||||
image = image.convert("RGB")
|
||||
|
||||
image.save(target_fd, format=target)
|
||||
target_fd.close()
|
||||
|
||||
return Icon(icon.key, icon.icon_hash, target_mime)
|
||||
|
||||
async def generic_get(self, scope, key, icon_hash,
|
||||
**kwargs) -> Optional[Icon]:
|
||||
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Optional[Icon]:
|
||||
"""Get any icon."""
|
||||
|
||||
log.debug('GET {} {} {}', scope, key, icon_hash)
|
||||
log.debug("GET {} {} {}", scope, key, icon_hash)
|
||||
key = str(key)
|
||||
|
||||
hash_query = 'AND hash = $3' if icon_hash else ''
|
||||
hash_query = "AND hash = $3" if icon_hash else ""
|
||||
|
||||
# hacky solution to only add icon_hash
|
||||
# when needed.
|
||||
|
|
@ -325,18 +316,21 @@ class IconManager:
|
|||
if icon_hash:
|
||||
args.append(icon_hash)
|
||||
|
||||
icon_row = await self.storage.db.fetchrow(f"""
|
||||
icon_row = await self.storage.db.fetchrow(
|
||||
f"""
|
||||
SELECT key, hash, mime
|
||||
FROM icons
|
||||
WHERE scope = $1
|
||||
AND key = $2
|
||||
{hash_query}
|
||||
""", *args)
|
||||
""",
|
||||
*args,
|
||||
)
|
||||
|
||||
if not icon_row:
|
||||
return None
|
||||
|
||||
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
|
||||
icon = Icon(icon_row["key"], icon_row["hash"], icon_row["mime"])
|
||||
|
||||
# ensure we aren't messing with NULLs everywhere.
|
||||
if icon.as_pathlib is None:
|
||||
|
|
@ -349,18 +343,16 @@ class IconManager:
|
|||
if icon.extension is None:
|
||||
return None
|
||||
|
||||
if 'ext' in kwargs and kwargs['ext'] != icon.extension:
|
||||
return await self._convert_ext(icon, kwargs['ext'])
|
||||
if "ext" in kwargs and kwargs["ext"] != icon.extension:
|
||||
return await self._convert_ext(icon, kwargs["ext"])
|
||||
|
||||
return icon
|
||||
|
||||
async def get_guild_icon(self, guild_id: int, icon_hash: str, **kwargs):
|
||||
"""Get an icon for a guild."""
|
||||
return await self.generic_get(
|
||||
'guild', guild_id, icon_hash, **kwargs)
|
||||
return await self.generic_get("guild", guild_id, icon_hash, **kwargs)
|
||||
|
||||
async def put(self, scope: str, key: str,
|
||||
b64_data: str, **kwargs) -> Icon:
|
||||
async def put(self, scope: str, key: str, b64_data: str, **kwargs) -> Icon:
|
||||
"""Insert an icon."""
|
||||
if b64_data is None:
|
||||
return _invalid(kwargs)
|
||||
|
|
@ -373,23 +365,22 @@ class IconManager:
|
|||
# get an extension for the given data uri
|
||||
extension = get_ext(mime)
|
||||
|
||||
if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']:
|
||||
if "bsize" in kwargs and len(raw_data) > kwargs["bsize"]:
|
||||
return _invalid(kwargs)
|
||||
|
||||
# size management is different for gif files
|
||||
# as they're composed of multiple frames.
|
||||
if 'size' in kwargs and mime == 'image/gif':
|
||||
data_fd, raw_data = await resize_gif(raw_data, kwargs['size'])
|
||||
elif 'size' in kwargs:
|
||||
if "size" in kwargs and mime == "image/gif":
|
||||
data_fd, raw_data = await resize_gif(raw_data, kwargs["size"])
|
||||
elif "size" in kwargs:
|
||||
image = Image.open(data_fd)
|
||||
|
||||
if mime == 'image/jpeg':
|
||||
if mime == "image/jpeg":
|
||||
image = image.convert("RGB")
|
||||
|
||||
want = kwargs['size']
|
||||
want = kwargs["size"]
|
||||
|
||||
log.info('resizing from {} to {}',
|
||||
image.size, want)
|
||||
log.info("resizing from {} to {}", image.size, want)
|
||||
|
||||
resized = image.resize(want, resample=Image.LANCZOS)
|
||||
|
||||
|
|
@ -404,23 +395,26 @@ class IconManager:
|
|||
|
||||
# calculate sha256
|
||||
# ignore icon hashes if we're talking about emoji
|
||||
icon_hash = (await calculate_hash(data_fd)
|
||||
if scope != 'emoji'
|
||||
else None)
|
||||
icon_hash = await calculate_hash(data_fd) if scope != "emoji" else None
|
||||
|
||||
if scope == 'user' and mime == 'image/gif':
|
||||
icon_hash = f'a_{icon_hash}'
|
||||
if scope == "user" and mime == "image/gif":
|
||||
icon_hash = f"a_{icon_hash}"
|
||||
|
||||
log.debug('PUT icon {!r} {!r} {!r} {!r}',
|
||||
scope, key, icon_hash, mime)
|
||||
log.debug("PUT icon {!r} {!r} {!r} {!r}", scope, key, icon_hash, mime)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
INSERT INTO icons (scope, key, hash, mime)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
""", scope, str(key), icon_hash, mime)
|
||||
""",
|
||||
scope,
|
||||
str(key),
|
||||
icon_hash,
|
||||
mime,
|
||||
)
|
||||
|
||||
# write it off to fs
|
||||
icon_path = IMAGE_FOLDER / f'{key}_{icon_hash}.{extension}'
|
||||
icon_path = IMAGE_FOLDER / f"{key}_{icon_hash}.{extension}"
|
||||
icon_path.write_bytes(raw_data)
|
||||
|
||||
# copy from data_fd to icon_fd
|
||||
|
|
@ -434,57 +428,80 @@ class IconManager:
|
|||
if not icon:
|
||||
return
|
||||
|
||||
log.debug('DEL {}',
|
||||
icon)
|
||||
log.debug("DEL {}", icon)
|
||||
|
||||
# dereference
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET avatar = NULL
|
||||
WHERE avatar = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
UPDATE group_dm_channels
|
||||
SET icon = NULL
|
||||
WHERE icon = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
DELETE FROM guild_emoji
|
||||
WHERE image = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET icon = NULL
|
||||
WHERE icon = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET splash = NULL
|
||||
WHERE splash = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
UPDATE guilds
|
||||
SET banner = NULL
|
||||
WHERE banner = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
UPDATE group_dm_channels
|
||||
SET icon = NULL
|
||||
WHERE icon = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
await self.storage.db.execute("""
|
||||
await self.storage.db.execute(
|
||||
"""
|
||||
DELETE FROM icons
|
||||
WHERE hash = $1
|
||||
""", icon.icon_hash)
|
||||
""",
|
||||
icon.icon_hash,
|
||||
)
|
||||
|
||||
paths = IMAGE_FOLDER.glob(f'{icon.key}_{icon.icon_hash}.*')
|
||||
paths = IMAGE_FOLDER.glob(f"{icon.key}_{icon.icon_hash}.*")
|
||||
|
||||
for path in paths:
|
||||
try:
|
||||
|
|
@ -492,11 +509,9 @@ class IconManager:
|
|||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
async def update(self, scope: str, key: str,
|
||||
new_icon_data: str, **kwargs) -> Icon:
|
||||
async def update(self, scope: str, key: str, new_icon_data: str, **kwargs) -> Icon:
|
||||
"""Update an icon on a key."""
|
||||
old_icon_hash = await self.storage.db.fetchval(
|
||||
_gen_update_sql(scope), key)
|
||||
old_icon_hash = await self.storage.db.fetchval(_gen_update_sql(scope), key)
|
||||
|
||||
# converting key to str only here since from here onwards
|
||||
# its operations on the icons table (or a dereference with
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
import asyncio
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ class JobManager:
|
|||
use helpers such as asyncio.gather and asyncio.Task.all_tasks. It only uses
|
||||
its own internal list of jobs.
|
||||
"""
|
||||
|
||||
def __init__(self, loop=None):
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
self.jobs = []
|
||||
|
|
@ -41,13 +43,11 @@ class JobManager:
|
|||
try:
|
||||
await coro
|
||||
except Exception:
|
||||
log.exception('Error while running job')
|
||||
log.exception("Error while running job")
|
||||
|
||||
def spawn(self, coro):
|
||||
"""Spawn a given future or coroutine in the background."""
|
||||
task = self.loop.create_task(
|
||||
self._wrapper(coro)
|
||||
)
|
||||
task = self.loop.create_task(self._wrapper(coro))
|
||||
|
||||
self.jobs.append(task)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,40 +26,42 @@ from quart import current_app as app
|
|||
# type for all the fields
|
||||
_i = ctypes.c_uint8
|
||||
|
||||
|
||||
class _RawPermsBits(ctypes.LittleEndianStructure):
|
||||
"""raw bitfield for discord's permission number."""
|
||||
|
||||
_fields_ = [
|
||||
('create_invites', _i, 1),
|
||||
('kick_members', _i, 1),
|
||||
('ban_members', _i, 1),
|
||||
('administrator', _i, 1),
|
||||
('manage_channels', _i, 1),
|
||||
('manage_guild', _i, 1),
|
||||
('add_reactions', _i, 1),
|
||||
('view_audit_log', _i, 1),
|
||||
('priority_speaker', _i, 1),
|
||||
('stream', _i, 1),
|
||||
('read_messages', _i, 1),
|
||||
('send_messages', _i, 1),
|
||||
('send_tts', _i, 1),
|
||||
('manage_messages', _i, 1),
|
||||
('embed_links', _i, 1),
|
||||
('attach_files', _i, 1),
|
||||
('read_history', _i, 1),
|
||||
('mention_everyone', _i, 1),
|
||||
('external_emojis', _i, 1),
|
||||
('_unused2', _i, 1),
|
||||
('connect', _i, 1),
|
||||
('speak', _i, 1),
|
||||
('mute_members', _i, 1),
|
||||
('deafen_members', _i, 1),
|
||||
('move_members', _i, 1),
|
||||
('use_voice_activation', _i, 1),
|
||||
('change_nickname', _i, 1),
|
||||
('manage_nicknames', _i, 1),
|
||||
('manage_roles', _i, 1),
|
||||
('manage_webhooks', _i, 1),
|
||||
('manage_emojis', _i, 1),
|
||||
("create_invites", _i, 1),
|
||||
("kick_members", _i, 1),
|
||||
("ban_members", _i, 1),
|
||||
("administrator", _i, 1),
|
||||
("manage_channels", _i, 1),
|
||||
("manage_guild", _i, 1),
|
||||
("add_reactions", _i, 1),
|
||||
("view_audit_log", _i, 1),
|
||||
("priority_speaker", _i, 1),
|
||||
("stream", _i, 1),
|
||||
("read_messages", _i, 1),
|
||||
("send_messages", _i, 1),
|
||||
("send_tts", _i, 1),
|
||||
("manage_messages", _i, 1),
|
||||
("embed_links", _i, 1),
|
||||
("attach_files", _i, 1),
|
||||
("read_history", _i, 1),
|
||||
("mention_everyone", _i, 1),
|
||||
("external_emojis", _i, 1),
|
||||
("_unused2", _i, 1),
|
||||
("connect", _i, 1),
|
||||
("speak", _i, 1),
|
||||
("mute_members", _i, 1),
|
||||
("deafen_members", _i, 1),
|
||||
("move_members", _i, 1),
|
||||
("use_voice_activation", _i, 1),
|
||||
("change_nickname", _i, 1),
|
||||
("manage_nicknames", _i, 1),
|
||||
("manage_roles", _i, 1),
|
||||
("manage_webhooks", _i, 1),
|
||||
("manage_emojis", _i, 1),
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -72,16 +74,14 @@ class Permissions(ctypes.Union):
|
|||
val
|
||||
The permissions value as an integer.
|
||||
"""
|
||||
_fields_ = [
|
||||
('bits', _RawPermsBits),
|
||||
('binary', ctypes.c_uint64),
|
||||
]
|
||||
|
||||
_fields_ = [("bits", _RawPermsBits), ("binary", ctypes.c_uint64)]
|
||||
|
||||
def __init__(self, val: int):
|
||||
self.binary = val
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Permissions binary={self.binary}>'
|
||||
return f"<Permissions binary={self.binary}>"
|
||||
|
||||
def __int__(self):
|
||||
return self.binary
|
||||
|
|
@ -95,11 +95,15 @@ async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
|
|||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
perms = await storage.db.fetchval("""
|
||||
perms = await storage.db.fetchval(
|
||||
"""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
WHERE guild_id = $1 AND id = $2
|
||||
""", guild_id, role_id)
|
||||
""",
|
||||
guild_id,
|
||||
role_id,
|
||||
)
|
||||
|
||||
return Permissions(perms)
|
||||
|
||||
|
|
@ -118,11 +122,14 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
|
|||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
owner_id = await storage.db.fetchval("""
|
||||
owner_id = await storage.db.fetchval(
|
||||
"""
|
||||
SELECT owner_id
|
||||
FROM guilds
|
||||
WHERE id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if owner_id == member_id:
|
||||
return ALL_PERMISSIONS
|
||||
|
|
@ -130,20 +137,27 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
|
|||
# get permissions for @everyone
|
||||
permissions = await get_role_perms(guild_id, guild_id, storage)
|
||||
|
||||
role_ids = await storage.db.fetch("""
|
||||
role_ids = await storage.db.fetch(
|
||||
"""
|
||||
SELECT role_id
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
""",
|
||||
guild_id,
|
||||
member_id,
|
||||
)
|
||||
|
||||
role_perms = []
|
||||
|
||||
for row in role_ids:
|
||||
rperm = await storage.db.fetchval("""
|
||||
rperm = await storage.db.fetchval(
|
||||
"""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
WHERE id = $1
|
||||
""", row['role_id'])
|
||||
""",
|
||||
row["role_id"],
|
||||
)
|
||||
|
||||
role_perms.append(rperm)
|
||||
|
||||
|
|
@ -164,16 +178,17 @@ def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
|
|||
result = perms.binary
|
||||
|
||||
# negate the permissions that are denied
|
||||
result &= ~overwrite['deny']
|
||||
result &= ~overwrite["deny"]
|
||||
|
||||
# combine the permissions that are allowed
|
||||
result |= overwrite['allow']
|
||||
result |= overwrite["allow"]
|
||||
|
||||
return Permissions(result)
|
||||
|
||||
|
||||
def overwrite_find_mix(perms: Permissions, overwrites: dict,
|
||||
target_id: int) -> Permissions:
|
||||
def overwrite_find_mix(
|
||||
perms: Permissions, overwrites: dict, target_id: int
|
||||
) -> Permissions:
|
||||
"""Mix a given permission with a given overwrite.
|
||||
|
||||
Returns the given permission if an overwrite is not found.
|
||||
|
|
@ -201,19 +216,25 @@ def overwrite_find_mix(perms: Permissions, overwrites: dict,
|
|||
return perms
|
||||
|
||||
|
||||
async def role_permissions(guild_id: int, role_id: int,
|
||||
channel_id: int, storage=None) -> Permissions:
|
||||
async def role_permissions(
|
||||
guild_id: int, role_id: int, channel_id: int, storage=None
|
||||
) -> Permissions:
|
||||
"""Get the permissions for a role, in relation to a channel"""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
perms = await get_role_perms(guild_id, role_id, storage)
|
||||
|
||||
overwrite = await storage.db.fetchrow("""
|
||||
overwrite = await storage.db.fetchrow(
|
||||
"""
|
||||
SELECT allow, deny
|
||||
FROM channel_overwrites
|
||||
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
|
||||
""", channel_id, 1, role_id)
|
||||
""",
|
||||
channel_id,
|
||||
1,
|
||||
role_id,
|
||||
)
|
||||
|
||||
if overwrite:
|
||||
perms = overwrite_mix(perms, overwrite)
|
||||
|
|
@ -221,10 +242,13 @@ async def role_permissions(guild_id: int, role_id: int,
|
|||
return perms
|
||||
|
||||
|
||||
async def compute_overwrites(base_perms: Permissions,
|
||||
user_id, channel_id: int,
|
||||
guild_id: Optional[int] = None,
|
||||
storage=None):
|
||||
async def compute_overwrites(
|
||||
base_perms: Permissions,
|
||||
user_id,
|
||||
channel_id: int,
|
||||
guild_id: Optional[int] = None,
|
||||
storage=None,
|
||||
):
|
||||
"""Compute the permissions in the context of a channel."""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
|
@ -245,7 +269,7 @@ async def compute_overwrites(base_perms: Permissions,
|
|||
return ALL_PERMISSIONS
|
||||
|
||||
# make it a map for better usage
|
||||
overwrites = {int(o['id']): o for o in overwrites}
|
||||
overwrites = {int(o["id"]): o for o in overwrites}
|
||||
|
||||
perms = overwrite_find_mix(perms, overwrites, guild_id)
|
||||
|
||||
|
|
@ -260,14 +284,11 @@ async def compute_overwrites(base_perms: Permissions,
|
|||
for role_id in role_ids:
|
||||
overwrite = overwrites.get(role_id)
|
||||
if overwrite:
|
||||
allow |= overwrite['allow']
|
||||
deny |= overwrite['deny']
|
||||
allow |= overwrite["allow"]
|
||||
deny |= overwrite["deny"]
|
||||
|
||||
# final step for roles: mix
|
||||
perms = overwrite_mix(perms, {
|
||||
'allow': allow,
|
||||
'deny': deny
|
||||
})
|
||||
perms = overwrite_mix(perms, {"allow": allow, "deny": deny})
|
||||
|
||||
# apply member specific overwrites
|
||||
perms = overwrite_find_mix(perms, overwrites, user_id)
|
||||
|
|
@ -275,8 +296,7 @@ async def compute_overwrites(base_perms: Permissions,
|
|||
return perms
|
||||
|
||||
|
||||
async def get_permissions(member_id: int, channel_id,
|
||||
*, storage=None) -> Permissions:
|
||||
async def get_permissions(member_id: int, channel_id, *, storage=None) -> Permissions:
|
||||
"""Get the permissions for a user in a channel."""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
|
@ -290,4 +310,5 @@ async def get_permissions(member_id: int, channel_id,
|
|||
base_perms = await base_permissions(member_id, guild_id, storage)
|
||||
|
||||
return await compute_overwrites(
|
||||
base_perms, member_id, channel_id, guild_id, storage)
|
||||
base_perms, member_id, channel_id, guild_id, storage
|
||||
)
|
||||
|
|
|
|||
|
|
@ -32,62 +32,56 @@ def status_cmp(status: str, other_status: str) -> bool:
|
|||
in the status hierarchy.
|
||||
"""
|
||||
|
||||
hierarchy = {
|
||||
'online': 3,
|
||||
'idle': 2,
|
||||
'dnd': 1,
|
||||
'offline': 0,
|
||||
None: -1,
|
||||
}
|
||||
hierarchy = {"online": 3, "idle": 2, "dnd": 1, "offline": 0, None: -1}
|
||||
|
||||
return hierarchy[status] > hierarchy[other_status]
|
||||
|
||||
|
||||
def _best_presence(shards):
|
||||
"""Find the 'best' presence given a list of GatewayState."""
|
||||
best = {'status': None, 'game': None}
|
||||
best = {"status": None, "game": None}
|
||||
|
||||
for state in shards:
|
||||
presence = state.presence
|
||||
|
||||
status = presence['status']
|
||||
status = presence["status"]
|
||||
|
||||
if not presence:
|
||||
continue
|
||||
|
||||
# shards with a better status
|
||||
# in the hierarchy are treated as best
|
||||
if status_cmp(status, best['status']):
|
||||
best['status'] = status
|
||||
if status_cmp(status, best["status"]):
|
||||
best["status"] = status
|
||||
|
||||
# if we have any game, use it
|
||||
if presence['game'] is not None:
|
||||
best['game'] = presence['game']
|
||||
if presence["game"] is not None:
|
||||
best["game"] = presence["game"]
|
||||
|
||||
# best['status'] is None when no
|
||||
# status was good enough.
|
||||
return None if not best['status'] else best
|
||||
return None if not best["status"] else best
|
||||
|
||||
|
||||
def fill_presence(presence: dict, *, game=None) -> dict:
|
||||
"""Fill a given presence object with some specific fields."""
|
||||
presence['client_status'] = {}
|
||||
presence['mobile'] = False
|
||||
presence["client_status"] = {}
|
||||
presence["mobile"] = False
|
||||
|
||||
if 'since' not in presence:
|
||||
presence['since'] = 0
|
||||
if "since" not in presence:
|
||||
presence["since"] = 0
|
||||
|
||||
# fill game and activities array depending if game
|
||||
# is there or not
|
||||
game = game or presence.get('game')
|
||||
game = game or presence.get("game")
|
||||
|
||||
# casting to bool since a game of {} is still invalid
|
||||
if game:
|
||||
presence['game'] = game
|
||||
presence['activities'] = [game]
|
||||
presence["game"] = game
|
||||
presence["activities"] = [game]
|
||||
else:
|
||||
presence['game'] = None
|
||||
presence['activities'] = []
|
||||
presence["game"] = None
|
||||
presence["activities"] = []
|
||||
|
||||
return presence
|
||||
|
||||
|
|
@ -96,14 +90,13 @@ async def _pres(storage, user_id: int, status_obj: dict) -> dict:
|
|||
"""Convert a given status into a presence, given the User ID and the
|
||||
:class:`Storage` instance."""
|
||||
ext = {
|
||||
'user': await storage.get_user(user_id),
|
||||
'activities': [],
|
||||
|
||||
"user": await storage.get_user(user_id),
|
||||
"activities": [],
|
||||
# NOTE: we are purposefully overwriting the fields, as there
|
||||
# isn't any push for us to actually implement mobile detection, or
|
||||
# web detection, etc.
|
||||
'client_status': {},
|
||||
'mobile': False,
|
||||
"client_status": {},
|
||||
"mobile": False,
|
||||
}
|
||||
|
||||
return fill_presence({**status_obj, **ext})
|
||||
|
|
@ -115,14 +108,16 @@ class PresenceManager:
|
|||
Has common functions to deal with fetching or updating presences, including
|
||||
side-effects (events).
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.storage = app.storage
|
||||
self.user_storage = app.user_storage
|
||||
self.state_manager = app.state_manager
|
||||
self.dispatcher = app.dispatcher
|
||||
|
||||
async def guild_presences(self, member_ids: List[int],
|
||||
guild_id: int) -> List[Dict[Any, str]]:
|
||||
async def guild_presences(
|
||||
self, member_ids: List[int], guild_id: int
|
||||
) -> List[Dict[Any, str]]:
|
||||
"""Fetch all presences in a guild."""
|
||||
# this works via fetching all connected GatewayState on a guild
|
||||
# then fetching its respective member and merging that info with
|
||||
|
|
@ -132,34 +127,36 @@ class PresenceManager:
|
|||
presences = []
|
||||
|
||||
for state in states:
|
||||
member = await self.storage.get_member_data_one(
|
||||
guild_id, state.user_id)
|
||||
member = await self.storage.get_member_data_one(guild_id, state.user_id)
|
||||
|
||||
game = state.presence.get('game', None)
|
||||
game = state.presence.get("game", None)
|
||||
|
||||
# only use the data we need.
|
||||
presences.append(fill_presence({
|
||||
'user': member['user'],
|
||||
'roles': member['roles'],
|
||||
'guild_id': str(guild_id),
|
||||
|
||||
# if a state is connected to the guild
|
||||
# we assume its online.
|
||||
'status': state.presence.get('status', 'online'),
|
||||
}, game=game))
|
||||
presences.append(
|
||||
fill_presence(
|
||||
{
|
||||
"user": member["user"],
|
||||
"roles": member["roles"],
|
||||
"guild_id": str(guild_id),
|
||||
# if a state is connected to the guild
|
||||
# we assume its online.
|
||||
"status": state.presence.get("status", "online"),
|
||||
},
|
||||
game=game,
|
||||
)
|
||||
)
|
||||
|
||||
return presences
|
||||
|
||||
async def dispatch_guild_pres(self, guild_id: int,
|
||||
user_id: int, new_state: dict):
|
||||
async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict):
|
||||
"""Dispatch a Presence update to an entire guild."""
|
||||
state = dict(new_state)
|
||||
|
||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||
|
||||
game = state['game']
|
||||
game = state["game"]
|
||||
|
||||
lazy_guild_store = self.dispatcher.backends['lazy_guild']
|
||||
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||
|
||||
# shards that are in lazy guilds with 'everyone'
|
||||
|
|
@ -168,49 +165,44 @@ class PresenceManager:
|
|||
|
||||
for member_list in lists:
|
||||
session_ids = await member_list.pres_update(
|
||||
int(member['user']['id']),
|
||||
{
|
||||
'roles': member['roles'],
|
||||
'status': state['status'],
|
||||
'game': game
|
||||
}
|
||||
int(member["user"]["id"]),
|
||||
{"roles": member["roles"], "status": state["status"], "game": game},
|
||||
)
|
||||
|
||||
log.debug('Lazy Dispatch to {}',
|
||||
len(session_ids))
|
||||
log.debug("Lazy Dispatch to {}", len(session_ids))
|
||||
|
||||
# if we are on the 'everyone' member list, we don't
|
||||
# dispatch a PRESENCE_UPDATE for those shards.
|
||||
if member_list.channel_id == member_list.guild_id:
|
||||
in_lazy.extend(session_ids)
|
||||
|
||||
pres_update_payload = fill_presence({
|
||||
'guild_id': str(guild_id),
|
||||
'user': member['user'],
|
||||
'roles': member['roles'],
|
||||
'status': state['status'],
|
||||
}, game=game)
|
||||
pres_update_payload = fill_presence(
|
||||
{
|
||||
"guild_id": str(guild_id),
|
||||
"user": member["user"],
|
||||
"roles": member["roles"],
|
||||
"status": state["status"],
|
||||
},
|
||||
game=game,
|
||||
)
|
||||
|
||||
# given a session id, return if the session id actually connects to
|
||||
# a given user, and if the state has not been dispatched via lazy guild.
|
||||
def _session_check(session_id):
|
||||
state = self.state_manager.fetch_raw(session_id)
|
||||
uid = int(member['user']['id'])
|
||||
uid = int(member["user"]["id"])
|
||||
|
||||
if not state:
|
||||
return False
|
||||
|
||||
# we don't want to send a presence update
|
||||
# to the same user
|
||||
return (state.user_id != uid and
|
||||
session_id not in in_lazy)
|
||||
return state.user_id != uid and session_id not in in_lazy
|
||||
|
||||
# everyone not in lazy guild mode
|
||||
# gets a PRESENCE_UPDATE
|
||||
await self.dispatcher.dispatch_filter(
|
||||
'guild', guild_id,
|
||||
_session_check,
|
||||
'PRESENCE_UPDATE', pres_update_payload
|
||||
"guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload
|
||||
)
|
||||
|
||||
return in_lazy
|
||||
|
|
@ -220,25 +212,25 @@ class PresenceManager:
|
|||
|
||||
Also dispatches the presence to all the users' friends
|
||||
"""
|
||||
if state['status'] == 'invisible':
|
||||
state['status'] = 'offline'
|
||||
if state["status"] == "invisible":
|
||||
state["status"] = "offline"
|
||||
|
||||
# TODO: shard-aware
|
||||
guild_ids = await self.user_storage.get_user_guilds(user_id)
|
||||
|
||||
for guild_id in guild_ids:
|
||||
await self.dispatch_guild_pres(
|
||||
guild_id, user_id, state)
|
||||
await self.dispatch_guild_pres(guild_id, user_id, state)
|
||||
|
||||
# dispatch to all friends that are subscribed to them
|
||||
user = await self.storage.get_user(user_id)
|
||||
game = state['game']
|
||||
game = state["game"]
|
||||
|
||||
await self.dispatcher.dispatch(
|
||||
'friend', user_id, 'PRESENCE_UPDATE', fill_presence({
|
||||
'user': user,
|
||||
'status': state['status'],
|
||||
}, game=game))
|
||||
"friend",
|
||||
user_id,
|
||||
"PRESENCE_UPDATE",
|
||||
fill_presence({"user": user, "status": state["status"]}, game=game),
|
||||
)
|
||||
|
||||
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
|
||||
"""Fetch presences for a group of users.
|
||||
|
|
@ -254,22 +246,25 @@ class PresenceManager:
|
|||
|
||||
if not friend_states:
|
||||
# append offline
|
||||
res.append(await _pres(storage, friend_id, {
|
||||
'afk': False,
|
||||
'status': 'offline',
|
||||
'game': None,
|
||||
'since': 0
|
||||
}))
|
||||
res.append(
|
||||
await _pres(
|
||||
storage,
|
||||
friend_id,
|
||||
{"afk": False, "status": "offline", "game": None, "since": 0},
|
||||
)
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
# filter the best shards:
|
||||
# - all with id 0 (are the first shards in the collection) or
|
||||
# - all shards with count = 1 (single shards)
|
||||
good_shards = list(filter(
|
||||
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
|
||||
friend_states
|
||||
))
|
||||
good_shards = list(
|
||||
filter(
|
||||
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
|
||||
friend_states,
|
||||
)
|
||||
)
|
||||
|
||||
if good_shards:
|
||||
best_pres = _best_presence(good_shards)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,11 @@ from .channel import ChannelDispatcher
|
|||
from .friend import FriendDispatcher
|
||||
from .lazy_guild import LazyGuildDispatcher
|
||||
|
||||
__all__ = ['GuildDispatcher', 'MemberDispatcher',
|
||||
'UserDispatcher', 'ChannelDispatcher',
|
||||
'FriendDispatcher', 'LazyGuildDispatcher']
|
||||
__all__ = [
|
||||
"GuildDispatcher",
|
||||
"MemberDispatcher",
|
||||
"UserDispatcher",
|
||||
"ChannelDispatcher",
|
||||
"FriendDispatcher",
|
||||
"LazyGuildDispatcher",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -38,23 +38,20 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
|||
# make a copy or the original channel object
|
||||
data = dict(orig)
|
||||
|
||||
idx = index_by_func(
|
||||
lambda user: user['id'] == str(user_id),
|
||||
data['recipients']
|
||||
)
|
||||
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
|
||||
|
||||
data['recipients'].pop(idx)
|
||||
data["recipients"].pop(idx)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ChannelDispatcher(DispatcherWithFlags):
|
||||
"""Main channel Pub/Sub logic."""
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def dispatch(self, channel_id,
|
||||
event: str, data: Any) -> List[str]:
|
||||
async def dispatch(self, channel_id, event: str, data: Any) -> List[str]:
|
||||
"""Dispatch an event to a channel."""
|
||||
# get everyone who is subscribed
|
||||
# and store the number of states we dispatched the event to
|
||||
|
|
@ -75,9 +72,11 @@ class ChannelDispatcher(DispatcherWithFlags):
|
|||
# TODO: make a fetch_states that fetches shards
|
||||
# - with id 0 (count any) OR
|
||||
# - single shards (id=0, count=1)
|
||||
states = (self.sm.fetch_states(user_id, guild_id)
|
||||
if guild_id else
|
||||
self.sm.user_states(user_id))
|
||||
states = (
|
||||
self.sm.fetch_states(user_id, guild_id)
|
||||
if guild_id
|
||||
else self.sm.user_states(user_id)
|
||||
)
|
||||
|
||||
# unsub people who don't have any states tied to the channel.
|
||||
if not states:
|
||||
|
|
@ -85,28 +84,28 @@ class ChannelDispatcher(DispatcherWithFlags):
|
|||
continue
|
||||
|
||||
# skip typing events for users that don't want it
|
||||
if event.startswith('TYPING_') and \
|
||||
not self.flags_get(channel_id, user_id, 'typing', True):
|
||||
if event.startswith("TYPING_") and not self.flags_get(
|
||||
channel_id, user_id, "typing", True
|
||||
):
|
||||
continue
|
||||
|
||||
cur_sess = []
|
||||
|
||||
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
||||
and data.get('type') == ChannelType.GROUP_DM.value:
|
||||
if (
|
||||
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||
and data.get("type") == ChannelType.GROUP_DM.value
|
||||
):
|
||||
# we edit the channel payload so it doesn't show
|
||||
# the user as a recipient
|
||||
|
||||
new_data = gdm_recipient_view(data, user_id)
|
||||
cur_sess = await self._dispatch_states(
|
||||
states, event, new_data)
|
||||
cur_sess = await self._dispatch_states(states, event, new_data)
|
||||
else:
|
||||
cur_sess = await self._dispatch_states(
|
||||
states, event, data)
|
||||
cur_sess = await self._dispatch_states(states, event, data)
|
||||
|
||||
sessions.extend(cur_sess)
|
||||
dispatched += len(cur_sess)
|
||||
|
||||
log.info('Dispatched chan={} {!r} to {} states',
|
||||
channel_id, event, dispatched)
|
||||
log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
|
||||
|
||||
return sessions
|
||||
|
|
|
|||
|
|
@ -80,8 +80,7 @@ class Dispatcher:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _dispatch_states(self, states: list, event: str,
|
||||
data) -> List[str]:
|
||||
async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
|
||||
"""Dispatch an event to a list of states."""
|
||||
res = []
|
||||
|
||||
|
|
@ -90,7 +89,7 @@ class Dispatcher:
|
|||
await state.ws.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
except Exception:
|
||||
log.exception('error while dispatching')
|
||||
log.exception("error while dispatching")
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -102,6 +101,7 @@ class DispatcherWithState(Dispatcher):
|
|||
of boilerplate code on Pub/Sub backends
|
||||
that have that dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class FriendDispatcher(DispatcherWithState):
|
|||
channels. If that friend updates their presence, it will be
|
||||
broadcasted through that channel to basically all their friends.
|
||||
"""
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
|
|
@ -44,17 +45,13 @@ class FriendDispatcher(DispatcherWithState):
|
|||
# since relationships broadcast to all shards.
|
||||
sessions.extend(
|
||||
await self.main_dispatcher.dispatch_filter(
|
||||
'user', peer_id, func, event, data)
|
||||
"user", peer_id, func, event, data
|
||||
)
|
||||
)
|
||||
|
||||
log.info('dispatched uid={} {!r} to {} states',
|
||||
user_id, event, len(sessions))
|
||||
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
|
||||
|
||||
return sessions
|
||||
|
||||
async def dispatch(self, user_id, event, data):
|
||||
return await self.dispatch_filter(
|
||||
user_id,
|
||||
lambda sess_id: True,
|
||||
event, data,
|
||||
)
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ log = Logger(__name__)
|
|||
|
||||
class GuildDispatcher(DispatcherWithFlags):
|
||||
"""Guild backend for Pub/Sub"""
|
||||
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def _chan_action(self, action: str,
|
||||
guild_id: int, user_id: int, flags=None):
|
||||
async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None):
|
||||
"""Send an action to all channels of the guild."""
|
||||
flags = flags or {}
|
||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||
|
|
@ -43,33 +43,31 @@ class GuildDispatcher(DispatcherWithFlags):
|
|||
# only do an action for users that can
|
||||
# actually read the channel to start with.
|
||||
chan_perms = await get_permissions(
|
||||
user_id, chan_id,
|
||||
storage=self.main_dispatcher.app.storage)
|
||||
user_id, chan_id, storage=self.main_dispatcher.app.storage
|
||||
)
|
||||
|
||||
if not chan_perms.bits.read_messages:
|
||||
log.debug('skipping cid={}, no read messages',
|
||||
chan_id)
|
||||
log.debug("skipping cid={}, no read messages", chan_id)
|
||||
continue
|
||||
|
||||
log.debug('sending raw action {!r} to chan={}',
|
||||
action, chan_id)
|
||||
log.debug("sending raw action {!r} to chan={}", action, chan_id)
|
||||
|
||||
# for now, only sub() has support for flags.
|
||||
# it is an idea to have flags support for other actions
|
||||
args = []
|
||||
if action == 'sub':
|
||||
if action == "sub":
|
||||
chanflags = dict(flags)
|
||||
|
||||
# channels don't need presence flags
|
||||
try:
|
||||
chanflags.pop('presence')
|
||||
chanflags.pop("presence")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
args.append(chanflags)
|
||||
|
||||
await self.main_dispatcher.action(
|
||||
'channel', action, chan_id, user_id, *args
|
||||
"channel", action, chan_id, user_id, *args
|
||||
)
|
||||
|
||||
async def _chan_call(self, meth: str, guild_id: int, *args):
|
||||
|
|
@ -77,26 +75,24 @@ class GuildDispatcher(DispatcherWithFlags):
|
|||
in the guild."""
|
||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||
|
||||
chan_dispatcher = self.main_dispatcher.backends['channel']
|
||||
chan_dispatcher = self.main_dispatcher.backends["channel"]
|
||||
method = getattr(chan_dispatcher, meth)
|
||||
|
||||
for chan_id in chan_ids:
|
||||
log.debug('calling {} to chan={}',
|
||||
meth, chan_id)
|
||||
log.debug("calling {} to chan={}", meth, chan_id)
|
||||
await method(chan_id, *args)
|
||||
|
||||
async def sub(self, guild_id: int, user_id: int, flags=None):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(guild_id, user_id, flags)
|
||||
await self._chan_action('sub', guild_id, user_id, flags)
|
||||
await self._chan_action("sub", guild_id, user_id, flags)
|
||||
|
||||
async def unsub(self, guild_id: int, user_id: int):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(guild_id, user_id)
|
||||
await self._chan_action('unsub', guild_id, user_id)
|
||||
await self._chan_action("unsub", guild_id, user_id)
|
||||
|
||||
async def dispatch_filter(self, guild_id: int, func,
|
||||
event: str, data: Any):
|
||||
async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
|
||||
"""Selectively dispatch to session ids that have
|
||||
func(session_id) true."""
|
||||
user_ids = self.state[guild_id]
|
||||
|
|
@ -121,31 +117,23 @@ class GuildDispatcher(DispatcherWithFlags):
|
|||
|
||||
# note that this does not equate to any unsubscription
|
||||
# of the channel.
|
||||
if event.startswith('PRESENCE_') and \
|
||||
not self.flags_get(guild_id, user_id, 'presence', True):
|
||||
if event.startswith("PRESENCE_") and not self.flags_get(
|
||||
guild_id, user_id, "presence", True
|
||||
):
|
||||
continue
|
||||
|
||||
# filter the ones that matter
|
||||
states = list(filter(
|
||||
lambda state: func(state.session_id), states
|
||||
))
|
||||
states = list(filter(lambda state: func(state.session_id), states))
|
||||
|
||||
cur_sess = await self._dispatch_states(
|
||||
states, event, data)
|
||||
cur_sess = await self._dispatch_states(states, event, data)
|
||||
|
||||
sessions.extend(cur_sess)
|
||||
dispatched += len(cur_sess)
|
||||
|
||||
log.info('Dispatched {} {!r} to {} states',
|
||||
guild_id, event, dispatched)
|
||||
log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched)
|
||||
|
||||
return sessions
|
||||
|
||||
async def dispatch(self, guild_id: int,
|
||||
event: str, data: Any):
|
||||
async def dispatch(self, guild_id: int, event: str, data: Any):
|
||||
"""Dispatch an event to all subscribers of the guild."""
|
||||
return await self.dispatch_filter(
|
||||
guild_id,
|
||||
lambda sess_id: True,
|
||||
event, data,
|
||||
)
|
||||
return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -22,6 +22,7 @@ from .dispatcher import Dispatcher
|
|||
|
||||
class MemberDispatcher(Dispatcher):
|
||||
"""Member backend for Pub/Sub."""
|
||||
|
||||
KEY_TYPE = tuple
|
||||
|
||||
async def dispatch(self, key, event, data):
|
||||
|
|
@ -39,7 +40,7 @@ class MemberDispatcher(Dispatcher):
|
|||
# if no states were found, we should
|
||||
# unsub the user from the GUILD channel
|
||||
if not states:
|
||||
await self.main_dispatcher.unsub('guild', guild_id, user_id)
|
||||
await self.main_dispatcher.unsub("guild", guild_id, user_id)
|
||||
return
|
||||
|
||||
return await self._dispatch_states(states, event, data)
|
||||
|
|
|
|||
|
|
@ -22,22 +22,18 @@ from .dispatcher import Dispatcher
|
|||
|
||||
class UserDispatcher(Dispatcher):
|
||||
"""User backend for Pub/Sub."""
|
||||
|
||||
KEY_TYPE = int
|
||||
|
||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
||||
"""Dispatch an event to all shards of a user."""
|
||||
|
||||
# filter only states where func() gives true
|
||||
states = list(filter(
|
||||
lambda state: func(state.session_id),
|
||||
self.sm.user_states(user_id)
|
||||
))
|
||||
states = list(
|
||||
filter(lambda state: func(state.session_id), self.sm.user_states(user_id))
|
||||
)
|
||||
|
||||
return await self._dispatch_states(states, event, data)
|
||||
|
||||
async def dispatch(self, user_id: int, event, data):
|
||||
return await self.dispatch_filter(
|
||||
user_id,
|
||||
lambda sess_id: True,
|
||||
event, data,
|
||||
)
|
||||
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import time
|
|||
|
||||
class RatelimitBucket:
|
||||
"""Main ratelimit bucket class."""
|
||||
|
||||
def __init__(self, tokens, second):
|
||||
self.requests = tokens
|
||||
self.second = second
|
||||
|
|
@ -88,17 +89,19 @@ class RatelimitBucket:
|
|||
|
||||
Used to manage multiple ratelimits to users.
|
||||
"""
|
||||
return RatelimitBucket(self.requests,
|
||||
self.second)
|
||||
return RatelimitBucket(self.requests, self.second)
|
||||
|
||||
def __repr__(self):
|
||||
return (f'<RatelimitBucket requests={self.requests} '
|
||||
f'second={self.second} window: {self._window} '
|
||||
f'tokens={self._tokens}>')
|
||||
return (
|
||||
f"<RatelimitBucket requests={self.requests} "
|
||||
f"second={self.second} window: {self._window} "
|
||||
f"tokens={self._tokens}>"
|
||||
)
|
||||
|
||||
|
||||
class Ratelimit:
|
||||
"""Manages buckets."""
|
||||
|
||||
def __init__(self, tokens, second, keys=None):
|
||||
self._cache = {}
|
||||
if keys is None:
|
||||
|
|
@ -107,12 +110,11 @@ class Ratelimit:
|
|||
self._cooldown = RatelimitBucket(tokens, second)
|
||||
|
||||
def __repr__(self):
|
||||
return (f'<Ratelimit cooldown={self._cooldown}>')
|
||||
return f"<Ratelimit cooldown={self._cooldown}>"
|
||||
|
||||
def _verify_cache(self):
|
||||
current = time.time()
|
||||
dead_keys = [k for k, v in self._cache.items()
|
||||
if current > v._last + v.second]
|
||||
dead_keys = [k for k, v in self._cache.items() if current > v._last + v.second]
|
||||
|
||||
for k in dead_keys:
|
||||
del self._cache[k]
|
||||
|
|
|
|||
|
|
@ -31,10 +31,10 @@ async def _check_bucket(bucket):
|
|||
if retry_after:
|
||||
request.retry_after = retry_after
|
||||
|
||||
raise Ratelimited('You are being rate limited.', {
|
||||
'retry_after': int(retry_after * 1000),
|
||||
'global': request.bucket_global,
|
||||
})
|
||||
raise Ratelimited(
|
||||
"You are being rate limited.",
|
||||
{"retry_after": int(retry_after * 1000), "global": request.bucket_global},
|
||||
)
|
||||
|
||||
|
||||
async def _handle_global(ratelimit):
|
||||
|
|
@ -59,13 +59,13 @@ async def _handle_specific(ratelimit):
|
|||
keys = ratelimit.keys
|
||||
|
||||
# base key is the user id
|
||||
key_components = [f'user_id:{user_id}']
|
||||
key_components = [f"user_id:{user_id}"]
|
||||
|
||||
for key in keys:
|
||||
val = request.view_args[key]
|
||||
key_components.append(f'{key}:{val}')
|
||||
key_components.append(f"{key}:{val}")
|
||||
|
||||
bucket_key = ':'.join(key_components)
|
||||
bucket_key = ":".join(key_components)
|
||||
bucket = ratelimit.get_bucket(bucket_key)
|
||||
await _check_bucket(bucket)
|
||||
|
||||
|
|
@ -78,9 +78,7 @@ async def ratelimit_handler():
|
|||
rule = request.url_rule
|
||||
|
||||
if rule is None:
|
||||
return await _handle_global(
|
||||
app.ratelimiter.global_bucket
|
||||
)
|
||||
return await _handle_global(app.ratelimiter.global_bucket)
|
||||
|
||||
# rule.endpoint is composed of '<blueprint>.<function>'
|
||||
# and so we can use that to make routes with different
|
||||
|
|
@ -97,6 +95,4 @@ async def ratelimit_handler():
|
|||
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
|
||||
await _handle_specific(ratelimit)
|
||||
except KeyError:
|
||||
await _handle_global(
|
||||
app.ratelimiter.global_bucket
|
||||
)
|
||||
await _handle_global(app.ratelimiter.global_bucket)
|
||||
|
|
|
|||
|
|
@ -34,33 +34,30 @@ WS:
|
|||
|All Sent Messages| | 120/60s | per-session
|
||||
"""
|
||||
|
||||
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
|
||||
REACTION_BUCKET = Ratelimit(1, 0.25, ("channel_id"))
|
||||
|
||||
RATELIMITS = {
|
||||
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
|
||||
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
|
||||
|
||||
"channel_messages.create_message": Ratelimit(5, 5, ("channel_id")),
|
||||
"channel_messages.delete_message": Ratelimit(5, 1, ("channel_id")),
|
||||
# all of those share the same bucket.
|
||||
'channel_reactions.add_reaction': REACTION_BUCKET,
|
||||
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
|
||||
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
|
||||
|
||||
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
|
||||
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
|
||||
|
||||
"channel_reactions.add_reaction": REACTION_BUCKET,
|
||||
"channel_reactions.remove_own_reaction": REACTION_BUCKET,
|
||||
"channel_reactions.remove_user_reaction": REACTION_BUCKET,
|
||||
"guild_members.modify_guild_member": Ratelimit(10, 10, ("guild_id")),
|
||||
"guild_members.update_nickname": Ratelimit(1, 1, ("guild_id")),
|
||||
# this only applies to username.
|
||||
# 'users.patch_me': Ratelimit(2, 3600),
|
||||
|
||||
'_ws.connect': Ratelimit(1, 5),
|
||||
'_ws.presence': Ratelimit(5, 60),
|
||||
'_ws.messages': Ratelimit(120, 60),
|
||||
|
||||
"_ws.connect": Ratelimit(1, 5),
|
||||
"_ws.presence": Ratelimit(5, 60),
|
||||
"_ws.messages": Ratelimit(120, 60),
|
||||
# 1000 / 4h for new session issuing
|
||||
'_ws.session': Ratelimit(1000, 14400)
|
||||
"_ws.session": Ratelimit(1000, 14400),
|
||||
}
|
||||
|
||||
|
||||
class RatelimitManager:
|
||||
"""Manager for the bucket managers"""
|
||||
|
||||
def __init__(self, testing_flag=False):
|
||||
self._ratelimiters = {}
|
||||
self._test = testing_flag
|
||||
|
|
@ -74,9 +71,7 @@ class RatelimitManager:
|
|||
|
||||
# NOTE: this is a bad way to do it, but
|
||||
# we only need to change that one for now.
|
||||
rtl = (Ratelimit(10, 1)
|
||||
if self._test and path == '_ws.connect'
|
||||
else rtl)
|
||||
rtl = Ratelimit(10, 1) if self._test and path == "_ws.connect" else rtl
|
||||
|
||||
self._ratelimiters[path] = rtl
|
||||
|
||||
|
|
|
|||
|
|
@ -28,26 +28,30 @@ from .errors import BadRequest
|
|||
from .permissions import Permissions
|
||||
from .types import Color
|
||||
from .enums import (
|
||||
ActivityType, StatusType, ExplicitFilter, RelationshipType,
|
||||
MessageNotifications, ChannelType, VerificationLevel
|
||||
ActivityType,
|
||||
StatusType,
|
||||
ExplicitFilter,
|
||||
RelationshipType,
|
||||
MessageNotifications,
|
||||
ChannelType,
|
||||
VerificationLevel,
|
||||
)
|
||||
|
||||
from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_ ]{2,30}$', re.A)
|
||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
|
||||
re.A)
|
||||
DATA_REGEX = re.compile(r'data\:image/(png|jpeg|gif);base64,(.+)', re.A)
|
||||
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A)
|
||||
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A)
|
||||
DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A)
|
||||
|
||||
|
||||
# collection of regexes
|
||||
USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M)
|
||||
CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M)
|
||||
ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M)
|
||||
EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
|
||||
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
|
||||
USER_MENTION = re.compile(r"<@!?(\d+)>", re.A | re.M)
|
||||
CHAN_MENTION = re.compile(r"<#(\d+)>", re.A | re.M)
|
||||
ROLE_MENTION = re.compile(r"<@&(\d+)>", re.A | re.M)
|
||||
EMOJO_MENTION = re.compile(r"<:(\.+):(\d+)>", re.A | re.M)
|
||||
ANIMOJI_MENTION = re.compile(r"<a:(\.+):(\d+)>", re.A | re.M)
|
||||
|
||||
|
||||
def _in_enum(enum, value) -> bool:
|
||||
|
|
@ -61,6 +65,7 @@ def _in_enum(enum, value) -> bool:
|
|||
|
||||
class LitecordValidator(Validator):
|
||||
"""Main validator class for Litecord, containing custom types."""
|
||||
|
||||
def _validate_type_username(self, value: str) -> bool:
|
||||
"""Validate against the username regex."""
|
||||
return bool(USERNAME_REGEX.match(value))
|
||||
|
|
@ -130,8 +135,7 @@ class LitecordValidator(Validator):
|
|||
return False
|
||||
|
||||
# nobody is allowed to use the INCOMING and OUTGOING rel types
|
||||
return val in (RelationshipType.FRIEND.value,
|
||||
RelationshipType.BLOCK.value)
|
||||
return val in (RelationshipType.FRIEND.value, RelationshipType.BLOCK.value)
|
||||
|
||||
def _validate_type_msg_notifications(self, value: str):
|
||||
try:
|
||||
|
|
@ -152,14 +156,15 @@ class LitecordValidator(Validator):
|
|||
return self._validate_type_guild_name(value)
|
||||
|
||||
def _validate_type_theme(self, value: str) -> bool:
|
||||
return value in ['light', 'dark']
|
||||
return value in ["light", "dark"]
|
||||
|
||||
def _validate_type_nickname(self, value: str) -> bool:
|
||||
return isinstance(value, str) and (len(value) < 32)
|
||||
|
||||
|
||||
def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
|
||||
raise_err: bool = True) -> Dict:
|
||||
def validate(
|
||||
reqjson: Optional[Union[Dict, List]], schema: Dict, raise_err: bool = True
|
||||
) -> Dict:
|
||||
"""Validate the given user-given data against a schema, giving the
|
||||
"correct" version of the document, with all defaults applied.
|
||||
|
||||
|
|
@ -176,20 +181,20 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
|
|||
validator = LitecordValidator(schema)
|
||||
|
||||
if reqjson is None:
|
||||
raise BadRequest('No JSON provided')
|
||||
raise BadRequest("No JSON provided")
|
||||
|
||||
try:
|
||||
valid = validator.validate(reqjson)
|
||||
except Exception:
|
||||
log.exception('Error while validating')
|
||||
raise Exception(f'Error while validating: {reqjson}')
|
||||
log.exception("Error while validating")
|
||||
raise Exception(f"Error while validating: {reqjson}")
|
||||
|
||||
if not valid:
|
||||
errs = validator.errors
|
||||
log.warning('Error validating doc {!r}: {!r}', reqjson, errs)
|
||||
log.warning("Error validating doc {!r}: {!r}", reqjson, errs)
|
||||
|
||||
if raise_err:
|
||||
raise BadRequest('bad payload', errs)
|
||||
raise BadRequest("bad payload", errs)
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -197,554 +202,441 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
|
|||
|
||||
|
||||
REGISTER = {
|
||||
'username': {'type': 'username', 'required': True},
|
||||
'email': {'type': 'email', 'required': False},
|
||||
'password': {'type': 'password', 'required': False},
|
||||
|
||||
"username": {"type": "username", "required": True},
|
||||
"email": {"type": "email", "required": False},
|
||||
"password": {"type": "password", "required": False},
|
||||
# invite stands for a guild invite, not an instance invite (that's on
|
||||
# the register_with_invite handler).
|
||||
'invite': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
"invite": {"type": "string", "required": False, "nullable": True},
|
||||
# following fields only sent by official client, unused by us
|
||||
'fingerprint': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'captcha_key': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'gift_code_sku_id': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'consent': {'type': 'boolean', 'required': False},
|
||||
"fingerprint": {"type": "string", "required": False, "nullable": True},
|
||||
"captcha_key": {"type": "string", "required": False, "nullable": True},
|
||||
"gift_code_sku_id": {"type": "string", "required": False, "nullable": True},
|
||||
"consent": {"type": "boolean", "required": False},
|
||||
}
|
||||
|
||||
# only used by us, not discord, hence 'invcode' (to separate from discord)
|
||||
REGISTER_WITH_INVITE = {**REGISTER, **{
|
||||
'invcode': {'type': 'string', 'required': True}
|
||||
}}
|
||||
REGISTER_WITH_INVITE = {**REGISTER, **{"invcode": {"type": "string", "required": True}}}
|
||||
|
||||
|
||||
USER_UPDATE = {
|
||||
'username': {
|
||||
'type': 'username', 'minlength': 2,
|
||||
'maxlength': 30, 'required': False},
|
||||
|
||||
'discriminator': {
|
||||
'type': 'discriminator',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
"username": {
|
||||
"type": "username",
|
||||
"minlength": 2,
|
||||
"maxlength": 30,
|
||||
"required": False,
|
||||
},
|
||||
|
||||
'password': {
|
||||
'type': 'password', 'required': False,
|
||||
"discriminator": {"type": "discriminator", "required": False, "nullable": True},
|
||||
"password": {"type": "password", "required": False},
|
||||
"new_password": {
|
||||
"type": "password",
|
||||
"required": False,
|
||||
"dependencies": "password",
|
||||
"nullable": True,
|
||||
},
|
||||
|
||||
'new_password': {
|
||||
'type': 'password', 'required': False,
|
||||
'dependencies': 'password', 'nullable': True
|
||||
},
|
||||
|
||||
'email': {
|
||||
'type': 'email', 'required': False, 'dependencies': 'password',
|
||||
},
|
||||
|
||||
'avatar': {
|
||||
"email": {"type": "email", "required": False, "dependencies": "password"},
|
||||
"avatar": {
|
||||
# can be both b64_icon or string (just the hash)
|
||||
'type': 'string', 'required': False,
|
||||
'nullable': True
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"nullable": True,
|
||||
},
|
||||
|
||||
}
|
||||
|
||||
PARTIAL_ROLE_GUILD_CREATE = {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'name': {'type': 'role_name'},
|
||||
'color': {'type': 'number', 'default': 0},
|
||||
'hoist': {'type': 'boolean', 'default': False},
|
||||
|
||||
"type": "dict",
|
||||
"schema": {
|
||||
"name": {"type": "role_name"},
|
||||
"color": {"type": "number", "default": 0},
|
||||
"hoist": {"type": "boolean", "default": False},
|
||||
# NOTE: no position on partial role (on guild create)
|
||||
|
||||
'permissions': {'coerce': Permissions, 'required': False},
|
||||
'mentionable': {'type': 'boolean', 'default': False},
|
||||
}
|
||||
"permissions": {"coerce": Permissions, "required": False},
|
||||
"mentionable": {"type": "boolean", "default": False},
|
||||
},
|
||||
}
|
||||
|
||||
PARTIAL_CHANNEL_GUILD_CREATE = {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'name': {'type': 'channel_name'},
|
||||
'type': {'type': 'channel_type'},
|
||||
}
|
||||
"type": "dict",
|
||||
"schema": {"name": {"type": "channel_name"}, "type": {"type": "channel_type"}},
|
||||
}
|
||||
|
||||
GUILD_CREATE = {
|
||||
'name': {'type': 'guild_name'},
|
||||
'region': {'type': 'voice_region', 'nullable': True},
|
||||
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
|
||||
|
||||
'verification_level': {
|
||||
'type': 'verification_level', 'default': 0},
|
||||
'default_message_notifications': {
|
||||
'type': 'msg_notifications', 'default': 0},
|
||||
'explicit_content_filter': {
|
||||
'type': 'explicit', 'default': 0},
|
||||
|
||||
'roles': {
|
||||
'type': 'list', 'required': False,
|
||||
'schema': PARTIAL_ROLE_GUILD_CREATE},
|
||||
'channels': {
|
||||
'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE},
|
||||
"name": {"type": "guild_name"},
|
||||
"region": {"type": "voice_region", "nullable": True},
|
||||
"icon": {"type": "b64_icon", "required": False, "nullable": True},
|
||||
"verification_level": {"type": "verification_level", "default": 0},
|
||||
"default_message_notifications": {"type": "msg_notifications", "default": 0},
|
||||
"explicit_content_filter": {"type": "explicit", "default": 0},
|
||||
"roles": {"type": "list", "required": False, "schema": PARTIAL_ROLE_GUILD_CREATE},
|
||||
"channels": {"type": "list", "default": [], "schema": PARTIAL_CHANNEL_GUILD_CREATE},
|
||||
}
|
||||
|
||||
|
||||
GUILD_UPDATE = {
|
||||
'name': {
|
||||
'type': 'guild_name',
|
||||
'required': False
|
||||
},
|
||||
'region': {'type': 'voice_region', 'required': False, 'nullable': True},
|
||||
|
||||
"name": {"type": "guild_name", "required": False},
|
||||
"region": {"type": "voice_region", "required": False, "nullable": True},
|
||||
# all three can have hashes
|
||||
'icon': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'banner': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'splash': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
'description': {
|
||||
'type': 'string', 'required': False,
|
||||
'minlength': 1, 'maxlength': 120,
|
||||
'nullable': True
|
||||
"icon": {"type": "string", "required": False, "nullable": True},
|
||||
"banner": {"type": "string", "required": False, "nullable": True},
|
||||
"splash": {"type": "string", "required": False, "nullable": True},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"minlength": 1,
|
||||
"maxlength": 120,
|
||||
"nullable": True,
|
||||
},
|
||||
|
||||
'verification_level': {
|
||||
'type': 'verification_level', 'required': False},
|
||||
'default_message_notifications': {
|
||||
'type': 'msg_notifications', 'required': False},
|
||||
'explicit_content_filter': {'type': 'explicit', 'required': False},
|
||||
|
||||
'afk_channel_id': {
|
||||
'type': 'snowflake', 'required': False, 'nullable': True},
|
||||
'afk_timeout': {'type': 'number', 'required': False},
|
||||
|
||||
'owner_id': {'type': 'snowflake', 'required': False},
|
||||
|
||||
'system_channel_id': {
|
||||
'type': 'snowflake', 'required': False, 'nullable': True},
|
||||
"verification_level": {"type": "verification_level", "required": False},
|
||||
"default_message_notifications": {"type": "msg_notifications", "required": False},
|
||||
"explicit_content_filter": {"type": "explicit", "required": False},
|
||||
"afk_channel_id": {"type": "snowflake", "required": False, "nullable": True},
|
||||
"afk_timeout": {"type": "number", "required": False},
|
||||
"owner_id": {"type": "snowflake", "required": False},
|
||||
"system_channel_id": {"type": "snowflake", "required": False, "nullable": True},
|
||||
}
|
||||
|
||||
|
||||
CHAN_OVERWRITE = {
|
||||
'id': {'coerce': int},
|
||||
'type': {'type': 'string', 'allowed': ['role', 'member']},
|
||||
'allow': {'coerce': Permissions},
|
||||
'deny': {'coerce': Permissions}
|
||||
"id": {"coerce": int},
|
||||
"type": {"type": "string", "allowed": ["role", "member"]},
|
||||
"allow": {"coerce": Permissions},
|
||||
"deny": {"coerce": Permissions},
|
||||
}
|
||||
|
||||
|
||||
CHAN_CREATE = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 2,
|
||||
'maxlength': 100, 'required': True
|
||||
},
|
||||
|
||||
'type': {'type': 'channel_type',
|
||||
'default': ChannelType.GUILD_TEXT.value},
|
||||
|
||||
'position': {'coerce': int, 'required': False},
|
||||
|
||||
'topic': {
|
||||
'type': 'string', 'minlength': 0,
|
||||
'maxlength': 1024, 'required': False},
|
||||
|
||||
'nsfw': {'type': 'boolean', 'required': False},
|
||||
'rate_limit_per_user': {
|
||||
'coerce': int, 'min': 0,
|
||||
'max': 120, 'required': False},
|
||||
|
||||
'bitrate': {
|
||||
'coerce': int, 'min': 8000,
|
||||
|
||||
"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": True},
|
||||
"type": {"type": "channel_type", "default": ChannelType.GUILD_TEXT.value},
|
||||
"position": {"coerce": int, "required": False},
|
||||
"topic": {"type": "string", "minlength": 0, "maxlength": 1024, "required": False},
|
||||
"nsfw": {"type": "boolean", "required": False},
|
||||
"rate_limit_per_user": {"coerce": int, "min": 0, "max": 120, "required": False},
|
||||
"bitrate": {
|
||||
"coerce": int,
|
||||
"min": 8000,
|
||||
# NOTE: 'max' is 96000 for non-vip guilds
|
||||
'max': 128000, 'required': False},
|
||||
|
||||
'user_limit': {
|
||||
"max": 128000,
|
||||
"required": False,
|
||||
},
|
||||
"user_limit": {
|
||||
# user_limit being 0 means infinite.
|
||||
'coerce': int, 'min': 0,
|
||||
'max': 99, 'required': False
|
||||
"coerce": int,
|
||||
"min": 0,
|
||||
"max": 99,
|
||||
"required": False,
|
||||
},
|
||||
|
||||
'permission_overwrites': {
|
||||
'type': 'list',
|
||||
'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE},
|
||||
'required': False
|
||||
"permission_overwrites": {
|
||||
"type": "list",
|
||||
"schema": {"type": "dict", "schema": CHAN_OVERWRITE},
|
||||
"required": False,
|
||||
},
|
||||
|
||||
'parent_id': {'coerce': int, 'required': False, 'nullable': True}
|
||||
"parent_id": {"coerce": int, "required": False, "nullable": True},
|
||||
}
|
||||
|
||||
|
||||
CHAN_UPDATE = {**CHAN_CREATE, **{
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 2,
|
||||
'maxlength': 100, 'required': False},
|
||||
|
||||
}}
|
||||
CHAN_UPDATE = {
|
||||
**CHAN_CREATE,
|
||||
**{"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": False}},
|
||||
}
|
||||
|
||||
|
||||
ROLE_CREATE = {
|
||||
'name': {'type': 'string', 'default': 'new role'},
|
||||
'permissions': {'coerce': Permissions, 'nullable': True},
|
||||
'color': {'coerce': Color, 'default': 0},
|
||||
'hoist': {'type': 'boolean', 'default': False},
|
||||
'mentionable': {'type': 'boolean', 'default': False},
|
||||
"name": {"type": "string", "default": "new role"},
|
||||
"permissions": {"coerce": Permissions, "nullable": True},
|
||||
"color": {"coerce": Color, "default": 0},
|
||||
"hoist": {"type": "boolean", "default": False},
|
||||
"mentionable": {"type": "boolean", "default": False},
|
||||
}
|
||||
|
||||
ROLE_UPDATE = {
|
||||
'name': {'type': 'string', 'required': False},
|
||||
'permissions': {'coerce': Permissions, 'required': False},
|
||||
'color': {'coerce': Color, 'required': False},
|
||||
'hoist': {'type': 'boolean', 'required': False},
|
||||
'mentionable': {'type': 'boolean', 'required': False},
|
||||
"name": {"type": "string", "required": False},
|
||||
"permissions": {"coerce": Permissions, "required": False},
|
||||
"color": {"coerce": Color, "required": False},
|
||||
"hoist": {"type": "boolean", "required": False},
|
||||
"mentionable": {"type": "boolean", "required": False},
|
||||
}
|
||||
|
||||
|
||||
ROLE_UPDATE_POSITION = {
|
||||
'roles': {
|
||||
'type': 'list',
|
||||
'schema': {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'id': {'coerce': int},
|
||||
'position': {'coerce': int},
|
||||
},
|
||||
}
|
||||
"roles": {
|
||||
"type": "list",
|
||||
"schema": {
|
||||
"type": "dict",
|
||||
"schema": {"id": {"coerce": int}, "position": {"coerce": int}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
MEMBER_UPDATE = {
|
||||
'nick': {
|
||||
'type': 'nickname', 'required': False},
|
||||
'roles': {'type': 'list', 'required': False,
|
||||
'schema': {'coerce': int}},
|
||||
'mute': {'type': 'boolean', 'required': False},
|
||||
'deaf': {'type': 'boolean', 'required': False},
|
||||
'channel_id': {'type': 'snowflake', 'required': False},
|
||||
"nick": {"type": "nickname", "required": False},
|
||||
"roles": {"type": "list", "required": False, "schema": {"coerce": int}},
|
||||
"mute": {"type": "boolean", "required": False},
|
||||
"deaf": {"type": "boolean", "required": False},
|
||||
"channel_id": {"type": "snowflake", "required": False},
|
||||
}
|
||||
|
||||
|
||||
# NOTE: things such as payload_json are parsed at the handler
|
||||
# for creating a message.
|
||||
MESSAGE_CREATE = {
|
||||
'content': {'type': 'string', 'minlength': 0, 'maxlength': 2000},
|
||||
'nonce': {'type': 'snowflake', 'required': False},
|
||||
'tts': {'type': 'boolean', 'required': False},
|
||||
|
||||
'embed': {
|
||||
'type': 'dict',
|
||||
'schema': EMBED_OBJECT,
|
||||
'required': False,
|
||||
'nullable': True
|
||||
}
|
||||
"content": {"type": "string", "minlength": 0, "maxlength": 2000},
|
||||
"nonce": {"type": "snowflake", "required": False},
|
||||
"tts": {"type": "boolean", "required": False},
|
||||
"embed": {
|
||||
"type": "dict",
|
||||
"schema": EMBED_OBJECT,
|
||||
"required": False,
|
||||
"nullable": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
GW_ACTIVITY = {
|
||||
'name': {'type': 'string', 'required': True},
|
||||
'type': {'type': 'activity_type', 'required': True},
|
||||
|
||||
'url': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
'timestamps': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'start': {'type': 'number', 'required': False},
|
||||
'end': {'type': 'number', 'required': False},
|
||||
"name": {"type": "string", "required": True},
|
||||
"type": {"type": "activity_type", "required": True},
|
||||
"url": {"type": "string", "required": False, "nullable": True},
|
||||
"timestamps": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"schema": {
|
||||
"start": {"type": "number", "required": False},
|
||||
"end": {"type": "number", "required": False},
|
||||
},
|
||||
},
|
||||
|
||||
'application_id': {'type': 'snowflake', 'required': False,
|
||||
'nullable': False},
|
||||
'details': {'type': 'string', 'required': False, 'nullable': True},
|
||||
'state': {'type': 'string', 'required': False, 'nullable': True},
|
||||
|
||||
'party': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'id': {'type': 'snowflake', 'required': False},
|
||||
'size': {'type': 'list', 'required': False},
|
||||
}
|
||||
"application_id": {"type": "snowflake", "required": False, "nullable": False},
|
||||
"details": {"type": "string", "required": False, "nullable": True},
|
||||
"state": {"type": "string", "required": False, "nullable": True},
|
||||
"party": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"schema": {
|
||||
"id": {"type": "snowflake", "required": False},
|
||||
"size": {"type": "list", "required": False},
|
||||
},
|
||||
},
|
||||
|
||||
'assets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'large_image': {'type': 'snowflake', 'required': False},
|
||||
'large_text': {'type': 'string', 'required': False},
|
||||
'small_image': {'type': 'snowflake', 'required': False},
|
||||
'small_text': {'type': 'string', 'required': False},
|
||||
}
|
||||
"assets": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"schema": {
|
||||
"large_image": {"type": "snowflake", "required": False},
|
||||
"large_text": {"type": "string", "required": False},
|
||||
"small_image": {"type": "snowflake", "required": False},
|
||||
"small_text": {"type": "string", "required": False},
|
||||
},
|
||||
},
|
||||
|
||||
'secrets': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'join': {'type': 'string', 'required': False},
|
||||
'spectate': {'type': 'string', 'required': False},
|
||||
'match': {'type': 'string', 'required': False},
|
||||
}
|
||||
"secrets": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"schema": {
|
||||
"join": {"type": "string", "required": False},
|
||||
"spectate": {"type": "string", "required": False},
|
||||
"match": {"type": "string", "required": False},
|
||||
},
|
||||
},
|
||||
|
||||
'instance': {'type': 'boolean', 'required': False},
|
||||
'flags': {'type': 'number', 'required': False},
|
||||
"instance": {"type": "boolean", "required": False},
|
||||
"flags": {"type": "number", "required": False},
|
||||
}
|
||||
|
||||
GW_STATUS_UPDATE = {
|
||||
'status': {'type': 'status_external', 'required': False,
|
||||
'default': 'online'},
|
||||
'activities': {
|
||||
'type': 'list', 'required': False,
|
||||
'schema': {'type': 'dict', 'schema': GW_ACTIVITY}
|
||||
"status": {"type": "status_external", "required": False, "default": "online"},
|
||||
"activities": {
|
||||
"type": "list",
|
||||
"required": False,
|
||||
"schema": {"type": "dict", "schema": GW_ACTIVITY},
|
||||
},
|
||||
'afk': {'type': 'boolean', 'required': False},
|
||||
|
||||
'since': {'type': 'number', 'required': False, 'nullable': True},
|
||||
'game': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'nullable': True,
|
||||
'schema': GW_ACTIVITY,
|
||||
"afk": {"type": "boolean", "required": False},
|
||||
"since": {"type": "number", "required": False, "nullable": True},
|
||||
"game": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"nullable": True,
|
||||
"schema": GW_ACTIVITY,
|
||||
},
|
||||
}
|
||||
|
||||
INVITE = {
|
||||
# max_age in seconds
|
||||
# 0 for infinite
|
||||
'max_age': {
|
||||
'type': 'number',
|
||||
'min': 0,
|
||||
'max': 86400,
|
||||
|
||||
"max_age": {
|
||||
"type": "number",
|
||||
"min": 0,
|
||||
"max": 86400,
|
||||
# a day
|
||||
'default': 86400
|
||||
"default": 86400,
|
||||
},
|
||||
|
||||
# max invite uses
|
||||
'max_uses': {
|
||||
'type': 'number',
|
||||
'min': 0,
|
||||
|
||||
"max_uses": {
|
||||
"type": "number",
|
||||
"min": 0,
|
||||
# idk
|
||||
'max': 1000,
|
||||
|
||||
"max": 1000,
|
||||
# default infinite
|
||||
'default': 0
|
||||
"default": 0,
|
||||
},
|
||||
|
||||
'temporary': {'type': 'boolean', 'required': False, 'default': False},
|
||||
'unique': {'type': 'boolean', 'required': False, 'default': True},
|
||||
'validate': {'type': 'string', 'required': False, 'nullable': True} # discord client sends invite code there
|
||||
"temporary": {"type": "boolean", "required": False, "default": False},
|
||||
"unique": {"type": "boolean", "required": False, "default": True},
|
||||
"validate": {
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"nullable": True,
|
||||
}, # discord client sends invite code there
|
||||
}
|
||||
|
||||
USER_SETTINGS = {
|
||||
'afk_timeout': {
|
||||
'type': 'number', 'required': False, 'min': 0, 'max': 3000},
|
||||
|
||||
'animate_emoji': {'type': 'boolean', 'required': False},
|
||||
'convert_emoticons': {'type': 'boolean', 'required': False},
|
||||
'default_guilds_restricted': {'type': 'boolean', 'required': False},
|
||||
'detect_platform_accounts': {'type': 'boolean', 'required': False},
|
||||
'developer_mode': {'type': 'boolean', 'required': False},
|
||||
'disable_games_tab': {'type': 'boolean', 'required': False},
|
||||
'enable_tts_command': {'type': 'boolean', 'required': False},
|
||||
|
||||
'explicit_content_filter': {'type': 'explicit', 'required': False},
|
||||
|
||||
'friend_source': {
|
||||
'type': 'dict',
|
||||
'required': False,
|
||||
'schema': {
|
||||
'all': {'type': 'boolean', 'required': False},
|
||||
'mutual_guilds': {'type': 'boolean', 'required': False},
|
||||
'mutual_friends': {'type': 'boolean', 'required': False},
|
||||
}
|
||||
"afk_timeout": {"type": "number", "required": False, "min": 0, "max": 3000},
|
||||
"animate_emoji": {"type": "boolean", "required": False},
|
||||
"convert_emoticons": {"type": "boolean", "required": False},
|
||||
"default_guilds_restricted": {"type": "boolean", "required": False},
|
||||
"detect_platform_accounts": {"type": "boolean", "required": False},
|
||||
"developer_mode": {"type": "boolean", "required": False},
|
||||
"disable_games_tab": {"type": "boolean", "required": False},
|
||||
"enable_tts_command": {"type": "boolean", "required": False},
|
||||
"explicit_content_filter": {"type": "explicit", "required": False},
|
||||
"friend_source": {
|
||||
"type": "dict",
|
||||
"required": False,
|
||||
"schema": {
|
||||
"all": {"type": "boolean", "required": False},
|
||||
"mutual_guilds": {"type": "boolean", "required": False},
|
||||
"mutual_friends": {"type": "boolean", "required": False},
|
||||
},
|
||||
},
|
||||
'guild_positions': {
|
||||
'type': 'list',
|
||||
'required': False,
|
||||
'schema': {'type': 'snowflake'}
|
||||
"guild_positions": {
|
||||
"type": "list",
|
||||
"required": False,
|
||||
"schema": {"type": "snowflake"},
|
||||
},
|
||||
'restricted_guilds': {
|
||||
'type': 'list',
|
||||
'required': False,
|
||||
'schema': {'type': 'snowflake'}
|
||||
"restricted_guilds": {
|
||||
"type": "list",
|
||||
"required": False,
|
||||
"schema": {"type": "snowflake"},
|
||||
},
|
||||
|
||||
'gif_auto_play': {'type': 'boolean', 'required': False},
|
||||
'inline_attachment_media': {'type': 'boolean', 'required': False},
|
||||
'inline_embed_media': {'type': 'boolean', 'required': False},
|
||||
'message_display_compact': {'type': 'boolean', 'required': False},
|
||||
'render_embeds': {'type': 'boolean', 'required': False},
|
||||
'render_reactions': {'type': 'boolean', 'required': False},
|
||||
'show_current_game': {'type': 'boolean', 'required': False},
|
||||
|
||||
'timezone_offset': {'type': 'number', 'required': False},
|
||||
|
||||
'status': {'type': 'status_external', 'required': False},
|
||||
'theme': {'type': 'theme', 'required': False}
|
||||
"gif_auto_play": {"type": "boolean", "required": False},
|
||||
"inline_attachment_media": {"type": "boolean", "required": False},
|
||||
"inline_embed_media": {"type": "boolean", "required": False},
|
||||
"message_display_compact": {"type": "boolean", "required": False},
|
||||
"render_embeds": {"type": "boolean", "required": False},
|
||||
"render_reactions": {"type": "boolean", "required": False},
|
||||
"show_current_game": {"type": "boolean", "required": False},
|
||||
"timezone_offset": {"type": "number", "required": False},
|
||||
"status": {"type": "status_external", "required": False},
|
||||
"theme": {"type": "theme", "required": False},
|
||||
}
|
||||
|
||||
RELATIONSHIP = {
|
||||
'type': {
|
||||
'type': 'rel_type',
|
||||
'required': False,
|
||||
'default': RelationshipType.FRIEND.value
|
||||
"type": {
|
||||
"type": "rel_type",
|
||||
"required": False,
|
||||
"default": RelationshipType.FRIEND.value,
|
||||
}
|
||||
}
|
||||
|
||||
CREATE_DM = {
|
||||
'recipient_id': {
|
||||
'type': 'snowflake',
|
||||
'required': True
|
||||
}
|
||||
}
|
||||
CREATE_DM = {"recipient_id": {"type": "snowflake", "required": True}}
|
||||
|
||||
CREATE_GROUP_DM = {
|
||||
'recipients': {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'schema': {'type': 'snowflake'}
|
||||
},
|
||||
"recipients": {"type": "list", "required": True, "schema": {"type": "snowflake"}}
|
||||
}
|
||||
|
||||
GROUP_DM_UPDATE = {
|
||||
'name': {
|
||||
'type': 'guild_name',
|
||||
'required': False
|
||||
},
|
||||
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
|
||||
"name": {"type": "guild_name", "required": False},
|
||||
"icon": {"type": "b64_icon", "required": False, "nullable": True},
|
||||
}
|
||||
|
||||
SPECIFIC_FRIEND = {
|
||||
'username': {'type': 'username'},
|
||||
'discriminator': {'type': 'discriminator'}
|
||||
"username": {"type": "username"},
|
||||
"discriminator": {"type": "discriminator"},
|
||||
}
|
||||
|
||||
GUILD_SETTINGS_CHAN_OVERRIDE = {
|
||||
'type': 'dict',
|
||||
'schema': {
|
||||
'muted': {
|
||||
'type': 'boolean', 'required': False},
|
||||
'message_notifications': {
|
||||
'type': 'msg_notifications',
|
||||
'required': False,
|
||||
}
|
||||
}
|
||||
"type": "dict",
|
||||
"schema": {
|
||||
"muted": {"type": "boolean", "required": False},
|
||||
"message_notifications": {"type": "msg_notifications", "required": False},
|
||||
},
|
||||
}
|
||||
|
||||
GUILD_SETTINGS = {
|
||||
'channel_overrides': {
|
||||
'type': 'dict',
|
||||
'valueschema': GUILD_SETTINGS_CHAN_OVERRIDE,
|
||||
'keyschema': {'type': 'snowflake'},
|
||||
'required': False,
|
||||
"channel_overrides": {
|
||||
"type": "dict",
|
||||
"valueschema": GUILD_SETTINGS_CHAN_OVERRIDE,
|
||||
"keyschema": {"type": "snowflake"},
|
||||
"required": False,
|
||||
},
|
||||
'suppress_everyone': {
|
||||
'type': 'boolean', 'required': False},
|
||||
'muted': {
|
||||
'type': 'boolean', 'required': False},
|
||||
'mobile_push': {
|
||||
'type': 'boolean', 'required': False},
|
||||
'message_notifications': {
|
||||
'type': 'msg_notifications',
|
||||
'required': False,
|
||||
}
|
||||
"suppress_everyone": {"type": "boolean", "required": False},
|
||||
"muted": {"type": "boolean", "required": False},
|
||||
"mobile_push": {"type": "boolean", "required": False},
|
||||
"message_notifications": {"type": "msg_notifications", "required": False},
|
||||
}
|
||||
|
||||
GUILD_PRUNE = {
|
||||
'days': {'type': 'number', 'coerce': int, 'min': 1, 'max': 30, 'default': 7},
|
||||
'compute_prune_count': {'type': 'string', 'default': 'true'}
|
||||
"days": {"type": "number", "coerce": int, "min": 1, "max": 30, "default": 7},
|
||||
"compute_prune_count": {"type": "string", "default": "true"},
|
||||
}
|
||||
|
||||
NEW_EMOJI = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True},
|
||||
'image': {'type': 'b64_icon', 'required': True},
|
||||
'roles': {'type': 'list', 'schema': {'coerce': int}}
|
||||
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
|
||||
"image": {"type": "b64_icon", "required": True},
|
||||
"roles": {"type": "list", "schema": {"coerce": int}},
|
||||
}
|
||||
|
||||
PATCH_EMOJI = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True},
|
||||
'roles': {'type': 'list', 'schema': {'coerce': int}}
|
||||
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
|
||||
"roles": {"type": "list", "schema": {"coerce": int}},
|
||||
}
|
||||
|
||||
|
||||
SEARCH_CHANNEL = {
|
||||
'content': {'type': 'string', 'minlength': 1, 'required': True},
|
||||
'include_nsfw': {'coerce': bool, 'default': False},
|
||||
'offset': {'coerce': int, 'default': 0}
|
||||
"content": {"type": "string", "minlength": 1, "required": True},
|
||||
"include_nsfw": {"coerce": bool, "default": False},
|
||||
"offset": {"coerce": int, "default": 0},
|
||||
}
|
||||
|
||||
|
||||
GET_MENTIONS = {
|
||||
'limit': {'coerce': int, 'default': 25},
|
||||
'roles': {'coerce': bool, 'default': True},
|
||||
'everyone': {'coerce': bool, 'default': True},
|
||||
'guild_id': {'coerce': int, 'required': False}
|
||||
"limit": {"coerce": int, "default": 25},
|
||||
"roles": {"coerce": bool, "default": True},
|
||||
"everyone": {"coerce": bool, "default": True},
|
||||
"guild_id": {"coerce": int, "required": False},
|
||||
}
|
||||
|
||||
|
||||
VANITY_URL_PATCH = {
|
||||
# TODO: put proper values in maybe an invite data type
|
||||
'code': {'type': 'string', 'minlength': 5, 'maxlength': 30}
|
||||
"code": {"type": "string", "minlength": 5, "maxlength": 30}
|
||||
}
|
||||
|
||||
WEBHOOK_CREATE = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 2, 'maxlength': 32,
|
||||
'required': True
|
||||
},
|
||||
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False}
|
||||
"name": {"type": "string", "minlength": 2, "maxlength": 32, "required": True},
|
||||
"avatar": {"type": "b64_icon", "required": False, "nullable": False},
|
||||
}
|
||||
|
||||
WEBHOOK_UPDATE = {
|
||||
'name': {
|
||||
'type': 'string', 'minlength': 2, 'maxlength': 32,
|
||||
'required': False
|
||||
},
|
||||
|
||||
"name": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
|
||||
# TODO: check if its b64_icon or string since the client
|
||||
# could pass an icon hash instead.
|
||||
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False},
|
||||
'channel_id': {'coerce': int, 'required': False, 'nullable': False}
|
||||
"avatar": {"type": "b64_icon", "required": False, "nullable": False},
|
||||
"channel_id": {"coerce": int, "required": False, "nullable": False},
|
||||
}
|
||||
|
||||
WEBHOOK_MESSAGE_CREATE = {
|
||||
'content': {
|
||||
'type': 'string',
|
||||
'minlength': 0, 'maxlength': 2000, 'required': False
|
||||
"content": {"type": "string", "minlength": 0, "maxlength": 2000, "required": False},
|
||||
"tts": {"type": "boolean", "required": False},
|
||||
"username": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
|
||||
"avatar_url": {"coerce": EmbedURL, "required": False},
|
||||
"embeds": {
|
||||
"type": "list",
|
||||
"required": False,
|
||||
"schema": {"type": "dict", "schema": EMBED_OBJECT},
|
||||
},
|
||||
'tts': {'type': 'boolean', 'required': False},
|
||||
|
||||
'username': {
|
||||
'type': 'string',
|
||||
'minlength': 2, 'maxlength': 32, 'required': False
|
||||
},
|
||||
|
||||
'avatar_url': {
|
||||
'coerce': EmbedURL, 'required': False
|
||||
},
|
||||
|
||||
'embeds': {
|
||||
'type': 'list',
|
||||
'required': False,
|
||||
'schema': {'type': 'dict', 'schema': EMBED_OBJECT}
|
||||
}
|
||||
}
|
||||
|
||||
BULK_DELETE = {
|
||||
'messages': {
|
||||
'type': 'list', 'required': True,
|
||||
'minlength': 2, 'maxlength': 100,
|
||||
'schema': {'coerce': int}
|
||||
"messages": {
|
||||
"type": "list",
|
||||
"required": True,
|
||||
"minlength": 2,
|
||||
"maxlength": 100,
|
||||
"schema": {"coerce": int},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,19 +61,19 @@ def _snowflake(timestamp: int) -> Snowflake:
|
|||
# bits 0-12 encode _generated_ids (size 12)
|
||||
|
||||
# modulo'd to prevent overflows
|
||||
genid_b = '{0:012b}'.format(_generated_ids % 4096)
|
||||
genid_b = "{0:012b}".format(_generated_ids % 4096)
|
||||
|
||||
# bits 12-17 encode PROCESS_ID (size 5)
|
||||
procid_b = '{0:05b}'.format(PROCESS_ID)
|
||||
procid_b = "{0:05b}".format(PROCESS_ID)
|
||||
|
||||
# bits 17-22 encode WORKER_ID (size 5)
|
||||
workid_b = '{0:05b}'.format(WORKER_ID)
|
||||
workid_b = "{0:05b}".format(WORKER_ID)
|
||||
|
||||
# bits 22-64 encode (timestamp - EPOCH) (size 42)
|
||||
epochized = timestamp - EPOCH
|
||||
epoch_b = '{0:042b}'.format(epochized)
|
||||
epoch_b = "{0:042b}".format(epochized)
|
||||
|
||||
snowflake_b = f'{epoch_b}{workid_b}{procid_b}{genid_b}'
|
||||
snowflake_b = f"{epoch_b}{workid_b}{procid_b}{genid_b}"
|
||||
_generated_ids += 1
|
||||
|
||||
return int(snowflake_b, 2)
|
||||
|
|
@ -87,7 +87,7 @@ def snowflake_time(snowflake: Snowflake) -> float:
|
|||
# the total size for a snowflake is 64 bits,
|
||||
# considering it is a string, position 0 to 42 will give us
|
||||
# the `epochized` variable
|
||||
snowflake_b = '{0:064b}'.format(snowflake)
|
||||
snowflake_b = "{0:064b}".format(snowflake)
|
||||
epochized_b = snowflake_b[:42]
|
||||
epochized = int(epochized_b, 2)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -24,6 +24,7 @@ from litecord.enums import MessageType
|
|||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
||||
"""Handle a message pin."""
|
||||
new_id = get_snowflake()
|
||||
|
|
@ -37,8 +38,10 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
|||
($1, $2, NULL, $3, NULL, '',
|
||||
$4)
|
||||
""",
|
||||
new_id, channel_id, author_id,
|
||||
MessageType.CHANNEL_PINNED_MESSAGE.value
|
||||
new_id,
|
||||
channel_id,
|
||||
author_id,
|
||||
MessageType.CHANNEL_PINNED_MESSAGE.value,
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
|
@ -56,15 +59,16 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id):
|
|||
VALUES
|
||||
($1, $2, $3, NULL, $4, $5)
|
||||
""",
|
||||
new_id, channel_id, author_id,
|
||||
f'<@{peer_id}>',
|
||||
MessageType.RECIPIENT_ADD.value
|
||||
new_id,
|
||||
channel_id,
|
||||
author_id,
|
||||
f"<@{peer_id}>",
|
||||
MessageType.RECIPIENT_ADD.value,
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
|
||||
|
||||
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
||||
new_id = get_snowflake()
|
||||
|
||||
|
|
@ -76,9 +80,11 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
|||
VALUES
|
||||
($1, $2, $3, NULL, $4, $5)
|
||||
""",
|
||||
new_id, channel_id, author_id,
|
||||
f'<@{peer_id}>',
|
||||
MessageType.RECIPIENT_REMOVE.value
|
||||
new_id,
|
||||
channel_id,
|
||||
author_id,
|
||||
f"<@{peer_id}>",
|
||||
MessageType.RECIPIENT_REMOVE.value,
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
|
@ -87,13 +93,16 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
|||
async def _handle_gdm_name_edit(app, channel_id, author_id):
|
||||
new_id = get_snowflake()
|
||||
|
||||
gdm_name = await app.db.fetchval("""
|
||||
gdm_name = await app.db.fetchval(
|
||||
"""
|
||||
SELECT name FROM group_dm_channels
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
""",
|
||||
channel_id,
|
||||
)
|
||||
|
||||
if not gdm_name:
|
||||
log.warning('no gdm name found for sys message')
|
||||
log.warning("no gdm name found for sys message")
|
||||
return
|
||||
|
||||
await app.db.execute(
|
||||
|
|
@ -104,9 +113,11 @@ async def _handle_gdm_name_edit(app, channel_id, author_id):
|
|||
VALUES
|
||||
($1, $2, $3, NULL, $4, $5)
|
||||
""",
|
||||
new_id, channel_id, author_id,
|
||||
new_id,
|
||||
channel_id,
|
||||
author_id,
|
||||
gdm_name,
|
||||
MessageType.CHANNEL_NAME_CHANGE.value
|
||||
MessageType.CHANNEL_NAME_CHANGE.value,
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
|
@ -123,16 +134,19 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id):
|
|||
VALUES
|
||||
($1, $2, $3, NULL, $4, $5)
|
||||
""",
|
||||
new_id, channel_id, author_id,
|
||||
'',
|
||||
MessageType.CHANNEL_ICON_CHANGE.value
|
||||
new_id,
|
||||
channel_id,
|
||||
author_id,
|
||||
"",
|
||||
MessageType.CHANNEL_ICON_CHANGE.value,
|
||||
)
|
||||
|
||||
return new_id
|
||||
|
||||
|
||||
async def send_sys_message(app, channel_id: int, m_type: MessageType,
|
||||
*args, **kwargs) -> int:
|
||||
async def send_sys_message(
|
||||
app, channel_id: int, m_type: MessageType, *args, **kwargs
|
||||
) -> int:
|
||||
"""Send a system message.
|
||||
|
||||
The handler for a given message type MUST return an integer, that integer
|
||||
|
|
@ -156,22 +170,19 @@ async def send_sys_message(app, channel_id: int, m_type: MessageType,
|
|||
try:
|
||||
handler = {
|
||||
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
|
||||
|
||||
# gdm specific
|
||||
MessageType.RECIPIENT_ADD: _handle_recp_add,
|
||||
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
|
||||
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
|
||||
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit
|
||||
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit,
|
||||
}[m_type]
|
||||
except KeyError:
|
||||
raise ValueError('Invalid system message type')
|
||||
raise ValueError("Invalid system message type")
|
||||
|
||||
message_id = await handler(app, channel_id, *args, **kwargs)
|
||||
|
||||
message = await app.storage.get_message(message_id)
|
||||
|
||||
await app.dispatcher.dispatch(
|
||||
'channel', channel_id, 'MESSAGE_CREATE', message
|
||||
)
|
||||
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
|
||||
|
||||
return message_id
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ HOURS = 60 * MINUTES
|
|||
|
||||
class Color:
|
||||
"""Custom color class"""
|
||||
|
||||
def __init__(self, val: int):
|
||||
self.blue = val & 255
|
||||
self.green = (val >> 8) & 255
|
||||
|
|
@ -37,7 +38,7 @@ class Color:
|
|||
@property
|
||||
def value(self):
|
||||
"""Give the actual RGB integer encoding this color."""
|
||||
return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16)
|
||||
return int("%02x%02x%02x" % (self.red, self.green, self.blue), 16)
|
||||
|
||||
@property
|
||||
def to_json(self):
|
||||
|
|
@ -49,4 +50,4 @@ class Color:
|
|||
|
||||
def timestamp_(dt) -> Optional[str]:
|
||||
"""safer version for dt.isoformat()"""
|
||||
return f'{dt.isoformat()}+00:00' if dt else None
|
||||
return f"{dt.isoformat()}+00:00" if dt else None
|
||||
|
|
|
|||
|
|
@ -27,43 +27,52 @@ log = Logger(__name__)
|
|||
|
||||
class UserStorage:
|
||||
"""Storage functions related to a single user."""
|
||||
|
||||
def __init__(self, storage):
|
||||
self.storage = storage
|
||||
self.db = storage.db
|
||||
|
||||
async def fetch_notes(self, user_id: int) -> dict:
|
||||
"""Fetch a users' notes"""
|
||||
note_rows = await self.db.fetch("""
|
||||
note_rows = await self.db.fetch(
|
||||
"""
|
||||
SELECT target_id, note
|
||||
FROM notes
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return {str(row['target_id']): row['note']
|
||||
for row in note_rows}
|
||||
return {str(row["target_id"]): row["note"] for row in note_rows}
|
||||
|
||||
async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
|
||||
"""Get current user settings."""
|
||||
row = await self.storage.fetchrow_with_json("""
|
||||
row = await self.storage.fetchrow_with_json(
|
||||
"""
|
||||
SELECT *
|
||||
FROM user_settings
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
if not row:
|
||||
log.info('Generating user settings for {}', user_id)
|
||||
log.info("Generating user settings for {}", user_id)
|
||||
|
||||
await self.db.execute("""
|
||||
await self.db.execute(
|
||||
"""
|
||||
INSERT INTO user_settings (id)
|
||||
VALUES ($1)
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
# recalling get_user_settings
|
||||
# should work after adding
|
||||
return await self.get_user_settings(user_id)
|
||||
|
||||
drow = dict(row)
|
||||
drow.pop('id')
|
||||
drow.pop("id")
|
||||
return drow
|
||||
|
||||
async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
|
|
@ -76,11 +85,15 @@ class UserStorage:
|
|||
_outgoing = RelationshipType.OUTGOING.value
|
||||
|
||||
# check all outgoing friends
|
||||
friends = await self.db.fetch("""
|
||||
friends = await self.db.fetch(
|
||||
"""
|
||||
SELECT user_id, peer_id, rel_type
|
||||
FROM relationships
|
||||
WHERE user_id = $1 AND rel_type = $2
|
||||
""", user_id, _friend)
|
||||
""",
|
||||
user_id,
|
||||
_friend,
|
||||
)
|
||||
friends = list(map(dict, friends))
|
||||
|
||||
# mutuals is a list of ints
|
||||
|
|
@ -95,66 +108,80 @@ class UserStorage:
|
|||
SELECT user_id, peer_id
|
||||
FROM relationships
|
||||
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
||||
""", row['peer_id'], row['user_id'],
|
||||
_friend)
|
||||
""",
|
||||
row["peer_id"],
|
||||
row["user_id"],
|
||||
_friend,
|
||||
)
|
||||
|
||||
if is_friend is not None:
|
||||
mutuals.append(row['peer_id'])
|
||||
mutuals.append(row["peer_id"])
|
||||
|
||||
# fetch friend requests directed at us
|
||||
incoming_friends = await self.db.fetch("""
|
||||
incoming_friends = await self.db.fetch(
|
||||
"""
|
||||
SELECT user_id, peer_id
|
||||
FROM relationships
|
||||
WHERE peer_id = $1 AND rel_type = $2
|
||||
""", user_id, _friend)
|
||||
""",
|
||||
user_id,
|
||||
_friend,
|
||||
)
|
||||
|
||||
# only need their ids
|
||||
incoming_friends = [r['user_id'] for r in incoming_friends
|
||||
if r['user_id'] not in mutuals]
|
||||
incoming_friends = [
|
||||
r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals
|
||||
]
|
||||
|
||||
# only fetch blocks we did,
|
||||
# not fetching the ones people did to us
|
||||
blocks = await self.db.fetch("""
|
||||
blocks = await self.db.fetch(
|
||||
"""
|
||||
SELECT user_id, peer_id, rel_type
|
||||
FROM relationships
|
||||
WHERE user_id = $1 AND rel_type = $2
|
||||
""", user_id, _block)
|
||||
""",
|
||||
user_id,
|
||||
_block,
|
||||
)
|
||||
blocks = list(map(dict, blocks))
|
||||
|
||||
res = []
|
||||
|
||||
for drow in friends:
|
||||
drow['type'] = drow['rel_type']
|
||||
drow['id'] = str(drow['peer_id'])
|
||||
drow.pop('rel_type')
|
||||
drow["type"] = drow["rel_type"]
|
||||
drow["id"] = str(drow["peer_id"])
|
||||
drow.pop("rel_type")
|
||||
|
||||
# check if the receiver is a mutual
|
||||
# if it isnt, its still on a friend request stage
|
||||
if drow['peer_id'] not in mutuals:
|
||||
drow['type'] = _outgoing
|
||||
if drow["peer_id"] not in mutuals:
|
||||
drow["type"] = _outgoing
|
||||
|
||||
drow['user'] = await self.storage.get_user(drow['peer_id'])
|
||||
drow["user"] = await self.storage.get_user(drow["peer_id"])
|
||||
|
||||
drow.pop('user_id')
|
||||
drow.pop('peer_id')
|
||||
drow.pop("user_id")
|
||||
drow.pop("peer_id")
|
||||
res.append(drow)
|
||||
|
||||
for peer_id in incoming_friends:
|
||||
res.append({
|
||||
'id': str(peer_id),
|
||||
'user': await self.storage.get_user(peer_id),
|
||||
'type': _incoming,
|
||||
})
|
||||
res.append(
|
||||
{
|
||||
"id": str(peer_id),
|
||||
"user": await self.storage.get_user(peer_id),
|
||||
"type": _incoming,
|
||||
}
|
||||
)
|
||||
|
||||
for drow in blocks:
|
||||
drow['type'] = drow['rel_type']
|
||||
drow.pop('rel_type')
|
||||
drow["type"] = drow["rel_type"]
|
||||
drow.pop("rel_type")
|
||||
|
||||
drow['id'] = str(drow['peer_id'])
|
||||
drow['user'] = await self.storage.get_user(drow['peer_id'])
|
||||
drow["id"] = str(drow["peer_id"])
|
||||
drow["user"] = await self.storage.get_user(drow["peer_id"])
|
||||
|
||||
drow.pop('user_id')
|
||||
drow.pop('peer_id')
|
||||
drow.pop("user_id")
|
||||
drow.pop("peer_id")
|
||||
res.append(drow)
|
||||
|
||||
return res
|
||||
|
|
@ -163,9 +190,11 @@ class UserStorage:
|
|||
"""Get all friend IDs for a user."""
|
||||
rels = await self.get_relationships(user_id)
|
||||
|
||||
return [int(r['user']['id'])
|
||||
for r in rels
|
||||
if r['type'] == RelationshipType.FRIEND.value]
|
||||
return [
|
||||
int(r["user"]["id"])
|
||||
for r in rels
|
||||
if r["type"] == RelationshipType.FRIEND.value
|
||||
]
|
||||
|
||||
async def get_dms(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get all DM channels for a user, including group DMs.
|
||||
|
|
@ -173,13 +202,16 @@ class UserStorage:
|
|||
This will only fetch channels the user has in their state,
|
||||
which is different than the whole list of DM channels.
|
||||
"""
|
||||
dm_ids = await self.db.fetch("""
|
||||
dm_ids = await self.db.fetch(
|
||||
"""
|
||||
SELECT dm_id
|
||||
FROM dm_channel_state
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
dm_ids = [r['dm_id'] for r in dm_ids]
|
||||
dm_ids = [r["dm_id"] for r in dm_ids]
|
||||
|
||||
res = []
|
||||
|
||||
|
|
@ -191,21 +223,24 @@ class UserStorage:
|
|||
|
||||
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get the read state for a user."""
|
||||
rows = await self.db.fetch("""
|
||||
rows = await self.db.fetch(
|
||||
"""
|
||||
SELECT channel_id, last_message_id, mention_count
|
||||
FROM user_read_state
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
res = []
|
||||
|
||||
for row in rows:
|
||||
drow = dict(row)
|
||||
|
||||
drow['id'] = str(drow['channel_id'])
|
||||
drow.pop('channel_id')
|
||||
drow["id"] = str(drow["channel_id"])
|
||||
drow.pop("channel_id")
|
||||
|
||||
drow['last_message_id'] = str(drow['last_message_id'])
|
||||
drow["last_message_id"] = str(drow["last_message_id"])
|
||||
|
||||
res.append(drow)
|
||||
|
||||
|
|
@ -214,13 +249,17 @@ class UserStorage:
|
|||
async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List:
|
||||
chan_overrides = []
|
||||
|
||||
overrides = await self.db.fetch("""
|
||||
overrides = await self.db.fetch(
|
||||
"""
|
||||
SELECT channel_id::text, muted, message_notifications
|
||||
FROM guild_settings_channel_overrides
|
||||
WHERE
|
||||
user_id = $1
|
||||
AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
for chan_row in overrides:
|
||||
dcrow = dict(chan_row)
|
||||
|
|
@ -228,30 +267,35 @@ class UserStorage:
|
|||
|
||||
return chan_overrides
|
||||
|
||||
async def get_guild_settings_one(self, user_id: int,
|
||||
guild_id: int) -> dict:
|
||||
async def get_guild_settings_one(self, user_id: int, guild_id: int) -> dict:
|
||||
"""Get guild settings information for a single guild."""
|
||||
row = await self.db.fetchrow("""
|
||||
row = await self.db.fetchrow(
|
||||
"""
|
||||
SELECT guild_id::text, suppress_everyone, muted,
|
||||
message_notifications, mobile_push
|
||||
FROM guild_settings
|
||||
WHERE user_id = $1 AND guild_id = $2
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
if not row:
|
||||
await self.db.execute("""
|
||||
await self.db.execute(
|
||||
"""
|
||||
INSERT INTO guild_settings (user_id, guild_id)
|
||||
VALUES ($1, $2)
|
||||
""", user_id, guild_id)
|
||||
""",
|
||||
user_id,
|
||||
guild_id,
|
||||
)
|
||||
|
||||
return await self.get_guild_settings_one(user_id, guild_id)
|
||||
|
||||
gid = int(row['guild_id'])
|
||||
gid = int(row["guild_id"])
|
||||
drow = dict(row)
|
||||
chan_overrides = await self._get_chan_overrides(user_id, gid)
|
||||
return {**drow, **{
|
||||
'channel_overrides': chan_overrides
|
||||
}}
|
||||
return {**drow, **{"channel_overrides": chan_overrides}}
|
||||
|
||||
async def get_guild_settings(self, user_id: int):
|
||||
"""Get the specific User Guild Settings,
|
||||
|
|
@ -259,34 +303,38 @@ class UserStorage:
|
|||
|
||||
res = []
|
||||
|
||||
settings = await self.db.fetch("""
|
||||
settings = await self.db.fetch(
|
||||
"""
|
||||
SELECT guild_id::text, suppress_everyone, muted,
|
||||
message_notifications, mobile_push
|
||||
FROM guild_settings
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
for row in settings:
|
||||
gid = int(row['guild_id'])
|
||||
gid = int(row["guild_id"])
|
||||
drow = dict(row)
|
||||
|
||||
chan_overrides = await self._get_chan_overrides(user_id, gid)
|
||||
|
||||
res.append({**drow, **{
|
||||
'channel_overrides': chan_overrides
|
||||
}})
|
||||
res.append({**drow, **{"channel_overrides": chan_overrides}})
|
||||
|
||||
return res
|
||||
|
||||
async def get_user_guilds(self, user_id: int) -> List[int]:
|
||||
"""Get all guild IDs a user is on."""
|
||||
guild_ids = await self.db.fetch("""
|
||||
guild_ids = await self.db.fetch(
|
||||
"""
|
||||
SELECT guild_id
|
||||
FROM members
|
||||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return [row['guild_id'] for row in guild_ids]
|
||||
return [row["guild_id"] for row in guild_ids]
|
||||
|
||||
async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]:
|
||||
"""Get a list of guilds two separate users
|
||||
|
|
@ -301,13 +349,17 @@ class UserStorage:
|
|||
|
||||
return await self.get_user_guilds(user_id) or [0]
|
||||
|
||||
mutual_guilds = await self.db.fetch("""
|
||||
mutual_guilds = await self.db.fetch(
|
||||
"""
|
||||
SELECT guild_id FROM members WHERE user_id = $1
|
||||
INTERSECT
|
||||
SELECT guild_id FROM members WHERE user_id = $2
|
||||
""", user_id, peer_id)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
)
|
||||
|
||||
mutual_guilds = [r['guild_id'] for r in mutual_guilds]
|
||||
mutual_guilds = [r["guild_id"] for r in mutual_guilds]
|
||||
|
||||
return mutual_guilds
|
||||
|
||||
|
|
@ -316,7 +368,8 @@ class UserStorage:
|
|||
|
||||
This returns false even if there is a friend request.
|
||||
"""
|
||||
return await self.db.fetchval("""
|
||||
return await self.db.fetchval(
|
||||
"""
|
||||
SELECT
|
||||
(
|
||||
SELECT EXISTS(
|
||||
|
|
@ -337,17 +390,23 @@ class UserStorage:
|
|||
AND rel_type = 1
|
||||
)
|
||||
)
|
||||
""", user_id, peer_id)
|
||||
""",
|
||||
user_id,
|
||||
peer_id,
|
||||
)
|
||||
|
||||
async def get_gdms_internal(self, user_id) -> List[int]:
|
||||
"""Return a list of Group DM IDs the user is a member of."""
|
||||
rows = await self.db.fetch("""
|
||||
rows = await self.db.fetch(
|
||||
"""
|
||||
SELECT id
|
||||
FROM group_dm_members
|
||||
WHERE member_id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
return [r['id'] for r in rows]
|
||||
return [r["id"] for r in rows]
|
||||
|
||||
async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
|
||||
"""Get list of group DMs a user is in."""
|
||||
|
|
@ -356,8 +415,6 @@ class UserStorage:
|
|||
res = []
|
||||
|
||||
for gdm_id in gdm_ids:
|
||||
res.append(
|
||||
await self.storage.get_channel(gdm_id, user_id=user_id)
|
||||
)
|
||||
res.append(await self.storage.get_channel(gdm_id, user_id=user_id))
|
||||
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ async def task_wrapper(name: str, coro):
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
except:
|
||||
log.exception('{} task error', name)
|
||||
log.exception("{} task error", name)
|
||||
|
||||
|
||||
def dict_get(mapping, key, default):
|
||||
|
|
@ -84,54 +84,66 @@ def mmh3(inp_str: str, seed: int = 0):
|
|||
h1 = seed
|
||||
|
||||
# mm3 constants
|
||||
c1 = 0xcc9e2d51
|
||||
c2 = 0x1b873593
|
||||
c1 = 0xCC9E2D51
|
||||
c2 = 0x1B873593
|
||||
i = 0
|
||||
|
||||
while i < bytecount:
|
||||
k1 = (
|
||||
(key[i] & 0xff) |
|
||||
((key[i + 1] & 0xff) << 8) |
|
||||
((key[i + 2] & 0xff) << 16) |
|
||||
((key[i + 3] & 0xff) << 24)
|
||||
(key[i] & 0xFF)
|
||||
| ((key[i + 1] & 0xFF) << 8)
|
||||
| ((key[i + 2] & 0xFF) << 16)
|
||||
| ((key[i + 3] & 0xFF) << 24)
|
||||
)
|
||||
|
||||
i += 4
|
||||
|
||||
k1 = ((((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16))) & 0xffffffff
|
||||
k1 = (
|
||||
(((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16))
|
||||
) & 0xFFFFFFFF
|
||||
k1 = (k1 << 15) | (_u(k1) >> 17)
|
||||
k1 = ((((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16))) & 0xffffffff;
|
||||
k1 = (
|
||||
(((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16))
|
||||
) & 0xFFFFFFFF
|
||||
|
||||
h1 ^= k1
|
||||
h1 = (h1 << 13) | (_u(h1) >> 19);
|
||||
h1b = ((((h1 & 0xffff) * 5) + ((((_u(h1) >> 16) * 5) & 0xffff) << 16))) & 0xffffffff;
|
||||
h1 = (((h1b & 0xffff) + 0x6b64) + ((((_u(h1b) >> 16) + 0xe654) & 0xffff) << 16))
|
||||
|
||||
h1 = (h1 << 13) | (_u(h1) >> 19)
|
||||
h1b = (
|
||||
(((h1 & 0xFFFF) * 5) + ((((_u(h1) >> 16) * 5) & 0xFFFF) << 16))
|
||||
) & 0xFFFFFFFF
|
||||
h1 = ((h1b & 0xFFFF) + 0x6B64) + ((((_u(h1b) >> 16) + 0xE654) & 0xFFFF) << 16)
|
||||
|
||||
k1 = 0
|
||||
v = None
|
||||
|
||||
if remainder == 3:
|
||||
v = (key[i + 2] & 0xff) << 16
|
||||
v = (key[i + 2] & 0xFF) << 16
|
||||
elif remainder == 2:
|
||||
v = (key[i + 1] & 0xff) << 8
|
||||
v = (key[i + 1] & 0xFF) << 8
|
||||
elif remainder == 1:
|
||||
v = (key[i] & 0xff)
|
||||
v = key[i] & 0xFF
|
||||
|
||||
if v is not None:
|
||||
k1 ^= v
|
||||
|
||||
k1 = (((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16)) & 0xffffffff
|
||||
k1 = (((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16)) & 0xFFFFFFFF
|
||||
k1 = (k1 << 15) | (_u(k1) >> 17)
|
||||
k1 = (((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16)) & 0xffffffff
|
||||
k1 = (((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16)) & 0xFFFFFFFF
|
||||
h1 ^= k1
|
||||
|
||||
h1 ^= len(key)
|
||||
|
||||
h1 ^= _u(h1) >> 16
|
||||
h1 = (((h1 & 0xffff) * 0x85ebca6b) + ((((_u(h1) >> 16) * 0x85ebca6b) & 0xffff) << 16)) & 0xffffffff
|
||||
h1 = (
|
||||
((h1 & 0xFFFF) * 0x85EBCA6B) + ((((_u(h1) >> 16) * 0x85EBCA6B) & 0xFFFF) << 16)
|
||||
) & 0xFFFFFFFF
|
||||
h1 ^= _u(h1) >> 13
|
||||
h1 = ((((h1 & 0xffff) * 0xc2b2ae35) + ((((_u(h1) >> 16) * 0xc2b2ae35) & 0xffff) << 16))) & 0xffffffff
|
||||
h1 = (
|
||||
(
|
||||
((h1 & 0xFFFF) * 0xC2B2AE35)
|
||||
+ ((((_u(h1) >> 16) * 0xC2B2AE35) & 0xFFFF) << 16)
|
||||
)
|
||||
) & 0xFFFFFFFF
|
||||
h1 ^= _u(h1) >> 16
|
||||
|
||||
return _u(h1) >> 0
|
||||
|
|
@ -139,6 +151,7 @@ def mmh3(inp_str: str, seed: int = 0):
|
|||
|
||||
class LitecordJSONEncoder(JSONEncoder):
|
||||
"""Custom JSON encoder for Litecord."""
|
||||
|
||||
def default(self, value: Any):
|
||||
"""By default, this will try to get the to_json attribute of a given
|
||||
value being JSON encoded."""
|
||||
|
|
@ -151,17 +164,17 @@ class LitecordJSONEncoder(JSONEncoder):
|
|||
async def pg_set_json(con):
|
||||
"""Set JSON and JSONB codecs for an asyncpg connection."""
|
||||
await con.set_type_codec(
|
||||
'json',
|
||||
"json",
|
||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||
decoder=json.loads,
|
||||
schema='pg_catalog'
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
await con.set_type_codec(
|
||||
'jsonb',
|
||||
"jsonb",
|
||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||
decoder=json.loads,
|
||||
schema='pg_catalog'
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -177,7 +190,8 @@ def yield_chunks(input_list: Sequence[Any], chunk_size: int):
|
|||
# range accepts step param, so we use that to
|
||||
# make the chunks
|
||||
for idx in range(0, len(input_list), chunk_size):
|
||||
yield input_list[idx:idx + chunk_size]
|
||||
yield input_list[idx : idx + chunk_size]
|
||||
|
||||
|
||||
def to_update(j: dict, orig: dict, field: str) -> bool:
|
||||
"""Compare values to check if j[field] is actually updating
|
||||
|
|
@ -193,27 +207,23 @@ async def search_result_from_list(rows: List) -> Dict[str, Any]:
|
|||
- An int (?) on `total_results`
|
||||
- Two bigint[], each on `before` and `after` respectively.
|
||||
"""
|
||||
results = 0 if not rows else rows[0]['total_results']
|
||||
results = 0 if not rows else rows[0]["total_results"]
|
||||
res = []
|
||||
|
||||
for row in rows:
|
||||
before, after = [], []
|
||||
|
||||
for before_id in reversed(row['before']):
|
||||
for before_id in reversed(row["before"]):
|
||||
before.append(await app.storage.get_message(before_id))
|
||||
|
||||
for after_id in row['after']:
|
||||
for after_id in row["after"]:
|
||||
after.append(await app.storage.get_message(after_id))
|
||||
|
||||
msg = await app.storage.get_message(row['current_id'])
|
||||
msg['hit'] = True
|
||||
msg = await app.storage.get_message(row["current_id"])
|
||||
msg["hit"] = True
|
||||
res.append(before + [msg] + after)
|
||||
|
||||
return {
|
||||
'total_results': results,
|
||||
'messages': res,
|
||||
'analytics_id': '',
|
||||
}
|
||||
return {"total_results": results, "messages": res, "analytics_id": ""}
|
||||
|
||||
|
||||
def maybe_int(val: Any) -> Union[int, Any]:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ log = Logger(__name__)
|
|||
|
||||
class LVSPConnection:
|
||||
"""Represents a single LVSP connection."""
|
||||
|
||||
def __init__(self, lvsp, region: str, hostname: str):
|
||||
self.lvsp = lvsp
|
||||
self.app = lvsp.app
|
||||
|
|
@ -46,7 +47,7 @@ class LVSPConnection:
|
|||
|
||||
@property
|
||||
def _log_id(self):
|
||||
return f'region={self.region} hostname={self.hostname}'
|
||||
return f"region={self.region} hostname={self.hostname}"
|
||||
|
||||
async def send(self, payload):
|
||||
"""Send a payload down the websocket."""
|
||||
|
|
@ -61,50 +62,42 @@ class LVSPConnection:
|
|||
|
||||
async def send_op(self, opcode: int, data: dict):
|
||||
"""Send a message with an OP code included"""
|
||||
await self.send({
|
||||
'op': opcode,
|
||||
'd': data
|
||||
})
|
||||
await self.send({"op": opcode, "d": data})
|
||||
|
||||
async def send_info(self, info_type: str, info_data: Dict):
|
||||
"""Send an INFO message down the websocket."""
|
||||
await self.send({
|
||||
'op': OP.info,
|
||||
'd': {
|
||||
'type': InfoTable[info_type.upper()],
|
||||
'data': info_data
|
||||
await self.send(
|
||||
{
|
||||
"op": OP.info,
|
||||
"d": {"type": InfoTable[info_type.upper()], "data": info_data},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
async def _heartbeater(self, hb_interval: int):
|
||||
try:
|
||||
await asyncio.sleep(hb_interval)
|
||||
|
||||
# TODO: add self._seq
|
||||
await self.send_op(OP.heartbeat, {
|
||||
's': 0
|
||||
})
|
||||
await self.send_op(OP.heartbeat, {"s": 0})
|
||||
|
||||
# give the server 300 milliseconds to reply.
|
||||
await asyncio.sleep(300)
|
||||
await self.conn.close(4000, 'heartbeat timeout')
|
||||
await self.conn.close(4000, "heartbeat timeout")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def _start_hb(self):
|
||||
self._hb_task = self.app.loop.create_task(
|
||||
self._heartbeater(self._hb_interval)
|
||||
)
|
||||
self._hb_task = self.app.loop.create_task(self._heartbeater(self._hb_interval))
|
||||
|
||||
def _stop_hb(self):
|
||||
self._hb_task.cancel()
|
||||
|
||||
async def _handle_0(self, msg):
|
||||
"""Handle HELLO message."""
|
||||
data = msg['d']
|
||||
data = msg["d"]
|
||||
|
||||
# nonce = data['nonce']
|
||||
self._hb_interval = data['heartbeat_interval']
|
||||
self._hb_interval = data["heartbeat_interval"]
|
||||
|
||||
# TODO: send identify
|
||||
|
||||
|
|
@ -112,48 +105,52 @@ class LVSPConnection:
|
|||
"""Update the health value of a given voice server."""
|
||||
self.health = new_health
|
||||
|
||||
await self.app.db.execute("""
|
||||
await self.app.db.execute(
|
||||
"""
|
||||
UPDATE voice_servers
|
||||
SET health = $1
|
||||
WHERE hostname = $2
|
||||
""", new_health, self.hostname)
|
||||
""",
|
||||
new_health,
|
||||
self.hostname,
|
||||
)
|
||||
|
||||
async def _handle_3(self, msg):
|
||||
"""Handle READY message.
|
||||
|
||||
We only start heartbeating after READY.
|
||||
"""
|
||||
await self._update_health(msg['health'])
|
||||
await self._update_health(msg["health"])
|
||||
self._start_hb()
|
||||
|
||||
async def _handle_5(self, msg):
|
||||
"""Handle HEARTBEAT_ACK."""
|
||||
self._stop_hb()
|
||||
await self._update_health(msg['health'])
|
||||
await self._update_health(msg["health"])
|
||||
self._start_hb()
|
||||
|
||||
async def _handle_6(self, msg):
|
||||
"""Handle INFO messages."""
|
||||
info = msg['d']
|
||||
info_type_str = InfoReverse[info['type']].lower()
|
||||
info = msg["d"]
|
||||
info_type_str = InfoReverse[info["type"]].lower()
|
||||
|
||||
try:
|
||||
info_handler = getattr(self, f'_handle_info_{info_type_str}')
|
||||
info_handler = getattr(self, f"_handle_info_{info_type_str}")
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
await info_handler(info['data'])
|
||||
await info_handler(info["data"])
|
||||
|
||||
async def _handle_info_channel_assign(self, data: dict):
|
||||
"""called by the server once we got a channel assign."""
|
||||
try:
|
||||
channel_id = data['channel_id']
|
||||
channel_id = data["channel_id"]
|
||||
channel_id = int(channel_id)
|
||||
except (TypeError, ValueError):
|
||||
return
|
||||
|
||||
try:
|
||||
guild_id = data['guild_id']
|
||||
guild_id = data["guild_id"]
|
||||
guild_id = int(guild_id)
|
||||
except (TypeError, ValueError):
|
||||
guild_id = None
|
||||
|
|
@ -166,19 +163,19 @@ class LVSPConnection:
|
|||
msg = await self.recv()
|
||||
|
||||
try:
|
||||
opcode = msg['op']
|
||||
handler = getattr(self, f'_handle_{opcode}')
|
||||
opcode = msg["op"]
|
||||
handler = getattr(self, f"_handle_{opcode}")
|
||||
await handler(msg)
|
||||
except (KeyError, AttributeError):
|
||||
# TODO: error codes in LVSP
|
||||
raise Exception('invalid op code')
|
||||
raise Exception("invalid op code")
|
||||
|
||||
async def start(self):
|
||||
"""Try to start a websocket connection."""
|
||||
try:
|
||||
self.conn = await websockets.connect(f'wss://{self.hostname}')
|
||||
self.conn = await websockets.connect(f"wss://{self.hostname}")
|
||||
except Exception:
|
||||
log.exception('failed to start lvsp conn to {}', self.hostname)
|
||||
log.exception("failed to start lvsp conn to {}", self.hostname)
|
||||
|
||||
async def run(self):
|
||||
"""Start the websocket."""
|
||||
|
|
@ -186,15 +183,15 @@ class LVSPConnection:
|
|||
|
||||
try:
|
||||
if not self.conn:
|
||||
log.error('failed to start lvsp connection, stopping')
|
||||
log.error("failed to start lvsp connection, stopping")
|
||||
return
|
||||
|
||||
await self._loop()
|
||||
except websockets.exceptions.ConnectionClosed as err:
|
||||
log.warning('conn close, {}, err={}', self._log_id, err)
|
||||
log.warning("conn close, {}, err={}", self._log_id, err)
|
||||
# except WebsocketClose as err:
|
||||
# log.warning('ws close, state={} err={}', self.state, err)
|
||||
# await self.conn.close(code=err.code, reason=err.reason)
|
||||
except Exception as err:
|
||||
log.exception('An exception has occoured. {}', self._log_id)
|
||||
log.exception("An exception has occoured. {}", self._log_id)
|
||||
await self.conn.close(code=4000, reason=repr(err))
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ log = Logger(__name__)
|
|||
@dataclass
|
||||
class Region:
|
||||
"""Voice region data."""
|
||||
|
||||
id: str
|
||||
vip: bool
|
||||
|
||||
|
|
@ -40,6 +41,7 @@ class LVSPManager:
|
|||
|
||||
Spawns :class:`LVSPConnection` as needed, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, app, voice):
|
||||
self.app = app
|
||||
self.voice = voice
|
||||
|
|
@ -61,49 +63,50 @@ class LVSPManager:
|
|||
async def _spawn(self):
|
||||
"""Spawn LVSPConnection for each region."""
|
||||
|
||||
regions = await self.app.db.fetch("""
|
||||
regions = await self.app.db.fetch(
|
||||
"""
|
||||
SELECT id, vip
|
||||
FROM voice_regions
|
||||
WHERE deprecated = false
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
regions = [Region(r['id'], r['vip']) for r in regions]
|
||||
regions = [Region(r["id"], r["vip"]) for r in regions]
|
||||
|
||||
if not regions:
|
||||
log.warning('no regions are setup')
|
||||
log.warning("no regions are setup")
|
||||
return
|
||||
|
||||
for region in regions:
|
||||
# store it locally for region() function
|
||||
self.regions[region.id] = region
|
||||
|
||||
self.app.loop.create_task(
|
||||
self._spawn_region(region)
|
||||
)
|
||||
self.app.loop.create_task(self._spawn_region(region))
|
||||
|
||||
async def _spawn_region(self, region: Region):
|
||||
"""Spawn a region. Involves fetching all the hostnames
|
||||
for the regions and spawning a LVSPConnection for each."""
|
||||
servers = await self.app.db.fetch("""
|
||||
servers = await self.app.db.fetch(
|
||||
"""
|
||||
SELECT hostname
|
||||
FROM voice_servers
|
||||
WHERE region_id = $1
|
||||
""", region.id)
|
||||
""",
|
||||
region.id,
|
||||
)
|
||||
|
||||
if not servers:
|
||||
log.warning('region {} does not have servers', region)
|
||||
log.warning("region {} does not have servers", region)
|
||||
return
|
||||
|
||||
servers = [r['hostname'] for r in servers]
|
||||
servers = [r["hostname"] for r in servers]
|
||||
self.servers[region.id] = servers
|
||||
|
||||
for hostname in servers:
|
||||
conn = LVSPConnection(self, region.id, hostname)
|
||||
self.conns[hostname] = conn
|
||||
|
||||
self.app.loop.create_task(
|
||||
conn.run()
|
||||
)
|
||||
self.app.loop.create_task(conn.run())
|
||||
|
||||
async def del_conn(self, conn):
|
||||
"""Delete a connection from the connection pool."""
|
||||
|
|
@ -119,11 +122,14 @@ class LVSPManager:
|
|||
|
||||
async def guild_region(self, guild_id: int) -> Optional[str]:
|
||||
"""Return the voice region of a guild."""
|
||||
return await self.app.db.fetchval("""
|
||||
return await self.app.db.fetchval(
|
||||
"""
|
||||
SELECT region
|
||||
FROM guilds
|
||||
WHERE id = $1
|
||||
""", guild_id)
|
||||
""",
|
||||
guild_id,
|
||||
)
|
||||
|
||||
def get_health(self, hostname: str) -> float:
|
||||
"""Get voice server health, given hostname."""
|
||||
|
|
@ -144,10 +150,7 @@ class LVSPManager:
|
|||
region = await self.guild_region(guild_id)
|
||||
|
||||
# sort connected servers by health
|
||||
sorted_servers = sorted(
|
||||
self.servers[region],
|
||||
key=self.get_health
|
||||
)
|
||||
sorted_servers = sorted(self.servers[region], key=self.get_health)
|
||||
|
||||
try:
|
||||
hostname = sorted_servers[0]
|
||||
|
|
|
|||
|
|
@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
|
||||
class OPCodes:
|
||||
"""LVSP OP codes."""
|
||||
|
||||
hello = 0
|
||||
identify = 1
|
||||
resume = 2
|
||||
|
|
@ -29,13 +31,13 @@ class OPCodes:
|
|||
|
||||
|
||||
InfoTable = {
|
||||
'CHANNEL_REQ': 0,
|
||||
'CHANNEL_ASSIGN': 1,
|
||||
'CHANNEL_UPDATE': 2,
|
||||
'CHANNEL_DESTROY': 3,
|
||||
'VST_CREATE': 4,
|
||||
'VST_UPDATE': 5,
|
||||
'VST_LEAVE': 6,
|
||||
"CHANNEL_REQ": 0,
|
||||
"CHANNEL_ASSIGN": 1,
|
||||
"CHANNEL_UPDATE": 2,
|
||||
"CHANNEL_DESTROY": 3,
|
||||
"VST_CREATE": 4,
|
||||
"VST_UPDATE": 5,
|
||||
"VST_LEAVE": 6,
|
||||
}
|
||||
|
||||
InfoReverse = {v: k for k, v in InfoTable.items()}
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ def _construct_state(state_dict: dict) -> VoiceState:
|
|||
|
||||
class VoiceManager:
|
||||
"""Main voice manager class."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
|
|
@ -56,7 +57,7 @@ class VoiceManager:
|
|||
"""Return if a user can join a channel."""
|
||||
|
||||
channel = await self.app.storage.get_channel(channel_id)
|
||||
ctype = ChannelType(channel['type'])
|
||||
ctype = ChannelType(channel["type"])
|
||||
|
||||
if ctype not in VOICE_CHANNELS:
|
||||
return
|
||||
|
|
@ -65,14 +66,12 @@ class VoiceManager:
|
|||
|
||||
# get_permissions returns ALL_PERMISSIONS when
|
||||
# the channel isn't from a guild
|
||||
perms = await get_permissions(
|
||||
user_id, channel_id, storage=self.app.storage
|
||||
)
|
||||
perms = await get_permissions(user_id, channel_id, storage=self.app.storage)
|
||||
|
||||
# hacky user_limit but should work, as channels not
|
||||
# in guilds won't have that field.
|
||||
is_full = states >= channel.get('user_limit', 100)
|
||||
is_bot = (await self.app.storage.get_user(user_id))['bot']
|
||||
is_full = states >= channel.get("user_limit", 100)
|
||||
is_bot = (await self.app.storage.get_user(user_id))["bot"]
|
||||
is_manager = perms.bits.manage_channels
|
||||
|
||||
# if the channel is full AND:
|
||||
|
|
@ -140,8 +139,8 @@ class VoiceManager:
|
|||
|
||||
for field in prop:
|
||||
# NOTE: this should not happen, ever.
|
||||
if field in ('channel_id', 'user_id'):
|
||||
raise ValueError('properties are updating channel or user')
|
||||
if field in ("channel_id", "user_id"):
|
||||
raise ValueError("properties are updating channel or user")
|
||||
|
||||
new_state_dict[field] = prop[field]
|
||||
|
||||
|
|
@ -153,27 +152,28 @@ class VoiceManager:
|
|||
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
|
||||
"""Move a user between channels."""
|
||||
await self.del_state(old_voice_key)
|
||||
await self.create_state(old_voice_key, {'channel_id': channel_id})
|
||||
await self.create_state(old_voice_key, {"channel_id": channel_id})
|
||||
|
||||
async def _lvsp_info_guild(self, guild_id, info_type, info_data):
|
||||
hostname = await self.lvsp.get_guild_server(guild_id)
|
||||
if hostname is None:
|
||||
log.error('no voice server for guild id {}', guild_id)
|
||||
log.error("no voice server for guild id {}", guild_id)
|
||||
return
|
||||
|
||||
conn = self.lvsp.get_conn(hostname)
|
||||
await conn.send_info(info_type, info_data)
|
||||
|
||||
async def _create_ctx_guild(self, guild_id, channel_id):
|
||||
await self._lvsp_info_guild(guild_id, 'CHANNEL_REQ', {
|
||||
'guild_id': str(guild_id),
|
||||
'channel_id': str(channel_id),
|
||||
})
|
||||
await self._lvsp_info_guild(
|
||||
guild_id,
|
||||
"CHANNEL_REQ",
|
||||
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||
)
|
||||
|
||||
async def _start_voice_guild(self, voice_key: VoiceKey, data: dict):
|
||||
"""Start a voice context in a guild."""
|
||||
user_id, guild_id = voice_key
|
||||
channel_id = int(data['channel_id'])
|
||||
channel_id = int(data["channel_id"])
|
||||
|
||||
existing_states = self.states[voice_key]
|
||||
channel_exists = any(
|
||||
|
|
@ -183,11 +183,15 @@ class VoiceManager:
|
|||
if not channel_exists:
|
||||
await self._create_ctx_guild(guild_id, channel_id)
|
||||
|
||||
await self._lvsp_info_guild(guild_id, 'VST_CREATE', {
|
||||
'user_id': str(user_id),
|
||||
'guild_id': str(guild_id),
|
||||
'channel_id': str(channel_id),
|
||||
})
|
||||
await self._lvsp_info_guild(
|
||||
guild_id,
|
||||
"VST_CREATE",
|
||||
{
|
||||
"user_id": str(user_id),
|
||||
"guild_id": str(guild_id),
|
||||
"channel_id": str(channel_id),
|
||||
},
|
||||
)
|
||||
|
||||
async def create_state(self, voice_key: VoiceKey, data: dict):
|
||||
"""Creates (or tries to create) a voice state.
|
||||
|
|
@ -249,10 +253,13 @@ class VoiceManager:
|
|||
|
||||
async def voice_server_list(self, region: str) -> List[dict]:
|
||||
"""Get a list of voice server objects"""
|
||||
rows = await self.app.db.fetch("""
|
||||
rows = await self.app.db.fetch(
|
||||
"""
|
||||
SELECT hostname, last_health
|
||||
FROM voice_servers
|
||||
WHERE region_id = $1
|
||||
""", region)
|
||||
""",
|
||||
region,
|
||||
)
|
||||
|
||||
return list(map(dict, rows))
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from dataclasses import dataclass, asdict
|
|||
@dataclass
|
||||
class VoiceState:
|
||||
"""Represents a voice state."""
|
||||
|
||||
guild_id: int
|
||||
channel_id: int
|
||||
user_id: int
|
||||
|
|
@ -55,7 +56,7 @@ class VoiceState:
|
|||
|
||||
# a better approach would be actually using
|
||||
# the suppressed_by field for backend efficiency.
|
||||
self_dict['suppress'] = user_id == self.suppressed_by
|
||||
self_dict.pop('suppressed_by')
|
||||
self_dict["suppress"] = user_id == self.suppressed_by
|
||||
self_dict.pop("suppressed_by")
|
||||
|
||||
return self_dict
|
||||
|
|
|
|||
|
|
@ -27,5 +27,5 @@ import config
|
|||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(config))
|
||||
|
|
|
|||
|
|
@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
|
|||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
|
|||
|
||||
async def _gen_inv() -> str:
|
||||
"""Generate an invite code"""
|
||||
return ''.join(choice(ALPHABET) for _ in range(6))
|
||||
return "".join(choice(ALPHABET) for _ in range(6))
|
||||
|
||||
|
||||
async def gen_inv(ctx) -> str:
|
||||
|
|
@ -34,11 +34,14 @@ async def gen_inv(ctx) -> str:
|
|||
for _ in range(10):
|
||||
possible_inv = await _gen_inv()
|
||||
|
||||
created_at = await ctx.db.fetchval("""
|
||||
created_at = await ctx.db.fetchval(
|
||||
"""
|
||||
SELECT created_at
|
||||
FROM instance_invites
|
||||
WHERE code = $1
|
||||
""", possible_inv)
|
||||
""",
|
||||
possible_inv,
|
||||
)
|
||||
|
||||
if created_at is None:
|
||||
return possible_inv
|
||||
|
|
@ -51,27 +54,32 @@ async def make_inv(ctx, args):
|
|||
|
||||
max_uses = args.max_uses
|
||||
|
||||
await ctx.db.execute("""
|
||||
await ctx.db.execute(
|
||||
"""
|
||||
INSERT INTO instance_invites (code, max_uses)
|
||||
VALUES ($1, $2)
|
||||
""", code, max_uses)
|
||||
""",
|
||||
code,
|
||||
max_uses,
|
||||
)
|
||||
|
||||
print(f'invite created with {max_uses} max uses', code)
|
||||
print(f"invite created with {max_uses} max uses", code)
|
||||
|
||||
|
||||
async def list_invs(ctx, args):
|
||||
rows = await ctx.db.fetch("""
|
||||
rows = await ctx.db.fetch(
|
||||
"""
|
||||
SELECT code, created_at, uses, max_uses
|
||||
FROM instance_invites
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
print(len(rows), 'invites')
|
||||
print(len(rows), "invites")
|
||||
|
||||
for row in rows:
|
||||
max_uses = row['max_uses']
|
||||
delta = datetime.datetime.utcnow() - row['created_at']
|
||||
usage = ('infinite uses' if max_uses == -1
|
||||
else f'{row["uses"]} / {max_uses}')
|
||||
max_uses = row["max_uses"]
|
||||
delta = datetime.datetime.utcnow() - row["created_at"]
|
||||
usage = "infinite uses" if max_uses == -1 else f'{row["uses"]} / {max_uses}'
|
||||
|
||||
print(f'\t{row["code"]}, {usage}, made {delta} ago')
|
||||
|
||||
|
|
@ -79,40 +87,37 @@ async def list_invs(ctx, args):
|
|||
async def delete_inv(ctx, args):
|
||||
inv = args.invite_code
|
||||
|
||||
res = await ctx.db.execute("""
|
||||
res = await ctx.db.execute(
|
||||
"""
|
||||
DELETE FROM instance_invites
|
||||
WHERE code = $1
|
||||
""", inv)
|
||||
""",
|
||||
inv,
|
||||
)
|
||||
|
||||
if res == 'DELETE 0':
|
||||
print('NOT FOUND')
|
||||
if res == "DELETE 0":
|
||||
print("NOT FOUND")
|
||||
return
|
||||
|
||||
print('OK')
|
||||
print("OK")
|
||||
|
||||
|
||||
def setup(subparser):
|
||||
makeinv_parser = subparser.add_parser(
|
||||
'makeinv',
|
||||
help='create an invite',
|
||||
)
|
||||
makeinv_parser = subparser.add_parser("makeinv", help="create an invite")
|
||||
|
||||
makeinv_parser.add_argument(
|
||||
'max_uses', nargs='?', type=int, default=-1,
|
||||
help='Maximum amount of uses before the invite is unavailable',
|
||||
"max_uses",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum amount of uses before the invite is unavailable",
|
||||
)
|
||||
|
||||
makeinv_parser.set_defaults(func=make_inv)
|
||||
|
||||
listinv_parser = subparser.add_parser(
|
||||
'listinv',
|
||||
help='list all invites',
|
||||
)
|
||||
listinv_parser = subparser.add_parser("listinv", help="list all invites")
|
||||
listinv_parser.set_defaults(func=list_invs)
|
||||
|
||||
delinv_parser = subparser.add_parser(
|
||||
'delinv',
|
||||
help='delete an invite',
|
||||
)
|
||||
delinv_parser.add_argument('invite_code')
|
||||
delinv_parser = subparser.add_parser("delinv", help="delete an invite")
|
||||
delinv_parser.add_argument("invite_code")
|
||||
delinv_parser.set_defaults(func=delete_inv)
|
||||
|
|
|
|||
|
|
@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
from .command import setup as migration
|
||||
|
||||
__all__ = ['migration']
|
||||
__all__ = ["migration"]
|
||||
|
|
|
|||
|
|
@ -32,18 +32,19 @@ from logbook import Logger
|
|||
log = Logger(__name__)
|
||||
|
||||
|
||||
Migration = namedtuple('Migration', 'id name path')
|
||||
Migration = namedtuple("Migration", "id name path")
|
||||
|
||||
# line of change, 4 april 2019, at 1am (gmt+0)
|
||||
BREAK = datetime.datetime(2019, 4, 4, 1)
|
||||
|
||||
# if a database has those tables, it ran 0_base.sql.
|
||||
HAS_BASE = ['users', 'guilds', 'e']
|
||||
HAS_BASE = ["users", "guilds", "e"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationContext:
|
||||
"""Hold information about migration."""
|
||||
|
||||
migration_folder: Path
|
||||
scripts: Dict[int, Migration]
|
||||
|
||||
|
|
@ -60,22 +61,21 @@ def make_migration_ctx() -> MigrationContext:
|
|||
script_folder = os.sep.join(script_path.split(os.sep)[:-1])
|
||||
script_folder = Path(script_folder)
|
||||
|
||||
migration_folder = script_folder / 'scripts'
|
||||
migration_folder = script_folder / "scripts"
|
||||
|
||||
mctx = MigrationContext(migration_folder, {})
|
||||
|
||||
for mig_path in migration_folder.glob('*.sql'):
|
||||
for mig_path in migration_folder.glob("*.sql"):
|
||||
mig_path_str = str(mig_path)
|
||||
|
||||
# extract migration script id and name
|
||||
mig_filename = mig_path_str.split(os.sep)[-1].split('.')[0]
|
||||
name_fragments = mig_filename.split('_')
|
||||
mig_filename = mig_path_str.split(os.sep)[-1].split(".")[0]
|
||||
name_fragments = mig_filename.split("_")
|
||||
|
||||
mig_id = int(name_fragments[0])
|
||||
mig_name = '_'.join(name_fragments[1:])
|
||||
mig_name = "_".join(name_fragments[1:])
|
||||
|
||||
mctx.scripts[mig_id] = Migration(
|
||||
mig_id, mig_name, mig_path)
|
||||
mctx.scripts[mig_id] = Migration(mig_id, mig_name, mig_path)
|
||||
|
||||
return mctx
|
||||
|
||||
|
|
@ -83,7 +83,8 @@ def make_migration_ctx() -> MigrationContext:
|
|||
async def _ensure_changelog(app, ctx):
|
||||
# make sure we have the migration table up
|
||||
try:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
CREATE TABLE migration_log (
|
||||
change_num bigint NOT NULL,
|
||||
|
||||
|
|
@ -94,43 +95,56 @@ async def _ensure_changelog(app, ctx):
|
|||
|
||||
PRIMARY KEY (change_num)
|
||||
);
|
||||
""")
|
||||
"""
|
||||
)
|
||||
except asyncpg.DuplicateTableError:
|
||||
log.debug('existing migration table')
|
||||
log.debug("existing migration table")
|
||||
|
||||
# NOTE: this is a migration breakage,
|
||||
# only applying to databases that had their first migration
|
||||
# before 4 april 2019 (more on BREAK)
|
||||
|
||||
# if migration_log is empty, just assume this is new
|
||||
first = await app.db.fetchval("""
|
||||
first = (
|
||||
await app.db.fetchval(
|
||||
"""
|
||||
SELECT apply_ts FROM migration_log
|
||||
ORDER BY apply_ts ASC
|
||||
LIMIT 1
|
||||
""") or BREAK
|
||||
"""
|
||||
)
|
||||
or BREAK
|
||||
)
|
||||
if first < BREAK:
|
||||
log.info('deleting migration_log due to migration structure change')
|
||||
log.info("deleting migration_log due to migration structure change")
|
||||
await app.db.execute("DROP TABLE migration_log")
|
||||
await _ensure_changelog(app, ctx)
|
||||
|
||||
|
||||
async def _insert_log(app, migration_id: int, description) -> bool:
|
||||
try:
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO migration_log (change_num, description)
|
||||
VALUES ($1, $2)
|
||||
""", migration_id, description)
|
||||
""",
|
||||
migration_id,
|
||||
description,
|
||||
)
|
||||
|
||||
return True
|
||||
except asyncpg.UniqueViolationError:
|
||||
log.warning('already inserted {}', migration_id)
|
||||
log.warning("already inserted {}", migration_id)
|
||||
return False
|
||||
|
||||
|
||||
async def _delete_log(app, migration_id: int):
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
DELETE FROM migration_log WHERE change_num = $1
|
||||
""", migration_id)
|
||||
""",
|
||||
migration_id,
|
||||
)
|
||||
|
||||
|
||||
async def apply_migration(app, migration: Migration) -> bool:
|
||||
|
|
@ -144,21 +158,20 @@ async def apply_migration(app, migration: Migration) -> bool:
|
|||
|
||||
Returns a boolean signaling if this failed or not.
|
||||
"""
|
||||
migration_sql = migration.path.read_text(encoding='utf-8')
|
||||
migration_sql = migration.path.read_text(encoding="utf-8")
|
||||
|
||||
res = await _insert_log(
|
||||
app, migration.id, f'migration: {migration.name}')
|
||||
res = await _insert_log(app, migration.id, f"migration: {migration.name}")
|
||||
|
||||
if not res:
|
||||
return False
|
||||
|
||||
try:
|
||||
await app.db.execute(migration_sql)
|
||||
log.info('applied {} {}', migration.id, migration.name)
|
||||
log.info("applied {} {}", migration.id, migration.name)
|
||||
|
||||
return True
|
||||
except:
|
||||
log.exception('failed to run migration, rollbacking log')
|
||||
log.exception("failed to run migration, rollbacking log")
|
||||
await _delete_log(app, migration.id)
|
||||
|
||||
return False
|
||||
|
|
@ -169,9 +182,11 @@ async def _check_base(app) -> bool:
|
|||
file."""
|
||||
try:
|
||||
for table in HAS_BASE:
|
||||
await app.db.execute(f"""
|
||||
await app.db.execute(
|
||||
f"""
|
||||
SELECT * FROM {table} LIMIT 0
|
||||
""")
|
||||
"""
|
||||
)
|
||||
except asyncpg.UndefinedTableError:
|
||||
return False
|
||||
|
||||
|
|
@ -197,14 +212,16 @@ async def migrate_cmd(app, _args):
|
|||
has_base = await _check_base(app)
|
||||
|
||||
# fetch latest local migration that has been run on this database
|
||||
local_change = await app.db.fetchval("""
|
||||
local_change = await app.db.fetchval(
|
||||
"""
|
||||
SELECT max(change_num)
|
||||
FROM migration_log
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
# if base exists, add it to logs, if not, apply (and add to logs)
|
||||
if has_base:
|
||||
await _insert_log(app, 0, 'migration setup (from existing)')
|
||||
await _insert_log(app, 0, "migration setup (from existing)")
|
||||
else:
|
||||
await apply_migration(app, ctx.scripts[0])
|
||||
|
||||
|
|
@ -215,10 +232,10 @@ async def migrate_cmd(app, _args):
|
|||
local_change = local_change or 0
|
||||
latest_change = ctx.latest
|
||||
|
||||
log.debug('local: {}, latest: {}', local_change, latest_change)
|
||||
log.debug("local: {}, latest: {}", local_change, latest_change)
|
||||
|
||||
if local_change == latest_change:
|
||||
print('no changes to do, exiting')
|
||||
print("no changes to do, exiting")
|
||||
return
|
||||
|
||||
# we do local_change + 1 so we start from the
|
||||
|
|
@ -227,15 +244,13 @@ async def migrate_cmd(app, _args):
|
|||
for idx in range(local_change + 1, latest_change + 1):
|
||||
migration = ctx.scripts.get(idx)
|
||||
|
||||
print('applying', migration.id, migration.name)
|
||||
print("applying", migration.id, migration.name)
|
||||
await apply_migration(app, migration)
|
||||
|
||||
|
||||
def setup(subparser):
|
||||
migrate_parser = subparser.add_parser(
|
||||
'migrate',
|
||||
help='Run migration tasks',
|
||||
description=migrate_cmd.__doc__
|
||||
"migrate", help="Run migration tasks", description=migrate_cmd.__doc__
|
||||
)
|
||||
|
||||
migrate_parser.set_defaults(func=migrate_cmd)
|
||||
|
|
|
|||
|
|
@ -24,39 +24,51 @@ from litecord.enums import UserFlags
|
|||
|
||||
async def find_user(username, discrim, ctx) -> int:
|
||||
"""Get a user ID via the username/discrim pair."""
|
||||
return await ctx.db.fetchval("""
|
||||
return await ctx.db.fetchval(
|
||||
"""
|
||||
SELECT id
|
||||
FROM users
|
||||
WHERE username = $1 AND discriminator = $2
|
||||
""", username, discrim)
|
||||
""",
|
||||
username,
|
||||
discrim,
|
||||
)
|
||||
|
||||
|
||||
async def set_user_staff(user_id, ctx):
|
||||
"""Give a single user staff status."""
|
||||
old_flags = await ctx.db.fetchval("""
|
||||
old_flags = await ctx.db.fetchval(
|
||||
"""
|
||||
SELECT flags
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
new_flags = old_flags | UserFlags.staff
|
||||
|
||||
await ctx.db.execute("""
|
||||
await ctx.db.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET flags=$1
|
||||
WHERE id = $2
|
||||
""", new_flags, user_id)
|
||||
""",
|
||||
new_flags,
|
||||
user_id,
|
||||
)
|
||||
|
||||
|
||||
async def adduser(ctx, args):
|
||||
"""Create a single user."""
|
||||
uid, _ = await create_user(args.username, args.email,
|
||||
args.password, ctx.db, ctx.loop)
|
||||
uid, _ = await create_user(
|
||||
args.username, args.email, args.password, ctx.db, ctx.loop
|
||||
)
|
||||
|
||||
user = await ctx.storage.get_user(uid)
|
||||
|
||||
print('created!')
|
||||
print(f'\tuid: {uid}')
|
||||
print("created!")
|
||||
print(f"\tuid: {uid}")
|
||||
print(f'\tusername: {user["username"]}')
|
||||
print(f'\tdiscrim: {user["discriminator"]}')
|
||||
|
||||
|
|
@ -72,22 +84,26 @@ async def make_staff(ctx, args):
|
|||
uid = await find_user(args.username, args.discrim, ctx)
|
||||
|
||||
if not uid:
|
||||
return print('user not found')
|
||||
return print("user not found")
|
||||
|
||||
await set_user_staff(uid, ctx)
|
||||
print('OK: set staff')
|
||||
print("OK: set staff")
|
||||
|
||||
|
||||
async def generate_bot_token(ctx, args):
|
||||
"""Generate a token for specified bot."""
|
||||
|
||||
password_hash = await ctx.db.fetchval("""
|
||||
password_hash = await ctx.db.fetchval(
|
||||
"""
|
||||
SELECT password_hash
|
||||
FROM users
|
||||
WHERE id = $1 AND bot = 'true'
|
||||
""", int(args.user_id))
|
||||
""",
|
||||
int(args.user_id),
|
||||
)
|
||||
|
||||
if not password_hash:
|
||||
return print('cannot find a bot with specified id')
|
||||
return print("cannot find a bot with specified id")
|
||||
|
||||
print(make_token(args.user_id, password_hash))
|
||||
|
||||
|
|
@ -97,7 +113,7 @@ async def del_user(ctx, args):
|
|||
uid = await find_user(args.username, args.discrim, ctx)
|
||||
|
||||
if uid is None:
|
||||
print('user not found')
|
||||
print("user not found")
|
||||
return
|
||||
|
||||
user = await ctx.storage.get_user(uid)
|
||||
|
|
@ -106,57 +122,48 @@ async def del_user(ctx, args):
|
|||
print(f'\tuname: {user["username"]}')
|
||||
print(f'\tdiscrim: {user["discriminator"]}')
|
||||
|
||||
print('\n you sure you want to delete user? press Y (uppercase)')
|
||||
print("\n you sure you want to delete user? press Y (uppercase)")
|
||||
confirm = input()
|
||||
|
||||
if confirm != 'Y':
|
||||
print('not confirmed')
|
||||
if confirm != "Y":
|
||||
print("not confirmed")
|
||||
return
|
||||
|
||||
await delete_user(uid, app_=ctx)
|
||||
print('ok')
|
||||
print("ok")
|
||||
|
||||
|
||||
def setup(subparser):
|
||||
setup_test_parser = subparser.add_parser(
|
||||
'adduser',
|
||||
help='create a user',
|
||||
)
|
||||
setup_test_parser = subparser.add_parser("adduser", help="create a user")
|
||||
|
||||
setup_test_parser.add_argument(
|
||||
'username', help='username of the user')
|
||||
setup_test_parser.add_argument(
|
||||
'email', help='email of the user')
|
||||
setup_test_parser.add_argument(
|
||||
'password', help='password of the user')
|
||||
setup_test_parser.add_argument("username", help="username of the user")
|
||||
setup_test_parser.add_argument("email", help="email of the user")
|
||||
setup_test_parser.add_argument("password", help="password of the user")
|
||||
|
||||
setup_test_parser.set_defaults(func=adduser)
|
||||
|
||||
staff_parser = subparser.add_parser(
|
||||
'make_staff',
|
||||
help='make a user staff',
|
||||
description=make_staff.__doc__
|
||||
"make_staff", help="make a user staff", description=make_staff.__doc__
|
||||
)
|
||||
|
||||
staff_parser.add_argument('username')
|
||||
staff_parser.add_argument(
|
||||
'discrim', help='the discriminator of the user')
|
||||
staff_parser.add_argument("username")
|
||||
staff_parser.add_argument("discrim", help="the discriminator of the user")
|
||||
|
||||
staff_parser.set_defaults(func=make_staff)
|
||||
|
||||
del_user_parser = subparser.add_parser(
|
||||
'deluser', help='delete a single user')
|
||||
del_user_parser = subparser.add_parser("deluser", help="delete a single user")
|
||||
|
||||
del_user_parser.add_argument('username')
|
||||
del_user_parser.add_argument('discrim')
|
||||
del_user_parser.add_argument("username")
|
||||
del_user_parser.add_argument("discrim")
|
||||
|
||||
del_user_parser.set_defaults(func=del_user)
|
||||
|
||||
token_parser = subparser.add_parser(
|
||||
'generate_token',
|
||||
help='generate a token for specified bot',
|
||||
description=generate_bot_token.__doc__)
|
||||
"generate_token",
|
||||
help="generate a token for specified bot",
|
||||
description=generate_bot_token.__doc__,
|
||||
)
|
||||
|
||||
token_parser.add_argument('user_id')
|
||||
token_parser.add_argument("user_id")
|
||||
|
||||
token_parser.set_defaults(func=generate_bot_token)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ log = Logger(__name__)
|
|||
@dataclass
|
||||
class FakeApp:
|
||||
"""Fake app instance."""
|
||||
|
||||
config: dict
|
||||
db = None
|
||||
loop: asyncio.BaseEventLoop = None
|
||||
|
|
@ -50,7 +51,7 @@ class FakeApp:
|
|||
|
||||
def init_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
subparser = parser.add_subparsers(help='operations')
|
||||
subparser = parser.add_subparsers(help="operations")
|
||||
|
||||
migration(subparser)
|
||||
users.setup(subparser)
|
||||
|
|
@ -78,12 +79,12 @@ def main(config):
|
|||
# only init app managers when we aren't migrating
|
||||
# as the managers require it
|
||||
# and the migrate command also sets the db up
|
||||
if argv[1] != 'migrate':
|
||||
if argv[1] != "migrate":
|
||||
init_app_managers(app, voice=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
loop.run_until_complete(args.func(app, args))
|
||||
except Exception:
|
||||
log.exception('error while running command')
|
||||
log.exception("error while running command")
|
||||
finally:
|
||||
loop.run_until_complete(app.db.close())
|
||||
|
|
|
|||
236
run.py
236
run.py
|
|
@ -33,32 +33,51 @@ from aiohttp import ClientSession
|
|||
import config
|
||||
|
||||
from litecord.blueprints import (
|
||||
gateway, auth, users, guilds, channels, webhooks, science,
|
||||
voice, invites, relationships, dms, icons, nodeinfo, static,
|
||||
attachments, dm_channels
|
||||
gateway,
|
||||
auth,
|
||||
users,
|
||||
guilds,
|
||||
channels,
|
||||
webhooks,
|
||||
science,
|
||||
voice,
|
||||
invites,
|
||||
relationships,
|
||||
dms,
|
||||
icons,
|
||||
nodeinfo,
|
||||
static,
|
||||
attachments,
|
||||
dm_channels,
|
||||
)
|
||||
|
||||
# those blueprints are separated from the "main" ones
|
||||
# for code readability if people want to dig through
|
||||
# the codebase.
|
||||
from litecord.blueprints.guild import (
|
||||
guild_roles, guild_members, guild_channels, guild_mod,
|
||||
guild_emoji
|
||||
guild_roles,
|
||||
guild_members,
|
||||
guild_channels,
|
||||
guild_mod,
|
||||
guild_emoji,
|
||||
)
|
||||
|
||||
from litecord.blueprints.channel import (
|
||||
channel_messages, channel_reactions, channel_pins
|
||||
channel_messages,
|
||||
channel_reactions,
|
||||
channel_pins,
|
||||
)
|
||||
|
||||
from litecord.blueprints.user import (
|
||||
user_settings, user_billing, fake_store
|
||||
)
|
||||
from litecord.blueprints.user import user_settings, user_billing, fake_store
|
||||
|
||||
from litecord.blueprints.user.billing_job import payment_job
|
||||
|
||||
from litecord.blueprints.admin_api import (
|
||||
voice as voice_admin, features as features_admin,
|
||||
guilds as guilds_admin, users as users_admin, instance_invites
|
||||
voice as voice_admin,
|
||||
features as features_admin,
|
||||
guilds as guilds_admin,
|
||||
users as users_admin,
|
||||
instance_invites,
|
||||
)
|
||||
|
||||
from litecord.blueprints.admin_api.voice import guild_region_check
|
||||
|
|
@ -84,23 +103,23 @@ from litecord.utils import LitecordJSONEncoder
|
|||
# setup logbook
|
||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||
handler.push_application()
|
||||
log = Logger('litecord.boot')
|
||||
log = Logger("litecord.boot")
|
||||
redirect_logging()
|
||||
|
||||
|
||||
def make_app():
|
||||
app = Quart(__name__)
|
||||
app.config.from_object(f'config.{config.MODE}')
|
||||
is_debug = app.config.get('DEBUG', False)
|
||||
app.config.from_object(f"config.{config.MODE}")
|
||||
is_debug = app.config.get("DEBUG", False)
|
||||
app.debug = is_debug
|
||||
|
||||
if is_debug:
|
||||
log.info('on debug')
|
||||
log.info("on debug")
|
||||
handler.level = logbook.DEBUG
|
||||
app.logger.level = logbook.DEBUG
|
||||
|
||||
# always keep websockets on INFO
|
||||
logging.getLogger('websockets').setLevel(logbook.INFO)
|
||||
logging.getLogger("websockets").setLevel(logbook.INFO)
|
||||
|
||||
# use our custom json encoder for custom data types
|
||||
app.json_encoder = LitecordJSONEncoder
|
||||
|
|
@ -112,51 +131,44 @@ def set_blueprints(app_):
|
|||
"""Set the blueprints for a given app instance"""
|
||||
bps = {
|
||||
gateway: None,
|
||||
auth: '/auth',
|
||||
|
||||
users: '/users',
|
||||
user_settings: '/users',
|
||||
user_billing: '/users',
|
||||
relationships: '/users',
|
||||
|
||||
guilds: '/guilds',
|
||||
guild_roles: '/guilds',
|
||||
guild_members: '/guilds',
|
||||
guild_channels: '/guilds',
|
||||
guild_mod: '/guilds',
|
||||
guild_emoji: '/guilds',
|
||||
|
||||
channels: '/channels',
|
||||
channel_messages: '/channels',
|
||||
channel_reactions: '/channels',
|
||||
channel_pins: '/channels',
|
||||
|
||||
auth: "/auth",
|
||||
users: "/users",
|
||||
user_settings: "/users",
|
||||
user_billing: "/users",
|
||||
relationships: "/users",
|
||||
guilds: "/guilds",
|
||||
guild_roles: "/guilds",
|
||||
guild_members: "/guilds",
|
||||
guild_channels: "/guilds",
|
||||
guild_mod: "/guilds",
|
||||
guild_emoji: "/guilds",
|
||||
channels: "/channels",
|
||||
channel_messages: "/channels",
|
||||
channel_reactions: "/channels",
|
||||
channel_pins: "/channels",
|
||||
webhooks: None,
|
||||
science: None,
|
||||
voice: '/voice',
|
||||
voice: "/voice",
|
||||
invites: None,
|
||||
dms: '/users',
|
||||
dm_channels: '/channels',
|
||||
|
||||
dms: "/users",
|
||||
dm_channels: "/channels",
|
||||
fake_store: None,
|
||||
|
||||
icons: -1,
|
||||
attachments: -1,
|
||||
nodeinfo: -1,
|
||||
static: -1,
|
||||
|
||||
voice_admin: '/admin/voice',
|
||||
features_admin: '/admin/guilds',
|
||||
guilds_admin: '/admin/guilds',
|
||||
users_admin: '/admin/users',
|
||||
instance_invites: '/admin/instance/invites'
|
||||
voice_admin: "/admin/voice",
|
||||
features_admin: "/admin/guilds",
|
||||
guilds_admin: "/admin/guilds",
|
||||
users_admin: "/admin/users",
|
||||
instance_invites: "/admin/instance/invites",
|
||||
}
|
||||
|
||||
for bp, suffix in bps.items():
|
||||
url_prefix = f'/api/v6{suffix or ""}'
|
||||
|
||||
if suffix == -1:
|
||||
url_prefix = ''
|
||||
url_prefix = ""
|
||||
|
||||
app_.register_blueprint(bp, url_prefix=url_prefix)
|
||||
|
||||
|
|
@ -175,37 +187,35 @@ async def app_before_request():
|
|||
@app.after_request
|
||||
async def app_after_request(resp):
|
||||
"""Handle CORS headers."""
|
||||
origin = request.headers.get('Origin', '*')
|
||||
resp.headers['Access-Control-Allow-Origin'] = origin
|
||||
resp.headers['Access-Control-Allow-Headers'] = (
|
||||
'*, X-Super-Properties, '
|
||||
'X-Fingerprint, '
|
||||
'X-Context-Properties, '
|
||||
'X-Failed-Requests, '
|
||||
'X-Debug-Options, '
|
||||
'Content-Type, '
|
||||
'Authorization, '
|
||||
'Origin, '
|
||||
'If-None-Match'
|
||||
origin = request.headers.get("Origin", "*")
|
||||
resp.headers["Access-Control-Allow-Origin"] = origin
|
||||
resp.headers["Access-Control-Allow-Headers"] = (
|
||||
"*, X-Super-Properties, "
|
||||
"X-Fingerprint, "
|
||||
"X-Context-Properties, "
|
||||
"X-Failed-Requests, "
|
||||
"X-Debug-Options, "
|
||||
"Content-Type, "
|
||||
"Authorization, "
|
||||
"Origin, "
|
||||
"If-None-Match"
|
||||
)
|
||||
resp.headers['Access-Control-Allow-Methods'] = \
|
||||
resp.headers.get('allow', '*')
|
||||
resp.headers["Access-Control-Allow-Methods"] = resp.headers.get("allow", "*")
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def _set_rtl_reset(bucket, resp):
|
||||
reset = bucket._window + bucket.second
|
||||
precision = request.headers.get('x-ratelimit-precision', 'second')
|
||||
precision = request.headers.get("x-ratelimit-precision", "second")
|
||||
|
||||
if precision == 'second':
|
||||
resp.headers['X-RateLimit-Reset'] = str(round(reset))
|
||||
elif precision == 'millisecond':
|
||||
resp.headers['X-RateLimit-Reset'] = str(reset)
|
||||
if precision == "second":
|
||||
resp.headers["X-RateLimit-Reset"] = str(round(reset))
|
||||
elif precision == "millisecond":
|
||||
resp.headers["X-RateLimit-Reset"] = str(reset)
|
||||
else:
|
||||
resp.headers['X-RateLimit-Reset'] = (
|
||||
'Invalid X-RateLimit-Precision, '
|
||||
'valid options are (second, millisecond)'
|
||||
resp.headers["X-RateLimit-Reset"] = (
|
||||
"Invalid X-RateLimit-Precision, " "valid options are (second, millisecond)"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -218,15 +228,15 @@ async def app_set_ratelimit_headers(resp):
|
|||
if bucket is None:
|
||||
raise AttributeError()
|
||||
|
||||
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
|
||||
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
|
||||
resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower()
|
||||
resp.headers["X-RateLimit-Limit"] = str(bucket.requests)
|
||||
resp.headers["X-RateLimit-Remaining"] = str(bucket._tokens)
|
||||
resp.headers["X-RateLimit-Global"] = str(request.bucket_global).lower()
|
||||
_set_rtl_reset(bucket, resp)
|
||||
|
||||
# only add Retry-After if we actually hit a ratelimit
|
||||
retry_after = request.retry_after
|
||||
if request.retry_after:
|
||||
resp.headers['Retry-After'] = str(retry_after)
|
||||
resp.headers["Retry-After"] = str(retry_after)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
|
@ -238,8 +248,8 @@ async def init_app_db(app_):
|
|||
|
||||
Also spawns the job scheduler.
|
||||
"""
|
||||
log.info('db connect')
|
||||
app_.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
||||
log.info("db connect")
|
||||
app_.db = await asyncpg.create_pool(**app.config["POSTGRES"])
|
||||
|
||||
app_.sched = JobManager()
|
||||
|
||||
|
|
@ -247,7 +257,7 @@ async def init_app_db(app_):
|
|||
def init_app_managers(app_, *, voice=True):
|
||||
"""Initialize singleton classes."""
|
||||
app_.loop = asyncio.get_event_loop()
|
||||
app_.ratelimiter = RatelimitManager(app_.config.get('_testing'))
|
||||
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
|
||||
app_.state_manager = StateManager()
|
||||
|
||||
app_.storage = Storage(app_)
|
||||
|
|
@ -274,15 +284,12 @@ async def api_index(app_):
|
|||
to_find = {}
|
||||
found = []
|
||||
|
||||
with open('discord_endpoints.txt') as fd:
|
||||
with open("discord_endpoints.txt") as fd:
|
||||
for line in fd.readlines():
|
||||
components = line.split(' ')
|
||||
components = list(filter(
|
||||
bool,
|
||||
components
|
||||
))
|
||||
components = line.split(" ")
|
||||
components = list(filter(bool, components))
|
||||
name, method, path = components
|
||||
path = f'/api/v6{path.strip()}'
|
||||
path = f"/api/v6{path.strip()}"
|
||||
method = method.strip()
|
||||
to_find[(path, method)] = name
|
||||
|
||||
|
|
@ -290,17 +297,17 @@ async def api_index(app_):
|
|||
path = rule.rule
|
||||
|
||||
# convert the path to the discord_endpoints file's style
|
||||
path = path.replace('_', '.')
|
||||
path = path.replace('<', '{')
|
||||
path = path.replace('>', '}')
|
||||
path = path.replace('int:', '')
|
||||
path = path.replace("_", ".")
|
||||
path = path.replace("<", "{")
|
||||
path = path.replace(">", "}")
|
||||
path = path.replace("int:", "")
|
||||
|
||||
# change our parameters into user.id
|
||||
path = path.replace('member.id', 'user.id')
|
||||
path = path.replace('banned.id', 'user.id')
|
||||
path = path.replace('target.id', 'user.id')
|
||||
path = path.replace('other.id', 'user.id')
|
||||
path = path.replace('peer.id', 'user.id')
|
||||
path = path.replace("member.id", "user.id")
|
||||
path = path.replace("banned.id", "user.id")
|
||||
path = path.replace("target.id", "user.id")
|
||||
path = path.replace("other.id", "user.id")
|
||||
path = path.replace("peer.id", "user.id")
|
||||
|
||||
methods = rule.methods
|
||||
|
||||
|
|
@ -317,10 +324,15 @@ async def api_index(app_):
|
|||
percentage = (len(found) / len(api)) * 100
|
||||
percentage = round(percentage, 2)
|
||||
|
||||
log.debug('API compliance: {} out of {} ({} missing), {}% compliant',
|
||||
len(found), len(api), len(missing), percentage)
|
||||
log.debug(
|
||||
"API compliance: {} out of {} ({} missing), {}% compliant",
|
||||
len(found),
|
||||
len(api),
|
||||
len(missing),
|
||||
percentage,
|
||||
)
|
||||
|
||||
log.debug('missing: {}', missing)
|
||||
log.debug("missing: {}", missing)
|
||||
|
||||
|
||||
async def post_app_start(app_):
|
||||
|
|
@ -332,7 +344,7 @@ async def post_app_start(app_):
|
|||
|
||||
def start_websocket(host, port, ws_handler) -> asyncio.Future:
|
||||
"""Start a websocket. Returns the websocket future"""
|
||||
log.info(f'starting websocket at {host} {port}')
|
||||
log.info(f"starting websocket at {host} {port}")
|
||||
|
||||
async def _wrapper(ws, url):
|
||||
# We wrap the main websocket_handler
|
||||
|
|
@ -348,7 +360,7 @@ async def app_before_serving():
|
|||
|
||||
Also sets up the websocket handlers.
|
||||
"""
|
||||
log.info('opening db')
|
||||
log.info("opening db")
|
||||
await init_app_db(app)
|
||||
|
||||
app.session = ClientSession()
|
||||
|
|
@ -359,8 +371,7 @@ async def app_before_serving():
|
|||
# start gateway websocket
|
||||
# voice websocket is handled by the voice server
|
||||
ws_fut = start_websocket(
|
||||
app.config['WS_HOST'], app.config['WS_PORT'],
|
||||
websocket_handler
|
||||
app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler
|
||||
)
|
||||
|
||||
await ws_fut
|
||||
|
|
@ -379,7 +390,7 @@ async def app_after_serving():
|
|||
|
||||
app.sched.close()
|
||||
|
||||
log.info('closing db')
|
||||
log.info("closing db")
|
||||
await app.db.close()
|
||||
|
||||
|
||||
|
|
@ -391,24 +402,23 @@ async def handle_litecord_err(err):
|
|||
ejson = {}
|
||||
|
||||
try:
|
||||
ejson['code'] = err.error_code
|
||||
ejson["code"] = err.error_code
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
log.warning('error: {} {!r}', err.status_code, err.message)
|
||||
log.warning("error: {} {!r}", err.status_code, err.message)
|
||||
|
||||
return jsonify({
|
||||
'error': True,
|
||||
'status': err.status_code,
|
||||
'message': err.message,
|
||||
**ejson
|
||||
}), err.status_code
|
||||
return (
|
||||
jsonify(
|
||||
{"error": True, "status": err.status_code, "message": err.message, **ejson}
|
||||
),
|
||||
err.status_code,
|
||||
)
|
||||
|
||||
|
||||
@app.errorhandler(500)
|
||||
async def handle_500(err):
|
||||
return jsonify({
|
||||
'error': True,
|
||||
'message': repr(err),
|
||||
'internal_server_error': True,
|
||||
}), 500
|
||||
return (
|
||||
jsonify({"error": True, "message": repr(err), "internal_server_error": True}),
|
||||
500,
|
||||
)
|
||||
|
|
|
|||
12
setup.py
12
setup.py
|
|
@ -20,10 +20,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='litecord',
|
||||
version='0.0.1',
|
||||
description='Implementation of the Discord API',
|
||||
url='https://litecord.top',
|
||||
author='Luna Mendes',
|
||||
python_requires='>=3.7'
|
||||
name="litecord",
|
||||
version="0.0.1",
|
||||
description="Implementation of the Discord API",
|
||||
url="https://litecord.top",
|
||||
author="Luna Mendes",
|
||||
python_requires=">=3.7",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,13 +19,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import secrets
|
||||
|
||||
|
||||
def email() -> str:
|
||||
return f'{secrets.token_hex(5)}@{secrets.token_hex(5)}.com'
|
||||
return f"{secrets.token_hex(5)}@{secrets.token_hex(5)}.com"
|
||||
|
||||
|
||||
class TestClient:
|
||||
"""Test client that wraps pytest-sanic's TestClient and a test
|
||||
user and adds authorization headers to test requests."""
|
||||
|
||||
def __init__(self, test_cli, test_user):
|
||||
self.cli = test_cli
|
||||
self.app = test_cli.app
|
||||
|
|
@ -37,31 +39,31 @@ class TestClient:
|
|||
def _inject_auth(self, kwargs: dict) -> list:
|
||||
"""Inject the test user's API key into the test request before
|
||||
passing the request on to the underlying TestClient."""
|
||||
headers = kwargs.get('headers', {})
|
||||
headers['authorization'] = self.user['token']
|
||||
headers = kwargs.get("headers", {})
|
||||
headers["authorization"] = self.user["token"]
|
||||
return headers
|
||||
|
||||
async def get(self, *args, **kwargs):
|
||||
"""Send a GET request."""
|
||||
kwargs['headers'] = self._inject_auth(kwargs)
|
||||
kwargs["headers"] = self._inject_auth(kwargs)
|
||||
return await self.cli.get(*args, **kwargs)
|
||||
|
||||
async def post(self, *args, **kwargs):
|
||||
"""Send a POST request."""
|
||||
kwargs['headers'] = self._inject_auth(kwargs)
|
||||
kwargs["headers"] = self._inject_auth(kwargs)
|
||||
return await self.cli.post(*args, **kwargs)
|
||||
|
||||
async def put(self, *args, **kwargs):
|
||||
"""Send a POST request."""
|
||||
kwargs['headers'] = self._inject_auth(kwargs)
|
||||
kwargs["headers"] = self._inject_auth(kwargs)
|
||||
return await self.cli.put(*args, **kwargs)
|
||||
|
||||
async def patch(self, *args, **kwargs):
|
||||
"""Send a PATCH request."""
|
||||
kwargs['headers'] = self._inject_auth(kwargs)
|
||||
kwargs["headers"] = self._inject_auth(kwargs)
|
||||
return await self.cli.patch(*args, **kwargs)
|
||||
|
||||
async def delete(self, *args, **kwargs):
|
||||
"""Send a DELETE request."""
|
||||
kwargs['headers'] = self._inject_auth(kwargs)
|
||||
kwargs["headers"] = self._inject_auth(kwargs)
|
||||
return await self.cli.delete(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -36,22 +36,22 @@ from litecord.blueprints.auth import make_token
|
|||
from litecord.blueprints.users import delete_user
|
||||
|
||||
|
||||
@pytest.fixture(name='app')
|
||||
@pytest.fixture(name="app")
|
||||
def _test_app(unused_tcp_port, event_loop):
|
||||
set_blueprints(main_app)
|
||||
main_app.config['_testing'] = True
|
||||
main_app.config["_testing"] = True
|
||||
|
||||
# reassign an unused tcp port for websockets
|
||||
# since the config might give a used one.
|
||||
ws_port = unused_tcp_port
|
||||
|
||||
main_app.config['IS_SSL'] = False
|
||||
main_app.config['WS_PORT'] = ws_port
|
||||
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
|
||||
main_app.config["IS_SSL"] = False
|
||||
main_app.config["WS_PORT"] = ws_port
|
||||
main_app.config["WEBSOCKET_URL"] = f"localhost:{ws_port}"
|
||||
|
||||
# testing user creations requires hardcoding this to true
|
||||
# on testing
|
||||
main_app.config['REGISTRATIONS'] = True
|
||||
main_app.config["REGISTRATIONS"] = True
|
||||
|
||||
# make sure we're calling the before_serving hooks
|
||||
event_loop.run_until_complete(main_app.startup())
|
||||
|
|
@ -63,11 +63,12 @@ def _test_app(unused_tcp_port, event_loop):
|
|||
event_loop.run_until_complete(main_app.shutdown())
|
||||
|
||||
|
||||
@pytest.fixture(name='test_cli')
|
||||
@pytest.fixture(name="test_cli")
|
||||
def _test_cli(app):
|
||||
"""Give a test client."""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
# code shamelessly stolen from my elixire mr
|
||||
# https://gitlab.com/elixire/elixire/merge_requests/52
|
||||
async def _user_fixture_setup(app):
|
||||
|
|
@ -76,21 +77,26 @@ async def _user_fixture_setup(app):
|
|||
user_email = email()
|
||||
|
||||
user_id, pwd_hash = await create_user(
|
||||
username, user_email, password, app.db, app.loop)
|
||||
username, user_email, password, app.db, app.loop
|
||||
)
|
||||
|
||||
# generate a token for api access
|
||||
user_token = make_token(user_id, pwd_hash)
|
||||
|
||||
return {'id': user_id, 'token': user_token,
|
||||
'email': user_email, 'username': username,
|
||||
'password': password}
|
||||
return {
|
||||
"id": user_id,
|
||||
"token": user_token,
|
||||
"email": user_email,
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
|
||||
|
||||
async def _user_fixture_teardown(app, udata: dict):
|
||||
await delete_user(udata['id'], app_=app)
|
||||
await delete_user(udata["id"], app_=app)
|
||||
|
||||
|
||||
@pytest.fixture(name='test_user')
|
||||
@pytest.fixture(name="test_user")
|
||||
async def test_user_fixture(app):
|
||||
"""Yield a randomly generated test user."""
|
||||
udata = await _user_fixture_setup(app)
|
||||
|
|
@ -113,18 +119,25 @@ async def test_cli_staff(test_cli):
|
|||
# same test_cli_user, which isn't acceptable.
|
||||
app = test_cli.app
|
||||
test_user = await _user_fixture_setup(app)
|
||||
user_id = test_user['id']
|
||||
user_id = test_user["id"]
|
||||
|
||||
# copied from manage.cmd.users.set_user_staff.
|
||||
old_flags = await app.db.fetchval("""
|
||||
old_flags = await app.db.fetchval(
|
||||
"""
|
||||
SELECT flags FROM users WHERE id = $1
|
||||
""", user_id)
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
new_flags = old_flags | UserFlags.staff
|
||||
|
||||
await app.db.execute("""
|
||||
await app.db.execute(
|
||||
"""
|
||||
UPDATE users SET flags = $1 WHERE id = $2
|
||||
""", new_flags, user_id)
|
||||
""",
|
||||
new_flags,
|
||||
user_id,
|
||||
)
|
||||
|
||||
yield TestClient(test_cli, test_user)
|
||||
await _user_fixture_teardown(test_cli.app, test_user)
|
||||
|
|
|
|||
|
|
@ -24,24 +24,24 @@ import pytest
|
|||
from litecord.blueprints.guilds import delete_guild
|
||||
from litecord.errors import GuildNotFound
|
||||
|
||||
|
||||
async def _create_guild(test_cli_staff):
|
||||
genned_name = secrets.token_hex(6)
|
||||
|
||||
resp = await test_cli_staff.post('/api/v6/guilds', json={
|
||||
'name': genned_name,
|
||||
'region': None
|
||||
})
|
||||
resp = await test_cli_staff.post(
|
||||
"/api/v6/guilds", json={"name": genned_name, "region": None}
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
rjson = await resp.json
|
||||
assert isinstance(rjson, dict)
|
||||
assert rjson['name'] == genned_name
|
||||
assert rjson["name"] == genned_name
|
||||
|
||||
return rjson
|
||||
|
||||
|
||||
async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
||||
resp = await test_cli_staff.get(f'/api/v6/admin/guilds/{guild_id}')
|
||||
resp = await test_cli_staff.get(f"/api/v6/admin/guilds/{guild_id}")
|
||||
|
||||
if ret_early:
|
||||
return resp
|
||||
|
|
@ -49,7 +49,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
|||
assert resp.status_code == 200
|
||||
rjson = await resp.json
|
||||
assert isinstance(rjson, dict)
|
||||
assert rjson['id'] == guild_id
|
||||
assert rjson["id"] == guild_id
|
||||
|
||||
return rjson
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
|||
async def test_guild_fetch(test_cli_staff):
|
||||
"""Test the creation and fetching of a guild via the Admin API."""
|
||||
rjson = await _create_guild(test_cli_staff)
|
||||
guild_id = rjson['id']
|
||||
guild_id = rjson["id"]
|
||||
|
||||
try:
|
||||
await _fetch_guild(test_cli_staff, guild_id)
|
||||
|
|
@ -70,8 +70,8 @@ async def test_guild_fetch(test_cli_staff):
|
|||
async def test_guild_update(test_cli_staff):
|
||||
"""Test the update of a guild via the Admin API."""
|
||||
rjson = await _create_guild(test_cli_staff)
|
||||
guild_id = rjson['id']
|
||||
assert not rjson['unavailable']
|
||||
guild_id = rjson["id"]
|
||||
assert not rjson["unavailable"]
|
||||
|
||||
try:
|
||||
# I believe setting up an entire gateway client registered to the guild
|
||||
|
|
@ -79,19 +79,17 @@ async def test_guild_update(test_cli_staff):
|
|||
# testing them. Yes, I know its a bad idea, but if someone has an easier
|
||||
# way to write that, do send an MR.
|
||||
resp = await test_cli_staff.patch(
|
||||
f'/api/v6/admin/guilds/{guild_id}',
|
||||
json={
|
||||
'unavailable': True
|
||||
})
|
||||
f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True}
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
rjson = await resp.json
|
||||
assert isinstance(rjson, dict)
|
||||
assert rjson['id'] == guild_id
|
||||
assert rjson['unavailable']
|
||||
assert rjson["id"] == guild_id
|
||||
assert rjson["unavailable"]
|
||||
|
||||
rjson = await _fetch_guild(test_cli_staff, guild_id)
|
||||
assert rjson['unavailable']
|
||||
assert rjson["unavailable"]
|
||||
finally:
|
||||
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
||||
|
||||
|
|
@ -100,20 +98,19 @@ async def test_guild_update(test_cli_staff):
|
|||
async def test_guild_delete(test_cli_staff):
|
||||
"""Test the update of a guild via the Admin API."""
|
||||
rjson = await _create_guild(test_cli_staff)
|
||||
guild_id = rjson['id']
|
||||
guild_id = rjson["id"]
|
||||
|
||||
try:
|
||||
resp = await test_cli_staff.delete(f'/api/v6/admin/guilds/{guild_id}')
|
||||
resp = await test_cli_staff.delete(f"/api/v6/admin/guilds/{guild_id}")
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
resp = await _fetch_guild(
|
||||
test_cli_staff, guild_id, ret_early=True)
|
||||
resp = await _fetch_guild(test_cli_staff, guild_id, ret_early=True)
|
||||
|
||||
assert resp.status_code == 404
|
||||
rjson = await resp.json
|
||||
assert isinstance(rjson, dict)
|
||||
assert rjson['error']
|
||||
assert rjson['code'] == GuildNotFound.error_code
|
||||
assert rjson["error"]
|
||||
assert rjson["code"] == GuildNotFound.error_code
|
||||
finally:
|
||||
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import pytest
|
|||
|
||||
|
||||
async def _get_invs(test_cli):
|
||||
resp = await test_cli.get('/api/v6/admin/instance/invites')
|
||||
resp = await test_cli.get("/api/v6/admin/instance/invites")
|
||||
|
||||
assert resp.status_code == 200
|
||||
rjson = await resp.json
|
||||
|
|
@ -39,7 +39,7 @@ async def test_get_invites(test_cli_staff):
|
|||
async def test_inv_delete_invalid(test_cli_staff):
|
||||
"""Test errors happen when trying to delete a
|
||||
non-existing instance invite."""
|
||||
resp = await test_cli_staff.delete('/api/v6/admin/instance/invites/aaaaaa')
|
||||
resp = await test_cli_staff.delete("/api/v6/admin/instance/invites/aaaaaa")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
|
@ -48,21 +48,20 @@ async def test_inv_delete_invalid(test_cli_staff):
|
|||
async def test_create_invite(test_cli_staff):
|
||||
"""Test the creation of an instance invite, then listing it,
|
||||
then deleting it."""
|
||||
resp = await test_cli_staff.put('/api/v6/admin/instance/invites', json={
|
||||
'max_uses': 1
|
||||
})
|
||||
resp = await test_cli_staff.put(
|
||||
"/api/v6/admin/instance/invites", json={"max_uses": 1}
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
rjson = await resp.json
|
||||
assert isinstance(rjson, dict)
|
||||
code = rjson['code']
|
||||
code = rjson["code"]
|
||||
|
||||
# assert that the invite is in the list
|
||||
invites = await _get_invs(test_cli_staff)
|
||||
assert any(inv['code'] == code for inv in invites)
|
||||
assert any(inv["code"] == code for inv in invites)
|
||||
|
||||
# delete it, and assert it worked
|
||||
resp = await test_cli_staff.delete(
|
||||
f'/api/v6/admin/instance/invites/{code}')
|
||||
resp = await test_cli_staff.delete(f"/api/v6/admin/instance/invites/{code}")
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue