diff --git a/litecord/blueprints/__init__.py b/litecord/blueprints/__init__.py index 10a104b..c94aaa9 100644 --- a/litecord/blueprints/__init__.py +++ b/litecord/blueprints/__init__.py @@ -8,3 +8,4 @@ from .science import bp as science from .voice import bp as voice from .invites import bp as invites from .relationships import bp as relationships +from .dms import bp as dms diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 5b3310c..2ab084f 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -10,7 +10,7 @@ from ..errors import Forbidden, ChannelNotFound, MessageNotFound from ..schemas import validate, MESSAGE_CREATE from .checks import channel_check, guild_check -from .users import try_dm_state +from .dms import try_dm_state log = Logger(__name__) bp = Blueprint('channels', __name__) diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py new file mode 100644 index 0000000..7d625a2 --- /dev/null +++ b/litecord/blueprints/dms.py @@ -0,0 +1,102 @@ +""" +blueprint for direct messages +""" + +from asyncpg import UniqueViolationError +from quart import Blueprint, request, current_app as app, jsonify +from logbook import Logger + +from ..schemas import validate, CREATE_DM, CREATE_GROUP_DM +from ..enums import ChannelType +from ..snowflake import get_snowflake + +from .auth import token_check + +log = Logger(__name__) +bp = Blueprint('dms', __name__) + + +@bp.route('/@me/channels', methods=['GET']) +async def get_dms(): + """Get the open DMs for the user.""" + user_id = await token_check() + dms = await app.storage.get_dms(user_id) + return jsonify(dms) + + +async def try_dm_state(user_id: int, dm_id: int): + """Try inserting the user into the dm state + for the given DM. + + Does not do anything if the user is already + in the dm state. + """ + await app.db.execute(""" + INSERT INTO dm_channel_state (user_id, dm_id) + VALUES ($1, $2) + ON CONFLICT DO NOTHING + """, user_id, dm_id) + + +async def create_dm(user_id, recipient_id): + """Create a new dm with a user, + or get the existing DM id if it already exists.""" + dm_id = get_snowflake() + + try: + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, dm_id, ChannelType.DM.value) + + await app.db.execute(""" + INSERT INTO dm_channels (id, party1_id, party2_id) + VALUES ($1, $2, $3) + """, dm_id, user_id, recipient_id) + + # the dm state is something we use + # to give the currently "open dms" + # on the client. + + # we don't open a dm for the peer/recipient + # until the user sends a message. + await try_dm_state(user_id, dm_id) + + except UniqueViolationError: + # the dm already exists + dm_id = await app.db.fetchval(""" + SELECT id + FROM dm_channels + WHERE (party1_id = $1 OR party2_id = $1) AND + (party2_id = $2 OR party2_id = $2) + """, user_id, recipient_id) + + dm = await app.storage.get_dm(dm_id, user_id) + return jsonify(dm) + + +@bp.route('/@me/channels', methods=['POST']) +async def start_dm(): + """Create a DM with a user.""" + user_id = await token_check() + j = validate(await request.get_json(), CREATE_DM) + recipient_id = j['recipient_id'] + + return await create_dm(user_id, recipient_id) + + +@bp.route('//channels', methods=['POST']) +async def create_group_dm(p_user_id: int): + """Create a DM or a Group DM with user(s).""" + user_id = await token_check() + assert user_id == p_user_id + + j = validate(await request.get_json(), CREATE_GROUP_DM) + recipients = j['recipients'] + + if len(recipients) == 1: + # its a group dm with 1 user... a dm! + return await create_dm(user_id, int(recipients[0])) + + # TODO: group dms + return 'group dms not implemented', 500 diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 7bc762d..0b37e54 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -1,14 +1,12 @@ import random -from quart import Blueprint, jsonify, request, current_app as app from asyncpg import UniqueViolationError +from quart import Blueprint, jsonify, request, current_app as app from ..auth import token_check -from ..snowflake import get_snowflake -from ..errors import Forbidden, BadRequest, Unauthorized +from ..errors import Forbidden, BadRequest from ..schemas import validate, USER_SETTINGS, \ - CREATE_DM, CREATE_GROUP_DM, USER_UPDATE, GUILD_SETTINGS -from ..enums import ChannelType, RelationshipType + USER_UPDATE, GUILD_SETTINGS from .guilds import guild_check from .auth import hash_data, check_password, check_username_usage @@ -268,90 +266,6 @@ async def get_connections(): pass -@bp.route('/@me/channels', methods=['GET']) -async def get_dms(): - user_id = await token_check() - dms = await app.storage.get_dms(user_id) - return jsonify(dms) - - -async def try_dm_state(user_id, dm_id): - """Try insertin the user into the dm state - for the given DM.""" - try: - await app.db.execute(""" - INSERT INTO dm_channel_state (user_id, dm_id) - VALUES ($1, $2) - """, user_id, dm_id) - except UniqueViolationError: - # if already in state, ignore - pass - - -async def create_dm(user_id, recipient_id): - """Create a new dm with a user, - or get the existing DM id if it already exists.""" - dm_id = get_snowflake() - - try: - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, dm_id, ChannelType.DM.value) - - await app.db.execute(""" - INSERT INTO dm_channels (id, party1_id, party2_id) - VALUES ($1, $2, $3) - """, dm_id, user_id, recipient_id) - - # the dm state is something we use - # to give the currently "open dms" - # on the client. - - # we don't open a dm for the peer/recipient - # until the user sends a message. - await try_dm_state(user_id, dm_id) - - except UniqueViolationError: - # the dm already exists - dm_id = await app.db.fetchval(""" - SELECT id - FROM dm_channels - WHERE (party1_id = $1 OR party2_id = $1) AND - (party2_id = $2 OR party2_id = $2) - """, user_id, recipient_id) - - dm = await app.storage.get_dm(dm_id, user_id) - return jsonify(dm) - - -@bp.route('/@me/channels', methods=['POST']) -async def start_dm(): - """Create a DM with a user.""" - user_id = await token_check() - j = validate(await request.get_json(), CREATE_DM) - recipient_id = j['recipient_id'] - - return await create_dm(user_id, recipient_id) - - -@bp.route('//channels', methods=['POST']) -async def create_group_dm(p_user_id: int): - """Create a DM or a Group DM with user(s).""" - user_id = await token_check() - assert user_id == p_user_id - - j = validate(await request.get_json(), CREATE_GROUP_DM) - recipients = j['recipients'] - - if len(recipients) == 1: - # its a group dm with 1 user... a dm! - return await create_dm(user_id, int(recipients[0])) - - # TODO: group dms - return 'group dms not implemented', 500 - - @bp.route('/@me/notes/', methods=['PUT']) async def put_note(target_id: int): """Put a note to a user.""" @@ -536,7 +450,7 @@ async def patch_guild_settings(guild_id: int): settings = await app.storage.get_guild_settings_one(user_id, guild_id) - await app.dispatcher.dispatch_user(user_id, 'GUILD_SETTINGS_UPDATE', { + await app.dispatcher.dispatch_user(user_id, 'USER_GUILD_SETTINGS_UPDATE', { **settings, **{'guild_id': guild_id} }) diff --git a/litecord/storage.py b/litecord/storage.py index 4edb591..7804d4c 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -913,7 +913,7 @@ class Storage: """Get the specific User Guild Settings, for all guilds a user is on.""" - res = {} + res = [] settings = await self.db.fetch(""" SELECT guild_id, suppress_everyone, muted @@ -925,7 +925,6 @@ class Storage: for row in settings: gid = row['guild_id'] drow = dict(row) - drow.pop('guild_id') chan_ids = await self.get_channel_ids(gid) @@ -943,9 +942,9 @@ class Storage: chan_overrides[str(chan_id)] = dict(chan_row) - res[str(gid)] = {**drow, **{ + res.append({**drow, **{ 'channel_overrides': chan_overrides - }} + }}) return res diff --git a/run.py b/run.py index c48646b..98f674d 100644 --- a/run.py +++ b/run.py @@ -11,7 +11,7 @@ from logbook.compat import redirect_logging import config from litecord.blueprints import gateway, auth, users, guilds, channels, \ - webhooks, science, voice, invites, relationships + webhooks, science, voice, invites, relationships, dms from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -56,6 +56,7 @@ bps = { science: None, voice: '/voice', invites: None, + dms: '/users' } for bp, suffix in bps.items(): diff --git a/schema.sql b/schema.sql index 30c5a1d..e43ba45 100644 --- a/schema.sql +++ b/schema.sql @@ -283,11 +283,11 @@ CREATE TABLE IF NOT EXISTS guild_settings ( CREATE TABLE IF NOT EXISTS guild_settings_channel_overrides ( user_id bigint REFERENCES users (id) ON DELETE CASCADE, - guild_id bigint REFERENCES guilds (id) ON DELETE CASCADE + guild_id bigint REFERENCES guilds (id) ON DELETE CASCADE, channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, muted bool DEFAULT false, - message_notifications bool DEFAULT 0, + message_notifications int DEFAULT 0, PRIMARY KEY (user_id, guild_id, channel_id) );