blueprints.channels: add channel and guild ack routes

SQL for instances: Rerun `schema.sql` for the new table.

 - gateway.websocket: add get_read_state to read_state's ready
 - gateway.websocket: add get_dms on private_channels' ready
 - storage: fix get_dms()
 - storage: add Storage.get_channel_ids()
 - storage: add Storage.get_read_state()
 - schema.sql: add user_read_state table
This commit is contained in:
Luna Mendes 2018-10-09 22:52:12 -03:00
parent 5edcc62be4
commit adc6a58179
6 changed files with 184 additions and 56 deletions

View File

@ -6,47 +6,15 @@ from logbook import Logger
from ..auth import token_check
from ..snowflake import get_snowflake, snowflake_datetime
from ..enums import ChannelType, MessageType, GUILD_CHANS
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
from ..errors import Forbidden, ChannelNotFound, MessageNotFound
from ..schemas import validate, MESSAGE_CREATE
from .guilds import guild_check
from .checks import channel_check, guild_check
log = Logger(__name__)
bp = Blueprint('channels', __name__)
async def channel_check(user_id, channel_id):
"""Check if the current user is authorized
to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None:
raise ChannelNotFound(f'channel type not found')
ctype = ChannelType(chan_type)
if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval("""
SELECT guild_id
FROM guild_channels
WHERE guild_channels.id = $1
""", channel_id)
await guild_check(user_id, guild_id)
return guild_id
if ctype == ChannelType.DM:
parties = await app.db.fetchval("""
SELECT party1_id, party2_id
FROM dm_channels
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
""", channel_id, user_id)
# get the id of the other party
parties.remove(user_id)
return parties[0]
@bp.route('/<int:channel_id>', methods=['GET'])
async def get_channel(channel_id):
"""Get a single channel's information"""
@ -263,6 +231,17 @@ async def create_message(channel_id):
payload = await app.storage.get_message(message_id)
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
for str_uid in payload['mentions']:
uid = int(str_uid)
await app.db.execute("""
UPDATE user_read_state
SET mention_count += 1
WHERE user_id = $1 AND channel_id = $2
""", uid, channel_id)
return jsonify(payload)
@ -431,3 +410,55 @@ async def trigger_typing(channel_id):
})
return '', 204
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
"""ACK a channel."""
if not message_id:
message_id = await app.storage.chan_last_message(channel_id)
res = await app.db.execute("""
UPDATE user_read_state
SET last_message_id = $1,
mention_count = 0
WHERE user_id = $2 AND channel_id = $3
""", message_id, user_id, channel_id)
if res == 'UPDATE 0':
await app.db.execute("""
INSERT INTO user_read_state
(user_id, channel_id, last_message_id, mention_count)
VALUES ($1, $2, $3, $4)
""", user_id, channel_id, message_id, 0)
if guild_id:
await app.dispatcher.dispatch_user_guild(
user_id, guild_id, 'MESSAGE_ACK', {
'message_id': str(message_id),
'channel_id': str(channel_id)
})
else:
# TODO: use ChannelDispatcher
await app.dispatcher.dispatch_user(
user_id, 'MESSAGE_ACK', {
'message_id': str(message_id),
'channel_id': str(channel_id)
})
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
async def ack_channel(channel_id, message_id):
user_id = await token_check()
guild_id = await channel_check(user_id, channel_id)
await channel_ack(user_id, guild_id, channel_id, message_id)
return jsonify({
# token seems to be used for
# data collection activities,
# so we never use it.
'token': None
})

View File

@ -0,0 +1,48 @@
from quart import current_app as app
from ..enums import ChannelType, GUILD_CHANS
from ..errors import GuildNotFound, ChannelNotFound
async def guild_check(user_id: int, guild_id: int):
"""Check if a user is in a guild."""
joined_at = await app.db.execute("""
SELECT joined_at
FROM members
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
if not joined_at:
raise GuildNotFound('guild not found')
async def channel_check(user_id, channel_id):
"""Check if the current user is authorized
to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None:
raise ChannelNotFound(f'channel type not found')
ctype = ChannelType(chan_type)
if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval("""
SELECT guild_id
FROM guild_channels
WHERE guild_channels.id = $1
""", channel_id)
await guild_check(user_id, guild_id)
return guild_id
if ctype == ChannelType.DM:
parties = await app.db.fetchval("""
SELECT party1_id, party2_id
FROM dm_channels
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
""", channel_id, user_id)
# get the id of the other party
parties.remove(user_id)
return parties[0]

View File

@ -5,22 +5,12 @@ from ..snowflake import get_snowflake
from ..enums import ChannelType
from ..errors import Forbidden, GuildNotFound, BadRequest
from ..schemas import validate, GUILD_UPDATE
from .channels import channel_ack
from .checks import guild_check, channel_check
bp = Blueprint('guilds', __name__)
async def guild_check(user_id: int, guild_id: int):
"""Check if a user is in a guild."""
joined_at = await app.db.execute("""
SELECT joined_at
FROM members
WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id)
if not joined_at:
raise GuildNotFound('guild not found')
async def guild_owner_check(user_id: int, guild_id: int):
"""Check if a user is the owner of the guild."""
owner_id = await app.db.fetchval("""
@ -469,3 +459,16 @@ async def search_messages(guild_id):
'messages': [],
'analytics_id': 'ass',
})
@bp.route('/<int:guild_id>/ack', methods=['POST'])
async def ack_guild(guild_id):
user_id = await token_check()
await guild_check(user_id, guild_id)
chan_ids = await app.storage.get_channel_ids(guild_id)
for chan_id in chan_ids:
await channel_ack(user_id, guild_id, chan_id)
return '', 204

View File

@ -207,6 +207,8 @@ class GatewayWebsocket:
'user_settings': await self.storage.get_user_settings(user_id),
'notes': await self.storage.fetch_notes(user_id),
'relationships': await self.storage.get_relationships(user_id),
'read_state': await self.storage.get_read_state(user_id),
'friend_suggestion_count': 0,
# TODO
@ -215,9 +217,6 @@ class GatewayWebsocket:
# TODO
'presences': [],
# TODO
'read_state': [],
# TODO
'connected_accounts': [],
@ -229,7 +228,9 @@ class GatewayWebsocket:
async def dispatch_ready(self):
"""Dispatch the READY packet for a connecting account."""
guilds = await self._make_guild_list()
user = await self.storage.get_user(self.state.user_id, True)
user_id = self.state.user_id
user = await self.storage.get_user(user_id, True)
uready = {}
if not self.state.bot:
@ -240,8 +241,8 @@ class GatewayWebsocket:
'v': 6,
'user': user,
# TODO: dms
'private_channels': [],
'private_channels': await self.storage.get_dms(user_id),
'guilds': guilds,
'session_id': self.state.session_id,
'_trace': ['transbian']

View File

@ -225,7 +225,7 @@ class Storage:
members = await self.get_member_multi(guild_id, mids)
return members
async def _chan_last_message(self, channel_id: int):
async def chan_last_message(self, channel_id: int):
return await self.db.fetchval("""
SELECT MAX(id)
FROM messages
@ -247,7 +247,7 @@ class Storage:
return {**row, **{
'topic': topic,
'last_message_id': str(
await self._chan_last_message(row['id']))
await self.chan_last_message(row['id']))
}}
elif chan_type == ChannelType.GUILD_VOICE:
vrow = await self.db.fetchval("""
@ -329,7 +329,7 @@ class Storage:
drow['type'] = chan_type
drow['last_message_id'] = str(
await self._chan_last_message(channel_id))
await self.chan_last_message(channel_id))
# dms have just two recipients.
drow['recipients'] = [
@ -347,6 +347,15 @@ class Storage:
return None
async def get_channel_ids(self, guild_id: int) -> List[int]:
rows = await self.db.fetch("""
SELECT id
FROM guild_channels
WHERE guild_id = $1
""", guild_id)
return [r['id'] for r in rows]
async def get_channel_data(self, guild_id) -> List[Dict]:
"""Get channel information on a guild"""
channel_basics = await self.db.fetch("""
@ -753,7 +762,7 @@ class Storage:
which is different than the whole list of DM channels.
"""
dm_ids = await self.db.fetch("""
SELECT id
SELECT dm_id
FROM dm_channel_state
WHERE user_id = $1
""", user_id)
@ -781,3 +790,25 @@ class Storage:
res.append(dm_chan)
return res
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
"""Get the read state for a user."""
rows = await self.db.fetch("""
SELECT channel_id, last_message_id, mention_count
FROM user_read_state
WHERE user_id = $1
""", user_id)
res = []
for row in rows:
drow = dict(row)
drow['id'] = str(drow['channel_id'])
drow.pop('channel_id')
drow['last_message_id'] = str(drow['last_message_id'])
res.append(drow)
return res

View File

@ -180,6 +180,20 @@ CREATE TABLE IF NOT EXISTS channels (
channel_type int NOT NULL
);
CREATE TABLE IF NOT EXISTS user_read_state (
user_id bigint REFERENCES users (id),
channel_id bigint REFERENCES channels (id),
-- we don't really need to link
-- this column to the messages table
last_message_id bigint,
-- counts are always positive
mention_count bigint CHECK (mention_count > -1),
PRIMARY KEY (user_id, channel_id)
);
CREATE TABLE IF NOT EXISTS guilds (
id bigint PRIMARY KEY NOT NULL,