diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py index 7d625a2..2a83c54 100644 --- a/litecord/blueprints/dms.py +++ b/litecord/blueprints/dms.py @@ -38,41 +38,47 @@ async def try_dm_state(user_id: int, dm_id: int): """, user_id, dm_id) +async def jsonify_dm(dm_id: int, user_id: int): + dm_chan = await app.storage.get_dm(dm_id, user_id) + return jsonify(dm_chan) + + async def create_dm(user_id, recipient_id): """Create a new dm with a user, or get the existing DM id if it already exists.""" + + dm_id = await app.db.fetchval(""" + SELECT id + FROM dm_channels + WHERE (party1_id = $1 OR party2_id = $1) AND + (party1_id = $2 OR party2_id = $2) + """, user_id, recipient_id) + + if dm_id: + return await jsonify_dm(dm_id, user_id) + + # if no dm was found, create a new one + dm_id = get_snowflake() + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, dm_id, ChannelType.DM.value) - try: - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, dm_id, ChannelType.DM.value) + await app.db.execute(""" + INSERT INTO dm_channels (id, party1_id, party2_id) + VALUES ($1, $2, $3) + """, dm_id, user_id, recipient_id) - await app.db.execute(""" - INSERT INTO dm_channels (id, party1_id, party2_id) - VALUES ($1, $2, $3) - """, dm_id, user_id, recipient_id) + # the dm state is something we use + # to give the currently "open dms" + # on the client. - # the dm state is something we use - # to give the currently "open dms" - # on the client. + # we don't open a dm for the peer/recipient + # until the user sends a message. + await try_dm_state(user_id, dm_id) - # we don't open a dm for the peer/recipient - # until the user sends a message. - await try_dm_state(user_id, dm_id) - - except UniqueViolationError: - # the dm already exists - dm_id = await app.db.fetchval(""" - SELECT id - FROM dm_channels - WHERE (party1_id = $1 OR party2_id = $1) AND - (party2_id = $2 OR party2_id = $2) - """, user_id, recipient_id) - - dm = await app.storage.get_dm(dm_id, user_id) - return jsonify(dm) + return await jsonify_dm(dm_id, user_id) @bp.route('/@me/channels', methods=['POST'])