mirror of https://gitlab.com/litecord/litecord.git
parent
506bd8afbe
commit
d91030a2c1
|
|
@ -217,7 +217,8 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
|
|||
# for the users that have a state
|
||||
# in the channel.
|
||||
if mentions_here:
|
||||
uids = []
|
||||
uids = set()
|
||||
|
||||
await app.db.execute("""
|
||||
UPDATE user_read_state
|
||||
SET mention_count = mention_count + 1
|
||||
|
|
@ -229,7 +230,7 @@ async def _guild_text_mentions(payload: dict, guild_id: int,
|
|||
# that might not have read permissions
|
||||
# to the channel.
|
||||
if mentions_everyone:
|
||||
uids = []
|
||||
uids = set()
|
||||
|
||||
member_ids = await app.storage.get_member_ids(guild_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
|
@ -262,7 +263,7 @@ async def _update_pos(channel_id, pos: int):
|
|||
""", pos, channel_id)
|
||||
|
||||
|
||||
async def _mass_chan_update(guild_id, channel_ids: int):
|
||||
async def _mass_chan_update(guild_id, channel_ids: List[int]):
|
||||
for channel_id in channel_ids:
|
||||
chan = await app.storage.get_channel(channel_id)
|
||||
await app.dispatcher.dispatch(
|
||||
|
|
@ -337,7 +338,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
|||
if 'position' in j:
|
||||
channel_data = await app.storage.get_channel_data(guild_id)
|
||||
|
||||
chans = [None * len(channel_data)]
|
||||
chans = [None] * len(channel_data)
|
||||
for chandata in channel_data:
|
||||
chans.insert(chandata['position'], int(chandata['id']))
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ async def get_members(guild_id):
|
|||
|
||||
|
||||
async def _update_member_roles(guild_id: int, member_id: int,
|
||||
wanted_roles: list):
|
||||
wanted_roles: set):
|
||||
"""Update the roles a member has."""
|
||||
|
||||
# first, fetch all current roles
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple
|
||||
|
||||
from quart import Blueprint, request, current_app as app, jsonify
|
||||
from logbook import Logger
|
||||
|
|
@ -184,10 +184,11 @@ async def _role_pairs_update(guild_id: int, pairs: list):
|
|||
await _role_update_dispatch(role_1, guild_id)
|
||||
await _role_update_dispatch(role_2, guild_id)
|
||||
|
||||
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
|
||||
|
||||
def gen_pairs(list_of_changes: List[Dict[str, int]],
|
||||
current_state: Dict[int, int],
|
||||
blacklist: List[int] = None) -> List[tuple]:
|
||||
blacklist: List[int] = None) -> PairList:
|
||||
"""Generate a list of pairs that, when applied to the database,
|
||||
will generate the desired state given in list_of_changes.
|
||||
|
||||
|
|
@ -262,7 +263,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
|
|||
|
||||
# if its being swapped to leave space, add it
|
||||
# to the pairs list
|
||||
if new_pos_2:
|
||||
if element_2 and new_pos_2:
|
||||
pairs.append(
|
||||
((element_1, new_pos_1), (element_2, new_pos_2))
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from collections import Counter
|
||||
from random import choice
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ def _majority_region_count(regions: list) -> str:
|
|||
return region
|
||||
|
||||
|
||||
async def _choose_random_region() -> str:
|
||||
async def _choose_random_region() -> Optional[str]:
|
||||
"""Give a random voice region."""
|
||||
regions = await app.db.fetch("""
|
||||
SELECT id
|
||||
|
|
@ -51,7 +52,7 @@ async def _choose_random_region() -> str:
|
|||
return choice(regions)
|
||||
|
||||
|
||||
async def _majority_region_any(user_id) -> str:
|
||||
async def _majority_region_any(user_id) -> Optional[str]:
|
||||
"""Calculate the most likely region to make the user happy, but
|
||||
this is based on the guilds the user is IN, instead of the guilds
|
||||
the user owns."""
|
||||
|
|
@ -79,7 +80,7 @@ async def _majority_region_any(user_id) -> str:
|
|||
return most_common
|
||||
|
||||
|
||||
async def majority_region(user_id) -> str:
|
||||
async def majority_region(user_id: int) -> Optional[str]:
|
||||
"""Given a user ID, give the most likely region for the user to be
|
||||
happy with."""
|
||||
regions = await app.db.fetch("""
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ class GatewayWebsocket:
|
|||
's': None
|
||||
})
|
||||
|
||||
def _check_ratelimit(self, key: str, ratelimit_key: str):
|
||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||
return bucket.update_rate_limit()
|
||||
|
|
@ -292,7 +292,7 @@ class GatewayWebsocket:
|
|||
|
||||
await self.send(payload)
|
||||
|
||||
async def _make_guild_list(self) -> List[int]:
|
||||
async def _make_guild_list(self) -> List[Dict[str, Any]]:
|
||||
user_id = self.state.user_id
|
||||
|
||||
guild_ids = await self._guild_ids()
|
||||
|
|
@ -772,7 +772,7 @@ class GatewayWebsocket:
|
|||
|
||||
await self._resume(range(seq, state.seq))
|
||||
|
||||
async def _req_guild_members(self, guild_id: str, user_ids: List[int],
|
||||
async def _req_guild_members(self, guild_id, user_ids: List[int],
|
||||
query: str, limit: int):
|
||||
try:
|
||||
guild_id = int(guild_id)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -54,13 +54,13 @@ class ChannelDispatcher(DispatcherWithState):
|
|||
VAL_TYPE = int
|
||||
|
||||
async def dispatch(self, channel_id,
|
||||
event: str, data: Any):
|
||||
event: str, data: Any) -> List[str]:
|
||||
"""Dispatch an event to a channel."""
|
||||
# get everyone who is subscribed
|
||||
# and store the number of states we dispatched the event to
|
||||
user_ids = self.state[channel_id]
|
||||
dispatched = 0
|
||||
sessions = []
|
||||
sessions: List[str] = []
|
||||
|
||||
# making a copy of user_ids since
|
||||
# we'll modify it later on.
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
"""
|
||||
litecord.pubsub.dispatcher: main dispatcher class
|
||||
"""
|
||||
from typing import List
|
||||
from collections import defaultdict
|
||||
|
||||
from logbook import Logger
|
||||
|
|
@ -82,7 +80,8 @@ class Dispatcher:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _dispatch_states(self, states: list, event: str, data) -> int:
|
||||
async def _dispatch_states(self, states: list, event: str,
|
||||
data) -> List[str]:
|
||||
"""Dispatch an event to a list of states."""
|
||||
res = []
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ lazy guilds:
|
|||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Dict, Union
|
||||
from typing import Any, List, Dict, Union, Optional, Iterable, Tuple
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
from logbook import Logger
|
||||
|
|
@ -39,7 +39,7 @@ from litecord.permissions import (
|
|||
)
|
||||
from litecord.utils import index_by_func
|
||||
from litecord.utils import mmh3
|
||||
|
||||
from litecord.gateway.state import GatewayState
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ class MemberList:
|
|||
yield group, self.data[group.gid]
|
||||
|
||||
@property
|
||||
def iter_non_empty(self) -> tuple:
|
||||
def iter_non_empty(self) -> Generator[Tuple[GroupInfo, List[int]]]:
|
||||
"""Only iterate through non-empty groups.
|
||||
|
||||
Note that while the offline group can be empty, it is always
|
||||
|
|
@ -359,7 +359,7 @@ class GuildMemberList:
|
|||
# then the final perms for that role if
|
||||
# any overwrite exists in the channel
|
||||
final_perms = overwrite_find_mix(
|
||||
role_perms, self.list.overwrites, group.gid)
|
||||
role_perms, self.list.overwrites, int(group.gid))
|
||||
|
||||
# update the group's permissions
|
||||
# with the mixed ones
|
||||
|
|
@ -423,7 +423,7 @@ class GuildMemberList:
|
|||
|
||||
async def _get_group_for_member(self, member_id: int,
|
||||
roles: List[Union[str, int]],
|
||||
status: str) -> GroupID:
|
||||
status: str) -> Optional[GroupID]:
|
||||
"""Return a fitting group ID for the member."""
|
||||
member_roles = list(map(int, roles))
|
||||
|
||||
|
|
@ -463,15 +463,15 @@ class GuildMemberList:
|
|||
self.list.members[member_id] = member
|
||||
self.list.data[group_id].append(member_id)
|
||||
|
||||
def _display_name(self, member_id: int) -> str:
|
||||
def _display_name(self, member_id: int) -> Optional[str]:
|
||||
"""Get the display name for a given member.
|
||||
|
||||
This is more efficient than the old function (not method) of same
|
||||
name, as we dont need to pass nickname information to it.
|
||||
"""
|
||||
member = self.list.members.get(member_id)
|
||||
|
||||
if not member_id:
|
||||
try:
|
||||
member = self.list.members[member_id]
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
username = member['user']['username']
|
||||
|
|
@ -578,7 +578,7 @@ class GuildMemberList:
|
|||
if not self.state:
|
||||
self._set_empty_list()
|
||||
|
||||
def _get_state(self, session_id: str):
|
||||
def _get_state(self, session_id: str) -> Optional[GatewayState]:
|
||||
"""Get the state for a session id.
|
||||
|
||||
Wrapper for :meth:`StateManager.fetch_raw`
|
||||
|
|
@ -625,7 +625,8 @@ class GuildMemberList:
|
|||
|
||||
return dispatched
|
||||
|
||||
async def _resync(self, session_ids: int, item_index: int) -> List[str]:
|
||||
async def _resync(self, session_ids: List[int],
|
||||
item_index: int) -> List[str]:
|
||||
"""Send a SYNC event to all states that are subscribed to an item.
|
||||
|
||||
Returns
|
||||
|
|
@ -729,7 +730,7 @@ class GuildMemberList:
|
|||
# send SYNCs to the state that requested
|
||||
await self._dispatch_sess([session_id], ops)
|
||||
|
||||
def _get_item_index(self, user_id: Union[str, int]) -> int:
|
||||
def _get_item_index(self, user_id: Union[str, int]) -> Optional[int]:
|
||||
"""Get the item index a user is on."""
|
||||
# NOTE: this is inefficient
|
||||
user_id = int(user_id)
|
||||
|
|
@ -749,7 +750,7 @@ class GuildMemberList:
|
|||
|
||||
return None
|
||||
|
||||
def _get_group_item_index(self, group_id: GroupID) -> int:
|
||||
def _get_group_item_index(self, group_id: GroupID) -> Optional[int]:
|
||||
"""Get the item index a group is on."""
|
||||
index = 0
|
||||
|
||||
|
|
@ -773,7 +774,7 @@ class GuildMemberList:
|
|||
|
||||
return False
|
||||
|
||||
def _get_subs(self, item_index: int) -> filter:
|
||||
def _get_subs(self, item_index: int) -> Iterable[str]:
|
||||
"""Get the list of subscribed states to a given item."""
|
||||
return filter(
|
||||
lambda sess_id: self._is_subbed(item_index, sess_id),
|
||||
|
|
@ -1141,7 +1142,7 @@ class GuildMemberList:
|
|||
# when bots come along.
|
||||
self.list.data[new_group.gid] = []
|
||||
|
||||
def _get_role_as_group_idx(self, role_id: int) -> int:
|
||||
def _get_role_as_group_idx(self, role_id: int) -> Optional[int]:
|
||||
"""Get a group index representing the given role id.
|
||||
|
||||
Returns
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ class LitecordValidator(Validator):
|
|||
|
||||
|
||||
def validate(reqjson: Union[Dict, List], schema: Dict,
|
||||
raise_err: bool = True) -> Union[Dict, List]:
|
||||
raise_err: bool = True) -> Dict:
|
||||
"""Validate a given document (user-input) and give
|
||||
the correct document as a result.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ class Storage:
|
|||
self.db = app.db
|
||||
self.presence = None
|
||||
|
||||
async def fetchrow_with_json(self, query: str, *args):
|
||||
async def fetchrow_with_json(self, query: str, *args) -> Any:
|
||||
"""Fetch a single row with JSON/JSONB support."""
|
||||
# the pool by itself doesn't have
|
||||
# set_type_codec, so we must set it manually
|
||||
|
|
@ -86,19 +86,19 @@ class Storage:
|
|||
await pg_set_json(con)
|
||||
return await con.fetchrow(query, *args)
|
||||
|
||||
async def fetch_with_json(self, query: str, *args):
|
||||
async def fetch_with_json(self, query: str, *args) -> List[Any]:
|
||||
"""Fetch many rows with JSON/JSONB support."""
|
||||
async with self.db.acquire() as con:
|
||||
await pg_set_json(con)
|
||||
return await con.fetch(query, *args)
|
||||
|
||||
async def execute_with_json(self, query: str, *args):
|
||||
async def execute_with_json(self, query: str, *args) -> str:
|
||||
"""Execute a SQL statement with JSON/JSONB support."""
|
||||
async with self.db.acquire() as con:
|
||||
await pg_set_json(con)
|
||||
return await con.execute(query, *args)
|
||||
|
||||
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
|
||||
async def get_user(self, user_id, secure=False) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single user payload."""
|
||||
user_id = int(user_id)
|
||||
|
||||
|
|
@ -115,7 +115,7 @@ class Storage:
|
|||
""", user_id)
|
||||
|
||||
if not user_row:
|
||||
return
|
||||
return None
|
||||
|
||||
duser = dict(user_row)
|
||||
|
||||
|
|
@ -141,7 +141,7 @@ class Storage:
|
|||
"""Search a user"""
|
||||
if len(discriminator) < 4:
|
||||
# how do we do this in f-strings again..?
|
||||
discriminator = '%04d' % discriminator
|
||||
discriminator = '%04d' % int(discriminator)
|
||||
|
||||
return await self.db.fetchval("""
|
||||
SELECT id FROM users
|
||||
|
|
@ -219,12 +219,12 @@ class Storage:
|
|||
}
|
||||
|
||||
async def get_member_data_one(self, guild_id: int,
|
||||
member_id: int) -> Dict[str, Any]:
|
||||
member_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get data about one member in a guild."""
|
||||
basic = await self._member_basic(guild_id, member_id)
|
||||
|
||||
if not basic:
|
||||
return
|
||||
return None
|
||||
|
||||
return await self._member_dict(basic, guild_id, member_id)
|
||||
|
||||
|
|
@ -376,7 +376,7 @@ class Storage:
|
|||
return [r['member_id'] for r in user_ids]
|
||||
|
||||
async def _gdm_recipients(self, channel_id: int,
|
||||
reference_id: int = None) -> List[int]:
|
||||
reference_id: int = None) -> List[Dict]:
|
||||
"""Get the list of users that are recipients of the
|
||||
given Group DM."""
|
||||
recipients = await self.gdm_recipient_ids(channel_id)
|
||||
|
|
@ -392,7 +392,8 @@ class Storage:
|
|||
|
||||
return res
|
||||
|
||||
async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]:
|
||||
async def get_channel(self, channel_id: int,
|
||||
**kwargs) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch a single channel's information."""
|
||||
chan_type = await self.get_chan_type(channel_id)
|
||||
ctype = ChannelType(chan_type)
|
||||
|
|
@ -501,7 +502,7 @@ class Storage:
|
|||
return channels
|
||||
|
||||
async def get_role(self, role_id: int,
|
||||
guild_id: int = None) -> Dict[str, Any]:
|
||||
guild_id: int = None) -> Optional[Dict[str, Any]]:
|
||||
"""get a single role's information."""
|
||||
|
||||
guild_field = 'AND guild_id = $2' if guild_id else ''
|
||||
|
|
@ -519,7 +520,7 @@ class Storage:
|
|||
""", *args)
|
||||
|
||||
if not row:
|
||||
return
|
||||
return None
|
||||
|
||||
return dict(row)
|
||||
|
||||
|
|
@ -769,7 +770,8 @@ class Storage:
|
|||
|
||||
res.pop('author_id')
|
||||
|
||||
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
||||
async def get_message(self, message_id: int,
|
||||
user_id: Optional[int] = None) -> Optional[Dict]:
|
||||
"""Get a single message's payload."""
|
||||
row = await self.fetchrow_with_json("""
|
||||
SELECT id::text, channel_id::text, author_id, webhook_id, content,
|
||||
|
|
@ -780,7 +782,7 @@ class Storage:
|
|||
""", message_id)
|
||||
|
||||
if not row:
|
||||
return
|
||||
return None
|
||||
|
||||
res = dict(row)
|
||||
res['nonce'] = str(res['nonce'])
|
||||
|
|
@ -915,7 +917,8 @@ class Storage:
|
|||
'approximate_member_count': len(mids),
|
||||
}
|
||||
|
||||
async def get_invite_metadata(self, invite_code: str) -> Dict[str, Any]:
|
||||
async def get_invite_metadata(self,
|
||||
invite_code: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch invite metadata (max_age and friends)."""
|
||||
invite = await self.db.fetchrow("""
|
||||
SELECT code, inviter, created_at, uses,
|
||||
|
|
@ -925,7 +928,7 @@ class Storage:
|
|||
""", invite_code)
|
||||
|
||||
if invite is None:
|
||||
return
|
||||
return None
|
||||
|
||||
dinv = dict_(invite)
|
||||
inviter = await self.get_user(invite['inviter'])
|
||||
|
|
@ -966,7 +969,7 @@ class Storage:
|
|||
|
||||
return parties[0]
|
||||
|
||||
async def get_emoji(self, emoji_id: int) -> Dict:
|
||||
async def get_emoji(self, emoji_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single emoji."""
|
||||
row = await self.db.fetchrow("""
|
||||
SELECT id::text, name, animated, managed,
|
||||
|
|
@ -976,7 +979,7 @@ class Storage:
|
|||
""", emoji_id)
|
||||
|
||||
if not row:
|
||||
return
|
||||
return None
|
||||
|
||||
drow = dict(row)
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
import asyncio
|
||||
import json
|
||||
from logbook import Logger
|
||||
from typing import Any
|
||||
from typing import Any, Iterable
|
||||
from quart.json import JSONEncoder
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -160,7 +160,7 @@ async def pg_set_json(con):
|
|||
)
|
||||
|
||||
|
||||
def yield_chunks(input_list: list, chunk_size: int):
|
||||
def yield_chunks(input_list: Iterable, chunk_size: int):
|
||||
"""Yield successive n-sized chunks from l.
|
||||
|
||||
Taken from https://stackoverflow.com/a/312464.
|
||||
|
|
|
|||
Loading…
Reference in New Issue