mirror of https://gitlab.com/litecord/litecord.git
typing, episode 2
This commit is contained in:
parent
ed3c436b6d
commit
9dab5b20ae
|
|
@ -58,11 +58,11 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
|||
user_settings = await app.user_storage.get_user_settings(user_id)
|
||||
peer_settings = await app.user_storage.get_user_settings(peer_id)
|
||||
|
||||
restricted_user = [int(v) for v in user_settings['restricted_guilds']]
|
||||
restricted_peer = [int(v) for v in peer_settings['restricted_guilds']]
|
||||
restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
|
||||
restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
|
||||
|
||||
restricted_user = set(restricted_user)
|
||||
restricted_peer = set(restricted_peer)
|
||||
restricted_user = set(restricted_user_)
|
||||
restricted_peer = set(restricted_peer_)
|
||||
|
||||
mutual_guilds -= restricted_user
|
||||
mutual_guilds -= restricted_peer
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ litecord.embed.sanitizer
|
|||
sanitize embeds by giving common values
|
||||
such as type: rich
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
|
||||
from logbook import Logger
|
||||
from quart import current_app as app
|
||||
|
|
@ -44,7 +44,7 @@ def sanitize_embed(embed: Embed) -> Embed:
|
|||
}}
|
||||
|
||||
|
||||
def path_exists(embed: Embed, components: str):
|
||||
def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
||||
"""Tell if a given path exists in an embed (or any dictionary).
|
||||
|
||||
The components string is formatted like this:
|
||||
|
|
@ -54,10 +54,10 @@ def path_exists(embed: Embed, components: str):
|
|||
"""
|
||||
|
||||
# get the list of components given
|
||||
if isinstance(components, str):
|
||||
components = components.split('.')
|
||||
if isinstance(components_in, str):
|
||||
components = components_in.split('.')
|
||||
else:
|
||||
components = list(components)
|
||||
components = list(components_in)
|
||||
|
||||
# if there are no components, we reached the end of recursion
|
||||
# and can return true
|
||||
|
|
@ -96,7 +96,7 @@ def proxify(url, *, config=None) -> str:
|
|||
)
|
||||
|
||||
|
||||
async def fetch_metadata(url, *, config=None, session=None) -> dict:
|
||||
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
|
||||
"""Fetch metadata for a url."""
|
||||
|
||||
if session is None:
|
||||
|
|
@ -123,7 +123,7 @@ async def fetch_metadata(url, *, config=None, session=None) -> dict:
|
|||
|
||||
log.warning('failed to generate meta for {!r}: {} {!r}',
|
||||
url, resp.status, body)
|
||||
return
|
||||
return None
|
||||
|
||||
return await resp.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import mimetypes
|
|||
import asyncio
|
||||
import base64
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from dataclasses import dataclass
|
||||
from hashlib import sha256
|
||||
|
|
@ -67,22 +68,33 @@ def get_mime(ext: str):
|
|||
@dataclass
|
||||
class Icon:
|
||||
"""Main icon class"""
|
||||
key: str
|
||||
icon_hash: str
|
||||
mime: str
|
||||
key: Optional[str]
|
||||
icon_hash: Optional[str]
|
||||
mime: Optional[str]
|
||||
|
||||
@property
|
||||
def as_path(self) -> str:
|
||||
def as_path(self) -> Optional[str]:
|
||||
"""Return a filesystem path for the given icon."""
|
||||
if self.mime is None:
|
||||
return None
|
||||
|
||||
ext = get_ext(self.mime)
|
||||
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
|
||||
|
||||
@property
|
||||
def as_pathlib(self) -> str:
|
||||
def as_pathlib(self) -> Optional[Path]:
|
||||
"""Get a Path instance of this icon."""
|
||||
if self.as_path is None:
|
||||
return None
|
||||
|
||||
return Path(self.as_path)
|
||||
|
||||
@property
|
||||
def extension(self) -> str:
|
||||
def extension(self) -> Optional[str]:
|
||||
"""Get the extension of this icon."""
|
||||
if self.mime is None:
|
||||
return None
|
||||
|
||||
return get_ext(self.mime)
|
||||
|
||||
|
||||
|
|
@ -91,7 +103,7 @@ class ImageError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def to_raw(data_type: str, data: str) -> bytes:
|
||||
def to_raw(data_type: str, data: str) -> Optional[bytes]:
|
||||
"""Given a data type in the data URI and data,
|
||||
give the raw bytes being encoded."""
|
||||
if data_type == 'base64':
|
||||
|
|
@ -176,7 +188,7 @@ def _gen_update_sql(scope: str) -> str:
|
|||
"""
|
||||
|
||||
|
||||
def _invalid(kwargs: dict):
|
||||
def _invalid(kwargs: dict) -> Optional[Icon]:
|
||||
"""Send an invalid value."""
|
||||
if not kwargs.get('always_icon', False):
|
||||
return None
|
||||
|
|
@ -272,7 +284,8 @@ class IconManager:
|
|||
|
||||
return Icon(icon.key, icon.icon_hash, target_mime)
|
||||
|
||||
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon:
|
||||
async def generic_get(self, scope, key, icon_hash,
|
||||
**kwargs) -> Optional[Icon]:
|
||||
"""Get any icon."""
|
||||
if icon_hash is None:
|
||||
return None
|
||||
|
|
@ -302,10 +315,17 @@ class IconManager:
|
|||
|
||||
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
|
||||
|
||||
# ensure we aren't messing with NULLs everywhere.
|
||||
if icon.as_pathlib is None:
|
||||
return None
|
||||
|
||||
if not icon.as_pathlib.exists():
|
||||
await self.delete(icon)
|
||||
return None
|
||||
|
||||
if icon.extension is None:
|
||||
return None
|
||||
|
||||
if 'ext' in kwargs and kwargs['ext'] != icon.extension:
|
||||
return await self._convert_ext(icon, kwargs['ext'])
|
||||
|
||||
|
|
|
|||
|
|
@ -198,7 +198,8 @@ async def role_permissions(guild_id: int, role_id: int,
|
|||
|
||||
async def compute_overwrites(base_perms: Permissions,
|
||||
user_id, channel_id: int,
|
||||
guild_id: int = None, storage=None):
|
||||
guild_id: Optional[int] = None,
|
||||
storage=None):
|
||||
"""Compute the permissions in the context of a channel."""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
|
@ -211,8 +212,12 @@ async def compute_overwrites(base_perms: Permissions,
|
|||
# list of overwrites
|
||||
overwrites = await storage.chan_overwrites(channel_id)
|
||||
|
||||
# if the channel isn't a guild, we should just return
|
||||
# ALL_PERMISSIONS. the old approach was calling guild_from_channel
|
||||
# again, but it is already passed by get_permissions(), so its
|
||||
# redundant.
|
||||
if not guild_id:
|
||||
guild_id = await storage.guild_from_channel(channel_id)
|
||||
return ALL_PERMISSIONS
|
||||
|
||||
# make it a map for better usage
|
||||
overwrites = {int(o['id']): o for o in overwrites}
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class PresenceManager:
|
|||
|
||||
# shards that are in lazy guilds with 'everyone'
|
||||
# enabled
|
||||
in_lazy = []
|
||||
in_lazy: List[str] = []
|
||||
|
||||
for member_list in lists:
|
||||
session_ids = await member_list.pres_update(
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class ChannelDispatcher(DispatcherWithState):
|
|||
await self.unsub(channel_id, user_id)
|
||||
continue
|
||||
|
||||
cur_sess = 0
|
||||
cur_sess = []
|
||||
|
||||
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
||||
and data.get('type') == ChannelType.GROUP_DM.value:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,9 @@ lazy guilds:
|
|||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple
|
||||
from typing import (
|
||||
Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
|
||||
)
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
from logbook import Logger
|
||||
|
|
@ -265,7 +267,7 @@ class GuildMemberList:
|
|||
|
||||
#: store the states that are subscribed to the list.
|
||||
# type is {session_id: set[list]}
|
||||
self.state = defaultdict(set)
|
||||
self.state: Dict[str, Set[List[int, int]]] = defaultdict(set)
|
||||
|
||||
self._list_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -589,7 +591,7 @@ class GuildMemberList:
|
|||
except KeyError:
|
||||
return None
|
||||
|
||||
async def _dispatch_sess(self, session_ids: List[str],
|
||||
async def _dispatch_sess(self, session_ids: Iterable[str],
|
||||
operations: List[Operation]):
|
||||
"""Dispatch a GUILD_MEMBER_LIST_UPDATE to the
|
||||
given session ids."""
|
||||
|
|
@ -613,11 +615,12 @@ class GuildMemberList:
|
|||
}
|
||||
|
||||
states = map(self._get_state, session_ids)
|
||||
states = filter(lambda state: state is not None, states)
|
||||
|
||||
dispatched = []
|
||||
|
||||
for state in states:
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
await state.ws.dispatch(
|
||||
'GUILD_MEMBER_LIST_UPDATE', payload)
|
||||
|
||||
|
|
@ -625,7 +628,7 @@ class GuildMemberList:
|
|||
|
||||
return dispatched
|
||||
|
||||
async def _resync(self, session_ids: List[int],
|
||||
async def _resync(self, session_ids: List[str],
|
||||
item_index: int) -> List[str]:
|
||||
"""Send a SYNC event to all states that are subscribed to an item.
|
||||
|
||||
|
|
@ -661,7 +664,7 @@ class GuildMemberList:
|
|||
|
||||
return result
|
||||
|
||||
async def _resync_by_item(self, item_index: int):
|
||||
async def _resync_by_item(self, item_index: Optional[int]):
|
||||
"""Resync but only giving the item index."""
|
||||
if item_index is None:
|
||||
return []
|
||||
|
|
@ -1339,7 +1342,10 @@ class GuildMemberList:
|
|||
log.debug('there are {} session ids to resync (for item {})',
|
||||
len(sess_ids_resync), role_item_index)
|
||||
|
||||
return await self._resync(sess_ids_resync, role_item_index)
|
||||
if role_item_index is not None:
|
||||
return await self._resync(sess_ids_resync, role_item_index)
|
||||
|
||||
return []
|
||||
|
||||
async def chan_update(self):
|
||||
"""Called then a channel's data has been updated."""
|
||||
|
|
|
|||
|
|
@ -19,8 +19,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Iterable, Optional, Indexable
|
||||
|
||||
from logbook import Logger
|
||||
from typing import Any, Iterable
|
||||
from quart.json import JSONEncoder
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -51,7 +52,7 @@ def dict_get(mapping, key, default):
|
|||
return mapping.get(key) or default
|
||||
|
||||
|
||||
def index_by_func(function, indexable: iter) -> int:
|
||||
def index_by_func(function, indexable: Indexable) -> Optional[int]:
|
||||
"""Search in an idexable and return the index number
|
||||
for an iterm that has func(item) = True."""
|
||||
for index, item in enumerate(indexable):
|
||||
|
|
@ -66,7 +67,7 @@ def _u(val):
|
|||
return val % 0x100000000
|
||||
|
||||
|
||||
def mmh3(key: str, seed: int = 0):
|
||||
def mmh3(inp_str: str, seed: int = 0):
|
||||
"""MurMurHash3 implementation.
|
||||
|
||||
This seems to match Discord's JavaScript implementaiton.
|
||||
|
|
@ -74,7 +75,7 @@ def mmh3(key: str, seed: int = 0):
|
|||
Based off
|
||||
https://github.com/garycourt/murmurhash-js/blob/master/murmurhash3_gc.js
|
||||
"""
|
||||
key = [ord(c) for c in key]
|
||||
key = [ord(c) for c in inp_str]
|
||||
|
||||
remainder = len(key) & 3
|
||||
bytecount = len(key) - remainder
|
||||
|
|
|
|||
Loading…
Reference in New Issue