diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index aa92b7a..ebd0700 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -37,12 +37,14 @@ async def channel_check(user_id, channel_id): return guild_id if ctype == ChannelType.DM: - parties = await app.db.fetchval(""" + parties = await app.db.fetchrow(""" SELECT party1_id, party2_id FROM dm_channels WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) """, channel_id, user_id) + parties = [parties['party1_id'], parties['party2_id']] + # get the id of the other party parties.remove(user_id) return parties[0] diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 117b9fd..74e72cb 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -5,6 +5,7 @@ from ..auth import token_check from ..snowflake import get_snowflake from ..errors import Forbidden, BadRequest from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM +from ..enums import ChannelType from .guilds import guild_check @@ -133,7 +134,7 @@ async def try_dm_state(user_id, dm_id): for the given DM.""" try: await app.db.execute(""" - INSERT INTO dm_channel_state (id, dm_id) + INSERT INTO dm_channel_state (user_id, dm_id) VALUES ($1, $2) """, user_id, dm_id) except UniqueViolationError: @@ -145,6 +146,11 @@ async def create_dm(user_id, recipient_id): dm_id = get_snowflake() try: + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, dm_id, ChannelType.DM.value) + await app.db.execute(""" INSERT INTO dm_channels (id, party1_id, party2_id) VALUES ($1, $2, $3) @@ -175,7 +181,7 @@ async def start_dm(): return await create_dm(user_id, recipient_id) -@bp.route('//channels', methods=['POST']) +@bp.route('//channels', methods=['POST']) async def create_group_dm(p_user_id: int): """Create a DM or a Group DM with user(s).""" user_id = await token_check() @@ -184,12 +190,12 @@ async def create_group_dm(p_user_id: int): j = validate(await request.get_json(), CREATE_GROUP_DM) recipients = j['recipients'] - if list(recipients) == 1: + if len(recipients) == 1: # its a group dm with 1 user... a dm! return await create_dm(user_id, int(recipients[0])) # TODO: group dms - return '', 500 + return 'group dms not implemented', 500 @bp.route('/@me/notes/', methods=['PUT']) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index fed011d..7e9eb77 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -97,13 +97,15 @@ class GatewayWebsocket: if not isinstance(encoded, bytes): encoded = encoded.encode() + # handle zlib-stream, pure zlib or plain if self.wsp.compress == 'zlib-stream': data1 = self.wsp.zctx.compress(encoded) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) await self.ws.send(data1 + data2) + elif self.state and self.state.compress: + await self.ws.send(zlib.compress(encoded)) else: - # TODO: pure zlib await self.ws.send(encoded.decode()) async def _hb_wait(self, interval: int): diff --git a/litecord/schemas.py b/litecord/schemas.py index b7476e3..a6f30dd 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -288,9 +288,11 @@ CREATE_DM = { } CREATE_GROUP_DM = { - 'type': 'list', - 'required': True, - 'schema': {'type': 'snowflake'} + 'recipients': { + 'type': 'list', + 'required': True, + 'schema': {'type': 'snowflake'} + }, } SPECIFIC_FRIEND = { diff --git a/litecord/storage.py b/litecord/storage.py index 0a1a83c..71839dd 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -41,9 +41,9 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): the one that is reundant (ourselves).""" user_id = str(user_id) - return filter( + return list(filter( lambda recipient: recipient['id'] != user_id, - recipients) + recipients)) class Storage: @@ -336,7 +336,7 @@ class Storage: return res elif ctype == ChannelType.DM: dm_row = await self.db.fetchrow(""" - SELECT party1_id, party2_id + SELECT id, party1_id, party2_id FROM dm_channels WHERE id = $1 """, channel_id) @@ -783,6 +783,8 @@ class Storage: WHERE user_id = $1 """, user_id) + dm_ids = [r['dm_id'] for r in dm_ids] + res = [] for dm_id in dm_ids: