From 6c992588e984cbdf23fe3ce4f87cdac812940130 Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 8 Dec 2018 23:38:52 -0300 Subject: [PATCH] channel.messages: fix insertion into attachments table - channel.messages: add file write - images: make get_ext and get_mime public functions - storage: receive app instance instead of db --- litecord/blueprints/channel/messages.py | 31 +++++++++++++++---------- litecord/images.py | 12 +++++----- litecord/storage.py | 27 +++++++++++++++++---- run.py | 2 +- 4 files changed, 49 insertions(+), 23 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index b3f9eee..acdec46 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -35,6 +35,7 @@ from litecord.utils import pg_set_json from litecord.embed.sanitizer import fill_embed, proxify, fetch_metadata from litecord.blueprints.channel.dm_checks import dm_pre_check +from litecord.images import get_ext log = Logger(__name__) @@ -347,7 +348,8 @@ def _check_content(payload: dict, files: list): raise BadRequest('No content has been provided.') -async def _add_attachment(message_id: int, attachment_file) -> int: +async def _add_attachment(message_id: int, channel_id: int, + attachment_file) -> int: """Add an attachment to a message. Parameters @@ -359,9 +361,11 @@ async def _add_attachment(message_id: int, attachment_file) -> int: """ attachment_id = get_snowflake() + filename = attachment_file.filename # understand file info - is_image = attachment_file.mimetype.startswith('image/') + mime = attachment_file.mimetype + is_image = mime.startswith('image/') img_width, img_height = None, None @@ -381,20 +385,23 @@ async def _add_attachment(message_id: int, attachment_file) -> int: await app.db.execute( """ INSERT INTO attachments - (id, filename, filesize, image, height, width) + (id, channel_id, message_id, + filename, filesize, + image, height, width) VALUES - ($1, $2, $3, $4, $5, $6) + ($1, $2, $3, $4, $5, $6, $7, $8) """, - attachment_id, attachment_file.filename, file_size, + attachment_id, channel_id, message_id, + filename, file_size, is_image, img_width, img_height) - # TODO: save a file + ext = filename.split('.')[-1] - # add the newly created attachment to the message - await app.db.execute(""" - INSERT INTO message_attachments (message_id, attachment_id) - VALUES ($1, $2) - """, message_id, attachment_id) + with open(f'attachments/{attachment_id}.{ext}') as attach_file: + attach_file.write(attachment_file.stream.read()) + + log.debug('written {} bytes for attachment id {}', + file_size, attachment_id) return attachment_id @@ -450,7 +457,7 @@ async def _create_message(channel_id): # for each file given, we add it as an attachment for pre_attachment in files: - await _add_attachment(message_id, pre_attachment) + await _add_attachment(message_id, channel_id, pre_attachment) payload = await app.storage.get_message(message_id, user_id) diff --git a/litecord/images.py b/litecord/images.py index 1cb9eb2..d0cf305 100644 --- a/litecord/images.py +++ b/litecord/images.py @@ -49,7 +49,7 @@ MIMES = { } -def _get_ext(mime: str) -> str: +def get_ext(mime: str) -> str: if mime in EXTENSIONS: return EXTENSIONS[mime] @@ -57,7 +57,7 @@ def _get_ext(mime: str) -> str: return extensions[0].strip('.') -def _get_mime(ext: str): +def get_mime(ext: str): if ext in MIMES: return MIMES[ext] @@ -74,7 +74,7 @@ class Icon: @property def as_path(self) -> str: """Return a filesystem path for the given icon.""" - ext = _get_ext(self.mime) + ext = get_ext(self.mime) return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}') @property @@ -83,7 +83,7 @@ class Icon: @property def extension(self) -> str: - return _get_ext(self.mime) + return get_ext(self.mime) class ImageError(Exception): @@ -201,7 +201,7 @@ class IconManager: self.storage = app.storage async def _convert_ext(self, icon: Icon, target: str): - target_mime = _get_mime(target) + target_mime = get_mime(target) log.info('converting from {} to {}', icon.mime, target_mime) target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}' @@ -328,7 +328,7 @@ class IconManager: data_fd = BytesIO(raw_data) # get an extension for the given data uri - extension = _get_ext(mime) + extension = get_ext(mime) if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']: return _invalid(kwargs) diff --git a/litecord/storage.py b/litecord/storage.py index 04fe411..66629c2 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -31,7 +31,6 @@ from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE from litecord.types import timestamp_ from litecord.utils import pg_set_json -from litecord.embed.sanitizer import proxify log = Logger(__name__) @@ -65,8 +64,9 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): class Storage: """Class for common SQL statements.""" - def __init__(self, db): - self.db = db + def __init__(self, app): + self.app = app + self.db = app.db self.presence = None async def fetchrow_with_json(self, query: str, *args): @@ -649,16 +649,35 @@ class Storage: for attachment_id in attachment_ids: row = await self.db.fetchrow(""" - SELECT id::text, filename, filesize, image, height, width + SELECT id::text, message_id, channel_id, mime + filename, filesize, image, height, width FROM attachments WHERE id = $1 """, attachment_id) drow = dict(row) + drow.pop('message_id') + drow.pop('channel_id') + drow.pop('mime') + drow['size'] = drow['filesize'] drow.pop('size') + # construct attachment url + proto = 'https' if self.app.config['IS_SSL'] else 'http' + main_url = self.app.config['MAIN_URL'] + + attachment_ext = get_ext(row['mime']) + + drow['url'] = (f'{proto}://{main_url}/attachments/' + f'{row["channel_id"]}/{row["message_id"]}/' + f'{row["filename"]}.{attachment_ext}') + + # NOTE: since the url comes from the instance itself + # i think proxy_url=url is valid. + drow['proxy_url'] = drow['url'] + # TODO: url, proxy_url res.append(drow) diff --git a/run.py b/run.py index 03e1a95..adb01f4 100644 --- a/run.py +++ b/run.py @@ -220,7 +220,7 @@ def init_app_managers(app): app.ratelimiter = RatelimitManager(app.config.get('_testing')) app.state_manager = StateManager() - app.storage = Storage(app.db) + app.storage = Storage(app) app.user_storage = UserStorage(app.storage) app.icons = IconManager(app)