diff --git a/litecord/blueprints/guild/emoji.py b/litecord/blueprints/guild/emoji.py index 925adac..cc48dfe 100644 --- a/litecord/blueprints/guild/emoji.py +++ b/litecord/blueprints/guild/emoji.py @@ -24,6 +24,8 @@ from litecord.blueprints.checks import guild_check, guild_perm_check from litecord.schemas import validate, NEW_EMOJI, PATCH_EMOJI from litecord.snowflake import get_snowflake from litecord.types import KILOBYTES +from litecord.images import parse_data_uri +from litecord.errors import BadRequest bp = Blueprint('guild.emoji', __name__) @@ -54,6 +56,24 @@ async def _get_guild_emoji_one(guild_id, emoji_id): ) +async def _guild_emoji_size_check(guild_id: int, mime: str): + limit = 50 + if await app.storage.has_feature(guild_id, 'MORE_EMOJI'): + limit = 200 + + # NOTE: I'm assuming you can have 200 animated emojis. + select_animated = mime == 'image/gif' + + total_emoji = await app.db.fetchval(""" + SELECT COUNT(*) FROM guild_emoji + WHERE guild_id = $1 AND animated = $2 + """, guild_id, select_animated) + + if total_emoji >= limit: + # TODO: really return a BadRequest? needs more looking. + raise BadRequest(f'too many emoji ({limit})') + + @bp.route('//emojis', methods=['POST']) async def _put_emoji(guild_id): user_id = await token_check() @@ -63,6 +83,11 @@ async def _put_emoji(guild_id): j = validate(await request.get_json(), NEW_EMOJI) + # we have to parse it before passing on so that we know which + # size to check. + mime, _ = parse_data_uri(j['image']) + await _guild_emoji_size_check(guild_id, mime) + emoji_id = get_snowflake() icon = await app.icons.put( @@ -75,6 +100,8 @@ async def _put_emoji(guild_id): if not icon: return '', 400 + # TODO: better way to detect animated emoji rather than just gifs, + # maybe a list perhaps? await app.db.execute( """ INSERT INTO guild_emoji @@ -85,7 +112,8 @@ async def _put_emoji(guild_id): emoji_id, guild_id, user_id, j['name'], icon.icon_hash, - icon.mime == 'image/gif') + icon.mime == 'image/gif' + ) await _dispatch_emojis(guild_id) diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index ffcfc77..8d2b669 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -19,7 +19,6 @@ along with this program. If not, see . from typing import Optional -from asyncpg import UniqueViolationError from quart import Blueprint, request, current_app as app, jsonify from litecord.blueprints.guild.channels import create_guild_channel diff --git a/litecord/storage.py b/litecord/storage.py index 2f4279b..04d9473 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -1076,3 +1076,12 @@ class Storage: """) return list(map(dict, rows)) + + async def has_feature(self, guild_id: int, feature: str) -> bool: + """Return if a certain guild has a certain feature.""" + features = await self.db.fetchval(""" + SELECT features FROM guilds + WHERE id = $1 + """, guild_id) + + return feature.upper() in features