typing, episode 2

This commit is contained in:
Luna 2019-03-04 15:48:51 -03:00
parent ed3c436b6d
commit 9dab5b20ae
8 changed files with 68 additions and 36 deletions

View File

@ -58,11 +58,11 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
user_settings = await app.user_storage.get_user_settings(user_id) user_settings = await app.user_storage.get_user_settings(user_id)
peer_settings = await app.user_storage.get_user_settings(peer_id) peer_settings = await app.user_storage.get_user_settings(peer_id)
restricted_user = [int(v) for v in user_settings['restricted_guilds']] restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
restricted_peer = [int(v) for v in peer_settings['restricted_guilds']] restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
restricted_user = set(restricted_user) restricted_user = set(restricted_user_)
restricted_peer = set(restricted_peer) restricted_peer = set(restricted_peer_)
mutual_guilds -= restricted_user mutual_guilds -= restricted_user
mutual_guilds -= restricted_peer mutual_guilds -= restricted_peer

View File

@ -22,7 +22,7 @@ litecord.embed.sanitizer
sanitize embeds by giving common values sanitize embeds by giving common values
such as type: rich such as type: rich
""" """
from typing import Dict, Any from typing import Dict, Any, Optional, Union, List
from logbook import Logger from logbook import Logger
from quart import current_app as app from quart import current_app as app
@ -44,7 +44,7 @@ def sanitize_embed(embed: Embed) -> Embed:
}} }}
def path_exists(embed: Embed, components: str): def path_exists(embed: Embed, components_in: Union[List[str], str]):
"""Tell if a given path exists in an embed (or any dictionary). """Tell if a given path exists in an embed (or any dictionary).
The components string is formatted like this: The components string is formatted like this:
@ -54,10 +54,10 @@ def path_exists(embed: Embed, components: str):
""" """
# get the list of components given # get the list of components given
if isinstance(components, str): if isinstance(components_in, str):
components = components.split('.') components = components_in.split('.')
else: else:
components = list(components) components = list(components_in)
# if there are no components, we reached the end of recursion # if there are no components, we reached the end of recursion
# and can return true # and can return true
@ -96,7 +96,7 @@ def proxify(url, *, config=None) -> str:
) )
async def fetch_metadata(url, *, config=None, session=None) -> dict: async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
"""Fetch metadata for a url.""" """Fetch metadata for a url."""
if session is None: if session is None:
@ -123,7 +123,7 @@ async def fetch_metadata(url, *, config=None, session=None) -> dict:
log.warning('failed to generate meta for {!r}: {} {!r}', log.warning('failed to generate meta for {!r}: {} {!r}',
url, resp.status, body) url, resp.status, body)
return return None
return await resp.json() return await resp.json()

View File

@ -22,6 +22,7 @@ import mimetypes
import asyncio import asyncio
import base64 import base64
import tempfile import tempfile
from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from hashlib import sha256 from hashlib import sha256
@ -67,22 +68,33 @@ def get_mime(ext: str):
@dataclass @dataclass
class Icon: class Icon:
"""Main icon class""" """Main icon class"""
key: str key: Optional[str]
icon_hash: str icon_hash: Optional[str]
mime: str mime: Optional[str]
@property @property
def as_path(self) -> str: def as_path(self) -> Optional[str]:
"""Return a filesystem path for the given icon.""" """Return a filesystem path for the given icon."""
if self.mime is None:
return None
ext = get_ext(self.mime) ext = get_ext(self.mime)
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}') return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
@property @property
def as_pathlib(self) -> str: def as_pathlib(self) -> Optional[Path]:
"""Get a Path instance of this icon."""
if self.as_path is None:
return None
return Path(self.as_path) return Path(self.as_path)
@property @property
def extension(self) -> str: def extension(self) -> Optional[str]:
"""Get the extension of this icon."""
if self.mime is None:
return None
return get_ext(self.mime) return get_ext(self.mime)
@ -91,7 +103,7 @@ class ImageError(Exception):
pass pass
def to_raw(data_type: str, data: str) -> bytes: def to_raw(data_type: str, data: str) -> Optional[bytes]:
"""Given a data type in the data URI and data, """Given a data type in the data URI and data,
give the raw bytes being encoded.""" give the raw bytes being encoded."""
if data_type == 'base64': if data_type == 'base64':
@ -176,7 +188,7 @@ def _gen_update_sql(scope: str) -> str:
""" """
def _invalid(kwargs: dict): def _invalid(kwargs: dict) -> Optional[Icon]:
"""Send an invalid value.""" """Send an invalid value."""
if not kwargs.get('always_icon', False): if not kwargs.get('always_icon', False):
return None return None
@ -272,7 +284,8 @@ class IconManager:
return Icon(icon.key, icon.icon_hash, target_mime) return Icon(icon.key, icon.icon_hash, target_mime)
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon: async def generic_get(self, scope, key, icon_hash,
**kwargs) -> Optional[Icon]:
"""Get any icon.""" """Get any icon."""
if icon_hash is None: if icon_hash is None:
return None return None
@ -302,10 +315,17 @@ class IconManager:
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime']) icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
# ensure we aren't messing with NULLs everywhere.
if icon.as_pathlib is None:
return None
if not icon.as_pathlib.exists(): if not icon.as_pathlib.exists():
await self.delete(icon) await self.delete(icon)
return None return None
if icon.extension is None:
return None
if 'ext' in kwargs and kwargs['ext'] != icon.extension: if 'ext' in kwargs and kwargs['ext'] != icon.extension:
return await self._convert_ext(icon, kwargs['ext']) return await self._convert_ext(icon, kwargs['ext'])

View File

@ -198,7 +198,8 @@ async def role_permissions(guild_id: int, role_id: int,
async def compute_overwrites(base_perms: Permissions, async def compute_overwrites(base_perms: Permissions,
user_id, channel_id: int, user_id, channel_id: int,
guild_id: int = None, storage=None): guild_id: Optional[int] = None,
storage=None):
"""Compute the permissions in the context of a channel.""" """Compute the permissions in the context of a channel."""
if not storage: if not storage:
storage = app.storage storage = app.storage
@ -211,8 +212,12 @@ async def compute_overwrites(base_perms: Permissions,
# list of overwrites # list of overwrites
overwrites = await storage.chan_overwrites(channel_id) overwrites = await storage.chan_overwrites(channel_id)
# if the channel isn't a guild, we should just return
# ALL_PERMISSIONS. the old approach was calling guild_from_channel
# again, but it is already passed by get_permissions(), so its
# redundant.
if not guild_id: if not guild_id:
guild_id = await storage.guild_from_channel(channel_id) return ALL_PERMISSIONS
# make it a map for better usage # make it a map for better usage
overwrites = {int(o['id']): o for o in overwrites} overwrites = {int(o['id']): o for o in overwrites}

View File

@ -127,7 +127,7 @@ class PresenceManager:
# shards that are in lazy guilds with 'everyone' # shards that are in lazy guilds with 'everyone'
# enabled # enabled
in_lazy = [] in_lazy: List[str] = []
for member_list in lists: for member_list in lists:
session_ids = await member_list.pres_update( session_ids = await member_list.pres_update(

View File

@ -84,7 +84,7 @@ class ChannelDispatcher(DispatcherWithState):
await self.unsub(channel_id, user_id) await self.unsub(channel_id, user_id)
continue continue
cur_sess = 0 cur_sess = []
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \ if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
and data.get('type') == ChannelType.GROUP_DM.value: and data.get('type') == ChannelType.GROUP_DM.value:

View File

@ -28,7 +28,9 @@ lazy guilds:
import asyncio import asyncio
from collections import defaultdict from collections import defaultdict
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple from typing import (
Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
)
from dataclasses import dataclass, asdict, field from dataclasses import dataclass, asdict, field
from logbook import Logger from logbook import Logger
@ -265,7 +267,7 @@ class GuildMemberList:
#: store the states that are subscribed to the list. #: store the states that are subscribed to the list.
# type is {session_id: set[list]} # type is {session_id: set[list]}
self.state = defaultdict(set) self.state: Dict[str, Set[List[int, int]]] = defaultdict(set)
self._list_lock = asyncio.Lock() self._list_lock = asyncio.Lock()
@ -589,7 +591,7 @@ class GuildMemberList:
except KeyError: except KeyError:
return None return None
async def _dispatch_sess(self, session_ids: List[str], async def _dispatch_sess(self, session_ids: Iterable[str],
operations: List[Operation]): operations: List[Operation]):
"""Dispatch a GUILD_MEMBER_LIST_UPDATE to the """Dispatch a GUILD_MEMBER_LIST_UPDATE to the
given session ids.""" given session ids."""
@ -613,11 +615,12 @@ class GuildMemberList:
} }
states = map(self._get_state, session_ids) states = map(self._get_state, session_ids)
states = filter(lambda state: state is not None, states)
dispatched = [] dispatched = []
for state in states: for state in states:
if state is None:
continue
await state.ws.dispatch( await state.ws.dispatch(
'GUILD_MEMBER_LIST_UPDATE', payload) 'GUILD_MEMBER_LIST_UPDATE', payload)
@ -625,7 +628,7 @@ class GuildMemberList:
return dispatched return dispatched
async def _resync(self, session_ids: List[int], async def _resync(self, session_ids: List[str],
item_index: int) -> List[str]: item_index: int) -> List[str]:
"""Send a SYNC event to all states that are subscribed to an item. """Send a SYNC event to all states that are subscribed to an item.
@ -661,7 +664,7 @@ class GuildMemberList:
return result return result
async def _resync_by_item(self, item_index: int): async def _resync_by_item(self, item_index: Optional[int]):
"""Resync but only giving the item index.""" """Resync but only giving the item index."""
if item_index is None: if item_index is None:
return [] return []
@ -1339,8 +1342,11 @@ class GuildMemberList:
log.debug('there are {} session ids to resync (for item {})', log.debug('there are {} session ids to resync (for item {})',
len(sess_ids_resync), role_item_index) len(sess_ids_resync), role_item_index)
if role_item_index is not None:
return await self._resync(sess_ids_resync, role_item_index) return await self._resync(sess_ids_resync, role_item_index)
return []
async def chan_update(self): async def chan_update(self):
"""Called then a channel's data has been updated.""" """Called then a channel's data has been updated."""
await self._fetch_overwrites() await self._fetch_overwrites()

View File

@ -19,8 +19,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import json import json
from typing import Any, Iterable, Optional, Indexable
from logbook import Logger from logbook import Logger
from typing import Any, Iterable
from quart.json import JSONEncoder from quart.json import JSONEncoder
log = Logger(__name__) log = Logger(__name__)
@ -51,7 +52,7 @@ def dict_get(mapping, key, default):
return mapping.get(key) or default return mapping.get(key) or default
def index_by_func(function, indexable: iter) -> int: def index_by_func(function, indexable: Indexable) -> Optional[int]:
"""Search in an idexable and return the index number """Search in an idexable and return the index number
for an iterm that has func(item) = True.""" for an iterm that has func(item) = True."""
for index, item in enumerate(indexable): for index, item in enumerate(indexable):
@ -66,7 +67,7 @@ def _u(val):
return val % 0x100000000 return val % 0x100000000
def mmh3(key: str, seed: int = 0): def mmh3(inp_str: str, seed: int = 0):
"""MurMurHash3 implementation. """MurMurHash3 implementation.
This seems to match Discord's JavaScript implementaiton. This seems to match Discord's JavaScript implementaiton.
@ -74,7 +75,7 @@ def mmh3(key: str, seed: int = 0):
Based off Based off
https://github.com/garycourt/murmurhash-js/blob/master/murmurhash3_gc.js https://github.com/garycourt/murmurhash-js/blob/master/murmurhash3_gc.js
""" """
key = [ord(c) for c in key] key = [ord(c) for c in inp_str]
remainder = len(key) & 3 remainder = len(key) & 3
bytecount = len(key) - remainder bytecount = len(key) - remainder