user_storage: fix channel_overrides being a dict

This commit is contained in:
Luna 2018-12-10 00:51:03 -03:00
parent c5ccc5a9d0
commit 5c38198137
2 changed files with 22 additions and 40 deletions

View File

@ -18,7 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import urllib.parse
from .websocket import GatewayWebsocket
from litecord.gateway.websocket import GatewayWebsocket
async def websocket_handler(app, ws, url):
@ -54,6 +54,8 @@ async def websocket_handler(app, ws, url):
if gw_compress and gw_compress not in ('zlib-stream',):
return await ws.close(1000, 'Invalid gateway compress')
print('encoding', gw_encoding, 'compression', gw_compress)
gws = GatewayWebsocket(ws, app, v=gw_version,
encoding=gw_encoding, compress=gw_compress)
await gws.run()

View File

@ -211,6 +211,23 @@ class UserStorage:
return res
async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List:
chan_overrides = []
overrides = await self.db.fetch("""
SELECT channel_id::text, muted, message_notifications
FROM guild_settings_channel_overrides
WHERE
user_id = $1
AND guild_id = $2
""", user_id, guild_id)
for chan_row in overrides:
dcrow = dict(chan_row)
chan_overrides.append(dcrow)
return chan_overrides
async def get_guild_settings_one(self, user_id: int,
guild_id: int) -> dict:
"""Get guild settings information for a single guild."""
@ -231,25 +248,7 @@ class UserStorage:
gid = int(row['guild_id'])
drow = dict(row)
chan_overrides = {}
overrides = await self.db.fetch("""
SELECT channel_id::text, muted, message_notifications
FROM guild_settings_channel_overrides
WHERE
user_id = $1
AND guild_id = $2
""", user_id, gid)
for chan_row in overrides:
dcrow = dict(chan_row)
chan_id = dcrow['channel_id']
dcrow.pop('channel_id')
chan_overrides[chan_id] = dcrow
chan_overrides = await self._get_chan_overrides(user_id, gid)
return {**drow, **{
'channel_overrides': chan_overrides
}}
@ -271,26 +270,7 @@ class UserStorage:
gid = int(row['guild_id'])
drow = dict(row)
chan_overrides = {}
overrides = await self.db.fetch("""
SELECT channel_id::text, muted, message_notifications
FROM guild_settings_channel_overrides
WHERE
user_id = $1
AND guild_id = $2
""", user_id, gid)
for chan_row in overrides:
dcrow = dict(chan_row)
# channel_id isn't on the value of the dict
# so we query it (for the key) then pop
# from the value
chan_id = dcrow['channel_id']
dcrow.pop('channel_id')
chan_overrides[chan_id] = dcrow
chan_overrides = await self._get_chan_overrides(user_id, gid)
res.append({**drow, **{
'channel_overrides': chan_overrides