mirror of https://gitlab.com/litecord/litecord.git
various fixes to dms
- blueprints.checks: fix party fetching - blueprints.users: fix try_dm_state - blueprints.users: fix create_dm - blueprints.users: fix create_group_dm being used for 1-on-1 dm - gateway.websocket: add support for pure zlib - schemas: fix CREATE_GROUP_DM - storage: fix _filter_recipients, get_channel and get_dms
This commit is contained in:
parent
737129bd20
commit
e210c20d0f
|
|
@ -37,12 +37,14 @@ async def channel_check(user_id, channel_id):
|
||||||
return guild_id
|
return guild_id
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
parties = await app.db.fetchval("""
|
parties = await app.db.fetchrow("""
|
||||||
SELECT party1_id, party2_id
|
SELECT party1_id, party2_id
|
||||||
FROM dm_channels
|
FROM dm_channels
|
||||||
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
||||||
""", channel_id, user_id)
|
""", channel_id, user_id)
|
||||||
|
|
||||||
|
parties = [parties['party1_id'], parties['party2_id']]
|
||||||
|
|
||||||
# get the id of the other party
|
# get the id of the other party
|
||||||
parties.remove(user_id)
|
parties.remove(user_id)
|
||||||
return parties[0]
|
return parties[0]
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from ..auth import token_check
|
||||||
from ..snowflake import get_snowflake
|
from ..snowflake import get_snowflake
|
||||||
from ..errors import Forbidden, BadRequest
|
from ..errors import Forbidden, BadRequest
|
||||||
from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM
|
from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM
|
||||||
|
from ..enums import ChannelType
|
||||||
|
|
||||||
from .guilds import guild_check
|
from .guilds import guild_check
|
||||||
|
|
||||||
|
|
@ -133,7 +134,7 @@ async def try_dm_state(user_id, dm_id):
|
||||||
for the given DM."""
|
for the given DM."""
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
INSERT INTO dm_channel_state (id, dm_id)
|
INSERT INTO dm_channel_state (user_id, dm_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", user_id, dm_id)
|
""", user_id, dm_id)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
|
|
@ -145,6 +146,11 @@ async def create_dm(user_id, recipient_id):
|
||||||
dm_id = get_snowflake()
|
dm_id = get_snowflake()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO channels (id, channel_type)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", dm_id, ChannelType.DM.value)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
INSERT INTO dm_channels (id, party1_id, party2_id)
|
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
|
|
@ -175,7 +181,7 @@ async def start_dm():
|
||||||
return await create_dm(user_id, recipient_id)
|
return await create_dm(user_id, recipient_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:user_id>/channels', methods=['POST'])
|
@bp.route('/<int:p_user_id>/channels', methods=['POST'])
|
||||||
async def create_group_dm(p_user_id: int):
|
async def create_group_dm(p_user_id: int):
|
||||||
"""Create a DM or a Group DM with user(s)."""
|
"""Create a DM or a Group DM with user(s)."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -184,12 +190,12 @@ async def create_group_dm(p_user_id: int):
|
||||||
j = validate(await request.get_json(), CREATE_GROUP_DM)
|
j = validate(await request.get_json(), CREATE_GROUP_DM)
|
||||||
recipients = j['recipients']
|
recipients = j['recipients']
|
||||||
|
|
||||||
if list(recipients) == 1:
|
if len(recipients) == 1:
|
||||||
# its a group dm with 1 user... a dm!
|
# its a group dm with 1 user... a dm!
|
||||||
return await create_dm(user_id, int(recipients[0]))
|
return await create_dm(user_id, int(recipients[0]))
|
||||||
|
|
||||||
# TODO: group dms
|
# TODO: group dms
|
||||||
return '', 500
|
return 'group dms not implemented', 500
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
|
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
|
||||||
|
|
|
||||||
|
|
@ -97,13 +97,15 @@ class GatewayWebsocket:
|
||||||
if not isinstance(encoded, bytes):
|
if not isinstance(encoded, bytes):
|
||||||
encoded = encoded.encode()
|
encoded = encoded.encode()
|
||||||
|
|
||||||
|
# handle zlib-stream, pure zlib or plain
|
||||||
if self.wsp.compress == 'zlib-stream':
|
if self.wsp.compress == 'zlib-stream':
|
||||||
data1 = self.wsp.zctx.compress(encoded)
|
data1 = self.wsp.zctx.compress(encoded)
|
||||||
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
||||||
|
|
||||||
await self.ws.send(data1 + data2)
|
await self.ws.send(data1 + data2)
|
||||||
|
elif self.state and self.state.compress:
|
||||||
|
await self.ws.send(zlib.compress(encoded))
|
||||||
else:
|
else:
|
||||||
# TODO: pure zlib
|
|
||||||
await self.ws.send(encoded.decode())
|
await self.ws.send(encoded.decode())
|
||||||
|
|
||||||
async def _hb_wait(self, interval: int):
|
async def _hb_wait(self, interval: int):
|
||||||
|
|
|
||||||
|
|
@ -288,9 +288,11 @@ CREATE_DM = {
|
||||||
}
|
}
|
||||||
|
|
||||||
CREATE_GROUP_DM = {
|
CREATE_GROUP_DM = {
|
||||||
|
'recipients': {
|
||||||
'type': 'list',
|
'type': 'list',
|
||||||
'required': True,
|
'required': True,
|
||||||
'schema': {'type': 'snowflake'}
|
'schema': {'type': 'snowflake'}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
SPECIFIC_FRIEND = {
|
SPECIFIC_FRIEND = {
|
||||||
|
|
|
||||||
|
|
@ -41,9 +41,9 @@ def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
|
||||||
the one that is reundant (ourselves)."""
|
the one that is reundant (ourselves)."""
|
||||||
user_id = str(user_id)
|
user_id = str(user_id)
|
||||||
|
|
||||||
return filter(
|
return list(filter(
|
||||||
lambda recipient: recipient['id'] != user_id,
|
lambda recipient: recipient['id'] != user_id,
|
||||||
recipients)
|
recipients))
|
||||||
|
|
||||||
|
|
||||||
class Storage:
|
class Storage:
|
||||||
|
|
@ -336,7 +336,7 @@ class Storage:
|
||||||
return res
|
return res
|
||||||
elif ctype == ChannelType.DM:
|
elif ctype == ChannelType.DM:
|
||||||
dm_row = await self.db.fetchrow("""
|
dm_row = await self.db.fetchrow("""
|
||||||
SELECT party1_id, party2_id
|
SELECT id, party1_id, party2_id
|
||||||
FROM dm_channels
|
FROM dm_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""", channel_id)
|
||||||
|
|
@ -783,6 +783,8 @@ class Storage:
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""", user_id)
|
||||||
|
|
||||||
|
dm_ids = [r['dm_id'] for r in dm_ids]
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for dm_id in dm_ids:
|
for dm_id in dm_ids:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue