various fixes to dms

- blueprints.checks: fix party fetching
 - blueprints.users: fix try_dm_state
 - blueprints.users: fix create_dm
 - blueprints.users: fix create_group_dm being used for 1-on-1 dm
 - gateway.websocket: add support for pure zlib
 - schemas: fix CREATE_GROUP_DM
 - storage: fix _filter_recipients, get_channel and get_dms
This commit is contained in:
Luna Mendes 2018-10-10 17:53:31 -03:00
parent 737129bd20
commit e210c20d0f
5 changed files with 26 additions and 12 deletions

View File

@ -37,12 +37,14 @@ async def channel_check(user_id, channel_id):
return guild_id return guild_id
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
parties = await app.db.fetchval(""" parties = await app.db.fetchrow("""
SELECT party1_id, party2_id SELECT party1_id, party2_id
FROM dm_channels FROM dm_channels
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
""", channel_id, user_id) """, channel_id, user_id)
parties = [parties['party1_id'], parties['party2_id']]
# get the id of the other party # get the id of the other party
parties.remove(user_id) parties.remove(user_id)
return parties[0] return parties[0]

View File

@ -5,6 +5,7 @@ from ..auth import token_check
from ..snowflake import get_snowflake from ..snowflake import get_snowflake
from ..errors import Forbidden, BadRequest from ..errors import Forbidden, BadRequest
from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM
from ..enums import ChannelType
from .guilds import guild_check from .guilds import guild_check
@ -133,7 +134,7 @@ async def try_dm_state(user_id, dm_id):
for the given DM.""" for the given DM."""
try: try:
await app.db.execute(""" await app.db.execute("""
INSERT INTO dm_channel_state (id, dm_id) INSERT INTO dm_channel_state (user_id, dm_id)
VALUES ($1, $2) VALUES ($1, $2)
""", user_id, dm_id) """, user_id, dm_id)
except UniqueViolationError: except UniqueViolationError:
@ -145,6 +146,11 @@ async def create_dm(user_id, recipient_id):
dm_id = get_snowflake() dm_id = get_snowflake()
try: try:
await app.db.execute("""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", dm_id, ChannelType.DM.value)
await app.db.execute(""" await app.db.execute("""
INSERT INTO dm_channels (id, party1_id, party2_id) INSERT INTO dm_channels (id, party1_id, party2_id)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
@ -175,7 +181,7 @@ async def start_dm():
return await create_dm(user_id, recipient_id) return await create_dm(user_id, recipient_id)
@bp.route('/<int:user_id>/channels', methods=['POST']) @bp.route('/<int:p_user_id>/channels', methods=['POST'])
async def create_group_dm(p_user_id: int): async def create_group_dm(p_user_id: int):
"""Create a DM or a Group DM with user(s).""" """Create a DM or a Group DM with user(s)."""
user_id = await token_check() user_id = await token_check()
@ -184,12 +190,12 @@ async def create_group_dm(p_user_id: int):
j = validate(await request.get_json(), CREATE_GROUP_DM) j = validate(await request.get_json(), CREATE_GROUP_DM)
recipients = j['recipients'] recipients = j['recipients']
if list(recipients) == 1: if len(recipients) == 1:
# its a group dm with 1 user... a dm! # its a group dm with 1 user... a dm!
return await create_dm(user_id, int(recipients[0])) return await create_dm(user_id, int(recipients[0]))
# TODO: group dms # TODO: group dms
return '', 500 return 'group dms not implemented', 500
@bp.route('/@me/notes/<int:target_id>', methods=['PUT']) @bp.route('/@me/notes/<int:target_id>', methods=['PUT'])

View File

@ -97,13 +97,15 @@ class GatewayWebsocket:
if not isinstance(encoded, bytes): if not isinstance(encoded, bytes):
encoded = encoded.encode() encoded = encoded.encode()
# handle zlib-stream, pure zlib or plain
if self.wsp.compress == 'zlib-stream': if self.wsp.compress == 'zlib-stream':
data1 = self.wsp.zctx.compress(encoded) data1 = self.wsp.zctx.compress(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
await self.ws.send(data1 + data2) await self.ws.send(data1 + data2)
elif self.state and self.state.compress:
await self.ws.send(zlib.compress(encoded))
else: else:
# TODO: pure zlib
await self.ws.send(encoded.decode()) await self.ws.send(encoded.decode())
async def _hb_wait(self, interval: int): async def _hb_wait(self, interval: int):

View File

@ -288,9 +288,11 @@ CREATE_DM = {
} }
CREATE_GROUP_DM = { CREATE_GROUP_DM = {
'type': 'list', 'recipients': {
'required': True, 'type': 'list',
'schema': {'type': 'snowflake'} 'required': True,
'schema': {'type': 'snowflake'}
},
} }
SPECIFIC_FRIEND = { SPECIFIC_FRIEND = {

View File

@ -41,9 +41,9 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
the one that is reundant (ourselves).""" the one that is reundant (ourselves)."""
user_id = str(user_id) user_id = str(user_id)
return filter( return list(filter(
lambda recipient: recipient['id'] != user_id, lambda recipient: recipient['id'] != user_id,
recipients) recipients))
class Storage: class Storage:
@ -336,7 +336,7 @@ class Storage:
return res return res
elif ctype == ChannelType.DM: elif ctype == ChannelType.DM:
dm_row = await self.db.fetchrow(""" dm_row = await self.db.fetchrow("""
SELECT party1_id, party2_id SELECT id, party1_id, party2_id
FROM dm_channels FROM dm_channels
WHERE id = $1 WHERE id = $1
""", channel_id) """, channel_id)
@ -783,6 +783,8 @@ class Storage:
WHERE user_id = $1 WHERE user_id = $1
""", user_id) """, user_id)
dm_ids = [r['dm_id'] for r in dm_ids]
res = [] res = []
for dm_id in dm_ids: for dm_id in dm_ids: