channel.messages: fix insertion into attachments table

- channel.messages: add file write
 - images: make get_ext and get_mime public functions
 - storage: receive app instance instead of db
This commit is contained in:
Luna 2018-12-08 23:38:52 -03:00
parent 57f6e530f5
commit 6c992588e9
4 changed files with 49 additions and 23 deletions

View File

@ -35,6 +35,7 @@ from litecord.utils import pg_set_json
from litecord.embed.sanitizer import fill_embed, proxify, fetch_metadata from litecord.embed.sanitizer import fill_embed, proxify, fetch_metadata
from litecord.blueprints.channel.dm_checks import dm_pre_check from litecord.blueprints.channel.dm_checks import dm_pre_check
from litecord.images import get_ext
log = Logger(__name__) log = Logger(__name__)
@ -347,7 +348,8 @@ def _check_content(payload: dict, files: list):
raise BadRequest('No content has been provided.') raise BadRequest('No content has been provided.')
async def _add_attachment(message_id: int, attachment_file) -> int: async def _add_attachment(message_id: int, channel_id: int,
attachment_file) -> int:
"""Add an attachment to a message. """Add an attachment to a message.
Parameters Parameters
@ -359,9 +361,11 @@ async def _add_attachment(message_id: int, attachment_file) -> int:
""" """
attachment_id = get_snowflake() attachment_id = get_snowflake()
filename = attachment_file.filename
# understand file info # understand file info
is_image = attachment_file.mimetype.startswith('image/') mime = attachment_file.mimetype
is_image = mime.startswith('image/')
img_width, img_height = None, None img_width, img_height = None, None
@ -381,20 +385,23 @@ async def _add_attachment(message_id: int, attachment_file) -> int:
await app.db.execute( await app.db.execute(
""" """
INSERT INTO attachments INSERT INTO attachments
(id, filename, filesize, image, height, width) (id, channel_id, message_id,
filename, filesize,
image, height, width)
VALUES VALUES
($1, $2, $3, $4, $5, $6) ($1, $2, $3, $4, $5, $6, $7, $8)
""", """,
attachment_id, attachment_file.filename, file_size, attachment_id, channel_id, message_id,
filename, file_size,
is_image, img_width, img_height) is_image, img_width, img_height)
# TODO: save a file ext = filename.split('.')[-1]
# add the newly created attachment to the message with open(f'attachments/{attachment_id}.{ext}') as attach_file:
await app.db.execute(""" attach_file.write(attachment_file.stream.read())
INSERT INTO message_attachments (message_id, attachment_id)
VALUES ($1, $2) log.debug('written {} bytes for attachment id {}',
""", message_id, attachment_id) file_size, attachment_id)
return attachment_id return attachment_id
@ -450,7 +457,7 @@ async def _create_message(channel_id):
# for each file given, we add it as an attachment # for each file given, we add it as an attachment
for pre_attachment in files: for pre_attachment in files:
await _add_attachment(message_id, pre_attachment) await _add_attachment(message_id, channel_id, pre_attachment)
payload = await app.storage.get_message(message_id, user_id) payload = await app.storage.get_message(message_id, user_id)

View File

@ -49,7 +49,7 @@ MIMES = {
} }
def _get_ext(mime: str) -> str: def get_ext(mime: str) -> str:
if mime in EXTENSIONS: if mime in EXTENSIONS:
return EXTENSIONS[mime] return EXTENSIONS[mime]
@ -57,7 +57,7 @@ def _get_ext(mime: str) -> str:
return extensions[0].strip('.') return extensions[0].strip('.')
def _get_mime(ext: str): def get_mime(ext: str):
if ext in MIMES: if ext in MIMES:
return MIMES[ext] return MIMES[ext]
@ -74,7 +74,7 @@ class Icon:
@property @property
def as_path(self) -> str: def as_path(self) -> str:
"""Return a filesystem path for the given icon.""" """Return a filesystem path for the given icon."""
ext = _get_ext(self.mime) ext = get_ext(self.mime)
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}') return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
@property @property
@ -83,7 +83,7 @@ class Icon:
@property @property
def extension(self) -> str: def extension(self) -> str:
return _get_ext(self.mime) return get_ext(self.mime)
class ImageError(Exception): class ImageError(Exception):
@ -201,7 +201,7 @@ class IconManager:
self.storage = app.storage self.storage = app.storage
async def _convert_ext(self, icon: Icon, target: str): async def _convert_ext(self, icon: Icon, target: str):
target_mime = _get_mime(target) target_mime = get_mime(target)
log.info('converting from {} to {}', icon.mime, target_mime) log.info('converting from {} to {}', icon.mime, target_mime)
target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}' target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}'
@ -328,7 +328,7 @@ class IconManager:
data_fd = BytesIO(raw_data) data_fd = BytesIO(raw_data)
# get an extension for the given data uri # get an extension for the given data uri
extension = _get_ext(mime) extension = get_ext(mime)
if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']: if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']:
return _invalid(kwargs) return _invalid(kwargs)

View File

@ -31,7 +31,6 @@ from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
from litecord.types import timestamp_ from litecord.types import timestamp_
from litecord.utils import pg_set_json from litecord.utils import pg_set_json
from litecord.embed.sanitizer import proxify
log = Logger(__name__) log = Logger(__name__)
@ -65,8 +64,9 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
class Storage: class Storage:
"""Class for common SQL statements.""" """Class for common SQL statements."""
def __init__(self, db): def __init__(self, app):
self.db = db self.app = app
self.db = app.db
self.presence = None self.presence = None
async def fetchrow_with_json(self, query: str, *args): async def fetchrow_with_json(self, query: str, *args):
@ -649,16 +649,35 @@ class Storage:
for attachment_id in attachment_ids: for attachment_id in attachment_ids:
row = await self.db.fetchrow(""" row = await self.db.fetchrow("""
SELECT id::text, filename, filesize, image, height, width SELECT id::text, message_id, channel_id, mime
filename, filesize, image, height, width
FROM attachments FROM attachments
WHERE id = $1 WHERE id = $1
""", attachment_id) """, attachment_id)
drow = dict(row) drow = dict(row)
drow.pop('message_id')
drow.pop('channel_id')
drow.pop('mime')
drow['size'] = drow['filesize'] drow['size'] = drow['filesize']
drow.pop('size') drow.pop('size')
# construct attachment url
proto = 'https' if self.app.config['IS_SSL'] else 'http'
main_url = self.app.config['MAIN_URL']
attachment_ext = get_ext(row['mime'])
drow['url'] = (f'{proto}://{main_url}/attachments/'
f'{row["channel_id"]}/{row["message_id"]}/'
f'{row["filename"]}.{attachment_ext}')
# NOTE: since the url comes from the instance itself
# i think proxy_url=url is valid.
drow['proxy_url'] = drow['url']
# TODO: url, proxy_url # TODO: url, proxy_url
res.append(drow) res.append(drow)

2
run.py
View File

@ -220,7 +220,7 @@ def init_app_managers(app):
app.ratelimiter = RatelimitManager(app.config.get('_testing')) app.ratelimiter = RatelimitManager(app.config.get('_testing'))
app.state_manager = StateManager() app.state_manager = StateManager()
app.storage = Storage(app.db) app.storage = Storage(app)
app.user_storage = UserStorage(app.storage) app.user_storage = UserStorage(app.storage)
app.icons = IconManager(app) app.icons = IconManager(app)