mirror of https://gitlab.com/litecord/litecord.git
typing, episode 2
This commit is contained in:
parent
ed3c436b6d
commit
9dab5b20ae
|
|
@ -58,11 +58,11 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||||
user_settings = await app.user_storage.get_user_settings(user_id)
|
user_settings = await app.user_storage.get_user_settings(user_id)
|
||||||
peer_settings = await app.user_storage.get_user_settings(peer_id)
|
peer_settings = await app.user_storage.get_user_settings(peer_id)
|
||||||
|
|
||||||
restricted_user = [int(v) for v in user_settings['restricted_guilds']]
|
restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
|
||||||
restricted_peer = [int(v) for v in peer_settings['restricted_guilds']]
|
restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
|
||||||
|
|
||||||
restricted_user = set(restricted_user)
|
restricted_user = set(restricted_user_)
|
||||||
restricted_peer = set(restricted_peer)
|
restricted_peer = set(restricted_peer_)
|
||||||
|
|
||||||
mutual_guilds -= restricted_user
|
mutual_guilds -= restricted_user
|
||||||
mutual_guilds -= restricted_peer
|
mutual_guilds -= restricted_peer
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ litecord.embed.sanitizer
|
||||||
sanitize embeds by giving common values
|
sanitize embeds by giving common values
|
||||||
such as type: rich
|
such as type: rich
|
||||||
"""
|
"""
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any, Optional, Union, List
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
@ -44,7 +44,7 @@ def sanitize_embed(embed: Embed) -> Embed:
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
|
||||||
def path_exists(embed: Embed, components: str):
|
def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
||||||
"""Tell if a given path exists in an embed (or any dictionary).
|
"""Tell if a given path exists in an embed (or any dictionary).
|
||||||
|
|
||||||
The components string is formatted like this:
|
The components string is formatted like this:
|
||||||
|
|
@ -54,10 +54,10 @@ def path_exists(embed: Embed, components: str):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# get the list of components given
|
# get the list of components given
|
||||||
if isinstance(components, str):
|
if isinstance(components_in, str):
|
||||||
components = components.split('.')
|
components = components_in.split('.')
|
||||||
else:
|
else:
|
||||||
components = list(components)
|
components = list(components_in)
|
||||||
|
|
||||||
# if there are no components, we reached the end of recursion
|
# if there are no components, we reached the end of recursion
|
||||||
# and can return true
|
# and can return true
|
||||||
|
|
@ -96,7 +96,7 @@ def proxify(url, *, config=None) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def fetch_metadata(url, *, config=None, session=None) -> dict:
|
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
|
||||||
"""Fetch metadata for a url."""
|
"""Fetch metadata for a url."""
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
|
|
@ -123,7 +123,7 @@ async def fetch_metadata(url, *, config=None, session=None) -> dict:
|
||||||
|
|
||||||
log.warning('failed to generate meta for {!r}: {} {!r}',
|
log.warning('failed to generate meta for {!r}: {} {!r}',
|
||||||
url, resp.status, body)
|
url, resp.status, body)
|
||||||
return
|
return None
|
||||||
|
|
||||||
return await resp.json()
|
return await resp.json()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import mimetypes
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
|
@ -67,22 +68,33 @@ def get_mime(ext: str):
|
||||||
@dataclass
|
@dataclass
|
||||||
class Icon:
|
class Icon:
|
||||||
"""Main icon class"""
|
"""Main icon class"""
|
||||||
key: str
|
key: Optional[str]
|
||||||
icon_hash: str
|
icon_hash: Optional[str]
|
||||||
mime: str
|
mime: Optional[str]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def as_path(self) -> str:
|
def as_path(self) -> Optional[str]:
|
||||||
"""Return a filesystem path for the given icon."""
|
"""Return a filesystem path for the given icon."""
|
||||||
|
if self.mime is None:
|
||||||
|
return None
|
||||||
|
|
||||||
ext = get_ext(self.mime)
|
ext = get_ext(self.mime)
|
||||||
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
|
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def as_pathlib(self) -> str:
|
def as_pathlib(self) -> Optional[Path]:
|
||||||
|
"""Get a Path instance of this icon."""
|
||||||
|
if self.as_path is None:
|
||||||
|
return None
|
||||||
|
|
||||||
return Path(self.as_path)
|
return Path(self.as_path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def extension(self) -> str:
|
def extension(self) -> Optional[str]:
|
||||||
|
"""Get the extension of this icon."""
|
||||||
|
if self.mime is None:
|
||||||
|
return None
|
||||||
|
|
||||||
return get_ext(self.mime)
|
return get_ext(self.mime)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -91,7 +103,7 @@ class ImageError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def to_raw(data_type: str, data: str) -> bytes:
|
def to_raw(data_type: str, data: str) -> Optional[bytes]:
|
||||||
"""Given a data type in the data URI and data,
|
"""Given a data type in the data URI and data,
|
||||||
give the raw bytes being encoded."""
|
give the raw bytes being encoded."""
|
||||||
if data_type == 'base64':
|
if data_type == 'base64':
|
||||||
|
|
@ -176,7 +188,7 @@ def _gen_update_sql(scope: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _invalid(kwargs: dict):
|
def _invalid(kwargs: dict) -> Optional[Icon]:
|
||||||
"""Send an invalid value."""
|
"""Send an invalid value."""
|
||||||
if not kwargs.get('always_icon', False):
|
if not kwargs.get('always_icon', False):
|
||||||
return None
|
return None
|
||||||
|
|
@ -272,7 +284,8 @@ class IconManager:
|
||||||
|
|
||||||
return Icon(icon.key, icon.icon_hash, target_mime)
|
return Icon(icon.key, icon.icon_hash, target_mime)
|
||||||
|
|
||||||
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon:
|
async def generic_get(self, scope, key, icon_hash,
|
||||||
|
**kwargs) -> Optional[Icon]:
|
||||||
"""Get any icon."""
|
"""Get any icon."""
|
||||||
if icon_hash is None:
|
if icon_hash is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -302,10 +315,17 @@ class IconManager:
|
||||||
|
|
||||||
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
|
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
|
||||||
|
|
||||||
|
# ensure we aren't messing with NULLs everywhere.
|
||||||
|
if icon.as_pathlib is None:
|
||||||
|
return None
|
||||||
|
|
||||||
if not icon.as_pathlib.exists():
|
if not icon.as_pathlib.exists():
|
||||||
await self.delete(icon)
|
await self.delete(icon)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if icon.extension is None:
|
||||||
|
return None
|
||||||
|
|
||||||
if 'ext' in kwargs and kwargs['ext'] != icon.extension:
|
if 'ext' in kwargs and kwargs['ext'] != icon.extension:
|
||||||
return await self._convert_ext(icon, kwargs['ext'])
|
return await self._convert_ext(icon, kwargs['ext'])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -198,7 +198,8 @@ async def role_permissions(guild_id: int, role_id: int,
|
||||||
|
|
||||||
async def compute_overwrites(base_perms: Permissions,
|
async def compute_overwrites(base_perms: Permissions,
|
||||||
user_id, channel_id: int,
|
user_id, channel_id: int,
|
||||||
guild_id: int = None, storage=None):
|
guild_id: Optional[int] = None,
|
||||||
|
storage=None):
|
||||||
"""Compute the permissions in the context of a channel."""
|
"""Compute the permissions in the context of a channel."""
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = app.storage
|
storage = app.storage
|
||||||
|
|
@ -211,8 +212,12 @@ async def compute_overwrites(base_perms: Permissions,
|
||||||
# list of overwrites
|
# list of overwrites
|
||||||
overwrites = await storage.chan_overwrites(channel_id)
|
overwrites = await storage.chan_overwrites(channel_id)
|
||||||
|
|
||||||
|
# if the channel isn't a guild, we should just return
|
||||||
|
# ALL_PERMISSIONS. the old approach was calling guild_from_channel
|
||||||
|
# again, but it is already passed by get_permissions(), so its
|
||||||
|
# redundant.
|
||||||
if not guild_id:
|
if not guild_id:
|
||||||
guild_id = await storage.guild_from_channel(channel_id)
|
return ALL_PERMISSIONS
|
||||||
|
|
||||||
# make it a map for better usage
|
# make it a map for better usage
|
||||||
overwrites = {int(o['id']): o for o in overwrites}
|
overwrites = {int(o['id']): o for o in overwrites}
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,7 @@ class PresenceManager:
|
||||||
|
|
||||||
# shards that are in lazy guilds with 'everyone'
|
# shards that are in lazy guilds with 'everyone'
|
||||||
# enabled
|
# enabled
|
||||||
in_lazy = []
|
in_lazy: List[str] = []
|
||||||
|
|
||||||
for member_list in lists:
|
for member_list in lists:
|
||||||
session_ids = await member_list.pres_update(
|
session_ids = await member_list.pres_update(
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ class ChannelDispatcher(DispatcherWithState):
|
||||||
await self.unsub(channel_id, user_id)
|
await self.unsub(channel_id, user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cur_sess = 0
|
cur_sess = []
|
||||||
|
|
||||||
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
||||||
and data.get('type') == ChannelType.GROUP_DM.value:
|
and data.get('type') == ChannelType.GROUP_DM.value:
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,9 @@ lazy guilds:
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple
|
from typing import (
|
||||||
|
Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
|
||||||
|
)
|
||||||
from dataclasses import dataclass, asdict, field
|
from dataclasses import dataclass, asdict, field
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
@ -265,7 +267,7 @@ class GuildMemberList:
|
||||||
|
|
||||||
#: store the states that are subscribed to the list.
|
#: store the states that are subscribed to the list.
|
||||||
# type is {session_id: set[list]}
|
# type is {session_id: set[list]}
|
||||||
self.state = defaultdict(set)
|
self.state: Dict[str, Set[List[int, int]]] = defaultdict(set)
|
||||||
|
|
||||||
self._list_lock = asyncio.Lock()
|
self._list_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
@ -589,7 +591,7 @@ class GuildMemberList:
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _dispatch_sess(self, session_ids: List[str],
|
async def _dispatch_sess(self, session_ids: Iterable[str],
|
||||||
operations: List[Operation]):
|
operations: List[Operation]):
|
||||||
"""Dispatch a GUILD_MEMBER_LIST_UPDATE to the
|
"""Dispatch a GUILD_MEMBER_LIST_UPDATE to the
|
||||||
given session ids."""
|
given session ids."""
|
||||||
|
|
@ -613,11 +615,12 @@ class GuildMemberList:
|
||||||
}
|
}
|
||||||
|
|
||||||
states = map(self._get_state, session_ids)
|
states = map(self._get_state, session_ids)
|
||||||
states = filter(lambda state: state is not None, states)
|
|
||||||
|
|
||||||
dispatched = []
|
dispatched = []
|
||||||
|
|
||||||
for state in states:
|
for state in states:
|
||||||
|
if state is None:
|
||||||
|
continue
|
||||||
|
|
||||||
await state.ws.dispatch(
|
await state.ws.dispatch(
|
||||||
'GUILD_MEMBER_LIST_UPDATE', payload)
|
'GUILD_MEMBER_LIST_UPDATE', payload)
|
||||||
|
|
||||||
|
|
@ -625,7 +628,7 @@ class GuildMemberList:
|
||||||
|
|
||||||
return dispatched
|
return dispatched
|
||||||
|
|
||||||
async def _resync(self, session_ids: List[int],
|
async def _resync(self, session_ids: List[str],
|
||||||
item_index: int) -> List[str]:
|
item_index: int) -> List[str]:
|
||||||
"""Send a SYNC event to all states that are subscribed to an item.
|
"""Send a SYNC event to all states that are subscribed to an item.
|
||||||
|
|
||||||
|
|
@ -661,7 +664,7 @@ class GuildMemberList:
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _resync_by_item(self, item_index: int):
|
async def _resync_by_item(self, item_index: Optional[int]):
|
||||||
"""Resync but only giving the item index."""
|
"""Resync but only giving the item index."""
|
||||||
if item_index is None:
|
if item_index is None:
|
||||||
return []
|
return []
|
||||||
|
|
@ -1339,7 +1342,10 @@ class GuildMemberList:
|
||||||
log.debug('there are {} session ids to resync (for item {})',
|
log.debug('there are {} session ids to resync (for item {})',
|
||||||
len(sess_ids_resync), role_item_index)
|
len(sess_ids_resync), role_item_index)
|
||||||
|
|
||||||
return await self._resync(sess_ids_resync, role_item_index)
|
if role_item_index is not None:
|
||||||
|
return await self._resync(sess_ids_resync, role_item_index)
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
async def chan_update(self):
|
async def chan_update(self):
|
||||||
"""Called then a channel's data has been updated."""
|
"""Called then a channel's data has been updated."""
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Iterable, Optional, Indexable
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from typing import Any, Iterable
|
|
||||||
from quart.json import JSONEncoder
|
from quart.json import JSONEncoder
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -51,7 +52,7 @@ def dict_get(mapping, key, default):
|
||||||
return mapping.get(key) or default
|
return mapping.get(key) or default
|
||||||
|
|
||||||
|
|
||||||
def index_by_func(function, indexable: iter) -> int:
|
def index_by_func(function, indexable: Indexable) -> Optional[int]:
|
||||||
"""Search in an idexable and return the index number
|
"""Search in an idexable and return the index number
|
||||||
for an iterm that has func(item) = True."""
|
for an iterm that has func(item) = True."""
|
||||||
for index, item in enumerate(indexable):
|
for index, item in enumerate(indexable):
|
||||||
|
|
@ -66,7 +67,7 @@ def _u(val):
|
||||||
return val % 0x100000000
|
return val % 0x100000000
|
||||||
|
|
||||||
|
|
||||||
def mmh3(key: str, seed: int = 0):
|
def mmh3(inp_str: str, seed: int = 0):
|
||||||
"""MurMurHash3 implementation.
|
"""MurMurHash3 implementation.
|
||||||
|
|
||||||
This seems to match Discord's JavaScript implementaiton.
|
This seems to match Discord's JavaScript implementaiton.
|
||||||
|
|
@ -74,7 +75,7 @@ def mmh3(key: str, seed: int = 0):
|
||||||
Based off
|
Based off
|
||||||
https://github.com/garycourt/murmurhash-js/blob/master/murmurhash3_gc.js
|
https://github.com/garycourt/murmurhash-js/blob/master/murmurhash3_gc.js
|
||||||
"""
|
"""
|
||||||
key = [ord(c) for c in key]
|
key = [ord(c) for c in inp_str]
|
||||||
|
|
||||||
remainder = len(key) & 3
|
remainder = len(key) & 3
|
||||||
bytecount = len(key) - remainder
|
bytecount = len(key) - remainder
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue