typing, episode 2

This commit is contained in:
Luna 2019-03-04 15:48:51 -03:00
parent ed3c436b6d
commit 9dab5b20ae
8 changed files with 68 additions and 36 deletions

View File

@ -58,11 +58,11 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
user_settings = await app.user_storage.get_user_settings(user_id)
peer_settings = await app.user_storage.get_user_settings(peer_id)
restricted_user = [int(v) for v in user_settings['restricted_guilds']]
restricted_peer = [int(v) for v in peer_settings['restricted_guilds']]
restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
restricted_user = set(restricted_user)
restricted_peer = set(restricted_peer)
restricted_user = set(restricted_user_)
restricted_peer = set(restricted_peer_)
mutual_guilds -= restricted_user
mutual_guilds -= restricted_peer

View File

@ -22,7 +22,7 @@ litecord.embed.sanitizer
sanitize embeds by giving common values
such as type: rich
"""
from typing import Dict, Any
from typing import Dict, Any, Optional, Union, List
from logbook import Logger
from quart import current_app as app
@ -44,7 +44,7 @@ def sanitize_embed(embed: Embed) -> Embed:
}}
def path_exists(embed: Embed, components: str):
def path_exists(embed: Embed, components_in: Union[List[str], str]):
"""Tell if a given path exists in an embed (or any dictionary).
The components string is formatted like this:
@ -54,10 +54,10 @@ def path_exists(embed: Embed, components: str):
"""
# get the list of components given
if isinstance(components, str):
components = components.split('.')
if isinstance(components_in, str):
components = components_in.split('.')
else:
components = list(components)
components = list(components_in)
# if there are no components, we reached the end of recursion
# and can return true
@ -96,7 +96,7 @@ def proxify(url, *, config=None) -> str:
)
async def fetch_metadata(url, *, config=None, session=None) -> dict:
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
"""Fetch metadata for a url."""
if session is None:
@ -123,7 +123,7 @@ async def fetch_metadata(url, *, config=None, session=None) -> dict:
log.warning('failed to generate meta for {!r}: {} {!r}',
url, resp.status, body)
return
return None
return await resp.json()

View File

@ -22,6 +22,7 @@ import mimetypes
import asyncio
import base64
import tempfile
from typing import Optional
from dataclasses import dataclass
from hashlib import sha256
@ -67,22 +68,33 @@ def get_mime(ext: str):
@dataclass
class Icon:
"""Main icon class"""
key: str
icon_hash: str
mime: str
key: Optional[str]
icon_hash: Optional[str]
mime: Optional[str]
@property
def as_path(self) -> str:
def as_path(self) -> Optional[str]:
"""Return a filesystem path for the given icon."""
if self.mime is None:
return None
ext = get_ext(self.mime)
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
@property
def as_pathlib(self) -> str:
def as_pathlib(self) -> Optional[Path]:
"""Get a Path instance of this icon."""
if self.as_path is None:
return None
return Path(self.as_path)
@property
def extension(self) -> str:
def extension(self) -> Optional[str]:
"""Get the extension of this icon."""
if self.mime is None:
return None
return get_ext(self.mime)
@ -91,7 +103,7 @@ class ImageError(Exception):
pass
def to_raw(data_type: str, data: str) -> bytes:
def to_raw(data_type: str, data: str) -> Optional[bytes]:
"""Given a data type in the data URI and data,
give the raw bytes being encoded."""
if data_type == 'base64':
@ -176,7 +188,7 @@ def _gen_update_sql(scope: str) -> str:
"""
def _invalid(kwargs: dict):
def _invalid(kwargs: dict) -> Optional[Icon]:
"""Send an invalid value."""
if not kwargs.get('always_icon', False):
return None
@ -272,7 +284,8 @@ class IconManager:
return Icon(icon.key, icon.icon_hash, target_mime)
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon:
async def generic_get(self, scope, key, icon_hash,
**kwargs) -> Optional[Icon]:
"""Get any icon."""
if icon_hash is None:
return None
@ -302,10 +315,17 @@ class IconManager:
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
# ensure we aren't messing with NULLs everywhere.
if icon.as_pathlib is None:
return None
if not icon.as_pathlib.exists():
await self.delete(icon)
return None
if icon.extension is None:
return None
if 'ext' in kwargs and kwargs['ext'] != icon.extension:
return await self._convert_ext(icon, kwargs['ext'])

View File

@ -198,7 +198,8 @@ async def role_permissions(guild_id: int, role_id: int,
async def compute_overwrites(base_perms: Permissions,
user_id, channel_id: int,
guild_id: int = None, storage=None):
guild_id: Optional[int] = None,
storage=None):
"""Compute the permissions in the context of a channel."""
if not storage:
storage = app.storage
@ -211,8 +212,12 @@ async def compute_overwrites(base_perms: Permissions,
# list of overwrites
overwrites = await storage.chan_overwrites(channel_id)
# if the channel isn't a guild, we should just return
# ALL_PERMISSIONS. the old approach was calling guild_from_channel
# again, but it is already passed by get_permissions(), so its
# redundant.
if not guild_id:
guild_id = await storage.guild_from_channel(channel_id)
return ALL_PERMISSIONS
# make it a map for better usage
overwrites = {int(o['id']): o for o in overwrites}

View File

@ -127,7 +127,7 @@ class PresenceManager:
# shards that are in lazy guilds with 'everyone'
# enabled
in_lazy = []
in_lazy: List[str] = []
for member_list in lists:
session_ids = await member_list.pres_update(

View File

@ -84,7 +84,7 @@ class ChannelDispatcher(DispatcherWithState):
await self.unsub(channel_id, user_id)
continue
cur_sess = 0
cur_sess = []
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
and data.get('type') == ChannelType.GROUP_DM.value:

View File

@ -28,7 +28,9 @@ lazy guilds:
import asyncio
from collections import defaultdict
from typing import Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple
from typing import (
Any, List, Dict, Union, Optional, Iterable, Iterator, Tuple, Set
)
from dataclasses import dataclass, asdict, field
from logbook import Logger
@ -265,7 +267,7 @@ class GuildMemberList:
#: store the states that are subscribed to the list.
# type is {session_id: set[list]}
self.state = defaultdict(set)
self.state: Dict[str, Set[List[int, int]]] = defaultdict(set)
self._list_lock = asyncio.Lock()
@ -589,7 +591,7 @@ class GuildMemberList:
except KeyError:
return None
async def _dispatch_sess(self, session_ids: List[str],
async def _dispatch_sess(self, session_ids: Iterable[str],
operations: List[Operation]):
"""Dispatch a GUILD_MEMBER_LIST_UPDATE to the
given session ids."""
@ -613,11 +615,12 @@ class GuildMemberList:
}
states = map(self._get_state, session_ids)
states = filter(lambda state: state is not None, states)
dispatched = []
for state in states:
if state is None:
continue
await state.ws.dispatch(
'GUILD_MEMBER_LIST_UPDATE', payload)
@ -625,7 +628,7 @@ class GuildMemberList:
return dispatched
async def _resync(self, session_ids: List[int],
async def _resync(self, session_ids: List[str],
item_index: int) -> List[str]:
"""Send a SYNC event to all states that are subscribed to an item.
@ -661,7 +664,7 @@ class GuildMemberList:
return result
async def _resync_by_item(self, item_index: int):
async def _resync_by_item(self, item_index: Optional[int]):
"""Resync but only giving the item index."""
if item_index is None:
return []
@ -1339,7 +1342,10 @@ class GuildMemberList:
log.debug('there are {} session ids to resync (for item {})',
len(sess_ids_resync), role_item_index)
return await self._resync(sess_ids_resync, role_item_index)
if role_item_index is not None:
return await self._resync(sess_ids_resync, role_item_index)
return []
async def chan_update(self):
"""Called then a channel's data has been updated."""

View File

@ -19,8 +19,9 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import json
from typing import Any, Iterable, Optional, Indexable
from logbook import Logger
from typing import Any, Iterable
from quart.json import JSONEncoder
log = Logger(__name__)
@ -51,7 +52,7 @@ def dict_get(mapping, key, default):
return mapping.get(key) or default
def index_by_func(function, indexable: iter) -> int:
def index_by_func(function, indexable: Indexable) -> Optional[int]:
"""Search in an idexable and return the index number
for an iterm that has func(item) = True."""
for index, item in enumerate(indexable):
@ -66,7 +67,7 @@ def _u(val):
return val % 0x100000000
def mmh3(key: str, seed: int = 0):
def mmh3(inp_str: str, seed: int = 0):
"""MurMurHash3 implementation.
This seems to match Discord's JavaScript implementaiton.
@ -74,7 +75,7 @@ def mmh3(key: str, seed: int = 0):
Based off
https://github.com/garycourt/murmurhash-js/blob/master/murmurhash3_gc.js
"""
key = [ord(c) for c in key]
key = [ord(c) for c in inp_str]
remainder = len(key) & 3
bytecount = len(key) - remainder