mirror of https://gitlab.com/litecord/litecord.git
various fixes to dms
- blueprints.checks: fix party fetching - blueprints.users: fix try_dm_state - blueprints.users: fix create_dm - blueprints.users: fix create_group_dm being used for 1-on-1 dm - gateway.websocket: add support for pure zlib - schemas: fix CREATE_GROUP_DM - storage: fix _filter_recipients, get_channel and get_dms
This commit is contained in:
parent
737129bd20
commit
e210c20d0f
|
|
@ -37,12 +37,14 @@ async def channel_check(user_id, channel_id):
|
|||
return guild_id
|
||||
|
||||
if ctype == ChannelType.DM:
|
||||
parties = await app.db.fetchval("""
|
||||
parties = await app.db.fetchrow("""
|
||||
SELECT party1_id, party2_id
|
||||
FROM dm_channels
|
||||
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
||||
""", channel_id, user_id)
|
||||
|
||||
parties = [parties['party1_id'], parties['party2_id']]
|
||||
|
||||
# get the id of the other party
|
||||
parties.remove(user_id)
|
||||
return parties[0]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from ..auth import token_check
|
|||
from ..snowflake import get_snowflake
|
||||
from ..errors import Forbidden, BadRequest
|
||||
from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM
|
||||
from ..enums import ChannelType
|
||||
|
||||
from .guilds import guild_check
|
||||
|
||||
|
|
@ -133,7 +134,7 @@ async def try_dm_state(user_id, dm_id):
|
|||
for the given DM."""
|
||||
try:
|
||||
await app.db.execute("""
|
||||
INSERT INTO dm_channel_state (id, dm_id)
|
||||
INSERT INTO dm_channel_state (user_id, dm_id)
|
||||
VALUES ($1, $2)
|
||||
""", user_id, dm_id)
|
||||
except UniqueViolationError:
|
||||
|
|
@ -145,6 +146,11 @@ async def create_dm(user_id, recipient_id):
|
|||
dm_id = get_snowflake()
|
||||
|
||||
try:
|
||||
await app.db.execute("""
|
||||
INSERT INTO channels (id, channel_type)
|
||||
VALUES ($1, $2)
|
||||
""", dm_id, ChannelType.DM.value)
|
||||
|
||||
await app.db.execute("""
|
||||
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||
VALUES ($1, $2, $3)
|
||||
|
|
@ -175,7 +181,7 @@ async def start_dm():
|
|||
return await create_dm(user_id, recipient_id)
|
||||
|
||||
|
||||
@bp.route('/<int:user_id>/channels', methods=['POST'])
|
||||
@bp.route('/<int:p_user_id>/channels', methods=['POST'])
|
||||
async def create_group_dm(p_user_id: int):
|
||||
"""Create a DM or a Group DM with user(s)."""
|
||||
user_id = await token_check()
|
||||
|
|
@ -184,12 +190,12 @@ async def create_group_dm(p_user_id: int):
|
|||
j = validate(await request.get_json(), CREATE_GROUP_DM)
|
||||
recipients = j['recipients']
|
||||
|
||||
if list(recipients) == 1:
|
||||
if len(recipients) == 1:
|
||||
# its a group dm with 1 user... a dm!
|
||||
return await create_dm(user_id, int(recipients[0]))
|
||||
|
||||
# TODO: group dms
|
||||
return '', 500
|
||||
return 'group dms not implemented', 500
|
||||
|
||||
|
||||
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
|
||||
|
|
|
|||
|
|
@ -97,13 +97,15 @@ class GatewayWebsocket:
|
|||
if not isinstance(encoded, bytes):
|
||||
encoded = encoded.encode()
|
||||
|
||||
# handle zlib-stream, pure zlib or plain
|
||||
if self.wsp.compress == 'zlib-stream':
|
||||
data1 = self.wsp.zctx.compress(encoded)
|
||||
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
||||
|
||||
await self.ws.send(data1 + data2)
|
||||
elif self.state and self.state.compress:
|
||||
await self.ws.send(zlib.compress(encoded))
|
||||
else:
|
||||
# TODO: pure zlib
|
||||
await self.ws.send(encoded.decode())
|
||||
|
||||
async def _hb_wait(self, interval: int):
|
||||
|
|
|
|||
|
|
@ -288,9 +288,11 @@ CREATE_DM = {
|
|||
}
|
||||
|
||||
CREATE_GROUP_DM = {
|
||||
'recipients': {
|
||||
'type': 'list',
|
||||
'required': True,
|
||||
'schema': {'type': 'snowflake'}
|
||||
},
|
||||
}
|
||||
|
||||
SPECIFIC_FRIEND = {
|
||||
|
|
|
|||
|
|
@ -41,9 +41,9 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
|
|||
the one that is reundant (ourselves)."""
|
||||
user_id = str(user_id)
|
||||
|
||||
return filter(
|
||||
return list(filter(
|
||||
lambda recipient: recipient['id'] != user_id,
|
||||
recipients)
|
||||
recipients))
|
||||
|
||||
|
||||
class Storage:
|
||||
|
|
@ -336,7 +336,7 @@ class Storage:
|
|||
return res
|
||||
elif ctype == ChannelType.DM:
|
||||
dm_row = await self.db.fetchrow("""
|
||||
SELECT party1_id, party2_id
|
||||
SELECT id, party1_id, party2_id
|
||||
FROM dm_channels
|
||||
WHERE id = $1
|
||||
""", channel_id)
|
||||
|
|
@ -783,6 +783,8 @@ class Storage:
|
|||
WHERE user_id = $1
|
||||
""", user_id)
|
||||
|
||||
dm_ids = [r['dm_id'] for r in dm_ids]
|
||||
|
||||
res = []
|
||||
|
||||
for dm_id in dm_ids:
|
||||
|
|
|
|||
Loading…
Reference in New Issue