mirror of https://gitlab.com/litecord/litecord.git
black fmt pass
This commit is contained in:
parent
0bc4b1ba3f
commit
83a1c1ae29
25
config.ci.py
25
config.ci.py
|
|
@ -17,13 +17,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MODE = 'CI'
|
MODE = "CI"
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Default configuration values for litecord."""
|
"""Default configuration values for litecord."""
|
||||||
MAIN_URL = 'localhost:1'
|
|
||||||
NAME = 'gitlab ci'
|
MAIN_URL = "localhost:1"
|
||||||
|
NAME = "gitlab ci"
|
||||||
|
|
||||||
# Enable debug logging?
|
# Enable debug logging?
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
@ -37,11 +38,11 @@ class Config:
|
||||||
# Set this url to somewhere *your users*
|
# Set this url to somewhere *your users*
|
||||||
# will hit the websocket.
|
# will hit the websocket.
|
||||||
# e.g 'gateway.example.com' for reverse proxies.
|
# e.g 'gateway.example.com' for reverse proxies.
|
||||||
WEBSOCKET_URL = 'localhost:5001'
|
WEBSOCKET_URL = "localhost:5001"
|
||||||
|
|
||||||
# Where to host the websocket?
|
# Where to host the websocket?
|
||||||
# (a local address the server will bind to)
|
# (a local address the server will bind to)
|
||||||
WS_HOST = 'localhost'
|
WS_HOST = "localhost"
|
||||||
WS_PORT = 5001
|
WS_PORT = 5001
|
||||||
|
|
||||||
# Postgres credentials
|
# Postgres credentials
|
||||||
|
|
@ -51,10 +52,10 @@ class Config:
|
||||||
class Development(Config):
|
class Development(Config):
|
||||||
DEBUG = True
|
DEBUG = True
|
||||||
POSTGRES = {
|
POSTGRES = {
|
||||||
'host': 'localhost',
|
"host": "localhost",
|
||||||
'user': 'litecord',
|
"user": "litecord",
|
||||||
'password': '123',
|
"password": "123",
|
||||||
'database': 'litecord',
|
"database": "litecord",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,8 +67,4 @@ class Production(Config):
|
||||||
class CI(Config):
|
class CI(Config):
|
||||||
DEBUG = True
|
DEBUG = True
|
||||||
|
|
||||||
POSTGRES = {
|
POSTGRES = {"host": "postgres", "user": "postgres", "password": ""}
|
||||||
'host': 'postgres',
|
|
||||||
'user': 'postgres',
|
|
||||||
'password': ''
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -17,16 +17,17 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MODE = 'Development'
|
MODE = "Development"
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Default configuration values for litecord."""
|
"""Default configuration values for litecord."""
|
||||||
|
|
||||||
#: Main URL of the instance.
|
#: Main URL of the instance.
|
||||||
MAIN_URL = 'discordapp.io'
|
MAIN_URL = "discordapp.io"
|
||||||
|
|
||||||
#: Name of the instance
|
#: Name of the instance
|
||||||
NAME = 'Litecord/Nya'
|
NAME = "Litecord/Nya"
|
||||||
|
|
||||||
#: Enable debug logging?
|
#: Enable debug logging?
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
@ -45,17 +46,17 @@ class Config:
|
||||||
# Set this url to somewhere *your users*
|
# Set this url to somewhere *your users*
|
||||||
# will hit the websocket.
|
# will hit the websocket.
|
||||||
# e.g 'gateway.example.com' for reverse proxies.
|
# e.g 'gateway.example.com' for reverse proxies.
|
||||||
WEBSOCKET_URL = 'localhost:5001'
|
WEBSOCKET_URL = "localhost:5001"
|
||||||
|
|
||||||
#: Where to host the websocket?
|
#: Where to host the websocket?
|
||||||
# (a local address the server will bind to)
|
# (a local address the server will bind to)
|
||||||
WS_HOST = '0.0.0.0'
|
WS_HOST = "0.0.0.0"
|
||||||
WS_PORT = 5001
|
WS_PORT = 5001
|
||||||
|
|
||||||
#: Mediaproxy URL on the internet
|
#: Mediaproxy URL on the internet
|
||||||
# mediaproxy is made to prevent client IPs being leaked.
|
# mediaproxy is made to prevent client IPs being leaked.
|
||||||
# None is a valid value if you don't want to deploy mediaproxy.
|
# None is a valid value if you don't want to deploy mediaproxy.
|
||||||
MEDIA_PROXY = 'localhost:5002'
|
MEDIA_PROXY = "localhost:5002"
|
||||||
|
|
||||||
#: Postgres credentials
|
#: Postgres credentials
|
||||||
POSTGRES = {}
|
POSTGRES = {}
|
||||||
|
|
@ -65,10 +66,10 @@ class Development(Config):
|
||||||
DEBUG = True
|
DEBUG = True
|
||||||
|
|
||||||
POSTGRES = {
|
POSTGRES = {
|
||||||
'host': 'localhost',
|
"host": "localhost",
|
||||||
'user': 'litecord',
|
"user": "litecord",
|
||||||
'password': '123',
|
"password": "123",
|
||||||
'database': 'litecord',
|
"database": "litecord",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -77,8 +78,8 @@ class Production(Config):
|
||||||
IS_SSL = True
|
IS_SSL = True
|
||||||
|
|
||||||
POSTGRES = {
|
POSTGRES = {
|
||||||
'host': 'some_production_postgres',
|
"host": "some_production_postgres",
|
||||||
'user': 'some_production_user',
|
"user": "some_production_user",
|
||||||
'password': 'some_production_password',
|
"password": "some_production_password",
|
||||||
'database': 'litecord_or_anything_else_really',
|
"database": "litecord_or_anything_else_really",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,42 +19,33 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from litecord.enums import Feature, UserFlags
|
from litecord.enums import Feature, UserFlags
|
||||||
|
|
||||||
VOICE_SERVER = {
|
VOICE_SERVER = {"hostname": {"type": "string", "maxlength": 255, "required": True}}
|
||||||
'hostname': {'type': 'string', 'maxlength': 255, 'required': True}
|
|
||||||
}
|
|
||||||
|
|
||||||
VOICE_REGION = {
|
VOICE_REGION = {
|
||||||
'id': {'type': 'string', 'maxlength': 255, 'required': True},
|
"id": {"type": "string", "maxlength": 255, "required": True},
|
||||||
'name': {'type': 'string', 'maxlength': 255, 'required': True},
|
"name": {"type": "string", "maxlength": 255, "required": True},
|
||||||
|
"vip": {"type": "boolean", "default": False},
|
||||||
'vip': {'type': 'boolean', 'default': False},
|
"deprecated": {"type": "boolean", "default": False},
|
||||||
'deprecated': {'type': 'boolean', 'default': False},
|
"custom": {"type": "boolean", "default": False},
|
||||||
'custom': {'type': 'boolean', 'default': False},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
FEATURES = {
|
FEATURES = {
|
||||||
'features': {
|
"features": {
|
||||||
'type': 'list', 'required': True,
|
"type": "list",
|
||||||
|
"required": True,
|
||||||
# using Feature doesn't seem to work with a "not callable" error.
|
# using Feature doesn't seem to work with a "not callable" error.
|
||||||
'schema': {'coerce': lambda x: Feature(x)}
|
"schema": {"coerce": lambda x: Feature(x)},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
USER_CREATE = {
|
USER_CREATE = {
|
||||||
'username': {'type': 'username', 'required': True},
|
"username": {"type": "username", "required": True},
|
||||||
'email': {'type': 'email', 'required': True},
|
"email": {"type": "email", "required": True},
|
||||||
'password': {'type': 'string', 'minlength': 5, 'required': True},
|
"password": {"type": "string", "minlength": 5, "required": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANCE_INVITE = {
|
INSTANCE_INVITE = {"max_uses": {"type": "integer", "required": True}}
|
||||||
'max_uses': {'type': 'integer', 'required': True}
|
|
||||||
}
|
|
||||||
|
|
||||||
GUILD_UPDATE = {
|
GUILD_UPDATE = {"unavailable": {"type": "boolean", "required": False}}
|
||||||
'unavailable': {'type': 'boolean', 'required': False}
|
|
||||||
}
|
|
||||||
|
|
||||||
USER_UPDATE = {
|
USER_UPDATE = {"flags": {"required": False, "coerce": UserFlags.from_int}}
|
||||||
'flags': {'required': False, 'coerce': UserFlags.from_int}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
100
litecord/auth.py
100
litecord/auth.py
|
|
@ -55,44 +55,50 @@ async def raw_token_check(token: str, db=None) -> int:
|
||||||
|
|
||||||
# just try by fragments instead of
|
# just try by fragments instead of
|
||||||
# unpacking
|
# unpacking
|
||||||
fragments = token.split('.')
|
fragments = token.split(".")
|
||||||
user_id = fragments[0]
|
user_id = fragments[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = base64.b64decode(user_id.encode())
|
user_id = base64.b64decode(user_id.encode())
|
||||||
user_id = int(user_id)
|
user_id = int(user_id)
|
||||||
except (ValueError, binascii.Error):
|
except (ValueError, binascii.Error):
|
||||||
raise Unauthorized('Invalid user ID type')
|
raise Unauthorized("Invalid user ID type")
|
||||||
|
|
||||||
pwd_hash = await db.fetchval("""
|
pwd_hash = await db.fetchval(
|
||||||
|
"""
|
||||||
SELECT password_hash
|
SELECT password_hash
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not pwd_hash:
|
if not pwd_hash:
|
||||||
raise Unauthorized('User ID not found')
|
raise Unauthorized("User ID not found")
|
||||||
|
|
||||||
signer = TimestampSigner(pwd_hash)
|
signer = TimestampSigner(pwd_hash)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
signer.unsign(token)
|
signer.unsign(token)
|
||||||
log.debug('login for uid {} successful', user_id)
|
log.debug("login for uid {} successful", user_id)
|
||||||
|
|
||||||
# update the user's last_session field
|
# update the user's last_session field
|
||||||
# so that we can keep an exact track of activity,
|
# so that we can keep an exact track of activity,
|
||||||
# even on long-lived single sessions (that can happen
|
# even on long-lived single sessions (that can happen
|
||||||
# with people leaving their clients open forever)
|
# with people leaving their clients open forever)
|
||||||
await db.execute("""
|
await db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET last_session = (now() at time zone 'utc')
|
SET last_session = (now() at time zone 'utc')
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
except BadSignature:
|
except BadSignature:
|
||||||
log.warning('token failed for uid {}', user_id)
|
log.warning("token failed for uid {}", user_id)
|
||||||
raise Forbidden('Invalid token')
|
raise Forbidden("Invalid token")
|
||||||
|
|
||||||
|
|
||||||
async def token_check() -> int:
|
async def token_check() -> int:
|
||||||
|
|
@ -104,12 +110,12 @@ async def token_check() -> int:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token = request.headers['Authorization']
|
token = request.headers["Authorization"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise Unauthorized('No token provided')
|
raise Unauthorized("No token provided")
|
||||||
|
|
||||||
if token.startswith('Bot '):
|
if token.startswith("Bot "):
|
||||||
token = token.replace('Bot ', '')
|
token = token.replace("Bot ", "")
|
||||||
|
|
||||||
user_id = await raw_token_check(token)
|
user_id = await raw_token_check(token)
|
||||||
request.user_id = user_id
|
request.user_id = user_id
|
||||||
|
|
@ -120,15 +126,18 @@ async def admin_check() -> int:
|
||||||
"""Check if the user is an admin."""
|
"""Check if the user is an admin."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
flags = await app.db.fetchval("""
|
flags = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT flags
|
SELECT flags
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
flags = UserFlags.from_int(flags)
|
flags = UserFlags.from_int(flags)
|
||||||
if not flags.is_staff:
|
if not flags.is_staff:
|
||||||
raise Unauthorized('you are not staff')
|
raise Unauthorized("you are not staff")
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
|
@ -138,9 +147,7 @@ async def hash_data(data: str, loop=None) -> str:
|
||||||
loop = loop or app.loop
|
loop = loop or app.loop
|
||||||
buf = data.encode()
|
buf = data.encode()
|
||||||
|
|
||||||
hashed = await loop.run_in_executor(
|
hashed = await loop.run_in_executor(None, bcrypt.hashpw, buf, bcrypt.gensalt(14))
|
||||||
None, bcrypt.hashpw, buf, bcrypt.gensalt(14)
|
|
||||||
)
|
|
||||||
|
|
||||||
return hashed.decode()
|
return hashed.decode()
|
||||||
|
|
||||||
|
|
@ -148,22 +155,28 @@ async def hash_data(data: str, loop=None) -> str:
|
||||||
async def check_username_usage(username: str, db=None):
|
async def check_username_usage(username: str, db=None):
|
||||||
"""Raise an error if too many people are with the same username."""
|
"""Raise an error if too many people are with the same username."""
|
||||||
db = db or app.db
|
db = db or app.db
|
||||||
same_username = await db.fetchval("""
|
same_username = await db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM users
|
FROM users
|
||||||
WHERE username = $1
|
WHERE username = $1
|
||||||
""", username)
|
""",
|
||||||
|
username,
|
||||||
|
)
|
||||||
|
|
||||||
if same_username > 9000:
|
if same_username > 9000:
|
||||||
raise BadRequest('Too many people.', {
|
raise BadRequest(
|
||||||
'username': 'Too many people used the same username. '
|
"Too many people.",
|
||||||
'Please choose another'
|
{
|
||||||
})
|
"username": "Too many people used the same username. "
|
||||||
|
"Please choose another"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _raw_discrim() -> str:
|
def _raw_discrim() -> str:
|
||||||
new_discrim = randint(1, 9999)
|
new_discrim = randint(1, 9999)
|
||||||
new_discrim = '%04d' % new_discrim
|
new_discrim = "%04d" % new_discrim
|
||||||
return new_discrim
|
return new_discrim
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -186,11 +199,15 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
|
||||||
discrim = _raw_discrim()
|
discrim = _raw_discrim()
|
||||||
|
|
||||||
# check if anyone is with it
|
# check if anyone is with it
|
||||||
res = await db.fetchval("""
|
res = await db.fetchval(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM users
|
FROM users
|
||||||
WHERE username = $1 AND discriminator = $2
|
WHERE username = $1 AND discriminator = $2
|
||||||
""", username, discrim)
|
""",
|
||||||
|
username,
|
||||||
|
discrim,
|
||||||
|
)
|
||||||
|
|
||||||
# if no user is found with the (username, discrim)
|
# if no user is found with the (username, discrim)
|
||||||
# pair, then this is unique! return it.
|
# pair, then this is unique! return it.
|
||||||
|
|
@ -200,8 +217,9 @@ async def roll_discrim(username: str, *, db=None) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def create_user(username: str, email: str, password: str,
|
async def create_user(
|
||||||
db=None, loop=None) -> Tuple[int, str]:
|
username: str, email: str, password: str, db=None, loop=None
|
||||||
|
) -> Tuple[int, str]:
|
||||||
"""Create a single user.
|
"""Create a single user.
|
||||||
|
|
||||||
Generates a distriminator and other information. You can fetch the user
|
Generates a distriminator and other information. You can fetch the user
|
||||||
|
|
@ -214,20 +232,28 @@ async def create_user(username: str, email: str, password: str,
|
||||||
new_discrim = await roll_discrim(username, db=db)
|
new_discrim = await roll_discrim(username, db=db)
|
||||||
|
|
||||||
if new_discrim is None:
|
if new_discrim is None:
|
||||||
raise BadRequest('Unable to register.', {
|
raise BadRequest(
|
||||||
'username': 'Too many people are with this username.'
|
"Unable to register.",
|
||||||
})
|
{"username": "Too many people are with this username."},
|
||||||
|
)
|
||||||
|
|
||||||
pwd_hash = await hash_data(password, loop)
|
pwd_hash = await hash_data(password, loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db.execute("""
|
await db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO users
|
INSERT INTO users
|
||||||
(id, email, username, discriminator, password_hash)
|
(id, email, username, discriminator, password_hash)
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, $4, $5)
|
($1, $2, $3, $4, $5)
|
||||||
""", new_id, email, username, new_discrim, pwd_hash)
|
""",
|
||||||
|
new_id,
|
||||||
|
email,
|
||||||
|
username,
|
||||||
|
new_discrim,
|
||||||
|
pwd_hash,
|
||||||
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
raise BadRequest('Email already used.')
|
raise BadRequest("Email already used.")
|
||||||
|
|
||||||
return new_id, pwd_hash
|
return new_id, pwd_hash
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,21 @@ from .static import bp as static
|
||||||
from .attachments import bp as attachments
|
from .attachments import bp as attachments
|
||||||
from .dm_channels import bp as dm_channels
|
from .dm_channels import bp as dm_channels
|
||||||
|
|
||||||
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
|
__all__ = [
|
||||||
'webhooks', 'science', 'voice', 'invites', 'relationships',
|
"gateway",
|
||||||
'dms', 'icons', 'nodeinfo', 'static', 'attachments',
|
"auth",
|
||||||
'dm_channels']
|
"users",
|
||||||
|
"guilds",
|
||||||
|
"channels",
|
||||||
|
"webhooks",
|
||||||
|
"science",
|
||||||
|
"voice",
|
||||||
|
"invites",
|
||||||
|
"relationships",
|
||||||
|
"dms",
|
||||||
|
"icons",
|
||||||
|
"nodeinfo",
|
||||||
|
"static",
|
||||||
|
"attachments",
|
||||||
|
"dm_channels",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -23,4 +23,4 @@ from .guilds import bp as guilds
|
||||||
from .users import bp as users
|
from .users import bp as users
|
||||||
from .instance_invites import bp as instance_invites
|
from .instance_invites import bp as instance_invites
|
||||||
|
|
||||||
__all__ = ['voice', 'features', 'guilds', 'users', 'instance_invites']
|
__all__ = ["voice", "features", "guilds", "users", "instance_invites"]
|
||||||
|
|
|
||||||
|
|
@ -25,45 +25,53 @@ from litecord.errors import BadRequest
|
||||||
from litecord.schemas import validate
|
from litecord.schemas import validate
|
||||||
from litecord.admin_schemas import FEATURES
|
from litecord.admin_schemas import FEATURES
|
||||||
|
|
||||||
bp = Blueprint('features_admin', __name__)
|
bp = Blueprint("features_admin", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def _features_from_req() -> List[str]:
|
async def _features_from_req() -> List[str]:
|
||||||
j = validate(await request.get_json(), FEATURES)
|
j = validate(await request.get_json(), FEATURES)
|
||||||
return [feature.value for feature in j['features']]
|
return [feature.value for feature in j["features"]]
|
||||||
|
|
||||||
|
|
||||||
async def _features(guild_id: int):
|
async def _features(guild_id: int):
|
||||||
return jsonify({
|
return jsonify({"features": await app.storage.guild_features(guild_id)})
|
||||||
'features': await app.storage.guild_features(guild_id)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_features(guild_id: int, features: list):
|
async def _update_features(guild_id: int, features: list):
|
||||||
if 'VANITY_URL' not in features:
|
if "VANITY_URL" not in features:
|
||||||
existing_inv = await app.storage.vanity_invite(guild_id)
|
existing_inv = await app.storage.vanity_invite(guild_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM vanity_invites
|
DELETE FROM vanity_invites
|
||||||
WHERE guild_id = $1
|
WHERE guild_id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM invites
|
DELETE FROM invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", existing_inv)
|
""",
|
||||||
|
existing_inv,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET features = $1
|
SET features = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", features, guild_id)
|
""",
|
||||||
|
features,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
guild = await app.storage.get_guild_full(guild_id)
|
guild = await app.storage.get_guild_full(guild_id)
|
||||||
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_UPDATE', guild)
|
await app.dispatcher.dispatch("guild", guild_id, "GUILD_UPDATE", guild)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/features', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/features", methods=["PATCH"])
|
||||||
async def replace_features(guild_id: int):
|
async def replace_features(guild_id: int):
|
||||||
"""Replace the feature list in a guild"""
|
"""Replace the feature list in a guild"""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
@ -76,7 +84,7 @@ async def replace_features(guild_id: int):
|
||||||
return await _features(guild_id)
|
return await _features(guild_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/features', methods=['PUT'])
|
@bp.route("/<int:guild_id>/features", methods=["PUT"])
|
||||||
async def insert_features(guild_id: int):
|
async def insert_features(guild_id: int):
|
||||||
"""Insert a feature on a guild."""
|
"""Insert a feature on a guild."""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
@ -93,7 +101,7 @@ async def insert_features(guild_id: int):
|
||||||
return await _features(guild_id)
|
return await _features(guild_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/features', methods=['DELETE'])
|
@bp.route("/<int:guild_id>/features", methods=["DELETE"])
|
||||||
async def remove_features(guild_id: int):
|
async def remove_features(guild_id: int):
|
||||||
"""Remove a feature from a guild"""
|
"""Remove a feature from a guild"""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
@ -104,7 +112,7 @@ async def remove_features(guild_id: int):
|
||||||
try:
|
try:
|
||||||
features.remove(feature)
|
features.remove(feature)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise BadRequest('Trying to remove already removed feature.')
|
raise BadRequest("Trying to remove already removed feature.")
|
||||||
|
|
||||||
await _update_features(guild_id, features)
|
await _update_features(guild_id, features)
|
||||||
return await _features(guild_id)
|
return await _features(guild_id)
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,10 @@ from litecord.admin_schemas import GUILD_UPDATE
|
||||||
from litecord.blueprints.guilds import delete_guild
|
from litecord.blueprints.guilds import delete_guild
|
||||||
from litecord.errors import GuildNotFound
|
from litecord.errors import GuildNotFound
|
||||||
|
|
||||||
bp = Blueprint('guilds_admin', __name__)
|
bp = Blueprint("guilds_admin", __name__)
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['GET'])
|
|
||||||
|
@bp.route("/<int:guild_id>", methods=["GET"])
|
||||||
async def get_guild(guild_id: int):
|
async def get_guild(guild_id: int):
|
||||||
"""Get a basic guild payload."""
|
"""Get a basic guild payload."""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
@ -40,7 +41,7 @@ async def get_guild(guild_id: int):
|
||||||
return jsonify(guild)
|
return jsonify(guild)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['PATCH'])
|
@bp.route("/<int:guild_id>", methods=["PATCH"])
|
||||||
async def update_guild(guild_id: int):
|
async def update_guild(guild_id: int):
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
|
|
@ -48,13 +49,13 @@ async def update_guild(guild_id: int):
|
||||||
|
|
||||||
# TODO: what happens to the other guild attributes when its
|
# TODO: what happens to the other guild attributes when its
|
||||||
# unavailable? do they vanish?
|
# unavailable? do they vanish?
|
||||||
old_unavailable = app.guild_store.get(guild_id, 'unavailable')
|
old_unavailable = app.guild_store.get(guild_id, "unavailable")
|
||||||
new_unavailable = j.get('unavailable', old_unavailable)
|
new_unavailable = j.get("unavailable", old_unavailable)
|
||||||
|
|
||||||
# always set unavailable status since new_unavailable will be
|
# always set unavailable status since new_unavailable will be
|
||||||
# old_unavailable when not provided, so we don't need to check if
|
# old_unavailable when not provided, so we don't need to check if
|
||||||
# j.unavailable is there
|
# j.unavailable is there
|
||||||
app.guild_store.set(guild_id, 'unavailable', j['unavailable'])
|
app.guild_store.set(guild_id, "unavailable", j["unavailable"])
|
||||||
|
|
||||||
guild = await app.storage.get_guild(guild_id)
|
guild = await app.storage.get_guild(guild_id)
|
||||||
|
|
||||||
|
|
@ -62,17 +63,17 @@ async def update_guild(guild_id: int):
|
||||||
|
|
||||||
if old_unavailable and not new_unavailable:
|
if old_unavailable and not new_unavailable:
|
||||||
# guild became available
|
# guild became available
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild)
|
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild)
|
||||||
else:
|
else:
|
||||||
# guild became unavailable
|
# guild became unavailable
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_DELETE', guild)
|
await app.dispatcher.dispatch_guild(guild_id, "GUILD_DELETE", guild)
|
||||||
|
|
||||||
return jsonify(guild)
|
return jsonify(guild)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['DELETE'])
|
@bp.route("/<int:guild_id>", methods=["DELETE"])
|
||||||
async def delete_guild_as_admin(guild_id):
|
async def delete_guild_as_admin(guild_id):
|
||||||
"""Delete a single guild via the admin API without ownership checks."""
|
"""Delete a single guild via the admin API without ownership checks."""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
await delete_guild(guild_id)
|
await delete_guild(guild_id)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -27,13 +27,13 @@ from litecord.types import timestamp_
|
||||||
from litecord.schemas import validate
|
from litecord.schemas import validate
|
||||||
from litecord.admin_schemas import INSTANCE_INVITE
|
from litecord.admin_schemas import INSTANCE_INVITE
|
||||||
|
|
||||||
bp = Blueprint('instance_invites', __name__)
|
bp = Blueprint("instance_invites", __name__)
|
||||||
ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
|
ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
|
||||||
|
|
||||||
|
|
||||||
async def _gen_inv() -> str:
|
async def _gen_inv() -> str:
|
||||||
"""Generate an invite code"""
|
"""Generate an invite code"""
|
||||||
return ''.join(choice(ALPHABET) for _ in range(6))
|
return "".join(choice(ALPHABET) for _ in range(6))
|
||||||
|
|
||||||
|
|
||||||
async def gen_inv(ctx) -> str:
|
async def gen_inv(ctx) -> str:
|
||||||
|
|
@ -41,11 +41,14 @@ async def gen_inv(ctx) -> str:
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
possible_inv = await _gen_inv()
|
possible_inv = await _gen_inv()
|
||||||
|
|
||||||
created_at = await ctx.db.fetchval("""
|
created_at = await ctx.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT created_at
|
SELECT created_at
|
||||||
FROM instance_invites
|
FROM instance_invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", possible_inv)
|
""",
|
||||||
|
possible_inv,
|
||||||
|
)
|
||||||
|
|
||||||
if created_at is None:
|
if created_at is None:
|
||||||
return possible_inv
|
return possible_inv
|
||||||
|
|
@ -53,57 +56,71 @@ async def gen_inv(ctx) -> str:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@bp.route('', methods=['GET'])
|
@bp.route("", methods=["GET"])
|
||||||
async def _all_instance_invites():
|
async def _all_instance_invites():
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
rows = await app.db.fetch("""
|
rows = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT code, created_at, uses, max_uses
|
SELECT code, created_at, uses, max_uses
|
||||||
FROM instance_invites
|
FROM instance_invites
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
rows = [dict(row) for row in rows]
|
rows = [dict(row) for row in rows]
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
row['created_at'] = timestamp_(row['created_at'])
|
row["created_at"] = timestamp_(row["created_at"])
|
||||||
|
|
||||||
return jsonify(rows)
|
return jsonify(rows)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('', methods=['PUT'])
|
@bp.route("", methods=["PUT"])
|
||||||
async def _create_invite():
|
async def _create_invite():
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
code = await gen_inv(app)
|
code = await gen_inv(app)
|
||||||
if code is None:
|
if code is None:
|
||||||
return 'failed to make invite', 500
|
return "failed to make invite", 500
|
||||||
|
|
||||||
j = validate(await request.get_json(), INSTANCE_INVITE)
|
j = validate(await request.get_json(), INSTANCE_INVITE)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO instance_invites (code, max_uses)
|
INSERT INTO instance_invites (code, max_uses)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", code, j['max_uses'])
|
""",
|
||||||
|
code,
|
||||||
|
j["max_uses"],
|
||||||
|
)
|
||||||
|
|
||||||
inv = dict(await app.db.fetchrow("""
|
inv = dict(
|
||||||
|
await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT code, created_at, uses, max_uses
|
SELECT code, created_at, uses, max_uses
|
||||||
FROM instance_invites
|
FROM instance_invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", code))
|
""",
|
||||||
|
code,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(dict(inv))
|
return jsonify(dict(inv))
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<invite>', methods=['DELETE'])
|
@bp.route("/<invite>", methods=["DELETE"])
|
||||||
async def _del_invite(invite: str):
|
async def _del_invite(invite: str):
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
res = await app.db.execute("""
|
res = await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM instance_invites
|
DELETE FROM instance_invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", invite)
|
""",
|
||||||
|
invite,
|
||||||
|
)
|
||||||
|
|
||||||
if res.lower() == 'delete 0':
|
if res.lower() == "delete 0":
|
||||||
return 'invite not found', 404
|
return "invite not found", 404
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -25,24 +25,21 @@ from litecord.schemas import validate
|
||||||
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
|
from litecord.admin_schemas import USER_CREATE, USER_UPDATE
|
||||||
from litecord.errors import BadRequest, Forbidden
|
from litecord.errors import BadRequest, Forbidden
|
||||||
from litecord.utils import async_map
|
from litecord.utils import async_map
|
||||||
from litecord.blueprints.users import (
|
from litecord.blueprints.users import delete_user, user_disconnect, mass_user_update
|
||||||
delete_user, user_disconnect, mass_user_update
|
|
||||||
)
|
|
||||||
from litecord.enums import UserFlags
|
from litecord.enums import UserFlags
|
||||||
|
|
||||||
bp = Blueprint('users_admin', __name__)
|
bp = Blueprint("users_admin", __name__)
|
||||||
|
|
||||||
@bp.route('', methods=['POST'])
|
|
||||||
|
@bp.route("", methods=["POST"])
|
||||||
async def _create_user():
|
async def _create_user():
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
j = validate(await request.get_json(), USER_CREATE)
|
j = validate(await request.get_json(), USER_CREATE)
|
||||||
|
|
||||||
user_id, _ = await create_user(j['username'], j['email'], j['password'])
|
user_id, _ = await create_user(j["username"], j["email"], j["password"])
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_user(user_id))
|
||||||
await app.storage.get_user(user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def args_try(args: dict, typ, field: str, default):
|
def args_try(args: dict, typ, field: str, default):
|
||||||
|
|
@ -51,29 +48,29 @@ def args_try(args: dict, typ, field: str, default):
|
||||||
try:
|
try:
|
||||||
return typ(args.get(field, default))
|
return typ(args.get(field, default))
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
raise BadRequest(f'invalid {field} value')
|
raise BadRequest(f"invalid {field} value")
|
||||||
|
|
||||||
|
|
||||||
@bp.route('', methods=['GET'])
|
@bp.route("", methods=["GET"])
|
||||||
async def _search_users():
|
async def _search_users():
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
args = request.args
|
args = request.args
|
||||||
|
|
||||||
username, discrim = args.get('username'), args.get('discriminator')
|
username, discrim = args.get("username"), args.get("discriminator")
|
||||||
|
|
||||||
per_page = args_try(args, int, 'per_page', 20)
|
per_page = args_try(args, int, "per_page", 20)
|
||||||
page = args_try(args, int, 'page', 0)
|
page = args_try(args, int, "page", 0)
|
||||||
|
|
||||||
if page < 0:
|
if page < 0:
|
||||||
raise BadRequest('invalid page number')
|
raise BadRequest("invalid page number")
|
||||||
|
|
||||||
if per_page > 50:
|
if per_page > 50:
|
||||||
raise BadRequest('invalid per page number')
|
raise BadRequest("invalid per page number")
|
||||||
|
|
||||||
# any of those must be available.
|
# any of those must be available.
|
||||||
if not any((username, discrim)):
|
if not any((username, discrim)):
|
||||||
raise BadRequest('must insert username or discrim')
|
raise BadRequest("must insert username or discrim")
|
||||||
|
|
||||||
wheres, args = [], []
|
wheres, args = [], []
|
||||||
|
|
||||||
|
|
@ -82,29 +79,31 @@ async def _search_users():
|
||||||
args.append(username)
|
args.append(username)
|
||||||
|
|
||||||
if discrim:
|
if discrim:
|
||||||
wheres.append(f'discriminator = ${len(args) + 2}')
|
wheres.append(f"discriminator = ${len(args) + 2}")
|
||||||
args.append(discrim)
|
args.append(discrim)
|
||||||
|
|
||||||
where_tot = 'WHERE ' if args else ''
|
where_tot = "WHERE " if args else ""
|
||||||
where_tot += ' AND '.join(wheres)
|
where_tot += " AND ".join(wheres)
|
||||||
|
|
||||||
rows = await app.db.fetch(f"""
|
rows = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM users
|
FROM users
|
||||||
{where_tot}
|
{where_tot}
|
||||||
ORDER BY id ASC
|
ORDER BY id ASC
|
||||||
LIMIT {per_page}
|
LIMIT {per_page}
|
||||||
OFFSET ($1 * {per_page})
|
OFFSET ($1 * {per_page})
|
||||||
""", page, *args)
|
""",
|
||||||
|
page,
|
||||||
rows = [r['id'] for r in rows]
|
*args,
|
||||||
|
|
||||||
return jsonify(
|
|
||||||
await async_map(app.storage.get_user, rows)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
rows = [r["id"] for r in rows]
|
||||||
|
|
||||||
@bp.route('/<int:user_id>', methods=['DELETE'])
|
return jsonify(await async_map(app.storage.get_user, rows))
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/<int:user_id>", methods=["DELETE"])
|
||||||
async def _delete_single_user(user_id: int):
|
async def _delete_single_user(user_id: int):
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
|
|
@ -115,13 +114,10 @@ async def _delete_single_user(user_id: int):
|
||||||
|
|
||||||
new_user = await app.storage.get_user(user_id)
|
new_user = await app.storage.get_user(user_id)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"old": old_user, "new": new_user})
|
||||||
'old': old_user,
|
|
||||||
'new': new_user
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:user_id>', methods=['PATCH'])
|
@bp.route("/<int:user_id>", methods=["PATCH"])
|
||||||
async def patch_user(user_id: int):
|
async def patch_user(user_id: int):
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
||||||
|
|
@ -129,21 +125,25 @@ async def patch_user(user_id: int):
|
||||||
|
|
||||||
# get the original user for flags checking
|
# get the original user for flags checking
|
||||||
user = await app.storage.get_user(user_id)
|
user = await app.storage.get_user(user_id)
|
||||||
old_flags = UserFlags.from_int(user['flags'])
|
old_flags = UserFlags.from_int(user["flags"])
|
||||||
|
|
||||||
# j.flags is already a UserFlags since we coerce it.
|
# j.flags is already a UserFlags since we coerce it.
|
||||||
if 'flags' in j:
|
if "flags" in j:
|
||||||
new_flags = j['flags']
|
new_flags = j["flags"]
|
||||||
|
|
||||||
# disallow any changes to the staff badge
|
# disallow any changes to the staff badge
|
||||||
if new_flags.is_staff != old_flags.is_staff:
|
if new_flags.is_staff != old_flags.is_staff:
|
||||||
raise Forbidden('you can not change a users staff badge')
|
raise Forbidden("you can not change a users staff badge")
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET flags = $1
|
SET flags = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_flags.value, user_id)
|
""",
|
||||||
|
new_flags.value,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
public_user, _ = await mass_user_update(user_id, app)
|
public_user, _ = await mass_user_update(user_id, app)
|
||||||
return jsonify(public_user)
|
return jsonify(public_user)
|
||||||
|
|
|
||||||
|
|
@ -27,10 +27,10 @@ from litecord.admin_schemas import VOICE_SERVER, VOICE_REGION
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('voice_admin', __name__)
|
bp = Blueprint("voice_admin", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/regions/<region>', methods=['GET'])
|
@bp.route("/regions/<region>", methods=["GET"])
|
||||||
async def get_region_servers(region):
|
async def get_region_servers(region):
|
||||||
"""Return a list of all servers for a region."""
|
"""Return a list of all servers for a region."""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
@ -38,18 +38,25 @@ async def get_region_servers(region):
|
||||||
return jsonify(servers)
|
return jsonify(servers)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/regions', methods=['PUT'])
|
@bp.route("/regions", methods=["PUT"])
|
||||||
async def insert_new_region():
|
async def insert_new_region():
|
||||||
"""Create a voice region."""
|
"""Create a voice region."""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
j = validate(await request.get_json(), VOICE_REGION)
|
j = validate(await request.get_json(), VOICE_REGION)
|
||||||
|
|
||||||
j['id'] = j['id'].lower()
|
j["id"] = j["id"].lower()
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO voice_regions (id, name, vip, deprecated, custom)
|
INSERT INTO voice_regions (id, name, vip, deprecated, custom)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
""", j['id'], j['name'], j['vip'], j['deprecated'], j['custom'])
|
""",
|
||||||
|
j["id"],
|
||||||
|
j["name"],
|
||||||
|
j["vip"],
|
||||||
|
j["deprecated"],
|
||||||
|
j["custom"],
|
||||||
|
)
|
||||||
|
|
||||||
regions = await app.storage.all_voice_regions()
|
regions = await app.storage.all_voice_regions()
|
||||||
region_count = len(regions)
|
region_count = len(regions)
|
||||||
|
|
@ -57,34 +64,41 @@ async def insert_new_region():
|
||||||
# if region count is 1, this is the first region to be created,
|
# if region count is 1, this is the first region to be created,
|
||||||
# so we should update all guilds to that region
|
# so we should update all guilds to that region
|
||||||
if region_count == 1:
|
if region_count == 1:
|
||||||
res = await app.db.execute("""
|
res = await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET region = $1
|
SET region = $1
|
||||||
""", j['id'])
|
""",
|
||||||
|
j["id"],
|
||||||
|
)
|
||||||
|
|
||||||
log.info('updating guilds to first voice region: {}', res)
|
log.info("updating guilds to first voice region: {}", res)
|
||||||
|
|
||||||
return jsonify(regions)
|
return jsonify(regions)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/regions/<region>/servers', methods=['PUT'])
|
@bp.route("/regions/<region>/servers", methods=["PUT"])
|
||||||
async def put_region_server(region):
|
async def put_region_server(region):
|
||||||
"""Insert a voice server to a region"""
|
"""Insert a voice server to a region"""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
j = validate(await request.get_json(), VOICE_SERVER)
|
j = validate(await request.get_json(), VOICE_SERVER)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO voice_servers (hostname, region_id)
|
INSERT INTO voice_servers (hostname, region_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", j['hostname'], region)
|
""",
|
||||||
|
j["hostname"],
|
||||||
|
region,
|
||||||
|
)
|
||||||
except asyncpg.UniqueViolationError:
|
except asyncpg.UniqueViolationError:
|
||||||
raise BadRequest('voice server already exists with given hostname')
|
raise BadRequest("voice server already exists with given hostname")
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/regions/<region>/deprecate', methods=['PUT'])
|
@bp.route("/regions/<region>/deprecate", methods=["PUT"])
|
||||||
async def deprecate_region(region):
|
async def deprecate_region(region):
|
||||||
"""Deprecate a voice region."""
|
"""Deprecate a voice region."""
|
||||||
await admin_check()
|
await admin_check()
|
||||||
|
|
@ -92,13 +106,16 @@ async def deprecate_region(region):
|
||||||
# TODO: write this
|
# TODO: write this
|
||||||
await app.voice.disable_region(region)
|
await app.voice.disable_region(region)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE voice_regions
|
UPDATE voice_regions
|
||||||
SET deprecated = true
|
SET deprecated = true
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", region)
|
""",
|
||||||
|
region,
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def guild_region_check(app_):
|
async def guild_region_check(app_):
|
||||||
|
|
@ -112,10 +129,11 @@ async def guild_region_check(app_):
|
||||||
regions = await app_.storage.all_voice_regions()
|
regions = await app_.storage.all_voice_regions()
|
||||||
|
|
||||||
if not regions:
|
if not regions:
|
||||||
log.info('region check: no regions to move guilds to')
|
log.info("region check: no regions to move guilds to")
|
||||||
return
|
return
|
||||||
|
|
||||||
res = await app_.db.execute("""
|
res = await app_.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET region = (
|
SET region = (
|
||||||
SELECT id
|
SELECT id
|
||||||
|
|
@ -124,6 +142,8 @@ async def guild_region_check(app_):
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
)
|
)
|
||||||
WHERE region = NULL
|
WHERE region = NULL
|
||||||
""", len(regions))
|
""",
|
||||||
|
len(regions),
|
||||||
|
)
|
||||||
|
|
||||||
log.info('region check: updating guild.region=null: {!r}', res)
|
log.info("region check: updating guild.region=null: {!r}", res)
|
||||||
|
|
|
||||||
|
|
@ -24,16 +24,17 @@ from PIL import Image
|
||||||
|
|
||||||
from litecord.images import resize_gif
|
from litecord.images import resize_gif
|
||||||
|
|
||||||
bp = Blueprint('attachments', __name__)
|
bp = Blueprint("attachments", __name__)
|
||||||
ATTACHMENTS = Path.cwd() / 'attachments'
|
ATTACHMENTS = Path.cwd() / "attachments"
|
||||||
|
|
||||||
|
|
||||||
async def _resize_gif(attach_id: int, resized_path: Path,
|
async def _resize_gif(
|
||||||
width: int, height: int) -> str:
|
attach_id: int, resized_path: Path, width: int, height: int
|
||||||
|
) -> str:
|
||||||
"""Resize a GIF attachment."""
|
"""Resize a GIF attachment."""
|
||||||
|
|
||||||
# get original gif bytes
|
# get original gif bytes
|
||||||
orig_path = ATTACHMENTS / f'{attach_id}.gif'
|
orig_path = ATTACHMENTS / f"{attach_id}.gif"
|
||||||
orig_bytes = orig_path.read_bytes()
|
orig_bytes = orig_path.read_bytes()
|
||||||
|
|
||||||
# give them and the target size to the
|
# give them and the target size to the
|
||||||
|
|
@ -47,10 +48,7 @@ async def _resize_gif(attach_id: int, resized_path: Path,
|
||||||
return str(resized_path)
|
return str(resized_path)
|
||||||
|
|
||||||
|
|
||||||
FORMAT_HARDCODE = {
|
FORMAT_HARDCODE = {"jpg": "jpeg", "jpe": "jpeg"}
|
||||||
'jpg': 'jpeg',
|
|
||||||
'jpe': 'jpeg'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def to_format(ext: str) -> str:
|
def to_format(ext: str) -> str:
|
||||||
|
|
@ -63,11 +61,10 @@ def to_format(ext: str) -> str:
|
||||||
return ext
|
return ext
|
||||||
|
|
||||||
|
|
||||||
async def _resize(image, attach_id: int, ext: str,
|
async def _resize(image, attach_id: int, ext: str, width: int, height: int) -> str:
|
||||||
width: int, height: int) -> str:
|
|
||||||
"""Resize an image."""
|
"""Resize an image."""
|
||||||
# check if we have it on the folder
|
# check if we have it on the folder
|
||||||
resized_path = ATTACHMENTS / f'{attach_id}_{width}_{height}.{ext}'
|
resized_path = ATTACHMENTS / f"{attach_id}_{width}_{height}.{ext}"
|
||||||
|
|
||||||
# keep a str-fied instance since that is what
|
# keep a str-fied instance since that is what
|
||||||
# we'll return.
|
# we'll return.
|
||||||
|
|
@ -81,7 +78,7 @@ async def _resize(image, attach_id: int, ext: str,
|
||||||
|
|
||||||
# the process is different for gif files because we need
|
# the process is different for gif files because we need
|
||||||
# gifsicle. doing it manually is too troublesome.
|
# gifsicle. doing it manually is too troublesome.
|
||||||
if ext == 'gif':
|
if ext == "gif":
|
||||||
return await _resize_gif(attach_id, resized_path, width, height)
|
return await _resize_gif(attach_id, resized_path, width, height)
|
||||||
|
|
||||||
# NOTE: this is the same resize mode for icons.
|
# NOTE: this is the same resize mode for icons.
|
||||||
|
|
@ -91,38 +88,42 @@ async def _resize(image, attach_id: int, ext: str,
|
||||||
return resized_path_s
|
return resized_path_s
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/attachments'
|
@bp.route(
|
||||||
'/<int:channel_id>/<int:message_id>/<filename>',
|
"/attachments" "/<int:channel_id>/<int:message_id>/<filename>", methods=["GET"]
|
||||||
methods=['GET'])
|
)
|
||||||
async def _get_attachment(channel_id: int, message_id: int,
|
async def _get_attachment(channel_id: int, message_id: int, filename: str):
|
||||||
filename: str):
|
|
||||||
|
|
||||||
attach_id = await app.db.fetchval("""
|
attach_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM attachments
|
FROM attachments
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
AND message_id = $2
|
AND message_id = $2
|
||||||
AND filename = $3
|
AND filename = $3
|
||||||
""", channel_id, message_id, filename)
|
""",
|
||||||
|
channel_id,
|
||||||
|
message_id,
|
||||||
|
filename,
|
||||||
|
)
|
||||||
|
|
||||||
if attach_id is None:
|
if attach_id is None:
|
||||||
return '', 404
|
return "", 404
|
||||||
|
|
||||||
ext = filename.split('.')[-1]
|
ext = filename.split(".")[-1]
|
||||||
filepath = f'./attachments/{attach_id}.{ext}'
|
filepath = f"./attachments/{attach_id}.{ext}"
|
||||||
|
|
||||||
image = Image.open(filepath)
|
image = Image.open(filepath)
|
||||||
im_width, im_height = image.size
|
im_width, im_height = image.size
|
||||||
|
|
||||||
try:
|
try:
|
||||||
width = int(request.args.get('width', 0)) or im_width
|
width = int(request.args.get("width", 0)) or im_width
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return '', 400
|
return "", 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
height = int(request.args.get('height', 0)) or im_height
|
height = int(request.args.get("height", 0)) or im_height
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return '', 400
|
return "", 400
|
||||||
|
|
||||||
# if width and height are the same (happens if they weren't provided)
|
# if width and height are the same (happens if they weren't provided)
|
||||||
if width == im_width and height == im_height:
|
if width == im_width and height == im_height:
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ from litecord.snowflake import get_snowflake
|
||||||
from .invites import use_invite
|
from .invites import use_invite
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('auth', __name__)
|
bp = Blueprint("auth", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def check_password(pwd_hash: str, given_password: str) -> bool:
|
async def check_password(pwd_hash: str, given_password: str) -> bool:
|
||||||
|
|
@ -53,141 +53,139 @@ def make_token(user_id, user_pwd_hash) -> str:
|
||||||
return signer.sign(user_id).decode()
|
return signer.sign(user_id).decode()
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/register', methods=['POST'])
|
@bp.route("/register", methods=["POST"])
|
||||||
async def register():
|
async def register():
|
||||||
"""Register a single user."""
|
"""Register a single user."""
|
||||||
enabled = app.config.get('REGISTRATIONS')
|
enabled = app.config.get("REGISTRATIONS")
|
||||||
if not enabled:
|
if not enabled:
|
||||||
raise BadRequest('Registrations disabled', {
|
raise BadRequest(
|
||||||
'email': 'Registrations are disabled.'
|
"Registrations disabled", {"email": "Registrations are disabled."}
|
||||||
})
|
)
|
||||||
|
|
||||||
j = await request.get_json()
|
j = await request.get_json()
|
||||||
|
|
||||||
if not 'password' in j:
|
if not "password" in j:
|
||||||
# we need a password to generate a token.
|
# we need a password to generate a token.
|
||||||
# passwords are optional, so
|
# passwords are optional, so
|
||||||
j['password'] = 'default_password'
|
j["password"] = "default_password"
|
||||||
|
|
||||||
j = validate(j, REGISTER)
|
j = validate(j, REGISTER)
|
||||||
|
|
||||||
# they're optional
|
# they're optional
|
||||||
email = j.get('email')
|
email = j.get("email")
|
||||||
invite = j.get('invite')
|
invite = j.get("invite")
|
||||||
|
|
||||||
username, password = j['username'], j['password']
|
username, password = j["username"], j["password"]
|
||||||
|
|
||||||
new_id, pwd_hash = await create_user(
|
new_id, pwd_hash = await create_user(username, email, password, app.db)
|
||||||
username, email, password, app.db
|
|
||||||
)
|
|
||||||
|
|
||||||
if invite:
|
if invite:
|
||||||
try:
|
try:
|
||||||
await use_invite(new_id, invite)
|
await use_invite(new_id, invite)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('failed to use invite for register {} {!r}',
|
log.exception("failed to use invite for register {} {!r}", new_id, invite)
|
||||||
new_id, invite)
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"token": make_token(new_id, pwd_hash)})
|
||||||
'token': make_token(new_id, pwd_hash)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/register_inv', methods=['POST'])
|
@bp.route("/register_inv", methods=["POST"])
|
||||||
async def _register_with_invite():
|
async def _register_with_invite():
|
||||||
data = await request.form
|
data = await request.form
|
||||||
data = validate(await request.form, REGISTER_WITH_INVITE)
|
data = validate(await request.form, REGISTER_WITH_INVITE)
|
||||||
|
|
||||||
invcode = data['invcode']
|
invcode = data["invcode"]
|
||||||
|
|
||||||
row = await app.db.fetchrow("""
|
row = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT uses, max_uses
|
SELECT uses, max_uses
|
||||||
FROM instance_invites
|
FROM instance_invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", invcode)
|
""",
|
||||||
|
invcode,
|
||||||
|
)
|
||||||
|
|
||||||
if row is None:
|
if row is None:
|
||||||
raise BadRequest('unknown instance invite')
|
raise BadRequest("unknown instance invite")
|
||||||
|
|
||||||
if row['max_uses'] != -1 and row['uses'] >= row['max_uses']:
|
if row["max_uses"] != -1 and row["uses"] >= row["max_uses"]:
|
||||||
raise BadRequest('invite expired')
|
raise BadRequest("invite expired")
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE instance_invites
|
UPDATE instance_invites
|
||||||
SET uses = uses + 1
|
SET uses = uses + 1
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", invcode)
|
""",
|
||||||
|
invcode,
|
||||||
|
)
|
||||||
|
|
||||||
user_id, pwd_hash = await create_user(
|
user_id, pwd_hash = await create_user(
|
||||||
data['username'], data['email'], data['password'], app.db)
|
data["username"], data["email"], data["password"], app.db
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"token": make_token(user_id, pwd_hash), "user_id": str(user_id)})
|
||||||
'token': make_token(user_id, pwd_hash),
|
|
||||||
'user_id': str(user_id),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/login', methods=['POST'])
|
@bp.route("/login", methods=["POST"])
|
||||||
async def login():
|
async def login():
|
||||||
j = await request.get_json()
|
j = await request.get_json()
|
||||||
email, password = j['email'], j['password']
|
email, password = j["email"], j["password"]
|
||||||
|
|
||||||
row = await app.db.fetchrow("""
|
row = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT id, password_hash
|
SELECT id, password_hash
|
||||||
FROM users
|
FROM users
|
||||||
WHERE email = $1
|
WHERE email = $1
|
||||||
""", email)
|
""",
|
||||||
|
email,
|
||||||
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return jsonify({'email': ['User not found.']}), 401
|
return jsonify({"email": ["User not found."]}), 401
|
||||||
|
|
||||||
user_id, pwd_hash = row
|
user_id, pwd_hash = row
|
||||||
|
|
||||||
if not await check_password(pwd_hash, password):
|
if not await check_password(pwd_hash, password):
|
||||||
return jsonify({'password': ['Password does not match.']}), 401
|
return jsonify({"password": ["Password does not match."]}), 401
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"token": make_token(user_id, pwd_hash)})
|
||||||
'token': make_token(user_id, pwd_hash)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/consent-required', methods=['GET'])
|
@bp.route("/consent-required", methods=["GET"])
|
||||||
async def consent_required():
|
async def consent_required():
|
||||||
return jsonify({
|
return jsonify({"required": True})
|
||||||
'required': True,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/verify/resend', methods=['POST'])
|
@bp.route("/verify/resend", methods=["POST"])
|
||||||
async def verify_user():
|
async def verify_user():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
# TODO: actually verify a user by sending an email
|
# TODO: actually verify a user by sending an email
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET verified = true
|
SET verified = true
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
new_user = await app.storage.get_user(user_id, True)
|
new_user = await app.storage.get_user(user_id, True)
|
||||||
await app.dispatcher.dispatch_user(
|
await app.dispatcher.dispatch_user(user_id, "USER_UPDATE", new_user)
|
||||||
user_id, 'USER_UPDATE', new_user)
|
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/logout', methods=['POST'])
|
@bp.route("/logout", methods=["POST"])
|
||||||
async def _logout():
|
async def _logout():
|
||||||
"""Called by the client to logout."""
|
"""Called by the client to logout."""
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/fingerprint', methods=['POST'])
|
@bp.route("/fingerprint", methods=["POST"])
|
||||||
async def _fingerprint():
|
async def _fingerprint():
|
||||||
"""No idea what this route is about."""
|
"""No idea what this route is about."""
|
||||||
fingerprint_id = get_snowflake()
|
fingerprint_id = get_snowflake()
|
||||||
fingerprint = f'{fingerprint_id}.{secrets.token_urlsafe(32)}'
|
fingerprint = f"{fingerprint_id}.{secrets.token_urlsafe(32)}"
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"fingerprint": fingerprint})
|
||||||
'fingerprint': fingerprint
|
|
||||||
})
|
|
||||||
|
|
|
||||||
|
|
@ -21,4 +21,4 @@ from .messages import bp as channel_messages
|
||||||
from .reactions import bp as channel_reactions
|
from .reactions import bp as channel_reactions
|
||||||
from .pins import bp as channel_pins
|
from .pins import bp as channel_pins
|
||||||
|
|
||||||
__all__ = ['channel_messages', 'channel_reactions', 'channel_pins']
|
__all__ = ["channel_messages", "channel_reactions", "channel_pins"]
|
||||||
|
|
|
||||||
|
|
@ -30,13 +30,18 @@ class ForbiddenDM(Forbidden):
|
||||||
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||||
"""Check if the user can DM the peer."""
|
"""Check if the user can DM the peer."""
|
||||||
# first step is checking if there is a block in any direction
|
# first step is checking if there is a block in any direction
|
||||||
blockrow = await app.db.fetchrow("""
|
blockrow = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT rel_type
|
SELECT rel_type
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE rel_type = $3
|
WHERE rel_type = $3
|
||||||
AND user_id IN ($1, $2)
|
AND user_id IN ($1, $2)
|
||||||
AND peer_id IN ($1, $2)
|
AND peer_id IN ($1, $2)
|
||||||
""", user_id, peer_id, RelationshipType.BLOCK.value)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
RelationshipType.BLOCK.value,
|
||||||
|
)
|
||||||
|
|
||||||
if blockrow is not None:
|
if blockrow is not None:
|
||||||
raise ForbiddenDM()
|
raise ForbiddenDM()
|
||||||
|
|
@ -58,8 +63,8 @@ async def dm_pre_check(user_id: int, channel_id: int, peer_id: int):
|
||||||
user_settings = await app.user_storage.get_user_settings(user_id)
|
user_settings = await app.user_storage.get_user_settings(user_id)
|
||||||
peer_settings = await app.user_storage.get_user_settings(peer_id)
|
peer_settings = await app.user_storage.get_user_settings(peer_id)
|
||||||
|
|
||||||
restricted_user_ = [int(v) for v in user_settings['restricted_guilds']]
|
restricted_user_ = [int(v) for v in user_settings["restricted_guilds"]]
|
||||||
restricted_peer_ = [int(v) for v in peer_settings['restricted_guilds']]
|
restricted_peer_ = [int(v) for v in peer_settings["restricted_guilds"]]
|
||||||
|
|
||||||
restricted_user = set(restricted_user_)
|
restricted_user = set(restricted_user_)
|
||||||
restricted_peer = set(restricted_peer_)
|
restricted_peer = set(restricted_peer_)
|
||||||
|
|
|
||||||
|
|
@ -41,18 +41,18 @@ from litecord.images import try_unlink
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('channel_messages', __name__)
|
bp = Blueprint("channel_messages", __name__)
|
||||||
|
|
||||||
|
|
||||||
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
def extract_limit(request_, default: int = 50, max_val: int = 100):
|
||||||
"""Extract a limit kwarg."""
|
"""Extract a limit kwarg."""
|
||||||
try:
|
try:
|
||||||
limit = int(request_.args.get('limit', default))
|
limit = int(request_.args.get("limit", default))
|
||||||
|
|
||||||
if limit not in range(0, max_val + 1):
|
if limit not in range(0, max_val + 1):
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
raise BadRequest('limit not int')
|
raise BadRequest("limit not int")
|
||||||
|
|
||||||
return limit
|
return limit
|
||||||
|
|
||||||
|
|
@ -61,27 +61,27 @@ def query_tuple_from_args(args: dict, limit: int) -> tuple:
|
||||||
"""Extract a 2-tuple out of request arguments."""
|
"""Extract a 2-tuple out of request arguments."""
|
||||||
before, after = None, None
|
before, after = None, None
|
||||||
|
|
||||||
if 'around' in request.args:
|
if "around" in request.args:
|
||||||
average = int(limit / 2)
|
average = int(limit / 2)
|
||||||
around = int(args['around'])
|
around = int(args["around"])
|
||||||
|
|
||||||
after = around - average
|
after = around - average
|
||||||
before = around + average
|
before = around + average
|
||||||
|
|
||||||
elif 'before' in args:
|
elif "before" in args:
|
||||||
before = int(args['before'])
|
before = int(args["before"])
|
||||||
elif 'after' in args:
|
elif "after" in args:
|
||||||
before = int(args['after'])
|
before = int(args["after"])
|
||||||
|
|
||||||
return before, after
|
return before, after
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages', methods=['GET'])
|
@bp.route("/<int:channel_id>/messages", methods=["GET"])
|
||||||
async def get_messages(channel_id):
|
async def get_messages(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
ctype, peer_id = await channel_check(user_id, channel_id)
|
ctype, peer_id = await channel_check(user_id, channel_id)
|
||||||
await channel_perm_check(user_id, channel_id, 'read_history')
|
await channel_perm_check(user_id, channel_id, "read_history")
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
# make sure both parties will be subbed
|
# make sure both parties will be subbed
|
||||||
|
|
@ -91,42 +91,45 @@ async def get_messages(channel_id):
|
||||||
|
|
||||||
limit = extract_limit(request, 50)
|
limit = extract_limit(request, 50)
|
||||||
|
|
||||||
where_clause = ''
|
where_clause = ""
|
||||||
before, after = query_tuple_from_args(request.args, limit)
|
before, after = query_tuple_from_args(request.args, limit)
|
||||||
|
|
||||||
if before:
|
if before:
|
||||||
where_clause += f'AND id < {before}'
|
where_clause += f"AND id < {before}"
|
||||||
|
|
||||||
if after:
|
if after:
|
||||||
where_clause += f'AND id > {after}'
|
where_clause += f"AND id > {after}"
|
||||||
|
|
||||||
message_ids = await app.db.fetch(f"""
|
message_ids = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE channel_id = $1 {where_clause}
|
WHERE channel_id = $1 {where_clause}
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
LIMIT {limit}
|
LIMIT {limit}
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for message_id in message_ids:
|
for message_id in message_ids:
|
||||||
msg = await app.storage.get_message(message_id['id'], user_id)
|
msg = await app.storage.get_message(message_id["id"], user_id)
|
||||||
|
|
||||||
if msg is None:
|
if msg is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
result.append(msg)
|
result.append(msg)
|
||||||
|
|
||||||
log.info('Fetched {} messages', len(result))
|
log.info("Fetched {} messages", len(result))
|
||||||
return jsonify(result)
|
return jsonify(result)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
|
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["GET"])
|
||||||
async def get_single_message(channel_id, message_id):
|
async def get_single_message(channel_id, message_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await channel_check(user_id, channel_id)
|
await channel_check(user_id, channel_id)
|
||||||
await channel_perm_check(user_id, channel_id, 'read_history')
|
await channel_perm_check(user_id, channel_id, "read_history")
|
||||||
|
|
||||||
message = await app.storage.get_message(message_id, user_id)
|
message = await app.storage.get_message(message_id, user_id)
|
||||||
|
|
||||||
|
|
@ -142,11 +145,15 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
||||||
|
|
||||||
# check the other party's dm_channel_state
|
# check the other party's dm_channel_state
|
||||||
|
|
||||||
dm_state = await app.db.fetchval("""
|
dm_state = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT dm_id
|
SELECT dm_id
|
||||||
FROM dm_channel_state
|
FROM dm_channel_state
|
||||||
WHERE user_id = $1 AND dm_id = $2
|
WHERE user_id = $1 AND dm_id = $2
|
||||||
""", peer_id, channel_id)
|
""",
|
||||||
|
peer_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
if dm_state:
|
if dm_state:
|
||||||
# the peer already has the channel
|
# the peer already has the channel
|
||||||
|
|
@ -157,18 +164,19 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
||||||
|
|
||||||
# dispatch CHANNEL_CREATE so the client knows which
|
# dispatch CHANNEL_CREATE so the client knows which
|
||||||
# channel the future event is about
|
# channel the future event is about
|
||||||
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
|
await app.dispatcher.dispatch_user(peer_id, "CHANNEL_CREATE", dm_chan)
|
||||||
|
|
||||||
# subscribe the peer to the channel
|
# subscribe the peer to the channel
|
||||||
await app.dispatcher.sub('channel', channel_id, peer_id)
|
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||||
|
|
||||||
# insert it on dm_channel_state so the client
|
# insert it on dm_channel_state so the client
|
||||||
# is subscribed on the future
|
# is subscribed on the future
|
||||||
await try_dm_state(peer_id, channel_id)
|
await try_dm_state(peer_id, channel_id)
|
||||||
|
|
||||||
|
|
||||||
async def create_message(channel_id: int, actual_guild_id: int,
|
async def create_message(
|
||||||
author_id: int, data: dict) -> int:
|
channel_id: int, actual_guild_id: int, author_id: int, data: dict
|
||||||
|
) -> int:
|
||||||
message_id = get_snowflake()
|
message_id = get_snowflake()
|
||||||
|
|
||||||
async with app.db.acquire() as conn:
|
async with app.db.acquire() as conn:
|
||||||
|
|
@ -185,32 +193,32 @@ async def create_message(channel_id: int, actual_guild_id: int,
|
||||||
channel_id,
|
channel_id,
|
||||||
actual_guild_id,
|
actual_guild_id,
|
||||||
author_id,
|
author_id,
|
||||||
data['content'],
|
data["content"],
|
||||||
|
data["tts"],
|
||||||
data['tts'],
|
data["everyone_mention"],
|
||||||
data['everyone_mention'],
|
data["nonce"],
|
||||||
|
|
||||||
data['nonce'],
|
|
||||||
MessageType.DEFAULT.value,
|
MessageType.DEFAULT.value,
|
||||||
data.get('embeds') or []
|
data.get("embeds") or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
async def msg_guild_text_mentions(payload: dict, guild_id: int,
|
|
||||||
mentions_everyone: bool, mentions_here: bool):
|
async def msg_guild_text_mentions(
|
||||||
|
payload: dict, guild_id: int, mentions_everyone: bool, mentions_here: bool
|
||||||
|
):
|
||||||
"""Calculates mention data side-effects."""
|
"""Calculates mention data side-effects."""
|
||||||
channel_id = int(payload['channel_id'])
|
channel_id = int(payload["channel_id"])
|
||||||
|
|
||||||
# calculate the user ids we'll bump the mention count for
|
# calculate the user ids we'll bump the mention count for
|
||||||
uids = set()
|
uids = set()
|
||||||
|
|
||||||
# first is extracting user mentions
|
# first is extracting user mentions
|
||||||
for mention in payload['mentions']:
|
for mention in payload["mentions"]:
|
||||||
uids.add(int(mention['id']))
|
uids.add(int(mention["id"]))
|
||||||
|
|
||||||
# then role mentions
|
# then role mentions
|
||||||
for role_mention in payload['mention_roles']:
|
for role_mention in payload["mention_roles"]:
|
||||||
role_id = int(role_mention)
|
role_id = int(role_mention)
|
||||||
member_ids = await app.storage.get_role_members(role_id)
|
member_ids = await app.storage.get_role_members(role_id)
|
||||||
|
|
||||||
|
|
@ -223,11 +231,14 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
|
||||||
if mentions_here:
|
if mentions_here:
|
||||||
uids = set()
|
uids = set()
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE user_read_state
|
UPDATE user_read_state
|
||||||
SET mention_count = mention_count + 1
|
SET mention_count = mention_count + 1
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
# at-here updates the read state
|
# at-here updates the read state
|
||||||
# for all users, including the ones
|
# for all users, including the ones
|
||||||
|
|
@ -238,19 +249,26 @@ async def msg_guild_text_mentions(payload: dict, guild_id: int,
|
||||||
|
|
||||||
member_ids = await app.storage.get_member_ids(guild_id)
|
member_ids = await app.storage.get_member_ids(guild_id)
|
||||||
|
|
||||||
await app.db.executemany("""
|
await app.db.executemany(
|
||||||
|
"""
|
||||||
UPDATE user_read_state
|
UPDATE user_read_state
|
||||||
SET mention_count = mention_count + 1
|
SET mention_count = mention_count + 1
|
||||||
WHERE channel_id = $1 AND user_id = $2
|
WHERE channel_id = $1 AND user_id = $2
|
||||||
""", [(channel_id, uid) for uid in member_ids])
|
""",
|
||||||
|
[(channel_id, uid) for uid in member_ids],
|
||||||
|
)
|
||||||
|
|
||||||
for user_id in uids:
|
for user_id in uids:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE user_read_state
|
UPDATE user_read_state
|
||||||
SET mention_count = mention_count + 1
|
SET mention_count = mention_count + 1
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
AND channel_id = $2
|
AND channel_id = $2
|
||||||
""", user_id, channel_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def msg_create_request() -> tuple:
|
async def msg_create_request() -> tuple:
|
||||||
|
|
@ -264,12 +282,12 @@ async def msg_create_request() -> tuple:
|
||||||
|
|
||||||
# NOTE: embed isn't set on form data
|
# NOTE: embed isn't set on form data
|
||||||
json_from_form = {
|
json_from_form = {
|
||||||
'content': form.get('content', ''),
|
"content": form.get("content", ""),
|
||||||
'nonce': form.get('nonce', '0'),
|
"nonce": form.get("nonce", "0"),
|
||||||
'tts': json.loads(form.get('tts', 'false')),
|
"tts": json.loads(form.get("tts", "false")),
|
||||||
}
|
}
|
||||||
|
|
||||||
payload_json = json.loads(form.get('payload_json', '{}'))
|
payload_json = json.loads(form.get("payload_json", "{}"))
|
||||||
|
|
||||||
json_from_form.update(request_json)
|
json_from_form.update(request_json)
|
||||||
json_from_form.update(payload_json)
|
json_from_form.update(payload_json)
|
||||||
|
|
@ -283,20 +301,19 @@ async def msg_create_request() -> tuple:
|
||||||
|
|
||||||
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
def msg_create_check_content(payload: dict, files: list, *, use_embeds=False):
|
||||||
"""Check if there is actually any content being sent to us."""
|
"""Check if there is actually any content being sent to us."""
|
||||||
has_content = bool(payload.get('content', ''))
|
has_content = bool(payload.get("content", ""))
|
||||||
has_files = len(files) > 0
|
has_files = len(files) > 0
|
||||||
|
|
||||||
embed_field = 'embeds' if use_embeds else 'embed'
|
embed_field = "embeds" if use_embeds else "embed"
|
||||||
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
has_embed = embed_field in payload and payload.get(embed_field) is not None
|
||||||
|
|
||||||
has_total_content = has_content or has_embed or has_files
|
has_total_content = has_content or has_embed or has_files
|
||||||
|
|
||||||
if not has_total_content:
|
if not has_total_content:
|
||||||
raise BadRequest('No content has been provided.')
|
raise BadRequest("No content has been provided.")
|
||||||
|
|
||||||
|
|
||||||
async def msg_add_attachment(message_id: int, channel_id: int,
|
async def msg_add_attachment(message_id: int, channel_id: int, attachment_file) -> int:
|
||||||
attachment_file) -> int:
|
|
||||||
"""Add an attachment to a message.
|
"""Add an attachment to a message.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
|
@ -318,7 +335,7 @@ async def msg_add_attachment(message_id: int, channel_id: int,
|
||||||
|
|
||||||
# understand file info
|
# understand file info
|
||||||
mime = attachment_file.mimetype
|
mime = attachment_file.mimetype
|
||||||
is_image = mime.startswith('image/')
|
is_image = mime.startswith("image/")
|
||||||
|
|
||||||
img_width, img_height = None, None
|
img_width, img_height = None, None
|
||||||
|
|
||||||
|
|
@ -346,17 +363,22 @@ async def msg_add_attachment(message_id: int, channel_id: int,
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
""",
|
""",
|
||||||
attachment_id, channel_id, message_id,
|
attachment_id,
|
||||||
filename, file_size,
|
channel_id,
|
||||||
is_image, img_width, img_height)
|
message_id,
|
||||||
|
filename,
|
||||||
|
file_size,
|
||||||
|
is_image,
|
||||||
|
img_width,
|
||||||
|
img_height,
|
||||||
|
)
|
||||||
|
|
||||||
ext = filename.split('.')[-1]
|
ext = filename.split(".")[-1]
|
||||||
|
|
||||||
with open(f'attachments/{attachment_id}.{ext}', 'wb') as attach_file:
|
with open(f"attachments/{attachment_id}.{ext}", "wb") as attach_file:
|
||||||
attach_file.write(attachment_file.stream.read())
|
attach_file.write(attachment_file.stream.read())
|
||||||
|
|
||||||
log.debug('written {} bytes for attachment id {}',
|
log.debug("written {} bytes for attachment id {}", file_size, attachment_id)
|
||||||
file_size, attachment_id)
|
|
||||||
|
|
||||||
return attachment_id
|
return attachment_id
|
||||||
|
|
||||||
|
|
@ -364,12 +386,12 @@ async def msg_add_attachment(message_id: int, channel_id: int,
|
||||||
async def _spawn_embed(app_, payload, **kwargs):
|
async def _spawn_embed(app_, payload, **kwargs):
|
||||||
app_.sched.spawn(
|
app_.sched.spawn(
|
||||||
process_url_embed(
|
process_url_embed(
|
||||||
app_.config, app_.storage, app_.dispatcher, app_.session,
|
app_.config, app_.storage, app_.dispatcher, app_.session, payload, **kwargs
|
||||||
payload, **kwargs)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
@bp.route("/<int:channel_id>/messages", methods=["POST"])
|
||||||
async def _create_message(channel_id):
|
async def _create_message(channel_id):
|
||||||
"""Create a message."""
|
"""Create a message."""
|
||||||
|
|
||||||
|
|
@ -379,7 +401,7 @@ async def _create_message(channel_id):
|
||||||
actual_guild_id = None
|
actual_guild_id = None
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
await channel_perm_check(user_id, channel_id, 'send_messages')
|
await channel_perm_check(user_id, channel_id, "send_messages")
|
||||||
actual_guild_id = guild_id
|
actual_guild_id = guild_id
|
||||||
|
|
||||||
payload_json, files = await msg_create_request()
|
payload_json, files = await msg_create_request()
|
||||||
|
|
@ -394,29 +416,31 @@ async def _create_message(channel_id):
|
||||||
await dm_pre_check(user_id, channel_id, guild_id)
|
await dm_pre_check(user_id, channel_id, guild_id)
|
||||||
|
|
||||||
can_everyone = await channel_perm_check(
|
can_everyone = await channel_perm_check(
|
||||||
user_id, channel_id, 'mention_everyone', False
|
user_id, channel_id, "mention_everyone", False
|
||||||
)
|
)
|
||||||
|
|
||||||
mentions_everyone = ('@everyone' in j['content']) and can_everyone
|
mentions_everyone = ("@everyone" in j["content"]) and can_everyone
|
||||||
mentions_here = ('@here' in j['content']) and can_everyone
|
mentions_here = ("@here" in j["content"]) and can_everyone
|
||||||
|
|
||||||
is_tts = (j.get('tts', False) and
|
is_tts = j.get("tts", False) and await channel_perm_check(
|
||||||
await channel_perm_check(
|
user_id, channel_id, "send_tts_messages", False
|
||||||
user_id, channel_id, 'send_tts_messages', False
|
)
|
||||||
))
|
|
||||||
|
|
||||||
message_id = await create_message(
|
message_id = await create_message(
|
||||||
channel_id, actual_guild_id, user_id, {
|
channel_id,
|
||||||
'content': j['content'],
|
actual_guild_id,
|
||||||
'tts': is_tts,
|
user_id,
|
||||||
'nonce': int(j.get('nonce', 0)),
|
{
|
||||||
'everyone_mention': mentions_everyone or mentions_here,
|
"content": j["content"],
|
||||||
|
"tts": is_tts,
|
||||||
|
"nonce": int(j.get("nonce", 0)),
|
||||||
|
"everyone_mention": mentions_everyone or mentions_here,
|
||||||
# fill_embed takes care of filling proxy and width/height
|
# fill_embed takes care of filling proxy and width/height
|
||||||
'embeds': ([await fill_embed(j['embed'])]
|
"embeds": (
|
||||||
if j.get('embed') is not None
|
[await fill_embed(j["embed"])] if j.get("embed") is not None else []
|
||||||
else []),
|
),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# for each file given, we add it as an attachment
|
# for each file given, we add it as an attachment
|
||||||
for pre_attachment in files:
|
for pre_attachment in files:
|
||||||
|
|
@ -429,8 +453,7 @@ async def _create_message(channel_id):
|
||||||
await _dm_pre_dispatch(channel_id, user_id)
|
await _dm_pre_dispatch(channel_id, user_id)
|
||||||
await _dm_pre_dispatch(channel_id, guild_id)
|
await _dm_pre_dispatch(channel_id, guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch('channel', channel_id,
|
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||||
'MESSAGE_CREATE', payload)
|
|
||||||
|
|
||||||
# spawn url processor for embedding of images
|
# spawn url processor for embedding of images
|
||||||
perms = await get_permissions(user_id, channel_id)
|
perms = await get_permissions(user_id, channel_id)
|
||||||
|
|
@ -438,54 +461,71 @@ async def _create_message(channel_id):
|
||||||
await _spawn_embed(app, payload)
|
await _spawn_embed(app, payload)
|
||||||
|
|
||||||
# update read state for the author
|
# update read state for the author
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE user_read_state
|
UPDATE user_read_state
|
||||||
SET last_message_id = $1
|
SET last_message_id = $1
|
||||||
WHERE channel_id = $2 AND user_id = $3
|
WHERE channel_id = $2 AND user_id = $3
|
||||||
""", message_id, channel_id, user_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if ctype == ChannelType.GUILD_TEXT:
|
if ctype == ChannelType.GUILD_TEXT:
|
||||||
await msg_guild_text_mentions(
|
await msg_guild_text_mentions(
|
||||||
payload, guild_id, mentions_everyone, mentions_here)
|
payload, guild_id, mentions_everyone, mentions_here
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(payload)
|
return jsonify(payload)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['PATCH'])
|
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["PATCH"])
|
||||||
async def edit_message(channel_id, message_id):
|
async def edit_message(channel_id, message_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_ctype, _guild_id = await channel_check(user_id, channel_id)
|
_ctype, _guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
author_id = await app.db.fetchval("""
|
author_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT author_id FROM messages
|
SELECT author_id FROM messages
|
||||||
WHERE messages.id = $1
|
WHERE messages.id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not author_id == user_id:
|
if not author_id == user_id:
|
||||||
raise Forbidden('You can not edit this message')
|
raise Forbidden("You can not edit this message")
|
||||||
|
|
||||||
j = await request.get_json()
|
j = await request.get_json()
|
||||||
updated = 'content' in j or 'embed' in j
|
updated = "content" in j or "embed" in j
|
||||||
old_message = await app.storage.get_message(message_id)
|
old_message = await app.storage.get_message(message_id)
|
||||||
|
|
||||||
if 'content' in j:
|
if "content" in j:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE messages
|
UPDATE messages
|
||||||
SET content=$1
|
SET content=$1
|
||||||
WHERE messages.id = $2
|
WHERE messages.id = $2
|
||||||
""", j['content'], message_id)
|
""",
|
||||||
|
j["content"],
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
if 'embed' in j:
|
if "embed" in j:
|
||||||
embeds = [await fill_embed(j['embed'])]
|
embeds = [await fill_embed(j["embed"])]
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE messages
|
UPDATE messages
|
||||||
SET embeds=$1
|
SET embeds=$1
|
||||||
WHERE messages.id = $2
|
WHERE messages.id = $2
|
||||||
""", embeds, message_id)
|
""",
|
||||||
|
embeds,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
# do not spawn process_url_embed since we already have embeds.
|
# do not spawn process_url_embed since we already have embeds.
|
||||||
elif 'content' in j:
|
elif "content" in j:
|
||||||
# if there weren't any embed changes BUT
|
# if there weren't any embed changes BUT
|
||||||
# we had a content change, we dispatch process_url_embed but with
|
# we had a content change, we dispatch process_url_embed but with
|
||||||
# an artificial delay.
|
# an artificial delay.
|
||||||
|
|
@ -495,46 +535,55 @@ async def edit_message(channel_id, message_id):
|
||||||
# BEFORE the MESSAGE_UPDATE with the new embeds (based on content)
|
# BEFORE the MESSAGE_UPDATE with the new embeds (based on content)
|
||||||
perms = await get_permissions(user_id, channel_id)
|
perms = await get_permissions(user_id, channel_id)
|
||||||
if perms.bits.embed_links:
|
if perms.bits.embed_links:
|
||||||
await _spawn_embed(app, {
|
await _spawn_embed(
|
||||||
'id': message_id,
|
app,
|
||||||
'channel_id': channel_id,
|
{
|
||||||
'content': j['content'],
|
"id": message_id,
|
||||||
'embeds': old_message['embeds']
|
"channel_id": channel_id,
|
||||||
}, delay=0.2)
|
"content": j["content"],
|
||||||
|
"embeds": old_message["embeds"],
|
||||||
|
},
|
||||||
|
delay=0.2,
|
||||||
|
)
|
||||||
|
|
||||||
# only set new timestamp upon actual update
|
# only set new timestamp upon actual update
|
||||||
if updated:
|
if updated:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE messages
|
UPDATE messages
|
||||||
SET edited_at = (now() at time zone 'utc')
|
SET edited_at = (now() at time zone 'utc')
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
message = await app.storage.get_message(message_id, user_id)
|
message = await app.storage.get_message(message_id, user_id)
|
||||||
|
|
||||||
# only dispatch MESSAGE_UPDATE if any update
|
# only dispatch MESSAGE_UPDATE if any update
|
||||||
# actually happened
|
# actually happened
|
||||||
if updated:
|
if updated:
|
||||||
await app.dispatcher.dispatch('channel', channel_id,
|
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", message)
|
||||||
'MESSAGE_UPDATE', message)
|
|
||||||
|
|
||||||
return jsonify(message)
|
return jsonify(message)
|
||||||
|
|
||||||
|
|
||||||
async def _del_msg_fkeys(message_id: int):
|
async def _del_msg_fkeys(message_id: int):
|
||||||
attachs = await app.db.fetch("""
|
attachs = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id FROM attachments
|
SELECT id FROM attachments
|
||||||
WHERE message_id = $1
|
WHERE message_id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
attachs = [r['id'] for r in attachs]
|
attachs = [r["id"] for r in attachs]
|
||||||
|
|
||||||
attachments = Path('./attachments')
|
attachments = Path("./attachments")
|
||||||
for attach_id in attachs:
|
for attach_id in attachs:
|
||||||
# anything starting with the given attachment shall be
|
# anything starting with the given attachment shall be
|
||||||
# deleted, because there may be resizes of the original
|
# deleted, because there may be resizes of the original
|
||||||
# attachment laying around.
|
# attachment laying around.
|
||||||
for filepath in attachments.glob(f'{attach_id}*'):
|
for filepath in attachments.glob(f"{attach_id}*"):
|
||||||
try_unlink(filepath)
|
try_unlink(filepath)
|
||||||
|
|
||||||
# after trying to delete all available attachments, delete
|
# after trying to delete all available attachments, delete
|
||||||
|
|
@ -542,51 +591,64 @@ async def _del_msg_fkeys(message_id: int):
|
||||||
|
|
||||||
# take the chance and delete all the data from the other tables too!
|
# take the chance and delete all the data from the other tables too!
|
||||||
|
|
||||||
tables = ['attachments', 'message_webhook_info',
|
tables = [
|
||||||
'message_reactions', 'channel_pins']
|
"attachments",
|
||||||
|
"message_webhook_info",
|
||||||
|
"message_reactions",
|
||||||
|
"channel_pins",
|
||||||
|
]
|
||||||
|
|
||||||
for table in tables:
|
for table in tables:
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
DELETE FROM {table}
|
DELETE FROM {table}
|
||||||
WHERE message_id = $1
|
WHERE message_id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['DELETE'])
|
@bp.route("/<int:channel_id>/messages/<int:message_id>", methods=["DELETE"])
|
||||||
async def delete_message(channel_id, message_id):
|
async def delete_message(channel_id, message_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
author_id = await app.db.fetchval("""
|
author_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT author_id FROM messages
|
SELECT author_id FROM messages
|
||||||
WHERE messages.id = $1
|
WHERE messages.id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
by_perm = await channel_perm_check(
|
|
||||||
user_id, channel_id, 'manage_messages', False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
by_perm = await channel_perm_check(user_id, channel_id, "manage_messages", False)
|
||||||
|
|
||||||
by_ownership = author_id == user_id
|
by_ownership = author_id == user_id
|
||||||
|
|
||||||
can_delete = by_perm or by_ownership
|
can_delete = by_perm or by_ownership
|
||||||
if not can_delete:
|
if not can_delete:
|
||||||
raise Forbidden('You can not delete this message')
|
raise Forbidden("You can not delete this message")
|
||||||
|
|
||||||
await _del_msg_fkeys(message_id)
|
await _del_msg_fkeys(message_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM messages
|
DELETE FROM messages
|
||||||
WHERE messages.id = $1
|
WHERE messages.id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'channel', channel_id,
|
"channel",
|
||||||
'MESSAGE_DELETE', {
|
channel_id,
|
||||||
'id': str(message_id),
|
"MESSAGE_DELETE",
|
||||||
'channel_id': str(channel_id),
|
{
|
||||||
|
"id": str(message_id),
|
||||||
|
"channel_id": str(channel_id),
|
||||||
# for lazy guilds
|
# for lazy guilds
|
||||||
'guild_id': str(guild_id),
|
"guild_id": str(guild_id),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -28,28 +28,32 @@ from litecord.system_messages import send_sys_message
|
||||||
from litecord.enums import MessageType, SYS_MESSAGES
|
from litecord.enums import MessageType, SYS_MESSAGES
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
|
|
||||||
bp = Blueprint('channel_pins', __name__)
|
bp = Blueprint("channel_pins", __name__)
|
||||||
|
|
||||||
|
|
||||||
class SysMsgInvalidAction(BadRequest):
|
class SysMsgInvalidAction(BadRequest):
|
||||||
"""Invalid action on a system message."""
|
"""Invalid action on a system message."""
|
||||||
|
|
||||||
error_code = 50021
|
error_code = 50021
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/pins', methods=['GET'])
|
@bp.route("/<int:channel_id>/pins", methods=["GET"])
|
||||||
async def get_pins(channel_id):
|
async def get_pins(channel_id):
|
||||||
"""Get the pins for a channel"""
|
"""Get the pins for a channel"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await channel_check(user_id, channel_id)
|
await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
ids = await app.db.fetch("""
|
ids = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT message_id
|
SELECT message_id
|
||||||
FROM channel_pins
|
FROM channel_pins
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
ORDER BY message_id DESC
|
ORDER BY message_id DESC
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
ids = [r['message_id'] for r in ids]
|
ids = [r["message_id"] for r in ids]
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for message_id in ids:
|
for message_id in ids:
|
||||||
|
|
@ -60,80 +64,96 @@ async def get_pins(channel_id):
|
||||||
return jsonify(res)
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['PUT'])
|
@bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["PUT"])
|
||||||
async def add_pin(channel_id, message_id):
|
async def add_pin(channel_id, message_id):
|
||||||
"""Add a pin to a channel"""
|
"""Add a pin to a channel"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||||
|
|
||||||
mtype = await app.db.fetchval("""
|
mtype = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT message_type
|
SELECT message_type
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
if mtype in SYS_MESSAGES:
|
if mtype in SYS_MESSAGES:
|
||||||
raise SysMsgInvalidAction(
|
raise SysMsgInvalidAction("Cannot execute action on a system message")
|
||||||
'Cannot execute action on a system message')
|
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO channel_pins (channel_id, message_id)
|
INSERT INTO channel_pins (channel_id, message_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", channel_id, message_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
row = await app.db.fetchrow("""
|
row = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT message_id
|
SELECT message_id
|
||||||
FROM channel_pins
|
FROM channel_pins
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
ORDER BY message_id ASC
|
ORDER BY message_id ASC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
timestamp = snowflake_datetime(row['message_id'])
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
|
||||||
'channel', channel_id, 'CHANNEL_PINS_UPDATE',
|
|
||||||
{
|
|
||||||
'channel_id': str(channel_id),
|
|
||||||
'last_pin_timestamp': timestamp_(timestamp)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await send_sys_message(app, channel_id,
|
timestamp = snowflake_datetime(row["message_id"])
|
||||||
MessageType.CHANNEL_PINNED_MESSAGE,
|
|
||||||
message_id, user_id)
|
|
||||||
|
|
||||||
return '', 204
|
await app.dispatcher.dispatch(
|
||||||
|
"channel",
|
||||||
|
channel_id,
|
||||||
|
"CHANNEL_PINS_UPDATE",
|
||||||
|
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp_(timestamp)},
|
||||||
|
)
|
||||||
|
|
||||||
|
await send_sys_message(
|
||||||
|
app, channel_id, MessageType.CHANNEL_PINNED_MESSAGE, message_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/pins/<int:message_id>', methods=['DELETE'])
|
@bp.route("/<int:channel_id>/pins/<int:message_id>", methods=["DELETE"])
|
||||||
async def delete_pin(channel_id, message_id):
|
async def delete_pin(channel_id, message_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM channel_pins
|
DELETE FROM channel_pins
|
||||||
WHERE channel_id = $1 AND message_id = $2
|
WHERE channel_id = $1 AND message_id = $2
|
||||||
""", channel_id, message_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
row = await app.db.fetchrow("""
|
row = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT message_id
|
SELECT message_id
|
||||||
FROM channel_pins
|
FROM channel_pins
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
ORDER BY message_id ASC
|
ORDER BY message_id ASC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
timestamp = snowflake_datetime(row['message_id'])
|
timestamp = snowflake_datetime(row["message_id"])
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'channel', channel_id, 'CHANNEL_PINS_UPDATE', {
|
"channel",
|
||||||
'channel_id': str(channel_id),
|
channel_id,
|
||||||
'last_pin_timestamp': timestamp.isoformat()
|
"CHANNEL_PINS_UPDATE",
|
||||||
})
|
{"channel_id": str(channel_id), "last_pin_timestamp": timestamp.isoformat()},
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -26,17 +26,15 @@ from logbook import Logger
|
||||||
from litecord.utils import async_map
|
from litecord.utils import async_map
|
||||||
from litecord.blueprints.auth import token_check
|
from litecord.blueprints.auth import token_check
|
||||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
from litecord.blueprints.channel.messages import (
|
from litecord.blueprints.channel.messages import query_tuple_from_args, extract_limit
|
||||||
query_tuple_from_args, extract_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
from litecord.enums import GUILD_CHANS
|
from litecord.enums import GUILD_CHANS
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('channel_reactions', __name__)
|
bp = Blueprint("channel_reactions", __name__)
|
||||||
|
|
||||||
BASEPATH = '/<int:channel_id>/messages/<int:message_id>/reactions'
|
BASEPATH = "/<int:channel_id>/messages/<int:message_id>/reactions"
|
||||||
|
|
||||||
|
|
||||||
class EmojiType(IntEnum):
|
class EmojiType(IntEnum):
|
||||||
|
|
@ -51,16 +49,14 @@ def emoji_info_from_str(emoji: str) -> tuple:
|
||||||
# unicode emoji just have the raw unicode.
|
# unicode emoji just have the raw unicode.
|
||||||
|
|
||||||
# try checking if the emoji is custom or unicode
|
# try checking if the emoji is custom or unicode
|
||||||
emoji_type = 0 if ':' in emoji else 1
|
emoji_type = 0 if ":" in emoji else 1
|
||||||
emoji_type = EmojiType(emoji_type)
|
emoji_type = EmojiType(emoji_type)
|
||||||
|
|
||||||
# extract the emoji id OR the unicode value of the emoji
|
# extract the emoji id OR the unicode value of the emoji
|
||||||
# depending if it is custom or not
|
# depending if it is custom or not
|
||||||
emoji_id = (int(emoji.split(':')[1])
|
emoji_id = int(emoji.split(":")[1]) if emoji_type == EmojiType.CUSTOM else emoji
|
||||||
if emoji_type == EmojiType.CUSTOM
|
|
||||||
else emoji)
|
|
||||||
|
|
||||||
emoji_name = emoji.split(':')[0]
|
emoji_name = emoji.split(":")[0]
|
||||||
|
|
||||||
return emoji_type, emoji_id, emoji_name
|
return emoji_type, emoji_id, emoji_name
|
||||||
|
|
||||||
|
|
@ -68,27 +64,27 @@ def emoji_info_from_str(emoji: str) -> tuple:
|
||||||
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
|
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
|
||||||
print(emoji_type, emoji_id, emoji_name)
|
print(emoji_type, emoji_id, emoji_name)
|
||||||
return {
|
return {
|
||||||
'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
|
"id": None if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||||
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
|
"name": emoji_name if emoji_type == EmojiType.UNICODE else emoji_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _make_payload(user_id, channel_id, message_id, partial):
|
def _make_payload(user_id, channel_id, message_id, partial):
|
||||||
return {
|
return {
|
||||||
'user_id': str(user_id),
|
"user_id": str(user_id),
|
||||||
'channel_id': str(channel_id),
|
"channel_id": str(channel_id),
|
||||||
'message_id': str(message_id),
|
"message_id": str(message_id),
|
||||||
'emoji': partial
|
"emoji": partial,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['PUT'])
|
@bp.route(f"{BASEPATH}/<emoji>/@me", methods=["PUT"])
|
||||||
async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
||||||
"""Put a reaction."""
|
"""Put a reaction."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
await channel_perm_check(user_id, channel_id, 'read_history')
|
await channel_perm_check(user_id, channel_id, "read_history")
|
||||||
|
|
||||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||||
|
|
||||||
|
|
@ -97,52 +93,64 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
|
||||||
|
|
||||||
# ADD_REACTIONS is only checked when this is the first
|
# ADD_REACTIONS is only checked when this is the first
|
||||||
# reaction in a message.
|
# reaction in a message.
|
||||||
reaction_count = await app.db.fetchval("""
|
reaction_count = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM message_reactions
|
FROM message_reactions
|
||||||
WHERE message_id = $1
|
WHERE message_id = $1
|
||||||
AND emoji_type = $2
|
AND emoji_type = $2
|
||||||
AND emoji_id = $3
|
AND emoji_id = $3
|
||||||
AND emoji_text = $4
|
AND emoji_text = $4
|
||||||
""", message_id, emoji_type, emoji_id, emoji_text)
|
""",
|
||||||
|
message_id,
|
||||||
|
emoji_type,
|
||||||
|
emoji_id,
|
||||||
|
emoji_text,
|
||||||
|
)
|
||||||
|
|
||||||
if reaction_count == 0:
|
if reaction_count == 0:
|
||||||
await channel_perm_check(user_id, channel_id, 'add_reactions')
|
await channel_perm_check(user_id, channel_id, "add_reactions")
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO message_reactions (message_id, user_id,
|
INSERT INTO message_reactions (message_id, user_id,
|
||||||
emoji_type, emoji_id, emoji_text)
|
emoji_type, emoji_id, emoji_text)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
""", message_id, user_id, emoji_type,
|
""",
|
||||||
|
message_id,
|
||||||
|
user_id,
|
||||||
|
emoji_type,
|
||||||
# if it is custom, we put the emoji_id on emoji_id
|
# if it is custom, we put the emoji_id on emoji_id
|
||||||
# column, if it isn't, we put it on emoji_text
|
# column, if it isn't, we put it on emoji_text
|
||||||
# column.
|
# column.
|
||||||
emoji_id, emoji_text
|
emoji_id,
|
||||||
|
emoji_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||||
payload = _make_payload(user_id, channel_id, message_id, partial)
|
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
payload['guild_id'] = str(guild_id)
|
payload["guild_id"] = str(guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'channel', channel_id, 'MESSAGE_REACTION_ADD', payload)
|
"channel", channel_id, "MESSAGE_REACTION_ADD", payload
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
|
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
|
||||||
"""Extract SQL clauses to search for specific emoji
|
"""Extract SQL clauses to search for specific emoji
|
||||||
in the message_reactions table."""
|
in the message_reactions table."""
|
||||||
param = f'${param}'
|
param = f"${param}"
|
||||||
|
|
||||||
# know which column to filter with
|
# know which column to filter with
|
||||||
where_ext = (f'AND emoji_id = {param}'
|
where_ext = (
|
||||||
if emoji_type == EmojiType.CUSTOM else
|
f"AND emoji_id = {param}"
|
||||||
f'AND emoji_text = {param}')
|
if emoji_type == EmojiType.CUSTOM
|
||||||
|
else f"AND emoji_text = {param}"
|
||||||
|
)
|
||||||
|
|
||||||
# which emoji to remove (custom or unicode)
|
# which emoji to remove (custom or unicode)
|
||||||
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
|
main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name
|
||||||
|
|
@ -157,8 +165,7 @@ def _emoji_sql_simple(emoji: str, param=4):
|
||||||
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
|
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
|
||||||
|
|
||||||
|
|
||||||
async def remove_reaction(channel_id: int, message_id: int,
|
async def remove_reaction(channel_id: int, message_id: int, user_id: int, emoji: str):
|
||||||
user_id: int, emoji: str):
|
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
|
||||||
|
|
@ -171,39 +178,45 @@ async def remove_reaction(channel_id: int, message_id: int,
|
||||||
AND user_id = $2
|
AND user_id = $2
|
||||||
AND emoji_type = $3
|
AND emoji_type = $3
|
||||||
{where_ext}
|
{where_ext}
|
||||||
""", message_id, user_id, emoji_type, main_emoji)
|
""",
|
||||||
|
message_id,
|
||||||
|
user_id,
|
||||||
|
emoji_type,
|
||||||
|
main_emoji,
|
||||||
|
)
|
||||||
|
|
||||||
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
|
||||||
payload = _make_payload(user_id, channel_id, message_id, partial)
|
payload = _make_payload(user_id, channel_id, message_id, partial)
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
payload['guild_id'] = str(guild_id)
|
payload["guild_id"] = str(guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload)
|
"channel", channel_id, "MESSAGE_REACTION_REMOVE", payload
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route(f'{BASEPATH}/<emoji>/@me', methods=['DELETE'])
|
@bp.route(f"{BASEPATH}/<emoji>/@me", methods=["DELETE"])
|
||||||
async def remove_own_reaction(channel_id, message_id, emoji):
|
async def remove_own_reaction(channel_id, message_id, emoji):
|
||||||
"""Remove a reaction."""
|
"""Remove a reaction."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await remove_reaction(channel_id, message_id, user_id, emoji)
|
await remove_reaction(channel_id, message_id, user_id, emoji)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route(f'{BASEPATH}/<emoji>/<int:other_id>', methods=['DELETE'])
|
@bp.route(f"{BASEPATH}/<emoji>/<int:other_id>", methods=["DELETE"])
|
||||||
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
|
async def remove_user_reaction(channel_id, message_id, emoji, other_id):
|
||||||
"""Remove a reaction made by another user."""
|
"""Remove a reaction made by another user."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||||
|
|
||||||
await remove_reaction(channel_id, message_id, other_id, emoji)
|
await remove_reaction(channel_id, message_id, other_id, emoji)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route(f'{BASEPATH}/<emoji>', methods=['GET'])
|
@bp.route(f"{BASEPATH}/<emoji>", methods=["GET"])
|
||||||
async def list_users_reaction(channel_id, message_id, emoji):
|
async def list_users_reaction(channel_id, message_id, emoji):
|
||||||
"""Get the list of all users who reacted with a certain emoji."""
|
"""Get the list of all users who reacted with a certain emoji."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -215,42 +228,49 @@ async def list_users_reaction(channel_id, message_id, emoji):
|
||||||
limit = extract_limit(request, 25)
|
limit = extract_limit(request, 25)
|
||||||
before, after = query_tuple_from_args(request.args, limit)
|
before, after = query_tuple_from_args(request.args, limit)
|
||||||
|
|
||||||
before_clause = 'AND user_id < $2' if before else ''
|
before_clause = "AND user_id < $2" if before else ""
|
||||||
after_clause = 'AND user_id > $3' if after else ''
|
after_clause = "AND user_id > $3" if after else ""
|
||||||
|
|
||||||
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
|
where_ext, main_emoji = _emoji_sql_simple(emoji, 4)
|
||||||
|
|
||||||
rows = await app.db.fetch(f"""
|
rows = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT user_id
|
SELECT user_id
|
||||||
FROM message_reactions
|
FROM message_reactions
|
||||||
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
|
WHERE message_id = $1 {before_clause} {after_clause} {where_ext}
|
||||||
""", message_id, before, after, main_emoji)
|
""",
|
||||||
|
message_id,
|
||||||
|
before,
|
||||||
|
after,
|
||||||
|
main_emoji,
|
||||||
|
)
|
||||||
|
|
||||||
user_ids = [r['user_id'] for r in rows]
|
user_ids = [r["user_id"] for r in rows]
|
||||||
users = await async_map(app.storage.get_user, user_ids)
|
users = await async_map(app.storage.get_user, user_ids)
|
||||||
return jsonify(users)
|
return jsonify(users)
|
||||||
|
|
||||||
|
|
||||||
@bp.route(f'{BASEPATH}', methods=['DELETE'])
|
@bp.route(f"{BASEPATH}", methods=["DELETE"])
|
||||||
async def remove_all_reactions(channel_id, message_id):
|
async def remove_all_reactions(channel_id, message_id):
|
||||||
"""Remove all reactions in a message."""
|
"""Remove all reactions in a message."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM message_reactions
|
DELETE FROM message_reactions
|
||||||
WHERE message_id = $1
|
WHERE message_id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
payload = {
|
payload = {"channel_id": str(channel_id), "message_id": str(message_id)}
|
||||||
'channel_id': str(channel_id),
|
|
||||||
'message_id': str(message_id),
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
payload['guild_id'] = str(guild_id)
|
payload["guild_id"] = str(guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload)
|
"channel", channel_id, "MESSAGE_REACTION_REMOVE_ALL", payload
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -28,24 +28,26 @@ from litecord.auth import token_check
|
||||||
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
|
from litecord.enums import ChannelType, GUILD_CHANS, MessageType, MessageFlags
|
||||||
from litecord.errors import ChannelNotFound, Forbidden, BadRequest
|
from litecord.errors import ChannelNotFound, Forbidden, BadRequest
|
||||||
from litecord.schemas import (
|
from litecord.schemas import (
|
||||||
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE,
|
validate,
|
||||||
|
CHAN_UPDATE,
|
||||||
|
CHAN_OVERWRITE,
|
||||||
|
SEARCH_CHANNEL,
|
||||||
|
GROUP_DM_UPDATE,
|
||||||
BULK_DELETE,
|
BULK_DELETE,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
from litecord.system_messages import send_sys_message
|
from litecord.system_messages import send_sys_message
|
||||||
from litecord.blueprints.dm_channels import (
|
from litecord.blueprints.dm_channels import gdm_remove_recipient, gdm_destroy
|
||||||
gdm_remove_recipient, gdm_destroy
|
|
||||||
)
|
|
||||||
from litecord.utils import search_result_from_list
|
from litecord.utils import search_result_from_list
|
||||||
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
from litecord.embed.messages import process_url_embed, msg_update_embeds
|
||||||
from litecord.snowflake import snowflake_datetime
|
from litecord.snowflake import snowflake_datetime
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('channels', __name__)
|
bp = Blueprint("channels", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>', methods=['GET'])
|
@bp.route("/<int:channel_id>", methods=["GET"])
|
||||||
async def get_channel(channel_id):
|
async def get_channel(channel_id):
|
||||||
"""Get a single channel's information"""
|
"""Get a single channel's information"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -56,7 +58,7 @@ async def get_channel(channel_id):
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
if not chan:
|
if not chan:
|
||||||
raise ChannelNotFound('single channel not found')
|
raise ChannelNotFound("single channel not found")
|
||||||
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
@ -64,106 +66,129 @@ async def get_channel(channel_id):
|
||||||
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
|
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
|
||||||
"""Update a guild's channel id field to NULL,
|
"""Update a guild's channel id field to NULL,
|
||||||
if it was set to the given channel id before."""
|
if it was set to the given channel id before."""
|
||||||
return await app.db.execute(f"""
|
return await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET {field} = NULL
|
SET {field} = NULL
|
||||||
WHERE guilds.id = $1 AND {field} = $2
|
WHERE guilds.id = $1 AND {field} = $2
|
||||||
""", guild_id, channel_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
async def _update_guild_chan_text(guild_id: int, channel_id: int):
|
||||||
res_embed = await __guild_chan_sql(
|
res_embed = await __guild_chan_sql(guild_id, channel_id, "embed_channel_id")
|
||||||
guild_id, channel_id, 'embed_channel_id')
|
|
||||||
|
|
||||||
res_widget = await __guild_chan_sql(
|
res_widget = await __guild_chan_sql(guild_id, channel_id, "widget_channel_id")
|
||||||
guild_id, channel_id, 'widget_channel_id')
|
|
||||||
|
|
||||||
res_system = await __guild_chan_sql(
|
res_system = await __guild_chan_sql(guild_id, channel_id, "system_channel_id")
|
||||||
guild_id, channel_id, 'system_channel_id')
|
|
||||||
|
|
||||||
# if none of them were actually updated,
|
# if none of them were actually updated,
|
||||||
# ignore and dont dispatch anything
|
# ignore and dont dispatch anything
|
||||||
if 'UPDATE 1' not in (res_embed, res_widget, res_system):
|
if "UPDATE 1" not in (res_embed, res_widget, res_system):
|
||||||
return
|
return
|
||||||
|
|
||||||
# at least one of the fields were updated,
|
# at least one of the fields were updated,
|
||||||
# dispatch GUILD_UPDATE
|
# dispatch GUILD_UPDATE
|
||||||
guild = await app.storage.get_guild(guild_id)
|
guild = await app.storage.get_guild(guild_id)
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||||
guild_id, 'GUILD_UPDATE', guild)
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
|
||||||
res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id')
|
res = await __guild_chan_sql(guild_id, channel_id, "afk_channel_id")
|
||||||
|
|
||||||
# guild didnt update
|
# guild didnt update
|
||||||
if res == 'UPDATE 0':
|
if res == "UPDATE 0":
|
||||||
return
|
return
|
||||||
|
|
||||||
guild = await app.storage.get_guild(guild_id)
|
guild = await app.storage.get_guild(guild_id)
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||||
guild_id, 'GUILD_UPDATE', guild)
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
|
||||||
# get all channels that were childs of the category
|
# get all channels that were childs of the category
|
||||||
childs = await app.db.fetch("""
|
childs = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM guild_channels
|
FROM guild_channels
|
||||||
WHERE guild_id = $1 AND parent_id = $2
|
WHERE guild_id = $1 AND parent_id = $2
|
||||||
""", guild_id, channel_id)
|
""",
|
||||||
childs = [c['id'] for c in childs]
|
guild_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
childs = [c["id"] for c in childs]
|
||||||
|
|
||||||
# update every child channel to parent_id = NULL
|
# update every child channel to parent_id = NULL
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guild_channels
|
UPDATE guild_channels
|
||||||
SET parent_id = NULL
|
SET parent_id = NULL
|
||||||
WHERE guild_id = $1 AND parent_id = $2
|
WHERE guild_id = $1 AND parent_id = $2
|
||||||
""", guild_id, channel_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
# tell all people in the guild of the category removal
|
# tell all people in the guild of the category removal
|
||||||
for child_id in childs:
|
for child_id in childs:
|
||||||
child = await app.storage.get_channel(child_id)
|
child = await app.storage.get_channel(child_id)
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", child)
|
||||||
guild_id, 'CHANNEL_UPDATE', child
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_messages(channel_id):
|
async def delete_messages(channel_id):
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM channel_pins
|
DELETE FROM channel_pins
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM user_read_state
|
DELETE FROM user_read_state
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM messages
|
DELETE FROM messages
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def guild_cleanup(channel_id):
|
async def guild_cleanup(channel_id):
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM channel_overwrites
|
DELETE FROM channel_overwrites
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM invites
|
DELETE FROM invites
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM webhooks
|
DELETE FROM webhooks
|
||||||
WHERE channel_id = $1
|
WHERE channel_id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>', methods=['DELETE'])
|
@bp.route("/<int:channel_id>", methods=["DELETE"])
|
||||||
async def close_channel(channel_id):
|
async def close_channel(channel_id):
|
||||||
"""Close or delete a channel."""
|
"""Close or delete a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -184,9 +209,8 @@ async def close_channel(channel_id):
|
||||||
}[ctype]
|
}[ctype]
|
||||||
|
|
||||||
main_tbl = {
|
main_tbl = {
|
||||||
ChannelType.GUILD_TEXT: 'guild_text_channels',
|
ChannelType.GUILD_TEXT: "guild_text_channels",
|
||||||
ChannelType.GUILD_VOICE: 'guild_voice_channels',
|
ChannelType.GUILD_VOICE: "guild_voice_channels",
|
||||||
|
|
||||||
# TODO: categories?
|
# TODO: categories?
|
||||||
}[ctype]
|
}[ctype]
|
||||||
|
|
||||||
|
|
@ -199,29 +223,37 @@ async def close_channel(channel_id):
|
||||||
await delete_messages(channel_id)
|
await delete_messages(channel_id)
|
||||||
await guild_cleanup(channel_id)
|
await guild_cleanup(channel_id)
|
||||||
|
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
DELETE FROM {main_tbl}
|
DELETE FROM {main_tbl}
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM guild_channels
|
DELETE FROM guild_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM channels
|
DELETE FROM channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
# clean its member list representation
|
# clean its member list representation
|
||||||
lazy_guilds = app.dispatcher.backends['lazy_guild']
|
lazy_guilds = app.dispatcher.backends["lazy_guild"]
|
||||||
lazy_guilds.remove_channel(channel_id)
|
lazy_guilds.remove_channel(channel_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_DELETE", chan)
|
||||||
guild_id, 'CHANNEL_DELETE', chan)
|
|
||||||
|
|
||||||
await app.dispatcher.remove('channel', channel_id)
|
await app.dispatcher.remove("channel", channel_id)
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
|
|
@ -231,27 +263,34 @@ async def close_channel(channel_id):
|
||||||
# instead, we close the channel for the user that is making
|
# instead, we close the channel for the user that is making
|
||||||
# the request via removing the link between them and
|
# the request via removing the link between them and
|
||||||
# the channel on dm_channel_state
|
# the channel on dm_channel_state
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM dm_channel_state
|
DELETE FROM dm_channel_state
|
||||||
WHERE user_id = $1 AND dm_id = $2
|
WHERE user_id = $1 AND dm_id = $2
|
||||||
""", user_id, channel_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
# unsubscribe
|
# unsubscribe
|
||||||
await app.dispatcher.unsub('channel', channel_id, user_id)
|
await app.dispatcher.unsub("channel", channel_id, user_id)
|
||||||
|
|
||||||
# nothing happens to the other party of the dm channel
|
# nothing happens to the other party of the dm channel
|
||||||
await app.dispatcher.dispatch_user(user_id, 'CHANNEL_DELETE', chan)
|
await app.dispatcher.dispatch_user(user_id, "CHANNEL_DELETE", chan)
|
||||||
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
if ctype == ChannelType.GROUP_DM:
|
if ctype == ChannelType.GROUP_DM:
|
||||||
await gdm_remove_recipient(channel_id, user_id)
|
await gdm_remove_recipient(channel_id, user_id)
|
||||||
|
|
||||||
gdm_count = await app.db.fetchval("""
|
gdm_count = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM group_dm_members
|
FROM group_dm_members
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
if gdm_count == 0:
|
if gdm_count == 0:
|
||||||
# destroy dm
|
# destroy dm
|
||||||
|
|
@ -261,11 +300,15 @@ async def close_channel(channel_id):
|
||||||
|
|
||||||
|
|
||||||
async def _update_pos(channel_id, pos: int):
|
async def _update_pos(channel_id, pos: int):
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guild_channels
|
UPDATE guild_channels
|
||||||
SET position = $1
|
SET position = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", pos, channel_id)
|
""",
|
||||||
|
pos,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
||||||
|
|
@ -274,20 +317,19 @@ async def _mass_chan_update(guild_id, channel_ids: List[Optional[int]]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||||
'guild', guild_id, 'CHANNEL_UPDATE', chan)
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_overwrites(channel_id: int, overwrites: list):
|
async def _process_overwrites(channel_id: int, overwrites: list):
|
||||||
for overwrite in overwrites:
|
for overwrite in overwrites:
|
||||||
|
|
||||||
# 0 for member overwrite, 1 for role overwrite
|
# 0 for member overwrite, 1 for role overwrite
|
||||||
target_type = 0 if overwrite['type'] == 'member' else 1
|
target_type = 0 if overwrite["type"] == "member" else 1
|
||||||
target_role = None if target_type == 0 else overwrite['id']
|
target_role = None if target_type == 0 else overwrite["id"]
|
||||||
target_user = overwrite['id'] if target_type == 0 else None
|
target_user = overwrite["id"] if target_type == 0 else None
|
||||||
|
|
||||||
col_name = 'target_user' if target_type == 0 else 'target_role'
|
col_name = "target_user" if target_type == 0 else "target_role"
|
||||||
constraint_name = f'channel_overwrites_{col_name}_uniq'
|
constraint_name = f"channel_overwrites_{col_name}_uniq"
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
f"""
|
f"""
|
||||||
|
|
@ -301,53 +343,66 @@ async def _process_overwrites(channel_id: int, overwrites: list):
|
||||||
UPDATE
|
UPDATE
|
||||||
SET allow = $5, deny = $6
|
SET allow = $5, deny = $6
|
||||||
""",
|
""",
|
||||||
channel_id, target_type,
|
channel_id,
|
||||||
target_role, target_user,
|
target_type,
|
||||||
overwrite['allow'], overwrite['deny'])
|
target_role,
|
||||||
|
target_user,
|
||||||
|
overwrite["allow"],
|
||||||
|
overwrite["deny"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/permissions/<int:overwrite_id>', methods=['PUT'])
|
@bp.route("/<int:channel_id>/permissions/<int:overwrite_id>", methods=["PUT"])
|
||||||
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
async def put_channel_overwrite(channel_id: int, overwrite_id: int):
|
||||||
"""Insert or modify a channel overwrite."""
|
"""Insert or modify a channel overwrite."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
if ctype not in GUILD_CHANS:
|
if ctype not in GUILD_CHANS:
|
||||||
raise ChannelNotFound('Only usable for guild channels.')
|
raise ChannelNotFound("Only usable for guild channels.")
|
||||||
|
|
||||||
await channel_perm_check(user_id, guild_id, 'manage_roles')
|
await channel_perm_check(user_id, guild_id, "manage_roles")
|
||||||
|
|
||||||
j = validate(
|
j = validate(
|
||||||
# inserting a fake id on the payload so validation passes through
|
# inserting a fake id on the payload so validation passes through
|
||||||
{**await request.get_json(), **{'id': -1}},
|
{**await request.get_json(), **{"id": -1}},
|
||||||
CHAN_OVERWRITE
|
CHAN_OVERWRITE,
|
||||||
)
|
)
|
||||||
|
|
||||||
await _process_overwrites(channel_id, [{
|
await _process_overwrites(
|
||||||
'allow': j['allow'],
|
channel_id,
|
||||||
'deny': j['deny'],
|
[
|
||||||
'type': j['type'],
|
{
|
||||||
'id': overwrite_id
|
"allow": j["allow"],
|
||||||
}])
|
"deny": j["deny"],
|
||||||
|
"type": j["type"],
|
||||||
|
"id": overwrite_id,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
await _mass_chan_update(guild_id, [channel_id])
|
await _mass_chan_update(guild_id, [channel_id])
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
||||||
if 'name' in j:
|
if "name" in j:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guild_channels
|
UPDATE guild_channels
|
||||||
SET name = $1
|
SET name = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j['name'], channel_id)
|
""",
|
||||||
|
j["name"],
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
if 'position' in j:
|
if "position" in j:
|
||||||
channel_data = await app.storage.get_channel_data(guild_id)
|
channel_data = await app.storage.get_channel_data(guild_id)
|
||||||
|
|
||||||
chans = [None] * len(channel_data)
|
chans = [None] * len(channel_data)
|
||||||
for chandata in channel_data:
|
for chandata in channel_data:
|
||||||
chans.insert(chandata['position'], int(chandata['id']))
|
chans.insert(chandata["position"], int(chandata["id"]))
|
||||||
|
|
||||||
# are we changing to the left or to the right?
|
# are we changing to the left or to the right?
|
||||||
|
|
||||||
|
|
@ -358,7 +413,7 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
||||||
# channelN-1 going to the position channel2
|
# channelN-1 going to the position channel2
|
||||||
# was occupying.
|
# was occupying.
|
||||||
current_pos = chans.index(channel_id)
|
current_pos = chans.index(channel_id)
|
||||||
new_pos = j['position']
|
new_pos = j["position"]
|
||||||
|
|
||||||
# if the new position is bigger than the current one,
|
# if the new position is bigger than the current one,
|
||||||
# we're making a left shift of all the channels that are
|
# we're making a left shift of all the channels that are
|
||||||
|
|
@ -366,113 +421,136 @@ async def _update_channel_common(channel_id, guild_id: int, j: dict):
|
||||||
left_shift = new_pos > current_pos
|
left_shift = new_pos > current_pos
|
||||||
|
|
||||||
# find all channels that we'll have to shift
|
# find all channels that we'll have to shift
|
||||||
shift_block = (chans[current_pos:new_pos]
|
shift_block = (
|
||||||
if left_shift else
|
chans[current_pos:new_pos] if left_shift else chans[new_pos:current_pos]
|
||||||
chans[new_pos:current_pos]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
shift = -1 if left_shift else 1
|
shift = -1 if left_shift else 1
|
||||||
|
|
||||||
# do the shift (to the left or to the right)
|
# do the shift (to the left or to the right)
|
||||||
await app.db.executemany("""
|
await app.db.executemany(
|
||||||
|
"""
|
||||||
UPDATE guild_channels
|
UPDATE guild_channels
|
||||||
SET position = position + $1
|
SET position = position + $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", [(shift, chan_id) for chan_id in shift_block])
|
""",
|
||||||
|
[(shift, chan_id) for chan_id in shift_block],
|
||||||
|
)
|
||||||
|
|
||||||
await _mass_chan_update(guild_id, shift_block)
|
await _mass_chan_update(guild_id, shift_block)
|
||||||
|
|
||||||
# since theres now an empty slot, move current channel to it
|
# since theres now an empty slot, move current channel to it
|
||||||
await _update_pos(channel_id, new_pos)
|
await _update_pos(channel_id, new_pos)
|
||||||
|
|
||||||
if 'channel_overwrites' in j:
|
if "channel_overwrites" in j:
|
||||||
overwrites = j['channel_overwrites']
|
overwrites = j["channel_overwrites"]
|
||||||
await _process_overwrites(channel_id, overwrites)
|
await _process_overwrites(channel_id, overwrites)
|
||||||
|
|
||||||
|
|
||||||
async def _common_guild_chan(channel_id, j: dict):
|
async def _common_guild_chan(channel_id, j: dict):
|
||||||
# common updates to the guild_channels table
|
# common updates to the guild_channels table
|
||||||
for field in [field for field in j.keys()
|
for field in [field for field in j.keys() if field in ("nsfw", "parent_id")]:
|
||||||
if field in ('nsfw', 'parent_id')]:
|
await app.db.execute(
|
||||||
await app.db.execute(f"""
|
f"""
|
||||||
UPDATE guild_channels
|
UPDATE guild_channels
|
||||||
SET {field} = $1
|
SET {field} = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j[field], channel_id)
|
""",
|
||||||
|
j[field],
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _update_text_channel(channel_id: int, j: dict, _user_id: int):
|
async def _update_text_channel(channel_id: int, j: dict, _user_id: int):
|
||||||
# first do the specific ones related to guild_text_channels
|
# first do the specific ones related to guild_text_channels
|
||||||
for field in [field for field in j.keys()
|
for field in [
|
||||||
if field in ('topic', 'rate_limit_per_user')]:
|
field for field in j.keys() if field in ("topic", "rate_limit_per_user")
|
||||||
await app.db.execute(f"""
|
]:
|
||||||
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE guild_text_channels
|
UPDATE guild_text_channels
|
||||||
SET {field} = $1
|
SET {field} = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j[field], channel_id)
|
""",
|
||||||
|
j[field],
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await _common_guild_chan(channel_id, j)
|
await _common_guild_chan(channel_id, j)
|
||||||
|
|
||||||
|
|
||||||
async def _update_voice_channel(channel_id: int, j: dict, _user_id: int):
|
async def _update_voice_channel(channel_id: int, j: dict, _user_id: int):
|
||||||
# first do the specific ones in guild_voice_channels
|
# first do the specific ones in guild_voice_channels
|
||||||
for field in [field for field in j.keys()
|
for field in [field for field in j.keys() if field in ("bitrate", "user_limit")]:
|
||||||
if field in ('bitrate', 'user_limit')]:
|
await app.db.execute(
|
||||||
await app.db.execute(f"""
|
f"""
|
||||||
UPDATE guild_voice_channels
|
UPDATE guild_voice_channels
|
||||||
SET {field} = $1
|
SET {field} = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j[field], channel_id)
|
""",
|
||||||
|
j[field],
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
# yes, i'm letting voice channels have nsfw, you cant stop me
|
# yes, i'm letting voice channels have nsfw, you cant stop me
|
||||||
await _common_guild_chan(channel_id, j)
|
await _common_guild_chan(channel_id, j)
|
||||||
|
|
||||||
|
|
||||||
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
|
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
|
||||||
if 'name' in j:
|
if "name" in j:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE group_dm_channels
|
UPDATE group_dm_channels
|
||||||
SET name = $1
|
SET name = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j['name'], channel_id)
|
""",
|
||||||
|
j["name"],
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await send_sys_message(
|
await send_sys_message(
|
||||||
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
|
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'icon' in j:
|
if "icon" in j:
|
||||||
new_icon = await app.icons.update(
|
new_icon = await app.icons.update(
|
||||||
'channel-icons', channel_id, j['icon'], always_icon=True
|
"channel-icons", channel_id, j["icon"], always_icon=True
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE group_dm_channels
|
UPDATE group_dm_channels
|
||||||
SET icon = $1
|
SET icon = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_icon.icon_hash, channel_id)
|
""",
|
||||||
|
new_icon.icon_hash,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await send_sys_message(
|
await send_sys_message(
|
||||||
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
|
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
|
@bp.route("/<int:channel_id>", methods=["PUT", "PATCH"])
|
||||||
async def update_channel(channel_id):
|
async def update_channel(channel_id):
|
||||||
"""Update a channel's information"""
|
"""Update a channel's information"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
|
if ctype not in (
|
||||||
ChannelType.GROUP_DM):
|
ChannelType.GUILD_TEXT,
|
||||||
raise ChannelNotFound('unable to edit unsupported chan type')
|
ChannelType.GUILD_VOICE,
|
||||||
|
ChannelType.GROUP_DM,
|
||||||
|
):
|
||||||
|
raise ChannelNotFound("unable to edit unsupported chan type")
|
||||||
|
|
||||||
is_guild = ctype in GUILD_CHANS
|
is_guild = ctype in GUILD_CHANS
|
||||||
|
|
||||||
if is_guild:
|
if is_guild:
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_channels')
|
await channel_perm_check(user_id, channel_id, "manage_channels")
|
||||||
|
|
||||||
j = validate(await request.get_json(),
|
j = validate(await request.get_json(), CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
|
||||||
CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
|
|
||||||
|
|
||||||
# TODO: categories
|
# TODO: categories
|
||||||
update_handler = {
|
update_handler = {
|
||||||
|
|
@ -489,30 +567,32 @@ async def update_channel(channel_id):
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
if is_guild:
|
if is_guild:
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("guild", guild_id, "CHANNEL_UPDATE", chan)
|
||||||
'guild', guild_id, 'CHANNEL_UPDATE', chan)
|
|
||||||
else:
|
else:
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||||
'channel', channel_id, 'CHANNEL_UPDATE', chan)
|
|
||||||
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/typing', methods=['POST'])
|
@bp.route("/<int:channel_id>/typing", methods=["POST"])
|
||||||
async def trigger_typing(channel_id):
|
async def trigger_typing(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch('channel', channel_id, 'TYPING_START', {
|
await app.dispatcher.dispatch(
|
||||||
'channel_id': str(channel_id),
|
"channel",
|
||||||
'user_id': str(user_id),
|
channel_id,
|
||||||
'timestamp': int(time.time()),
|
"TYPING_START",
|
||||||
|
{
|
||||||
|
"channel_id": str(channel_id),
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"timestamp": int(time.time()),
|
||||||
# guild_id for lazy guilds
|
# guild_id for lazy guilds
|
||||||
'guild_id': str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
"guild_id": str(guild_id) if ctype == ChannelType.GUILD_TEXT else None,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||||
|
|
@ -521,7 +601,8 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||||
if not message_id:
|
if not message_id:
|
||||||
message_id = await app.storage.chan_last_message(channel_id)
|
message_id = await app.storage.chan_last_message(channel_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO user_read_state
|
INSERT INTO user_read_state
|
||||||
(user_id, channel_id, last_message_id, mention_count)
|
(user_id, channel_id, last_message_id, mention_count)
|
||||||
VALUES
|
VALUES
|
||||||
|
|
@ -532,26 +613,31 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None):
|
||||||
SET last_message_id = $3, mention_count = 0
|
SET last_message_id = $3, mention_count = 0
|
||||||
WHERE user_read_state.user_id = $1
|
WHERE user_read_state.user_id = $1
|
||||||
AND user_read_state.channel_id = $2
|
AND user_read_state.channel_id = $2
|
||||||
""", user_id, channel_id, message_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
channel_id,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
if guild_id:
|
if guild_id:
|
||||||
await app.dispatcher.dispatch_user_guild(
|
await app.dispatcher.dispatch_user_guild(
|
||||||
user_id, guild_id, 'MESSAGE_ACK', {
|
user_id,
|
||||||
'message_id': str(message_id),
|
guild_id,
|
||||||
'channel_id': str(channel_id)
|
"MESSAGE_ACK",
|
||||||
})
|
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# we don't use ChannelDispatcher here because since
|
# we don't use ChannelDispatcher here because since
|
||||||
# guild_id is None, all user devices are already subscribed
|
# guild_id is None, all user devices are already subscribed
|
||||||
# to the given channel (a dm or a group dm)
|
# to the given channel (a dm or a group dm)
|
||||||
await app.dispatcher.dispatch_user(
|
await app.dispatcher.dispatch_user(
|
||||||
user_id, 'MESSAGE_ACK', {
|
user_id,
|
||||||
'message_id': str(message_id),
|
"MESSAGE_ACK",
|
||||||
'channel_id': str(channel_id)
|
{"message_id": str(message_id), "channel_id": str(channel_id)},
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>/ack', methods=['POST'])
|
@bp.route("/<int:channel_id>/messages/<int:message_id>/ack", methods=["POST"])
|
||||||
async def ack_channel(channel_id, message_id):
|
async def ack_channel(channel_id, message_id):
|
||||||
"""Acknowledge a channel."""
|
"""Acknowledge a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -562,40 +648,47 @@ async def ack_channel(channel_id, message_id):
|
||||||
|
|
||||||
await channel_ack(user_id, guild_id, channel_id, message_id)
|
await channel_ack(user_id, guild_id, channel_id, message_id)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify(
|
||||||
|
{
|
||||||
# token seems to be used for
|
# token seems to be used for
|
||||||
# data collection activities,
|
# data collection activities,
|
||||||
# so we never use it.
|
# so we never use it.
|
||||||
'token': None
|
"token": None
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/ack', methods=['DELETE'])
|
@bp.route("/<int:channel_id>/messages/ack", methods=["DELETE"])
|
||||||
async def delete_read_state(channel_id):
|
async def delete_read_state(channel_id):
|
||||||
"""Delete the read state of a channel."""
|
"""Delete the read state of a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await channel_check(user_id, channel_id)
|
await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM user_read_state
|
DELETE FROM user_read_state
|
||||||
WHERE user_id = $1 AND channel_id = $2
|
WHERE user_id = $1 AND channel_id = $2
|
||||||
""", user_id, channel_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/search', methods=['GET'])
|
@bp.route("/<int:channel_id>/messages/search", methods=["GET"])
|
||||||
async def _search_channel(channel_id):
|
async def _search_channel(channel_id):
|
||||||
"""Search in DMs or group DMs"""
|
"""Search in DMs or group DMs"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await channel_check(user_id, channel_id)
|
await channel_check(user_id, channel_id)
|
||||||
await channel_perm_check(user_id, channel_id, 'read_messages')
|
await channel_perm_check(user_id, channel_id, "read_messages")
|
||||||
|
|
||||||
j = validate(dict(request.args), SEARCH_CHANNEL)
|
j = validate(dict(request.args), SEARCH_CHANNEL)
|
||||||
|
|
||||||
# main search query
|
# main search query
|
||||||
# the context (before/after) columns are copied from the guilds blueprint.
|
# the context (before/after) columns are copied from the guilds blueprint.
|
||||||
rows = await app.db.fetch(f"""
|
rows = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT orig.id AS current_id,
|
SELECT orig.id AS current_id,
|
||||||
COUNT(*) OVER() AS total_results,
|
COUNT(*) OVER() AS total_results,
|
||||||
array((SELECT messages.id AS before_id
|
array((SELECT messages.id AS before_id
|
||||||
|
|
@ -611,28 +704,40 @@ async def _search_channel(channel_id):
|
||||||
ORDER BY orig.id DESC
|
ORDER BY orig.id DESC
|
||||||
LIMIT 50
|
LIMIT 50
|
||||||
OFFSET $2
|
OFFSET $2
|
||||||
""", channel_id, j['offset'], j['content'])
|
""",
|
||||||
|
channel_id,
|
||||||
|
j["offset"],
|
||||||
|
j["content"],
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(await search_result_from_list(rows))
|
return jsonify(await search_result_from_list(rows))
|
||||||
|
|
||||||
|
|
||||||
# NOTE that those functions stay here until some other
|
# NOTE that those functions stay here until some other
|
||||||
# route or code wants it.
|
# route or code wants it.
|
||||||
|
|
||||||
|
|
||||||
async def _msg_update_flags(message_id: int, flags: int):
|
async def _msg_update_flags(message_id: int, flags: int):
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE messages
|
UPDATE messages
|
||||||
SET flags = $1
|
SET flags = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", flags, message_id)
|
""",
|
||||||
|
flags,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _msg_get_flags(message_id: int):
|
async def _msg_get_flags(message_id: int):
|
||||||
return await app.db.fetchval("""
|
return await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT flags
|
SELECT flags
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _msg_set_flags(message_id: int, new_flags: int):
|
async def _msg_set_flags(message_id: int, new_flags: int):
|
||||||
|
|
@ -647,8 +752,9 @@ async def _msg_unset_flags(message_id: int, unset_flags: int):
|
||||||
await _msg_update_flags(message_id, flags)
|
await _msg_update_flags(message_id, flags)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/<int:message_id>/suppress-embeds',
|
@bp.route(
|
||||||
methods=['POST'])
|
"/<int:channel_id>/messages/<int:message_id>/suppress-embeds", methods=["POST"]
|
||||||
|
)
|
||||||
async def suppress_embeds(channel_id: int, message_id: int):
|
async def suppress_embeds(channel_id: int, message_id: int):
|
||||||
"""Toggle the embeds in a message.
|
"""Toggle the embeds in a message.
|
||||||
|
|
||||||
|
|
@ -661,29 +767,27 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
||||||
# the checks here have been copied from the delete_message()
|
# the checks here have been copied from the delete_message()
|
||||||
# handler on blueprints.channel.messages. maybe we can combine
|
# handler on blueprints.channel.messages. maybe we can combine
|
||||||
# them someday?
|
# them someday?
|
||||||
author_id = await app.db.fetchval("""
|
author_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT author_id FROM messages
|
SELECT author_id FROM messages
|
||||||
WHERE messages.id = $1
|
WHERE messages.id = $1
|
||||||
""", message_id)
|
""",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
by_perms = await channel_perm_check(
|
by_perms = await channel_perm_check(user_id, channel_id, "manage_messages", False)
|
||||||
user_id, channel_id, 'manage_messages', False)
|
|
||||||
|
|
||||||
by_author = author_id == user_id
|
by_author = author_id == user_id
|
||||||
|
|
||||||
can_suppress = by_perms or by_author
|
can_suppress = by_perms or by_author
|
||||||
if not can_suppress:
|
if not can_suppress:
|
||||||
raise Forbidden('Not enough permissions.')
|
raise Forbidden("Not enough permissions.")
|
||||||
|
|
||||||
j = validate(
|
j = validate(await request.get_json(), {"suppress": {"type": "boolean"}})
|
||||||
await request.get_json(),
|
|
||||||
{'suppress': {'type': 'boolean'}},
|
|
||||||
)
|
|
||||||
|
|
||||||
suppress = j['suppress']
|
suppress = j["suppress"]
|
||||||
message = await app.storage.get_message(message_id)
|
message = await app.storage.get_message(message_id)
|
||||||
url_embeds = sum(
|
url_embeds = sum(1 for embed in message["embeds"] if embed["type"] == "url")
|
||||||
1 for embed in message['embeds'] if embed['type'] == 'url')
|
|
||||||
|
|
||||||
# NOTE for any future self. discord doing flags an optional thing instead
|
# NOTE for any future self. discord doing flags an optional thing instead
|
||||||
# of just giving 0 is a pretty bad idea because now i have to deal with
|
# of just giving 0 is a pretty bad idea because now i have to deal with
|
||||||
|
|
@ -693,8 +797,7 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
||||||
# delete all embeds then dispatch an update
|
# delete all embeds then dispatch an update
|
||||||
await _msg_set_flags(message_id, MessageFlags.suppress_embeds)
|
await _msg_set_flags(message_id, MessageFlags.suppress_embeds)
|
||||||
|
|
||||||
message['flags'] = \
|
message["flags"] = message.get("flags", 0) | MessageFlags.suppress_embeds
|
||||||
message.get('flags', 0) | MessageFlags.suppress_embeds
|
|
||||||
|
|
||||||
await msg_update_embeds(message, [], app.storage, app.dispatcher)
|
await msg_update_embeds(message, [], app.storage, app.dispatcher)
|
||||||
elif not suppress and not url_embeds:
|
elif not suppress and not url_embeds:
|
||||||
|
|
@ -702,30 +805,29 @@ async def suppress_embeds(channel_id: int, message_id: int):
|
||||||
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
|
await _msg_unset_flags(message_id, MessageFlags.suppress_embeds)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message.pop('flags')
|
message.pop("flags")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
app.sched.spawn(
|
app.sched.spawn(
|
||||||
process_url_embed(
|
process_url_embed(
|
||||||
app.config, app.storage, app.dispatcher, app.session,
|
app.config, app.storage, app.dispatcher, app.session, message
|
||||||
message
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages/bulk-delete', methods=['POST'])
|
@bp.route("/<int:channel_id>/messages/bulk-delete", methods=["POST"])
|
||||||
async def bulk_delete(channel_id: int):
|
async def bulk_delete(channel_id: int):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
guild_id = guild_id if ctype in GUILD_CHANS else None
|
guild_id = guild_id if ctype in GUILD_CHANS else None
|
||||||
|
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_messages')
|
await channel_perm_check(user_id, channel_id, "manage_messages")
|
||||||
|
|
||||||
j = validate(await request.get_json(), BULK_DELETE)
|
j = validate(await request.get_json(), BULK_DELETE)
|
||||||
message_ids = set(j['messages'])
|
message_ids = set(j["messages"])
|
||||||
|
|
||||||
# as per discord behavior, if any id here is older than two weeks,
|
# as per discord behavior, if any id here is older than two weeks,
|
||||||
# we must error. a cuter behavior would be returning the message ids
|
# we must error. a cuter behavior would be returning the message ids
|
||||||
|
|
@ -738,25 +840,28 @@ async def bulk_delete(channel_id: int):
|
||||||
raise BadRequest(50034)
|
raise BadRequest(50034)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
'guild_id': str(guild_id),
|
"guild_id": str(guild_id),
|
||||||
'channel_id': str(channel_id),
|
"channel_id": str(channel_id),
|
||||||
'ids': list(map(str, message_ids)),
|
"ids": list(map(str, message_ids)),
|
||||||
}
|
}
|
||||||
|
|
||||||
# payload.guild_id is optional in the event, not nullable.
|
# payload.guild_id is optional in the event, not nullable.
|
||||||
if guild_id is None:
|
if guild_id is None:
|
||||||
payload.pop('guild_id')
|
payload.pop("guild_id")
|
||||||
|
|
||||||
res = await app.db.execute("""
|
res = await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM messages
|
DELETE FROM messages
|
||||||
WHERE
|
WHERE
|
||||||
channel_id = $1
|
channel_id = $1
|
||||||
AND ARRAY[id] <@ $2::bigint[]
|
AND ARRAY[id] <@ $2::bigint[]
|
||||||
""", channel_id, list(message_ids))
|
""",
|
||||||
|
channel_id,
|
||||||
|
list(message_ids),
|
||||||
|
)
|
||||||
|
|
||||||
if res == 'DELETE 0':
|
if res == "DELETE 0":
|
||||||
raise BadRequest('No messages were removed')
|
raise BadRequest("No messages were removed")
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_DELETE_BULK", payload)
|
||||||
'channel', channel_id, 'MESSAGE_DELETE_BULK', payload)
|
return "", 204
|
||||||
return '', 204
|
|
||||||
|
|
|
||||||
|
|
@ -23,46 +23,57 @@ from quart import current_app as app
|
||||||
|
|
||||||
from litecord.enums import ChannelType, GUILD_CHANS
|
from litecord.enums import ChannelType, GUILD_CHANS
|
||||||
from litecord.errors import (
|
from litecord.errors import (
|
||||||
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
|
GuildNotFound,
|
||||||
|
ChannelNotFound,
|
||||||
|
Forbidden,
|
||||||
|
MissingPermissions,
|
||||||
)
|
)
|
||||||
from litecord.permissions import base_permissions, get_permissions
|
from litecord.permissions import base_permissions, get_permissions
|
||||||
|
|
||||||
|
|
||||||
async def guild_check(user_id: int, guild_id: int):
|
async def guild_check(user_id: int, guild_id: int):
|
||||||
"""Check if a user is in a guild."""
|
"""Check if a user is in a guild."""
|
||||||
joined_at = await app.db.fetchval("""
|
joined_at = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT joined_at
|
SELECT joined_at
|
||||||
FROM members
|
FROM members
|
||||||
WHERE user_id = $1 AND guild_id = $2
|
WHERE user_id = $1 AND guild_id = $2
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not joined_at:
|
if not joined_at:
|
||||||
raise GuildNotFound('guild not found')
|
raise GuildNotFound("guild not found")
|
||||||
|
|
||||||
|
|
||||||
async def guild_owner_check(user_id: int, guild_id: int):
|
async def guild_owner_check(user_id: int, guild_id: int):
|
||||||
"""Check if a user is the owner of the guild."""
|
"""Check if a user is the owner of the guild."""
|
||||||
owner_id = await app.db.fetchval("""
|
owner_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT owner_id
|
SELECT owner_id
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE guilds.id = $1
|
WHERE guilds.id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not owner_id:
|
if not owner_id:
|
||||||
raise GuildNotFound()
|
raise GuildNotFound()
|
||||||
|
|
||||||
if user_id != owner_id:
|
if user_id != owner_id:
|
||||||
raise Forbidden('You are not the owner of the guild')
|
raise Forbidden("You are not the owner of the guild")
|
||||||
|
|
||||||
|
|
||||||
async def channel_check(user_id, channel_id, *,
|
async def channel_check(
|
||||||
only: Union[ChannelType, List[ChannelType]] = None):
|
user_id, channel_id, *, only: Union[ChannelType, List[ChannelType]] = None
|
||||||
|
):
|
||||||
"""Check if the current user is authorized
|
"""Check if the current user is authorized
|
||||||
to read the channel's information."""
|
to read the channel's information."""
|
||||||
chan_type = await app.storage.get_chan_type(channel_id)
|
chan_type = await app.storage.get_chan_type(channel_id)
|
||||||
|
|
||||||
if chan_type is None:
|
if chan_type is None:
|
||||||
raise ChannelNotFound('channel type not found')
|
raise ChannelNotFound("channel type not found")
|
||||||
|
|
||||||
ctype = ChannelType(chan_type)
|
ctype = ChannelType(chan_type)
|
||||||
|
|
||||||
|
|
@ -70,14 +81,17 @@ async def channel_check(user_id, channel_id, *,
|
||||||
only = [only]
|
only = [only]
|
||||||
|
|
||||||
if only and ctype not in only:
|
if only and ctype not in only:
|
||||||
raise ChannelNotFound('invalid channel type')
|
raise ChannelNotFound("invalid channel type")
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
guild_id = await app.db.fetchval("""
|
guild_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
FROM guild_channels
|
FROM guild_channels
|
||||||
WHERE guild_channels.id = $1
|
WHERE guild_channels.id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
return ctype, guild_id
|
return ctype, guild_id
|
||||||
|
|
@ -87,11 +101,14 @@ async def channel_check(user_id, channel_id, *,
|
||||||
return ctype, peer_id
|
return ctype, peer_id
|
||||||
|
|
||||||
if ctype == ChannelType.GROUP_DM:
|
if ctype == ChannelType.GROUP_DM:
|
||||||
owner_id = await app.db.fetchval("""
|
owner_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT owner_id
|
SELECT owner_id
|
||||||
FROM group_dm_channels
|
FROM group_dm_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
return ctype, owner_id
|
return ctype, owner_id
|
||||||
|
|
||||||
|
|
@ -102,18 +119,17 @@ async def guild_perm_check(user_id, guild_id, permission: str):
|
||||||
hasperm = getattr(base_perms.bits, permission)
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
||||||
if not hasperm:
|
if not hasperm:
|
||||||
raise MissingPermissions('Missing permissions.')
|
raise MissingPermissions("Missing permissions.")
|
||||||
|
|
||||||
return bool(hasperm)
|
return bool(hasperm)
|
||||||
|
|
||||||
|
|
||||||
async def channel_perm_check(user_id, channel_id,
|
async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True):
|
||||||
permission: str, raise_err=True):
|
|
||||||
"""Check channel permissions for a user."""
|
"""Check channel permissions for a user."""
|
||||||
base_perms = await get_permissions(user_id, channel_id)
|
base_perms = await get_permissions(user_id, channel_id)
|
||||||
hasperm = getattr(base_perms.bits, permission)
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
||||||
if not hasperm and raise_err:
|
if not hasperm and raise_err:
|
||||||
raise MissingPermissions('Missing permissions.')
|
raise MissingPermissions("Missing permissions.")
|
||||||
|
|
||||||
return bool(hasperm)
|
return bool(hasperm)
|
||||||
|
|
|
||||||
|
|
@ -29,21 +29,29 @@ from litecord.system_messages import send_sys_message
|
||||||
from litecord.pubsub.channel import gdm_recipient_view
|
from litecord.pubsub.channel import gdm_recipient_view
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('dm_channels', __name__)
|
bp = Blueprint("dm_channels", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def _raw_gdm_add(channel_id, user_id):
|
async def _raw_gdm_add(channel_id, user_id):
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO group_dm_members (id, member_id)
|
INSERT INTO group_dm_members (id, member_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", channel_id, user_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _raw_gdm_remove(channel_id, user_id):
|
async def _raw_gdm_remove(channel_id, user_id):
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM group_dm_members
|
DELETE FROM group_dm_members
|
||||||
WHERE id = $1 AND member_id = $2
|
WHERE id = $1 AND member_id = $2
|
||||||
""", channel_id, user_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def gdm_create(user_id, peer_id) -> int:
|
async def gdm_create(user_id, peer_id) -> int:
|
||||||
|
|
@ -53,24 +61,32 @@ async def gdm_create(user_id, peer_id) -> int:
|
||||||
"""
|
"""
|
||||||
channel_id = get_snowflake()
|
channel_id = get_snowflake()
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO channels (id, channel_type)
|
INSERT INTO channels (id, channel_type)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", channel_id, ChannelType.GROUP_DM.value)
|
""",
|
||||||
|
channel_id,
|
||||||
|
ChannelType.GROUP_DM.value,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO group_dm_channels (id, owner_id, name, icon)
|
INSERT INTO group_dm_channels (id, owner_id, name, icon)
|
||||||
VALUES ($1, $2, NULL, NULL)
|
VALUES ($1, $2, NULL, NULL)
|
||||||
""", channel_id, user_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
await _raw_gdm_add(channel_id, user_id)
|
await _raw_gdm_add(channel_id, user_id)
|
||||||
await _raw_gdm_add(channel_id, peer_id)
|
await _raw_gdm_add(channel_id, peer_id)
|
||||||
|
|
||||||
await app.dispatcher.sub('channel', channel_id, user_id)
|
await app.dispatcher.sub("channel", channel_id, user_id)
|
||||||
await app.dispatcher.sub('channel', channel_id, peer_id)
|
await app.dispatcher.sub("channel", channel_id, peer_id)
|
||||||
|
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan)
|
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_CREATE", chan)
|
||||||
|
|
||||||
return channel_id
|
return channel_id
|
||||||
|
|
||||||
|
|
@ -89,17 +105,16 @@ async def gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
||||||
|
|
||||||
# the reasoning behind gdm_recipient_view is in its docstring.
|
# the reasoning behind gdm_recipient_view is in its docstring.
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id))
|
"user", peer_id, "CHANNEL_CREATE", gdm_recipient_view(chan, peer_id)
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_UPDATE", chan)
|
||||||
'channel', channel_id, 'CHANNEL_UPDATE', chan)
|
|
||||||
|
|
||||||
await app.dispatcher.sub('channel', peer_id)
|
await app.dispatcher.sub("channel", peer_id)
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
await send_sys_message(
|
await send_sys_message(
|
||||||
app, channel_id, MessageType.RECIPIENT_ADD,
|
app, channel_id, MessageType.RECIPIENT_ADD, user_id, peer_id
|
||||||
user_id, peer_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -116,22 +131,22 @@ async def gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
|
||||||
|
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id))
|
"user", peer_id, "CHANNEL_DELETE", gdm_recipient_view(chan, user_id)
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.unsub('channel', peer_id)
|
await app.dispatcher.unsub("channel", peer_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', {
|
"channel",
|
||||||
'channel_id': str(channel_id),
|
channel_id,
|
||||||
'user': await app.storage.get_user(peer_id)
|
"CHANNEL_RECIPIENT_REMOVE",
|
||||||
}
|
{"channel_id": str(channel_id), "user": await app.storage.get_user(peer_id)},
|
||||||
)
|
)
|
||||||
|
|
||||||
author_id = peer_id if user_id is None else user_id
|
author_id = peer_id if user_id is None else user_id
|
||||||
|
|
||||||
await send_sys_message(
|
await send_sys_message(
|
||||||
app, channel_id, MessageType.RECIPIENT_REMOVE,
|
app, channel_id, MessageType.RECIPIENT_REMOVE, author_id, peer_id
|
||||||
author_id, peer_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -139,40 +154,51 @@ async def gdm_destroy(channel_id):
|
||||||
"""Destroy a Group DM."""
|
"""Destroy a Group DM."""
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM group_dm_members
|
DELETE FROM group_dm_members
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
await app.db.execute("""
|
|
||||||
DELETE FROM group_dm_channels
|
|
||||||
WHERE id = $1
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
|
||||||
DELETE FROM channels
|
|
||||||
WHERE id = $1
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
|
||||||
'channel', channel_id, 'CHANNEL_DELETE', chan
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.remove('channel', channel_id)
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM group_dm_channels
|
||||||
|
WHERE id = $1
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM channels
|
||||||
|
WHERE id = $1
|
||||||
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.dispatcher.dispatch("channel", channel_id, "CHANNEL_DELETE", chan)
|
||||||
|
|
||||||
|
await app.dispatcher.remove("channel", channel_id)
|
||||||
|
|
||||||
|
|
||||||
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
|
async def gdm_is_member(channel_id: int, user_id: int) -> bool:
|
||||||
"""Return if the given user is a member of the Group DM."""
|
"""Return if the given user is a member of the Group DM."""
|
||||||
row = await app.db.fetchval("""
|
row = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM group_dm_members
|
FROM group_dm_members
|
||||||
WHERE id = $1 AND member_id = $2
|
WHERE id = $1 AND member_id = $2
|
||||||
""", channel_id, user_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return row is not None
|
return row is not None
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['PUT'])
|
@bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["PUT"])
|
||||||
async def add_to_group_dm(dm_chan, peer_id):
|
async def add_to_group_dm(dm_chan, peer_id):
|
||||||
"""Adds a member to a group dm OR creates a group dm."""
|
"""Adds a member to a group dm OR creates a group dm."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -182,8 +208,7 @@ async def add_to_group_dm(dm_chan, peer_id):
|
||||||
|
|
||||||
# other_id is the peer of the dm if the given channel is a dm
|
# other_id is the peer of the dm if the given channel is a dm
|
||||||
ctype, other_id = await channel_check(
|
ctype, other_id = await channel_check(
|
||||||
user_id, dm_chan,
|
user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM]
|
||||||
only=[ChannelType.DM, ChannelType.GROUP_DM]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# check relationship with the given user id
|
# check relationship with the given user id
|
||||||
|
|
@ -191,30 +216,24 @@ async def add_to_group_dm(dm_chan, peer_id):
|
||||||
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
||||||
|
|
||||||
if not friends:
|
if not friends:
|
||||||
raise BadRequest('Cant insert peer into dm')
|
raise BadRequest("Cant insert peer into dm")
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
dm_chan = await gdm_create(
|
dm_chan = await gdm_create(user_id, other_id)
|
||||||
user_id, other_id
|
|
||||||
)
|
|
||||||
|
|
||||||
await gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
|
await gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_channel(dm_chan))
|
||||||
await app.storage.get_channel(dm_chan)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['DELETE'])
|
@bp.route("/<int:dm_chan>/recipients/<int:peer_id>", methods=["DELETE"])
|
||||||
async def remove_from_group_dm(dm_chan, peer_id):
|
async def remove_from_group_dm(dm_chan, peer_id):
|
||||||
"""Remove users from group dm."""
|
"""Remove users from group dm."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_ctype, owner_id = await channel_check(
|
_ctype, owner_id = await channel_check(user_id, dm_chan, only=ChannelType.GROUP_DM)
|
||||||
user_id, dm_chan, only=ChannelType.GROUP_DM
|
|
||||||
)
|
|
||||||
|
|
||||||
if owner_id != user_id:
|
if owner_id != user_id:
|
||||||
raise Forbidden('You are now the owner of the group DM')
|
raise Forbidden("You are now the owner of the group DM")
|
||||||
|
|
||||||
await gdm_remove_recipient(dm_chan, peer_id)
|
await gdm_remove_recipient(dm_chan, peer_id)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -30,15 +30,13 @@ from ..snowflake import get_snowflake
|
||||||
|
|
||||||
from .auth import token_check
|
from .auth import token_check
|
||||||
|
|
||||||
from litecord.blueprints.dm_channels import (
|
from litecord.blueprints.dm_channels import gdm_create, gdm_add_recipient
|
||||||
gdm_create, gdm_add_recipient
|
|
||||||
)
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('dms', __name__)
|
bp = Blueprint("dms", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/channels', methods=['GET'])
|
@bp.route("/@me/channels", methods=["GET"])
|
||||||
async def get_dms():
|
async def get_dms():
|
||||||
"""Get the open DMs for the user."""
|
"""Get the open DMs for the user."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -53,11 +51,15 @@ async def try_dm_state(user_id: int, dm_id: int):
|
||||||
Does not do anything if the user is already
|
Does not do anything if the user is already
|
||||||
in the dm state.
|
in the dm state.
|
||||||
"""
|
"""
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO dm_channel_state (user_id, dm_id)
|
INSERT INTO dm_channel_state (user_id, dm_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
ON CONFLICT DO NOTHING
|
ON CONFLICT DO NOTHING
|
||||||
""", user_id, dm_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
dm_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def jsonify_dm(dm_id: int, user_id: int):
|
async def jsonify_dm(dm_id: int, user_id: int):
|
||||||
|
|
@ -69,12 +71,16 @@ async def create_dm(user_id, recipient_id):
|
||||||
"""Create a new dm with a user,
|
"""Create a new dm with a user,
|
||||||
or get the existing DM id if it already exists."""
|
or get the existing DM id if it already exists."""
|
||||||
|
|
||||||
dm_id = await app.db.fetchval("""
|
dm_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM dm_channels
|
FROM dm_channels
|
||||||
WHERE (party1_id = $1 OR party2_id = $1) AND
|
WHERE (party1_id = $1 OR party2_id = $1) AND
|
||||||
(party1_id = $2 OR party2_id = $2)
|
(party1_id = $2 OR party2_id = $2)
|
||||||
""", user_id, recipient_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
recipient_id,
|
||||||
|
)
|
||||||
|
|
||||||
if dm_id:
|
if dm_id:
|
||||||
return await jsonify_dm(dm_id, user_id)
|
return await jsonify_dm(dm_id, user_id)
|
||||||
|
|
@ -82,15 +88,24 @@ async def create_dm(user_id, recipient_id):
|
||||||
# if no dm was found, create a new one
|
# if no dm was found, create a new one
|
||||||
|
|
||||||
dm_id = get_snowflake()
|
dm_id = get_snowflake()
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO channels (id, channel_type)
|
INSERT INTO channels (id, channel_type)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", dm_id, ChannelType.DM.value)
|
""",
|
||||||
|
dm_id,
|
||||||
|
ChannelType.DM.value,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO dm_channels (id, party1_id, party2_id)
|
INSERT INTO dm_channels (id, party1_id, party2_id)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""", dm_id, user_id, recipient_id)
|
""",
|
||||||
|
dm_id,
|
||||||
|
user_id,
|
||||||
|
recipient_id,
|
||||||
|
)
|
||||||
|
|
||||||
# the dm state is something we use
|
# the dm state is something we use
|
||||||
# to give the currently "open dms"
|
# to give the currently "open dms"
|
||||||
|
|
@ -103,24 +118,24 @@ async def create_dm(user_id, recipient_id):
|
||||||
return await jsonify_dm(dm_id, user_id)
|
return await jsonify_dm(dm_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/channels', methods=['POST'])
|
@bp.route("/@me/channels", methods=["POST"])
|
||||||
async def start_dm():
|
async def start_dm():
|
||||||
"""Create a DM with a user."""
|
"""Create a DM with a user."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
j = validate(await request.get_json(), CREATE_DM)
|
j = validate(await request.get_json(), CREATE_DM)
|
||||||
recipient_id = j['recipient_id']
|
recipient_id = j["recipient_id"]
|
||||||
|
|
||||||
return await create_dm(user_id, recipient_id)
|
return await create_dm(user_id, recipient_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:p_user_id>/channels', methods=['POST'])
|
@bp.route("/<int:p_user_id>/channels", methods=["POST"])
|
||||||
async def create_group_dm(p_user_id: int):
|
async def create_group_dm(p_user_id: int):
|
||||||
"""Create a DM or a Group DM with user(s)."""
|
"""Create a DM or a Group DM with user(s)."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
assert user_id == p_user_id
|
assert user_id == p_user_id
|
||||||
|
|
||||||
j = validate(await request.get_json(), CREATE_GROUP_DM)
|
j = validate(await request.get_json(), CREATE_GROUP_DM)
|
||||||
recipients = j['recipients']
|
recipients = j["recipients"]
|
||||||
|
|
||||||
if len(recipients) == 1:
|
if len(recipients) == 1:
|
||||||
# its a group dm with 1 user... a dm!
|
# its a group dm with 1 user... a dm!
|
||||||
|
|
|
||||||
|
|
@ -23,37 +23,38 @@ from quart import Blueprint, jsonify, current_app as app
|
||||||
|
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
|
|
||||||
bp = Blueprint('gateway', __name__)
|
bp = Blueprint("gateway", __name__)
|
||||||
|
|
||||||
|
|
||||||
def get_gw():
|
def get_gw():
|
||||||
"""Get the gateway's web"""
|
"""Get the gateway's web"""
|
||||||
proto = 'wss://' if app.config['IS_SSL'] else 'ws://'
|
proto = "wss://" if app.config["IS_SSL"] else "ws://"
|
||||||
return f'{proto}{app.config["WEBSOCKET_URL"]}'
|
return f'{proto}{app.config["WEBSOCKET_URL"]}'
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/gateway')
|
@bp.route("/gateway")
|
||||||
def api_gateway():
|
def api_gateway():
|
||||||
"""Get the raw URL."""
|
"""Get the raw URL."""
|
||||||
return jsonify({
|
return jsonify({"url": get_gw()})
|
||||||
'url': get_gw()
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/gateway/bot')
|
@bp.route("/gateway/bot")
|
||||||
async def api_gateway_bot():
|
async def api_gateway_bot():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
guild_count = await app.db.fetchval("""
|
guild_count = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM members
|
FROM members
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
shards = max(int(guild_count / 1000), 1)
|
shards = max(int(guild_count / 1000), 1)
|
||||||
|
|
||||||
# get _ws.session ratelimit
|
# get _ws.session ratelimit
|
||||||
ratelimit = app.ratelimiter.get_ratelimit('_ws.session')
|
ratelimit = app.ratelimiter.get_ratelimit("_ws.session")
|
||||||
bucket = ratelimit.get_bucket(user_id)
|
bucket = ratelimit.get_bucket(user_id)
|
||||||
|
|
||||||
# timestamp of bucket reset
|
# timestamp of bucket reset
|
||||||
|
|
@ -62,13 +63,14 @@ async def api_gateway_bot():
|
||||||
# how many seconds until bucket reset
|
# how many seconds until bucket reset
|
||||||
reset_after_ts = reset_ts - time.time()
|
reset_after_ts = reset_ts - time.time()
|
||||||
|
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'url': get_gw(),
|
{
|
||||||
'shards': shards,
|
"url": get_gw(),
|
||||||
|
"shards": shards,
|
||||||
'session_start_limit': {
|
"session_start_limit": {
|
||||||
'total': bucket.requests,
|
"total": bucket.requests,
|
||||||
'remaining': bucket._tokens,
|
"remaining": bucket._tokens,
|
||||||
'reset_after': int(reset_after_ts * 1000),
|
"reset_after": int(reset_after_ts * 1000),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,5 +23,4 @@ from .channels import bp as guild_channels
|
||||||
from .mod import bp as guild_mod
|
from .mod import bp as guild_mod
|
||||||
from .emoji import bp as guild_emoji
|
from .emoji import bp as guild_emoji
|
||||||
|
|
||||||
__all__ = ['guild_roles', 'guild_members', 'guild_channels', 'guild_mod',
|
__all__ = ["guild_roles", "guild_members", "guild_channels", "guild_mod", "guild_emoji"]
|
||||||
'guild_emoji']
|
|
||||||
|
|
|
||||||
|
|
@ -25,23 +25,23 @@ from litecord.errors import BadRequest
|
||||||
from litecord.enums import ChannelType
|
from litecord.enums import ChannelType
|
||||||
from litecord.blueprints.guild.roles import gen_pairs
|
from litecord.blueprints.guild.roles import gen_pairs
|
||||||
|
|
||||||
from litecord.schemas import (
|
from litecord.schemas import validate, ROLE_UPDATE_POSITION, CHAN_CREATE
|
||||||
validate, ROLE_UPDATE_POSITION, CHAN_CREATE
|
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
|
||||||
)
|
|
||||||
from litecord.blueprints.checks import (
|
|
||||||
guild_check, guild_owner_check, guild_perm_check
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint('guild_channels', __name__)
|
bp = Blueprint("guild_channels", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||||
if ctype == ChannelType.GUILD_TEXT:
|
if ctype == ChannelType.GUILD_TEXT:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO guild_text_channels (id, topic)
|
INSERT INTO guild_text_channels (id, topic)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", channel_id, kwargs.get('topic', ''))
|
""",
|
||||||
|
channel_id,
|
||||||
|
kwargs.get("topic", ""),
|
||||||
|
)
|
||||||
elif ctype == ChannelType.GUILD_VOICE:
|
elif ctype == ChannelType.GUILD_VOICE:
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
|
|
@ -49,34 +49,48 @@ async def _specific_chan_create(channel_id, ctype, **kwargs):
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""",
|
""",
|
||||||
channel_id,
|
channel_id,
|
||||||
kwargs.get('bitrate', 64),
|
kwargs.get("bitrate", 64),
|
||||||
kwargs.get('user_limit', 0)
|
kwargs.get("user_limit", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def create_guild_channel(guild_id: int, channel_id: int,
|
async def create_guild_channel(
|
||||||
ctype: ChannelType, **kwargs):
|
guild_id: int, channel_id: int, ctype: ChannelType, **kwargs
|
||||||
|
):
|
||||||
"""Create a channel in a guild."""
|
"""Create a channel in a guild."""
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO channels (id, channel_type)
|
INSERT INTO channels (id, channel_type)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", channel_id, ctype.value)
|
""",
|
||||||
|
channel_id,
|
||||||
|
ctype.value,
|
||||||
|
)
|
||||||
|
|
||||||
# calc new pos
|
# calc new pos
|
||||||
max_pos = await app.db.fetchval("""
|
max_pos = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT MAX(position)
|
SELECT MAX(position)
|
||||||
FROM guild_channels
|
FROM guild_channels
|
||||||
WHERE guild_id = $1
|
WHERE guild_id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
# account for the first channel in a guild too
|
# account for the first channel in a guild too
|
||||||
max_pos = max_pos or 0
|
max_pos = max_pos or 0
|
||||||
|
|
||||||
# all channels go to guild_channels
|
# all channels go to guild_channels
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO guild_channels (id, guild_id, name, position)
|
INSERT INTO guild_channels (id, guild_id, name, position)
|
||||||
VALUES ($1, $2, $3, $4)
|
VALUES ($1, $2, $3, $4)
|
||||||
""", channel_id, guild_id, kwargs['name'], max_pos + 1)
|
""",
|
||||||
|
channel_id,
|
||||||
|
guild_id,
|
||||||
|
kwargs["name"],
|
||||||
|
max_pos + 1,
|
||||||
|
)
|
||||||
|
|
||||||
# the rest of sql magic is dependant on the channel
|
# the rest of sql magic is dependant on the channel
|
||||||
# we're creating (a text or voice or category),
|
# we're creating (a text or voice or category),
|
||||||
|
|
@ -84,35 +98,32 @@ async def create_guild_channel(guild_id: int, channel_id: int,
|
||||||
await _specific_chan_create(channel_id, ctype, **kwargs)
|
await _specific_chan_create(channel_id, ctype, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/channels', methods=['GET'])
|
@bp.route("/<int:guild_id>/channels", methods=["GET"])
|
||||||
async def get_guild_channels(guild_id):
|
async def get_guild_channels(guild_id):
|
||||||
"""Get the list of channels in a guild."""
|
"""Get the list of channels in a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_channel_data(guild_id))
|
||||||
await app.storage.get_channel_data(guild_id))
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/channels', methods=['POST'])
|
@bp.route("/<int:guild_id>/channels", methods=["POST"])
|
||||||
async def create_channel(guild_id):
|
async def create_channel(guild_id):
|
||||||
"""Create a channel in a guild."""
|
"""Create a channel in a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
j = validate(await request.get_json(), CHAN_CREATE)
|
j = validate(await request.get_json(), CHAN_CREATE)
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||||
|
|
||||||
channel_type = j.get('type', ChannelType.GUILD_TEXT)
|
channel_type = j.get("type", ChannelType.GUILD_TEXT)
|
||||||
channel_type = ChannelType(channel_type)
|
channel_type = ChannelType(channel_type)
|
||||||
|
|
||||||
if channel_type not in (ChannelType.GUILD_TEXT,
|
if channel_type not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE):
|
||||||
ChannelType.GUILD_VOICE):
|
raise BadRequest("Invalid channel type")
|
||||||
raise BadRequest('Invalid channel type')
|
|
||||||
|
|
||||||
new_channel_id = get_snowflake()
|
new_channel_id = get_snowflake()
|
||||||
await create_guild_channel(
|
await create_guild_channel(guild_id, new_channel_id, channel_type, **j)
|
||||||
guild_id, new_channel_id, channel_type, **j)
|
|
||||||
|
|
||||||
# TODO: do a better method
|
# TODO: do a better method
|
||||||
# subscribe the currently subscribed users to the new channel
|
# subscribe the currently subscribed users to the new channel
|
||||||
|
|
@ -120,14 +131,13 @@ async def create_channel(guild_id):
|
||||||
|
|
||||||
# since GuildDispatcher calls Storage.get_channel_ids,
|
# since GuildDispatcher calls Storage.get_channel_ids,
|
||||||
# it will subscribe all users to the newly created channel.
|
# it will subscribe all users to the newly created channel.
|
||||||
guild_pubsub = app.dispatcher.backends['guild']
|
guild_pubsub = app.dispatcher.backends["guild"]
|
||||||
user_ids = guild_pubsub.state[guild_id]
|
user_ids = guild_pubsub.state[guild_id]
|
||||||
for uid in user_ids:
|
for uid in user_ids:
|
||||||
await app.dispatcher.sub('guild', guild_id, uid)
|
await app.dispatcher.sub("guild", guild_id, uid)
|
||||||
|
|
||||||
chan = await app.storage.get_channel(new_channel_id)
|
chan = await app.storage.get_channel(new_channel_id)
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_CREATE", chan)
|
||||||
guild_id, 'CHANNEL_CREATE', chan)
|
|
||||||
return jsonify(chan)
|
return jsonify(chan)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -135,7 +145,7 @@ async def _chan_update_dispatch(guild_id: int, channel_id: int):
|
||||||
"""Fetch new information about the channel and dispatch
|
"""Fetch new information about the channel and dispatch
|
||||||
a single CHANNEL_UPDATE event to the guild."""
|
a single CHANNEL_UPDATE event to the guild."""
|
||||||
chan = await app.storage.get_channel(channel_id)
|
chan = await app.storage.get_channel(channel_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan)
|
await app.dispatcher.dispatch_guild(guild_id, "CHANNEL_UPDATE", chan)
|
||||||
|
|
||||||
|
|
||||||
async def _do_single_swap(guild_id: int, pair: tuple):
|
async def _do_single_swap(guild_id: int, pair: tuple):
|
||||||
|
|
@ -149,13 +159,14 @@ async def _do_single_swap(guild_id: int, pair: tuple):
|
||||||
conn = await app.db.acquire()
|
conn = await app.db.acquire()
|
||||||
|
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
await conn.executemany("""
|
await conn.executemany(
|
||||||
|
"""
|
||||||
UPDATE guild_channels
|
UPDATE guild_channels
|
||||||
SET position = $1
|
SET position = $1
|
||||||
WHERE id = $2 AND guild_id = $3
|
WHERE id = $2 AND guild_id = $3
|
||||||
""", [
|
""",
|
||||||
(new_pos_1, channel_1, guild_id),
|
[(new_pos_1, channel_1, guild_id), (new_pos_2, channel_2, guild_id)],
|
||||||
(new_pos_2, channel_2, guild_id)])
|
)
|
||||||
|
|
||||||
await app.db.release(conn)
|
await app.db.release(conn)
|
||||||
|
|
||||||
|
|
@ -173,30 +184,26 @@ async def _do_channel_swaps(guild_id: int, swap_pairs: list):
|
||||||
await _do_single_swap(guild_id, pair)
|
await _do_single_swap(guild_id, pair)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/channels', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/channels", methods=["PATCH"])
|
||||||
async def modify_channel_pos(guild_id):
|
async def modify_channel_pos(guild_id):
|
||||||
"""Change positions of channels in a guild."""
|
"""Change positions of channels in a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_owner_check(user_id, guild_id)
|
await guild_owner_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||||
|
|
||||||
# same thing as guild.roles, so we use
|
# same thing as guild.roles, so we use
|
||||||
# the same schema and all.
|
# the same schema and all.
|
||||||
raw_j = await request.get_json()
|
raw_j = await request.get_json()
|
||||||
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
|
||||||
j = j['roles']
|
j = j["roles"]
|
||||||
|
|
||||||
channels = await app.storage.get_channel_data(guild_id)
|
channels = await app.storage.get_channel_data(guild_id)
|
||||||
|
|
||||||
channel_positions = {chan['position']: int(chan['id'])
|
channel_positions = {chan["position"]: int(chan["id"]) for chan in channels}
|
||||||
for chan in channels}
|
|
||||||
|
|
||||||
swap_pairs = gen_pairs(
|
swap_pairs = gen_pairs(j, channel_positions)
|
||||||
j,
|
|
||||||
channel_positions
|
|
||||||
)
|
|
||||||
|
|
||||||
await _do_channel_swaps(guild_id, swap_pairs)
|
await _do_channel_swaps(guild_id, swap_pairs)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -27,78 +27,85 @@ from litecord.types import KILOBYTES
|
||||||
from litecord.images import parse_data_uri
|
from litecord.images import parse_data_uri
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
|
|
||||||
bp = Blueprint('guild.emoji', __name__)
|
bp = Blueprint("guild.emoji", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch_emojis(guild_id):
|
async def _dispatch_emojis(guild_id):
|
||||||
"""Dispatch a Guild Emojis Update payload to a guild."""
|
"""Dispatch a Guild Emojis Update payload to a guild."""
|
||||||
await app.dispatcher.dispatch('guild', guild_id, 'GUILD_EMOJIS_UPDATE', {
|
await app.dispatcher.dispatch(
|
||||||
'guild_id': str(guild_id),
|
"guild",
|
||||||
'emojis': await app.storage.get_guild_emojis(guild_id)
|
guild_id,
|
||||||
})
|
"GUILD_EMOJIS_UPDATE",
|
||||||
|
{
|
||||||
|
"guild_id": str(guild_id),
|
||||||
|
"emojis": await app.storage.get_guild_emojis(guild_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/emojis', methods=['GET'])
|
@bp.route("/<int:guild_id>/emojis", methods=["GET"])
|
||||||
async def _get_guild_emoji(guild_id):
|
async def _get_guild_emoji(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_guild_emojis(guild_id))
|
||||||
await app.storage.get_guild_emojis(guild_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['GET'])
|
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["GET"])
|
||||||
async def _get_guild_emoji_one(guild_id, emoji_id):
|
async def _get_guild_emoji_one(guild_id, emoji_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_emoji(emoji_id))
|
||||||
await app.storage.get_emoji(emoji_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _guild_emoji_size_check(guild_id: int, mime: str):
|
async def _guild_emoji_size_check(guild_id: int, mime: str):
|
||||||
limit = 50
|
limit = 50
|
||||||
if await app.storage.has_feature(guild_id, 'MORE_EMOJI'):
|
if await app.storage.has_feature(guild_id, "MORE_EMOJI"):
|
||||||
limit = 200
|
limit = 200
|
||||||
|
|
||||||
# NOTE: I'm assuming you can have 200 animated emojis.
|
# NOTE: I'm assuming you can have 200 animated emojis.
|
||||||
select_animated = mime == 'image/gif'
|
select_animated = mime == "image/gif"
|
||||||
|
|
||||||
total_emoji = await app.db.fetchval("""
|
total_emoji = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*) FROM guild_emoji
|
SELECT COUNT(*) FROM guild_emoji
|
||||||
WHERE guild_id = $1 AND animated = $2
|
WHERE guild_id = $1 AND animated = $2
|
||||||
""", guild_id, select_animated)
|
""",
|
||||||
|
guild_id,
|
||||||
|
select_animated,
|
||||||
|
)
|
||||||
|
|
||||||
if total_emoji >= limit:
|
if total_emoji >= limit:
|
||||||
# TODO: really return a BadRequest? needs more looking.
|
# TODO: really return a BadRequest? needs more looking.
|
||||||
raise BadRequest(f'too many emoji ({limit})')
|
raise BadRequest(f"too many emoji ({limit})")
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/emojis', methods=['POST'])
|
@bp.route("/<int:guild_id>/emojis", methods=["POST"])
|
||||||
async def _put_emoji(guild_id):
|
async def _put_emoji(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_emojis')
|
await guild_perm_check(user_id, guild_id, "manage_emojis")
|
||||||
|
|
||||||
j = validate(await request.get_json(), NEW_EMOJI)
|
j = validate(await request.get_json(), NEW_EMOJI)
|
||||||
|
|
||||||
# we have to parse it before passing on so that we know which
|
# we have to parse it before passing on so that we know which
|
||||||
# size to check.
|
# size to check.
|
||||||
mime, _ = parse_data_uri(j['image'])
|
mime, _ = parse_data_uri(j["image"])
|
||||||
await _guild_emoji_size_check(guild_id, mime)
|
await _guild_emoji_size_check(guild_id, mime)
|
||||||
|
|
||||||
emoji_id = get_snowflake()
|
emoji_id = get_snowflake()
|
||||||
|
|
||||||
icon = await app.icons.put(
|
icon = await app.icons.put(
|
||||||
'emoji', emoji_id, j['image'],
|
"emoji",
|
||||||
|
emoji_id,
|
||||||
|
j["image"],
|
||||||
# limits to emojis
|
# limits to emojis
|
||||||
bsize=128 * KILOBYTES, size=(128, 128)
|
bsize=128 * KILOBYTES,
|
||||||
|
size=(128, 128),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not icon:
|
if not icon:
|
||||||
return '', 400
|
return "", 400
|
||||||
|
|
||||||
# TODO: better way to detect animated emoji rather than just gifs,
|
# TODO: better way to detect animated emoji rather than just gifs,
|
||||||
# maybe a list perhaps?
|
# maybe a list perhaps?
|
||||||
|
|
@ -109,25 +116,25 @@ async def _put_emoji(guild_id):
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, $4, $5, $6)
|
($1, $2, $3, $4, $5, $6)
|
||||||
""",
|
""",
|
||||||
emoji_id, guild_id, user_id,
|
emoji_id,
|
||||||
j['name'],
|
guild_id,
|
||||||
|
user_id,
|
||||||
|
j["name"],
|
||||||
icon.icon_hash,
|
icon.icon_hash,
|
||||||
icon.mime == 'image/gif'
|
icon.mime == "image/gif",
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch_emojis(guild_id)
|
await _dispatch_emojis(guild_id)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_emoji(emoji_id))
|
||||||
await app.storage.get_emoji(emoji_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["PATCH"])
|
||||||
async def _patch_emoji(guild_id, emoji_id):
|
async def _patch_emoji(guild_id, emoji_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_emojis')
|
await guild_perm_check(user_id, guild_id, "manage_emojis")
|
||||||
|
|
||||||
j = validate(await request.get_json(), PATCH_EMOJI)
|
j = validate(await request.get_json(), PATCH_EMOJI)
|
||||||
emoji = await app.storage.get_emoji(emoji_id)
|
emoji = await app.storage.get_emoji(emoji_id)
|
||||||
|
|
@ -135,34 +142,39 @@ async def _patch_emoji(guild_id, emoji_id):
|
||||||
# if emoji.name is still the same, we don't update anything
|
# if emoji.name is still the same, we don't update anything
|
||||||
# or send ane events, just return the same emoji we'd send
|
# or send ane events, just return the same emoji we'd send
|
||||||
# as if we updated it.
|
# as if we updated it.
|
||||||
if j['name'] == emoji['name']:
|
if j["name"] == emoji["name"]:
|
||||||
return jsonify(emoji)
|
return jsonify(emoji)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guild_emoji
|
UPDATE guild_emoji
|
||||||
SET name = $1
|
SET name = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j['name'], emoji_id)
|
""",
|
||||||
|
j["name"],
|
||||||
|
emoji_id,
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch_emojis(guild_id)
|
await _dispatch_emojis(guild_id)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_emoji(emoji_id))
|
||||||
await app.storage.get_emoji(emoji_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/emojis/<int:emoji_id>', methods=['DELETE'])
|
@bp.route("/<int:guild_id>/emojis/<int:emoji_id>", methods=["DELETE"])
|
||||||
async def _del_emoji(guild_id, emoji_id):
|
async def _del_emoji(guild_id, emoji_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_emojis')
|
await guild_perm_check(user_id, guild_id, "manage_emojis")
|
||||||
|
|
||||||
# TODO: check if actually deleted
|
# TODO: check if actually deleted
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM guild_emoji
|
DELETE FROM guild_emoji
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", emoji_id)
|
""",
|
||||||
|
emoji_id,
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch_emojis(guild_id)
|
await _dispatch_emojis(guild_id)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -22,18 +22,14 @@ from quart import Blueprint, request, current_app as app, jsonify
|
||||||
from litecord.blueprints.auth import token_check
|
from litecord.blueprints.auth import token_check
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
|
|
||||||
from litecord.schemas import (
|
from litecord.schemas import validate, MEMBER_UPDATE
|
||||||
validate, MEMBER_UPDATE
|
|
||||||
)
|
|
||||||
|
|
||||||
from litecord.blueprints.checks import (
|
from litecord.blueprints.checks import guild_check, guild_owner_check, guild_perm_check
|
||||||
guild_check, guild_owner_check, guild_perm_check
|
|
||||||
)
|
|
||||||
|
|
||||||
bp = Blueprint('guild_members', __name__)
|
bp = Blueprint("guild_members", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['GET'])
|
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["GET"])
|
||||||
async def get_guild_member(guild_id, member_id):
|
async def get_guild_member(guild_id, member_id):
|
||||||
"""Get a member's information in a guild."""
|
"""Get a member's information in a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -42,7 +38,7 @@ async def get_guild_member(guild_id, member_id):
|
||||||
return jsonify(member)
|
return jsonify(member)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members', methods=['GET'])
|
@bp.route("/<int:guild_id>/members", methods=["GET"])
|
||||||
async def get_members(guild_id):
|
async def get_members(guild_id):
|
||||||
"""Get members inside a guild."""
|
"""Get members inside a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -50,34 +46,41 @@ async def get_members(guild_id):
|
||||||
|
|
||||||
j = await request.get_json()
|
j = await request.get_json()
|
||||||
|
|
||||||
limit, after = int(j.get('limit', 1)), j.get('after', 0)
|
limit, after = int(j.get("limit", 1)), j.get("after", 0)
|
||||||
|
|
||||||
if limit < 1 or limit > 1000:
|
if limit < 1 or limit > 1000:
|
||||||
raise BadRequest('limit not in 1-1000 range')
|
raise BadRequest("limit not in 1-1000 range")
|
||||||
|
|
||||||
user_ids = await app.db.fetch(f"""
|
user_ids = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT user_id
|
SELECT user_id
|
||||||
WHERE guild_id = $1, user_id > $2
|
WHERE guild_id = $1, user_id > $2
|
||||||
LIMIT {limit}
|
LIMIT {limit}
|
||||||
ORDER BY user_id ASC
|
ORDER BY user_id ASC
|
||||||
""", guild_id, after)
|
""",
|
||||||
|
guild_id,
|
||||||
|
after,
|
||||||
|
)
|
||||||
|
|
||||||
user_ids = [r[0] for r in user_ids]
|
user_ids = [r[0] for r in user_ids]
|
||||||
members = await app.storage.get_member_multi(guild_id, user_ids)
|
members = await app.storage.get_member_multi(guild_id, user_ids)
|
||||||
return jsonify(members)
|
return jsonify(members)
|
||||||
|
|
||||||
|
|
||||||
async def _update_member_roles(guild_id: int, member_id: int,
|
async def _update_member_roles(guild_id: int, member_id: int, wanted_roles: set):
|
||||||
wanted_roles: set):
|
|
||||||
"""Update the roles a member has."""
|
"""Update the roles a member has."""
|
||||||
|
|
||||||
# first, fetch all current roles
|
# first, fetch all current roles
|
||||||
roles = await app.db.fetch("""
|
roles = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT role_id from member_roles
|
SELECT role_id from member_roles
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
""", guild_id, member_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
member_id,
|
||||||
|
)
|
||||||
|
|
||||||
roles = [r['role_id'] for r in roles]
|
roles = [r["role_id"] for r in roles]
|
||||||
|
|
||||||
roles = set(roles)
|
roles = set(roles)
|
||||||
wanted_roles = set(wanted_roles)
|
wanted_roles = set(wanted_roles)
|
||||||
|
|
@ -96,26 +99,30 @@ async def _update_member_roles(guild_id: int, member_id: int,
|
||||||
|
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
# add roles
|
# add roles
|
||||||
await app.db.executemany("""
|
await app.db.executemany(
|
||||||
|
"""
|
||||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""", [(member_id, guild_id, role_id)
|
""",
|
||||||
for role_id in added_roles])
|
[(member_id, guild_id, role_id) for role_id in added_roles],
|
||||||
|
)
|
||||||
|
|
||||||
# remove roles
|
# remove roles
|
||||||
await app.db.executemany("""
|
await app.db.executemany(
|
||||||
|
"""
|
||||||
DELETE FROM member_roles
|
DELETE FROM member_roles
|
||||||
WHERE
|
WHERE
|
||||||
user_id = $1
|
user_id = $1
|
||||||
AND guild_id = $2
|
AND guild_id = $2
|
||||||
AND role_id = $3
|
AND role_id = $3
|
||||||
""", [(member_id, guild_id, role_id)
|
""",
|
||||||
for role_id in removed_roles])
|
[(member_id, guild_id, role_id) for role_id in removed_roles],
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.release(conn)
|
await app.db.release(conn)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["PATCH"])
|
||||||
async def modify_guild_member(guild_id, member_id):
|
async def modify_guild_member(guild_id, member_id):
|
||||||
"""Modify a members' information in a guild."""
|
"""Modify a members' information in a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -124,96 +131,112 @@ async def modify_guild_member(guild_id, member_id):
|
||||||
j = validate(await request.get_json(), MEMBER_UPDATE)
|
j = validate(await request.get_json(), MEMBER_UPDATE)
|
||||||
nick_flag = False
|
nick_flag = False
|
||||||
|
|
||||||
if 'nick' in j:
|
if "nick" in j:
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_nicknames')
|
await guild_perm_check(user_id, guild_id, "manage_nicknames")
|
||||||
|
|
||||||
nick = j['nick'] or None
|
nick = j["nick"] or None
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE members
|
UPDATE members
|
||||||
SET nickname = $1
|
SET nickname = $1
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
""", nick, member_id, guild_id)
|
""",
|
||||||
|
nick,
|
||||||
|
member_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
nick_flag = True
|
nick_flag = True
|
||||||
|
|
||||||
if 'mute' in j:
|
if "mute" in j:
|
||||||
await guild_perm_check(user_id, guild_id, 'mute_members')
|
await guild_perm_check(user_id, guild_id, "mute_members")
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE members
|
UPDATE members
|
||||||
SET muted = $1
|
SET muted = $1
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
""", j['mute'], member_id, guild_id)
|
""",
|
||||||
|
j["mute"],
|
||||||
|
member_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if 'deaf' in j:
|
if "deaf" in j:
|
||||||
await guild_perm_check(user_id, guild_id, 'deafen_members')
|
await guild_perm_check(user_id, guild_id, "deafen_members")
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE members
|
UPDATE members
|
||||||
SET deafened = $1
|
SET deafened = $1
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
""", j['deaf'], member_id, guild_id)
|
""",
|
||||||
|
j["deaf"],
|
||||||
|
member_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if 'channel_id' in j:
|
if "channel_id" in j:
|
||||||
# TODO: check MOVE_MEMBERS and CONNECT to the channel
|
# TODO: check MOVE_MEMBERS and CONNECT to the channel
|
||||||
# TODO: change the member's voice channel
|
# TODO: change the member's voice channel
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if 'roles' in j:
|
if "roles" in j:
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||||
await _update_member_roles(guild_id, member_id, j['roles'])
|
await _update_member_roles(guild_id, member_id, j["roles"])
|
||||||
|
|
||||||
member = await app.storage.get_member_data_one(guild_id, member_id)
|
member = await app.storage.get_member_data_one(guild_id, member_id)
|
||||||
member.pop('joined_at')
|
member.pop("joined_at")
|
||||||
|
|
||||||
# call pres_update for role and nick changes.
|
# call pres_update for role and nick changes.
|
||||||
partial = {
|
partial = {"roles": member["roles"]}
|
||||||
'roles': member['roles']
|
|
||||||
}
|
|
||||||
|
|
||||||
if nick_flag:
|
if nick_flag:
|
||||||
partial['nick'] = j['nick']
|
partial["nick"] = j["nick"]
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'lazy_guild', guild_id, 'pres_update', user_id, partial)
|
"lazy_guild", guild_id, "pres_update", user_id, partial
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
await app.dispatcher.dispatch_guild(
|
||||||
'guild_id': str(guild_id)
|
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||||
}, **member})
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/@me/nick', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/members/@me/nick", methods=["PATCH"])
|
||||||
async def update_nickname(guild_id):
|
async def update_nickname(guild_id):
|
||||||
"""Update a member's nickname in a guild."""
|
"""Update a member's nickname in a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
j = validate(await request.get_json(), {
|
j = validate(await request.get_json(), {"nick": {"type": "nickname"}})
|
||||||
'nick': {'type': 'nickname'}
|
|
||||||
})
|
|
||||||
|
|
||||||
nick = j['nick'] or None
|
nick = j["nick"] or None
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE members
|
UPDATE members
|
||||||
SET nickname = $1
|
SET nickname = $1
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
""", nick, user_id, guild_id)
|
""",
|
||||||
|
nick,
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||||
member.pop('joined_at')
|
member.pop("joined_at")
|
||||||
|
|
||||||
# call pres_update for nick changes, etc.
|
# call pres_update for nick changes, etc.
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch(
|
||||||
'lazy_guild', guild_id, 'pres_update', user_id, {
|
"lazy_guild", guild_id, "pres_update", user_id, {"nick": j["nick"]}
|
||||||
'nick': j['nick']
|
)
|
||||||
})
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{
|
await app.dispatcher.dispatch_guild(
|
||||||
'guild_id': str(guild_id)
|
guild_id, "GUILD_MEMBER_UPDATE", {**{"guild_id": str(guild_id)}, **member}
|
||||||
}, **member})
|
)
|
||||||
|
|
||||||
return j['nick']
|
return j["nick"]
|
||||||
|
|
|
||||||
|
|
@ -24,33 +24,38 @@ from litecord.blueprints.checks import guild_perm_check
|
||||||
|
|
||||||
from litecord.schemas import validate, GUILD_PRUNE
|
from litecord.schemas import validate, GUILD_PRUNE
|
||||||
|
|
||||||
bp = Blueprint('guild_moderation', __name__)
|
bp = Blueprint("guild_moderation", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def remove_member(guild_id: int, member_id: int):
|
async def remove_member(guild_id: int, member_id: int):
|
||||||
"""Do common tasks related to deleting a member from the guild,
|
"""Do common tasks related to deleting a member from the guild,
|
||||||
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE."""
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM members
|
DELETE FROM members
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
""", guild_id, member_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
member_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user_guild(
|
await app.dispatcher.dispatch_user_guild(
|
||||||
member_id, guild_id, 'GUILD_DELETE', {
|
member_id,
|
||||||
'guild_id': str(guild_id),
|
guild_id,
|
||||||
'unavailable': False,
|
"GUILD_DELETE",
|
||||||
})
|
{"guild_id": str(guild_id), "unavailable": False},
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.unsub('guild', guild_id, member_id)
|
await app.dispatcher.unsub("guild", guild_id, member_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("lazy_guild", guild_id, "remove_member", member_id)
|
||||||
'lazy_guild', guild_id, 'remove_member', member_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', {
|
await app.dispatcher.dispatch_guild(
|
||||||
'guild_id': str(guild_id),
|
guild_id,
|
||||||
'user': await app.storage.get_user(member_id),
|
"GUILD_MEMBER_REMOVE",
|
||||||
})
|
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def remove_member_multi(guild_id: int, members: list):
|
async def remove_member_multi(guild_id: int, members: list):
|
||||||
|
|
@ -59,84 +64,100 @@ async def remove_member_multi(guild_id: int, members: list):
|
||||||
await remove_member(guild_id, member_id)
|
await remove_member(guild_id, member_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/members/<int:member_id>', methods=['DELETE'])
|
@bp.route("/<int:guild_id>/members/<int:member_id>", methods=["DELETE"])
|
||||||
async def kick_guild_member(guild_id, member_id):
|
async def kick_guild_member(guild_id, member_id):
|
||||||
"""Remove a member from a guild."""
|
"""Remove a member from a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'kick_members')
|
await guild_perm_check(user_id, guild_id, "kick_members")
|
||||||
await remove_member(guild_id, member_id)
|
await remove_member(guild_id, member_id)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/bans', methods=['GET'])
|
@bp.route("/<int:guild_id>/bans", methods=["GET"])
|
||||||
async def get_bans(guild_id):
|
async def get_bans(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'ban_members')
|
await guild_perm_check(user_id, guild_id, "ban_members")
|
||||||
|
|
||||||
bans = await app.db.fetch("""
|
bans = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT user_id, reason
|
SELECT user_id, reason
|
||||||
FROM bans
|
FROM bans
|
||||||
WHERE bans.guild_id = $1
|
WHERE bans.guild_id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for ban in bans:
|
for ban in bans:
|
||||||
res.append({
|
res.append(
|
||||||
'reason': ban['reason'],
|
{
|
||||||
'user': await app.storage.get_user(ban['user_id'])
|
"reason": ban["reason"],
|
||||||
})
|
"user": await app.storage.get_user(ban["user_id"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(res)
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/bans/<int:member_id>', methods=['PUT'])
|
@bp.route("/<int:guild_id>/bans/<int:member_id>", methods=["PUT"])
|
||||||
async def create_ban(guild_id, member_id):
|
async def create_ban(guild_id, member_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'ban_members')
|
await guild_perm_check(user_id, guild_id, "ban_members")
|
||||||
|
|
||||||
j = await request.get_json()
|
j = await request.get_json()
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO bans (guild_id, user_id, reason)
|
INSERT INTO bans (guild_id, user_id, reason)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""", guild_id, member_id, j.get('reason', ''))
|
""",
|
||||||
|
guild_id,
|
||||||
|
member_id,
|
||||||
|
j.get("reason", ""),
|
||||||
|
)
|
||||||
|
|
||||||
await remove_member(guild_id, member_id)
|
await remove_member(guild_id, member_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {
|
await app.dispatcher.dispatch_guild(
|
||||||
'guild_id': str(guild_id),
|
guild_id,
|
||||||
'user': await app.storage.get_user(member_id)
|
"GUILD_BAN_ADD",
|
||||||
})
|
{"guild_id": str(guild_id), "user": await app.storage.get_user(member_id)},
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/bans/<int:banned_id>', methods=['DELETE'])
|
@bp.route("/<int:guild_id>/bans/<int:banned_id>", methods=["DELETE"])
|
||||||
async def remove_ban(guild_id, banned_id):
|
async def remove_ban(guild_id, banned_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'ban_members')
|
await guild_perm_check(user_id, guild_id, "ban_members")
|
||||||
|
|
||||||
res = await app.db.execute("""
|
res = await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM bans
|
DELETE FROM bans
|
||||||
WHERE guild_id = $1 AND user_id = $@
|
WHERE guild_id = $1 AND user_id = $@
|
||||||
""", guild_id, banned_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
banned_id,
|
||||||
|
)
|
||||||
|
|
||||||
# we don't really need to dispatch GUILD_BAN_REMOVE
|
# we don't really need to dispatch GUILD_BAN_REMOVE
|
||||||
# when no bans were actually removed.
|
# when no bans were actually removed.
|
||||||
if res == 'DELETE 0':
|
if res == "DELETE 0":
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', {
|
await app.dispatcher.dispatch_guild(
|
||||||
'guild_id': str(guild_id),
|
guild_id,
|
||||||
'user': await app.storage.get_user(banned_id)
|
"GUILD_BAN_REMOVE",
|
||||||
})
|
{"guild_id": str(guild_id), "user": await app.storage.get_user(banned_id)},
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def get_prune(guild_id: int, days: int) -> list:
|
async def get_prune(guild_id: int, days: int) -> list:
|
||||||
|
|
@ -146,23 +167,30 @@ async def get_prune(guild_id: int, days: int) -> list:
|
||||||
- don't have any roles.
|
- don't have any roles.
|
||||||
"""
|
"""
|
||||||
# a good solution would be in pure sql.
|
# a good solution would be in pure sql.
|
||||||
member_ids = await app.db.fetch(f"""
|
member_ids = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM users
|
FROM users
|
||||||
JOIN members
|
JOIN members
|
||||||
ON members.guild_id = $1 AND members.user_id = users.id
|
ON members.guild_id = $1 AND members.user_id = users.id
|
||||||
WHERE users.last_session < (now() - (interval '{days} days'))
|
WHERE users.last_session < (now() - (interval '{days} days'))
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
member_ids = [r['id'] for r in member_ids]
|
member_ids = [r["id"] for r in member_ids]
|
||||||
members = []
|
members = []
|
||||||
|
|
||||||
for member_id in member_ids:
|
for member_id in member_ids:
|
||||||
role_count = await app.db.fetchval("""
|
role_count = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM member_roles
|
FROM member_roles
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
""", guild_id, member_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
member_id,
|
||||||
|
)
|
||||||
|
|
||||||
if role_count == 0:
|
if role_count == 0:
|
||||||
members.append(member_id)
|
members.append(member_id)
|
||||||
|
|
@ -170,33 +198,29 @@ async def get_prune(guild_id: int, days: int) -> list:
|
||||||
return members
|
return members
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/prune', methods=['GET'])
|
@bp.route("/<int:guild_id>/prune", methods=["GET"])
|
||||||
async def get_guild_prune_count(guild_id):
|
async def get_guild_prune_count(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'kick_members')
|
await guild_perm_check(user_id, guild_id, "kick_members")
|
||||||
|
|
||||||
j = validate(request.args, GUILD_PRUNE)
|
j = validate(request.args, GUILD_PRUNE)
|
||||||
days = j['days']
|
days = j["days"]
|
||||||
member_ids = await get_prune(guild_id, days)
|
member_ids = await get_prune(guild_id, days)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"pruned": len(member_ids)})
|
||||||
'pruned': len(member_ids),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/prune', methods=['POST'])
|
@bp.route("/<int:guild_id>/prune", methods=["POST"])
|
||||||
async def begin_guild_prune(guild_id):
|
async def begin_guild_prune(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'kick_members')
|
await guild_perm_check(user_id, guild_id, "kick_members")
|
||||||
|
|
||||||
j = validate(request.args, GUILD_PRUNE)
|
j = validate(request.args, GUILD_PRUNE)
|
||||||
days = j['days']
|
days = j["days"]
|
||||||
member_ids = await get_prune(guild_id, days)
|
member_ids = await get_prune(guild_id, days)
|
||||||
|
|
||||||
app.loop.create_task(remove_member_multi(guild_id, member_ids))
|
app.loop.create_task(remove_member_multi(guild_id, member_ids))
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({"pruned": len(member_ids)})
|
||||||
'pruned': len(member_ids)
|
|
||||||
})
|
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,8 @@ from logbook import Logger
|
||||||
|
|
||||||
from litecord.auth import token_check
|
from litecord.auth import token_check
|
||||||
|
|
||||||
from litecord.blueprints.checks import (
|
from litecord.blueprints.checks import guild_check, guild_perm_check
|
||||||
guild_check, guild_perm_check
|
from litecord.schemas import validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
||||||
)
|
|
||||||
from litecord.schemas import (
|
|
||||||
validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION
|
|
||||||
)
|
|
||||||
|
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.snowflake import get_snowflake
|
||||||
from litecord.utils import dict_get
|
from litecord.utils import dict_get
|
||||||
|
|
@ -37,22 +33,19 @@ from litecord.permissions import get_role_perms
|
||||||
|
|
||||||
DEFAULT_EVERYONE_PERMS = 104324161
|
DEFAULT_EVERYONE_PERMS = 104324161
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('guild_roles', __name__)
|
bp = Blueprint("guild_roles", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/roles', methods=['GET'])
|
@bp.route("/<int:guild_id>/roles", methods=["GET"])
|
||||||
async def get_guild_roles(guild_id):
|
async def get_guild_roles(guild_id):
|
||||||
"""Get all roles in a guild."""
|
"""Get all roles in a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_role_data(guild_id))
|
||||||
await app.storage.get_role_data(guild_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_lg(guild_id: int, event: str,
|
async def _maybe_lg(guild_id: int, event: str, role, force: bool = False):
|
||||||
role, force: bool = False):
|
|
||||||
# sometimes we want to dispatch an event
|
# sometimes we want to dispatch an event
|
||||||
# even if the role isn't hoisted
|
# even if the role isn't hoisted
|
||||||
|
|
||||||
|
|
@ -61,11 +54,10 @@ async def _maybe_lg(guild_id: int, event: str,
|
||||||
|
|
||||||
# check if is a dict first because role_delete
|
# check if is a dict first because role_delete
|
||||||
# only receives the role id.
|
# only receives the role id.
|
||||||
if isinstance(role, dict) and not role['hoist'] and not force:
|
if isinstance(role, dict) and not role["hoist"] and not force:
|
||||||
return
|
return
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("lazy_guild", guild_id, event, role)
|
||||||
'lazy_guild', guild_id, event, role)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_role(guild_id, name: str, **kwargs):
|
async def create_role(guild_id, name: str, **kwargs):
|
||||||
|
|
@ -73,18 +65,20 @@ async def create_role(guild_id, name: str, **kwargs):
|
||||||
new_role_id = get_snowflake()
|
new_role_id = get_snowflake()
|
||||||
|
|
||||||
everyone_perms = await get_role_perms(guild_id, guild_id)
|
everyone_perms = await get_role_perms(guild_id, guild_id)
|
||||||
default_perms = dict_get(kwargs, 'default_perms',
|
default_perms = dict_get(kwargs, "default_perms", everyone_perms.binary)
|
||||||
everyone_perms.binary)
|
|
||||||
|
|
||||||
# update all roles so that we have space for pos 1, but without
|
# update all roles so that we have space for pos 1, but without
|
||||||
# sending GUILD_ROLE_UPDATE for everyone
|
# sending GUILD_ROLE_UPDATE for everyone
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE roles
|
UPDATE roles
|
||||||
SET
|
SET
|
||||||
position = position + 1
|
position = position + 1
|
||||||
WHERE guild_id = $1
|
WHERE guild_id = $1
|
||||||
AND NOT (position = 0)
|
AND NOT (position = 0)
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
|
|
@ -95,42 +89,39 @@ async def create_role(guild_id, name: str, **kwargs):
|
||||||
new_role_id,
|
new_role_id,
|
||||||
guild_id,
|
guild_id,
|
||||||
name,
|
name,
|
||||||
dict_get(kwargs, 'color', 0),
|
dict_get(kwargs, "color", 0),
|
||||||
dict_get(kwargs, 'hoist', False),
|
dict_get(kwargs, "hoist", False),
|
||||||
|
|
||||||
# always set ourselves on position 1
|
# always set ourselves on position 1
|
||||||
1,
|
1,
|
||||||
int(dict_get(kwargs, 'permissions', default_perms)),
|
int(dict_get(kwargs, "permissions", default_perms)),
|
||||||
False,
|
False,
|
||||||
dict_get(kwargs, 'mentionable', False)
|
dict_get(kwargs, "mentionable", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
role = await app.storage.get_role(new_role_id, guild_id)
|
role = await app.storage.get_role(new_role_id, guild_id)
|
||||||
|
|
||||||
# we need to update the lazy guild handlers for the newly created group
|
# we need to update the lazy guild handlers for the newly created group
|
||||||
await _maybe_lg(guild_id, 'new_role', role)
|
await _maybe_lg(guild_id, "new_role", role)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
await app.dispatcher.dispatch_guild(
|
||||||
guild_id, 'GUILD_ROLE_CREATE', {
|
guild_id, "GUILD_ROLE_CREATE", {"guild_id": str(guild_id), "role": role}
|
||||||
'guild_id': str(guild_id),
|
)
|
||||||
'role': role,
|
|
||||||
})
|
|
||||||
|
|
||||||
return role
|
return role
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/roles', methods=['POST'])
|
@bp.route("/<int:guild_id>/roles", methods=["POST"])
|
||||||
async def create_guild_role(guild_id: int):
|
async def create_guild_role(guild_id: int):
|
||||||
"""Add a role to a guild"""
|
"""Add a role to a guild"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||||
|
|
||||||
# client can just send null
|
# client can just send null
|
||||||
j = validate(await request.get_json() or {}, ROLE_CREATE)
|
j = validate(await request.get_json() or {}, ROLE_CREATE)
|
||||||
|
|
||||||
role_name = j['name']
|
role_name = j["name"]
|
||||||
j.pop('name')
|
j.pop("name")
|
||||||
|
|
||||||
role = await create_role(guild_id, role_name, **j)
|
role = await create_role(guild_id, role_name, **j)
|
||||||
|
|
||||||
|
|
@ -141,12 +132,11 @@ async def _role_update_dispatch(role_id: int, guild_id: int):
|
||||||
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
|
"""Dispatch a GUILD_ROLE_UPDATE with updated information on a role."""
|
||||||
role = await app.storage.get_role(role_id, guild_id)
|
role = await app.storage.get_role(role_id, guild_id)
|
||||||
|
|
||||||
await _maybe_lg(guild_id, 'role_pos_upd', role)
|
await _maybe_lg(guild_id, "role_pos_upd", role)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', {
|
await app.dispatcher.dispatch_guild(
|
||||||
'guild_id': str(guild_id),
|
guild_id, "GUILD_ROLE_UPDATE", {"guild_id": str(guild_id), "role": role}
|
||||||
'role': role,
|
)
|
||||||
})
|
|
||||||
|
|
||||||
return role
|
return role
|
||||||
|
|
||||||
|
|
@ -166,17 +156,25 @@ async def _role_pairs_update(guild_id: int, pairs: list):
|
||||||
async with conn.transaction():
|
async with conn.transaction():
|
||||||
# update happens in a transaction
|
# update happens in a transaction
|
||||||
# so we don't fuck it up
|
# so we don't fuck it up
|
||||||
await conn.execute("""
|
await conn.execute(
|
||||||
|
"""
|
||||||
UPDATE roles
|
UPDATE roles
|
||||||
SET position = $1
|
SET position = $1
|
||||||
WHERE roles.id = $2
|
WHERE roles.id = $2
|
||||||
""", new_pos_1, role_1)
|
""",
|
||||||
|
new_pos_1,
|
||||||
|
role_1,
|
||||||
|
)
|
||||||
|
|
||||||
await conn.execute("""
|
await conn.execute(
|
||||||
|
"""
|
||||||
UPDATE roles
|
UPDATE roles
|
||||||
SET position = $1
|
SET position = $1
|
||||||
WHERE roles.id = $2
|
WHERE roles.id = $2
|
||||||
""", new_pos_2, role_2)
|
""",
|
||||||
|
new_pos_2,
|
||||||
|
role_2,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.release(conn)
|
await app.db.release(conn)
|
||||||
|
|
||||||
|
|
@ -184,11 +182,15 @@ async def _role_pairs_update(guild_id: int, pairs: list):
|
||||||
await _role_update_dispatch(role_1, guild_id)
|
await _role_update_dispatch(role_1, guild_id)
|
||||||
await _role_update_dispatch(role_2, guild_id)
|
await _role_update_dispatch(role_2, guild_id)
|
||||||
|
|
||||||
|
|
||||||
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
|
PairList = List[Tuple[Tuple[int, int], Tuple[int, int]]]
|
||||||
|
|
||||||
def gen_pairs(list_of_changes: List[Dict[str, int]],
|
|
||||||
|
def gen_pairs(
|
||||||
|
list_of_changes: List[Dict[str, int]],
|
||||||
current_state: Dict[int, int],
|
current_state: Dict[int, int],
|
||||||
blacklist: List[int] = None) -> PairList:
|
blacklist: List[int] = None,
|
||||||
|
) -> PairList:
|
||||||
"""Generate a list of pairs that, when applied to the database,
|
"""Generate a list of pairs that, when applied to the database,
|
||||||
will generate the desired state given in list_of_changes.
|
will generate the desired state given in list_of_changes.
|
||||||
|
|
||||||
|
|
@ -230,8 +232,9 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
|
||||||
pairs = []
|
pairs = []
|
||||||
blacklist = blacklist or []
|
blacklist = blacklist or []
|
||||||
|
|
||||||
preferred_state = {element['id']: element['position']
|
preferred_state = {
|
||||||
for element in list_of_changes}
|
element["id"]: element["position"] for element in list_of_changes
|
||||||
|
}
|
||||||
|
|
||||||
for blacklisted_id in blacklist:
|
for blacklisted_id in blacklist:
|
||||||
preferred_state.pop(blacklisted_id)
|
preferred_state.pop(blacklisted_id)
|
||||||
|
|
@ -239,7 +242,7 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
|
||||||
# for each change, we must find a matching change
|
# for each change, we must find a matching change
|
||||||
# in the same list, so we can make a swap pair
|
# in the same list, so we can make a swap pair
|
||||||
for change in list_of_changes:
|
for change in list_of_changes:
|
||||||
element_1, new_pos_1 = change['id'], change['position']
|
element_1, new_pos_1 = change["id"], change["position"]
|
||||||
|
|
||||||
# check current pairs
|
# check current pairs
|
||||||
# so we don't repeat an element
|
# so we don't repeat an element
|
||||||
|
|
@ -267,36 +270,34 @@ def gen_pairs(list_of_changes: List[Dict[str, int]],
|
||||||
# if its being swapped to leave space, add it
|
# if its being swapped to leave space, add it
|
||||||
# to the pairs list
|
# to the pairs list
|
||||||
if new_pos_2 is not None:
|
if new_pos_2 is not None:
|
||||||
pairs.append(
|
pairs.append(((element_1, new_pos_1), (element_2, new_pos_2)))
|
||||||
((element_1, new_pos_1), (element_2, new_pos_2))
|
|
||||||
)
|
|
||||||
|
|
||||||
return pairs
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/roles', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/roles", methods=["PATCH"])
|
||||||
async def update_guild_role_positions(guild_id):
|
async def update_guild_role_positions(guild_id):
|
||||||
"""Update the positions for a bunch of roles."""
|
"""Update the positions for a bunch of roles."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||||
|
|
||||||
raw_j = await request.get_json()
|
raw_j = await request.get_json()
|
||||||
|
|
||||||
# we need to do this hackiness because thats
|
# we need to do this hackiness because thats
|
||||||
# cerberus for ya.
|
# cerberus for ya.
|
||||||
j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION)
|
j = validate({"roles": raw_j}, ROLE_UPDATE_POSITION)
|
||||||
|
|
||||||
# extract the list out
|
# extract the list out
|
||||||
j = j['roles']
|
j = j["roles"]
|
||||||
|
|
||||||
log.debug('role stuff: {!r}', j)
|
log.debug("role stuff: {!r}", j)
|
||||||
|
|
||||||
all_roles = await app.storage.get_role_data(guild_id)
|
all_roles = await app.storage.get_role_data(guild_id)
|
||||||
|
|
||||||
# we'll have to calculate pairs of changing roles,
|
# we'll have to calculate pairs of changing roles,
|
||||||
# then do the changes, etc.
|
# then do the changes, etc.
|
||||||
roles_pos = {role['position']: int(role['id']) for role in all_roles}
|
roles_pos = {role["position"]: int(role["id"]) for role in all_roles}
|
||||||
|
|
||||||
# TODO: check if the user can even change the roles in the first place,
|
# TODO: check if the user can even change the roles in the first place,
|
||||||
# preferrably when we have a proper perms system.
|
# preferrably when we have a proper perms system.
|
||||||
|
|
@ -306,10 +307,9 @@ async def update_guild_role_positions(guild_id):
|
||||||
pairs = gen_pairs(
|
pairs = gen_pairs(
|
||||||
j,
|
j,
|
||||||
roles_pos,
|
roles_pos,
|
||||||
|
|
||||||
# always ignore people trying to change
|
# always ignore people trying to change
|
||||||
# the @everyone's role position
|
# the @everyone's role position
|
||||||
[guild_id]
|
[guild_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
await _role_pairs_update(guild_id, pairs)
|
await _role_pairs_update(guild_id, pairs)
|
||||||
|
|
@ -318,31 +318,36 @@ async def update_guild_role_positions(guild_id):
|
||||||
return jsonify(await app.storage.get_role_data(guild_id))
|
return jsonify(await app.storage.get_role_data(guild_id))
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["PATCH"])
|
||||||
async def update_guild_role(guild_id, role_id):
|
async def update_guild_role(guild_id, role_id):
|
||||||
"""Update a single role's information."""
|
"""Update a single role's information."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||||
|
|
||||||
j = validate(await request.get_json(), ROLE_UPDATE)
|
j = validate(await request.get_json(), ROLE_UPDATE)
|
||||||
|
|
||||||
# we only update ints on the db, not Permissions
|
# we only update ints on the db, not Permissions
|
||||||
j['permissions'] = int(j['permissions'])
|
j["permissions"] = int(j["permissions"])
|
||||||
|
|
||||||
for field in j:
|
for field in j:
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE roles
|
UPDATE roles
|
||||||
SET {field} = $1
|
SET {field} = $1
|
||||||
WHERE roles.id = $2 AND roles.guild_id = $3
|
WHERE roles.id = $2 AND roles.guild_id = $3
|
||||||
""", j[field], role_id, guild_id)
|
""",
|
||||||
|
j[field],
|
||||||
|
role_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
role = await _role_update_dispatch(role_id, guild_id)
|
role = await _role_update_dispatch(role_id, guild_id)
|
||||||
await _maybe_lg(guild_id, 'role_update', role, True)
|
await _maybe_lg(guild_id, "role_update", role, True)
|
||||||
return jsonify(role)
|
return jsonify(role)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/roles/<int:role_id>', methods=['DELETE'])
|
@bp.route("/<int:guild_id>/roles/<int:role_id>", methods=["DELETE"])
|
||||||
async def delete_guild_role(guild_id, role_id):
|
async def delete_guild_role(guild_id, role_id):
|
||||||
"""Delete a role.
|
"""Delete a role.
|
||||||
|
|
||||||
|
|
@ -350,21 +355,26 @@ async def delete_guild_role(guild_id, role_id):
|
||||||
"""
|
"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_roles')
|
await guild_perm_check(user_id, guild_id, "manage_roles")
|
||||||
|
|
||||||
res = await app.db.execute("""
|
res = await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM roles
|
DELETE FROM roles
|
||||||
WHERE guild_id = $1 AND id = $2
|
WHERE guild_id = $1 AND id = $2
|
||||||
""", guild_id, role_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
role_id,
|
||||||
|
)
|
||||||
|
|
||||||
if res == 'DELETE 0':
|
if res == "DELETE 0":
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
await _maybe_lg(guild_id, 'role_delete', role_id, True)
|
await _maybe_lg(guild_id, "role_delete", role_id, True)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', {
|
await app.dispatcher.dispatch_guild(
|
||||||
'guild_id': str(guild_id),
|
guild_id,
|
||||||
'role_id': str(role_id),
|
"GUILD_ROLE_DELETE",
|
||||||
})
|
{"guild_id": str(guild_id), "role_id": str(role_id)},
|
||||||
|
)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -22,16 +22,17 @@ from typing import Optional, List
|
||||||
from quart import Blueprint, request, current_app as app, jsonify
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
|
|
||||||
from litecord.blueprints.guild.channels import create_guild_channel
|
from litecord.blueprints.guild.channels import create_guild_channel
|
||||||
from litecord.blueprints.guild.roles import (
|
from litecord.blueprints.guild.roles import create_role, DEFAULT_EVERYONE_PERMS
|
||||||
create_role, DEFAULT_EVERYONE_PERMS
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..auth import token_check
|
from ..auth import token_check
|
||||||
from ..snowflake import get_snowflake
|
from ..snowflake import get_snowflake
|
||||||
from ..enums import ChannelType
|
from ..enums import ChannelType
|
||||||
from ..schemas import (
|
from ..schemas import (
|
||||||
validate, GUILD_CREATE, GUILD_UPDATE, SEARCH_CHANNEL,
|
validate,
|
||||||
VANITY_URL_PATCH
|
GUILD_CREATE,
|
||||||
|
GUILD_UPDATE,
|
||||||
|
SEARCH_CHANNEL,
|
||||||
|
VANITY_URL_PATCH,
|
||||||
)
|
)
|
||||||
from .channels import channel_ack
|
from .channels import channel_ack
|
||||||
from .checks import guild_check, guild_owner_check, guild_perm_check
|
from .checks import guild_check, guild_owner_check, guild_perm_check
|
||||||
|
|
@ -40,7 +41,7 @@ from litecord.errors import BadRequest
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint('guilds', __name__)
|
bp = Blueprint("guilds", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def create_guild_settings(guild_id: int, user_id: int):
|
async def create_guild_settings(guild_id: int, user_id: int):
|
||||||
|
|
@ -49,26 +50,38 @@ async def create_guild_settings(guild_id: int, user_id: int):
|
||||||
|
|
||||||
# new guild_settings are based off the currently
|
# new guild_settings are based off the currently
|
||||||
# set guild settings (for the guild)
|
# set guild settings (for the guild)
|
||||||
m_notifs = await app.db.fetchval("""
|
m_notifs = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT default_message_notifications
|
SELECT default_message_notifications
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO guild_settings
|
INSERT INTO guild_settings
|
||||||
(user_id, guild_id, message_notifications)
|
(user_id, guild_id, message_notifications)
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3)
|
($1, $2, $3)
|
||||||
""", user_id, guild_id, m_notifs)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
m_notifs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def add_member(guild_id: int, user_id: int):
|
async def add_member(guild_id: int, user_id: int):
|
||||||
"""Add a user to a guild."""
|
"""Add a user to a guild."""
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO members (user_id, guild_id)
|
INSERT INTO members (user_id, guild_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
await create_guild_settings(guild_id, user_id)
|
await create_guild_settings(guild_id, user_id)
|
||||||
|
|
||||||
|
|
@ -83,28 +96,28 @@ async def guild_create_roles_prep(guild_id: int, roles: list):
|
||||||
# are patches to the @everyone role
|
# are patches to the @everyone role
|
||||||
everyone_patches = roles[0]
|
everyone_patches = roles[0]
|
||||||
for field in everyone_patches:
|
for field in everyone_patches:
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE roles
|
UPDATE roles
|
||||||
SET {field}={everyone_patches[field]}
|
SET {field}={everyone_patches[field]}
|
||||||
WHERE roles.id = $1
|
WHERE roles.id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
default_perms = (everyone_patches.get('permissions')
|
default_perms = everyone_patches.get("permissions") or DEFAULT_EVERYONE_PERMS
|
||||||
or DEFAULT_EVERYONE_PERMS)
|
|
||||||
|
|
||||||
# from the 2nd and forward,
|
# from the 2nd and forward,
|
||||||
# should be treated as new roles
|
# should be treated as new roles
|
||||||
for role in roles[1:]:
|
for role in roles[1:]:
|
||||||
await create_role(
|
await create_role(guild_id, role["name"], default_perms=default_perms, **role)
|
||||||
guild_id, role['name'], default_perms=default_perms, **role
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def guild_create_channels_prep(guild_id: int, channels: list):
|
async def guild_create_channels_prep(guild_id: int, channels: list):
|
||||||
"""Create channels pre-guild create"""
|
"""Create channels pre-guild create"""
|
||||||
for channel_raw in channels:
|
for channel_raw in channels:
|
||||||
channel_id = get_snowflake()
|
channel_id = get_snowflake()
|
||||||
ctype = ChannelType(channel_raw['type'])
|
ctype = ChannelType(channel_raw["type"])
|
||||||
|
|
||||||
await create_guild_channel(guild_id, channel_id, ctype)
|
await create_guild_channel(guild_id, channel_id, ctype)
|
||||||
|
|
||||||
|
|
@ -114,37 +127,29 @@ def sanitize_icon(icon: Optional[str]) -> Optional[str]:
|
||||||
|
|
||||||
Defaults to a jpeg icon when the header isn't given.
|
Defaults to a jpeg icon when the header isn't given.
|
||||||
"""
|
"""
|
||||||
if icon and icon.startswith('data'):
|
if icon and icon.startswith("data"):
|
||||||
return icon
|
return icon
|
||||||
|
|
||||||
return (f'data:image/jpeg;base64,{icon}'
|
return f"data:image/jpeg;base64,{icon}" if icon else None
|
||||||
if icon
|
|
||||||
else None)
|
|
||||||
|
|
||||||
|
|
||||||
async def _general_guild_icon(scope: str, guild_id: int,
|
async def _general_guild_icon(scope: str, guild_id: int, icon: str, **kwargs):
|
||||||
icon: str, **kwargs):
|
|
||||||
encoded = sanitize_icon(icon)
|
encoded = sanitize_icon(icon)
|
||||||
|
|
||||||
icon_kwargs = {
|
icon_kwargs = {"always_icon": True}
|
||||||
'always_icon': True
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'size' in kwargs:
|
if "size" in kwargs:
|
||||||
icon_kwargs['size'] = kwargs['size']
|
icon_kwargs["size"] = kwargs["size"]
|
||||||
|
|
||||||
return await app.icons.put(
|
return await app.icons.put(scope, guild_id, encoded, **icon_kwargs)
|
||||||
scope, guild_id, encoded,
|
|
||||||
**icon_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def put_guild_icon(guild_id: int, icon: Optional[str]):
|
async def put_guild_icon(guild_id: int, icon: Optional[str]):
|
||||||
"""Insert a guild icon on the icon database."""
|
"""Insert a guild icon on the icon database."""
|
||||||
return await _general_guild_icon('guild', guild_id, icon, size=(128, 128))
|
return await _general_guild_icon("guild", guild_id, icon, size=(128, 128))
|
||||||
|
|
||||||
|
|
||||||
@bp.route('', methods=['POST'])
|
@bp.route("", methods=["POST"])
|
||||||
async def create_guild():
|
async def create_guild():
|
||||||
"""Create a new guild, assigning
|
"""Create a new guild, assigning
|
||||||
the user creating it as the owner and
|
the user creating it as the owner and
|
||||||
|
|
@ -154,8 +159,8 @@ async def create_guild():
|
||||||
|
|
||||||
guild_id = get_snowflake()
|
guild_id = get_snowflake()
|
||||||
|
|
||||||
if 'icon' in j:
|
if "icon" in j:
|
||||||
image = await put_guild_icon(guild_id, j['icon'])
|
image = await put_guild_icon(guild_id, j["icon"])
|
||||||
image = image.icon_hash
|
image = image.icon_hash
|
||||||
else:
|
else:
|
||||||
image = None
|
image = None
|
||||||
|
|
@ -166,10 +171,16 @@ async def create_guild():
|
||||||
verification_level, default_message_notifications,
|
verification_level, default_message_notifications,
|
||||||
explicit_content_filter)
|
explicit_content_filter)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
""", guild_id, j['name'], j['region'], image, user_id,
|
""",
|
||||||
j.get('verification_level', 0),
|
guild_id,
|
||||||
j.get('default_message_notifications', 0),
|
j["name"],
|
||||||
j.get('explicit_content_filter', 0))
|
j["region"],
|
||||||
|
image,
|
||||||
|
user_id,
|
||||||
|
j.get("verification_level", 0),
|
||||||
|
j.get("default_message_notifications", 0),
|
||||||
|
j.get("explicit_content_filter", 0),
|
||||||
|
)
|
||||||
|
|
||||||
await add_member(guild_id, user_id)
|
await add_member(guild_id, user_id)
|
||||||
|
|
||||||
|
|
@ -179,107 +190,127 @@ async def create_guild():
|
||||||
# we also don't use create_role because the id of the role
|
# we also don't use create_role because the id of the role
|
||||||
# is the same as the id of the guild, and create_role
|
# is the same as the id of the guild, and create_role
|
||||||
# generates a new snowflake.
|
# generates a new snowflake.
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO roles (id, guild_id, name, position, permissions)
|
INSERT INTO roles (id, guild_id, name, position, permissions)
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
""", guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS)
|
""",
|
||||||
|
guild_id,
|
||||||
|
guild_id,
|
||||||
|
"@everyone",
|
||||||
|
0,
|
||||||
|
DEFAULT_EVERYONE_PERMS,
|
||||||
|
)
|
||||||
|
|
||||||
# add the @everyone role to the guild creator
|
# add the @everyone role to the guild creator
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""", user_id, guild_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
# create a single #general channel.
|
# create a single #general channel.
|
||||||
general_id = get_snowflake()
|
general_id = get_snowflake()
|
||||||
|
|
||||||
await create_guild_channel(
|
await create_guild_channel(
|
||||||
guild_id, general_id, ChannelType.GUILD_TEXT,
|
guild_id, general_id, ChannelType.GUILD_TEXT, name="general"
|
||||||
name='general')
|
)
|
||||||
|
|
||||||
if j.get('roles'):
|
if j.get("roles"):
|
||||||
await guild_create_roles_prep(guild_id, j['roles'])
|
await guild_create_roles_prep(guild_id, j["roles"])
|
||||||
|
|
||||||
if j.get('channels'):
|
if j.get("channels"):
|
||||||
await guild_create_channels_prep(guild_id, j['channels'])
|
await guild_create_channels_prep(guild_id, j["channels"])
|
||||||
|
|
||||||
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
guild_total = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||||
|
|
||||||
await app.dispatcher.sub('guild', guild_id, user_id)
|
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_CREATE', guild_total)
|
await app.dispatcher.dispatch_guild(guild_id, "GUILD_CREATE", guild_total)
|
||||||
return jsonify(guild_total)
|
return jsonify(guild_total)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['GET'])
|
@bp.route("/<int:guild_id>", methods=["GET"])
|
||||||
async def get_guild(guild_id):
|
async def get_guild(guild_id):
|
||||||
"""Get a single guilds' information."""
|
"""Get a single guilds' information."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_guild_full(guild_id, user_id, 250))
|
||||||
await app.storage.get_guild_full(guild_id, user_id, 250)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _guild_update_icon(scope: str, guild_id: int,
|
async def _guild_update_icon(scope: str, guild_id: int, icon: Optional[str], **kwargs):
|
||||||
icon: Optional[str], **kwargs):
|
|
||||||
"""Update icon."""
|
"""Update icon."""
|
||||||
new_icon = await app.icons.update(
|
new_icon = await app.icons.update(scope, guild_id, icon, always_icon=True, **kwargs)
|
||||||
scope, guild_id, icon, always_icon=True, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
table = {
|
table = {"guild": "icon"}.get(scope, scope)
|
||||||
'guild': 'icon',
|
|
||||||
}.get(scope, scope)
|
|
||||||
|
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET {table} = $1
|
SET {table} = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_icon.icon_hash, guild_id)
|
""",
|
||||||
|
new_icon.icon_hash,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _guild_update_region(guild_id, region):
|
async def _guild_update_region(guild_id, region):
|
||||||
is_vip = region.vip
|
is_vip = region.vip
|
||||||
can_vip = await app.storage.has_feature(guild_id, 'VIP_REGIONS')
|
can_vip = await app.storage.has_feature(guild_id, "VIP_REGIONS")
|
||||||
|
|
||||||
if is_vip and not can_vip:
|
if is_vip and not can_vip:
|
||||||
raise BadRequest('can not assign guild to vip-only region')
|
raise BadRequest("can not assign guild to vip-only region")
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET region = $1
|
SET region = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", region.id, guild_id)
|
""",
|
||||||
|
region.id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/<int:guild_id>", methods=["PATCH"])
|
||||||
@bp.route('/<int:guild_id>', methods=['PATCH'])
|
|
||||||
async def _update_guild(guild_id):
|
async def _update_guild(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||||
j = validate(await request.get_json(), GUILD_UPDATE)
|
j = validate(await request.get_json(), GUILD_UPDATE)
|
||||||
|
|
||||||
if 'owner_id' in j:
|
if "owner_id" in j:
|
||||||
await guild_owner_check(user_id, guild_id)
|
await guild_owner_check(user_id, guild_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET owner_id = $1
|
SET owner_id = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", int(j['owner_id']), guild_id)
|
""",
|
||||||
|
int(j["owner_id"]),
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if 'name' in j:
|
if "name" in j:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET name = $1
|
SET name = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j['name'], guild_id)
|
""",
|
||||||
|
j["name"],
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if 'region' in j:
|
if "region" in j:
|
||||||
region = app.voice.lvsp.region(j['region'])
|
region = app.voice.lvsp.region(j["region"])
|
||||||
|
|
||||||
if region is not None:
|
if region is not None:
|
||||||
await _guild_update_region(guild_id, region)
|
await _guild_update_region(guild_id, region)
|
||||||
|
|
@ -287,65 +318,77 @@ async def _update_guild(guild_id):
|
||||||
# small guild to work with to_update()
|
# small guild to work with to_update()
|
||||||
guild = await app.storage.get_guild(guild_id)
|
guild = await app.storage.get_guild(guild_id)
|
||||||
|
|
||||||
if to_update(j, guild, 'icon'):
|
if to_update(j, guild, "icon"):
|
||||||
await _guild_update_icon(
|
await _guild_update_icon("guild", guild_id, j["icon"], size=(128, 128))
|
||||||
'guild', guild_id, j['icon'], size=(128, 128))
|
|
||||||
|
|
||||||
if to_update(j, guild, 'splash'):
|
if to_update(j, guild, "splash"):
|
||||||
if not await app.storage.has_feature(guild_id, 'INVITE_SPLASH'):
|
if not await app.storage.has_feature(guild_id, "INVITE_SPLASH"):
|
||||||
raise BadRequest('guild does not have INVITE_SPLASH feature')
|
raise BadRequest("guild does not have INVITE_SPLASH feature")
|
||||||
|
|
||||||
await _guild_update_icon('splash', guild_id, j['splash'])
|
await _guild_update_icon("splash", guild_id, j["splash"])
|
||||||
|
|
||||||
if to_update(j, guild, 'banner'):
|
if to_update(j, guild, "banner"):
|
||||||
if not await app.storage.has_feature(guild_id, 'VERIFIED'):
|
if not await app.storage.has_feature(guild_id, "VERIFIED"):
|
||||||
raise BadRequest('guild is not verified')
|
raise BadRequest("guild is not verified")
|
||||||
|
|
||||||
await _guild_update_icon('banner', guild_id, j['banner'])
|
await _guild_update_icon("banner", guild_id, j["banner"])
|
||||||
|
|
||||||
fields = ['verification_level', 'default_message_notifications',
|
fields = [
|
||||||
'explicit_content_filter', 'afk_timeout', 'description']
|
"verification_level",
|
||||||
|
"default_message_notifications",
|
||||||
|
"explicit_content_filter",
|
||||||
|
"afk_timeout",
|
||||||
|
"description",
|
||||||
|
]
|
||||||
|
|
||||||
for field in [f for f in fields if f in j]:
|
for field in [f for f in fields if f in j]:
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET {field} = $1
|
SET {field} = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j[field], guild_id)
|
""",
|
||||||
|
j[field],
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
channel_fields = ['afk_channel_id', 'system_channel_id']
|
channel_fields = ["afk_channel_id", "system_channel_id"]
|
||||||
for field in [f for f in channel_fields if f in j]:
|
for field in [f for f in channel_fields if f in j]:
|
||||||
# setting to null should remove the link between the afk/sys channel
|
# setting to null should remove the link between the afk/sys channel
|
||||||
# to the guild.
|
# to the guild.
|
||||||
if j[field] is None:
|
if j[field] is None:
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET {field} = NULL
|
SET {field} = NULL
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chan = await app.storage.get_channel(int(j[field]))
|
chan = await app.storage.get_channel(int(j[field]))
|
||||||
|
|
||||||
if chan is None:
|
if chan is None:
|
||||||
raise BadRequest('invalid channel id')
|
raise BadRequest("invalid channel id")
|
||||||
|
|
||||||
if chan['guild_id'] != str(guild_id):
|
if chan["guild_id"] != str(guild_id):
|
||||||
raise BadRequest('channel id not linked to guild')
|
raise BadRequest("channel id not linked to guild")
|
||||||
|
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET {field} = $1
|
SET {field} = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j[field], guild_id)
|
""",
|
||||||
|
j[field],
|
||||||
guild = await app.storage.get_guild_full(
|
guild_id,
|
||||||
guild_id, user_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_guild(
|
guild = await app.storage.get_guild_full(guild_id, user_id)
|
||||||
guild_id, 'GUILD_UPDATE', guild)
|
|
||||||
|
await app.dispatcher.dispatch_guild(guild_id, "GUILD_UPDATE", guild)
|
||||||
|
|
||||||
return jsonify(guild)
|
return jsonify(guild)
|
||||||
|
|
||||||
|
|
@ -354,33 +397,41 @@ async def delete_guild(guild_id: int, *, app_=None):
|
||||||
"""Delete a single guild."""
|
"""Delete a single guild."""
|
||||||
app_ = app_ or app
|
app_ = app_ or app
|
||||||
|
|
||||||
await app_.db.execute("""
|
await app_.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM guilds
|
DELETE FROM guilds
|
||||||
WHERE guilds.id = $1
|
WHERE guilds.id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
# Discord's client expects IDs being string
|
# Discord's client expects IDs being string
|
||||||
await app_.dispatcher.dispatch('guild', guild_id, 'GUILD_DELETE', {
|
await app_.dispatcher.dispatch(
|
||||||
'guild_id': str(guild_id),
|
"guild",
|
||||||
'id': str(guild_id),
|
guild_id,
|
||||||
|
"GUILD_DELETE",
|
||||||
|
{
|
||||||
|
"guild_id": str(guild_id),
|
||||||
|
"id": str(guild_id),
|
||||||
# 'unavailable': False,
|
# 'unavailable': False,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# remove from the dispatcher so nobody
|
# remove from the dispatcher so nobody
|
||||||
# becomes the little memer that tries to fuck up with
|
# becomes the little memer that tries to fuck up with
|
||||||
# everybody's gateway
|
# everybody's gateway
|
||||||
await app_.dispatcher.remove('guild', guild_id)
|
await app_.dispatcher.remove("guild", guild_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>', methods=['DELETE'])
|
@bp.route("/<int:guild_id>", methods=["DELETE"])
|
||||||
# this endpoint is not documented, but used by the official client.
|
# this endpoint is not documented, but used by the official client.
|
||||||
@bp.route('/<int:guild_id>/delete', methods=['POST'])
|
@bp.route("/<int:guild_id>/delete", methods=["POST"])
|
||||||
async def delete_guild_handler(guild_id):
|
async def delete_guild_handler(guild_id):
|
||||||
"""Delete a guild."""
|
"""Delete a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_owner_check(user_id, guild_id)
|
await guild_owner_check(user_id, guild_id)
|
||||||
await delete_guild(guild_id)
|
await delete_guild(guild_id)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
|
async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
|
||||||
|
|
@ -397,7 +448,7 @@ async def fetch_readable_channels(guild_id: int, user_id: int) -> List[int]:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/messages/search', methods=['GET'])
|
@bp.route("/<int:guild_id>/messages/search", methods=["GET"])
|
||||||
async def search_messages(guild_id):
|
async def search_messages(guild_id):
|
||||||
"""Search messages in a guild.
|
"""Search messages in a guild.
|
||||||
|
|
||||||
|
|
@ -415,7 +466,8 @@ async def search_messages(guild_id):
|
||||||
# use that list on the main search query.
|
# use that list on the main search query.
|
||||||
can_read = await fetch_readable_channels(guild_id, user_id)
|
can_read = await fetch_readable_channels(guild_id, user_id)
|
||||||
|
|
||||||
rows = await app.db.fetch(f"""
|
rows = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT orig.id AS current_id,
|
SELECT orig.id AS current_id,
|
||||||
COUNT(*) OVER() as total_results,
|
COUNT(*) OVER() as total_results,
|
||||||
array((SELECT messages.id AS before_id
|
array((SELECT messages.id AS before_id
|
||||||
|
|
@ -432,12 +484,17 @@ async def search_messages(guild_id):
|
||||||
ORDER BY orig.id DESC
|
ORDER BY orig.id DESC
|
||||||
LIMIT 50
|
LIMIT 50
|
||||||
OFFSET $3
|
OFFSET $3
|
||||||
""", guild_id, j['content'], j['offset'], can_read)
|
""",
|
||||||
|
guild_id,
|
||||||
|
j["content"],
|
||||||
|
j["offset"],
|
||||||
|
can_read,
|
||||||
|
)
|
||||||
|
|
||||||
return jsonify(await search_result_from_list(rows))
|
return jsonify(await search_result_from_list(rows))
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/ack', methods=['POST'])
|
@bp.route("/<int:guild_id>/ack", methods=["POST"])
|
||||||
async def ack_guild(guild_id):
|
async def ack_guild(guild_id):
|
||||||
"""ACKnowledge all messages in the guild."""
|
"""ACKnowledge all messages in the guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -448,45 +505,43 @@ async def ack_guild(guild_id):
|
||||||
for chan_id in chan_ids:
|
for chan_id in chan_ids:
|
||||||
await channel_ack(user_id, guild_id, chan_id)
|
await channel_ack(user_id, guild_id, chan_id)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/vanity-url', methods=['GET'])
|
@bp.route("/<int:guild_id>/vanity-url", methods=["GET"])
|
||||||
async def get_vanity_url(guild_id: int):
|
async def get_vanity_url(guild_id: int):
|
||||||
"""Get the vanity url of a guild."""
|
"""Get the vanity url of a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||||
|
|
||||||
inv_code = await app.storage.vanity_invite(guild_id)
|
inv_code = await app.storage.vanity_invite(guild_id)
|
||||||
|
|
||||||
if inv_code is None:
|
if inv_code is None:
|
||||||
return jsonify({'code': None})
|
return jsonify({"code": None})
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await app.storage.get_invite(inv_code))
|
||||||
await app.storage.get_invite(inv_code)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:guild_id>/vanity-url', methods=['PATCH'])
|
@bp.route("/<int:guild_id>/vanity-url", methods=["PATCH"])
|
||||||
async def change_vanity_url(guild_id: int):
|
async def change_vanity_url(guild_id: int):
|
||||||
"""Get the vanity url of a guild."""
|
"""Get the vanity url of a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
if not await app.storage.has_feature(guild_id, 'VANITY_URL'):
|
if not await app.storage.has_feature(guild_id, "VANITY_URL"):
|
||||||
# TODO: is this the right error
|
# TODO: is this the right error
|
||||||
raise BadRequest('guild has no vanity url support')
|
raise BadRequest("guild has no vanity url support")
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||||
|
|
||||||
j = validate(await request.get_json(), VANITY_URL_PATCH)
|
j = validate(await request.get_json(), VANITY_URL_PATCH)
|
||||||
inv_code = j['code']
|
inv_code = j["code"]
|
||||||
|
|
||||||
# store old vanity in a variable to delete it from
|
# store old vanity in a variable to delete it from
|
||||||
# invites table
|
# invites table
|
||||||
old_vanity = await app.storage.vanity_invite(guild_id)
|
old_vanity = await app.storage.vanity_invite(guild_id)
|
||||||
|
|
||||||
if old_vanity == inv_code:
|
if old_vanity == inv_code:
|
||||||
raise BadRequest('can not change to same invite')
|
raise BadRequest("can not change to same invite")
|
||||||
|
|
||||||
# this is sad because we don't really use the things
|
# this is sad because we don't really use the things
|
||||||
# sql gives us, but i havent really found a way to put
|
# sql gives us, but i havent really found a way to put
|
||||||
|
|
@ -494,19 +549,22 @@ async def change_vanity_url(guild_id: int):
|
||||||
# guild_id_fkey fails but INSERT when code_fkey fails..
|
# guild_id_fkey fails but INSERT when code_fkey fails..
|
||||||
inv = await app.storage.get_invite(inv_code)
|
inv = await app.storage.get_invite(inv_code)
|
||||||
if inv:
|
if inv:
|
||||||
raise BadRequest('invite already exists')
|
raise BadRequest("invite already exists")
|
||||||
|
|
||||||
# TODO: this is bad, what if a guild has no channels?
|
# TODO: this is bad, what if a guild has no channels?
|
||||||
# we should probably choose the first channel that has
|
# we should probably choose the first channel that has
|
||||||
# @everyone read messages
|
# @everyone read messages
|
||||||
channels = await app.storage.get_channel_data(guild_id)
|
channels = await app.storage.get_channel_data(guild_id)
|
||||||
channel_id = int(channels[0]['id'])
|
channel_id = int(channels[0]["id"])
|
||||||
|
|
||||||
# delete the old invite, insert new one
|
# delete the old invite, insert new one
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM invites
|
DELETE FROM invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", old_vanity)
|
""",
|
||||||
|
old_vanity,
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
|
|
@ -515,21 +573,27 @@ async def change_vanity_url(guild_id: int):
|
||||||
max_age, temporary)
|
max_age, temporary)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
""",
|
""",
|
||||||
inv_code, guild_id, channel_id, user_id,
|
inv_code,
|
||||||
|
guild_id,
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
# sane defaults for vanity urls.
|
# sane defaults for vanity urls.
|
||||||
0, 0, False,
|
0,
|
||||||
|
0,
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO vanity_invites (guild_id, code)
|
INSERT INTO vanity_invites (guild_id, code)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO
|
ON CONFLICT ON CONSTRAINT vanity_invites_pkey DO
|
||||||
UPDATE
|
UPDATE
|
||||||
SET code = $2
|
SET code = $2
|
||||||
WHERE vanity_invites.guild_id = $1
|
WHERE vanity_invites.guild_id = $1
|
||||||
""", guild_id, inv_code)
|
""",
|
||||||
|
guild_id,
|
||||||
return jsonify(
|
inv_code,
|
||||||
await app.storage.get_invite(inv_code)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return jsonify(await app.storage.get_invite(inv_code))
|
||||||
|
|
|
||||||
|
|
@ -24,41 +24,39 @@ from quart import Blueprint, current_app as app, send_file, redirect
|
||||||
from litecord.embed.sanitizer import make_md_req_url
|
from litecord.embed.sanitizer import make_md_req_url
|
||||||
from litecord.embed.schemas import EmbedURL
|
from litecord.embed.schemas import EmbedURL
|
||||||
|
|
||||||
bp = Blueprint('images', __name__)
|
bp = Blueprint("images", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def send_icon(scope, key, icon_hash, **kwargs):
|
async def send_icon(scope, key, icon_hash, **kwargs):
|
||||||
"""Send an icon."""
|
"""Send an icon."""
|
||||||
icon = await app.icons.generic_get(
|
icon = await app.icons.generic_get(scope, key, icon_hash, **kwargs)
|
||||||
scope, key, icon_hash, **kwargs)
|
|
||||||
|
|
||||||
if not icon:
|
if not icon:
|
||||||
return '', 404
|
return "", 404
|
||||||
|
|
||||||
return await send_file(icon.as_path)
|
return await send_file(icon.as_path)
|
||||||
|
|
||||||
|
|
||||||
def splitext_(filepath):
|
def splitext_(filepath):
|
||||||
name, ext = splitext(filepath)
|
name, ext = splitext(filepath)
|
||||||
return name, ext.strip('.')
|
return name, ext.strip(".")
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/emojis/<emoji_file>', methods=['GET'])
|
@bp.route("/emojis/<emoji_file>", methods=["GET"])
|
||||||
async def _get_raw_emoji(emoji_file):
|
async def _get_raw_emoji(emoji_file):
|
||||||
# emoji = app.icons.get_emoji(emoji_id, ext=ext)
|
# emoji = app.icons.get_emoji(emoji_id, ext=ext)
|
||||||
# just a test file for now
|
# just a test file for now
|
||||||
emoji_id, ext = splitext_(emoji_file)
|
emoji_id, ext = splitext_(emoji_file)
|
||||||
return await send_icon(
|
return await send_icon("emoji", emoji_id, None, ext=ext)
|
||||||
'emoji', emoji_id, None, ext=ext)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/icons/<int:guild_id>/<icon_file>', methods=['GET'])
|
@bp.route("/icons/<int:guild_id>/<icon_file>", methods=["GET"])
|
||||||
async def _get_guild_icon(guild_id: int, icon_file: str):
|
async def _get_guild_icon(guild_id: int, icon_file: str):
|
||||||
icon_hash, ext = splitext_(icon_file)
|
icon_hash, ext = splitext_(icon_file)
|
||||||
return await send_icon('guild', guild_id, icon_hash, ext=ext)
|
return await send_icon("guild", guild_id, icon_hash, ext=ext)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/embed/avatars/<int:default_id>.png')
|
@bp.route("/embed/avatars/<int:default_id>.png")
|
||||||
async def _get_default_user_avatar(default_id: int):
|
async def _get_default_user_avatar(default_id: int):
|
||||||
# TODO: how do we determine which assets to use for this?
|
# TODO: how do we determine which assets to use for this?
|
||||||
# I don't think we can use discord assets.
|
# I don't think we can use discord assets.
|
||||||
|
|
@ -66,25 +64,29 @@ async def _get_default_user_avatar(default_id: int):
|
||||||
|
|
||||||
|
|
||||||
async def _handle_webhook_avatar(md_url_redir: str):
|
async def _handle_webhook_avatar(md_url_redir: str):
|
||||||
md_url = make_md_req_url(app.config, 'img', EmbedURL(md_url_redir))
|
md_url = make_md_req_url(app.config, "img", EmbedURL(md_url_redir))
|
||||||
return redirect(md_url)
|
return redirect(md_url)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/avatars/<int:user_id>/<avatar_file>')
|
@bp.route("/avatars/<int:user_id>/<avatar_file>")
|
||||||
async def _get_user_avatar(user_id, avatar_file):
|
async def _get_user_avatar(user_id, avatar_file):
|
||||||
avatar_hash, ext = splitext_(avatar_file)
|
avatar_hash, ext = splitext_(avatar_file)
|
||||||
|
|
||||||
# first, check if this is a webhook avatar to redir to
|
# first, check if this is a webhook avatar to redir to
|
||||||
md_url_redir = await app.db.fetchval("""
|
md_url_redir = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT md_url_redir
|
SELECT md_url_redir
|
||||||
FROM webhook_avatars
|
FROM webhook_avatars
|
||||||
WHERE webhook_id = $1 AND hash = $2
|
WHERE webhook_id = $1 AND hash = $2
|
||||||
""", user_id, avatar_hash)
|
""",
|
||||||
|
user_id,
|
||||||
|
avatar_hash,
|
||||||
|
)
|
||||||
|
|
||||||
if md_url_redir:
|
if md_url_redir:
|
||||||
return await _handle_webhook_avatar(md_url_redir)
|
return await _handle_webhook_avatar(md_url_redir)
|
||||||
|
|
||||||
return await send_icon('user', user_id, avatar_hash, ext=ext)
|
return await send_icon("user", user_id, avatar_hash, ext=ext)
|
||||||
|
|
||||||
|
|
||||||
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
|
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
|
||||||
|
|
@ -92,19 +94,19 @@ async def get_app_icon(application_id, icon_hash, ext):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/channel-icons/<int:channel_id>/<icon_file>', methods=['GET'])
|
@bp.route("/channel-icons/<int:channel_id>/<icon_file>", methods=["GET"])
|
||||||
async def _get_gdm_icon(channel_id: int, icon_file: str):
|
async def _get_gdm_icon(channel_id: int, icon_file: str):
|
||||||
icon_hash, ext = splitext_(icon_file)
|
icon_hash, ext = splitext_(icon_file)
|
||||||
return await send_icon('channel-icons', channel_id, icon_hash, ext=ext)
|
return await send_icon("channel-icons", channel_id, icon_hash, ext=ext)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/splashes/<int:guild_id>/<icon_file>', methods=['GET'])
|
@bp.route("/splashes/<int:guild_id>/<icon_file>", methods=["GET"])
|
||||||
async def _get_guild_splash(guild_id: int, icon_file: str):
|
async def _get_guild_splash(guild_id: int, icon_file: str):
|
||||||
icon_hash, ext = splitext_(icon_file)
|
icon_hash, ext = splitext_(icon_file)
|
||||||
return await send_icon('splash', guild_id, icon_hash, ext=ext)
|
return await send_icon("splash", guild_id, icon_hash, ext=ext)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/banners/<int:guild_id>/<icon_file>', methods=['GET'])
|
@bp.route("/banners/<int:guild_id>/<icon_file>", methods=["GET"])
|
||||||
async def _get_guild_banner(guild_id: int, icon_file: str):
|
async def _get_guild_banner(guild_id: int, icon_file: str):
|
||||||
icon_hash, ext = splitext_(icon_file)
|
icon_hash, ext = splitext_(icon_file)
|
||||||
return await send_icon('banner', guild_id, icon_hash, ext=ext)
|
return await send_icon("banner", guild_id, icon_hash, ext=ext)
|
||||||
|
|
|
||||||
|
|
@ -32,13 +32,16 @@ from .guilds import create_guild_settings
|
||||||
from ..utils import async_map
|
from ..utils import async_map
|
||||||
|
|
||||||
from litecord.blueprints.checks import (
|
from litecord.blueprints.checks import (
|
||||||
channel_check, channel_perm_check, guild_check, guild_perm_check
|
channel_check,
|
||||||
|
channel_perm_check,
|
||||||
|
guild_check,
|
||||||
|
guild_perm_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
|
from litecord.blueprints.dm_channels import gdm_is_member, gdm_add_recipient
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('invites', __name__)
|
bp = Blueprint("invites", __name__)
|
||||||
|
|
||||||
|
|
||||||
class UnknownInvite(BadRequest):
|
class UnknownInvite(BadRequest):
|
||||||
|
|
@ -48,16 +51,18 @@ class UnknownInvite(BadRequest):
|
||||||
class InvalidInvite(Forbidden):
|
class InvalidInvite(Forbidden):
|
||||||
error_code = 50020
|
error_code = 50020
|
||||||
|
|
||||||
|
|
||||||
class AlreadyInvited(BaseException):
|
class AlreadyInvited(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def gen_inv_code() -> str:
|
def gen_inv_code() -> str:
|
||||||
"""Generate an invite code.
|
"""Generate an invite code.
|
||||||
|
|
||||||
This is a primitive and does not guarantee uniqueness.
|
This is a primitive and does not guarantee uniqueness.
|
||||||
"""
|
"""
|
||||||
raw = secrets.token_urlsafe(10)
|
raw = secrets.token_urlsafe(10)
|
||||||
raw = re.sub(r'\/|\+|\-|\_', '', raw)
|
raw = re.sub(r"\/|\+|\-|\_", "", raw)
|
||||||
|
|
||||||
return raw[:7]
|
return raw[:7]
|
||||||
|
|
||||||
|
|
@ -65,23 +70,31 @@ def gen_inv_code() -> str:
|
||||||
async def invite_precheck(user_id: int, guild_id: int):
|
async def invite_precheck(user_id: int, guild_id: int):
|
||||||
"""pre-check invite use in the context of a guild."""
|
"""pre-check invite use in the context of a guild."""
|
||||||
|
|
||||||
joined = await app.db.fetchval("""
|
joined = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT joined_at
|
SELECT joined_at
|
||||||
FROM members
|
FROM members
|
||||||
WHERE user_id = $1 AND guild_id = $2
|
WHERE user_id = $1 AND guild_id = $2
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if joined is not None:
|
if joined is not None:
|
||||||
raise AlreadyInvited('You are already in the guild')
|
raise AlreadyInvited("You are already in the guild")
|
||||||
|
|
||||||
banned = await app.db.fetchval("""
|
banned = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT reason
|
SELECT reason
|
||||||
FROM bans
|
FROM bans
|
||||||
WHERE user_id = $1 AND guild_id = $2
|
WHERE user_id = $1 AND guild_id = $2
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if banned is not None:
|
if banned is not None:
|
||||||
raise InvalidInvite('You are banned.')
|
raise InvalidInvite("You are banned.")
|
||||||
|
|
||||||
|
|
||||||
async def invite_precheck_gdm(user_id: int, channel_id: int):
|
async def invite_precheck_gdm(user_id: int, channel_id: int):
|
||||||
|
|
@ -89,23 +102,23 @@ async def invite_precheck_gdm(user_id: int, channel_id: int):
|
||||||
is_member = await gdm_is_member(channel_id, user_id)
|
is_member = await gdm_is_member(channel_id, user_id)
|
||||||
|
|
||||||
if is_member:
|
if is_member:
|
||||||
raise AlreadyInvited('You are already in the Group DM')
|
raise AlreadyInvited("You are already in the Group DM")
|
||||||
|
|
||||||
|
|
||||||
async def _inv_check_age(inv: dict):
|
async def _inv_check_age(inv: dict):
|
||||||
if inv['max_age'] == 0:
|
if inv["max_age"] == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.datetime.utcnow()
|
||||||
delta_sec = (now - inv['created_at']).total_seconds()
|
delta_sec = (now - inv["created_at"]).total_seconds()
|
||||||
|
|
||||||
if delta_sec > inv['max_age']:
|
if delta_sec > inv["max_age"]:
|
||||||
await delete_invite(inv['code'])
|
await delete_invite(inv["code"])
|
||||||
raise InvalidInvite('Invite is expired')
|
raise InvalidInvite("Invite is expired")
|
||||||
|
|
||||||
if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']:
|
if inv["max_uses"] is not -1 and inv["uses"] > inv["max_uses"]:
|
||||||
await delete_invite(inv['code'])
|
await delete_invite(inv["code"])
|
||||||
raise InvalidInvite('Too many uses')
|
raise InvalidInvite("Too many uses")
|
||||||
|
|
||||||
|
|
||||||
async def _guild_add_member(guild_id: int, user_id: int):
|
async def _guild_add_member(guild_id: int, user_id: int):
|
||||||
|
|
@ -119,78 +132,89 @@ async def _guild_add_member(guild_id: int, user_id: int):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: system message for member join
|
# TODO: system message for member join
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO members (user_id, guild_id)
|
INSERT INTO members (user_id, guild_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
await create_guild_settings(guild_id, user_id)
|
await create_guild_settings(guild_id, user_id)
|
||||||
|
|
||||||
# add the @everyone role to the invited member
|
# add the @everyone role to the invited member
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO member_roles (user_id, guild_id, role_id)
|
INSERT INTO member_roles (user_id, guild_id, role_id)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""", user_id, guild_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
# tell current members a new member came up
|
# tell current members a new member came up
|
||||||
member = await app.storage.get_member_data_one(guild_id, user_id)
|
member = await app.storage.get_member_data_one(guild_id, user_id)
|
||||||
await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_ADD', {
|
await app.dispatcher.dispatch_guild(
|
||||||
**member,
|
guild_id, "GUILD_MEMBER_ADD", {**member, **{"guild_id": str(guild_id)}}
|
||||||
**{
|
)
|
||||||
'guild_id': str(guild_id),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
# update member lists for the new member
|
# update member lists for the new member
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("lazy_guild", guild_id, "new_member", user_id)
|
||||||
'lazy_guild', guild_id, 'new_member', user_id)
|
|
||||||
|
|
||||||
# subscribe new member to guild, so they get events n stuff
|
# subscribe new member to guild, so they get events n stuff
|
||||||
await app.dispatcher.sub('guild', guild_id, user_id)
|
await app.dispatcher.sub("guild", guild_id, user_id)
|
||||||
|
|
||||||
# tell the new member that theres the guild it just joined.
|
# tell the new member that theres the guild it just joined.
|
||||||
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
# we use dispatch_user_guild so that we send the GUILD_CREATE
|
||||||
# just to the shards that are actually tied to it.
|
# just to the shards that are actually tied to it.
|
||||||
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||||
await app.dispatcher.dispatch_user_guild(
|
await app.dispatcher.dispatch_user_guild(user_id, guild_id, "GUILD_CREATE", guild)
|
||||||
user_id, guild_id, 'GUILD_CREATE', guild)
|
|
||||||
|
|
||||||
|
|
||||||
async def use_invite(user_id, invite_code):
|
async def use_invite(user_id, invite_code):
|
||||||
"""Try using an invite"""
|
"""Try using an invite"""
|
||||||
inv = await app.db.fetchrow("""
|
inv = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT code, channel_id, guild_id, created_at,
|
SELECT code, channel_id, guild_id, created_at,
|
||||||
max_age, uses, max_uses
|
max_age, uses, max_uses
|
||||||
FROM invites
|
FROM invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", invite_code)
|
""",
|
||||||
|
invite_code,
|
||||||
|
)
|
||||||
|
|
||||||
if inv is None:
|
if inv is None:
|
||||||
raise UnknownInvite('Unknown invite')
|
raise UnknownInvite("Unknown invite")
|
||||||
|
|
||||||
await _inv_check_age(inv)
|
await _inv_check_age(inv)
|
||||||
|
|
||||||
# NOTE: if group dm invite, guild_id is null.
|
# NOTE: if group dm invite, guild_id is null.
|
||||||
guild_id = inv['guild_id']
|
guild_id = inv["guild_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if guild_id is None:
|
if guild_id is None:
|
||||||
channel_id = inv['channel_id']
|
channel_id = inv["channel_id"]
|
||||||
await invite_precheck_gdm(user_id, inv['channel_id'])
|
await invite_precheck_gdm(user_id, inv["channel_id"])
|
||||||
await gdm_add_recipient(channel_id, user_id)
|
await gdm_add_recipient(channel_id, user_id)
|
||||||
else:
|
else:
|
||||||
await invite_precheck(user_id, guild_id)
|
await invite_precheck(user_id, guild_id)
|
||||||
await _guild_add_member(guild_id, user_id)
|
await _guild_add_member(guild_id, user_id)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE invites
|
UPDATE invites
|
||||||
SET uses = uses + 1
|
SET uses = uses + 1
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", invite_code)
|
""",
|
||||||
|
invite_code,
|
||||||
|
)
|
||||||
except AlreadyInvited:
|
except AlreadyInvited:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@bp.route('/channels/<int:channel_id>/invites', methods=['POST'])
|
|
||||||
|
@bp.route("/channels/<int:channel_id>/invites", methods=["POST"])
|
||||||
async def create_invite(channel_id):
|
async def create_invite(channel_id):
|
||||||
"""Create an invite to a channel."""
|
"""Create an invite to a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -201,12 +225,14 @@ async def create_invite(channel_id):
|
||||||
|
|
||||||
# NOTE: this works on group dms, since it returns ALL_PERMISSIONS on
|
# NOTE: this works on group dms, since it returns ALL_PERMISSIONS on
|
||||||
# non-guild channels.
|
# non-guild channels.
|
||||||
await channel_perm_check(user_id, channel_id, 'create_invites')
|
await channel_perm_check(user_id, channel_id, "create_invites")
|
||||||
|
|
||||||
if chantype not in (ChannelType.GUILD_TEXT,
|
if chantype not in (
|
||||||
|
ChannelType.GUILD_TEXT,
|
||||||
ChannelType.GUILD_VOICE,
|
ChannelType.GUILD_VOICE,
|
||||||
ChannelType.GROUP_DM):
|
ChannelType.GROUP_DM,
|
||||||
raise BadRequest('Invalid channel type')
|
):
|
||||||
|
raise BadRequest("Invalid channel type")
|
||||||
|
|
||||||
invite_code = gen_inv_code()
|
invite_code = gen_inv_code()
|
||||||
|
|
||||||
|
|
@ -222,101 +248,122 @@ async def create_invite(channel_id):
|
||||||
max_age, temporary)
|
max_age, temporary)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
""",
|
""",
|
||||||
invite_code, guild_id, channel_id, user_id,
|
invite_code,
|
||||||
j['max_uses'], j['max_age'], j['temporary']
|
guild_id,
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
j["max_uses"],
|
||||||
|
j["max_age"],
|
||||||
|
j["temporary"],
|
||||||
)
|
)
|
||||||
|
|
||||||
invite = await app.storage.get_invite(invite_code)
|
invite = await app.storage.get_invite(invite_code)
|
||||||
return jsonify(invite)
|
return jsonify(invite)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/invite/<invite_code>', methods=['GET'])
|
@bp.route("/invite/<invite_code>", methods=["GET"])
|
||||||
@bp.route('/invites/<invite_code>', methods=['GET'])
|
@bp.route("/invites/<invite_code>", methods=["GET"])
|
||||||
async def get_invite(invite_code: str):
|
async def get_invite(invite_code: str):
|
||||||
inv = await app.storage.get_invite(invite_code)
|
inv = await app.storage.get_invite(invite_code)
|
||||||
|
|
||||||
if not inv:
|
if not inv:
|
||||||
return '', 404
|
return "", 404
|
||||||
|
|
||||||
if request.args.get('with_counts'):
|
if request.args.get("with_counts"):
|
||||||
extra = await app.storage.get_invite_extra(invite_code)
|
extra = await app.storage.get_invite_extra(invite_code)
|
||||||
inv.update(extra)
|
inv.update(extra)
|
||||||
|
|
||||||
return jsonify(inv)
|
return jsonify(inv)
|
||||||
|
|
||||||
|
|
||||||
async def delete_invite(invite_code: str):
|
async def delete_invite(invite_code: str):
|
||||||
"""Delete an invite."""
|
"""Delete an invite."""
|
||||||
await app.db.fetchval("""
|
await app.db.fetchval(
|
||||||
|
"""
|
||||||
DELETE FROM invites
|
DELETE FROM invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", invite_code)
|
""",
|
||||||
|
invite_code,
|
||||||
|
)
|
||||||
|
|
||||||
@bp.route('/invite/<invite_code>', methods=['DELETE'])
|
|
||||||
@bp.route('/invites/<invite_code>', methods=['DELETE'])
|
@bp.route("/invite/<invite_code>", methods=["DELETE"])
|
||||||
|
@bp.route("/invites/<invite_code>", methods=["DELETE"])
|
||||||
async def _delete_invite(invite_code: str):
|
async def _delete_invite(invite_code: str):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
guild_id = await app.db.fetchval("""
|
guild_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
FROM invites
|
FROM invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", invite_code)
|
""",
|
||||||
|
invite_code,
|
||||||
|
)
|
||||||
|
|
||||||
if guild_id is None:
|
if guild_id is None:
|
||||||
raise BadRequest('Unknown invite')
|
raise BadRequest("Unknown invite")
|
||||||
|
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||||
|
|
||||||
inv = await app.storage.get_invite(invite_code)
|
inv = await app.storage.get_invite(invite_code)
|
||||||
await delete_invite(invite_code)
|
await delete_invite(invite_code)
|
||||||
return jsonify(inv)
|
return jsonify(inv)
|
||||||
|
|
||||||
|
|
||||||
async def _get_inv(code):
|
async def _get_inv(code):
|
||||||
inv = await app.storage.get_invite(code)
|
inv = await app.storage.get_invite(code)
|
||||||
meta = await app.storage.get_invite_metadata(code)
|
meta = await app.storage.get_invite_metadata(code)
|
||||||
return {**inv, **meta}
|
return {**inv, **meta}
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/guilds/<int:guild_id>/invites', methods=['GET'])
|
@bp.route("/guilds/<int:guild_id>/invites", methods=["GET"])
|
||||||
async def get_guild_invites(guild_id: int):
|
async def get_guild_invites(guild_id: int):
|
||||||
"""Get all invites for a guild."""
|
"""Get all invites for a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_guild')
|
await guild_perm_check(user_id, guild_id, "manage_guild")
|
||||||
|
|
||||||
inv_codes = await app.db.fetch("""
|
inv_codes = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT code
|
SELECT code
|
||||||
FROM invites
|
FROM invites
|
||||||
WHERE guild_id = $1
|
WHERE guild_id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
inv_codes = [r['code'] for r in inv_codes]
|
inv_codes = [r["code"] for r in inv_codes]
|
||||||
invs = await async_map(_get_inv, inv_codes)
|
invs = await async_map(_get_inv, inv_codes)
|
||||||
return jsonify(invs)
|
return jsonify(invs)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/channels/<int:channel_id>/invites', methods=['GET'])
|
@bp.route("/channels/<int:channel_id>/invites", methods=["GET"])
|
||||||
async def get_channel_invites(channel_id: int):
|
async def get_channel_invites(channel_id: int):
|
||||||
"""Get all invites for a channel."""
|
"""Get all invites for a channel."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
_ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_channels')
|
await guild_perm_check(user_id, guild_id, "manage_channels")
|
||||||
|
|
||||||
inv_codes = await app.db.fetch("""
|
inv_codes = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT code
|
SELECT code
|
||||||
FROM invites
|
FROM invites
|
||||||
WHERE guild_id = $1 AND channel_id = $2
|
WHERE guild_id = $1 AND channel_id = $2
|
||||||
""", guild_id, channel_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
inv_codes = [r['code'] for r in inv_codes]
|
inv_codes = [r["code"] for r in inv_codes]
|
||||||
invs = await async_map(_get_inv, inv_codes)
|
invs = await async_map(_get_inv, inv_codes)
|
||||||
return jsonify(invs)
|
return jsonify(invs)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/invite/<invite_code>', methods=['POST'])
|
@bp.route("/invite/<invite_code>", methods=["POST"])
|
||||||
@bp.route('/invites/<invite_code>', methods=['POST'])
|
@bp.route("/invites/<invite_code>", methods=["POST"])
|
||||||
async def _use_invite(invite_code):
|
async def _use_invite(invite_code):
|
||||||
"""Use an invite."""
|
"""Use an invite."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -327,9 +374,4 @@ async def _use_invite(invite_code):
|
||||||
inv = await app.storage.get_invite(invite_code)
|
inv = await app.storage.get_invite(invite_code)
|
||||||
inv_meta = await app.storage.get_invite_metadata(invite_code)
|
inv_meta = await app.storage.get_invite_metadata(invite_code)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({**inv, **{"inviter": inv_meta["inviter"]}})
|
||||||
**inv,
|
|
||||||
**{
|
|
||||||
'inviter': inv_meta['inviter']
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
|
||||||
|
|
@ -19,83 +19,75 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from quart import Blueprint, current_app as app, jsonify, request
|
from quart import Blueprint, current_app as app, jsonify, request
|
||||||
|
|
||||||
bp = Blueprint('nodeinfo', __name__)
|
bp = Blueprint("nodeinfo", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/.well-known/nodeinfo')
|
@bp.route("/.well-known/nodeinfo")
|
||||||
async def _dummy_nodeinfo_index():
|
async def _dummy_nodeinfo_index():
|
||||||
proto = 'http' if not app.config['IS_SSL'] else 'https'
|
proto = "http" if not app.config["IS_SSL"] else "https"
|
||||||
main_url = app.config.get('MAIN_URL', request.host)
|
main_url = app.config.get("MAIN_URL", request.host)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'links': [{
|
{
|
||||||
'href': f'{proto}://{main_url}/nodeinfo/2.0.json',
|
"links": [
|
||||||
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.0'
|
{
|
||||||
}, {
|
"href": f"{proto}://{main_url}/nodeinfo/2.0.json",
|
||||||
'href': f'{proto}://{main_url}/nodeinfo/2.1.json',
|
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.0",
|
||||||
'rel': 'http://nodeinfo.diaspora.software/ns/schema/2.1'
|
},
|
||||||
}]
|
{
|
||||||
})
|
"href": f"{proto}://{main_url}/nodeinfo/2.1.json",
|
||||||
|
"rel": "http://nodeinfo.diaspora.software/ns/schema/2.1",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def fetch_nodeinfo_20():
|
async def fetch_nodeinfo_20():
|
||||||
usercount = await app.db.fetchval("""
|
usercount = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM users
|
FROM users
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
message_count = await app.db.fetchval("""
|
message_count = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM messages
|
FROM messages
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'metadata': {
|
"metadata": {
|
||||||
'features': [
|
"features": ["discord_api"],
|
||||||
'discord_api'
|
"nodeDescription": "A Litecord instance",
|
||||||
],
|
"nodeName": "Litecord/Nya",
|
||||||
|
"private": False,
|
||||||
'nodeDescription': 'A Litecord instance',
|
"federation": {},
|
||||||
'nodeName': 'Litecord/Nya',
|
|
||||||
'private': False,
|
|
||||||
|
|
||||||
'federation': {}
|
|
||||||
},
|
},
|
||||||
'openRegistrations': app.config['REGISTRATIONS'],
|
"openRegistrations": app.config["REGISTRATIONS"],
|
||||||
'protocols': [],
|
"protocols": [],
|
||||||
'software': {
|
"software": {"name": "litecord", "version": "litecord v0"},
|
||||||
'name': 'litecord',
|
"services": {"inbound": [], "outbound": []},
|
||||||
'version': 'litecord v0',
|
"usage": {"localPosts": message_count, "users": {"total": usercount}},
|
||||||
},
|
"version": "2.0",
|
||||||
|
|
||||||
'services': {
|
|
||||||
'inbound': [],
|
|
||||||
'outbound': [],
|
|
||||||
},
|
|
||||||
|
|
||||||
'usage': {
|
|
||||||
'localPosts': message_count,
|
|
||||||
'users': {
|
|
||||||
'total': usercount
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'version': '2.0',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/nodeinfo/2.0.json')
|
@bp.route("/nodeinfo/2.0.json")
|
||||||
async def _nodeinfo_20():
|
async def _nodeinfo_20():
|
||||||
"""Handler for nodeinfo 2.0."""
|
"""Handler for nodeinfo 2.0."""
|
||||||
raw_nodeinfo = await fetch_nodeinfo_20()
|
raw_nodeinfo = await fetch_nodeinfo_20()
|
||||||
return jsonify(raw_nodeinfo)
|
return jsonify(raw_nodeinfo)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/nodeinfo/2.1.json')
|
@bp.route("/nodeinfo/2.1.json")
|
||||||
async def _nodeinfo_21():
|
async def _nodeinfo_21():
|
||||||
"""Handler for nodeinfo 2.1."""
|
"""Handler for nodeinfo 2.1."""
|
||||||
raw_nodeinfo = await fetch_nodeinfo_20()
|
raw_nodeinfo = await fetch_nodeinfo_20()
|
||||||
|
|
||||||
raw_nodeinfo['software']['repository'] = 'https://gitlab.com/litecord/litecord'
|
raw_nodeinfo["software"]["repository"] = "https://gitlab.com/litecord/litecord"
|
||||||
raw_nodeinfo['version'] = '2.1'
|
raw_nodeinfo["version"] = "2.1"
|
||||||
|
|
||||||
return jsonify(raw_nodeinfo)
|
return jsonify(raw_nodeinfo)
|
||||||
|
|
|
||||||
|
|
@ -26,76 +26,89 @@ from ..enums import RelationshipType
|
||||||
from litecord.errors import BadRequest
|
from litecord.errors import BadRequest
|
||||||
|
|
||||||
|
|
||||||
bp = Blueprint('relationship', __name__)
|
bp = Blueprint("relationship", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/relationships', methods=['GET'])
|
@bp.route("/@me/relationships", methods=["GET"])
|
||||||
async def get_me_relationships():
|
async def get_me_relationships():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
return jsonify(
|
return jsonify(await app.user_storage.get_relationships(user_id))
|
||||||
await app.user_storage.get_relationships(user_id))
|
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch_single_pres(user_id, presence: dict):
|
async def _dispatch_single_pres(user_id, presence: dict):
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("user", user_id, "PRESENCE_UPDATE", presence)
|
||||||
'user', user_id, 'PRESENCE_UPDATE', presence
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _unsub_friend(user_id, peer_id):
|
async def _unsub_friend(user_id, peer_id):
|
||||||
await app.dispatcher.unsub('friend', user_id, peer_id)
|
await app.dispatcher.unsub("friend", user_id, peer_id)
|
||||||
await app.dispatcher.unsub('friend', peer_id, user_id)
|
await app.dispatcher.unsub("friend", peer_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
async def _sub_friend(user_id, peer_id):
|
async def _sub_friend(user_id, peer_id):
|
||||||
await app.dispatcher.sub('friend', user_id, peer_id)
|
await app.dispatcher.sub("friend", user_id, peer_id)
|
||||||
await app.dispatcher.sub('friend', peer_id, user_id)
|
await app.dispatcher.sub("friend", peer_id, user_id)
|
||||||
|
|
||||||
# dispatch presence update to the user and peer about
|
# dispatch presence update to the user and peer about
|
||||||
# eachother's presence.
|
# eachother's presence.
|
||||||
user_pres, peer_pres = await app.presence.friend_presences(
|
user_pres, peer_pres = await app.presence.friend_presences([user_id, peer_id])
|
||||||
[user_id, peer_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
await _dispatch_single_pres(user_id, peer_pres)
|
await _dispatch_single_pres(user_id, peer_pres)
|
||||||
await _dispatch_single_pres(peer_id, user_pres)
|
await _dispatch_single_pres(peer_id, user_pres)
|
||||||
|
|
||||||
|
|
||||||
async def make_friend(user_id: int, peer_id: int,
|
async def make_friend(
|
||||||
rel_type=RelationshipType.FRIEND.value):
|
user_id: int, peer_id: int, rel_type=RelationshipType.FRIEND.value
|
||||||
|
):
|
||||||
_friend = RelationshipType.FRIEND.value
|
_friend = RelationshipType.FRIEND.value
|
||||||
_block = RelationshipType.BLOCK.value
|
_block = RelationshipType.BLOCK.value
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO relationships (user_id, peer_id, rel_type)
|
INSERT INTO relationships (user_id, peer_id, rel_type)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""", user_id, peer_id, rel_type)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
rel_type,
|
||||||
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
# try to update rel_type
|
# try to update rel_type
|
||||||
old_rel_type = await app.db.fetchval("""
|
old_rel_type = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT rel_type
|
SELECT rel_type
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE user_id = $1 AND peer_id = $2
|
WHERE user_id = $1 AND peer_id = $2
|
||||||
""", user_id, peer_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
|
||||||
if old_rel_type == _friend and rel_type == _block:
|
if old_rel_type == _friend and rel_type == _block:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE relationships
|
UPDATE relationships
|
||||||
SET rel_type = $1
|
SET rel_type = $1
|
||||||
WHERE user_id = $2 AND peer_id = $3
|
WHERE user_id = $2 AND peer_id = $3
|
||||||
""", rel_type, user_id, peer_id)
|
""",
|
||||||
|
rel_type,
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
|
||||||
# remove any existing friendship before the block
|
# remove any existing friendship before the block
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM relationships
|
DELETE FROM relationships
|
||||||
WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3
|
WHERE peer_id = $1 AND user_id = $2 AND rel_type = $3
|
||||||
""", peer_id, user_id, _friend)
|
""",
|
||||||
|
peer_id,
|
||||||
|
user_id,
|
||||||
|
_friend,
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(
|
await app.dispatcher.dispatch_user(
|
||||||
peer_id, 'RELATIONSHIP_REMOVE', {
|
peer_id, "RELATIONSHIP_REMOVE", {"type": _friend, "id": str(user_id)}
|
||||||
'type': _friend,
|
|
||||||
'id': str(user_id)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await _unsub_friend(user_id, peer_id)
|
await _unsub_friend(user_id, peer_id)
|
||||||
|
|
@ -106,95 +119,118 @@ async def make_friend(user_id: int, peer_id: int,
|
||||||
|
|
||||||
# check if this is an acceptance
|
# check if this is an acceptance
|
||||||
# of a friend request
|
# of a friend request
|
||||||
existing = await app.db.fetchrow("""
|
existing = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT user_id, peer_id
|
SELECT user_id, peer_id
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
||||||
""", peer_id, user_id, _friend)
|
""",
|
||||||
|
peer_id,
|
||||||
|
user_id,
|
||||||
|
_friend,
|
||||||
|
)
|
||||||
|
|
||||||
_dispatch = app.dispatcher.dispatch_user
|
_dispatch = app.dispatcher.dispatch_user
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# accepted a friend request, dispatch respective
|
# accepted a friend request, dispatch respective
|
||||||
# relationship events
|
# relationship events
|
||||||
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
|
await _dispatch(
|
||||||
'type': RelationshipType.INCOMING.value,
|
user_id,
|
||||||
'id': str(peer_id)
|
"RELATIONSHIP_REMOVE",
|
||||||
})
|
{"type": RelationshipType.INCOMING.value, "id": str(peer_id)},
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch(user_id, 'RELATIONSHIP_ADD', {
|
await _dispatch(
|
||||||
'type': _friend,
|
user_id,
|
||||||
'id': str(peer_id),
|
"RELATIONSHIP_ADD",
|
||||||
'user': await app.storage.get_user(peer_id)
|
{
|
||||||
})
|
"type": _friend,
|
||||||
|
"id": str(peer_id),
|
||||||
|
"user": await app.storage.get_user(peer_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
|
await _dispatch(
|
||||||
'type': _friend,
|
peer_id,
|
||||||
'id': str(user_id),
|
"RELATIONSHIP_ADD",
|
||||||
'user': await app.storage.get_user(user_id)
|
{
|
||||||
})
|
"type": _friend,
|
||||||
|
"id": str(user_id),
|
||||||
|
"user": await app.storage.get_user(user_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
await _sub_friend(user_id, peer_id)
|
await _sub_friend(user_id, peer_id)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
# check if friend AND not acceptance of fr
|
# check if friend AND not acceptance of fr
|
||||||
if rel_type == _friend:
|
if rel_type == _friend:
|
||||||
await _dispatch(user_id, 'RELATIONSHIP_ADD', {
|
await _dispatch(
|
||||||
'id': str(peer_id),
|
user_id,
|
||||||
'type': RelationshipType.OUTGOING.value,
|
"RELATIONSHIP_ADD",
|
||||||
'user': await app.storage.get_user(peer_id),
|
{
|
||||||
})
|
"id": str(peer_id),
|
||||||
|
"type": RelationshipType.OUTGOING.value,
|
||||||
|
"user": await app.storage.get_user(peer_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch(peer_id, 'RELATIONSHIP_ADD', {
|
await _dispatch(
|
||||||
'id': str(user_id),
|
peer_id,
|
||||||
'type': RelationshipType.INCOMING.value,
|
"RELATIONSHIP_ADD",
|
||||||
'user': await app.storage.get_user(user_id)
|
{
|
||||||
})
|
"id": str(user_id),
|
||||||
|
"type": RelationshipType.INCOMING.value,
|
||||||
|
"user": await app.storage.get_user(user_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# we don't make the pubsub link
|
# we don't make the pubsub link
|
||||||
# until the peer accepts the friend request
|
# until the peer accepts the friend request
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class RelationshipFailed(BadRequest):
|
class RelationshipFailed(BadRequest):
|
||||||
"""Exception for general relationship errors."""
|
"""Exception for general relationship errors."""
|
||||||
|
|
||||||
error_code = 80004
|
error_code = 80004
|
||||||
|
|
||||||
|
|
||||||
class RelationshipBlocked(BadRequest):
|
class RelationshipBlocked(BadRequest):
|
||||||
"""Exception for when the peer has blocked the user."""
|
"""Exception for when the peer has blocked the user."""
|
||||||
|
|
||||||
error_code = 80001
|
error_code = 80001
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/relationships', methods=['POST'])
|
@bp.route("/@me/relationships", methods=["POST"])
|
||||||
async def post_relationship():
|
async def post_relationship():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
j = validate(await request.get_json(), SPECIFIC_FRIEND)
|
j = validate(await request.get_json(), SPECIFIC_FRIEND)
|
||||||
|
|
||||||
uid = await app.storage.search_user(j['username'],
|
uid = await app.storage.search_user(j["username"], str(j["discriminator"]))
|
||||||
str(j['discriminator']))
|
|
||||||
|
|
||||||
if not uid:
|
if not uid:
|
||||||
raise RelationshipFailed('No users with DiscordTag exist')
|
raise RelationshipFailed("No users with DiscordTag exist")
|
||||||
|
|
||||||
res = await make_friend(user_id, uid)
|
res = await make_friend(user_id, uid)
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
raise RelationshipBlocked('Can not friend user due to block')
|
raise RelationshipBlocked("Can not friend user due to block")
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/relationships/<int:peer_id>', methods=['PUT'])
|
@bp.route("/@me/relationships/<int:peer_id>", methods=["PUT"])
|
||||||
async def add_relationship(peer_id: int):
|
async def add_relationship(peer_id: int):
|
||||||
"""Add a relationship to the peer."""
|
"""Add a relationship to the peer."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
payload = validate(await request.get_json(), RELATIONSHIP)
|
payload = validate(await request.get_json(), RELATIONSHIP)
|
||||||
rel_type = payload['type']
|
rel_type = payload["type"]
|
||||||
|
|
||||||
res = await make_friend(user_id, peer_id, rel_type)
|
res = await make_friend(user_id, peer_id, rel_type)
|
||||||
|
|
||||||
|
|
@ -204,18 +240,22 @@ async def add_relationship(peer_id: int):
|
||||||
# make_friend did not succeed, so we
|
# make_friend did not succeed, so we
|
||||||
# assume it is a block and dispatch
|
# assume it is a block and dispatch
|
||||||
# the respective RELATIONSHIP_ADD.
|
# the respective RELATIONSHIP_ADD.
|
||||||
await app.dispatcher.dispatch_user(user_id, 'RELATIONSHIP_ADD', {
|
await app.dispatcher.dispatch_user(
|
||||||
'id': str(peer_id),
|
user_id,
|
||||||
'type': RelationshipType.BLOCK.value,
|
"RELATIONSHIP_ADD",
|
||||||
'user': await app.storage.get_user(peer_id)
|
{
|
||||||
})
|
"id": str(peer_id),
|
||||||
|
"type": RelationshipType.BLOCK.value,
|
||||||
|
"user": await app.storage.get_user(peer_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
await _unsub_friend(user_id, peer_id)
|
await _unsub_friend(user_id, peer_id)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/relationships/<int:peer_id>', methods=['DELETE'])
|
@bp.route("/@me/relationships/<int:peer_id>", methods=["DELETE"])
|
||||||
async def remove_relationship(peer_id: int):
|
async def remove_relationship(peer_id: int):
|
||||||
"""Remove an existing relationship"""
|
"""Remove an existing relationship"""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -223,69 +263,86 @@ async def remove_relationship(peer_id: int):
|
||||||
_block = RelationshipType.BLOCK.value
|
_block = RelationshipType.BLOCK.value
|
||||||
_dispatch = app.dispatcher.dispatch_user
|
_dispatch = app.dispatcher.dispatch_user
|
||||||
|
|
||||||
rel_type = await app.db.fetchval("""
|
rel_type = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT rel_type
|
SELECT rel_type
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE user_id = $1 AND peer_id = $2
|
WHERE user_id = $1 AND peer_id = $2
|
||||||
""", user_id, peer_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
|
||||||
incoming_rel_type = await app.db.fetchval("""
|
incoming_rel_type = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT rel_type
|
SELECT rel_type
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE user_id = $1 AND peer_id = $2
|
WHERE user_id = $1 AND peer_id = $2
|
||||||
""", peer_id, user_id)
|
""",
|
||||||
|
peer_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# if any of those are friend
|
# if any of those are friend
|
||||||
if _friend in (rel_type, incoming_rel_type):
|
if _friend in (rel_type, incoming_rel_type):
|
||||||
# closing the friendship, have to delete both rows
|
# closing the friendship, have to delete both rows
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM relationships
|
DELETE FROM relationships
|
||||||
WHERE (
|
WHERE (
|
||||||
(user_id = $1 AND peer_id = $2) OR
|
(user_id = $1 AND peer_id = $2) OR
|
||||||
(user_id = $2 AND peer_id = $1)
|
(user_id = $2 AND peer_id = $1)
|
||||||
) AND rel_type = $3
|
) AND rel_type = $3
|
||||||
""", user_id, peer_id, _friend)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
_friend,
|
||||||
|
)
|
||||||
|
|
||||||
# if there wasnt any mutual friendship before,
|
# if there wasnt any mutual friendship before,
|
||||||
# assume they were requests of INCOMING
|
# assume they were requests of INCOMING
|
||||||
# and OUTGOING.
|
# and OUTGOING.
|
||||||
user_del_type = RelationshipType.OUTGOING.value if \
|
user_del_type = (
|
||||||
incoming_rel_type != _friend else _friend
|
RelationshipType.OUTGOING.value if incoming_rel_type != _friend else _friend
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
|
await _dispatch(
|
||||||
'id': str(peer_id),
|
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": user_del_type}
|
||||||
'type': user_del_type,
|
)
|
||||||
})
|
|
||||||
|
|
||||||
peer_del_type = RelationshipType.INCOMING.value if \
|
peer_del_type = (
|
||||||
incoming_rel_type != _friend else _friend
|
RelationshipType.INCOMING.value if incoming_rel_type != _friend else _friend
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch(peer_id, 'RELATIONSHIP_REMOVE', {
|
await _dispatch(
|
||||||
'id': str(user_id),
|
peer_id, "RELATIONSHIP_REMOVE", {"id": str(user_id), "type": peer_del_type}
|
||||||
'type': peer_del_type,
|
)
|
||||||
})
|
|
||||||
|
|
||||||
await _unsub_friend(user_id, peer_id)
|
await _unsub_friend(user_id, peer_id)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
# was a block!
|
# was a block!
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM relationships
|
DELETE FROM relationships
|
||||||
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
||||||
""", user_id, peer_id, _block)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
_block,
|
||||||
|
)
|
||||||
|
|
||||||
await _dispatch(user_id, 'RELATIONSHIP_REMOVE', {
|
await _dispatch(
|
||||||
'id': str(peer_id),
|
user_id, "RELATIONSHIP_REMOVE", {"id": str(peer_id), "type": _block}
|
||||||
'type': _block,
|
)
|
||||||
})
|
|
||||||
|
|
||||||
await _unsub_friend(user_id, peer_id)
|
await _unsub_friend(user_id, peer_id)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:peer_id>/relationships', methods=['GET'])
|
@bp.route("/<int:peer_id>/relationships", methods=["GET"])
|
||||||
async def get_mutual_friends(peer_id: int):
|
async def get_mutual_friends(peer_id: int):
|
||||||
"""Fetch a users' mutual friends with the current user."""
|
"""Fetch a users' mutual friends with the current user."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -294,17 +351,15 @@ async def get_mutual_friends(peer_id: int):
|
||||||
peer = await app.storage.get_user(peer_id)
|
peer = await app.storage.get_user(peer_id)
|
||||||
|
|
||||||
if not peer:
|
if not peer:
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
# NOTE: maybe this could be better with pure SQL calculations
|
# NOTE: maybe this could be better with pure SQL calculations
|
||||||
# but it would be beyond my current SQL knowledge, so...
|
# but it would be beyond my current SQL knowledge, so...
|
||||||
user_rels = await app.user_storage.get_relationships(user_id)
|
user_rels = await app.user_storage.get_relationships(user_id)
|
||||||
peer_rels = await app.user_storage.get_relationships(peer_id)
|
peer_rels = await app.user_storage.get_relationships(peer_id)
|
||||||
|
|
||||||
user_friends = {rel['user']['id']
|
user_friends = {rel["user"]["id"] for rel in user_rels if rel["type"] == _friend}
|
||||||
for rel in user_rels if rel['type'] == _friend}
|
peer_friends = {rel["user"]["id"] for rel in peer_rels if rel["type"] == _friend}
|
||||||
peer_friends = {rel['user']['id']
|
|
||||||
for rel in peer_rels if rel['type'] == _friend}
|
|
||||||
|
|
||||||
# get the intersection, then map them to Storage.get_user() calls
|
# get the intersection, then map them to Storage.get_user() calls
|
||||||
mutual_ids = user_friends & peer_friends
|
mutual_ids = user_friends & peer_friends
|
||||||
|
|
@ -312,8 +367,6 @@ async def get_mutual_friends(peer_id: int):
|
||||||
mutual_friends = []
|
mutual_friends = []
|
||||||
|
|
||||||
for friend_id in mutual_ids:
|
for friend_id in mutual_ids:
|
||||||
mutual_friends.append(
|
mutual_friends.append(await app.storage.get_user(int(friend_id)))
|
||||||
await app.storage.get_user(int(friend_id))
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(mutual_friends)
|
return jsonify(mutual_friends)
|
||||||
|
|
|
||||||
|
|
@ -19,21 +19,19 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from quart import Blueprint, jsonify
|
from quart import Blueprint, jsonify
|
||||||
|
|
||||||
bp = Blueprint('science', __name__)
|
bp = Blueprint("science", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/science', methods=['POST'])
|
@bp.route("/science", methods=["POST"])
|
||||||
async def science():
|
async def science():
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/applications', methods=['GET'])
|
@bp.route("/applications", methods=["GET"])
|
||||||
async def applications():
|
async def applications():
|
||||||
return jsonify([])
|
return jsonify([])
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/experiments', methods=['GET'])
|
@bp.route("/experiments", methods=["GET"])
|
||||||
async def experiments():
|
async def experiments():
|
||||||
return jsonify({
|
return jsonify({"assignments": []})
|
||||||
'assignments': []
|
|
||||||
})
|
|
||||||
|
|
|
||||||
|
|
@ -20,23 +20,24 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
from quart import Blueprint, current_app as app, render_template_string
|
from quart import Blueprint, current_app as app, render_template_string
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
bp = Blueprint('static', __name__)
|
bp = Blueprint("static", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<path:path>')
|
@bp.route("/<path:path>")
|
||||||
async def static_pages(path):
|
async def static_pages(path):
|
||||||
"""Map requests from / to /static."""
|
"""Map requests from / to /static."""
|
||||||
if '..' in path:
|
if ".." in path:
|
||||||
return 'no', 404
|
return "no", 404
|
||||||
|
|
||||||
static_path = Path.cwd() / Path('static') / path
|
static_path = Path.cwd() / Path("static") / path
|
||||||
return await app.send_static_file(str(static_path))
|
return await app.send_static_file(str(static_path))
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/')
|
@bp.route("/")
|
||||||
@bp.route('/api')
|
@bp.route("/api")
|
||||||
async def index_handler():
|
async def index_handler():
|
||||||
"""Handler for the index page."""
|
"""Handler for the index page."""
|
||||||
index_path = Path.cwd() / Path('static') / 'index.html'
|
index_path = Path.cwd() / Path("static") / "index.html"
|
||||||
return await render_template_string(
|
return await render_template_string(
|
||||||
index_path.read_text(), inst_name=app.config['NAME'])
|
index_path.read_text(), inst_name=app.config["NAME"]
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -21,4 +21,4 @@ from .billing import bp as user_billing
|
||||||
from .settings import bp as user_settings
|
from .settings import bp as user_settings
|
||||||
from .fake_store import bp as fake_store
|
from .fake_store import bp as fake_store
|
||||||
|
|
||||||
__all__ = ['user_billing', 'user_settings', 'fake_store']
|
__all__ = ["user_billing", "user_settings", "fake_store"]
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from litecord.enums import UserFlags, PremiumType
|
||||||
from litecord.blueprints.users import mass_user_update
|
from litecord.blueprints.users import mass_user_update
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('users_billing', __name__)
|
bp = Blueprint("users_billing", __name__)
|
||||||
|
|
||||||
|
|
||||||
class PaymentSource(Enum):
|
class PaymentSource(Enum):
|
||||||
|
|
@ -68,78 +68,87 @@ class PaymentStatus:
|
||||||
|
|
||||||
|
|
||||||
PLAN_ID_TO_TYPE = {
|
PLAN_ID_TO_TYPE = {
|
||||||
'premium_month_tier_1': PremiumType.TIER_1,
|
"premium_month_tier_1": PremiumType.TIER_1,
|
||||||
'premium_month_tier_2': PremiumType.TIER_2,
|
"premium_month_tier_2": PremiumType.TIER_2,
|
||||||
'premium_year_tier_1': PremiumType.TIER_1,
|
"premium_year_tier_1": PremiumType.TIER_1,
|
||||||
'premium_year_tier_2': PremiumType.TIER_2,
|
"premium_year_tier_2": PremiumType.TIER_2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# how much should a payment be, depending
|
# how much should a payment be, depending
|
||||||
# of the subscription
|
# of the subscription
|
||||||
AMOUNTS = {
|
AMOUNTS = {
|
||||||
'premium_month_tier_1': 499,
|
"premium_month_tier_1": 499,
|
||||||
'premium_month_tier_2': 999,
|
"premium_month_tier_2": 999,
|
||||||
'premium_year_tier_1': 4999,
|
"premium_year_tier_1": 4999,
|
||||||
'premium_year_tier_2': 9999,
|
"premium_year_tier_2": 9999,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CREATE_SUBSCRIPTION = {
|
CREATE_SUBSCRIPTION = {
|
||||||
'payment_gateway_plan_id': {'type': 'string'},
|
"payment_gateway_plan_id": {"type": "string"},
|
||||||
'payment_source_id': {'coerce': int}
|
"payment_source_id": {"coerce": int},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
PAYMENT_SOURCE = {
|
PAYMENT_SOURCE = {
|
||||||
'billing_address': {
|
"billing_address": {
|
||||||
'type': 'dict',
|
"type": "dict",
|
||||||
'schema': {
|
"schema": {
|
||||||
'country': {'type': 'string', 'required': True},
|
"country": {"type": "string", "required": True},
|
||||||
'city': {'type': 'string', 'required': True},
|
"city": {"type": "string", "required": True},
|
||||||
'name': {'type': 'string', 'required': True},
|
"name": {"type": "string", "required": True},
|
||||||
'line_1': {'type': 'string', 'required': False},
|
"line_1": {"type": "string", "required": False},
|
||||||
'line_2': {'type': 'string', 'required': False},
|
"line_2": {"type": "string", "required": False},
|
||||||
'postal_code': {'type': 'string', 'required': True},
|
"postal_code": {"type": "string", "required": True},
|
||||||
'state': {'type': 'string', 'required': True},
|
"state": {"type": "string", "required": True},
|
||||||
}
|
|
||||||
},
|
},
|
||||||
'payment_gateway': {'type': 'number', 'required': True},
|
},
|
||||||
'token': {'type': 'string', 'required': True},
|
"payment_gateway": {"type": "number", "required": True},
|
||||||
|
"token": {"type": "string", "required": True},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def get_payment_source_ids(user_id: int) -> list:
|
async def get_payment_source_ids(user_id: int) -> list:
|
||||||
rows = await app.db.fetch("""
|
rows = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM user_payment_sources
|
FROM user_payment_sources
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [r['id'] for r in rows]
|
return [r["id"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
async def get_payment_ids(user_id: int, db=None) -> list:
|
async def get_payment_ids(user_id: int, db=None) -> list:
|
||||||
if not db:
|
if not db:
|
||||||
db = app.db
|
db = app.db
|
||||||
|
|
||||||
rows = await db.fetch("""
|
rows = await db.fetch(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [r['id'] for r in rows]
|
return [r["id"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
async def get_subscription_ids(user_id: int) -> list:
|
async def get_subscription_ids(user_id: int) -> list:
|
||||||
rows = await app.db.fetch("""
|
rows = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM user_subscriptions
|
FROM user_subscriptions
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [r['id'] for r in rows]
|
return [r["id"] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
||||||
|
|
@ -148,41 +157,44 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
||||||
if not db:
|
if not db:
|
||||||
db = app.db
|
db = app.db
|
||||||
|
|
||||||
source_type = await db.fetchval("""
|
source_type = await db.fetchval(
|
||||||
|
"""
|
||||||
SELECT source_type
|
SELECT source_type
|
||||||
FROM user_payment_sources
|
FROM user_payment_sources
|
||||||
WHERE id = $1 AND user_id = $2
|
WHERE id = $1 AND user_id = $2
|
||||||
""", source_id, user_id)
|
""",
|
||||||
|
source_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
source_type = PaymentSource(source_type)
|
source_type = PaymentSource(source_type)
|
||||||
|
|
||||||
specific_fields = {
|
specific_fields = {
|
||||||
PaymentSource.PAYPAL: ['paypal_email'],
|
PaymentSource.PAYPAL: ["paypal_email"],
|
||||||
PaymentSource.CREDIT: ['expires_month', 'expires_year',
|
PaymentSource.CREDIT: ["expires_month", "expires_year", "brand", "cc_full"],
|
||||||
'brand', 'cc_full']
|
|
||||||
}[source_type]
|
}[source_type]
|
||||||
|
|
||||||
fields = ','.join(specific_fields)
|
fields = ",".join(specific_fields)
|
||||||
|
|
||||||
extras_row = await db.fetchrow(f"""
|
extras_row = await db.fetchrow(
|
||||||
|
f"""
|
||||||
SELECT {fields}, billing_address, default_, id::text
|
SELECT {fields}, billing_address, default_, id::text
|
||||||
FROM user_payment_sources
|
FROM user_payment_sources
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", source_id)
|
""",
|
||||||
|
source_id,
|
||||||
|
)
|
||||||
|
|
||||||
derow = dict(extras_row)
|
derow = dict(extras_row)
|
||||||
|
|
||||||
if source_type == PaymentSource.CREDIT:
|
if source_type == PaymentSource.CREDIT:
|
||||||
derow['last_4'] = derow['cc_full'][-4:]
|
derow["last_4"] = derow["cc_full"][-4:]
|
||||||
derow.pop('cc_full')
|
derow.pop("cc_full")
|
||||||
|
|
||||||
derow['default'] = derow['default_']
|
derow["default"] = derow["default_"]
|
||||||
derow.pop('default_')
|
derow.pop("default_")
|
||||||
|
|
||||||
source = {
|
source = {"id": str(source_id), "type": source_type.value}
|
||||||
'id': str(source_id),
|
|
||||||
'type': source_type.value,
|
|
||||||
}
|
|
||||||
|
|
||||||
return {**source, **derow}
|
return {**source, **derow}
|
||||||
|
|
||||||
|
|
@ -192,7 +204,8 @@ async def get_subscription(subscription_id: int, db=None):
|
||||||
if not db:
|
if not db:
|
||||||
db = app.db
|
db = app.db
|
||||||
|
|
||||||
row = await db.fetchrow("""
|
row = await db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT id::text, source_id::text AS payment_source_id,
|
SELECT id::text, source_id::text AS payment_source_id,
|
||||||
user_id,
|
user_id,
|
||||||
payment_gateway, payment_gateway_plan_id,
|
payment_gateway, payment_gateway_plan_id,
|
||||||
|
|
@ -201,14 +214,16 @@ async def get_subscription(subscription_id: int, db=None):
|
||||||
canceled_at, s_type, status
|
canceled_at, s_type, status
|
||||||
FROM user_subscriptions
|
FROM user_subscriptions
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", subscription_id)
|
""",
|
||||||
|
subscription_id,
|
||||||
|
)
|
||||||
|
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
|
|
||||||
drow['type'] = drow['s_type']
|
drow["type"] = drow["s_type"]
|
||||||
drow.pop('s_type')
|
drow.pop("s_type")
|
||||||
|
|
||||||
to_tstamp = ['current_period_start', 'current_period_end', 'canceled_at']
|
to_tstamp = ["current_period_start", "current_period_end", "canceled_at"]
|
||||||
|
|
||||||
for field in to_tstamp:
|
for field in to_tstamp:
|
||||||
drow[field] = timestamp_(drow[field])
|
drow[field] = timestamp_(drow[field])
|
||||||
|
|
@ -221,27 +236,30 @@ async def get_payment(payment_id: int, db=None):
|
||||||
if not db:
|
if not db:
|
||||||
db = app.db
|
db = app.db
|
||||||
|
|
||||||
row = await db.fetchrow("""
|
row = await db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT id::text, source_id, subscription_id, user_id,
|
SELECT id::text, source_id, subscription_id, user_id,
|
||||||
amount, amount_refunded, currency,
|
amount, amount_refunded, currency,
|
||||||
description, status, tax, tax_inclusive
|
description, status, tax, tax_inclusive
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", payment_id)
|
""",
|
||||||
|
payment_id,
|
||||||
|
)
|
||||||
|
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
|
|
||||||
drow.pop('source_id')
|
drow.pop("source_id")
|
||||||
drow.pop('subscription_id')
|
drow.pop("subscription_id")
|
||||||
drow.pop('user_id')
|
drow.pop("user_id")
|
||||||
|
|
||||||
drow['created_at'] = snowflake_datetime(int(drow['id']))
|
drow["created_at"] = snowflake_datetime(int(drow["id"]))
|
||||||
|
|
||||||
drow['payment_source'] = await get_payment_source(
|
drow["payment_source"] = await get_payment_source(
|
||||||
row['user_id'], row['source_id'], db)
|
row["user_id"], row["source_id"], db
|
||||||
|
)
|
||||||
|
|
||||||
drow['subscription'] = await get_subscription(
|
drow["subscription"] = await get_subscription(row["subscription_id"], db)
|
||||||
row['subscription_id'], db)
|
|
||||||
|
|
||||||
return drow
|
return drow
|
||||||
|
|
||||||
|
|
@ -255,7 +273,7 @@ async def create_payment(subscription_id, db=None):
|
||||||
|
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
||||||
amount = AMOUNTS[sub['payment_gateway_plan_id']]
|
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
|
||||||
|
|
||||||
await db.execute(
|
await db.execute(
|
||||||
"""
|
"""
|
||||||
|
|
@ -266,10 +284,16 @@ async def create_payment(subscription_id, db=None):
|
||||||
)
|
)
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
|
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
|
||||||
""", new_id, int(sub['payment_source_id']),
|
""",
|
||||||
subscription_id, int(sub['user_id']),
|
new_id,
|
||||||
amount, 'usd', 'FUCK NITRO',
|
int(sub["payment_source_id"]),
|
||||||
PaymentStatus.SUCCESS)
|
subscription_id,
|
||||||
|
int(sub["user_id"]),
|
||||||
|
amount,
|
||||||
|
"usd",
|
||||||
|
"FUCK NITRO",
|
||||||
|
PaymentStatus.SUCCESS,
|
||||||
|
)
|
||||||
|
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
|
|
@ -278,29 +302,34 @@ async def process_subscription(app, subscription_id: int):
|
||||||
"""Process a single subscription."""
|
"""Process a single subscription."""
|
||||||
sub = await get_subscription(subscription_id, app.db)
|
sub = await get_subscription(subscription_id, app.db)
|
||||||
|
|
||||||
user_id = int(sub['user_id'])
|
user_id = int(sub["user_id"])
|
||||||
|
|
||||||
if sub['status'] != SubscriptionStatus.ACTIVE:
|
if sub["status"] != SubscriptionStatus.ACTIVE:
|
||||||
log.debug('ignoring sub {}, not active',
|
log.debug("ignoring sub {}, not active", subscription_id)
|
||||||
subscription_id)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# if the subscription is still active
|
# if the subscription is still active
|
||||||
# (should get cancelled status on failed
|
# (should get cancelled status on failed
|
||||||
# payments), then we should update premium status
|
# payments), then we should update premium status
|
||||||
first_payment_id = await app.db.fetchval("""
|
first_payment_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT MIN(id)
|
SELECT MIN(id)
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
WHERE subscription_id = $1
|
WHERE subscription_id = $1
|
||||||
""", subscription_id)
|
""",
|
||||||
|
subscription_id,
|
||||||
|
)
|
||||||
|
|
||||||
first_payment_ts = snowflake_datetime(first_payment_id)
|
first_payment_ts = snowflake_datetime(first_payment_id)
|
||||||
|
|
||||||
premium_since = await app.db.fetchval("""
|
premium_since = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT premium_since
|
SELECT premium_since
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
premium_since = premium_since or datetime.datetime.fromtimestamp(0)
|
premium_since = premium_since or datetime.datetime.fromtimestamp(0)
|
||||||
|
|
||||||
|
|
@ -312,27 +341,34 @@ async def process_subscription(app, subscription_id: int):
|
||||||
if delta.total_seconds() < 24 * HOURS:
|
if delta.total_seconds() < 24 * HOURS:
|
||||||
return
|
return
|
||||||
|
|
||||||
old_flags = await app.db.fetchval("""
|
old_flags = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT flags
|
SELECT flags
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
new_flags = old_flags | UserFlags.premium_early
|
new_flags = old_flags | UserFlags.premium_early
|
||||||
log.debug('updating flags {}, {} => {}',
|
log.debug("updating flags {}, {} => {}", user_id, old_flags, new_flags)
|
||||||
user_id, old_flags, new_flags)
|
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET premium_since = $1, flags = $2
|
SET premium_since = $1, flags = $2
|
||||||
WHERE id = $3
|
WHERE id = $3
|
||||||
""", first_payment_ts, new_flags, user_id)
|
""",
|
||||||
|
first_payment_ts,
|
||||||
|
new_flags,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# dispatch updated user to all possible clients
|
# dispatch updated user to all possible clients
|
||||||
await mass_user_update(user_id, app)
|
await mass_user_update(user_id, app)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/billing/payment-sources', methods=['GET'])
|
@bp.route("/@me/billing/payment-sources", methods=["GET"])
|
||||||
async def _get_billing_sources():
|
async def _get_billing_sources():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
source_ids = await get_payment_source_ids(user_id)
|
source_ids = await get_payment_source_ids(user_id)
|
||||||
|
|
@ -346,7 +382,7 @@ async def _get_billing_sources():
|
||||||
return jsonify(res)
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/billing/subscriptions', methods=['GET'])
|
@bp.route("/@me/billing/subscriptions", methods=["GET"])
|
||||||
async def _get_billing_subscriptions():
|
async def _get_billing_subscriptions():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
sub_ids = await get_subscription_ids(user_id)
|
sub_ids = await get_subscription_ids(user_id)
|
||||||
|
|
@ -358,7 +394,7 @@ async def _get_billing_subscriptions():
|
||||||
return jsonify(res)
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/billing/payments', methods=['GET'])
|
@bp.route("/@me/billing/payments", methods=["GET"])
|
||||||
async def _get_billing_payments():
|
async def _get_billing_payments():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
payment_ids = await get_payment_ids(user_id)
|
payment_ids = await get_payment_ids(user_id)
|
||||||
|
|
@ -370,7 +406,7 @@ async def _get_billing_payments():
|
||||||
return jsonify(res)
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/billing/payment-sources', methods=['POST'])
|
@bp.route("/@me/billing/payment-sources", methods=["POST"])
|
||||||
async def _create_payment_source():
|
async def _create_payment_source():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
j = validate(await request.get_json(), PAYMENT_SOURCE)
|
j = validate(await request.get_json(), PAYMENT_SOURCE)
|
||||||
|
|
@ -383,34 +419,40 @@ async def _create_payment_source():
|
||||||
default_, expires_month, expires_year, brand, cc_full,
|
default_, expires_month, expires_year, brand, cc_full,
|
||||||
billing_address)
|
billing_address)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
""", new_source_id, user_id, PaymentSource.CREDIT.value,
|
""",
|
||||||
True, 12, 6969, 'Visa', '4242424242424242',
|
new_source_id,
|
||||||
json.dumps(j['billing_address']))
|
user_id,
|
||||||
|
PaymentSource.CREDIT.value,
|
||||||
return jsonify(
|
True,
|
||||||
await get_payment_source(user_id, new_source_id)
|
12,
|
||||||
|
6969,
|
||||||
|
"Visa",
|
||||||
|
"4242424242424242",
|
||||||
|
json.dumps(j["billing_address"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return jsonify(await get_payment_source(user_id, new_source_id))
|
||||||
|
|
||||||
@bp.route('/@me/billing/subscriptions', methods=['POST'])
|
|
||||||
|
@bp.route("/@me/billing/subscriptions", methods=["POST"])
|
||||||
async def _create_subscription():
|
async def _create_subscription():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
j = validate(await request.get_json(), CREATE_SUBSCRIPTION)
|
j = validate(await request.get_json(), CREATE_SUBSCRIPTION)
|
||||||
|
|
||||||
source = await get_payment_source(user_id, j['payment_source_id'])
|
source = await get_payment_source(user_id, j["payment_source_id"])
|
||||||
if not source:
|
if not source:
|
||||||
raise BadRequest('invalid source id')
|
raise BadRequest("invalid source id")
|
||||||
|
|
||||||
plan_id = j['payment_gateway_plan_id']
|
plan_id = j["payment_gateway_plan_id"]
|
||||||
|
|
||||||
# tier 1 is lightro / classic
|
# tier 1 is lightro / classic
|
||||||
# tier 2 is nitro
|
# tier 2 is nitro
|
||||||
|
|
||||||
period_end = {
|
period_end = {
|
||||||
'premium_month_tier_1': '1 month',
|
"premium_month_tier_1": "1 month",
|
||||||
'premium_month_tier_2': '1 month',
|
"premium_month_tier_2": "1 month",
|
||||||
'premium_year_tier_1': '1 year',
|
"premium_year_tier_1": "1 year",
|
||||||
'premium_year_tier_2': '1 year',
|
"premium_year_tier_2": "1 year",
|
||||||
}[plan_id]
|
}[plan_id]
|
||||||
|
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
@ -422,9 +464,15 @@ async def _create_subscription():
|
||||||
status, period_end)
|
status, period_end)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7,
|
VALUES ($1, $2, $3, $4, $5, $6, $7,
|
||||||
now()::timestamp + interval '{period_end}')
|
now()::timestamp + interval '{period_end}')
|
||||||
""", new_id, j['payment_source_id'], user_id,
|
""",
|
||||||
SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
|
new_id,
|
||||||
plan_id, 1)
|
j["payment_source_id"],
|
||||||
|
user_id,
|
||||||
|
SubscriptionType.PURCHASE,
|
||||||
|
PaymentGateway.STRIPE,
|
||||||
|
plan_id,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
await create_payment(new_id, app.db)
|
await create_payment(new_id, app.db)
|
||||||
|
|
||||||
|
|
@ -432,21 +480,17 @@ async def _create_subscription():
|
||||||
# and dispatch respective user updates to other people.
|
# and dispatch respective user updates to other people.
|
||||||
await process_subscription(app, new_id)
|
await process_subscription(app, new_id)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(await get_subscription(new_id))
|
||||||
await get_subscription(new_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/billing/subscriptions/<int:subscription_id>',
|
@bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["DELETE"])
|
||||||
methods=['DELETE'])
|
|
||||||
async def _delete_subscription(subscription_id):
|
async def _delete_subscription(subscription_id):
|
||||||
# user_id = await token_check()
|
# user_id = await token_check()
|
||||||
# return '', 204
|
# return '', 204
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/billing/subscriptions/<int:subscription_id>',
|
@bp.route("/@me/billing/subscriptions/<int:subscription_id>", methods=["PATCH"])
|
||||||
methods=['PATCH'])
|
|
||||||
async def _patch_subscription(subscription_id):
|
async def _patch_subscription(subscription_id):
|
||||||
"""change a subscription's payment source"""
|
"""change a subscription's payment source"""
|
||||||
# user_id = await token_check()
|
# user_id = await token_check()
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,11 @@ from asyncio import sleep, CancelledError
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from litecord.blueprints.user.billing import (
|
from litecord.blueprints.user.billing import (
|
||||||
get_subscription, get_payment_ids, get_payment, create_payment,
|
get_subscription,
|
||||||
process_subscription
|
get_payment_ids,
|
||||||
|
get_payment,
|
||||||
|
create_payment,
|
||||||
|
process_subscription,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.snowflake import snowflake_datetime
|
from litecord.snowflake import snowflake_datetime
|
||||||
|
|
@ -37,15 +40,15 @@ log = Logger(__name__)
|
||||||
# how many days until a payment needs
|
# how many days until a payment needs
|
||||||
# to be issued
|
# to be issued
|
||||||
THRESHOLDS = {
|
THRESHOLDS = {
|
||||||
'premium_month_tier_1': 30,
|
"premium_month_tier_1": 30,
|
||||||
'premium_month_tier_2': 30,
|
"premium_month_tier_2": 30,
|
||||||
'premium_year_tier_1': 365,
|
"premium_year_tier_1": 365,
|
||||||
'premium_year_tier_2': 365,
|
"premium_year_tier_2": 365,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _resched(app):
|
async def _resched(app):
|
||||||
log.debug('waiting 30 minutes for job.')
|
log.debug("waiting 30 minutes for job.")
|
||||||
await sleep(30 * MINUTES)
|
await sleep(30 * MINUTES)
|
||||||
app.sched.spawn(payment_job(app))
|
app.sched.spawn(payment_job(app))
|
||||||
|
|
||||||
|
|
@ -54,10 +57,10 @@ async def _process_user_payments(app, user_id: int):
|
||||||
payments = await get_payment_ids(user_id, app.db)
|
payments = await get_payment_ids(user_id, app.db)
|
||||||
|
|
||||||
if not payments:
|
if not payments:
|
||||||
log.debug('no payments for uid {}, skipping', user_id)
|
log.debug("no payments for uid {}, skipping", user_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
log.debug('{} payments for uid {}', len(payments), user_id)
|
log.debug("{} payments for uid {}", len(payments), user_id)
|
||||||
|
|
||||||
latest_payment = max(payments)
|
latest_payment = max(payments)
|
||||||
|
|
||||||
|
|
@ -66,33 +69,29 @@ async def _process_user_payments(app, user_id: int):
|
||||||
# calculate the difference between this payment
|
# calculate the difference between this payment
|
||||||
# and now.
|
# and now.
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
payment_tstamp = snowflake_datetime(int(payment_data['id']))
|
payment_tstamp = snowflake_datetime(int(payment_data["id"]))
|
||||||
|
|
||||||
delta = now - payment_tstamp
|
delta = now - payment_tstamp
|
||||||
|
|
||||||
sub_id = int(payment_data['subscription']['id'])
|
sub_id = int(payment_data["subscription"]["id"])
|
||||||
subscription = await get_subscription(
|
subscription = await get_subscription(sub_id, app.db)
|
||||||
sub_id, app.db)
|
|
||||||
|
|
||||||
# if the max payment is X days old, we create another.
|
# if the max payment is X days old, we create another.
|
||||||
# X is 30 for monthly subscriptions of nitro,
|
# X is 30 for monthly subscriptions of nitro,
|
||||||
# X is 365 for yearly subscriptions of nitro
|
# X is 365 for yearly subscriptions of nitro
|
||||||
threshold = THRESHOLDS[subscription['payment_gateway_plan_id']]
|
threshold = THRESHOLDS[subscription["payment_gateway_plan_id"]]
|
||||||
|
|
||||||
log.debug('delta {} delta days {} threshold {}',
|
log.debug("delta {} delta days {} threshold {}", delta, delta.days, threshold)
|
||||||
delta, delta.days, threshold)
|
|
||||||
|
|
||||||
if delta.days > threshold:
|
if delta.days > threshold:
|
||||||
log.info('creating payment for sid={}',
|
log.info("creating payment for sid={}", sub_id)
|
||||||
sub_id)
|
|
||||||
|
|
||||||
# create_payment does not call any Stripe
|
# create_payment does not call any Stripe
|
||||||
# or BrainTree APIs at all, since we'll just
|
# or BrainTree APIs at all, since we'll just
|
||||||
# give it as free.
|
# give it as free.
|
||||||
await create_payment(sub_id, app.db)
|
await create_payment(sub_id, app.db)
|
||||||
else:
|
else:
|
||||||
log.debug('sid={}, missing {} days',
|
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
|
||||||
sub_id, threshold - delta.days)
|
|
||||||
|
|
||||||
|
|
||||||
async def payment_job(app):
|
async def payment_job(app):
|
||||||
|
|
@ -101,35 +100,39 @@ async def payment_job(app):
|
||||||
This function will check through users' payments
|
This function will check through users' payments
|
||||||
and add a new one once a month / year.
|
and add a new one once a month / year.
|
||||||
"""
|
"""
|
||||||
log.debug('payment job start!')
|
log.debug("payment job start!")
|
||||||
|
|
||||||
user_ids = await app.db.fetch("""
|
user_ids = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT DISTINCT user_id
|
SELECT DISTINCT user_id
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
log.debug('working {} users', len(user_ids))
|
log.debug("working {} users", len(user_ids))
|
||||||
|
|
||||||
# go through each user's payments
|
# go through each user's payments
|
||||||
for row in user_ids:
|
for row in user_ids:
|
||||||
user_id = row['user_id']
|
user_id = row["user_id"]
|
||||||
try:
|
try:
|
||||||
await _process_user_payments(app, user_id)
|
await _process_user_payments(app, user_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('error while processing user payments')
|
log.exception("error while processing user payments")
|
||||||
|
|
||||||
subscribers = await app.db.fetch("""
|
subscribers = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM user_subscriptions
|
FROM user_subscriptions
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
for row in subscribers:
|
for row in subscribers:
|
||||||
try:
|
try:
|
||||||
await process_subscription(app, row['id'])
|
await process_subscription(app, row["id"])
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('error while processing subscription')
|
log.exception("error while processing subscription")
|
||||||
log.debug('rescheduling..')
|
log.debug("rescheduling..")
|
||||||
try:
|
try:
|
||||||
await _resched(app)
|
await _resched(app)
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
log.info('cancelled while waiting for resched')
|
log.info("cancelled while waiting for resched")
|
||||||
|
|
|
||||||
|
|
@ -22,24 +22,26 @@ fake routes for discord store
|
||||||
"""
|
"""
|
||||||
from quart import Blueprint, jsonify
|
from quart import Blueprint, jsonify
|
||||||
|
|
||||||
bp = Blueprint('fake_store', __name__)
|
bp = Blueprint("fake_store", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/promotions')
|
@bp.route("/promotions")
|
||||||
async def _get_promotions():
|
async def _get_promotions():
|
||||||
return jsonify([])
|
return jsonify([])
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/users/@me/library')
|
@bp.route("/users/@me/library")
|
||||||
async def _get_library():
|
async def _get_library():
|
||||||
return jsonify([])
|
return jsonify([])
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/users/@me/feed/settings')
|
@bp.route("/users/@me/feed/settings")
|
||||||
async def _get_feed_settings():
|
async def _get_feed_settings():
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'subscribed_games': [],
|
{
|
||||||
'subscribed_users': [],
|
"subscribed_games": [],
|
||||||
'unsubscribed_users': [],
|
"subscribed_users": [],
|
||||||
'unsubscribed_games': [],
|
"unsubscribed_users": [],
|
||||||
})
|
"unsubscribed_games": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -23,10 +23,10 @@ from litecord.auth import token_check
|
||||||
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
|
from litecord.schemas import validate, USER_SETTINGS, GUILD_SETTINGS
|
||||||
from litecord.blueprints.checks import guild_check
|
from litecord.blueprints.checks import guild_check
|
||||||
|
|
||||||
bp = Blueprint('users_settings', __name__)
|
bp = Blueprint("users_settings", __name__)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/settings', methods=['GET'])
|
@bp.route("/@me/settings", methods=["GET"])
|
||||||
async def get_user_settings():
|
async def get_user_settings():
|
||||||
"""Get the current user's settings."""
|
"""Get the current user's settings."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -34,7 +34,7 @@ async def get_user_settings():
|
||||||
return jsonify(settings)
|
return jsonify(settings)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/settings', methods=['PATCH'])
|
@bp.route("/@me/settings", methods=["PATCH"])
|
||||||
async def patch_current_settings():
|
async def patch_current_settings():
|
||||||
"""Patch the users' current settings.
|
"""Patch the users' current settings.
|
||||||
|
|
||||||
|
|
@ -47,19 +47,22 @@ async def patch_current_settings():
|
||||||
for key in j:
|
for key in j:
|
||||||
val = j[key]
|
val = j[key]
|
||||||
|
|
||||||
await app.storage.execute_with_json(f"""
|
await app.storage.execute_with_json(
|
||||||
|
f"""
|
||||||
UPDATE user_settings
|
UPDATE user_settings
|
||||||
SET {key}=$1
|
SET {key}=$1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", val, user_id)
|
""",
|
||||||
|
val,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
settings = await app.user_storage.get_user_settings(user_id)
|
settings = await app.user_storage.get_user_settings(user_id)
|
||||||
await app.dispatcher.dispatch_user(
|
await app.dispatcher.dispatch_user(user_id, "USER_SETTINGS_UPDATE", settings)
|
||||||
user_id, 'USER_SETTINGS_UPDATE', settings)
|
|
||||||
return jsonify(settings)
|
return jsonify(settings)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/guilds/<int:guild_id>/settings', methods=['PATCH'])
|
@bp.route("/@me/guilds/<int:guild_id>/settings", methods=["PATCH"])
|
||||||
async def patch_guild_settings(guild_id: int):
|
async def patch_guild_settings(guild_id: int):
|
||||||
"""Update the users' guild settings for a given guild.
|
"""Update the users' guild settings for a given guild.
|
||||||
|
|
||||||
|
|
@ -74,16 +77,21 @@ async def patch_guild_settings(guild_id: int):
|
||||||
# will make sure they exist in the table.
|
# will make sure they exist in the table.
|
||||||
await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
||||||
|
|
||||||
for field in (k for k in j.keys() if k != 'channel_overrides'):
|
for field in (k for k in j.keys() if k != "channel_overrides"):
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
UPDATE guild_settings
|
UPDATE guild_settings
|
||||||
SET {field} = $1
|
SET {field} = $1
|
||||||
WHERE user_id = $2 AND guild_id = $3
|
WHERE user_id = $2 AND guild_id = $3
|
||||||
""", j[field], user_id, guild_id)
|
""",
|
||||||
|
j[field],
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
chan_ids = await app.storage.get_channel_ids(guild_id)
|
chan_ids = await app.storage.get_channel_ids(guild_id)
|
||||||
|
|
||||||
for chandata in j.get('channel_overrides', {}).items():
|
for chandata in j.get("channel_overrides", {}).items():
|
||||||
chan_id, chan_overrides = chandata
|
chan_id, chan_overrides = chandata
|
||||||
chan_id = int(chan_id)
|
chan_id = int(chan_id)
|
||||||
|
|
||||||
|
|
@ -92,7 +100,8 @@ async def patch_guild_settings(guild_id: int):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for field in chan_overrides:
|
for field in chan_overrides:
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
INSERT INTO guild_settings_channel_overrides
|
INSERT INTO guild_settings_channel_overrides
|
||||||
(user_id, guild_id, channel_id, {field})
|
(user_id, guild_id, channel_id, {field})
|
||||||
VALUES
|
VALUES
|
||||||
|
|
@ -105,18 +114,21 @@ async def patch_guild_settings(guild_id: int):
|
||||||
WHERE guild_settings_channel_overrides.user_id = $1
|
WHERE guild_settings_channel_overrides.user_id = $1
|
||||||
AND guild_settings_channel_overrides.guild_id = $2
|
AND guild_settings_channel_overrides.guild_id = $2
|
||||||
AND guild_settings_channel_overrides.channel_id = $3
|
AND guild_settings_channel_overrides.channel_id = $3
|
||||||
""", user_id, guild_id, chan_id, chan_overrides[field])
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
chan_id,
|
||||||
|
chan_overrides[field],
|
||||||
|
)
|
||||||
|
|
||||||
settings = await app.user_storage.get_guild_settings_one(
|
settings = await app.user_storage.get_guild_settings_one(user_id, guild_id)
|
||||||
user_id, guild_id)
|
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(
|
await app.dispatcher.dispatch_user(user_id, "USER_GUILD_SETTINGS_UPDATE", settings)
|
||||||
user_id, 'USER_GUILD_SETTINGS_UPDATE', settings)
|
|
||||||
|
|
||||||
return jsonify(settings)
|
return jsonify(settings)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/notes/<int:target_id>', methods=['PUT'])
|
@bp.route("/@me/notes/<int:target_id>", methods=["PUT"])
|
||||||
async def put_note(target_id: int):
|
async def put_note(target_id: int):
|
||||||
"""Put a note to a user.
|
"""Put a note to a user.
|
||||||
|
|
||||||
|
|
@ -126,10 +138,11 @@ async def put_note(target_id: int):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
j = await request.get_json()
|
j = await request.get_json()
|
||||||
note = str(j['note'])
|
note = str(j["note"])
|
||||||
|
|
||||||
# UPSERTs are beautiful
|
# UPSERTs are beautiful
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO notes (user_id, target_id, note)
|
INSERT INTO notes (user_id, target_id, note)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
|
|
||||||
|
|
@ -138,12 +151,14 @@ async def put_note(target_id: int):
|
||||||
note = $3
|
note = $3
|
||||||
WHERE notes.user_id = $1
|
WHERE notes.user_id = $1
|
||||||
AND notes.target_id = $2
|
AND notes.target_id = $2
|
||||||
""", user_id, target_id, note)
|
""",
|
||||||
|
user_id,
|
||||||
|
target_id,
|
||||||
|
note,
|
||||||
|
)
|
||||||
|
|
||||||
await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', {
|
await app.dispatcher.dispatch_user(
|
||||||
'id': str(target_id),
|
user_id, "USER_NOTE_UPDATE", {"id": str(target_id), "note": note}
|
||||||
'note': note,
|
)
|
||||||
})
|
|
||||||
|
|
||||||
return '', 204
|
|
||||||
|
|
||||||
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,7 @@ from ..errors import Forbidden, BadRequest, Unauthorized
|
||||||
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
|
from ..schemas import validate, USER_UPDATE, GET_MENTIONS
|
||||||
|
|
||||||
from .guilds import guild_check
|
from .guilds import guild_check
|
||||||
from litecord.auth import (
|
from litecord.auth import token_check, hash_data, check_username_usage, roll_discrim
|
||||||
token_check, hash_data, check_username_usage, roll_discrim
|
|
||||||
)
|
|
||||||
from litecord.blueprints.guild.mod import remove_member
|
from litecord.blueprints.guild.mod import remove_member
|
||||||
|
|
||||||
from litecord.enums import PremiumType
|
from litecord.enums import PremiumType
|
||||||
|
|
@ -39,7 +37,7 @@ from litecord.permissions import base_permissions
|
||||||
from litecord.blueprints.auth import check_password
|
from litecord.blueprints.auth import check_password
|
||||||
from litecord.utils import to_update
|
from litecord.utils import to_update
|
||||||
|
|
||||||
bp = Blueprint('user', __name__)
|
bp = Blueprint("user", __name__)
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -58,8 +56,7 @@ async def mass_user_update(user_id, app_=None):
|
||||||
private_user = await app_.storage.get_user(user_id, secure=True)
|
private_user = await app_.storage.get_user(user_id, secure=True)
|
||||||
|
|
||||||
session_ids.extend(
|
session_ids.extend(
|
||||||
await app_.dispatcher.dispatch_user(
|
await app_.dispatcher.dispatch_user(user_id, "USER_UPDATE", private_user)
|
||||||
user_id, 'USER_UPDATE', private_user)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
guild_ids = await app_.user_storage.get_user_guilds(user_id)
|
guild_ids = await app_.user_storage.get_user_guilds(user_id)
|
||||||
|
|
@ -67,26 +64,22 @@ async def mass_user_update(user_id, app_=None):
|
||||||
|
|
||||||
session_ids.extend(
|
session_ids.extend(
|
||||||
await app_.dispatcher.dispatch_many_filter_list(
|
await app_.dispatcher.dispatch_many_filter_list(
|
||||||
'guild', guild_ids, session_ids,
|
"guild", guild_ids, session_ids, "USER_UPDATE", public_user
|
||||||
'USER_UPDATE', public_user
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
session_ids.extend(
|
session_ids.extend(
|
||||||
await app_.dispatcher.dispatch_many_filter_list(
|
await app_.dispatcher.dispatch_many_filter_list(
|
||||||
'friend', friend_ids, session_ids,
|
"friend", friend_ids, session_ids, "USER_UPDATE", public_user
|
||||||
'USER_UPDATE', public_user
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await app_.dispatcher.dispatch_many(
|
await app_.dispatcher.dispatch_many("lazy_guild", guild_ids, "update_user", user_id)
|
||||||
'lazy_guild', guild_ids, 'update_user', user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return public_user, private_user
|
return public_user, private_user
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me', methods=['GET'])
|
@bp.route("/@me", methods=["GET"])
|
||||||
async def get_me():
|
async def get_me():
|
||||||
"""Get the current user's information."""
|
"""Get the current user's information."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -94,18 +87,21 @@ async def get_me():
|
||||||
return jsonify(user)
|
return jsonify(user)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:target_id>', methods=['GET'])
|
@bp.route("/<int:target_id>", methods=["GET"])
|
||||||
async def get_other(target_id):
|
async def get_other(target_id):
|
||||||
"""Get any user, given the user ID."""
|
"""Get any user, given the user ID."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
bot = await app.db.fetchval("""
|
bot = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT bot FROM users
|
SELECT bot FROM users
|
||||||
WHERE users.id = $1
|
WHERE users.id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not bot:
|
if not bot:
|
||||||
raise Forbidden('Only bots can use this endpoint')
|
raise Forbidden("Only bots can use this endpoint")
|
||||||
|
|
||||||
other = await app.storage.get_user(target_id)
|
other = await app.storage.get_user(target_id)
|
||||||
return jsonify(other)
|
return jsonify(other)
|
||||||
|
|
@ -116,66 +112,80 @@ async def _try_username_patch(user_id, new_username: str) -> str:
|
||||||
discrim = None
|
discrim = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET username = $1
|
SET username = $1
|
||||||
WHERE users.id = $2
|
WHERE users.id = $2
|
||||||
""", new_username, user_id)
|
""",
|
||||||
|
new_username,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return await app.db.fetchval("""
|
return await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT discriminator
|
SELECT discriminator
|
||||||
FROM users
|
FROM users
|
||||||
WHERE users.id = $1
|
WHERE users.id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
discrim = await roll_discrim(new_username)
|
discrim = await roll_discrim(new_username)
|
||||||
|
|
||||||
if not discrim:
|
if not discrim:
|
||||||
raise BadRequest('Unable to change username', {
|
raise BadRequest(
|
||||||
'username': 'Too many people are with this username.'
|
"Unable to change username",
|
||||||
})
|
{"username": "Too many people are with this username."},
|
||||||
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET username = $1, discriminator = $2
|
SET username = $1, discriminator = $2
|
||||||
WHERE users.id = $3
|
WHERE users.id = $3
|
||||||
""", new_username, discrim, user_id)
|
""",
|
||||||
|
new_username,
|
||||||
|
discrim,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return discrim
|
return discrim
|
||||||
|
|
||||||
|
|
||||||
async def _try_discrim_patch(user_id, new_discrim: str):
|
async def _try_discrim_patch(user_id, new_discrim: str):
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET discriminator = $1
|
SET discriminator = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_discrim, user_id)
|
""",
|
||||||
|
new_discrim,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
except UniqueViolationError:
|
except UniqueViolationError:
|
||||||
raise BadRequest('Invalid discriminator', {
|
raise BadRequest(
|
||||||
'discriminator': 'Someone already used this discriminator.'
|
"Invalid discriminator",
|
||||||
})
|
{"discriminator": "Someone already used this discriminator."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _check_pass(j, user):
|
async def _check_pass(j, user):
|
||||||
# Do not do password checks on unclaimed accounts
|
# Do not do password checks on unclaimed accounts
|
||||||
if user['email'] is None:
|
if user["email"] is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not j['password']:
|
if not j["password"]:
|
||||||
raise BadRequest('password required', {
|
raise BadRequest("password required", {"password": "password required"})
|
||||||
'password': 'password required'
|
|
||||||
})
|
|
||||||
|
|
||||||
phash = user['password_hash']
|
phash = user["password_hash"]
|
||||||
|
|
||||||
if not await check_password(phash, j['password']):
|
if not await check_password(phash, j["password"]):
|
||||||
raise BadRequest('password incorrect', {
|
raise BadRequest("password incorrect", {"password": "password does not match."})
|
||||||
'password': 'password does not match.'
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me', methods=['PATCH'])
|
@bp.route("/@me", methods=["PATCH"])
|
||||||
async def patch_me():
|
async def patch_me():
|
||||||
"""Patch the current user's information."""
|
"""Patch the current user's information."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -183,36 +193,43 @@ async def patch_me():
|
||||||
j = validate(await request.get_json(), USER_UPDATE)
|
j = validate(await request.get_json(), USER_UPDATE)
|
||||||
user = await app.storage.get_user(user_id, True)
|
user = await app.storage.get_user(user_id, True)
|
||||||
|
|
||||||
user['password_hash'] = await app.db.fetchval("""
|
user["password_hash"] = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT password_hash
|
SELECT password_hash
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if to_update(j, user, 'username'):
|
if to_update(j, user, "username"):
|
||||||
# this will take care of regenning a new discriminator
|
# this will take care of regenning a new discriminator
|
||||||
discrim = await _try_username_patch(user_id, j['username'])
|
discrim = await _try_username_patch(user_id, j["username"])
|
||||||
user['username'] = j['username']
|
user["username"] = j["username"]
|
||||||
user['discriminator'] = discrim
|
user["discriminator"] = discrim
|
||||||
|
|
||||||
if to_update(j, user, 'discriminator'):
|
if to_update(j, user, "discriminator"):
|
||||||
# the API treats discriminators as integers,
|
# the API treats discriminators as integers,
|
||||||
# but I work with strings on the database.
|
# but I work with strings on the database.
|
||||||
new_discrim = str(j['discriminator'])
|
new_discrim = str(j["discriminator"])
|
||||||
|
|
||||||
await _try_discrim_patch(user_id, new_discrim)
|
await _try_discrim_patch(user_id, new_discrim)
|
||||||
user['discriminator'] = new_discrim
|
user["discriminator"] = new_discrim
|
||||||
|
|
||||||
if to_update(j, user, 'email'):
|
if to_update(j, user, "email"):
|
||||||
await _check_pass(j, user)
|
await _check_pass(j, user)
|
||||||
|
|
||||||
# TODO: reverify the new email?
|
# TODO: reverify the new email?
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET email = $1
|
SET email = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j['email'], user_id)
|
""",
|
||||||
user['email'] = j['email']
|
j["email"],
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
user["email"] = j["email"]
|
||||||
|
|
||||||
# only update if values are different
|
# only update if values are different
|
||||||
# from what the user gave.
|
# from what the user gave.
|
||||||
|
|
@ -224,44 +241,49 @@ async def patch_me():
|
||||||
|
|
||||||
# IconManager.update will take care of validating
|
# IconManager.update will take care of validating
|
||||||
# the value once put()-ing
|
# the value once put()-ing
|
||||||
if to_update(j, user, 'avatar'):
|
if to_update(j, user, "avatar"):
|
||||||
mime, _ = parse_data_uri(j['avatar'])
|
mime, _ = parse_data_uri(j["avatar"])
|
||||||
|
|
||||||
if mime == 'image/gif' and user['premium_type'] == PremiumType.NONE:
|
if mime == "image/gif" and user["premium_type"] == PremiumType.NONE:
|
||||||
raise BadRequest('no gif without nitro')
|
raise BadRequest("no gif without nitro")
|
||||||
|
|
||||||
new_icon = await app.icons.update(
|
new_icon = await app.icons.update("user", user_id, j["avatar"], size=(128, 128))
|
||||||
'user', user_id, j['avatar'], size=(128, 128))
|
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET avatar = $1
|
SET avatar = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_icon.icon_hash, user_id)
|
""",
|
||||||
|
new_icon.icon_hash,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if user['email'] is None and not 'new_password' in j:
|
if user["email"] is None and not "new_password" in j:
|
||||||
raise BadRequest('missing password', {
|
raise BadRequest("missing password", {"password": "Please set a password."})
|
||||||
'password': 'Please set a password.'
|
|
||||||
})
|
|
||||||
|
|
||||||
if 'new_password' in j and j['new_password']:
|
if "new_password" in j and j["new_password"]:
|
||||||
await _check_pass(j, user)
|
await _check_pass(j, user)
|
||||||
|
|
||||||
new_hash = await hash_data(j['new_password'])
|
new_hash = await hash_data(j["new_password"])
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET password_hash = $1
|
SET password_hash = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_hash, user_id)
|
""",
|
||||||
|
new_hash,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
user.pop('password_hash')
|
user.pop("password_hash")
|
||||||
|
|
||||||
_, private_user = await mass_user_update(user_id, app)
|
_, private_user = await mass_user_update(user_id, app)
|
||||||
return jsonify(private_user)
|
return jsonify(private_user)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/guilds', methods=['GET'])
|
@bp.route("/@me/guilds", methods=["GET"])
|
||||||
async def get_me_guilds():
|
async def get_me_guilds():
|
||||||
"""Get partial user guilds."""
|
"""Get partial user guilds."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -270,27 +292,30 @@ async def get_me_guilds():
|
||||||
partials = []
|
partials = []
|
||||||
|
|
||||||
for guild_id in guild_ids:
|
for guild_id in guild_ids:
|
||||||
partial = await app.db.fetchrow("""
|
partial = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT id::text, name, icon, owner_id
|
SELECT id::text, name, icon, owner_id
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE guilds.id = $1
|
WHERE guilds.id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
partial = dict(partial)
|
partial = dict(partial)
|
||||||
|
|
||||||
user_perms = await base_permissions(user_id, guild_id)
|
user_perms = await base_permissions(user_id, guild_id)
|
||||||
partial['permissions'] = user_perms.binary
|
partial["permissions"] = user_perms.binary
|
||||||
|
|
||||||
partial['owner'] = partial['owner_id'] == user_id
|
partial["owner"] = partial["owner_id"] == user_id
|
||||||
|
|
||||||
partial.pop('owner_id')
|
partial.pop("owner_id")
|
||||||
|
|
||||||
partials.append(partial)
|
partials.append(partial)
|
||||||
|
|
||||||
return jsonify(partials)
|
return jsonify(partials)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
|
@bp.route("/@me/guilds/<int:guild_id>", methods=["DELETE"])
|
||||||
async def leave_guild(guild_id: int):
|
async def leave_guild(guild_id: int):
|
||||||
"""Leave a guild."""
|
"""Leave a guild."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -298,7 +323,7 @@ async def leave_guild(guild_id: int):
|
||||||
|
|
||||||
await remove_member(guild_id, user_id)
|
await remove_member(guild_id, user_id)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
# @bp.route('/@me/connections', methods=['GET'])
|
# @bp.route('/@me/connections', methods=['GET'])
|
||||||
|
|
@ -306,7 +331,7 @@ async def get_connections():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/consent', methods=['GET', 'POST'])
|
@bp.route("/@me/consent", methods=["GET", "POST"])
|
||||||
async def get_consent():
|
async def get_consent():
|
||||||
"""Always disable data collection.
|
"""Always disable data collection.
|
||||||
|
|
||||||
|
|
@ -314,57 +339,58 @@ async def get_consent():
|
||||||
by the client and ignores them, as they
|
by the client and ignores them, as they
|
||||||
will always be false.
|
will always be false.
|
||||||
"""
|
"""
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'usage_statistics': {
|
{
|
||||||
'consented': False,
|
"usage_statistics": {"consented": False},
|
||||||
},
|
"personalization": {"consented": False},
|
||||||
'personalization': {
|
|
||||||
'consented': False,
|
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/harvest', methods=['GET'])
|
@bp.route("/@me/harvest", methods=["GET"])
|
||||||
async def get_harvest():
|
async def get_harvest():
|
||||||
"""Dummy route"""
|
"""Dummy route"""
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/activities/statistics/applications', methods=['GET'])
|
@bp.route("/@me/activities/statistics/applications", methods=["GET"])
|
||||||
async def get_stats_applications():
|
async def get_stats_applications():
|
||||||
"""Dummy route for info on gameplay time and such"""
|
"""Dummy route for info on gameplay time and such"""
|
||||||
return jsonify([])
|
return jsonify([])
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/library', methods=['GET'])
|
@bp.route("/@me/library", methods=["GET"])
|
||||||
async def get_library():
|
async def get_library():
|
||||||
"""Probably related to Discord Store?"""
|
"""Probably related to Discord Store?"""
|
||||||
return jsonify([])
|
return jsonify([])
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:peer_id>/profile', methods=['GET'])
|
@bp.route("/<int:peer_id>/profile", methods=["GET"])
|
||||||
async def get_profile(peer_id: int):
|
async def get_profile(peer_id: int):
|
||||||
"""Get a user's profile."""
|
"""Get a user's profile."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
peer = await app.storage.get_user(peer_id)
|
peer = await app.storage.get_user(peer_id)
|
||||||
|
|
||||||
if not peer:
|
if not peer:
|
||||||
return '', 404
|
return "", 404
|
||||||
|
|
||||||
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
||||||
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
friends = await app.user_storage.are_friends_with(user_id, peer_id)
|
||||||
|
|
||||||
# don't return a proper card if no guilds are being shared.
|
# don't return a proper card if no guilds are being shared.
|
||||||
if not mutuals and not friends:
|
if not mutuals and not friends:
|
||||||
return '', 404
|
return "", 404
|
||||||
|
|
||||||
# actual premium status is determined by that
|
# actual premium status is determined by that
|
||||||
# column being NULL or not
|
# column being NULL or not
|
||||||
peer_premium = await app.db.fetchval("""
|
peer_premium = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT premium_since
|
SELECT premium_since
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", peer_id)
|
""",
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
|
||||||
mutual_guilds = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
mutual_guilds = await app.user_storage.get_mutual_guilds(user_id, peer_id)
|
||||||
mutual_res = []
|
mutual_res = []
|
||||||
|
|
@ -372,45 +398,49 @@ async def get_profile(peer_id: int):
|
||||||
# ascending sorting
|
# ascending sorting
|
||||||
for guild_id in sorted(mutual_guilds):
|
for guild_id in sorted(mutual_guilds):
|
||||||
|
|
||||||
nick = await app.db.fetchval("""
|
nick = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT nickname
|
SELECT nickname
|
||||||
FROM members
|
FROM members
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
""", guild_id, peer_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
|
||||||
mutual_res.append({
|
mutual_res.append({"id": str(guild_id), "nick": nick})
|
||||||
'id': str(guild_id),
|
|
||||||
'nick': nick,
|
|
||||||
})
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify(
|
||||||
'user': peer,
|
{
|
||||||
'connected_accounts': [],
|
"user": peer,
|
||||||
'premium_since': peer_premium,
|
"connected_accounts": [],
|
||||||
'mutual_guilds': mutual_res,
|
"premium_since": peer_premium,
|
||||||
})
|
"mutual_guilds": mutual_res,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/mentions', methods=['GET'])
|
@bp.route("/@me/mentions", methods=["GET"])
|
||||||
async def _get_mentions():
|
async def _get_mentions():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
j = validate(dict(request.args), GET_MENTIONS)
|
j = validate(dict(request.args), GET_MENTIONS)
|
||||||
|
|
||||||
guild_query = 'AND messages.guild_id = $2' if 'guild_id' in j else ''
|
guild_query = "AND messages.guild_id = $2" if "guild_id" in j else ""
|
||||||
role_query = "OR content LIKE '%<@&%'" if j['roles'] else ''
|
role_query = "OR content LIKE '%<@&%'" if j["roles"] else ""
|
||||||
everyone_query = "OR content LIKE '%@everyone%'" if j['everyone'] else ''
|
everyone_query = "OR content LIKE '%@everyone%'" if j["everyone"] else ""
|
||||||
mention_user = f'<@{user_id}>'
|
mention_user = f"<@{user_id}>"
|
||||||
|
|
||||||
args = [mention_user]
|
args = [mention_user]
|
||||||
|
|
||||||
if guild_query:
|
if guild_query:
|
||||||
args.append(j['guild_id'])
|
args.append(j["guild_id"])
|
||||||
|
|
||||||
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
guild_ids = await app.user_storage.get_user_guilds(user_id)
|
||||||
gids = ','.join(str(guild_id) for guild_id in guild_ids)
|
gids = ",".join(str(guild_id) for guild_id in guild_ids)
|
||||||
|
|
||||||
rows = await app.db.fetch(f"""
|
rows = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT messages.id
|
SELECT messages.id
|
||||||
FROM messages
|
FROM messages
|
||||||
JOIN channels ON messages.channel_id = channels.id
|
JOIN channels ON messages.channel_id = channels.id
|
||||||
|
|
@ -423,20 +453,20 @@ async def _get_mentions():
|
||||||
{guild_query}
|
{guild_query}
|
||||||
)
|
)
|
||||||
LIMIT {j["limit"]}
|
LIMIT {j["limit"]}
|
||||||
""", *args)
|
""",
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
message = await app.storage.get_message(row['id'])
|
message = await app.storage.get_message(row["id"])
|
||||||
gid = int(message['guild_id'])
|
gid = int(message["guild_id"])
|
||||||
|
|
||||||
# ignore messages pre-messages.guild_id
|
# ignore messages pre-messages.guild_id
|
||||||
if gid not in guild_ids:
|
if gid not in guild_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
res.append(
|
res.append(message)
|
||||||
message
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify(res)
|
return jsonify(res)
|
||||||
|
|
||||||
|
|
@ -449,18 +479,20 @@ def rand_hex(length: int = 8) -> str:
|
||||||
async def _del_from_table(db, table: str, user_id: int):
|
async def _del_from_table(db, table: str, user_id: int):
|
||||||
"""Delete a row from a table."""
|
"""Delete a row from a table."""
|
||||||
column = {
|
column = {
|
||||||
'channel_overwrites': 'target_user',
|
"channel_overwrites": "target_user",
|
||||||
'user_settings': 'id',
|
"user_settings": "id",
|
||||||
'group_dm_members': 'member_id'
|
"group_dm_members": "member_id",
|
||||||
}.get(table, 'user_id')
|
}.get(table, "user_id")
|
||||||
|
|
||||||
res = await db.execute(f"""
|
res = await db.execute(
|
||||||
|
f"""
|
||||||
DELETE FROM {table}
|
DELETE FROM {table}
|
||||||
WHERE {column} = $1
|
WHERE {column} = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
log.info('Deleting uid {} from {}, res: {!r}',
|
log.info("Deleting uid {} from {}, res: {!r}", user_id, table, res)
|
||||||
user_id, table, res)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_user(user_id, *, app_=None):
|
async def delete_user(user_id, *, app_=None):
|
||||||
|
|
@ -470,13 +502,14 @@ async def delete_user(user_id, *, app_=None):
|
||||||
|
|
||||||
db = app_.db
|
db = app_.db
|
||||||
|
|
||||||
new_username = f'Deleted User {rand_hex()}'
|
new_username = f"Deleted User {rand_hex()}"
|
||||||
|
|
||||||
# by using a random hex in password_hash
|
# by using a random hex in password_hash
|
||||||
# we break attempts at using the default '123' password hash
|
# we break attempts at using the default '123' password hash
|
||||||
# to issue valid tokens for deleted users.
|
# to issue valid tokens for deleted users.
|
||||||
|
|
||||||
await db.execute("""
|
await db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET
|
SET
|
||||||
username = $1,
|
username = $1,
|
||||||
|
|
@ -490,32 +523,39 @@ async def delete_user(user_id, *, app_=None):
|
||||||
password_hash = $2
|
password_hash = $2
|
||||||
WHERE
|
WHERE
|
||||||
id = $3
|
id = $3
|
||||||
""", new_username, rand_hex(32), user_id)
|
""",
|
||||||
|
new_username,
|
||||||
|
rand_hex(32),
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# remove the user from various tables
|
# remove the user from various tables
|
||||||
await _del_from_table(db, 'user_settings', user_id)
|
await _del_from_table(db, "user_settings", user_id)
|
||||||
await _del_from_table(db, 'user_payment_sources', user_id)
|
await _del_from_table(db, "user_payment_sources", user_id)
|
||||||
await _del_from_table(db, 'user_subscriptions', user_id)
|
await _del_from_table(db, "user_subscriptions", user_id)
|
||||||
await _del_from_table(db, 'user_payments', user_id)
|
await _del_from_table(db, "user_payments", user_id)
|
||||||
await _del_from_table(db, 'user_read_state', user_id)
|
await _del_from_table(db, "user_read_state", user_id)
|
||||||
await _del_from_table(db, 'guild_settings', user_id)
|
await _del_from_table(db, "guild_settings", user_id)
|
||||||
await _del_from_table(db, 'guild_settings_channel_overrides', user_id)
|
await _del_from_table(db, "guild_settings_channel_overrides", user_id)
|
||||||
|
|
||||||
await db.execute("""
|
await db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM relationships
|
DELETE FROM relationships
|
||||||
WHERE user_id = $1 OR peer_id = $1
|
WHERE user_id = $1 OR peer_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# DMs are still maintained, but not the state.
|
# DMs are still maintained, but not the state.
|
||||||
await _del_from_table(db, 'dm_channel_state', user_id)
|
await _del_from_table(db, "dm_channel_state", user_id)
|
||||||
|
|
||||||
# NOTE: we don't delete the group dms the user is an owner of...
|
# NOTE: we don't delete the group dms the user is an owner of...
|
||||||
# TODO: group dm owner reassign when the owner leaves a gdm
|
# TODO: group dm owner reassign when the owner leaves a gdm
|
||||||
await _del_from_table(db, 'group_dm_members', user_id)
|
await _del_from_table(db, "group_dm_members", user_id)
|
||||||
|
|
||||||
await _del_from_table(db, 'members', user_id)
|
await _del_from_table(db, "members", user_id)
|
||||||
await _del_from_table(db, 'member_roles', user_id)
|
await _del_from_table(db, "member_roles", user_id)
|
||||||
await _del_from_table(db, 'channel_overwrites', user_id)
|
await _del_from_table(db, "channel_overwrites", user_id)
|
||||||
|
|
||||||
# after updating the user, we send USER_UPDATE so that all the other
|
# after updating the user, we send USER_UPDATE so that all the other
|
||||||
# clients can refresh their caches on the now-deleted user
|
# clients can refresh their caches on the now-deleted user
|
||||||
|
|
@ -540,15 +580,12 @@ async def user_disconnect(user_id: int):
|
||||||
await state.ws.ws.close(4000)
|
await state.ws.ws.close(4000)
|
||||||
|
|
||||||
# force everyone to see the user as offline
|
# force everyone to see the user as offline
|
||||||
await app.presence.dispatch_pres(user_id, {
|
await app.presence.dispatch_pres(
|
||||||
'afk': False,
|
user_id, {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||||
'status': 'offline',
|
)
|
||||||
'game': None,
|
|
||||||
'since': 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/delete', methods=['POST'])
|
@bp.route("/@me/delete", methods=["POST"])
|
||||||
async def delete_account():
|
async def delete_account():
|
||||||
"""Delete own account.
|
"""Delete own account.
|
||||||
|
|
||||||
|
|
@ -560,29 +597,35 @@ async def delete_account():
|
||||||
j = await request.get_json()
|
j = await request.get_json()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
password = j['password']
|
password = j["password"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise BadRequest('password required')
|
raise BadRequest("password required")
|
||||||
|
|
||||||
owned_guilds = await app.db.fetchval("""
|
owned_guilds = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE owner_id = $1
|
WHERE owner_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if owned_guilds > 0:
|
if owned_guilds > 0:
|
||||||
raise BadRequest('You still own guilds.')
|
raise BadRequest("You still own guilds.")
|
||||||
|
|
||||||
pwd_hash = await app.db.fetchval("""
|
pwd_hash = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT password_hash
|
SELECT password_hash
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not await check_password(pwd_hash, password):
|
if not await check_password(pwd_hash, password):
|
||||||
raise Unauthorized('password does not match')
|
raise Unauthorized("password does not match")
|
||||||
|
|
||||||
await delete_user(user_id)
|
await delete_user(user_id)
|
||||||
await user_disconnect(user_id)
|
await user_disconnect(user_id)
|
||||||
|
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from quart import Blueprint, jsonify, current_app as app
|
||||||
|
|
||||||
from litecord.blueprints.auth import token_check
|
from litecord.blueprints.auth import token_check
|
||||||
|
|
||||||
bp = Blueprint('voice', __name__)
|
bp = Blueprint("voice", __name__)
|
||||||
|
|
||||||
|
|
||||||
def _majority_region_count(regions: list) -> str:
|
def _majority_region_count(regions: list) -> str:
|
||||||
|
|
@ -39,12 +39,14 @@ def _majority_region_count(regions: list) -> str:
|
||||||
|
|
||||||
async def _choose_random_region() -> Optional[str]:
|
async def _choose_random_region() -> Optional[str]:
|
||||||
"""Give a random voice region."""
|
"""Give a random voice region."""
|
||||||
regions = await app.db.fetch("""
|
regions = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM voice_regions
|
FROM voice_regions
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
regions = [r['id'] for r in regions]
|
regions = [r["id"] for r in regions]
|
||||||
|
|
||||||
if not regions:
|
if not regions:
|
||||||
return None
|
return None
|
||||||
|
|
@ -64,11 +66,14 @@ async def _majority_region_any(user_id) -> Optional[str]:
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for guild_id in guilds:
|
for guild_id in guilds:
|
||||||
region = await app.db.fetchval("""
|
region = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT region
|
SELECT region
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
res.append(region)
|
res.append(region)
|
||||||
|
|
||||||
|
|
@ -83,20 +88,23 @@ async def _majority_region_any(user_id) -> Optional[str]:
|
||||||
async def majority_region(user_id: int) -> Optional[str]:
|
async def majority_region(user_id: int) -> Optional[str]:
|
||||||
"""Given a user ID, give the most likely region for the user to be
|
"""Given a user ID, give the most likely region for the user to be
|
||||||
happy with."""
|
happy with."""
|
||||||
regions = await app.db.fetch("""
|
regions = await app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT region
|
SELECT region
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE owner_id = $1
|
WHERE owner_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not regions:
|
if not regions:
|
||||||
return await _majority_region_any(user_id)
|
return await _majority_region_any(user_id)
|
||||||
|
|
||||||
regions = [r['region'] for r in regions]
|
regions = [r["region"] for r in regions]
|
||||||
return _majority_region_count(regions)
|
return _majority_region_count(regions)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/regions', methods=['GET'])
|
@bp.route("/regions", methods=["GET"])
|
||||||
async def voice_regions():
|
async def voice_regions():
|
||||||
"""Return voice regions."""
|
"""Return voice regions."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -105,6 +113,6 @@ async def voice_regions():
|
||||||
regions = await app.storage.all_voice_regions()
|
regions = await app.storage.all_voice_regions()
|
||||||
|
|
||||||
for region in regions:
|
for region in regions:
|
||||||
region['optimal'] = region['id'] == best_region
|
region["optimal"] = region["id"] == best_region
|
||||||
|
|
||||||
return jsonify(regions)
|
return jsonify(regions)
|
||||||
|
|
|
||||||
|
|
@ -26,22 +26,28 @@ from quart import Blueprint, jsonify, current_app as app, request
|
||||||
|
|
||||||
from litecord.auth import token_check
|
from litecord.auth import token_check
|
||||||
from litecord.blueprints.checks import (
|
from litecord.blueprints.checks import (
|
||||||
channel_check, channel_perm_check, guild_check, guild_perm_check
|
channel_check,
|
||||||
|
channel_perm_check,
|
||||||
|
guild_check,
|
||||||
|
guild_perm_check,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.schemas import (
|
from litecord.schemas import (
|
||||||
validate, WEBHOOK_CREATE, WEBHOOK_UPDATE, WEBHOOK_MESSAGE_CREATE
|
validate,
|
||||||
|
WEBHOOK_CREATE,
|
||||||
|
WEBHOOK_UPDATE,
|
||||||
|
WEBHOOK_MESSAGE_CREATE,
|
||||||
)
|
)
|
||||||
from litecord.enums import ChannelType
|
from litecord.enums import ChannelType
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.snowflake import get_snowflake
|
||||||
from litecord.utils import async_map
|
from litecord.utils import async_map
|
||||||
from litecord.errors import (
|
from litecord.errors import WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
|
||||||
WebhookNotFound, Unauthorized, ChannelNotFound, BadRequest
|
|
||||||
)
|
|
||||||
|
|
||||||
from litecord.blueprints.channel.messages import (
|
from litecord.blueprints.channel.messages import (
|
||||||
msg_create_request, msg_create_check_content, msg_add_attachment,
|
msg_create_request,
|
||||||
msg_guild_text_mentions
|
msg_create_check_content,
|
||||||
|
msg_add_attachment,
|
||||||
|
msg_guild_text_mentions,
|
||||||
)
|
)
|
||||||
from litecord.embed.sanitizer import fill_embed, fetch_raw_img
|
from litecord.embed.sanitizer import fill_embed, fetch_raw_img
|
||||||
from litecord.embed.messages import process_url_embed, is_media_url
|
from litecord.embed.messages import process_url_embed, is_media_url
|
||||||
|
|
@ -50,30 +56,34 @@ from litecord.utils import pg_set_json
|
||||||
from litecord.enums import MessageType
|
from litecord.enums import MessageType
|
||||||
from litecord.images import STATIC_IMAGE_MIMES
|
from litecord.images import STATIC_IMAGE_MIMES
|
||||||
|
|
||||||
bp = Blueprint('webhooks', __name__)
|
bp = Blueprint("webhooks", __name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_webhook(webhook_id: int, *,
|
async def get_webhook(
|
||||||
secure: bool=True) -> Optional[Dict[str, Any]]:
|
webhook_id: int, *, secure: bool = True
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Get a webhook data"""
|
"""Get a webhook data"""
|
||||||
row = await app.db.fetchrow("""
|
row = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT id::text, guild_id::text, channel_id::text, creator_id,
|
SELECT id::text, guild_id::text, channel_id::text, creator_id,
|
||||||
name, avatar, token
|
name, avatar, token
|
||||||
FROM webhooks
|
FROM webhooks
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", webhook_id)
|
""",
|
||||||
|
webhook_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
|
|
||||||
drow['user'] = await app.storage.get_user(row['creator_id'])
|
drow["user"] = await app.storage.get_user(row["creator_id"])
|
||||||
drow.pop('creator_id')
|
drow.pop("creator_id")
|
||||||
|
|
||||||
if not secure:
|
if not secure:
|
||||||
drow.pop('user')
|
drow.pop("user")
|
||||||
drow.pop('guild_id')
|
drow.pop("guild_id")
|
||||||
|
|
||||||
return drow
|
return drow
|
||||||
|
|
||||||
|
|
@ -82,7 +92,7 @@ async def _webhook_check(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await channel_check(user_id, channel_id, only=ChannelType.GUILD_TEXT)
|
await channel_check(user_id, channel_id, only=ChannelType.GUILD_TEXT)
|
||||||
await channel_perm_check(user_id, channel_id, 'manage_webhooks')
|
await channel_perm_check(user_id, channel_id, "manage_webhooks")
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
|
@ -91,17 +101,20 @@ async def _webhook_check_guild(guild_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
await guild_check(user_id, guild_id)
|
await guild_check(user_id, guild_id)
|
||||||
await guild_perm_check(user_id, guild_id, 'manage_webhooks')
|
await guild_perm_check(user_id, guild_id, "manage_webhooks")
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
async def _webhook_check_fw(webhook_id):
|
async def _webhook_check_fw(webhook_id):
|
||||||
"""Make a check from an incoming webhook id (fw = from webhook)."""
|
"""Make a check from an incoming webhook id (fw = from webhook)."""
|
||||||
guild_id = await app.db.fetchval("""
|
guild_id = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT guild_id FROM webhooks
|
SELECT guild_id FROM webhooks
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", webhook_id)
|
""",
|
||||||
|
webhook_id,
|
||||||
|
)
|
||||||
|
|
||||||
if guild_id is None:
|
if guild_id is None:
|
||||||
raise WebhookNotFound()
|
raise WebhookNotFound()
|
||||||
|
|
@ -110,42 +123,48 @@ async def _webhook_check_fw(webhook_id):
|
||||||
|
|
||||||
|
|
||||||
async def _webhook_many(where_clause, arg: int):
|
async def _webhook_many(where_clause, arg: int):
|
||||||
webhook_ids = await app.db.fetch(f"""
|
webhook_ids = await app.db.fetch(
|
||||||
|
f"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM webhooks
|
FROM webhooks
|
||||||
{where_clause}
|
{where_clause}
|
||||||
""", arg)
|
""",
|
||||||
|
arg,
|
||||||
webhook_ids = [r['id'] for r in webhook_ids]
|
|
||||||
|
|
||||||
return jsonify(
|
|
||||||
await async_map(get_webhook, webhook_ids)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
webhook_ids = [r["id"] for r in webhook_ids]
|
||||||
|
|
||||||
|
return jsonify(await async_map(get_webhook, webhook_ids))
|
||||||
|
|
||||||
|
|
||||||
async def webhook_token_check(webhook_id: int, webhook_token: str):
|
async def webhook_token_check(webhook_id: int, webhook_token: str):
|
||||||
"""token_check() equivalent for webhooks."""
|
"""token_check() equivalent for webhooks."""
|
||||||
row = await app.db.fetchrow("""
|
row = await app.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT guild_id, channel_id
|
SELECT guild_id, channel_id
|
||||||
FROM webhooks
|
FROM webhooks
|
||||||
WHERE id = $1 AND token = $2
|
WHERE id = $1 AND token = $2
|
||||||
""", webhook_id, webhook_token)
|
""",
|
||||||
|
webhook_id,
|
||||||
|
webhook_token,
|
||||||
|
)
|
||||||
|
|
||||||
if row is None:
|
if row is None:
|
||||||
raise Unauthorized('webhook not found or unauthorized')
|
raise Unauthorized("webhook not found or unauthorized")
|
||||||
|
|
||||||
return row['guild_id'], row['channel_id']
|
return row["guild_id"], row["channel_id"]
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch_webhook_update(guild_id: int, channel_id):
|
async def _dispatch_webhook_update(guild_id: int, channel_id):
|
||||||
await app.dispatcher.dispatch('guild', guild_id, 'WEBHOOKS_UPDATE', {
|
await app.dispatcher.dispatch(
|
||||||
'guild_id': str(guild_id),
|
"guild",
|
||||||
'channel_id': str(channel_id)
|
guild_id,
|
||||||
})
|
"WEBHOOKS_UPDATE",
|
||||||
|
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@bp.route("/channels/<int:channel_id>/webhooks", methods=["POST"])
|
||||||
@bp.route('/channels/<int:channel_id>/webhooks', methods=['POST'])
|
|
||||||
async def create_webhook(channel_id: int):
|
async def create_webhook(channel_id: int):
|
||||||
"""Create a webhook given a channel."""
|
"""Create a webhook given a channel."""
|
||||||
user_id = await _webhook_check(channel_id)
|
user_id = await _webhook_check(channel_id)
|
||||||
|
|
@ -162,8 +181,7 @@ async def create_webhook(channel_id: int):
|
||||||
token = secrets.token_urlsafe(40)
|
token = secrets.token_urlsafe(40)
|
||||||
|
|
||||||
webhook_icon = await app.icons.put(
|
webhook_icon = await app.icons.put(
|
||||||
'user', webhook_id, j.get('avatar'),
|
"user", webhook_id, j.get("avatar"), always_icon=True, size=(128, 128)
|
||||||
always_icon=True, size=(128, 128)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
|
|
@ -173,36 +191,41 @@ async def create_webhook(channel_id: int):
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, $4, $5, $6, $7)
|
($1, $2, $3, $4, $5, $6, $7)
|
||||||
""",
|
""",
|
||||||
webhook_id, guild_id, channel_id, user_id,
|
webhook_id,
|
||||||
j['name'], webhook_icon.icon_hash, token
|
guild_id,
|
||||||
|
channel_id,
|
||||||
|
user_id,
|
||||||
|
j["name"],
|
||||||
|
webhook_icon.icon_hash,
|
||||||
|
token,
|
||||||
)
|
)
|
||||||
|
|
||||||
await _dispatch_webhook_update(guild_id, channel_id)
|
await _dispatch_webhook_update(guild_id, channel_id)
|
||||||
return jsonify(await get_webhook(webhook_id))
|
return jsonify(await get_webhook(webhook_id))
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/channels/<int:channel_id>/webhooks', methods=['GET'])
|
@bp.route("/channels/<int:channel_id>/webhooks", methods=["GET"])
|
||||||
async def get_channel_webhook(channel_id: int):
|
async def get_channel_webhook(channel_id: int):
|
||||||
"""Get a list of webhooks in a channel"""
|
"""Get a list of webhooks in a channel"""
|
||||||
await _webhook_check(channel_id)
|
await _webhook_check(channel_id)
|
||||||
return await _webhook_many('WHERE channel_id = $1', channel_id)
|
return await _webhook_many("WHERE channel_id = $1", channel_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/guilds/<int:guild_id>/webhooks', methods=['GET'])
|
@bp.route("/guilds/<int:guild_id>/webhooks", methods=["GET"])
|
||||||
async def get_guild_webhook(guild_id):
|
async def get_guild_webhook(guild_id):
|
||||||
"""Get all webhooks in a guild"""
|
"""Get all webhooks in a guild"""
|
||||||
await _webhook_check_guild(guild_id)
|
await _webhook_check_guild(guild_id)
|
||||||
return await _webhook_many('WHERE guild_id = $1', guild_id)
|
return await _webhook_many("WHERE guild_id = $1", guild_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>', methods=['GET'])
|
@bp.route("/webhooks/<int:webhook_id>", methods=["GET"])
|
||||||
async def get_single_webhook(webhook_id):
|
async def get_single_webhook(webhook_id):
|
||||||
"""Get a single webhook's information."""
|
"""Get a single webhook's information."""
|
||||||
await _webhook_check_fw(webhook_id)
|
await _webhook_check_fw(webhook_id)
|
||||||
return await jsonify(await get_webhook(webhook_id))
|
return await jsonify(await get_webhook(webhook_id))
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['GET'])
|
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["GET"])
|
||||||
async def get_tokened_webhook(webhook_id, webhook_token):
|
async def get_tokened_webhook(webhook_id, webhook_token):
|
||||||
"""Get a webhook using its token."""
|
"""Get a webhook using its token."""
|
||||||
await webhook_token_check(webhook_id, webhook_token)
|
await webhook_token_check(webhook_id, webhook_token)
|
||||||
|
|
@ -210,46 +233,58 @@ async def get_tokened_webhook(webhook_id, webhook_token):
|
||||||
|
|
||||||
|
|
||||||
async def _update_webhook(webhook_id: int, j: dict):
|
async def _update_webhook(webhook_id: int, j: dict):
|
||||||
if 'name' in j:
|
if "name" in j:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE webhooks
|
UPDATE webhooks
|
||||||
SET name = $1
|
SET name = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j['name'], webhook_id)
|
""",
|
||||||
|
j["name"],
|
||||||
|
webhook_id,
|
||||||
|
)
|
||||||
|
|
||||||
if 'channel_id' in j:
|
if "channel_id" in j:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE webhooks
|
UPDATE webhooks
|
||||||
SET channel_id = $1
|
SET channel_id = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", j['channel_id'], webhook_id)
|
""",
|
||||||
|
j["channel_id"],
|
||||||
if 'avatar' in j:
|
webhook_id,
|
||||||
new_icon = await app.icons.update(
|
|
||||||
'user', webhook_id, j['avatar'], always_icon=True, size=(128, 128)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await app.db.execute("""
|
if "avatar" in j:
|
||||||
|
new_icon = await app.icons.update(
|
||||||
|
"user", webhook_id, j["avatar"], always_icon=True, size=(128, 128)
|
||||||
|
)
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE webhooks
|
UPDATE webhooks
|
||||||
SET icon = $1
|
SET icon = $1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_icon.icon_hash, webhook_id)
|
""",
|
||||||
|
new_icon.icon_hash,
|
||||||
|
webhook_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>', methods=['PATCH'])
|
@bp.route("/webhooks/<int:webhook_id>", methods=["PATCH"])
|
||||||
async def modify_webhook(webhook_id: int):
|
async def modify_webhook(webhook_id: int):
|
||||||
"""Patch a webhook."""
|
"""Patch a webhook."""
|
||||||
_user_id, guild_id = await _webhook_check_fw(webhook_id)
|
_user_id, guild_id = await _webhook_check_fw(webhook_id)
|
||||||
j = validate(await request.get_json(), WEBHOOK_UPDATE)
|
j = validate(await request.get_json(), WEBHOOK_UPDATE)
|
||||||
|
|
||||||
if 'channel_id' in j:
|
if "channel_id" in j:
|
||||||
# pre checks
|
# pre checks
|
||||||
chan = await app.storage.get_channel(j['channel_id'])
|
chan = await app.storage.get_channel(j["channel_id"])
|
||||||
|
|
||||||
# short-circuiting should ensure chan isn't none
|
# short-circuiting should ensure chan isn't none
|
||||||
# by the time we do chan['guild_id']
|
# by the time we do chan['guild_id']
|
||||||
if chan and chan['guild_id'] != str(guild_id):
|
if chan and chan["guild_id"] != str(guild_id):
|
||||||
raise ChannelNotFound('cant assign webhook to channel')
|
raise ChannelNotFound("cant assign webhook to channel")
|
||||||
|
|
||||||
await _update_webhook(webhook_id, j)
|
await _update_webhook(webhook_id, j)
|
||||||
|
|
||||||
|
|
@ -257,20 +292,18 @@ async def modify_webhook(webhook_id: int):
|
||||||
|
|
||||||
# we don't need to cast channel_id to int since that isn't
|
# we don't need to cast channel_id to int since that isn't
|
||||||
# used in the dispatcher call
|
# used in the dispatcher call
|
||||||
await _dispatch_webhook_update(guild_id, webhook['channel_id'])
|
await _dispatch_webhook_update(guild_id, webhook["channel_id"])
|
||||||
return jsonify(webhook)
|
return jsonify(webhook)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['PATCH'])
|
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["PATCH"])
|
||||||
async def modify_webhook_tokened(webhook_id, webhook_token):
|
async def modify_webhook_tokened(webhook_id, webhook_token):
|
||||||
"""Modify a webhook, using its token."""
|
"""Modify a webhook, using its token."""
|
||||||
guild_id, channel_id = await webhook_token_check(
|
guild_id, channel_id = await webhook_token_check(webhook_id, webhook_token)
|
||||||
webhook_id, webhook_token)
|
|
||||||
|
|
||||||
# forcefully pop() the channel id out of the schema
|
# forcefully pop() the channel id out of the schema
|
||||||
# instead of making another, for simplicity's sake
|
# instead of making another, for simplicity's sake
|
||||||
j = validate(await request.get_json(),
|
j = validate(await request.get_json(), WEBHOOK_UPDATE.pop("channel_id"))
|
||||||
WEBHOOK_UPDATE.pop('channel_id'))
|
|
||||||
|
|
||||||
await _update_webhook(webhook_id, j)
|
await _update_webhook(webhook_id, j)
|
||||||
await _dispatch_webhook_update(guild_id, channel_id)
|
await _dispatch_webhook_update(guild_id, channel_id)
|
||||||
|
|
@ -281,35 +314,36 @@ async def delete_webhook(webhook_id: int):
|
||||||
"""Delete a webhook."""
|
"""Delete a webhook."""
|
||||||
webhook = await get_webhook(webhook_id)
|
webhook = await get_webhook(webhook_id)
|
||||||
|
|
||||||
res = await app.db.execute("""
|
res = await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM webhooks
|
DELETE FROM webhooks
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", webhook_id)
|
""",
|
||||||
|
webhook_id,
|
||||||
|
)
|
||||||
|
|
||||||
if res.lower() == 'delete 0':
|
if res.lower() == "delete 0":
|
||||||
raise WebhookNotFound()
|
raise WebhookNotFound()
|
||||||
|
|
||||||
# only casting the guild id since that's whats used
|
# only casting the guild id since that's whats used
|
||||||
# on the dispatcher call.
|
# on the dispatcher call.
|
||||||
await _dispatch_webhook_update(
|
await _dispatch_webhook_update(int(webhook["guild_id"]), webhook["channel_id"])
|
||||||
int(webhook['guild_id']), webhook['channel_id']
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>', methods=['DELETE'])
|
@bp.route("/webhooks/<int:webhook_id>", methods=["DELETE"])
|
||||||
async def del_webhook(webhook_id):
|
async def del_webhook(webhook_id):
|
||||||
"""Delete a webhook."""
|
"""Delete a webhook."""
|
||||||
await _webhook_check_fw(webhook_id)
|
await _webhook_check_fw(webhook_id)
|
||||||
await delete_webhook(webhook_id)
|
await delete_webhook(webhook_id)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['DELETE'])
|
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["DELETE"])
|
||||||
async def del_webhook_tokened(webhook_id, webhook_token):
|
async def del_webhook_tokened(webhook_id, webhook_token):
|
||||||
"""Delete a webhook, with its token."""
|
"""Delete a webhook, with its token."""
|
||||||
await webhook_token_check(webhook_id, webhook_token)
|
await webhook_token_check(webhook_id, webhook_token)
|
||||||
await delete_webhook(webhook_id)
|
await delete_webhook(webhook_id)
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
|
|
||||||
async def create_message_webhook(guild_id, channel_id, webhook_id, data):
|
async def create_message_webhook(guild_id, channel_id, webhook_id, data):
|
||||||
|
|
@ -328,23 +362,27 @@ async def create_message_webhook(guild_id, channel_id, webhook_id, data):
|
||||||
message_id,
|
message_id,
|
||||||
channel_id,
|
channel_id,
|
||||||
guild_id,
|
guild_id,
|
||||||
data['content'],
|
data["content"],
|
||||||
|
data["tts"],
|
||||||
data['tts'],
|
data["everyone_mention"],
|
||||||
data['everyone_mention'],
|
|
||||||
|
|
||||||
MessageType.DEFAULT.value,
|
MessageType.DEFAULT.value,
|
||||||
data.get('embeds', [])
|
data.get("embeds", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
info = data['info']
|
info = data["info"]
|
||||||
|
|
||||||
await conn.execute("""
|
await conn.execute(
|
||||||
|
"""
|
||||||
INSERT INTO message_webhook_info
|
INSERT INTO message_webhook_info
|
||||||
(message_id, webhook_id, name, avatar)
|
(message_id, webhook_id, name, avatar)
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, $4)
|
($1, $2, $3, $4)
|
||||||
""", message_id, webhook_id, info['name'], info['avatar'])
|
""",
|
||||||
|
message_id,
|
||||||
|
webhook_id,
|
||||||
|
info["name"],
|
||||||
|
info["avatar"],
|
||||||
|
)
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
||||||
|
|
@ -354,10 +392,15 @@ async def _webhook_avy_redir(webhook_id: int, avatar_url: EmbedURL):
|
||||||
url_hash = hashlib.sha256(avatar_url.to_md_path.encode()).hexdigest()
|
url_hash = hashlib.sha256(avatar_url.to_md_path.encode()).hexdigest()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO webhook_avatars (webhook_id, hash, md_url_redir)
|
INSERT INTO webhook_avatars (webhook_id, hash, md_url_redir)
|
||||||
VALUES ($1, $2, $3)
|
VALUES ($1, $2, $3)
|
||||||
""", webhook_id, url_hash, avatar_url.url)
|
""",
|
||||||
|
webhook_id,
|
||||||
|
url_hash,
|
||||||
|
avatar_url.url,
|
||||||
|
)
|
||||||
except asyncpg.UniqueViolationError:
|
except asyncpg.UniqueViolationError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -371,11 +414,11 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
|
||||||
Litecord will write an URL that redirects to the given avatar_url,
|
Litecord will write an URL that redirects to the given avatar_url,
|
||||||
using mediaproxy.
|
using mediaproxy.
|
||||||
"""
|
"""
|
||||||
if avatar_url.scheme not in ('http', 'https'):
|
if avatar_url.scheme not in ("http", "https"):
|
||||||
raise BadRequest('invalid avatar url scheme')
|
raise BadRequest("invalid avatar url scheme")
|
||||||
|
|
||||||
if not is_media_url(avatar_url):
|
if not is_media_url(avatar_url):
|
||||||
raise BadRequest('url is not media url')
|
raise BadRequest("url is not media url")
|
||||||
|
|
||||||
# we still fetch the URL to check its validity, mimetypes, etc
|
# we still fetch the URL to check its validity, mimetypes, etc
|
||||||
# but in the end, we will store it under the webhook_avatars table,
|
# but in the end, we will store it under the webhook_avatars table,
|
||||||
|
|
@ -383,11 +426,11 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
|
||||||
resp, raw = await fetch_raw_img(avatar_url)
|
resp, raw = await fetch_raw_img(avatar_url)
|
||||||
# raw_b64 = base64.b64encode(raw).decode()
|
# raw_b64 = base64.b64encode(raw).decode()
|
||||||
|
|
||||||
mime = resp.headers['content-type']
|
mime = resp.headers["content-type"]
|
||||||
|
|
||||||
# TODO: apng checks are missing (for this and everywhere else)
|
# TODO: apng checks are missing (for this and everywhere else)
|
||||||
if mime not in STATIC_IMAGE_MIMES:
|
if mime not in STATIC_IMAGE_MIMES:
|
||||||
raise BadRequest('invalid mime type for given url')
|
raise BadRequest("invalid mime type for given url")
|
||||||
|
|
||||||
# b64_data = f'data:{mime};base64,{raw_b64}'
|
# b64_data = f'data:{mime};base64,{raw_b64}'
|
||||||
|
|
||||||
|
|
@ -400,7 +443,7 @@ async def _create_avatar(webhook_id: int, avatar_url: EmbedURL) -> str:
|
||||||
return await _webhook_avy_redir(webhook_id, avatar_url)
|
return await _webhook_avy_redir(webhook_id, avatar_url)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>', methods=['POST'])
|
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>", methods=["POST"])
|
||||||
async def execute_webhook(webhook_id: int, webhook_token):
|
async def execute_webhook(webhook_id: int, webhook_token):
|
||||||
"""Execute a webhook. Sends a message to the channel the webhook
|
"""Execute a webhook. Sends a message to the channel the webhook
|
||||||
is tied to."""
|
is tied to."""
|
||||||
|
|
@ -413,41 +456,39 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
||||||
# NOTE: we really pop here instead of adding a kwarg
|
# NOTE: we really pop here instead of adding a kwarg
|
||||||
# to msg_create_request just because of webhooks.
|
# to msg_create_request just because of webhooks.
|
||||||
# nonce isn't allowed on WEBHOOK_MESSAGE_CREATE
|
# nonce isn't allowed on WEBHOOK_MESSAGE_CREATE
|
||||||
payload_json.pop('nonce')
|
payload_json.pop("nonce")
|
||||||
|
|
||||||
j = validate(payload_json, WEBHOOK_MESSAGE_CREATE)
|
j = validate(payload_json, WEBHOOK_MESSAGE_CREATE)
|
||||||
|
|
||||||
msg_create_check_content(j, files)
|
msg_create_check_content(j, files)
|
||||||
|
|
||||||
# webhooks don't need permissions.
|
# webhooks don't need permissions.
|
||||||
mentions_everyone = '@everyone' in j['content']
|
mentions_everyone = "@everyone" in j["content"]
|
||||||
mentions_here = '@here' in j['content']
|
mentions_here = "@here" in j["content"]
|
||||||
|
|
||||||
given_embeds = j.get('embeds', [])
|
given_embeds = j.get("embeds", [])
|
||||||
|
|
||||||
webhook = await get_webhook(webhook_id)
|
webhook = await get_webhook(webhook_id)
|
||||||
|
|
||||||
# webhooks have TWO avatars. one is from settings, the other is from
|
# webhooks have TWO avatars. one is from settings, the other is from
|
||||||
# the json's icon_url. one can be handled gracefully by IconManager,
|
# the json's icon_url. one can be handled gracefully by IconManager,
|
||||||
# but the other can't, at all.
|
# but the other can't, at all.
|
||||||
avatar = webhook['avatar']
|
avatar = webhook["avatar"]
|
||||||
|
|
||||||
if 'avatar_url' in j and j['avatar_url'] is not None:
|
if "avatar_url" in j and j["avatar_url"] is not None:
|
||||||
avatar = await _create_avatar(webhook_id, j['avatar_url'])
|
avatar = await _create_avatar(webhook_id, j["avatar_url"])
|
||||||
|
|
||||||
message_id = await create_message_webhook(
|
message_id = await create_message_webhook(
|
||||||
guild_id, channel_id, webhook_id, {
|
guild_id,
|
||||||
'content': j.get('content', ''),
|
channel_id,
|
||||||
'tts': j.get('tts', False),
|
webhook_id,
|
||||||
|
{
|
||||||
'everyone_mention': mentions_everyone or mentions_here,
|
"content": j.get("content", ""),
|
||||||
'embeds': await async_map(fill_embed, given_embeds),
|
"tts": j.get("tts", False),
|
||||||
|
"everyone_mention": mentions_everyone or mentions_here,
|
||||||
'info': {
|
"embeds": await async_map(fill_embed, given_embeds),
|
||||||
'name': j.get('username', webhook['name']),
|
"info": {"name": j.get("username", webhook["name"]), "avatar": avatar},
|
||||||
'avatar': avatar
|
},
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for pre_attachment in files:
|
for pre_attachment in files:
|
||||||
|
|
@ -455,33 +496,28 @@ async def execute_webhook(webhook_id: int, webhook_token):
|
||||||
|
|
||||||
payload = await app.storage.get_message(message_id)
|
payload = await app.storage.get_message(message_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch('channel', channel_id,
|
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", payload)
|
||||||
'MESSAGE_CREATE', payload)
|
|
||||||
|
|
||||||
# spawn embedder in the background, even when we're on a webhook.
|
# spawn embedder in the background, even when we're on a webhook.
|
||||||
app.sched.spawn(
|
app.sched.spawn(
|
||||||
process_url_embed(
|
process_url_embed(app.config, app.storage, app.dispatcher, app.session, payload)
|
||||||
app.config, app.storage, app.dispatcher, app.session,
|
|
||||||
payload
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# we can assume its a guild text channel, so just call it
|
# we can assume its a guild text channel, so just call it
|
||||||
await msg_guild_text_mentions(
|
await msg_guild_text_mentions(payload, guild_id, mentions_everyone, mentions_here)
|
||||||
payload, guild_id, mentions_everyone, mentions_here)
|
|
||||||
|
|
||||||
# TODO: is it really 204?
|
# TODO: is it really 204?
|
||||||
return '', 204
|
return "", 204
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/slack',
|
|
||||||
methods=['POST'])
|
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>/slack", methods=["POST"])
|
||||||
async def execute_slack_webhook(webhook_id, webhook_token):
|
async def execute_slack_webhook(webhook_id, webhook_token):
|
||||||
"""Execute a webhook but expecting Slack data."""
|
"""Execute a webhook but expecting Slack data."""
|
||||||
# TODO: know slack webhooks
|
# TODO: know slack webhooks
|
||||||
await webhook_token_check(webhook_id, webhook_token)
|
await webhook_token_check(webhook_id, webhook_token)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/webhooks/<int:webhook_id>/<webhook_token>/github', methods=['POST'])
|
@bp.route("/webhooks/<int:webhook_id>/<webhook_token>/github", methods=["POST"])
|
||||||
async def execute_github_webhook(webhook_id, webhook_token):
|
async def execute_github_webhook(webhook_id, webhook_token):
|
||||||
"""Execute a webhook but expecting GitHub data."""
|
"""Execute a webhook but expecting GitHub data."""
|
||||||
# TODO: know github webhooks
|
# TODO: know github webhooks
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,14 @@ from typing import List, Any, Dict
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .pubsub import GuildDispatcher, MemberDispatcher, \
|
from .pubsub import (
|
||||||
UserDispatcher, ChannelDispatcher, FriendDispatcher, \
|
GuildDispatcher,
|
||||||
LazyGuildDispatcher
|
MemberDispatcher,
|
||||||
|
UserDispatcher,
|
||||||
|
ChannelDispatcher,
|
||||||
|
FriendDispatcher,
|
||||||
|
LazyGuildDispatcher,
|
||||||
|
)
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -44,17 +49,18 @@ class EventDispatcher:
|
||||||
when dispatching, the backend can do its own logic, given
|
when dispatching, the backend can do its own logic, given
|
||||||
its subscriber ids.
|
its subscriber ids.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.state_manager = app.state_manager
|
self.state_manager = app.state_manager
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
self.backends = {
|
self.backends = {
|
||||||
'guild': GuildDispatcher(self),
|
"guild": GuildDispatcher(self),
|
||||||
'member': MemberDispatcher(self),
|
"member": MemberDispatcher(self),
|
||||||
'channel': ChannelDispatcher(self),
|
"channel": ChannelDispatcher(self),
|
||||||
'user': UserDispatcher(self),
|
"user": UserDispatcher(self),
|
||||||
'friend': FriendDispatcher(self),
|
"friend": FriendDispatcher(self),
|
||||||
'lazy_guild': LazyGuildDispatcher(self),
|
"lazy_guild": LazyGuildDispatcher(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
||||||
|
|
@ -71,13 +77,13 @@ class EventDispatcher:
|
||||||
|
|
||||||
return await method(key, identifier, *args)
|
return await method(key, identifier, *args)
|
||||||
|
|
||||||
async def subscribe(self, backend: str, key: Any, identifier: Any,
|
async def subscribe(
|
||||||
flags: Dict[str, Any] = None):
|
self, backend: str, key: Any, identifier: Any, flags: Dict[str, Any] = None
|
||||||
|
):
|
||||||
"""Subscribe a single element to the given backend."""
|
"""Subscribe a single element to the given backend."""
|
||||||
flags = flags or {}
|
flags = flags or {}
|
||||||
|
|
||||||
log.debug('SUB backend={} key={} <= id={}',
|
log.debug("SUB backend={} key={} <= id={}", backend, key, identifier, backend)
|
||||||
backend, key, identifier, backend)
|
|
||||||
|
|
||||||
# this is a hacky solution for backwards compatibility between backends
|
# this is a hacky solution for backwards compatibility between backends
|
||||||
# that implement flags and backends that don't.
|
# that implement flags and backends that don't.
|
||||||
|
|
@ -85,16 +91,15 @@ class EventDispatcher:
|
||||||
# passing flags to backends that don't implement flags will
|
# passing flags to backends that don't implement flags will
|
||||||
# cause errors as expected.
|
# cause errors as expected.
|
||||||
if flags:
|
if flags:
|
||||||
return await self.action(backend, 'sub', key, identifier, flags)
|
return await self.action(backend, "sub", key, identifier, flags)
|
||||||
|
|
||||||
return await self.action(backend, 'sub', key, identifier)
|
return await self.action(backend, "sub", key, identifier)
|
||||||
|
|
||||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
||||||
"""Unsubscribe an element from the given backend."""
|
"""Unsubscribe an element from the given backend."""
|
||||||
log.debug('UNSUB backend={} key={} => id={}',
|
log.debug("UNSUB backend={} key={} => id={}", backend, key, identifier, backend)
|
||||||
backend, key, identifier, backend)
|
|
||||||
|
|
||||||
return await self.action(backend, 'unsub', key, identifier)
|
return await self.action(backend, "unsub", key, identifier)
|
||||||
|
|
||||||
async def sub(self, backend, key, identifier):
|
async def sub(self, backend, key, identifier):
|
||||||
"""Alias to subscribe()."""
|
"""Alias to subscribe()."""
|
||||||
|
|
@ -104,8 +109,13 @@ class EventDispatcher:
|
||||||
"""Alias to unsubscribe()."""
|
"""Alias to unsubscribe()."""
|
||||||
return await self.unsubscribe(backend, key, identifier)
|
return await self.unsubscribe(backend, key, identifier)
|
||||||
|
|
||||||
async def sub_many(self, backend_str: str, identifier: Any,
|
async def sub_many(
|
||||||
keys: list, flags: Dict[str, Any] = None):
|
self,
|
||||||
|
backend_str: str,
|
||||||
|
identifier: Any,
|
||||||
|
keys: list,
|
||||||
|
flags: Dict[str, Any] = None,
|
||||||
|
):
|
||||||
"""Subscribe to multiple channels (all in a single backend)
|
"""Subscribe to multiple channels (all in a single backend)
|
||||||
at a time.
|
at a time.
|
||||||
|
|
||||||
|
|
@ -116,8 +126,7 @@ class EventDispatcher:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
await self.subscribe(backend_str, key, identifier, flags)
|
await self.subscribe(backend_str, key, identifier, flags)
|
||||||
|
|
||||||
async def mass_sub(self, identifier: Any,
|
async def mass_sub(self, identifier: Any, backends: List[tuple]):
|
||||||
backends: List[tuple]):
|
|
||||||
"""Mass subscribe to many backends at once."""
|
"""Mass subscribe to many backends at once."""
|
||||||
for bcall in backends:
|
for bcall in backends:
|
||||||
backend_str, keys = bcall[0], bcall[1]
|
backend_str, keys = bcall[0], bcall[1]
|
||||||
|
|
@ -128,8 +137,13 @@ class EventDispatcher:
|
||||||
# we have flags
|
# we have flags
|
||||||
flags = bcall[2]
|
flags = bcall[2]
|
||||||
|
|
||||||
log.debug('subscribing {} to {} keys in backend {}, flags: {}',
|
log.debug(
|
||||||
identifier, len(keys), backend_str, flags)
|
"subscribing {} to {} keys in backend {}, flags: {}",
|
||||||
|
identifier,
|
||||||
|
len(keys),
|
||||||
|
backend_str,
|
||||||
|
flags,
|
||||||
|
)
|
||||||
|
|
||||||
await self.sub_many(backend_str, identifier, keys, flags)
|
await self.sub_many(backend_str, identifier, keys, flags)
|
||||||
|
|
||||||
|
|
@ -145,17 +159,14 @@ class EventDispatcher:
|
||||||
key = backend.KEY_TYPE(key)
|
key = backend.KEY_TYPE(key)
|
||||||
return await backend.dispatch(key, *args, **kwargs)
|
return await backend.dispatch(key, *args, **kwargs)
|
||||||
|
|
||||||
async def dispatch_many(self, backend_str: str,
|
async def dispatch_many(self, backend_str: str, keys: List[Any], *args, **kwargs):
|
||||||
keys: List[Any], *args, **kwargs):
|
|
||||||
"""Dispatch to multiple keys in a single backend."""
|
"""Dispatch to multiple keys in a single backend."""
|
||||||
log.info('MULTI DISPATCH: {!r}, {} keys',
|
log.info("MULTI DISPATCH: {!r}, {} keys", backend_str, len(keys))
|
||||||
backend_str, len(keys))
|
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
await self.dispatch(backend_str, key, *args, **kwargs)
|
await self.dispatch(backend_str, key, *args, **kwargs)
|
||||||
|
|
||||||
async def dispatch_filter(self, backend_str: str,
|
async def dispatch_filter(self, backend_str: str, key: Any, func, *args):
|
||||||
key: Any, func, *args):
|
|
||||||
"""Dispatch to a backend that only accepts
|
"""Dispatch to a backend that only accepts
|
||||||
(event, data) arguments with an optional filter
|
(event, data) arguments with an optional filter
|
||||||
function."""
|
function."""
|
||||||
|
|
@ -163,9 +174,9 @@ class EventDispatcher:
|
||||||
key = backend.KEY_TYPE(key)
|
key = backend.KEY_TYPE(key)
|
||||||
return await backend.dispatch_filter(key, func, *args)
|
return await backend.dispatch_filter(key, func, *args)
|
||||||
|
|
||||||
async def dispatch_many_filter_list(self, backend_str: str,
|
async def dispatch_many_filter_list(
|
||||||
keys: List[Any], sess_list: List[str],
|
self, backend_str: str, keys: List[Any], sess_list: List[str], *args
|
||||||
*args):
|
):
|
||||||
"""Make a "unique" dispatch given a list of session ids.
|
"""Make a "unique" dispatch given a list of session ids.
|
||||||
|
|
||||||
This only works for backends that have a dispatch_filter
|
This only works for backends that have a dispatch_filter
|
||||||
|
|
@ -175,9 +186,8 @@ class EventDispatcher:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
sess_list.extend(
|
sess_list.extend(
|
||||||
await self.dispatch_filter(
|
await self.dispatch_filter(
|
||||||
backend_str, key,
|
backend_str, key, lambda sess_id: sess_id not in sess_list, *args
|
||||||
lambda sess_id: sess_id not in sess_list,
|
)
|
||||||
*args)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return sess_list
|
return sess_list
|
||||||
|
|
@ -197,12 +207,12 @@ class EventDispatcher:
|
||||||
|
|
||||||
async def dispatch_guild(self, guild_id, event, data):
|
async def dispatch_guild(self, guild_id, event, data):
|
||||||
"""Backwards compatibility with old EventDispatcher."""
|
"""Backwards compatibility with old EventDispatcher."""
|
||||||
return await self.dispatch('guild', guild_id, event, data)
|
return await self.dispatch("guild", guild_id, event, data)
|
||||||
|
|
||||||
async def dispatch_user_guild(self, user_id, guild_id, event, data):
|
async def dispatch_user_guild(self, user_id, guild_id, event, data):
|
||||||
"""Backwards compatibility with old EventDispatcher."""
|
"""Backwards compatibility with old EventDispatcher."""
|
||||||
return await self.dispatch('member', (guild_id, user_id), event, data)
|
return await self.dispatch("member", (guild_id, user_id), event, data)
|
||||||
|
|
||||||
async def dispatch_user(self, user_id, event, data):
|
async def dispatch_user(self, user_id, event, data):
|
||||||
"""Backwards compatibility with old EventDispatcher."""
|
"""Backwards compatibility with old EventDispatcher."""
|
||||||
return await self.dispatch('user', user_id, event, data)
|
return await self.dispatch("user", user_id, event, data)
|
||||||
|
|
|
||||||
|
|
@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from .sanitizer import sanitize_embed
|
from .sanitizer import sanitize_embed
|
||||||
|
|
||||||
__all__ = ['sanitize_embed']
|
__all__ = ["sanitize_embed"]
|
||||||
|
|
|
||||||
|
|
@ -30,11 +30,7 @@ from litecord.embed.schemas import EmbedURL
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
MEDIA_EXTENSIONS = (
|
MEDIA_EXTENSIONS = ("png", "jpg", "jpeg", "gif", "webm")
|
||||||
'png',
|
|
||||||
'jpg', 'jpeg',
|
|
||||||
'gif', 'webm'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def insert_media_meta(url, config, session):
|
async def insert_media_meta(url, config, session):
|
||||||
|
|
@ -45,18 +41,18 @@ async def insert_media_meta(url, config, session):
|
||||||
if meta is None:
|
if meta is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not meta['image']:
|
if not meta["image"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'type': 'image',
|
"type": "image",
|
||||||
'url': url,
|
"url": url,
|
||||||
'thumbnail': {
|
"thumbnail": {
|
||||||
'width': meta['width'],
|
"width": meta["width"],
|
||||||
'height': meta['height'],
|
"height": meta["height"],
|
||||||
'url': url,
|
"url": url,
|
||||||
'proxy_url': img_proxy_url
|
"proxy_url": img_proxy_url,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -64,29 +60,32 @@ async def msg_update_embeds(payload, new_embeds, storage, dispatcher):
|
||||||
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
|
"""Update the message with the given embeds and dispatch a MESSAGE_UPDATE
|
||||||
to users."""
|
to users."""
|
||||||
|
|
||||||
message_id = int(payload['id'])
|
message_id = int(payload["id"])
|
||||||
channel_id = int(payload['channel_id'])
|
channel_id = int(payload["channel_id"])
|
||||||
|
|
||||||
await storage.execute_with_json("""
|
await storage.execute_with_json(
|
||||||
|
"""
|
||||||
UPDATE messages
|
UPDATE messages
|
||||||
SET embeds = $1
|
SET embeds = $1
|
||||||
WHERE messages.id = $2
|
WHERE messages.id = $2
|
||||||
""", new_embeds, message_id)
|
""",
|
||||||
|
new_embeds,
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
|
||||||
update_payload = {
|
update_payload = {
|
||||||
'id': str(message_id),
|
"id": str(message_id),
|
||||||
'channel_id': str(channel_id),
|
"channel_id": str(channel_id),
|
||||||
'embeds': new_embeds,
|
"embeds": new_embeds,
|
||||||
}
|
}
|
||||||
|
|
||||||
if 'guild_id' in payload:
|
if "guild_id" in payload:
|
||||||
update_payload['guild_id'] = payload['guild_id']
|
update_payload["guild_id"] = payload["guild_id"]
|
||||||
|
|
||||||
if 'flags' in payload:
|
if "flags" in payload:
|
||||||
update_payload['flags'] = payload['flags']
|
update_payload["flags"] = payload["flags"]
|
||||||
|
|
||||||
await dispatcher.dispatch(
|
await dispatcher.dispatch("channel", channel_id, "MESSAGE_UPDATE", update_payload)
|
||||||
'channel', channel_id, 'MESSAGE_UPDATE', update_payload)
|
|
||||||
|
|
||||||
|
|
||||||
def is_media_url(url) -> bool:
|
def is_media_url(url) -> bool:
|
||||||
|
|
@ -98,7 +97,7 @@ def is_media_url(url) -> bool:
|
||||||
parsed = urllib.parse.urlparse(url)
|
parsed = urllib.parse.urlparse(url)
|
||||||
|
|
||||||
path = Path(parsed.path)
|
path = Path(parsed.path)
|
||||||
extension = path.suffix.lstrip('.')
|
extension = path.suffix.lstrip(".")
|
||||||
|
|
||||||
return extension in MEDIA_EXTENSIONS
|
return extension in MEDIA_EXTENSIONS
|
||||||
|
|
||||||
|
|
@ -109,20 +108,20 @@ async def insert_mp_embed(parsed, config, session):
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
||||||
async def process_url_embed(config, storage, dispatcher,
|
async def process_url_embed(
|
||||||
session, payload: dict, *, delay=0):
|
config, storage, dispatcher, session, payload: dict, *, delay=0
|
||||||
|
):
|
||||||
"""Process URLs in a message and generate embeds based on that."""
|
"""Process URLs in a message and generate embeds based on that."""
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
message_id = int(payload['id'])
|
message_id = int(payload["id"])
|
||||||
|
|
||||||
# if we already have embeds
|
# if we already have embeds
|
||||||
# we shouldn't add our own.
|
# we shouldn't add our own.
|
||||||
embeds = payload['embeds']
|
embeds = payload["embeds"]
|
||||||
|
|
||||||
if embeds:
|
if embeds:
|
||||||
log.debug('url processor: ignoring existing embeds @ mid {}',
|
log.debug("url processor: ignoring existing embeds @ mid {}", message_id)
|
||||||
message_id)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# now, we have two types of embeds:
|
# now, we have two types of embeds:
|
||||||
|
|
@ -130,7 +129,7 @@ async def process_url_embed(config, storage, dispatcher,
|
||||||
# - url embeds
|
# - url embeds
|
||||||
|
|
||||||
# use regex to get URLs
|
# use regex to get URLs
|
||||||
urls = re.findall(r'(https?://\S+)', payload['content'])
|
urls = re.findall(r"(https?://\S+)", payload["content"])
|
||||||
urls = urls[:5]
|
urls = urls[:5]
|
||||||
|
|
||||||
# from there, we need to parse each found url and check its path.
|
# from there, we need to parse each found url and check its path.
|
||||||
|
|
@ -159,7 +158,6 @@ async def process_url_embed(config, storage, dispatcher,
|
||||||
if not new_embeds:
|
if not new_embeds:
|
||||||
return
|
return
|
||||||
|
|
||||||
log.debug('made {} embeds for mid {}',
|
log.debug("made {} embeds for mid {}", len(new_embeds), message_id)
|
||||||
len(new_embeds), message_id)
|
|
||||||
|
|
||||||
await msg_update_embeds(payload, new_embeds, storage, dispatcher)
|
await msg_update_embeds(payload, new_embeds, storage, dispatcher)
|
||||||
|
|
|
||||||
|
|
@ -39,9 +39,7 @@ def sanitize_embed(embed: Embed) -> Embed:
|
||||||
This is non-complex sanitization as it doesn't
|
This is non-complex sanitization as it doesn't
|
||||||
need the app object.
|
need the app object.
|
||||||
"""
|
"""
|
||||||
return {**embed, **{
|
return {**embed, **{"type": "rich"}}
|
||||||
'type': 'rich'
|
|
||||||
}}
|
|
||||||
|
|
||||||
|
|
||||||
def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
||||||
|
|
@ -55,7 +53,7 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
||||||
|
|
||||||
# get the list of components given
|
# get the list of components given
|
||||||
if isinstance(components_in, str):
|
if isinstance(components_in, str):
|
||||||
components = components_in.split('.')
|
components = components_in.split(".")
|
||||||
else:
|
else:
|
||||||
components = list(components_in)
|
components = list(components_in)
|
||||||
|
|
||||||
|
|
@ -77,7 +75,6 @@ def path_exists(embed: Embed, components_in: Union[List[str], str]):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _mk_cfg_sess(config, session) -> tuple:
|
def _mk_cfg_sess(config, session) -> tuple:
|
||||||
"""Return a tuple of (config, session)."""
|
"""Return a tuple of (config, session)."""
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|
@ -91,11 +88,11 @@ def _mk_cfg_sess(config, session) -> tuple:
|
||||||
|
|
||||||
def _md_base(config) -> Optional[tuple]:
|
def _md_base(config) -> Optional[tuple]:
|
||||||
"""Return the protocol and base url for the mediaproxy."""
|
"""Return the protocol and base url for the mediaproxy."""
|
||||||
md_base_url = config['MEDIA_PROXY']
|
md_base_url = config["MEDIA_PROXY"]
|
||||||
if md_base_url is None:
|
if md_base_url is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
proto = 'https' if config['IS_SSL'] else 'http'
|
proto = "https" if config["IS_SSL"] else "http"
|
||||||
|
|
||||||
return proto, md_base_url
|
return proto, md_base_url
|
||||||
|
|
||||||
|
|
@ -111,7 +108,7 @@ def make_md_req_url(config, scope: str, url):
|
||||||
return url.url if isinstance(url, EmbedURL) else url
|
return url.url if isinstance(url, EmbedURL) else url
|
||||||
|
|
||||||
proto, base_url = base
|
proto, base_url = base
|
||||||
return f'{proto}://{base_url}/{scope}/{url.to_md_path}'
|
return f"{proto}://{base_url}/{scope}/{url.to_md_path}"
|
||||||
|
|
||||||
|
|
||||||
def proxify(url, *, config=None) -> str:
|
def proxify(url, *, config=None) -> str:
|
||||||
|
|
@ -122,11 +119,12 @@ def proxify(url, *, config=None) -> str:
|
||||||
if isinstance(url, str):
|
if isinstance(url, str):
|
||||||
url = EmbedURL(url)
|
url = EmbedURL(url)
|
||||||
|
|
||||||
return make_md_req_url(config, 'img', url)
|
return make_md_req_url(config, "img", url)
|
||||||
|
|
||||||
|
|
||||||
async def _md_client_req(config, session, scope: str,
|
async def _md_client_req(
|
||||||
url, *, ret_resp=False) -> Optional[Union[Tuple, Dict]]:
|
config, session, scope: str, url, *, ret_resp=False
|
||||||
|
) -> Optional[Union[Tuple, Dict]]:
|
||||||
"""Makes a request to the mediaproxy.
|
"""Makes a request to the mediaproxy.
|
||||||
|
|
||||||
This has common code between all the main mediaproxy request functions
|
This has common code between all the main mediaproxy request functions
|
||||||
|
|
@ -172,17 +170,13 @@ async def _md_client_req(config, session, scope: str,
|
||||||
return await resp.json()
|
return await resp.json()
|
||||||
|
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
log.warning('failed to call {!r}, {} {!r}',
|
log.warning("failed to call {!r}, {} {!r}", request_url, resp.status, body)
|
||||||
request_url, resp.status, body)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def fetch_metadata(url, *, config=None,
|
async def fetch_metadata(url, *, config=None, session=None) -> Optional[Dict]:
|
||||||
session=None) -> Optional[Dict]:
|
|
||||||
"""Fetch metadata for a url (image width, mime, etc)."""
|
"""Fetch metadata for a url (image width, mime, etc)."""
|
||||||
return await _md_client_req(
|
return await _md_client_req(config, session, "meta", url)
|
||||||
config, session, 'meta', url
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
|
async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
|
||||||
|
|
@ -191,9 +185,7 @@ async def fetch_raw_img(url, *, config=None, session=None) -> Optional[tuple]:
|
||||||
Returns a tuple containing the response object and the raw bytes given by
|
Returns a tuple containing the response object and the raw bytes given by
|
||||||
the website.
|
the website.
|
||||||
"""
|
"""
|
||||||
tup = await _md_client_req(
|
tup = await _md_client_req(config, session, "img", url, ret_resp=True)
|
||||||
config, session, 'img', url, ret_resp=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tup:
|
if not tup:
|
||||||
return None
|
return None
|
||||||
|
|
@ -207,9 +199,7 @@ async def fetch_embed(url, *, config=None, session=None) -> Dict[str, Any]:
|
||||||
|
|
||||||
Returns a discord embed object.
|
Returns a discord embed object.
|
||||||
"""
|
"""
|
||||||
return await _md_client_req(
|
return await _md_client_req(config, session, "embed", url)
|
||||||
config, session, 'embed', url
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
|
async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
|
||||||
|
|
@ -229,22 +219,20 @@ async def fill_embed(embed: Optional[Embed]) -> Optional[Embed]:
|
||||||
|
|
||||||
embed = sanitize_embed(embed)
|
embed = sanitize_embed(embed)
|
||||||
|
|
||||||
if path_exists(embed, 'footer.icon_url'):
|
if path_exists(embed, "footer.icon_url"):
|
||||||
embed['footer']['proxy_icon_url'] = \
|
embed["footer"]["proxy_icon_url"] = proxify(embed["footer"]["icon_url"])
|
||||||
proxify(embed['footer']['icon_url'])
|
|
||||||
|
|
||||||
if path_exists(embed, 'author.icon_url'):
|
if path_exists(embed, "author.icon_url"):
|
||||||
embed['author']['proxy_icon_url'] = \
|
embed["author"]["proxy_icon_url"] = proxify(embed["author"]["icon_url"])
|
||||||
proxify(embed['author']['icon_url'])
|
|
||||||
|
|
||||||
if path_exists(embed, 'image.url'):
|
if path_exists(embed, "image.url"):
|
||||||
image_url = embed['image']['url']
|
image_url = embed["image"]["url"]
|
||||||
|
|
||||||
meta = await fetch_metadata(image_url)
|
meta = await fetch_metadata(image_url)
|
||||||
embed['image']['proxy_url'] = proxify(image_url)
|
embed["image"]["proxy_url"] = proxify(image_url)
|
||||||
|
|
||||||
if meta and meta['image']:
|
if meta and meta["image"]:
|
||||||
embed['image']['width'] = meta['width']
|
embed["image"]["width"] = meta["width"]
|
||||||
embed['image']['height'] = meta['height']
|
embed["image"]["height"] = meta["height"]
|
||||||
|
|
||||||
return embed
|
return embed
|
||||||
|
|
|
||||||
|
|
@ -28,8 +28,8 @@ class EmbedURL:
|
||||||
def __init__(self, url: str):
|
def __init__(self, url: str):
|
||||||
parsed = urllib.parse.urlparse(url)
|
parsed = urllib.parse.urlparse(url)
|
||||||
|
|
||||||
if parsed.scheme not in ('http', 'https', 'attachment'):
|
if parsed.scheme not in ("http", "https", "attachment"):
|
||||||
raise ValueError('Invalid URL scheme')
|
raise ValueError("Invalid URL scheme")
|
||||||
|
|
||||||
self.scheme = parsed.scheme
|
self.scheme = parsed.scheme
|
||||||
self.raw_url = url
|
self.raw_url = url
|
||||||
|
|
@ -54,105 +54,61 @@ class EmbedURL:
|
||||||
def to_md_path(self) -> str:
|
def to_md_path(self) -> str:
|
||||||
"""Convert the EmbedURL to a mediaproxy path (post img/meta)."""
|
"""Convert the EmbedURL to a mediaproxy path (post img/meta)."""
|
||||||
parsed = self.parsed
|
parsed = self.parsed
|
||||||
return (
|
return f"{parsed.scheme}/{parsed.netloc}" f"{parsed.path}?{parsed.query}"
|
||||||
f'{parsed.scheme}/{parsed.netloc}'
|
|
||||||
f'{parsed.path}?{parsed.query}'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
EMBED_FOOTER = {
|
EMBED_FOOTER = {
|
||||||
'text': {
|
"text": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True},
|
"icon_url": {"coerce": EmbedURL, "required": False},
|
||||||
|
|
||||||
'icon_url': {
|
|
||||||
'coerce': EmbedURL, 'required': False,
|
|
||||||
},
|
|
||||||
|
|
||||||
# NOTE: proxy_icon_url set by us
|
# NOTE: proxy_icon_url set by us
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBED_IMAGE = {
|
EMBED_IMAGE = {
|
||||||
'url': {'coerce': EmbedURL, 'required': True},
|
"url": {"coerce": EmbedURL, "required": True},
|
||||||
|
|
||||||
# NOTE: proxy_url, width, height set by us
|
# NOTE: proxy_url, width, height set by us
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBED_THUMBNAIL = EMBED_IMAGE
|
EMBED_THUMBNAIL = EMBED_IMAGE
|
||||||
|
|
||||||
EMBED_AUTHOR = {
|
EMBED_AUTHOR = {
|
||||||
'name': {
|
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False
|
"url": {"coerce": EmbedURL, "required": False},
|
||||||
},
|
"icon_url": {"coerce": EmbedURL, "required": False}
|
||||||
'url': {
|
|
||||||
'coerce': EmbedURL, 'required': False,
|
|
||||||
},
|
|
||||||
'icon_url': {
|
|
||||||
'coerce': EmbedURL, 'required': False,
|
|
||||||
}
|
|
||||||
|
|
||||||
# NOTE: proxy_icon_url set by us
|
# NOTE: proxy_icon_url set by us
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBED_FIELD = {
|
EMBED_FIELD = {
|
||||||
'name': {
|
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True
|
"value": {"type": "string", "minlength": 1, "maxlength": 1024, "required": True},
|
||||||
},
|
"inline": {"type": "boolean", "required": False, "default": True},
|
||||||
'value': {
|
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': True
|
|
||||||
},
|
|
||||||
'inline': {
|
|
||||||
'type': 'boolean', 'required': False, 'default': True,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EMBED_OBJECT = {
|
EMBED_OBJECT = {
|
||||||
'title': {
|
"title": {"type": "string", "minlength": 1, "maxlength": 256, "required": False},
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': False},
|
|
||||||
# NOTE: type set by us
|
# NOTE: type set by us
|
||||||
'description': {
|
"description": {
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 2048, 'required': False,
|
"type": "string",
|
||||||
|
"minlength": 1,
|
||||||
|
"maxlength": 2048,
|
||||||
|
"required": False,
|
||||||
},
|
},
|
||||||
'url': {
|
"url": {"coerce": EmbedURL, "required": False},
|
||||||
'coerce': EmbedURL, 'required': False,
|
"timestamp": {
|
||||||
},
|
|
||||||
'timestamp': {
|
|
||||||
# TODO: an ISO 8601 type
|
# TODO: an ISO 8601 type
|
||||||
# TODO: maybe replace the default in here with now().isoformat?
|
# TODO: maybe replace the default in here with now().isoformat?
|
||||||
'type': 'string', 'required': False
|
"type": "string",
|
||||||
|
"required": False,
|
||||||
},
|
},
|
||||||
|
"color": {"coerce": Color, "required": False},
|
||||||
'color': {
|
"footer": {"type": "dict", "schema": EMBED_FOOTER, "required": False},
|
||||||
'coerce': Color, 'required': False
|
"image": {"type": "dict", "schema": EMBED_IMAGE, "required": False},
|
||||||
},
|
"thumbnail": {"type": "dict", "schema": EMBED_THUMBNAIL, "required": False},
|
||||||
|
|
||||||
'footer': {
|
|
||||||
'type': 'dict',
|
|
||||||
'schema': EMBED_FOOTER,
|
|
||||||
'required': False,
|
|
||||||
},
|
|
||||||
'image': {
|
|
||||||
'type': 'dict',
|
|
||||||
'schema': EMBED_IMAGE,
|
|
||||||
'required': False,
|
|
||||||
},
|
|
||||||
'thumbnail': {
|
|
||||||
'type': 'dict',
|
|
||||||
'schema': EMBED_THUMBNAIL,
|
|
||||||
'required': False,
|
|
||||||
},
|
|
||||||
|
|
||||||
# NOTE: 'video' set by us
|
# NOTE: 'video' set by us
|
||||||
# NOTE: 'provider' set by us
|
# NOTE: 'provider' set by us
|
||||||
|
"author": {"type": "dict", "schema": EMBED_AUTHOR, "required": False},
|
||||||
'author': {
|
"fields": {
|
||||||
'type': 'dict',
|
"type": "list",
|
||||||
'schema': EMBED_AUTHOR,
|
"schema": {"type": "dict", "schema": EMBED_FIELD},
|
||||||
'required': False,
|
"required": False,
|
||||||
},
|
|
||||||
|
|
||||||
'fields': {
|
|
||||||
'type': 'list',
|
|
||||||
'schema': {'type': 'dict', 'schema': EMBED_FIELD},
|
|
||||||
'required': False,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -52,13 +52,14 @@ class Flags:
|
||||||
>>> i2.is_field_3
|
>>> i2.is_field_3
|
||||||
False
|
False
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init_subclass__(cls, **_kwargs):
|
def __init_subclass__(cls, **_kwargs):
|
||||||
attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
|
attrs = inspect.getmembers(cls, lambda x: not inspect.isroutine(x))
|
||||||
|
|
||||||
def _make_int(value):
|
def _make_int(value):
|
||||||
res = Flags()
|
res = Flags()
|
||||||
|
|
||||||
setattr(res, 'value', value)
|
setattr(res, "value", value)
|
||||||
|
|
||||||
for attr, val in attrs:
|
for attr, val in attrs:
|
||||||
# get only the ones that represent a field in the
|
# get only the ones that represent a field in the
|
||||||
|
|
@ -69,7 +70,7 @@ class Flags:
|
||||||
has_attr = (value & val) == val
|
has_attr = (value & val) == val
|
||||||
|
|
||||||
# set each attribute
|
# set each attribute
|
||||||
setattr(res, f'is_{attr}', has_attr)
|
setattr(res, f"is_{attr}", has_attr)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
@ -84,17 +85,16 @@ class ChannelType(EasyEnum):
|
||||||
GUILD_CATEGORY = 4
|
GUILD_CATEGORY = 4
|
||||||
|
|
||||||
|
|
||||||
GUILD_CHANS = (ChannelType.GUILD_TEXT,
|
GUILD_CHANS = (
|
||||||
|
ChannelType.GUILD_TEXT,
|
||||||
ChannelType.GUILD_VOICE,
|
ChannelType.GUILD_VOICE,
|
||||||
ChannelType.GUILD_CATEGORY)
|
ChannelType.GUILD_CATEGORY,
|
||||||
|
|
||||||
|
|
||||||
VOICE_CHANNELS = (
|
|
||||||
ChannelType.DM, ChannelType.GUILD_VOICE,
|
|
||||||
ChannelType.GUILD_CATEGORY
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
VOICE_CHANNELS = (ChannelType.DM, ChannelType.GUILD_VOICE, ChannelType.GUILD_CATEGORY)
|
||||||
|
|
||||||
|
|
||||||
class ActivityType(EasyEnum):
|
class ActivityType(EasyEnum):
|
||||||
PLAYING = 0
|
PLAYING = 0
|
||||||
STREAMING = 1
|
STREAMING = 1
|
||||||
|
|
@ -120,7 +120,7 @@ SYS_MESSAGES = (
|
||||||
MessageType.CHANNEL_NAME_CHANGE,
|
MessageType.CHANNEL_NAME_CHANGE,
|
||||||
MessageType.CHANNEL_ICON_CHANGE,
|
MessageType.CHANNEL_ICON_CHANGE,
|
||||||
MessageType.CHANNEL_PINNED_MESSAGE,
|
MessageType.CHANNEL_PINNED_MESSAGE,
|
||||||
MessageType.GUILD_MEMBER_JOIN
|
MessageType.GUILD_MEMBER_JOIN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -137,6 +137,7 @@ class ActivityFlags(Flags):
|
||||||
|
|
||||||
Only related to rich presence.
|
Only related to rich presence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
instance = 1
|
instance = 1
|
||||||
join = 2
|
join = 2
|
||||||
spectate = 4
|
spectate = 4
|
||||||
|
|
@ -150,6 +151,7 @@ class UserFlags(Flags):
|
||||||
|
|
||||||
Used by the client to show badges.
|
Used by the client to show badges.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
staff = 1
|
staff = 1
|
||||||
partner = 2
|
partner = 2
|
||||||
hypesquad = 4
|
hypesquad = 4
|
||||||
|
|
@ -166,6 +168,7 @@ class UserFlags(Flags):
|
||||||
|
|
||||||
class MessageFlags(Flags):
|
class MessageFlags(Flags):
|
||||||
"""Message flags."""
|
"""Message flags."""
|
||||||
|
|
||||||
none = 0
|
none = 0
|
||||||
|
|
||||||
crossposted = 1 << 0
|
crossposted = 1 << 0
|
||||||
|
|
@ -175,11 +178,12 @@ class MessageFlags(Flags):
|
||||||
|
|
||||||
class StatusType(EasyEnum):
|
class StatusType(EasyEnum):
|
||||||
"""All statuses there can be in a presence."""
|
"""All statuses there can be in a presence."""
|
||||||
ONLINE = 'online'
|
|
||||||
DND = 'dnd'
|
ONLINE = "online"
|
||||||
IDLE = 'idle'
|
DND = "dnd"
|
||||||
INVISIBLE = 'invisible'
|
IDLE = "idle"
|
||||||
OFFLINE = 'offline'
|
INVISIBLE = "invisible"
|
||||||
|
OFFLINE = "offline"
|
||||||
|
|
||||||
|
|
||||||
class ExplicitFilter(EasyEnum):
|
class ExplicitFilter(EasyEnum):
|
||||||
|
|
@ -187,6 +191,7 @@ class ExplicitFilter(EasyEnum):
|
||||||
|
|
||||||
Also applies to guilds.
|
Also applies to guilds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
EDGE = 0
|
EDGE = 0
|
||||||
FRIENDS = 1
|
FRIENDS = 1
|
||||||
SAFE = 2
|
SAFE = 2
|
||||||
|
|
@ -194,6 +199,7 @@ class ExplicitFilter(EasyEnum):
|
||||||
|
|
||||||
class VerificationLevel(IntEnum):
|
class VerificationLevel(IntEnum):
|
||||||
"""Verification level for guilds."""
|
"""Verification level for guilds."""
|
||||||
|
|
||||||
NONE = 0
|
NONE = 0
|
||||||
LOW = 1
|
LOW = 1
|
||||||
MEDIUM = 2
|
MEDIUM = 2
|
||||||
|
|
@ -205,6 +211,7 @@ class VerificationLevel(IntEnum):
|
||||||
|
|
||||||
class RelationshipType(EasyEnum):
|
class RelationshipType(EasyEnum):
|
||||||
"""Relationship types between users."""
|
"""Relationship types between users."""
|
||||||
|
|
||||||
FRIEND = 1
|
FRIEND = 1
|
||||||
BLOCK = 2
|
BLOCK = 2
|
||||||
INCOMING = 3
|
INCOMING = 3
|
||||||
|
|
@ -213,6 +220,7 @@ class RelationshipType(EasyEnum):
|
||||||
|
|
||||||
class MessageNotifications(EasyEnum):
|
class MessageNotifications(EasyEnum):
|
||||||
"""Message notifications"""
|
"""Message notifications"""
|
||||||
|
|
||||||
ALL = 0
|
ALL = 0
|
||||||
MENTIONS = 1
|
MENTIONS = 1
|
||||||
NOTHING = 2
|
NOTHING = 2
|
||||||
|
|
@ -220,6 +228,7 @@ class MessageNotifications(EasyEnum):
|
||||||
|
|
||||||
class PremiumType:
|
class PremiumType:
|
||||||
"""Premium (Nitro) type."""
|
"""Premium (Nitro) type."""
|
||||||
|
|
||||||
TIER_1 = 1
|
TIER_1 = 1
|
||||||
TIER_2 = 2
|
TIER_2 = 2
|
||||||
NONE = None
|
NONE = None
|
||||||
|
|
@ -227,12 +236,13 @@ class PremiumType:
|
||||||
|
|
||||||
class Feature(EasyEnum):
|
class Feature(EasyEnum):
|
||||||
"""Guild features."""
|
"""Guild features."""
|
||||||
invite_splash = 'INVITE_SPLASH'
|
|
||||||
vip = 'VIP_REGIONS'
|
invite_splash = "INVITE_SPLASH"
|
||||||
vanity = 'VANITY_URL'
|
vip = "VIP_REGIONS"
|
||||||
emoji = 'MORE_EMOJI'
|
vanity = "VANITY_URL"
|
||||||
verified = 'VERIFIED'
|
emoji = "MORE_EMOJI"
|
||||||
|
verified = "VERIFIED"
|
||||||
|
|
||||||
# unknown
|
# unknown
|
||||||
commerce = 'COMMERCE'
|
commerce = "COMMERCE"
|
||||||
news = 'NEWS'
|
news = "NEWS"
|
||||||
|
|
|
||||||
|
|
@ -18,60 +18,64 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ERR_MSG_MAP = {
|
ERR_MSG_MAP = {
|
||||||
10001: 'Unknown account',
|
10001: "Unknown account",
|
||||||
10002: 'Unknown application',
|
10002: "Unknown application",
|
||||||
10003: 'Unknown channel',
|
10003: "Unknown channel",
|
||||||
10004: 'Unknown guild',
|
10004: "Unknown guild",
|
||||||
10005: 'Unknown integration',
|
10005: "Unknown integration",
|
||||||
10006: 'Unknown invite',
|
10006: "Unknown invite",
|
||||||
10007: 'Unknown member',
|
10007: "Unknown member",
|
||||||
10008: 'Unknown message',
|
10008: "Unknown message",
|
||||||
10009: 'Unknown overwrite',
|
10009: "Unknown overwrite",
|
||||||
10010: 'Unknown provider',
|
10010: "Unknown provider",
|
||||||
10011: 'Unknown role',
|
10011: "Unknown role",
|
||||||
10012: 'Unknown token',
|
10012: "Unknown token",
|
||||||
10013: 'Unknown user',
|
10013: "Unknown user",
|
||||||
10014: 'Unknown Emoji',
|
10014: "Unknown Emoji",
|
||||||
10015: 'Unknown Webhook',
|
10015: "Unknown Webhook",
|
||||||
20001: 'Bots cannot use this endpoint',
|
20001: "Bots cannot use this endpoint",
|
||||||
20002: 'Only bots can use this endpoint',
|
20002: "Only bots can use this endpoint",
|
||||||
30001: 'Maximum number of guilds reached (100)',
|
30001: "Maximum number of guilds reached (100)",
|
||||||
30002: 'Maximum number of friends reached (1000)',
|
30002: "Maximum number of friends reached (1000)",
|
||||||
30003: 'Maximum number of pins reached (50)',
|
30003: "Maximum number of pins reached (50)",
|
||||||
30005: 'Maximum number of guild roles reached (250)',
|
30005: "Maximum number of guild roles reached (250)",
|
||||||
30010: 'Maximum number of reactions reached (20)',
|
30010: "Maximum number of reactions reached (20)",
|
||||||
30013: 'Maximum number of guild channels reached (500)',
|
30013: "Maximum number of guild channels reached (500)",
|
||||||
40001: 'Unauthorized',
|
40001: "Unauthorized",
|
||||||
50001: 'Missing access',
|
50001: "Missing access",
|
||||||
50002: 'Invalid account type',
|
50002: "Invalid account type",
|
||||||
50003: 'Cannot execute action on a DM channel',
|
50003: "Cannot execute action on a DM channel",
|
||||||
50004: 'Widget Disabled',
|
50004: "Widget Disabled",
|
||||||
50005: 'Cannot edit a message authored by another user',
|
50005: "Cannot edit a message authored by another user",
|
||||||
50006: 'Cannot send an empty message',
|
50006: "Cannot send an empty message",
|
||||||
50007: 'Cannot send messages to this user',
|
50007: "Cannot send messages to this user",
|
||||||
50008: 'Cannot send messages in a voice channel',
|
50008: "Cannot send messages in a voice channel",
|
||||||
50009: 'Channel verification level is too high',
|
50009: "Channel verification level is too high",
|
||||||
50010: 'OAuth2 application does not have a bot',
|
50010: "OAuth2 application does not have a bot",
|
||||||
50011: 'OAuth2 application limit reached',
|
50011: "OAuth2 application limit reached",
|
||||||
50012: 'Invalid OAuth state',
|
50012: "Invalid OAuth state",
|
||||||
50013: 'Missing permissions',
|
50013: "Missing permissions",
|
||||||
50014: 'Invalid authentication token',
|
50014: "Invalid authentication token",
|
||||||
50015: 'Note is too long',
|
50015: "Note is too long",
|
||||||
50016: ('Provided too few or too many messages to delete. Must provide at '
|
50016: (
|
||||||
'least 2 and fewer than 100 messages to delete.'),
|
"Provided too few or too many messages to delete. Must provide at "
|
||||||
50019: 'A message can only be pinned to the channel it was sent in',
|
"least 2 and fewer than 100 messages to delete."
|
||||||
50020: 'Invite code is either invalid or taken.',
|
),
|
||||||
50021: 'Cannot execute action on a system message',
|
50019: "A message can only be pinned to the channel it was sent in",
|
||||||
50025: 'Invalid OAuth2 access token',
|
50020: "Invite code is either invalid or taken.",
|
||||||
50034: 'A message provided was too old to bulk delete',
|
50021: "Cannot execute action on a system message",
|
||||||
50035: 'Invalid Form Body',
|
50025: "Invalid OAuth2 access token",
|
||||||
50036: 'An invite was accepted to a guild the application\'s bot is not in',
|
50034: "A message provided was too old to bulk delete",
|
||||||
50041: 'Invalid API version',
|
50035: "Invalid Form Body",
|
||||||
90001: 'Reaction blocked',
|
50036: "An invite was accepted to a guild the application's bot is not in",
|
||||||
|
50041: "Invalid API version",
|
||||||
|
90001: "Reaction blocked",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class LitecordError(Exception):
|
class LitecordError(Exception):
|
||||||
"""Base class for litecord errors"""
|
"""Base class for litecord errors"""
|
||||||
|
|
||||||
status_code = 500
|
status_code = 500
|
||||||
|
|
||||||
def _get_err_msg(self, err_code: int) -> str:
|
def _get_err_msg(self, err_code: int) -> str:
|
||||||
|
|
@ -91,7 +95,7 @@ class LitecordError(Exception):
|
||||||
|
|
||||||
return message
|
return message
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return self._get_err_msg(getattr(self, 'error_code', None))
|
return self._get_err_msg(getattr(self, "error_code", None))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def json(self):
|
def json(self):
|
||||||
|
|
@ -143,7 +147,7 @@ class MissingPermissions(Forbidden):
|
||||||
class WebsocketClose(Exception):
|
class WebsocketClose(Exception):
|
||||||
@property
|
@property
|
||||||
def code(self):
|
def code(self):
|
||||||
from_class = getattr(self, 'close_code', None)
|
from_class = getattr(self, "close_code", None)
|
||||||
|
|
||||||
if from_class:
|
if from_class:
|
||||||
return from_class
|
return from_class
|
||||||
|
|
@ -152,7 +156,7 @@ class WebsocketClose(Exception):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reason(self):
|
def reason(self):
|
||||||
from_class = getattr(self, 'close_code', None)
|
from_class = getattr(self, "close_code", None)
|
||||||
|
|
||||||
if from_class:
|
if from_class:
|
||||||
return self.args[0]
|
return self.args[0]
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,7 @@ from litecord.utils import LitecordJSONEncoder
|
||||||
|
|
||||||
def encode_json(payload) -> str:
|
def encode_json(payload) -> str:
|
||||||
"""Encode a given payload to JSON."""
|
"""Encode a given payload to JSON."""
|
||||||
return json.dumps(payload, separators=(',', ':'),
|
return json.dumps(payload, separators=(",", ":"), cls=LitecordJSONEncoder)
|
||||||
cls=LitecordJSONEncoder)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_json(data: str):
|
def decode_json(data: str):
|
||||||
|
|
@ -71,6 +70,7 @@ def _etf_decode_dict(data):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def decode_etf(data: bytes):
|
def decode_etf(data: bytes):
|
||||||
"""Decode data in ETF to any."""
|
"""Decode data in ETF to any."""
|
||||||
res = earl.unpack(data)
|
res = earl.unpack(data)
|
||||||
|
|
|
||||||
|
|
@ -24,37 +24,36 @@ from litecord.gateway.websocket import GatewayWebsocket
|
||||||
async def websocket_handler(app, ws, url):
|
async def websocket_handler(app, ws, url):
|
||||||
"""Main websocket handler, checks query arguments when connecting to
|
"""Main websocket handler, checks query arguments when connecting to
|
||||||
the gateway and spawns a GatewayWebsocket instance for the connection."""
|
the gateway and spawns a GatewayWebsocket instance for the connection."""
|
||||||
args = urllib.parse.parse_qs(
|
args = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
|
||||||
urllib.parse.urlparse(url).query
|
|
||||||
)
|
|
||||||
|
|
||||||
# pull a dict.get but in a really bad way.
|
# pull a dict.get but in a really bad way.
|
||||||
try:
|
try:
|
||||||
gw_version = args['v'][0]
|
gw_version = args["v"][0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
gw_version = '6'
|
gw_version = "6"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
gw_encoding = args['encoding'][0]
|
gw_encoding = args["encoding"][0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
gw_encoding = 'json'
|
gw_encoding = "json"
|
||||||
|
|
||||||
if gw_version not in ('6', '7'):
|
if gw_version not in ("6", "7"):
|
||||||
return await ws.close(1000, 'Invalid gateway version')
|
return await ws.close(1000, "Invalid gateway version")
|
||||||
|
|
||||||
if gw_encoding not in ('json', 'etf'):
|
if gw_encoding not in ("json", "etf"):
|
||||||
return await ws.close(1000, 'Invalid gateway encoding')
|
return await ws.close(1000, "Invalid gateway encoding")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
gw_compress = args['compress'][0]
|
gw_compress = args["compress"][0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
gw_compress = None
|
gw_compress = None
|
||||||
|
|
||||||
if gw_compress and gw_compress not in ('zlib-stream', 'zstd-stream'):
|
if gw_compress and gw_compress not in ("zlib-stream", "zstd-stream"):
|
||||||
return await ws.close(1000, 'Invalid gateway compress')
|
return await ws.close(1000, "Invalid gateway compress")
|
||||||
|
|
||||||
gws = GatewayWebsocket(
|
gws = GatewayWebsocket(
|
||||||
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress)
|
ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress
|
||||||
|
)
|
||||||
|
|
||||||
# this can be run with a single await since this whole coroutine
|
# this can be run with a single await since this whole coroutine
|
||||||
# is already running in the background.
|
# is already running in the background.
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class OP:
|
class OP:
|
||||||
"""Gateway OP codes."""
|
"""Gateway OP codes."""
|
||||||
|
|
||||||
DISPATCH = 0
|
DISPATCH = 0
|
||||||
HEARTBEAT = 1
|
HEARTBEAT = 1
|
||||||
IDENTIFY = 2
|
IDENTIFY = 2
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ class PayloadStore:
|
||||||
This will only store a maximum of MAX_STORE_SIZE,
|
This will only store a maximum of MAX_STORE_SIZE,
|
||||||
dropping the older payloads when adding new ones.
|
dropping the older payloads when adding new ones.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MAX_STORE_SIZE = 250
|
MAX_STORE_SIZE = 250
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -60,20 +61,20 @@ class GatewayState:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.session_id = kwargs.get('session_id', gen_session_id())
|
self.session_id = kwargs.get("session_id", gen_session_id())
|
||||||
|
|
||||||
#: event sequence number
|
#: event sequence number
|
||||||
self.seq = kwargs.get('seq', 0)
|
self.seq = kwargs.get("seq", 0)
|
||||||
|
|
||||||
#: last seq sent by us, the backend
|
#: last seq sent by us, the backend
|
||||||
self.last_seq = 0
|
self.last_seq = 0
|
||||||
|
|
||||||
#: shard information about the state,
|
#: shard information about the state,
|
||||||
# its id and shard count
|
# its id and shard count
|
||||||
self.shard = kwargs.get('shard', [0, 1])
|
self.shard = kwargs.get("shard", [0, 1])
|
||||||
|
|
||||||
self.user_id = kwargs.get('user_id')
|
self.user_id = kwargs.get("user_id")
|
||||||
self.bot = kwargs.get('bot', False)
|
self.bot = kwargs.get("bot", False)
|
||||||
|
|
||||||
#: set by the gateway connection
|
#: set by the gateway connection
|
||||||
# on OP STATUS_UPDATE
|
# on OP STATUS_UPDATE
|
||||||
|
|
@ -90,5 +91,4 @@ class GatewayState:
|
||||||
self.__dict__[key] = value
|
self.__dict__[key] = value
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f'GatewayState<seq={self.seq} '
|
return f"GatewayState<seq={self.seq} " f"shard={self.shard} uid={self.user_id}>"
|
||||||
f'shard={self.shard} uid={self.user_id}>')
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ class ManagerClose(Exception):
|
||||||
class StateDictWrapper:
|
class StateDictWrapper:
|
||||||
"""Wrap a mapping so that any kind of access to the mapping while the
|
"""Wrap a mapping so that any kind of access to the mapping while the
|
||||||
state manager is closed raises a ManagerClose error"""
|
state manager is closed raises a ManagerClose error"""
|
||||||
|
|
||||||
def __init__(self, state_manager, mapping):
|
def __init__(self, state_manager, mapping):
|
||||||
self.state_manager = state_manager
|
self.state_manager = state_manager
|
||||||
self._map = mapping
|
self._map = mapping
|
||||||
|
|
@ -98,7 +99,7 @@ class StateManager:
|
||||||
"""Insert a new state object."""
|
"""Insert a new state object."""
|
||||||
user_states = self.states[state.user_id]
|
user_states = self.states[state.user_id]
|
||||||
|
|
||||||
log.debug('inserting state: {!r}', state)
|
log.debug("inserting state: {!r}", state)
|
||||||
user_states[state.session_id] = state
|
user_states[state.session_id] = state
|
||||||
self.states_raw[state.session_id] = state
|
self.states_raw[state.session_id] = state
|
||||||
|
|
||||||
|
|
@ -128,7 +129,7 @@ class StateManager:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.debug('removing state: {!r}', state)
|
log.debug("removing state: {!r}", state)
|
||||||
self.states[state.user_id].pop(state.session_id)
|
self.states[state.user_id].pop(state.session_id)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
@ -152,8 +153,7 @@ class StateManager:
|
||||||
"""Fetch all states tied to a single user."""
|
"""Fetch all states tied to a single user."""
|
||||||
return list(self.states[user_id].values())
|
return list(self.states[user_id].values())
|
||||||
|
|
||||||
def guild_states(self, member_ids: List[int],
|
def guild_states(self, member_ids: List[int], guild_id: int) -> List[GatewayState]:
|
||||||
guild_id: int) -> List[GatewayState]:
|
|
||||||
"""Fetch all possible states about members in a guild."""
|
"""Fetch all possible states about members in a guild."""
|
||||||
states = []
|
states = []
|
||||||
|
|
||||||
|
|
@ -164,14 +164,14 @@ class StateManager:
|
||||||
# since server start, so we need to add a dummy state
|
# since server start, so we need to add a dummy state
|
||||||
if not member_states:
|
if not member_states:
|
||||||
dummy_state = GatewayState(
|
dummy_state = GatewayState(
|
||||||
session_id='',
|
session_id="",
|
||||||
user_id=member_id,
|
user_id=member_id,
|
||||||
presence={
|
presence={
|
||||||
'afk': False,
|
"afk": False,
|
||||||
'status': 'offline',
|
"status": "offline",
|
||||||
'game': None,
|
"game": None,
|
||||||
'since': 0
|
"since": 0,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
states.append(dummy_state)
|
states.append(dummy_state)
|
||||||
|
|
@ -187,9 +187,7 @@ class StateManager:
|
||||||
"""Send OP Reconnect to a single connection."""
|
"""Send OP Reconnect to a single connection."""
|
||||||
websocket = state.ws
|
websocket = state.ws
|
||||||
|
|
||||||
await websocket.send({
|
await websocket.send({"op": OP.RECONNECT})
|
||||||
'op': OP.RECONNECT
|
|
||||||
})
|
|
||||||
|
|
||||||
# wait 200ms
|
# wait 200ms
|
||||||
# so that the client has time to process
|
# so that the client has time to process
|
||||||
|
|
@ -198,12 +196,9 @@ class StateManager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# try to close the connection ourselves
|
# try to close the connection ourselves
|
||||||
await websocket.ws.close(
|
await websocket.ws.close(code=4000, reason="litecord shutting down")
|
||||||
code=4000,
|
|
||||||
reason='litecord shutting down'
|
|
||||||
)
|
|
||||||
except ConnectionClosed:
|
except ConnectionClosed:
|
||||||
log.info('client {} already closed', state)
|
log.info("client {} already closed", state)
|
||||||
|
|
||||||
def gen_close_tasks(self):
|
def gen_close_tasks(self):
|
||||||
"""Generate the tasks that will order the clients
|
"""Generate the tasks that will order the clients
|
||||||
|
|
@ -222,11 +217,9 @@ class StateManager:
|
||||||
if not state.ws:
|
if not state.ws:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tasks.append(
|
tasks.append(self.shutdown_single(state))
|
||||||
self.shutdown_single(state)
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info('made {} shutdown tasks', len(tasks))
|
log.info("made {} shutdown tasks", len(tasks))
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,11 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class WebsocketFileHandler:
|
class WebsocketFileHandler:
|
||||||
"""A handler around a websocket that wraps normal I/O calls into
|
"""A handler around a websocket that wraps normal I/O calls into
|
||||||
the websocket's respective asyncio calls via asyncio.ensure_future."""
|
the websocket's respective asyncio calls via asyncio.ensure_future."""
|
||||||
|
|
||||||
def __init__(self, ws):
|
def __init__(self, ws):
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,23 +31,20 @@ from logbook import Logger
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from litecord.enums import RelationshipType, ChannelType
|
from litecord.enums import RelationshipType, ChannelType
|
||||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||||
from litecord.utils import (
|
from litecord.utils import task_wrapper, yield_chunks, maybe_int
|
||||||
task_wrapper, yield_chunks, maybe_int
|
|
||||||
)
|
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
from litecord.gateway.state import GatewayState
|
from litecord.gateway.state import GatewayState
|
||||||
|
|
||||||
from litecord.errors import (
|
from litecord.errors import WebsocketClose, Unauthorized, Forbidden, BadRequest
|
||||||
WebsocketClose, Unauthorized, Forbidden, BadRequest
|
|
||||||
)
|
|
||||||
from litecord.gateway.errors import (
|
from litecord.gateway.errors import (
|
||||||
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired
|
DecodeError,
|
||||||
)
|
UnknownOPCode,
|
||||||
from litecord.gateway.encoding import (
|
InvalidShard,
|
||||||
encode_json, decode_json, encode_etf, decode_etf
|
ShardingRequired,
|
||||||
)
|
)
|
||||||
|
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
|
||||||
|
|
||||||
from litecord.gateway.utils import WebsocketFileHandler
|
from litecord.gateway.utils import WebsocketFileHandler
|
||||||
|
|
||||||
|
|
@ -56,15 +53,22 @@ from litecord.storage import int_
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
WebsocketProperties = collections.namedtuple(
|
WebsocketProperties = collections.namedtuple(
|
||||||
'WebsocketProperties', 'v encoding compress zctx zsctx tasks'
|
"WebsocketProperties", "v encoding compress zctx zsctx tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
WebsocketObjects = collections.namedtuple(
|
WebsocketObjects = collections.namedtuple(
|
||||||
'WebsocketObjects', (
|
"WebsocketObjects",
|
||||||
'db', 'state_manager', 'storage',
|
(
|
||||||
'loop', 'dispatcher', 'presence', 'ratelimiter',
|
"db",
|
||||||
'user_storage', 'voice'
|
"state_manager",
|
||||||
)
|
"storage",
|
||||||
|
"loop",
|
||||||
|
"dispatcher",
|
||||||
|
"presence",
|
||||||
|
"ratelimiter",
|
||||||
|
"user_storage",
|
||||||
|
"voice",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -73,9 +77,15 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
def __init__(self, ws, app, **kwargs):
|
def __init__(self, ws, app, **kwargs):
|
||||||
self.ext = WebsocketObjects(
|
self.ext = WebsocketObjects(
|
||||||
app.db, app.state_manager, app.storage, app.loop,
|
app.db,
|
||||||
app.dispatcher, app.presence, app.ratelimiter,
|
app.state_manager,
|
||||||
app.user_storage, app.voice
|
app.storage,
|
||||||
|
app.loop,
|
||||||
|
app.dispatcher,
|
||||||
|
app.presence,
|
||||||
|
app.ratelimiter,
|
||||||
|
app.user_storage,
|
||||||
|
app.voice,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.storage = self.ext.storage
|
self.storage = self.ext.storage
|
||||||
|
|
@ -84,15 +94,15 @@ class GatewayWebsocket:
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
||||||
self.wsp = WebsocketProperties(
|
self.wsp = WebsocketProperties(
|
||||||
kwargs.get('v'),
|
kwargs.get("v"),
|
||||||
kwargs.get('encoding', 'json'),
|
kwargs.get("encoding", "json"),
|
||||||
kwargs.get('compress', None),
|
kwargs.get("compress", None),
|
||||||
zlib.compressobj(),
|
zlib.compressobj(),
|
||||||
zstd.ZstdCompressor(),
|
zstd.ZstdCompressor(),
|
||||||
{}
|
{},
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug('websocket properties: {!r}', self.wsp)
|
log.debug("websocket properties: {!r}", self.wsp)
|
||||||
|
|
||||||
self.state = None
|
self.state = None
|
||||||
|
|
||||||
|
|
@ -102,8 +112,8 @@ class GatewayWebsocket:
|
||||||
encoding = self.wsp.encoding
|
encoding = self.wsp.encoding
|
||||||
|
|
||||||
encodings = {
|
encodings = {
|
||||||
'json': (encode_json, decode_json),
|
"json": (encode_json, decode_json),
|
||||||
'etf': (encode_etf, decode_etf),
|
"etf": (encode_etf, decode_etf),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.encoder, self.decoder = encodings[encoding]
|
self.encoder, self.decoder = encodings[encoding]
|
||||||
|
|
@ -111,16 +121,17 @@ class GatewayWebsocket:
|
||||||
async def _chunked_send(self, data: bytes, chunk_size: int):
|
async def _chunked_send(self, data: bytes, chunk_size: int):
|
||||||
"""Split data in chunk_size-big chunks and send them
|
"""Split data in chunk_size-big chunks and send them
|
||||||
over the websocket."""
|
over the websocket."""
|
||||||
log.debug('zlib-stream: chunking {} bytes into {}-byte chunks',
|
log.debug(
|
||||||
len(data), chunk_size)
|
"zlib-stream: chunking {} bytes into {}-byte chunks", len(data), chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
total_chunks = 0
|
total_chunks = 0
|
||||||
for chunk in yield_chunks(data, chunk_size):
|
for chunk in yield_chunks(data, chunk_size):
|
||||||
total_chunks += 1
|
total_chunks += 1
|
||||||
log.debug('zlib-stream: chunk {}', total_chunks)
|
log.debug("zlib-stream: chunk {}", total_chunks)
|
||||||
await self.ws.send(chunk)
|
await self.ws.send(chunk)
|
||||||
|
|
||||||
log.debug('zlib-stream: sent {} chunks', total_chunks)
|
log.debug("zlib-stream: sent {} chunks", total_chunks)
|
||||||
|
|
||||||
async def _zlib_stream_send(self, encoded):
|
async def _zlib_stream_send(self, encoded):
|
||||||
"""Sending a single payload across multiple compressed
|
"""Sending a single payload across multiple compressed
|
||||||
|
|
@ -130,8 +141,12 @@ class GatewayWebsocket:
|
||||||
data1 = self.wsp.zctx.compress(encoded)
|
data1 = self.wsp.zctx.compress(encoded)
|
||||||
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
|
||||||
|
|
||||||
log.debug('zlib-stream: length {} -> compressed ({} + {})',
|
log.debug(
|
||||||
len(encoded), len(data1), len(data2))
|
"zlib-stream: length {} -> compressed ({} + {})",
|
||||||
|
len(encoded),
|
||||||
|
len(data1),
|
||||||
|
len(data2),
|
||||||
|
)
|
||||||
|
|
||||||
if not data1:
|
if not data1:
|
||||||
# if data1 is nothing, that might cause problems
|
# if data1 is nothing, that might cause problems
|
||||||
|
|
@ -139,8 +154,11 @@ class GatewayWebsocket:
|
||||||
data1 = bytes([data2[0]])
|
data1 = bytes([data2[0]])
|
||||||
data2 = data2[1:]
|
data2 = data2[1:]
|
||||||
|
|
||||||
log.debug('zlib-stream: len(data1) == 0, remaking as ({} + {})',
|
log.debug(
|
||||||
len(data1), len(data2))
|
"zlib-stream: len(data1) == 0, remaking as ({} + {})",
|
||||||
|
len(data1),
|
||||||
|
len(data2),
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE: the old approach was ws.send(data1 + data2).
|
# NOTE: the old approach was ws.send(data1 + data2).
|
||||||
# I changed this to a chunked send of data1 and data2
|
# I changed this to a chunked send of data1 and data2
|
||||||
|
|
@ -157,8 +175,7 @@ class GatewayWebsocket:
|
||||||
await self._chunked_send(data2, 1024)
|
await self._chunked_send(data2, 1024)
|
||||||
|
|
||||||
async def _zstd_stream_send(self, encoded):
|
async def _zstd_stream_send(self, encoded):
|
||||||
compressor = self.wsp.zsctx.stream_writer(
|
compressor = self.wsp.zsctx.stream_writer(WebsocketFileHandler(self.ws))
|
||||||
WebsocketFileHandler(self.ws))
|
|
||||||
|
|
||||||
compressor.write(encoded)
|
compressor.write(encoded)
|
||||||
compressor.flush(zstd.FLUSH_FRAME)
|
compressor.flush(zstd.FLUSH_FRAME)
|
||||||
|
|
@ -172,21 +189,23 @@ class GatewayWebsocket:
|
||||||
encoded = self.encoder(payload)
|
encoded = self.encoder(payload)
|
||||||
|
|
||||||
if len(encoded) < 2048:
|
if len(encoded) < 2048:
|
||||||
log.debug('sending\n{}', pprint.pformat(payload))
|
log.debug("sending\n{}", pprint.pformat(payload))
|
||||||
else:
|
else:
|
||||||
log.debug('sending {}', pprint.pformat(payload))
|
log.debug("sending {}", pprint.pformat(payload))
|
||||||
log.debug('sending op={} s={} t={} (too big)',
|
log.debug(
|
||||||
payload.get('op'),
|
"sending op={} s={} t={} (too big)",
|
||||||
payload.get('s'),
|
payload.get("op"),
|
||||||
payload.get('t'))
|
payload.get("s"),
|
||||||
|
payload.get("t"),
|
||||||
|
)
|
||||||
|
|
||||||
# treat encoded as bytes
|
# treat encoded as bytes
|
||||||
if not isinstance(encoded, bytes):
|
if not isinstance(encoded, bytes):
|
||||||
encoded = encoded.encode()
|
encoded = encoded.encode()
|
||||||
|
|
||||||
if self.wsp.compress == 'zlib-stream':
|
if self.wsp.compress == "zlib-stream":
|
||||||
await self._zlib_stream_send(encoded)
|
await self._zlib_stream_send(encoded)
|
||||||
elif self.wsp.compress == 'zstd-stream':
|
elif self.wsp.compress == "zstd-stream":
|
||||||
await self._zstd_stream_send(encoded)
|
await self._zstd_stream_send(encoded)
|
||||||
elif self.state and self.state.compress and len(encoded) > 1024:
|
elif self.state and self.state.compress and len(encoded) > 1024:
|
||||||
# TODO: should we only compress on >1KB packets? or maybe we
|
# TODO: should we only compress on >1KB packets? or maybe we
|
||||||
|
|
@ -203,16 +222,10 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def send_op(self, op_code: int, data: Any):
|
async def send_op(self, op_code: int, data: Any):
|
||||||
"""Send a packet but just the OP code information is filled in."""
|
"""Send a packet but just the OP code information is filled in."""
|
||||||
await self.send({
|
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
||||||
'op': op_code,
|
|
||||||
'd': data,
|
|
||||||
|
|
||||||
't': None,
|
|
||||||
's': None
|
|
||||||
})
|
|
||||||
|
|
||||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
def _check_ratelimit(self, key: str, ratelimit_key):
|
||||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
ratelimit = self.ext.ratelimiter.get_ratelimit(f"_ws.{key}")
|
||||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||||
return bucket.update_rate_limit()
|
return bucket.update_rate_limit()
|
||||||
|
|
||||||
|
|
@ -221,19 +234,19 @@ class GatewayWebsocket:
|
||||||
# if the client heartbeats in time,
|
# if the client heartbeats in time,
|
||||||
# this task will be cancelled.
|
# this task will be cancelled.
|
||||||
await asyncio.sleep(interval / 1000)
|
await asyncio.sleep(interval / 1000)
|
||||||
await self.ws.close(4000, 'Heartbeat expired')
|
await self.ws.close(4000, "Heartbeat expired")
|
||||||
|
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def _hb_start(self, interval: int):
|
def _hb_start(self, interval: int):
|
||||||
# always refresh the heartbeat task
|
# always refresh the heartbeat task
|
||||||
# when possible
|
# when possible
|
||||||
task = self.wsp.tasks.get('heartbeat')
|
task = self.wsp.tasks.get("heartbeat")
|
||||||
if task:
|
if task:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
self.wsp.tasks['heartbeat'] = self.ext.loop.create_task(
|
self.wsp.tasks["heartbeat"] = self.ext.loop.create_task(
|
||||||
task_wrapper('hb wait', self._hb_wait(interval))
|
task_wrapper("hb wait", self._hb_wait(interval))
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _send_hello(self):
|
async def _send_hello(self):
|
||||||
|
|
@ -241,12 +254,9 @@ class GatewayWebsocket:
|
||||||
# random heartbeat intervals
|
# random heartbeat intervals
|
||||||
interval = randint(40, 46) * 1000
|
interval = randint(40, 46) * 1000
|
||||||
|
|
||||||
await self.send_op(OP.HELLO, {
|
await self.send_op(
|
||||||
'heartbeat_interval': interval,
|
OP.HELLO, {"heartbeat_interval": interval, "_trace": ["lesbian-server"]}
|
||||||
'_trace': [
|
)
|
||||||
'lesbian-server'
|
|
||||||
],
|
|
||||||
})
|
|
||||||
|
|
||||||
self._hb_start(interval)
|
self._hb_start(interval)
|
||||||
|
|
||||||
|
|
@ -255,16 +265,15 @@ class GatewayWebsocket:
|
||||||
self.state.seq += 1
|
self.state.seq += 1
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
'op': OP.DISPATCH,
|
"op": OP.DISPATCH,
|
||||||
't': event.upper(),
|
"t": event.upper(),
|
||||||
's': self.state.seq,
|
"s": self.state.seq,
|
||||||
'd': data,
|
"d": data,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.state.store[self.state.seq] = payload
|
self.state.store[self.state.seq] = payload
|
||||||
|
|
||||||
log.debug('sending payload {!r} sid {}',
|
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
|
||||||
event.upper(), self.state.session_id)
|
|
||||||
|
|
||||||
await self.send(payload)
|
await self.send(payload)
|
||||||
|
|
||||||
|
|
@ -274,16 +283,14 @@ class GatewayWebsocket:
|
||||||
guild_ids = await self._guild_ids()
|
guild_ids = await self._guild_ids()
|
||||||
|
|
||||||
if self.state.bot:
|
if self.state.bot:
|
||||||
return [{
|
return [{"id": row, "unavailable": True} for row in guild_ids]
|
||||||
'id': row,
|
|
||||||
'unavailable': True,
|
|
||||||
} for row in guild_ids]
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
**await self.storage.get_guild(guild_id, user_id),
|
**await self.storage.get_guild(guild_id, user_id),
|
||||||
**await self.storage.get_guild_extra(guild_id, user_id,
|
**await self.storage.get_guild_extra(
|
||||||
self.state.large)
|
guild_id, user_id, self.state.large
|
||||||
|
),
|
||||||
}
|
}
|
||||||
for guild_id in guild_ids
|
for guild_id in guild_ids
|
||||||
]
|
]
|
||||||
|
|
@ -298,13 +305,13 @@ class GatewayWebsocket:
|
||||||
for guild_obj in unavailable_guilds:
|
for guild_obj in unavailable_guilds:
|
||||||
# fetch full guild object including the 'large' field
|
# fetch full guild object including the 'large' field
|
||||||
guild = await self.storage.get_guild_full(
|
guild = await self.storage.get_guild_full(
|
||||||
int(guild_obj['id']), self.state.user_id, self.state.large
|
int(guild_obj["id"]), self.state.user_id, self.state.large
|
||||||
)
|
)
|
||||||
|
|
||||||
if guild is None:
|
if guild is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self.dispatch('GUILD_CREATE', guild)
|
await self.dispatch("GUILD_CREATE", guild)
|
||||||
|
|
||||||
async def _user_ready(self) -> dict:
|
async def _user_ready(self) -> dict:
|
||||||
"""Fetch information about users in the READY packet.
|
"""Fetch information about users in the READY packet.
|
||||||
|
|
@ -317,28 +324,28 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
relationships = await self.user_storage.get_relationships(user_id)
|
relationships = await self.user_storage.get_relationships(user_id)
|
||||||
|
|
||||||
friend_ids = [int(r['user']['id']) for r in relationships
|
friend_ids = [
|
||||||
if r['type'] == RelationshipType.FRIEND.value]
|
int(r["user"]["id"])
|
||||||
|
for r in relationships
|
||||||
|
if r["type"] == RelationshipType.FRIEND.value
|
||||||
|
]
|
||||||
|
|
||||||
friend_presences = await self.ext.presence.friend_presences(friend_ids)
|
friend_presences = await self.ext.presence.friend_presences(friend_ids)
|
||||||
settings = await self.user_storage.get_user_settings(user_id)
|
settings = await self.user_storage.get_user_settings(user_id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'user_settings': settings,
|
"user_settings": settings,
|
||||||
'notes': await self.user_storage.fetch_notes(user_id),
|
"notes": await self.user_storage.fetch_notes(user_id),
|
||||||
'relationships': relationships,
|
"relationships": relationships,
|
||||||
'presences': friend_presences,
|
"presences": friend_presences,
|
||||||
'read_state': await self.user_storage.get_read_state(user_id),
|
"read_state": await self.user_storage.get_read_state(user_id),
|
||||||
'user_guild_settings': await self.user_storage.get_guild_settings(
|
"user_guild_settings": await self.user_storage.get_guild_settings(user_id),
|
||||||
user_id),
|
"friend_suggestion_count": 0,
|
||||||
|
|
||||||
'friend_suggestion_count': 0,
|
|
||||||
|
|
||||||
# those are unused default values.
|
# those are unused default values.
|
||||||
'connected_accounts': [],
|
"connected_accounts": [],
|
||||||
'experiments': [],
|
"experiments": [],
|
||||||
'guild_experiments': [],
|
"guild_experiments": [],
|
||||||
'analytics_token': 'transbian',
|
"analytics_token": "transbian",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def dispatch_ready(self):
|
async def dispatch_ready(self):
|
||||||
|
|
@ -353,24 +360,21 @@ class GatewayWebsocket:
|
||||||
# user, fetch info
|
# user, fetch info
|
||||||
user_ready = await self._user_ready()
|
user_ready = await self._user_ready()
|
||||||
|
|
||||||
private_channels = (
|
private_channels = await self.user_storage.get_dms(
|
||||||
await self.user_storage.get_dms(user_id) +
|
user_id
|
||||||
await self.user_storage.get_gdms(user_id)
|
) + await self.user_storage.get_gdms(user_id)
|
||||||
)
|
|
||||||
|
|
||||||
base_ready = {
|
base_ready = {
|
||||||
'v': 6,
|
"v": 6,
|
||||||
'user': user,
|
"user": user,
|
||||||
|
"private_channels": private_channels,
|
||||||
'private_channels': private_channels,
|
"guilds": guilds,
|
||||||
|
"session_id": self.state.session_id,
|
||||||
'guilds': guilds,
|
"_trace": ["transbian"],
|
||||||
'session_id': self.state.session_id,
|
"shard": self.state.shard,
|
||||||
'_trace': ['transbian'],
|
|
||||||
'shard': self.state.shard,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
await self.dispatch('READY', {**base_ready, **user_ready})
|
await self.dispatch("READY", {**base_ready, **user_ready})
|
||||||
|
|
||||||
# async dispatch of guilds
|
# async dispatch of guilds
|
||||||
self.ext.loop.create_task(self._guild_dispatch(guilds))
|
self.ext.loop.create_task(self._guild_dispatch(guilds))
|
||||||
|
|
@ -380,33 +384,32 @@ class GatewayWebsocket:
|
||||||
"""
|
"""
|
||||||
current_shard, shard_count = shard
|
current_shard, shard_count = shard
|
||||||
|
|
||||||
guilds = await self.ext.db.fetchval("""
|
guilds = await self.ext.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM members
|
FROM members
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
recommended = max(int(guilds / 1200), 1)
|
recommended = max(int(guilds / 1200), 1)
|
||||||
|
|
||||||
if shard_count < recommended:
|
if shard_count < recommended:
|
||||||
raise ShardingRequired('Too many guilds for shard '
|
raise ShardingRequired("Too many guilds for shard " f"{current_shard}")
|
||||||
f'{current_shard}')
|
|
||||||
|
|
||||||
if guilds > 2500 and guilds / shard_count > 0.8:
|
if guilds > 2500 and guilds / shard_count > 0.8:
|
||||||
raise ShardingRequired('Too many shards. '
|
raise ShardingRequired("Too many shards. " f"(g={guilds} sc={shard_count})")
|
||||||
f'(g={guilds} sc={shard_count})')
|
|
||||||
|
|
||||||
if current_shard > shard_count:
|
if current_shard > shard_count:
|
||||||
raise InvalidShard('Shard count > Total shards')
|
raise InvalidShard("Shard count > Total shards")
|
||||||
|
|
||||||
async def _guild_ids(self) -> list:
|
async def _guild_ids(self) -> list:
|
||||||
"""Get a list of Guild IDs that are tied to this connection.
|
"""Get a list of Guild IDs that are tied to this connection.
|
||||||
|
|
||||||
The implementation is shard-aware.
|
The implementation is shard-aware.
|
||||||
"""
|
"""
|
||||||
guild_ids = await self.user_storage.get_user_guilds(
|
guild_ids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||||
self.state.user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
shard_id = self.state.current_shard
|
shard_id = self.state.current_shard
|
||||||
shard_count = self.state.shard_count
|
shard_count = self.state.shard_count
|
||||||
|
|
@ -414,10 +417,7 @@ class GatewayWebsocket:
|
||||||
def _get_shard(guild_id):
|
def _get_shard(guild_id):
|
||||||
return (guild_id >> 22) % shard_count
|
return (guild_id >> 22) % shard_count
|
||||||
|
|
||||||
filtered = filter(
|
filtered = filter(lambda guild_id: _get_shard(guild_id) == shard_id, guild_ids)
|
||||||
lambda guild_id: _get_shard(guild_id) == shard_id,
|
|
||||||
guild_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
return list(filtered)
|
return list(filtered)
|
||||||
|
|
||||||
|
|
@ -432,13 +432,17 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
# subscribe the user to all dms they have OPENED.
|
# subscribe the user to all dms they have OPENED.
|
||||||
dms = await self.user_storage.get_dms(user_id)
|
dms = await self.user_storage.get_dms(user_id)
|
||||||
dm_ids = [int(dm['id']) for dm in dms]
|
dm_ids = [int(dm["id"]) for dm in dms]
|
||||||
|
|
||||||
# fetch all group dms the user is a member of.
|
# fetch all group dms the user is a member of.
|
||||||
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
||||||
|
|
||||||
log.info('subscribing to {} guilds {} dms {} gdms',
|
log.info(
|
||||||
len(guild_ids), len(dm_ids), len(gdm_ids))
|
"subscribing to {} guilds {} dms {} gdms",
|
||||||
|
len(guild_ids),
|
||||||
|
len(dm_ids),
|
||||||
|
len(gdm_ids),
|
||||||
|
)
|
||||||
|
|
||||||
# guild_subscriptions:
|
# guild_subscriptions:
|
||||||
# enables dispatching of guild subscription events
|
# enables dispatching of guild subscription events
|
||||||
|
|
@ -447,10 +451,13 @@ class GatewayWebsocket:
|
||||||
# we enable processing of guild_subscriptions by adding flags
|
# we enable processing of guild_subscriptions by adding flags
|
||||||
# when subscribing to the given backend. those are optional.
|
# when subscribing to the given backend. those are optional.
|
||||||
channels_to_sub = [
|
channels_to_sub = [
|
||||||
('guild', guild_ids,
|
(
|
||||||
{'presence': guild_subscriptions, 'typing': guild_subscriptions}),
|
"guild",
|
||||||
('channel', dm_ids),
|
guild_ids,
|
||||||
('channel', gdm_ids),
|
{"presence": guild_subscriptions, "typing": guild_subscriptions},
|
||||||
|
),
|
||||||
|
("channel", dm_ids),
|
||||||
|
("channel", gdm_ids),
|
||||||
]
|
]
|
||||||
|
|
||||||
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
|
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||||
|
|
@ -460,28 +467,26 @@ class GatewayWebsocket:
|
||||||
# (their friends will also subscribe back
|
# (their friends will also subscribe back
|
||||||
# when they come online)
|
# when they come online)
|
||||||
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
friend_ids = await self.user_storage.get_friend_ids(user_id)
|
||||||
log.info('subscribing to {} friends', len(friend_ids))
|
log.info("subscribing to {} friends", len(friend_ids))
|
||||||
await self.ext.dispatcher.sub_many('friend', user_id, friend_ids)
|
await self.ext.dispatcher.sub_many("friend", user_id, friend_ids)
|
||||||
|
|
||||||
async def update_status(self, status: dict):
|
async def update_status(self, status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
if not self.state:
|
if not self.state:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._check_ratelimit('presence', self.state.session_id):
|
if self._check_ratelimit("presence", self.state.session_id):
|
||||||
# Presence Updates beyond the ratelimit
|
# Presence Updates beyond the ratelimit
|
||||||
# are just silently dropped.
|
# are just silently dropped.
|
||||||
return
|
return
|
||||||
|
|
||||||
default_status = {
|
default_status = {
|
||||||
'afk': False,
|
"afk": False,
|
||||||
|
|
||||||
# TODO: fetch status from settings
|
# TODO: fetch status from settings
|
||||||
'status': 'online',
|
"status": "online",
|
||||||
'game': None,
|
"game": None,
|
||||||
|
|
||||||
# TODO: this
|
# TODO: this
|
||||||
'since': 0,
|
"since": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
status = {**(status or {}), **default_status}
|
status = {**(status or {}), **default_status}
|
||||||
|
|
@ -489,39 +494,40 @@ class GatewayWebsocket:
|
||||||
try:
|
try:
|
||||||
status = validate(status, GW_STATUS_UPDATE)
|
status = validate(status, GW_STATUS_UPDATE)
|
||||||
except BadRequest as err:
|
except BadRequest as err:
|
||||||
log.warning(f'Invalid status update: {err}')
|
log.warning(f"Invalid status update: {err}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# try to extract game from activities
|
# try to extract game from activities
|
||||||
# when game not provided
|
# when game not provided
|
||||||
if not status.get('game'):
|
if not status.get("game"):
|
||||||
try:
|
try:
|
||||||
game = status['activities'][0]
|
game = status["activities"][0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
game = None
|
game = None
|
||||||
else:
|
else:
|
||||||
game = status['game']
|
game = status["game"]
|
||||||
|
|
||||||
# construct final status
|
# construct final status
|
||||||
status = {
|
status = {
|
||||||
'afk': status.get('afk', False),
|
"afk": status.get("afk", False),
|
||||||
'status': status.get('status', 'online'),
|
"status": status.get("status", "online"),
|
||||||
'game': game,
|
"game": game,
|
||||||
'since': status.get('since', 0),
|
"since": status.get("since", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.state.presence = status
|
self.state.presence = status
|
||||||
log.info(f'Updating presence status={status["status"]} for '
|
log.info(
|
||||||
f'uid={self.state.user_id}')
|
f'Updating presence status={status["status"]} for '
|
||||||
await self.ext.presence.dispatch_pres(self.state.user_id,
|
f"uid={self.state.user_id}"
|
||||||
self.state.presence)
|
)
|
||||||
|
await self.ext.presence.dispatch_pres(self.state.user_id, self.state.presence)
|
||||||
|
|
||||||
async def handle_1(self, payload: Dict[str, Any]):
|
async def handle_1(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 1 Heartbeat packets."""
|
"""Handle OP 1 Heartbeat packets."""
|
||||||
# give the client 3 more seconds before we
|
# give the client 3 more seconds before we
|
||||||
# close the websocket
|
# close the websocket
|
||||||
self._hb_start((46 + 3) * 1000)
|
self._hb_start((46 + 3) * 1000)
|
||||||
cliseq = payload.get('d')
|
cliseq = payload.get("d")
|
||||||
|
|
||||||
if self.state:
|
if self.state:
|
||||||
self.state.last_seq = cliseq
|
self.state.last_seq = cliseq
|
||||||
|
|
@ -529,39 +535,42 @@ class GatewayWebsocket:
|
||||||
await self.send_op(OP.HEARTBEAT_ACK, None)
|
await self.send_op(OP.HEARTBEAT_ACK, None)
|
||||||
|
|
||||||
async def _connect_ratelimit(self, user_id: int):
|
async def _connect_ratelimit(self, user_id: int):
|
||||||
if self._check_ratelimit('connect', user_id):
|
if self._check_ratelimit("connect", user_id):
|
||||||
await self.invalidate_session(False)
|
await self.invalidate_session(False)
|
||||||
raise WebsocketClose(4009, 'You are being ratelimited.')
|
raise WebsocketClose(4009, "You are being ratelimited.")
|
||||||
|
|
||||||
if self._check_ratelimit('session', user_id):
|
if self._check_ratelimit("session", user_id):
|
||||||
await self.invalidate_session(False)
|
await self.invalidate_session(False)
|
||||||
raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.')
|
raise WebsocketClose(4004, "Websocket Session Ratelimit reached.")
|
||||||
|
|
||||||
async def handle_2(self, payload: Dict[str, Any]):
|
async def handle_2(self, payload: Dict[str, Any]):
|
||||||
"""Handle the OP 2 Identify packet."""
|
"""Handle the OP 2 Identify packet."""
|
||||||
try:
|
try:
|
||||||
data = payload['d']
|
data = payload["d"]
|
||||||
token = data['token']
|
token = data["token"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise DecodeError('Invalid identify parameters')
|
raise DecodeError("Invalid identify parameters")
|
||||||
|
|
||||||
compress = data.get('compress', False)
|
compress = data.get("compress", False)
|
||||||
large = data.get('large_threshold', 50)
|
large = data.get("large_threshold", 50)
|
||||||
|
|
||||||
shard = data.get('shard', [0, 1])
|
shard = data.get("shard", [0, 1])
|
||||||
presence = data.get('presence')
|
presence = data.get("presence")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token, self.ext.db)
|
user_id = await raw_token_check(token, self.ext.db)
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, 'Authentication failed')
|
raise WebsocketClose(4004, "Authentication failed")
|
||||||
|
|
||||||
await self._connect_ratelimit(user_id)
|
await self._connect_ratelimit(user_id)
|
||||||
|
|
||||||
bot = await self.ext.db.fetchval("""
|
bot = await self.ext.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT bot FROM users
|
SELECT bot FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
await self._check_shards(shard, user_id)
|
await self._check_shards(shard, user_id)
|
||||||
|
|
||||||
|
|
@ -574,19 +583,19 @@ class GatewayWebsocket:
|
||||||
shard=shard,
|
shard=shard,
|
||||||
current_shard=shard[0],
|
current_shard=shard[0],
|
||||||
shard_count=shard[1],
|
shard_count=shard[1],
|
||||||
ws=self
|
ws=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
# link the state to the user
|
# link the state to the user
|
||||||
self.ext.state_manager.insert(self.state)
|
self.ext.state_manager.insert(self.state)
|
||||||
|
|
||||||
await self.update_status(presence)
|
await self.update_status(presence)
|
||||||
await self.subscribe_all(data.get('guild_subscriptions', True))
|
await self.subscribe_all(data.get("guild_subscriptions", True))
|
||||||
await self.dispatch_ready()
|
await self.dispatch_ready()
|
||||||
|
|
||||||
async def handle_3(self, payload: Dict[str, Any]):
|
async def handle_3(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 3 Status Update."""
|
"""Handle OP 3 Status Update."""
|
||||||
presence = payload['d']
|
presence = payload["d"]
|
||||||
|
|
||||||
# update_status will take care of validation and
|
# update_status will take care of validation and
|
||||||
# setting new presence to state
|
# setting new presence to state
|
||||||
|
|
@ -597,27 +606,27 @@ class GatewayWebsocket:
|
||||||
user settings."""
|
user settings."""
|
||||||
try:
|
try:
|
||||||
# TODO: fetch from settings if not provided
|
# TODO: fetch from settings if not provided
|
||||||
self_deaf = bool(data['self_deaf'])
|
self_deaf = bool(data["self_deaf"])
|
||||||
self_mute = bool(data['self_mute'])
|
self_mute = bool(data["self_mute"])
|
||||||
except (KeyError, ValueError):
|
except (KeyError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'deaf': state.deaf,
|
"deaf": state.deaf,
|
||||||
'mute': state.mute,
|
"mute": state.mute,
|
||||||
'self_deaf': self_deaf,
|
"self_deaf": self_deaf,
|
||||||
'self_mute': self_mute,
|
"self_mute": self_mute,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def handle_4(self, payload: Dict[str, Any]):
|
async def handle_4(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 4 Voice Status Update."""
|
"""Handle OP 4 Voice Status Update."""
|
||||||
data = payload['d']
|
data = payload["d"]
|
||||||
|
|
||||||
if not self.state:
|
if not self.state:
|
||||||
return
|
return
|
||||||
|
|
||||||
channel_id = int_(data.get('channel_id'))
|
channel_id = int_(data.get("channel_id"))
|
||||||
guild_id = int_(data.get('guild_id'))
|
guild_id = int_(data.get("guild_id"))
|
||||||
|
|
||||||
# if its null and null, disconnect the user from any voice
|
# if its null and null, disconnect the user from any voice
|
||||||
# TODO: maybe just leave from DMs? idk...
|
# TODO: maybe just leave from DMs? idk...
|
||||||
|
|
@ -630,9 +639,7 @@ class GatewayWebsocket:
|
||||||
return await self.ext.voice.leave(guild_id, self.state.user_id)
|
return await self.ext.voice.leave(guild_id, self.state.user_id)
|
||||||
|
|
||||||
# fetch an existing state given user and guild OR user and channel
|
# fetch an existing state given user and guild OR user and channel
|
||||||
chan_type = ChannelType(
|
chan_type = ChannelType(await self.storage.get_chan_type(channel_id))
|
||||||
await self.storage.get_chan_type(channel_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
state_id2 = channel_id
|
state_id2 = channel_id
|
||||||
|
|
||||||
|
|
@ -704,39 +711,38 @@ class GatewayWebsocket:
|
||||||
# ignore unknown seqs
|
# ignore unknown seqs
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload_t = payload.get('t')
|
payload_t = payload.get("t")
|
||||||
|
|
||||||
# presence resumption happens
|
# presence resumption happens
|
||||||
# on a separate event, PRESENCE_REPLACE.
|
# on a separate event, PRESENCE_REPLACE.
|
||||||
if payload_t == 'PRESENCE_UPDATE':
|
if payload_t == "PRESENCE_UPDATE":
|
||||||
presences.append(payload.get('d'))
|
presences.append(payload.get("d"))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self.send(payload)
|
await self.send(payload)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('error while resuming')
|
log.exception("error while resuming")
|
||||||
await self.invalidate_session(False)
|
await self.invalidate_session(False)
|
||||||
return
|
return
|
||||||
|
|
||||||
if presences:
|
if presences:
|
||||||
await self.dispatch('PRESENCE_REPLACE', presences)
|
await self.dispatch("PRESENCE_REPLACE", presences)
|
||||||
|
|
||||||
await self.dispatch('RESUMED', {})
|
await self.dispatch("RESUMED", {})
|
||||||
|
|
||||||
async def handle_6(self, payload: Dict[str, Any]):
|
async def handle_6(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 6 Resume."""
|
"""Handle OP 6 Resume."""
|
||||||
data = payload['d']
|
data = payload["d"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
token, sess_id, seq = data['token'], \
|
token, sess_id, seq = data["token"], data["session_id"], data["seq"]
|
||||||
data['session_id'], data['seq']
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise DecodeError('Invalid resume payload')
|
raise DecodeError("Invalid resume payload")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_id = await raw_token_check(token, self.ext.db)
|
user_id = await raw_token_check(token, self.ext.db)
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, 'Invalid token')
|
raise WebsocketClose(4004, "Invalid token")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state = self.ext.state_manager.fetch(user_id, sess_id)
|
state = self.ext.state_manager.fetch(user_id, sess_id)
|
||||||
|
|
@ -744,11 +750,11 @@ class GatewayWebsocket:
|
||||||
return await self.invalidate_session(False)
|
return await self.invalidate_session(False)
|
||||||
|
|
||||||
if seq > state.seq:
|
if seq > state.seq:
|
||||||
raise WebsocketClose(4007, 'Invalid seq')
|
raise WebsocketClose(4007, "Invalid seq")
|
||||||
|
|
||||||
# check if a websocket isnt on that state already
|
# check if a websocket isnt on that state already
|
||||||
if state.ws is not None:
|
if state.ws is not None:
|
||||||
log.info('Resuming failed, websocket already connected')
|
log.info("Resuming failed, websocket already connected")
|
||||||
return await self.invalidate_session(False)
|
return await self.invalidate_session(False)
|
||||||
|
|
||||||
# relink this connection
|
# relink this connection
|
||||||
|
|
@ -757,8 +763,9 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
await self._resume(range(seq, state.seq))
|
await self._resume(range(seq, state.seq))
|
||||||
|
|
||||||
async def _req_guild_members(self, guild_id, user_ids: List[int],
|
async def _req_guild_members(
|
||||||
query: str, limit: int):
|
self, guild_id, user_ids: List[int], query: str, limit: int
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
guild_id = int(guild_id)
|
guild_id = int(guild_id)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
|
|
@ -778,32 +785,32 @@ class GatewayWebsocket:
|
||||||
# ASSUMPTION: requesting user_ids means we don't do query.
|
# ASSUMPTION: requesting user_ids means we don't do query.
|
||||||
if user_ids:
|
if user_ids:
|
||||||
members = await self.storage.get_member_multi(guild_id, user_ids)
|
members = await self.storage.get_member_multi(guild_id, user_ids)
|
||||||
mids = [m['user']['id'] for m in members]
|
mids = [m["user"]["id"] for m in members]
|
||||||
not_found = [uid for uid in user_ids if uid not in mids]
|
not_found = [uid for uid in user_ids if uid not in mids]
|
||||||
|
|
||||||
await self.dispatch('GUILD_MEMBERS_CHUNK', {
|
await self.dispatch(
|
||||||
'guild_id': str(guild_id),
|
"GUILD_MEMBERS_CHUNK",
|
||||||
'members': members,
|
{"guild_id": str(guild_id), "members": members, "not_found": not_found},
|
||||||
'not_found': not_found,
|
)
|
||||||
})
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# do the search
|
# do the search
|
||||||
result = await self.storage.query_members(guild_id, query, limit)
|
result = await self.storage.query_members(guild_id, query, limit)
|
||||||
await self.dispatch('GUILD_MEMBERS_CHUNK', {
|
await self.dispatch(
|
||||||
'guild_id': str(guild_id),
|
"GUILD_MEMBERS_CHUNK", {"guild_id": str(guild_id), "members": result}
|
||||||
'members': result
|
)
|
||||||
})
|
|
||||||
|
|
||||||
async def handle_8(self, payload: Dict):
|
async def handle_8(self, payload: Dict):
|
||||||
"""Handle OP 8 Request Guild Members."""
|
"""Handle OP 8 Request Guild Members."""
|
||||||
data = payload['d']
|
data = payload["d"]
|
||||||
gids = data['guild_id']
|
gids = data["guild_id"]
|
||||||
|
|
||||||
uids, query, limit = data.get('user_ids', []), \
|
uids, query, limit = (
|
||||||
data.get('query', ''), \
|
data.get("user_ids", []),
|
||||||
data.get('limit', 0)
|
data.get("query", ""),
|
||||||
|
data.get("limit", 0),
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(gids, str):
|
if isinstance(gids, str):
|
||||||
await self._req_guild_members(gids, uids, query, limit)
|
await self._req_guild_members(gids, uids, query, limit)
|
||||||
|
|
@ -820,23 +827,21 @@ class GatewayWebsocket:
|
||||||
GUILD_SYNC event with that info.
|
GUILD_SYNC event with that info.
|
||||||
"""
|
"""
|
||||||
members = await self.storage.get_member_data(guild_id)
|
members = await self.storage.get_member_data(guild_id)
|
||||||
member_ids = [int(m['user']['id']) for m in members]
|
member_ids = [int(m["user"]["id"]) for m in members]
|
||||||
|
|
||||||
log.debug(f'Syncing guild {guild_id} with {len(member_ids)} members')
|
log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members")
|
||||||
presences = await self.presence.guild_presences(member_ids, guild_id)
|
presences = await self.presence.guild_presences(member_ids, guild_id)
|
||||||
|
|
||||||
await self.dispatch('GUILD_SYNC', {
|
await self.dispatch(
|
||||||
'id': str(guild_id),
|
"GUILD_SYNC",
|
||||||
'presences': presences,
|
{"id": str(guild_id), "presences": presences, "members": members},
|
||||||
'members': members,
|
)
|
||||||
})
|
|
||||||
|
|
||||||
async def handle_12(self, payload: Dict[str, Any]):
|
async def handle_12(self, payload: Dict[str, Any]):
|
||||||
"""Handle OP 12 Guild Sync."""
|
"""Handle OP 12 Guild Sync."""
|
||||||
data = payload['d']
|
data = payload["d"]
|
||||||
|
|
||||||
gids = await self.user_storage.get_user_guilds(
|
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||||
self.state.user_id)
|
|
||||||
|
|
||||||
for guild_id in data:
|
for guild_id in data:
|
||||||
try:
|
try:
|
||||||
|
|
@ -931,35 +936,33 @@ class GatewayWebsocket:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
data = payload['d']
|
data = payload["d"]
|
||||||
|
|
||||||
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
gids = await self.user_storage.get_user_guilds(self.state.user_id)
|
||||||
guild_id = int(data['guild_id'])
|
guild_id = int(data["guild_id"])
|
||||||
|
|
||||||
# make sure to not extract info you shouldn't get
|
# make sure to not extract info you shouldn't get
|
||||||
if guild_id not in gids:
|
if guild_id not in gids:
|
||||||
return
|
return
|
||||||
|
|
||||||
log.debug('lazy request: members: {}',
|
log.debug("lazy request: members: {}", data.get("members", []))
|
||||||
data.get('members', []))
|
|
||||||
|
|
||||||
# make shard query
|
# make shard query
|
||||||
lazy_guilds = self.ext.dispatcher.backends['lazy_guild']
|
lazy_guilds = self.ext.dispatcher.backends["lazy_guild"]
|
||||||
|
|
||||||
for chan_id, ranges in data.get('channels', {}).items():
|
for chan_id, ranges in data.get("channels", {}).items():
|
||||||
chan_id = int(chan_id)
|
chan_id = int(chan_id)
|
||||||
member_list = await lazy_guilds.get_gml(chan_id)
|
member_list = await lazy_guilds.get_gml(chan_id)
|
||||||
|
|
||||||
perms = await get_permissions(
|
perms = await get_permissions(
|
||||||
self.state.user_id, chan_id, storage=self.storage)
|
self.state.user_id, chan_id, storage=self.storage
|
||||||
|
)
|
||||||
|
|
||||||
if not perms.bits.read_messages:
|
if not perms.bits.read_messages:
|
||||||
# ignore requests to unknown channels
|
# ignore requests to unknown channels
|
||||||
return
|
return
|
||||||
|
|
||||||
await member_list.shard_query(
|
await member_list.shard_query(self.state.session_id, ranges)
|
||||||
self.state.session_id, ranges
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_23(self, payload):
|
async def _handle_23(self, payload):
|
||||||
# TODO reverse-engineer opcode 23, sent by client
|
# TODO reverse-engineer opcode 23, sent by client
|
||||||
|
|
@ -968,21 +971,21 @@ class GatewayWebsocket:
|
||||||
async def _process_message(self, payload):
|
async def _process_message(self, payload):
|
||||||
"""Process a single message coming in from the client."""
|
"""Process a single message coming in from the client."""
|
||||||
try:
|
try:
|
||||||
op_code = payload['op']
|
op_code = payload["op"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise UnknownOPCode('No OP code')
|
raise UnknownOPCode("No OP code")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handler = getattr(self, f'handle_{op_code}')
|
handler = getattr(self, f"handle_{op_code}")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
log.warning('Payload with bad op: {}', pprint.pformat(payload))
|
log.warning("Payload with bad op: {}", pprint.pformat(payload))
|
||||||
raise UnknownOPCode(f'Bad OP code: {op_code}')
|
raise UnknownOPCode(f"Bad OP code: {op_code}")
|
||||||
|
|
||||||
await handler(payload)
|
await handler(payload)
|
||||||
|
|
||||||
async def _msg_ratelimit(self):
|
async def _msg_ratelimit(self):
|
||||||
if self._check_ratelimit('messages', self.state.session_id):
|
if self._check_ratelimit("messages", self.state.session_id):
|
||||||
raise WebsocketClose(4008, 'You are being ratelimited.')
|
raise WebsocketClose(4008, "You are being ratelimited.")
|
||||||
|
|
||||||
async def _listen_messages(self):
|
async def _listen_messages(self):
|
||||||
"""Listen for messages coming in from the websocket."""
|
"""Listen for messages coming in from the websocket."""
|
||||||
|
|
@ -990,15 +993,15 @@ class GatewayWebsocket:
|
||||||
# close anyone trying to login while the
|
# close anyone trying to login while the
|
||||||
# server is shutting down
|
# server is shutting down
|
||||||
if self.ext.state_manager.closed:
|
if self.ext.state_manager.closed:
|
||||||
raise WebsocketClose(4000, 'state manager closed')
|
raise WebsocketClose(4000, "state manager closed")
|
||||||
|
|
||||||
if not self.ext.state_manager.accept_new:
|
if not self.ext.state_manager.accept_new:
|
||||||
raise WebsocketClose(4000, 'state manager closed for new')
|
raise WebsocketClose(4000, "state manager closed for new")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
message = await self.ws.recv()
|
message = await self.ws.recv()
|
||||||
if len(message) > 4096:
|
if len(message) > 4096:
|
||||||
raise DecodeError('Payload length exceeded')
|
raise DecodeError("Payload length exceeded")
|
||||||
|
|
||||||
if self.state:
|
if self.state:
|
||||||
await self._msg_ratelimit()
|
await self._msg_ratelimit()
|
||||||
|
|
@ -1033,17 +1036,9 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
# there arent any other states with websocket
|
# there arent any other states with websocket
|
||||||
if not with_ws:
|
if not with_ws:
|
||||||
offline = {
|
offline = {"afk": False, "status": "offline", "game": None, "since": 0}
|
||||||
'afk': False,
|
|
||||||
'status': 'offline',
|
|
||||||
'game': None,
|
|
||||||
'since': 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.ext.presence.dispatch_pres(
|
await self.ext.presence.dispatch_pres(user_id, offline)
|
||||||
user_id,
|
|
||||||
offline
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Wrap :meth:`listen_messages` inside
|
"""Wrap :meth:`listen_messages` inside
|
||||||
|
|
@ -1052,12 +1047,12 @@ class GatewayWebsocket:
|
||||||
await self._send_hello()
|
await self._send_hello()
|
||||||
await self._listen_messages()
|
await self._listen_messages()
|
||||||
except websockets.exceptions.ConnectionClosed as err:
|
except websockets.exceptions.ConnectionClosed as err:
|
||||||
log.warning('conn close, state={}, err={}', self.state, err)
|
log.warning("conn close, state={}, err={}", self.state, err)
|
||||||
except WebsocketClose as err:
|
except WebsocketClose as err:
|
||||||
log.warning('ws close, state={} err={}', self.state, err)
|
log.warning("ws close, state={} err={}", self.state, err)
|
||||||
await self.ws.close(code=err.code, reason=err.reason)
|
await self.ws.close(code=err.code, reason=err.reason)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
log.exception('An exception has occoured. state={}', self.state)
|
log.exception("An exception has occoured. state={}", self.state)
|
||||||
await self.ws.close(code=4000, reason=repr(err))
|
await self.ws.close(code=4000, reason=repr(err))
|
||||||
finally:
|
finally:
|
||||||
user_id = self.state.user_id if self.state else None
|
user_id = self.state.user_id if self.state else None
|
||||||
|
|
|
||||||
|
|
@ -17,19 +17,21 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GuildMemoryStore:
|
class GuildMemoryStore:
|
||||||
"""Store in-memory properties about guilds.
|
"""Store in-memory properties about guilds.
|
||||||
|
|
||||||
I could have just used Redis... probably too overkill to add
|
I could have just used Redis... probably too overkill to add
|
||||||
aioredis to the already long depedency list, plus, I don't need
|
aioredis to the already long depedency list, plus, I don't need
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._store = {}
|
self._store = {}
|
||||||
|
|
||||||
def get(self, guild_id: int, attribute: str, default=None):
|
def get(self, guild_id: int, attribute: str, default=None):
|
||||||
"""get a key"""
|
"""get a key"""
|
||||||
return self._store.get(f'{guild_id}:{attribute}', default)
|
return self._store.get(f"{guild_id}:{attribute}", default)
|
||||||
|
|
||||||
def set(self, guild_id: int, attribute: str, value):
|
def set(self, guild_id: int, attribute: str, value):
|
||||||
"""set a key"""
|
"""set a key"""
|
||||||
self._store[f'{guild_id}:{attribute}'] = value
|
self._store[f"{guild_id}:{attribute}"] = value
|
||||||
|
|
|
||||||
|
|
@ -33,47 +33,42 @@ from logbook import Logger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
IMAGE_FOLDER = Path('./images')
|
IMAGE_FOLDER = Path("./images")
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
EXTENSIONS = {
|
EXTENSIONS = {"image/jpeg": "jpeg", "image/webp": "webp"}
|
||||||
'image/jpeg': 'jpeg',
|
|
||||||
'image/webp': 'webp'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MIMES = {
|
MIMES = {
|
||||||
'jpg': 'image/jpeg',
|
"jpg": "image/jpeg",
|
||||||
'jpe': 'image/jpeg',
|
"jpe": "image/jpeg",
|
||||||
'jpeg': 'image/jpeg',
|
"jpeg": "image/jpeg",
|
||||||
'webp': 'image/webp',
|
"webp": "image/webp",
|
||||||
}
|
}
|
||||||
|
|
||||||
STATIC_IMAGE_MIMES = [
|
STATIC_IMAGE_MIMES = ["image/png", "image/jpeg", "image/webp"]
|
||||||
'image/png',
|
|
||||||
'image/jpeg',
|
|
||||||
'image/webp'
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_ext(mime: str) -> str:
|
def get_ext(mime: str) -> str:
|
||||||
if mime in EXTENSIONS:
|
if mime in EXTENSIONS:
|
||||||
return EXTENSIONS[mime]
|
return EXTENSIONS[mime]
|
||||||
|
|
||||||
extensions = mimetypes.guess_all_extensions(mime)
|
extensions = mimetypes.guess_all_extensions(mime)
|
||||||
return extensions[0].strip('.')
|
return extensions[0].strip(".")
|
||||||
|
|
||||||
|
|
||||||
def get_mime(ext: str):
|
def get_mime(ext: str):
|
||||||
if ext in MIMES:
|
if ext in MIMES:
|
||||||
return MIMES[ext]
|
return MIMES[ext]
|
||||||
|
|
||||||
return mimetypes.types_map[f'.{ext}']
|
return mimetypes.types_map[f".{ext}"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Icon:
|
class Icon:
|
||||||
"""Main icon class"""
|
"""Main icon class"""
|
||||||
|
|
||||||
key: Optional[str]
|
key: Optional[str]
|
||||||
icon_hash: Optional[str]
|
icon_hash: Optional[str]
|
||||||
mime: Optional[str]
|
mime: Optional[str]
|
||||||
|
|
@ -85,7 +80,7 @@ class Icon:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ext = get_ext(self.mime)
|
ext = get_ext(self.mime)
|
||||||
return str(IMAGE_FOLDER / f'{self.key}_{self.icon_hash}.{ext}')
|
return str(IMAGE_FOLDER / f"{self.key}_{self.icon_hash}.{ext}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def as_pathlib(self) -> Optional[Path]:
|
def as_pathlib(self) -> Optional[Path]:
|
||||||
|
|
@ -106,13 +101,14 @@ class Icon:
|
||||||
|
|
||||||
class ImageError(Exception):
|
class ImageError(Exception):
|
||||||
"""Image error class."""
|
"""Image error class."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def to_raw(data_type: str, data: str) -> Optional[bytes]:
|
def to_raw(data_type: str, data: str) -> Optional[bytes]:
|
||||||
"""Given a data type in the data URI and data,
|
"""Given a data type in the data URI and data,
|
||||||
give the raw bytes being encoded."""
|
give the raw bytes being encoded."""
|
||||||
if data_type == 'base64':
|
if data_type == "base64":
|
||||||
return base64.b64decode(data)
|
return base64.b64decode(data)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
@ -136,7 +132,7 @@ def _calculate_hash(fhandler) -> str:
|
||||||
"""
|
"""
|
||||||
hash_obj = sha256()
|
hash_obj = sha256()
|
||||||
|
|
||||||
for chunk in iter(lambda: fhandler.read(4096), b''):
|
for chunk in iter(lambda: fhandler.read(4096), b""):
|
||||||
hash_obj.update(chunk)
|
hash_obj.update(chunk)
|
||||||
|
|
||||||
# so that we can reuse the same handler
|
# so that we can reuse the same handler
|
||||||
|
|
@ -162,39 +158,36 @@ async def calculate_hash(fhandle, loop=None) -> str:
|
||||||
def parse_data_uri(string) -> tuple:
|
def parse_data_uri(string) -> tuple:
|
||||||
"""Extract image data."""
|
"""Extract image data."""
|
||||||
try:
|
try:
|
||||||
header, headered_data = string.split(';')
|
header, headered_data = string.split(";")
|
||||||
|
|
||||||
_, given_mime = header.split(':')
|
_, given_mime = header.split(":")
|
||||||
data_type, data = headered_data.split(',')
|
data_type, data = headered_data.split(",")
|
||||||
|
|
||||||
raw_data = to_raw(data_type, data)
|
raw_data = to_raw(data_type, data)
|
||||||
if raw_data is None:
|
if raw_data is None:
|
||||||
raise ImageError('Unknown data header')
|
raise ImageError("Unknown data header")
|
||||||
|
|
||||||
return given_mime, raw_data
|
return given_mime, raw_data
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ImageError('data URI invalid syntax')
|
raise ImageError("data URI invalid syntax")
|
||||||
|
|
||||||
|
|
||||||
def _gen_update_sql(scope: str) -> str:
|
def _gen_update_sql(scope: str) -> str:
|
||||||
# match a scope to (table, field)
|
# match a scope to (table, field)
|
||||||
field = {
|
field = {
|
||||||
'user': 'avatar',
|
"user": "avatar",
|
||||||
'guild': 'icon',
|
"guild": "icon",
|
||||||
'splash': 'splash',
|
"splash": "splash",
|
||||||
'banner': 'banner',
|
"banner": "banner",
|
||||||
|
"channel-icons": "icon",
|
||||||
'channel-icons': 'icon',
|
|
||||||
}[scope]
|
}[scope]
|
||||||
|
|
||||||
table = {
|
table = {
|
||||||
'user': 'users',
|
"user": "users",
|
||||||
|
"guild": "guilds",
|
||||||
'guild': 'guilds',
|
"splash": "guilds",
|
||||||
'splash': 'guilds',
|
"banner": "guilds",
|
||||||
'banner': 'guilds',
|
"channel-icons": "group_dm_channels",
|
||||||
|
|
||||||
'channel-icons': 'group_dm_channels'
|
|
||||||
}[scope]
|
}[scope]
|
||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
|
|
@ -204,10 +197,10 @@ def _gen_update_sql(scope: str) -> str:
|
||||||
|
|
||||||
def _invalid(kwargs: dict) -> Optional[Icon]:
|
def _invalid(kwargs: dict) -> Optional[Icon]:
|
||||||
"""Send an invalid value."""
|
"""Send an invalid value."""
|
||||||
if not kwargs.get('always_icon', False):
|
if not kwargs.get("always_icon", False):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return Icon(None, None, '')
|
return Icon(None, None, "")
|
||||||
|
|
||||||
|
|
||||||
def try_unlink(path: Union[Path, str]):
|
def try_unlink(path: Union[Path, str]):
|
||||||
|
|
@ -225,18 +218,17 @@ def try_unlink(path: Union[Path, str]):
|
||||||
async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
||||||
"""Resize a GIF image."""
|
"""Resize a GIF image."""
|
||||||
# generate a temporary file to call gifsticle to and from.
|
# generate a temporary file to call gifsticle to and from.
|
||||||
input_fd, input_path = tempfile.mkstemp(suffix='.gif')
|
input_fd, input_path = tempfile.mkstemp(suffix=".gif")
|
||||||
_, output_path = tempfile.mkstemp(suffix='.gif')
|
_, output_path = tempfile.mkstemp(suffix=".gif")
|
||||||
|
|
||||||
input_handler = os.fdopen(input_fd, 'wb')
|
input_handler = os.fdopen(input_fd, "wb")
|
||||||
|
|
||||||
# make sure its valid image data
|
# make sure its valid image data
|
||||||
data_fd = BytesIO(raw_data)
|
data_fd = BytesIO(raw_data)
|
||||||
image = Image.open(data_fd)
|
image = Image.open(data_fd)
|
||||||
image.close()
|
image.close()
|
||||||
|
|
||||||
log.info('resizing a GIF from {} to {}',
|
log.info("resizing a GIF from {} to {}", image.size, target)
|
||||||
image.size, target)
|
|
||||||
|
|
||||||
# insert image info on input_handler
|
# insert image info on input_handler
|
||||||
# close it to make it ready for consumption by gifsicle
|
# close it to make it ready for consumption by gifsicle
|
||||||
|
|
@ -244,12 +236,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
||||||
input_handler.close()
|
input_handler.close()
|
||||||
|
|
||||||
# call gifsicle under subprocess
|
# call gifsicle under subprocess
|
||||||
log.debug('input: {}', input_path)
|
log.debug("input: {}", input_path)
|
||||||
log.debug('output: {}', output_path)
|
log.debug("output: {}", output_path)
|
||||||
|
|
||||||
process = await asyncio.create_subprocess_shell(
|
process = await asyncio.create_subprocess_shell(
|
||||||
f'gifsicle --resize {target[0]}x{target[1]} '
|
f"gifsicle --resize {target[0]}x{target[1]} " f"{input_path} > {output_path}",
|
||||||
f'{input_path} > {output_path}',
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
|
|
@ -257,11 +248,11 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
||||||
# run it, etc.
|
# run it, etc.
|
||||||
out, err = await process.communicate()
|
out, err = await process.communicate()
|
||||||
|
|
||||||
log.debug('out + err from gifsicle: {}', out + err)
|
log.debug("out + err from gifsicle: {}", out + err)
|
||||||
|
|
||||||
# write over an empty data_fd
|
# write over an empty data_fd
|
||||||
data_fd = BytesIO()
|
data_fd = BytesIO()
|
||||||
output_handler = open(output_path, 'rb')
|
output_handler = open(output_path, "rb")
|
||||||
data_fd.write(output_handler.read())
|
data_fd.write(output_handler.read())
|
||||||
|
|
||||||
# close unused handlers
|
# close unused handlers
|
||||||
|
|
@ -283,40 +274,40 @@ async def resize_gif(raw_data: bytes, target: tuple) -> tuple:
|
||||||
|
|
||||||
class IconManager:
|
class IconManager:
|
||||||
"""Main icon manager."""
|
"""Main icon manager."""
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.storage = app.storage
|
self.storage = app.storage
|
||||||
|
|
||||||
async def _convert_ext(self, icon: Icon, target: str):
|
async def _convert_ext(self, icon: Icon, target: str):
|
||||||
target = 'jpeg' if target == 'jpg' else target
|
target = "jpeg" if target == "jpg" else target
|
||||||
|
|
||||||
target_mime = get_mime(target)
|
target_mime = get_mime(target)
|
||||||
log.info('converting from {} to {}', icon.mime, target_mime)
|
log.info("converting from {} to {}", icon.mime, target_mime)
|
||||||
|
|
||||||
target_path = IMAGE_FOLDER / f'{icon.key}_{icon.icon_hash}.{target}'
|
target_path = IMAGE_FOLDER / f"{icon.key}_{icon.icon_hash}.{target}"
|
||||||
|
|
||||||
if target_path.exists():
|
if target_path.exists():
|
||||||
return Icon(icon.key, icon.icon_hash, target_mime)
|
return Icon(icon.key, icon.icon_hash, target_mime)
|
||||||
|
|
||||||
image = Image.open(icon.as_path)
|
image = Image.open(icon.as_path)
|
||||||
target_fd = target_path.open('wb')
|
target_fd = target_path.open("wb")
|
||||||
|
|
||||||
if target == 'jpeg':
|
if target == "jpeg":
|
||||||
image = image.convert('RGB')
|
image = image.convert("RGB")
|
||||||
|
|
||||||
image.save(target_fd, format=target)
|
image.save(target_fd, format=target)
|
||||||
target_fd.close()
|
target_fd.close()
|
||||||
|
|
||||||
return Icon(icon.key, icon.icon_hash, target_mime)
|
return Icon(icon.key, icon.icon_hash, target_mime)
|
||||||
|
|
||||||
async def generic_get(self, scope, key, icon_hash,
|
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Optional[Icon]:
|
||||||
**kwargs) -> Optional[Icon]:
|
|
||||||
"""Get any icon."""
|
"""Get any icon."""
|
||||||
|
|
||||||
log.debug('GET {} {} {}', scope, key, icon_hash)
|
log.debug("GET {} {} {}", scope, key, icon_hash)
|
||||||
key = str(key)
|
key = str(key)
|
||||||
|
|
||||||
hash_query = 'AND hash = $3' if icon_hash else ''
|
hash_query = "AND hash = $3" if icon_hash else ""
|
||||||
|
|
||||||
# hacky solution to only add icon_hash
|
# hacky solution to only add icon_hash
|
||||||
# when needed.
|
# when needed.
|
||||||
|
|
@ -325,18 +316,21 @@ class IconManager:
|
||||||
if icon_hash:
|
if icon_hash:
|
||||||
args.append(icon_hash)
|
args.append(icon_hash)
|
||||||
|
|
||||||
icon_row = await self.storage.db.fetchrow(f"""
|
icon_row = await self.storage.db.fetchrow(
|
||||||
|
f"""
|
||||||
SELECT key, hash, mime
|
SELECT key, hash, mime
|
||||||
FROM icons
|
FROM icons
|
||||||
WHERE scope = $1
|
WHERE scope = $1
|
||||||
AND key = $2
|
AND key = $2
|
||||||
{hash_query}
|
{hash_query}
|
||||||
""", *args)
|
""",
|
||||||
|
*args,
|
||||||
|
)
|
||||||
|
|
||||||
if not icon_row:
|
if not icon_row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
icon = Icon(icon_row['key'], icon_row['hash'], icon_row['mime'])
|
icon = Icon(icon_row["key"], icon_row["hash"], icon_row["mime"])
|
||||||
|
|
||||||
# ensure we aren't messing with NULLs everywhere.
|
# ensure we aren't messing with NULLs everywhere.
|
||||||
if icon.as_pathlib is None:
|
if icon.as_pathlib is None:
|
||||||
|
|
@ -349,18 +343,16 @@ class IconManager:
|
||||||
if icon.extension is None:
|
if icon.extension is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if 'ext' in kwargs and kwargs['ext'] != icon.extension:
|
if "ext" in kwargs and kwargs["ext"] != icon.extension:
|
||||||
return await self._convert_ext(icon, kwargs['ext'])
|
return await self._convert_ext(icon, kwargs["ext"])
|
||||||
|
|
||||||
return icon
|
return icon
|
||||||
|
|
||||||
async def get_guild_icon(self, guild_id: int, icon_hash: str, **kwargs):
|
async def get_guild_icon(self, guild_id: int, icon_hash: str, **kwargs):
|
||||||
"""Get an icon for a guild."""
|
"""Get an icon for a guild."""
|
||||||
return await self.generic_get(
|
return await self.generic_get("guild", guild_id, icon_hash, **kwargs)
|
||||||
'guild', guild_id, icon_hash, **kwargs)
|
|
||||||
|
|
||||||
async def put(self, scope: str, key: str,
|
async def put(self, scope: str, key: str, b64_data: str, **kwargs) -> Icon:
|
||||||
b64_data: str, **kwargs) -> Icon:
|
|
||||||
"""Insert an icon."""
|
"""Insert an icon."""
|
||||||
if b64_data is None:
|
if b64_data is None:
|
||||||
return _invalid(kwargs)
|
return _invalid(kwargs)
|
||||||
|
|
@ -373,23 +365,22 @@ class IconManager:
|
||||||
# get an extension for the given data uri
|
# get an extension for the given data uri
|
||||||
extension = get_ext(mime)
|
extension = get_ext(mime)
|
||||||
|
|
||||||
if 'bsize' in kwargs and len(raw_data) > kwargs['bsize']:
|
if "bsize" in kwargs and len(raw_data) > kwargs["bsize"]:
|
||||||
return _invalid(kwargs)
|
return _invalid(kwargs)
|
||||||
|
|
||||||
# size management is different for gif files
|
# size management is different for gif files
|
||||||
# as they're composed of multiple frames.
|
# as they're composed of multiple frames.
|
||||||
if 'size' in kwargs and mime == 'image/gif':
|
if "size" in kwargs and mime == "image/gif":
|
||||||
data_fd, raw_data = await resize_gif(raw_data, kwargs['size'])
|
data_fd, raw_data = await resize_gif(raw_data, kwargs["size"])
|
||||||
elif 'size' in kwargs:
|
elif "size" in kwargs:
|
||||||
image = Image.open(data_fd)
|
image = Image.open(data_fd)
|
||||||
|
|
||||||
if mime == 'image/jpeg':
|
if mime == "image/jpeg":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
want = kwargs['size']
|
want = kwargs["size"]
|
||||||
|
|
||||||
log.info('resizing from {} to {}',
|
log.info("resizing from {} to {}", image.size, want)
|
||||||
image.size, want)
|
|
||||||
|
|
||||||
resized = image.resize(want, resample=Image.LANCZOS)
|
resized = image.resize(want, resample=Image.LANCZOS)
|
||||||
|
|
||||||
|
|
@ -404,23 +395,26 @@ class IconManager:
|
||||||
|
|
||||||
# calculate sha256
|
# calculate sha256
|
||||||
# ignore icon hashes if we're talking about emoji
|
# ignore icon hashes if we're talking about emoji
|
||||||
icon_hash = (await calculate_hash(data_fd)
|
icon_hash = await calculate_hash(data_fd) if scope != "emoji" else None
|
||||||
if scope != 'emoji'
|
|
||||||
else None)
|
|
||||||
|
|
||||||
if scope == 'user' and mime == 'image/gif':
|
if scope == "user" and mime == "image/gif":
|
||||||
icon_hash = f'a_{icon_hash}'
|
icon_hash = f"a_{icon_hash}"
|
||||||
|
|
||||||
log.debug('PUT icon {!r} {!r} {!r} {!r}',
|
log.debug("PUT icon {!r} {!r} {!r} {!r}", scope, key, icon_hash, mime)
|
||||||
scope, key, icon_hash, mime)
|
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO icons (scope, key, hash, mime)
|
INSERT INTO icons (scope, key, hash, mime)
|
||||||
VALUES ($1, $2, $3, $4)
|
VALUES ($1, $2, $3, $4)
|
||||||
""", scope, str(key), icon_hash, mime)
|
""",
|
||||||
|
scope,
|
||||||
|
str(key),
|
||||||
|
icon_hash,
|
||||||
|
mime,
|
||||||
|
)
|
||||||
|
|
||||||
# write it off to fs
|
# write it off to fs
|
||||||
icon_path = IMAGE_FOLDER / f'{key}_{icon_hash}.{extension}'
|
icon_path = IMAGE_FOLDER / f"{key}_{icon_hash}.{extension}"
|
||||||
icon_path.write_bytes(raw_data)
|
icon_path.write_bytes(raw_data)
|
||||||
|
|
||||||
# copy from data_fd to icon_fd
|
# copy from data_fd to icon_fd
|
||||||
|
|
@ -434,57 +428,80 @@ class IconManager:
|
||||||
if not icon:
|
if not icon:
|
||||||
return
|
return
|
||||||
|
|
||||||
log.debug('DEL {}',
|
log.debug("DEL {}", icon)
|
||||||
icon)
|
|
||||||
|
|
||||||
# dereference
|
# dereference
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET avatar = NULL
|
SET avatar = NULL
|
||||||
WHERE avatar = $1
|
WHERE avatar = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
UPDATE group_dm_channels
|
UPDATE group_dm_channels
|
||||||
SET icon = NULL
|
SET icon = NULL
|
||||||
WHERE icon = $1
|
WHERE icon = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM guild_emoji
|
DELETE FROM guild_emoji
|
||||||
WHERE image = $1
|
WHERE image = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET icon = NULL
|
SET icon = NULL
|
||||||
WHERE icon = $1
|
WHERE icon = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET splash = NULL
|
SET splash = NULL
|
||||||
WHERE splash = $1
|
WHERE splash = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
UPDATE guilds
|
UPDATE guilds
|
||||||
SET banner = NULL
|
SET banner = NULL
|
||||||
WHERE banner = $1
|
WHERE banner = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
UPDATE group_dm_channels
|
UPDATE group_dm_channels
|
||||||
SET icon = NULL
|
SET icon = NULL
|
||||||
WHERE icon = $1
|
WHERE icon = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
await self.storage.db.execute("""
|
await self.storage.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM icons
|
DELETE FROM icons
|
||||||
WHERE hash = $1
|
WHERE hash = $1
|
||||||
""", icon.icon_hash)
|
""",
|
||||||
|
icon.icon_hash,
|
||||||
|
)
|
||||||
|
|
||||||
paths = IMAGE_FOLDER.glob(f'{icon.key}_{icon.icon_hash}.*')
|
paths = IMAGE_FOLDER.glob(f"{icon.key}_{icon.icon_hash}.*")
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
try:
|
try:
|
||||||
|
|
@ -492,11 +509,9 @@ class IconManager:
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def update(self, scope: str, key: str,
|
async def update(self, scope: str, key: str, new_icon_data: str, **kwargs) -> Icon:
|
||||||
new_icon_data: str, **kwargs) -> Icon:
|
|
||||||
"""Update an icon on a key."""
|
"""Update an icon on a key."""
|
||||||
old_icon_hash = await self.storage.db.fetchval(
|
old_icon_hash = await self.storage.db.fetchval(_gen_update_sql(scope), key)
|
||||||
_gen_update_sql(scope), key)
|
|
||||||
|
|
||||||
# converting key to str only here since from here onwards
|
# converting key to str only here since from here onwards
|
||||||
# its operations on the icons table (or a dereference with
|
# its operations on the icons table (or a dereference with
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -30,6 +31,7 @@ class JobManager:
|
||||||
use helpers such as asyncio.gather and asyncio.Task.all_tasks. It only uses
|
use helpers such as asyncio.gather and asyncio.Task.all_tasks. It only uses
|
||||||
its own internal list of jobs.
|
its own internal list of jobs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, loop=None):
|
def __init__(self, loop=None):
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
self.jobs = []
|
self.jobs = []
|
||||||
|
|
@ -41,13 +43,11 @@ class JobManager:
|
||||||
try:
|
try:
|
||||||
await coro
|
await coro
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('Error while running job')
|
log.exception("Error while running job")
|
||||||
|
|
||||||
def spawn(self, coro):
|
def spawn(self, coro):
|
||||||
"""Spawn a given future or coroutine in the background."""
|
"""Spawn a given future or coroutine in the background."""
|
||||||
task = self.loop.create_task(
|
task = self.loop.create_task(self._wrapper(coro))
|
||||||
self._wrapper(coro)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.jobs.append(task)
|
self.jobs.append(task)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,40 +26,42 @@ from quart import current_app as app
|
||||||
# type for all the fields
|
# type for all the fields
|
||||||
_i = ctypes.c_uint8
|
_i = ctypes.c_uint8
|
||||||
|
|
||||||
|
|
||||||
class _RawPermsBits(ctypes.LittleEndianStructure):
|
class _RawPermsBits(ctypes.LittleEndianStructure):
|
||||||
"""raw bitfield for discord's permission number."""
|
"""raw bitfield for discord's permission number."""
|
||||||
|
|
||||||
_fields_ = [
|
_fields_ = [
|
||||||
('create_invites', _i, 1),
|
("create_invites", _i, 1),
|
||||||
('kick_members', _i, 1),
|
("kick_members", _i, 1),
|
||||||
('ban_members', _i, 1),
|
("ban_members", _i, 1),
|
||||||
('administrator', _i, 1),
|
("administrator", _i, 1),
|
||||||
('manage_channels', _i, 1),
|
("manage_channels", _i, 1),
|
||||||
('manage_guild', _i, 1),
|
("manage_guild", _i, 1),
|
||||||
('add_reactions', _i, 1),
|
("add_reactions", _i, 1),
|
||||||
('view_audit_log', _i, 1),
|
("view_audit_log", _i, 1),
|
||||||
('priority_speaker', _i, 1),
|
("priority_speaker", _i, 1),
|
||||||
('stream', _i, 1),
|
("stream", _i, 1),
|
||||||
('read_messages', _i, 1),
|
("read_messages", _i, 1),
|
||||||
('send_messages', _i, 1),
|
("send_messages", _i, 1),
|
||||||
('send_tts', _i, 1),
|
("send_tts", _i, 1),
|
||||||
('manage_messages', _i, 1),
|
("manage_messages", _i, 1),
|
||||||
('embed_links', _i, 1),
|
("embed_links", _i, 1),
|
||||||
('attach_files', _i, 1),
|
("attach_files", _i, 1),
|
||||||
('read_history', _i, 1),
|
("read_history", _i, 1),
|
||||||
('mention_everyone', _i, 1),
|
("mention_everyone", _i, 1),
|
||||||
('external_emojis', _i, 1),
|
("external_emojis", _i, 1),
|
||||||
('_unused2', _i, 1),
|
("_unused2", _i, 1),
|
||||||
('connect', _i, 1),
|
("connect", _i, 1),
|
||||||
('speak', _i, 1),
|
("speak", _i, 1),
|
||||||
('mute_members', _i, 1),
|
("mute_members", _i, 1),
|
||||||
('deafen_members', _i, 1),
|
("deafen_members", _i, 1),
|
||||||
('move_members', _i, 1),
|
("move_members", _i, 1),
|
||||||
('use_voice_activation', _i, 1),
|
("use_voice_activation", _i, 1),
|
||||||
('change_nickname', _i, 1),
|
("change_nickname", _i, 1),
|
||||||
('manage_nicknames', _i, 1),
|
("manage_nicknames", _i, 1),
|
||||||
('manage_roles', _i, 1),
|
("manage_roles", _i, 1),
|
||||||
('manage_webhooks', _i, 1),
|
("manage_webhooks", _i, 1),
|
||||||
('manage_emojis', _i, 1),
|
("manage_emojis", _i, 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -72,16 +74,14 @@ class Permissions(ctypes.Union):
|
||||||
val
|
val
|
||||||
The permissions value as an integer.
|
The permissions value as an integer.
|
||||||
"""
|
"""
|
||||||
_fields_ = [
|
|
||||||
('bits', _RawPermsBits),
|
_fields_ = [("bits", _RawPermsBits), ("binary", ctypes.c_uint64)]
|
||||||
('binary', ctypes.c_uint64),
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, val: int):
|
def __init__(self, val: int):
|
||||||
self.binary = val
|
self.binary = val
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'<Permissions binary={self.binary}>'
|
return f"<Permissions binary={self.binary}>"
|
||||||
|
|
||||||
def __int__(self):
|
def __int__(self):
|
||||||
return self.binary
|
return self.binary
|
||||||
|
|
@ -95,11 +95,15 @@ async def get_role_perms(guild_id, role_id, storage=None) -> Permissions:
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = app.storage
|
storage = app.storage
|
||||||
|
|
||||||
perms = await storage.db.fetchval("""
|
perms = await storage.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT permissions
|
SELECT permissions
|
||||||
FROM roles
|
FROM roles
|
||||||
WHERE guild_id = $1 AND id = $2
|
WHERE guild_id = $1 AND id = $2
|
||||||
""", guild_id, role_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
role_id,
|
||||||
|
)
|
||||||
|
|
||||||
return Permissions(perms)
|
return Permissions(perms)
|
||||||
|
|
||||||
|
|
@ -118,11 +122,14 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = app.storage
|
storage = app.storage
|
||||||
|
|
||||||
owner_id = await storage.db.fetchval("""
|
owner_id = await storage.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT owner_id
|
SELECT owner_id
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if owner_id == member_id:
|
if owner_id == member_id:
|
||||||
return ALL_PERMISSIONS
|
return ALL_PERMISSIONS
|
||||||
|
|
@ -130,20 +137,27 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
|
||||||
# get permissions for @everyone
|
# get permissions for @everyone
|
||||||
permissions = await get_role_perms(guild_id, guild_id, storage)
|
permissions = await get_role_perms(guild_id, guild_id, storage)
|
||||||
|
|
||||||
role_ids = await storage.db.fetch("""
|
role_ids = await storage.db.fetch(
|
||||||
|
"""
|
||||||
SELECT role_id
|
SELECT role_id
|
||||||
FROM member_roles
|
FROM member_roles
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
""", guild_id, member_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
member_id,
|
||||||
|
)
|
||||||
|
|
||||||
role_perms = []
|
role_perms = []
|
||||||
|
|
||||||
for row in role_ids:
|
for row in role_ids:
|
||||||
rperm = await storage.db.fetchval("""
|
rperm = await storage.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT permissions
|
SELECT permissions
|
||||||
FROM roles
|
FROM roles
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", row['role_id'])
|
""",
|
||||||
|
row["role_id"],
|
||||||
|
)
|
||||||
|
|
||||||
role_perms.append(rperm)
|
role_perms.append(rperm)
|
||||||
|
|
||||||
|
|
@ -164,16 +178,17 @@ def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
|
||||||
result = perms.binary
|
result = perms.binary
|
||||||
|
|
||||||
# negate the permissions that are denied
|
# negate the permissions that are denied
|
||||||
result &= ~overwrite['deny']
|
result &= ~overwrite["deny"]
|
||||||
|
|
||||||
# combine the permissions that are allowed
|
# combine the permissions that are allowed
|
||||||
result |= overwrite['allow']
|
result |= overwrite["allow"]
|
||||||
|
|
||||||
return Permissions(result)
|
return Permissions(result)
|
||||||
|
|
||||||
|
|
||||||
def overwrite_find_mix(perms: Permissions, overwrites: dict,
|
def overwrite_find_mix(
|
||||||
target_id: int) -> Permissions:
|
perms: Permissions, overwrites: dict, target_id: int
|
||||||
|
) -> Permissions:
|
||||||
"""Mix a given permission with a given overwrite.
|
"""Mix a given permission with a given overwrite.
|
||||||
|
|
||||||
Returns the given permission if an overwrite is not found.
|
Returns the given permission if an overwrite is not found.
|
||||||
|
|
@ -201,19 +216,25 @@ def overwrite_find_mix(perms: Permissions, overwrites: dict,
|
||||||
return perms
|
return perms
|
||||||
|
|
||||||
|
|
||||||
async def role_permissions(guild_id: int, role_id: int,
|
async def role_permissions(
|
||||||
channel_id: int, storage=None) -> Permissions:
|
guild_id: int, role_id: int, channel_id: int, storage=None
|
||||||
|
) -> Permissions:
|
||||||
"""Get the permissions for a role, in relation to a channel"""
|
"""Get the permissions for a role, in relation to a channel"""
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = app.storage
|
storage = app.storage
|
||||||
|
|
||||||
perms = await get_role_perms(guild_id, role_id, storage)
|
perms = await get_role_perms(guild_id, role_id, storage)
|
||||||
|
|
||||||
overwrite = await storage.db.fetchrow("""
|
overwrite = await storage.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT allow, deny
|
SELECT allow, deny
|
||||||
FROM channel_overwrites
|
FROM channel_overwrites
|
||||||
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
|
WHERE channel_id = $1 AND target_type = $2 AND target_role = $3
|
||||||
""", channel_id, 1, role_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
1,
|
||||||
|
role_id,
|
||||||
|
)
|
||||||
|
|
||||||
if overwrite:
|
if overwrite:
|
||||||
perms = overwrite_mix(perms, overwrite)
|
perms = overwrite_mix(perms, overwrite)
|
||||||
|
|
@ -221,10 +242,13 @@ async def role_permissions(guild_id: int, role_id: int,
|
||||||
return perms
|
return perms
|
||||||
|
|
||||||
|
|
||||||
async def compute_overwrites(base_perms: Permissions,
|
async def compute_overwrites(
|
||||||
user_id, channel_id: int,
|
base_perms: Permissions,
|
||||||
|
user_id,
|
||||||
|
channel_id: int,
|
||||||
guild_id: Optional[int] = None,
|
guild_id: Optional[int] = None,
|
||||||
storage=None):
|
storage=None,
|
||||||
|
):
|
||||||
"""Compute the permissions in the context of a channel."""
|
"""Compute the permissions in the context of a channel."""
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = app.storage
|
storage = app.storage
|
||||||
|
|
@ -245,7 +269,7 @@ async def compute_overwrites(base_perms: Permissions,
|
||||||
return ALL_PERMISSIONS
|
return ALL_PERMISSIONS
|
||||||
|
|
||||||
# make it a map for better usage
|
# make it a map for better usage
|
||||||
overwrites = {int(o['id']): o for o in overwrites}
|
overwrites = {int(o["id"]): o for o in overwrites}
|
||||||
|
|
||||||
perms = overwrite_find_mix(perms, overwrites, guild_id)
|
perms = overwrite_find_mix(perms, overwrites, guild_id)
|
||||||
|
|
||||||
|
|
@ -260,14 +284,11 @@ async def compute_overwrites(base_perms: Permissions,
|
||||||
for role_id in role_ids:
|
for role_id in role_ids:
|
||||||
overwrite = overwrites.get(role_id)
|
overwrite = overwrites.get(role_id)
|
||||||
if overwrite:
|
if overwrite:
|
||||||
allow |= overwrite['allow']
|
allow |= overwrite["allow"]
|
||||||
deny |= overwrite['deny']
|
deny |= overwrite["deny"]
|
||||||
|
|
||||||
# final step for roles: mix
|
# final step for roles: mix
|
||||||
perms = overwrite_mix(perms, {
|
perms = overwrite_mix(perms, {"allow": allow, "deny": deny})
|
||||||
'allow': allow,
|
|
||||||
'deny': deny
|
|
||||||
})
|
|
||||||
|
|
||||||
# apply member specific overwrites
|
# apply member specific overwrites
|
||||||
perms = overwrite_find_mix(perms, overwrites, user_id)
|
perms = overwrite_find_mix(perms, overwrites, user_id)
|
||||||
|
|
@ -275,8 +296,7 @@ async def compute_overwrites(base_perms: Permissions,
|
||||||
return perms
|
return perms
|
||||||
|
|
||||||
|
|
||||||
async def get_permissions(member_id: int, channel_id,
|
async def get_permissions(member_id: int, channel_id, *, storage=None) -> Permissions:
|
||||||
*, storage=None) -> Permissions:
|
|
||||||
"""Get the permissions for a user in a channel."""
|
"""Get the permissions for a user in a channel."""
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = app.storage
|
storage = app.storage
|
||||||
|
|
@ -290,4 +310,5 @@ async def get_permissions(member_id: int, channel_id,
|
||||||
base_perms = await base_permissions(member_id, guild_id, storage)
|
base_perms = await base_permissions(member_id, guild_id, storage)
|
||||||
|
|
||||||
return await compute_overwrites(
|
return await compute_overwrites(
|
||||||
base_perms, member_id, channel_id, guild_id, storage)
|
base_perms, member_id, channel_id, guild_id, storage
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -32,62 +32,56 @@ def status_cmp(status: str, other_status: str) -> bool:
|
||||||
in the status hierarchy.
|
in the status hierarchy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hierarchy = {
|
hierarchy = {"online": 3, "idle": 2, "dnd": 1, "offline": 0, None: -1}
|
||||||
'online': 3,
|
|
||||||
'idle': 2,
|
|
||||||
'dnd': 1,
|
|
||||||
'offline': 0,
|
|
||||||
None: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
return hierarchy[status] > hierarchy[other_status]
|
return hierarchy[status] > hierarchy[other_status]
|
||||||
|
|
||||||
|
|
||||||
def _best_presence(shards):
|
def _best_presence(shards):
|
||||||
"""Find the 'best' presence given a list of GatewayState."""
|
"""Find the 'best' presence given a list of GatewayState."""
|
||||||
best = {'status': None, 'game': None}
|
best = {"status": None, "game": None}
|
||||||
|
|
||||||
for state in shards:
|
for state in shards:
|
||||||
presence = state.presence
|
presence = state.presence
|
||||||
|
|
||||||
status = presence['status']
|
status = presence["status"]
|
||||||
|
|
||||||
if not presence:
|
if not presence:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# shards with a better status
|
# shards with a better status
|
||||||
# in the hierarchy are treated as best
|
# in the hierarchy are treated as best
|
||||||
if status_cmp(status, best['status']):
|
if status_cmp(status, best["status"]):
|
||||||
best['status'] = status
|
best["status"] = status
|
||||||
|
|
||||||
# if we have any game, use it
|
# if we have any game, use it
|
||||||
if presence['game'] is not None:
|
if presence["game"] is not None:
|
||||||
best['game'] = presence['game']
|
best["game"] = presence["game"]
|
||||||
|
|
||||||
# best['status'] is None when no
|
# best['status'] is None when no
|
||||||
# status was good enough.
|
# status was good enough.
|
||||||
return None if not best['status'] else best
|
return None if not best["status"] else best
|
||||||
|
|
||||||
|
|
||||||
def fill_presence(presence: dict, *, game=None) -> dict:
|
def fill_presence(presence: dict, *, game=None) -> dict:
|
||||||
"""Fill a given presence object with some specific fields."""
|
"""Fill a given presence object with some specific fields."""
|
||||||
presence['client_status'] = {}
|
presence["client_status"] = {}
|
||||||
presence['mobile'] = False
|
presence["mobile"] = False
|
||||||
|
|
||||||
if 'since' not in presence:
|
if "since" not in presence:
|
||||||
presence['since'] = 0
|
presence["since"] = 0
|
||||||
|
|
||||||
# fill game and activities array depending if game
|
# fill game and activities array depending if game
|
||||||
# is there or not
|
# is there or not
|
||||||
game = game or presence.get('game')
|
game = game or presence.get("game")
|
||||||
|
|
||||||
# casting to bool since a game of {} is still invalid
|
# casting to bool since a game of {} is still invalid
|
||||||
if game:
|
if game:
|
||||||
presence['game'] = game
|
presence["game"] = game
|
||||||
presence['activities'] = [game]
|
presence["activities"] = [game]
|
||||||
else:
|
else:
|
||||||
presence['game'] = None
|
presence["game"] = None
|
||||||
presence['activities'] = []
|
presence["activities"] = []
|
||||||
|
|
||||||
return presence
|
return presence
|
||||||
|
|
||||||
|
|
@ -96,14 +90,13 @@ async def _pres(storage, user_id: int, status_obj: dict) -> dict:
|
||||||
"""Convert a given status into a presence, given the User ID and the
|
"""Convert a given status into a presence, given the User ID and the
|
||||||
:class:`Storage` instance."""
|
:class:`Storage` instance."""
|
||||||
ext = {
|
ext = {
|
||||||
'user': await storage.get_user(user_id),
|
"user": await storage.get_user(user_id),
|
||||||
'activities': [],
|
"activities": [],
|
||||||
|
|
||||||
# NOTE: we are purposefully overwriting the fields, as there
|
# NOTE: we are purposefully overwriting the fields, as there
|
||||||
# isn't any push for us to actually implement mobile detection, or
|
# isn't any push for us to actually implement mobile detection, or
|
||||||
# web detection, etc.
|
# web detection, etc.
|
||||||
'client_status': {},
|
"client_status": {},
|
||||||
'mobile': False,
|
"mobile": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
return fill_presence({**status_obj, **ext})
|
return fill_presence({**status_obj, **ext})
|
||||||
|
|
@ -115,14 +108,16 @@ class PresenceManager:
|
||||||
Has common functions to deal with fetching or updating presences, including
|
Has common functions to deal with fetching or updating presences, including
|
||||||
side-effects (events).
|
side-effects (events).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.storage = app.storage
|
self.storage = app.storage
|
||||||
self.user_storage = app.user_storage
|
self.user_storage = app.user_storage
|
||||||
self.state_manager = app.state_manager
|
self.state_manager = app.state_manager
|
||||||
self.dispatcher = app.dispatcher
|
self.dispatcher = app.dispatcher
|
||||||
|
|
||||||
async def guild_presences(self, member_ids: List[int],
|
async def guild_presences(
|
||||||
guild_id: int) -> List[Dict[Any, str]]:
|
self, member_ids: List[int], guild_id: int
|
||||||
|
) -> List[Dict[Any, str]]:
|
||||||
"""Fetch all presences in a guild."""
|
"""Fetch all presences in a guild."""
|
||||||
# this works via fetching all connected GatewayState on a guild
|
# this works via fetching all connected GatewayState on a guild
|
||||||
# then fetching its respective member and merging that info with
|
# then fetching its respective member and merging that info with
|
||||||
|
|
@ -132,34 +127,36 @@ class PresenceManager:
|
||||||
presences = []
|
presences = []
|
||||||
|
|
||||||
for state in states:
|
for state in states:
|
||||||
member = await self.storage.get_member_data_one(
|
member = await self.storage.get_member_data_one(guild_id, state.user_id)
|
||||||
guild_id, state.user_id)
|
|
||||||
|
|
||||||
game = state.presence.get('game', None)
|
game = state.presence.get("game", None)
|
||||||
|
|
||||||
# only use the data we need.
|
# only use the data we need.
|
||||||
presences.append(fill_presence({
|
presences.append(
|
||||||
'user': member['user'],
|
fill_presence(
|
||||||
'roles': member['roles'],
|
{
|
||||||
'guild_id': str(guild_id),
|
"user": member["user"],
|
||||||
|
"roles": member["roles"],
|
||||||
|
"guild_id": str(guild_id),
|
||||||
# if a state is connected to the guild
|
# if a state is connected to the guild
|
||||||
# we assume its online.
|
# we assume its online.
|
||||||
'status': state.presence.get('status', 'online'),
|
"status": state.presence.get("status", "online"),
|
||||||
}, game=game))
|
},
|
||||||
|
game=game,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return presences
|
return presences
|
||||||
|
|
||||||
async def dispatch_guild_pres(self, guild_id: int,
|
async def dispatch_guild_pres(self, guild_id: int, user_id: int, new_state: dict):
|
||||||
user_id: int, new_state: dict):
|
|
||||||
"""Dispatch a Presence update to an entire guild."""
|
"""Dispatch a Presence update to an entire guild."""
|
||||||
state = dict(new_state)
|
state = dict(new_state)
|
||||||
|
|
||||||
member = await self.storage.get_member_data_one(guild_id, user_id)
|
member = await self.storage.get_member_data_one(guild_id, user_id)
|
||||||
|
|
||||||
game = state['game']
|
game = state["game"]
|
||||||
|
|
||||||
lazy_guild_store = self.dispatcher.backends['lazy_guild']
|
lazy_guild_store = self.dispatcher.backends["lazy_guild"]
|
||||||
lists = lazy_guild_store.get_gml_guild(guild_id)
|
lists = lazy_guild_store.get_gml_guild(guild_id)
|
||||||
|
|
||||||
# shards that are in lazy guilds with 'everyone'
|
# shards that are in lazy guilds with 'everyone'
|
||||||
|
|
@ -168,49 +165,44 @@ class PresenceManager:
|
||||||
|
|
||||||
for member_list in lists:
|
for member_list in lists:
|
||||||
session_ids = await member_list.pres_update(
|
session_ids = await member_list.pres_update(
|
||||||
int(member['user']['id']),
|
int(member["user"]["id"]),
|
||||||
{
|
{"roles": member["roles"], "status": state["status"], "game": game},
|
||||||
'roles': member['roles'],
|
|
||||||
'status': state['status'],
|
|
||||||
'game': game
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug('Lazy Dispatch to {}',
|
log.debug("Lazy Dispatch to {}", len(session_ids))
|
||||||
len(session_ids))
|
|
||||||
|
|
||||||
# if we are on the 'everyone' member list, we don't
|
# if we are on the 'everyone' member list, we don't
|
||||||
# dispatch a PRESENCE_UPDATE for those shards.
|
# dispatch a PRESENCE_UPDATE for those shards.
|
||||||
if member_list.channel_id == member_list.guild_id:
|
if member_list.channel_id == member_list.guild_id:
|
||||||
in_lazy.extend(session_ids)
|
in_lazy.extend(session_ids)
|
||||||
|
|
||||||
pres_update_payload = fill_presence({
|
pres_update_payload = fill_presence(
|
||||||
'guild_id': str(guild_id),
|
{
|
||||||
'user': member['user'],
|
"guild_id": str(guild_id),
|
||||||
'roles': member['roles'],
|
"user": member["user"],
|
||||||
'status': state['status'],
|
"roles": member["roles"],
|
||||||
}, game=game)
|
"status": state["status"],
|
||||||
|
},
|
||||||
|
game=game,
|
||||||
|
)
|
||||||
|
|
||||||
# given a session id, return if the session id actually connects to
|
# given a session id, return if the session id actually connects to
|
||||||
# a given user, and if the state has not been dispatched via lazy guild.
|
# a given user, and if the state has not been dispatched via lazy guild.
|
||||||
def _session_check(session_id):
|
def _session_check(session_id):
|
||||||
state = self.state_manager.fetch_raw(session_id)
|
state = self.state_manager.fetch_raw(session_id)
|
||||||
uid = int(member['user']['id'])
|
uid = int(member["user"]["id"])
|
||||||
|
|
||||||
if not state:
|
if not state:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# we don't want to send a presence update
|
# we don't want to send a presence update
|
||||||
# to the same user
|
# to the same user
|
||||||
return (state.user_id != uid and
|
return state.user_id != uid and session_id not in in_lazy
|
||||||
session_id not in in_lazy)
|
|
||||||
|
|
||||||
# everyone not in lazy guild mode
|
# everyone not in lazy guild mode
|
||||||
# gets a PRESENCE_UPDATE
|
# gets a PRESENCE_UPDATE
|
||||||
await self.dispatcher.dispatch_filter(
|
await self.dispatcher.dispatch_filter(
|
||||||
'guild', guild_id,
|
"guild", guild_id, _session_check, "PRESENCE_UPDATE", pres_update_payload
|
||||||
_session_check,
|
|
||||||
'PRESENCE_UPDATE', pres_update_payload
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return in_lazy
|
return in_lazy
|
||||||
|
|
@ -220,25 +212,25 @@ class PresenceManager:
|
||||||
|
|
||||||
Also dispatches the presence to all the users' friends
|
Also dispatches the presence to all the users' friends
|
||||||
"""
|
"""
|
||||||
if state['status'] == 'invisible':
|
if state["status"] == "invisible":
|
||||||
state['status'] = 'offline'
|
state["status"] = "offline"
|
||||||
|
|
||||||
# TODO: shard-aware
|
# TODO: shard-aware
|
||||||
guild_ids = await self.user_storage.get_user_guilds(user_id)
|
guild_ids = await self.user_storage.get_user_guilds(user_id)
|
||||||
|
|
||||||
for guild_id in guild_ids:
|
for guild_id in guild_ids:
|
||||||
await self.dispatch_guild_pres(
|
await self.dispatch_guild_pres(guild_id, user_id, state)
|
||||||
guild_id, user_id, state)
|
|
||||||
|
|
||||||
# dispatch to all friends that are subscribed to them
|
# dispatch to all friends that are subscribed to them
|
||||||
user = await self.storage.get_user(user_id)
|
user = await self.storage.get_user(user_id)
|
||||||
game = state['game']
|
game = state["game"]
|
||||||
|
|
||||||
await self.dispatcher.dispatch(
|
await self.dispatcher.dispatch(
|
||||||
'friend', user_id, 'PRESENCE_UPDATE', fill_presence({
|
"friend",
|
||||||
'user': user,
|
user_id,
|
||||||
'status': state['status'],
|
"PRESENCE_UPDATE",
|
||||||
}, game=game))
|
fill_presence({"user": user, "status": state["status"]}, game=game),
|
||||||
|
)
|
||||||
|
|
||||||
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
|
async def friend_presences(self, friend_ids: Iterable[int]) -> List[Presence]:
|
||||||
"""Fetch presences for a group of users.
|
"""Fetch presences for a group of users.
|
||||||
|
|
@ -254,22 +246,25 @@ class PresenceManager:
|
||||||
|
|
||||||
if not friend_states:
|
if not friend_states:
|
||||||
# append offline
|
# append offline
|
||||||
res.append(await _pres(storage, friend_id, {
|
res.append(
|
||||||
'afk': False,
|
await _pres(
|
||||||
'status': 'offline',
|
storage,
|
||||||
'game': None,
|
friend_id,
|
||||||
'since': 0
|
{"afk": False, "status": "offline", "game": None, "since": 0},
|
||||||
}))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# filter the best shards:
|
# filter the best shards:
|
||||||
# - all with id 0 (are the first shards in the collection) or
|
# - all with id 0 (are the first shards in the collection) or
|
||||||
# - all shards with count = 1 (single shards)
|
# - all shards with count = 1 (single shards)
|
||||||
good_shards = list(filter(
|
good_shards = list(
|
||||||
|
filter(
|
||||||
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
|
lambda state: state.shard[0] == 0 or state.shard[1] == 1,
|
||||||
friend_states
|
friend_states,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if good_shards:
|
if good_shards:
|
||||||
best_pres = _best_presence(good_shards)
|
best_pres = _best_presence(good_shards)
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,11 @@ from .channel import ChannelDispatcher
|
||||||
from .friend import FriendDispatcher
|
from .friend import FriendDispatcher
|
||||||
from .lazy_guild import LazyGuildDispatcher
|
from .lazy_guild import LazyGuildDispatcher
|
||||||
|
|
||||||
__all__ = ['GuildDispatcher', 'MemberDispatcher',
|
__all__ = [
|
||||||
'UserDispatcher', 'ChannelDispatcher',
|
"GuildDispatcher",
|
||||||
'FriendDispatcher', 'LazyGuildDispatcher']
|
"MemberDispatcher",
|
||||||
|
"UserDispatcher",
|
||||||
|
"ChannelDispatcher",
|
||||||
|
"FriendDispatcher",
|
||||||
|
"LazyGuildDispatcher",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -38,23 +38,20 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
||||||
# make a copy or the original channel object
|
# make a copy or the original channel object
|
||||||
data = dict(orig)
|
data = dict(orig)
|
||||||
|
|
||||||
idx = index_by_func(
|
idx = index_by_func(lambda user: user["id"] == str(user_id), data["recipients"])
|
||||||
lambda user: user['id'] == str(user_id),
|
|
||||||
data['recipients']
|
|
||||||
)
|
|
||||||
|
|
||||||
data['recipients'].pop(idx)
|
data["recipients"].pop(idx)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ChannelDispatcher(DispatcherWithFlags):
|
class ChannelDispatcher(DispatcherWithFlags):
|
||||||
"""Main channel Pub/Sub logic."""
|
"""Main channel Pub/Sub logic."""
|
||||||
|
|
||||||
KEY_TYPE = int
|
KEY_TYPE = int
|
||||||
VAL_TYPE = int
|
VAL_TYPE = int
|
||||||
|
|
||||||
async def dispatch(self, channel_id,
|
async def dispatch(self, channel_id, event: str, data: Any) -> List[str]:
|
||||||
event: str, data: Any) -> List[str]:
|
|
||||||
"""Dispatch an event to a channel."""
|
"""Dispatch an event to a channel."""
|
||||||
# get everyone who is subscribed
|
# get everyone who is subscribed
|
||||||
# and store the number of states we dispatched the event to
|
# and store the number of states we dispatched the event to
|
||||||
|
|
@ -75,9 +72,11 @@ class ChannelDispatcher(DispatcherWithFlags):
|
||||||
# TODO: make a fetch_states that fetches shards
|
# TODO: make a fetch_states that fetches shards
|
||||||
# - with id 0 (count any) OR
|
# - with id 0 (count any) OR
|
||||||
# - single shards (id=0, count=1)
|
# - single shards (id=0, count=1)
|
||||||
states = (self.sm.fetch_states(user_id, guild_id)
|
states = (
|
||||||
if guild_id else
|
self.sm.fetch_states(user_id, guild_id)
|
||||||
self.sm.user_states(user_id))
|
if guild_id
|
||||||
|
else self.sm.user_states(user_id)
|
||||||
|
)
|
||||||
|
|
||||||
# unsub people who don't have any states tied to the channel.
|
# unsub people who don't have any states tied to the channel.
|
||||||
if not states:
|
if not states:
|
||||||
|
|
@ -85,28 +84,28 @@ class ChannelDispatcher(DispatcherWithFlags):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# skip typing events for users that don't want it
|
# skip typing events for users that don't want it
|
||||||
if event.startswith('TYPING_') and \
|
if event.startswith("TYPING_") and not self.flags_get(
|
||||||
not self.flags_get(channel_id, user_id, 'typing', True):
|
channel_id, user_id, "typing", True
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cur_sess = []
|
cur_sess = []
|
||||||
|
|
||||||
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
if (
|
||||||
and data.get('type') == ChannelType.GROUP_DM.value:
|
event in ("CHANNEL_CREATE", "CHANNEL_UPDATE")
|
||||||
|
and data.get("type") == ChannelType.GROUP_DM.value
|
||||||
|
):
|
||||||
# we edit the channel payload so it doesn't show
|
# we edit the channel payload so it doesn't show
|
||||||
# the user as a recipient
|
# the user as a recipient
|
||||||
|
|
||||||
new_data = gdm_recipient_view(data, user_id)
|
new_data = gdm_recipient_view(data, user_id)
|
||||||
cur_sess = await self._dispatch_states(
|
cur_sess = await self._dispatch_states(states, event, new_data)
|
||||||
states, event, new_data)
|
|
||||||
else:
|
else:
|
||||||
cur_sess = await self._dispatch_states(
|
cur_sess = await self._dispatch_states(states, event, data)
|
||||||
states, event, data)
|
|
||||||
|
|
||||||
sessions.extend(cur_sess)
|
sessions.extend(cur_sess)
|
||||||
dispatched += len(cur_sess)
|
dispatched += len(cur_sess)
|
||||||
|
|
||||||
log.info('Dispatched chan={} {!r} to {} states',
|
log.info("Dispatched chan={} {!r} to {} states", channel_id, event, dispatched)
|
||||||
channel_id, event, dispatched)
|
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
|
||||||
|
|
@ -80,8 +80,7 @@ class Dispatcher:
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def _dispatch_states(self, states: list, event: str,
|
async def _dispatch_states(self, states: list, event: str, data) -> List[str]:
|
||||||
data) -> List[str]:
|
|
||||||
"""Dispatch an event to a list of states."""
|
"""Dispatch an event to a list of states."""
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
|
|
@ -90,7 +89,7 @@ class Dispatcher:
|
||||||
await state.ws.dispatch(event, data)
|
await state.ws.dispatch(event, data)
|
||||||
res.append(state.session_id)
|
res.append(state.session_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('error while dispatching')
|
log.exception("error while dispatching")
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
@ -102,6 +101,7 @@ class DispatcherWithState(Dispatcher):
|
||||||
of boilerplate code on Pub/Sub backends
|
of boilerplate code on Pub/Sub backends
|
||||||
that have that dictionary.
|
that have that dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, main):
|
def __init__(self, main):
|
||||||
super().__init__(main)
|
super().__init__(main)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ class FriendDispatcher(DispatcherWithState):
|
||||||
channels. If that friend updates their presence, it will be
|
channels. If that friend updates their presence, it will be
|
||||||
broadcasted through that channel to basically all their friends.
|
broadcasted through that channel to basically all their friends.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
KEY_TYPE = int
|
KEY_TYPE = int
|
||||||
VAL_TYPE = int
|
VAL_TYPE = int
|
||||||
|
|
||||||
|
|
@ -44,17 +45,13 @@ class FriendDispatcher(DispatcherWithState):
|
||||||
# since relationships broadcast to all shards.
|
# since relationships broadcast to all shards.
|
||||||
sessions.extend(
|
sessions.extend(
|
||||||
await self.main_dispatcher.dispatch_filter(
|
await self.main_dispatcher.dispatch_filter(
|
||||||
'user', peer_id, func, event, data)
|
"user", peer_id, func, event, data
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info('dispatched uid={} {!r} to {} states',
|
log.info("dispatched uid={} {!r} to {} states", user_id, event, len(sessions))
|
||||||
user_id, event, len(sessions))
|
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
async def dispatch(self, user_id, event, data):
|
async def dispatch(self, user_id, event, data):
|
||||||
return await self.dispatch_filter(
|
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||||
user_id,
|
|
||||||
lambda sess_id: True,
|
|
||||||
event, data,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -29,11 +29,11 @@ log = Logger(__name__)
|
||||||
|
|
||||||
class GuildDispatcher(DispatcherWithFlags):
|
class GuildDispatcher(DispatcherWithFlags):
|
||||||
"""Guild backend for Pub/Sub"""
|
"""Guild backend for Pub/Sub"""
|
||||||
|
|
||||||
KEY_TYPE = int
|
KEY_TYPE = int
|
||||||
VAL_TYPE = int
|
VAL_TYPE = int
|
||||||
|
|
||||||
async def _chan_action(self, action: str,
|
async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None):
|
||||||
guild_id: int, user_id: int, flags=None):
|
|
||||||
"""Send an action to all channels of the guild."""
|
"""Send an action to all channels of the guild."""
|
||||||
flags = flags or {}
|
flags = flags or {}
|
||||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||||
|
|
@ -43,33 +43,31 @@ class GuildDispatcher(DispatcherWithFlags):
|
||||||
# only do an action for users that can
|
# only do an action for users that can
|
||||||
# actually read the channel to start with.
|
# actually read the channel to start with.
|
||||||
chan_perms = await get_permissions(
|
chan_perms = await get_permissions(
|
||||||
user_id, chan_id,
|
user_id, chan_id, storage=self.main_dispatcher.app.storage
|
||||||
storage=self.main_dispatcher.app.storage)
|
)
|
||||||
|
|
||||||
if not chan_perms.bits.read_messages:
|
if not chan_perms.bits.read_messages:
|
||||||
log.debug('skipping cid={}, no read messages',
|
log.debug("skipping cid={}, no read messages", chan_id)
|
||||||
chan_id)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
log.debug('sending raw action {!r} to chan={}',
|
log.debug("sending raw action {!r} to chan={}", action, chan_id)
|
||||||
action, chan_id)
|
|
||||||
|
|
||||||
# for now, only sub() has support for flags.
|
# for now, only sub() has support for flags.
|
||||||
# it is an idea to have flags support for other actions
|
# it is an idea to have flags support for other actions
|
||||||
args = []
|
args = []
|
||||||
if action == 'sub':
|
if action == "sub":
|
||||||
chanflags = dict(flags)
|
chanflags = dict(flags)
|
||||||
|
|
||||||
# channels don't need presence flags
|
# channels don't need presence flags
|
||||||
try:
|
try:
|
||||||
chanflags.pop('presence')
|
chanflags.pop("presence")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
args.append(chanflags)
|
args.append(chanflags)
|
||||||
|
|
||||||
await self.main_dispatcher.action(
|
await self.main_dispatcher.action(
|
||||||
'channel', action, chan_id, user_id, *args
|
"channel", action, chan_id, user_id, *args
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _chan_call(self, meth: str, guild_id: int, *args):
|
async def _chan_call(self, meth: str, guild_id: int, *args):
|
||||||
|
|
@ -77,26 +75,24 @@ class GuildDispatcher(DispatcherWithFlags):
|
||||||
in the guild."""
|
in the guild."""
|
||||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||||
|
|
||||||
chan_dispatcher = self.main_dispatcher.backends['channel']
|
chan_dispatcher = self.main_dispatcher.backends["channel"]
|
||||||
method = getattr(chan_dispatcher, meth)
|
method = getattr(chan_dispatcher, meth)
|
||||||
|
|
||||||
for chan_id in chan_ids:
|
for chan_id in chan_ids:
|
||||||
log.debug('calling {} to chan={}',
|
log.debug("calling {} to chan={}", meth, chan_id)
|
||||||
meth, chan_id)
|
|
||||||
await method(chan_id, *args)
|
await method(chan_id, *args)
|
||||||
|
|
||||||
async def sub(self, guild_id: int, user_id: int, flags=None):
|
async def sub(self, guild_id: int, user_id: int, flags=None):
|
||||||
"""Subscribe a user to the guild."""
|
"""Subscribe a user to the guild."""
|
||||||
await super().sub(guild_id, user_id, flags)
|
await super().sub(guild_id, user_id, flags)
|
||||||
await self._chan_action('sub', guild_id, user_id, flags)
|
await self._chan_action("sub", guild_id, user_id, flags)
|
||||||
|
|
||||||
async def unsub(self, guild_id: int, user_id: int):
|
async def unsub(self, guild_id: int, user_id: int):
|
||||||
"""Unsubscribe a user from the guild."""
|
"""Unsubscribe a user from the guild."""
|
||||||
await super().unsub(guild_id, user_id)
|
await super().unsub(guild_id, user_id)
|
||||||
await self._chan_action('unsub', guild_id, user_id)
|
await self._chan_action("unsub", guild_id, user_id)
|
||||||
|
|
||||||
async def dispatch_filter(self, guild_id: int, func,
|
async def dispatch_filter(self, guild_id: int, func, event: str, data: Any):
|
||||||
event: str, data: Any):
|
|
||||||
"""Selectively dispatch to session ids that have
|
"""Selectively dispatch to session ids that have
|
||||||
func(session_id) true."""
|
func(session_id) true."""
|
||||||
user_ids = self.state[guild_id]
|
user_ids = self.state[guild_id]
|
||||||
|
|
@ -121,31 +117,23 @@ class GuildDispatcher(DispatcherWithFlags):
|
||||||
|
|
||||||
# note that this does not equate to any unsubscription
|
# note that this does not equate to any unsubscription
|
||||||
# of the channel.
|
# of the channel.
|
||||||
if event.startswith('PRESENCE_') and \
|
if event.startswith("PRESENCE_") and not self.flags_get(
|
||||||
not self.flags_get(guild_id, user_id, 'presence', True):
|
guild_id, user_id, "presence", True
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# filter the ones that matter
|
# filter the ones that matter
|
||||||
states = list(filter(
|
states = list(filter(lambda state: func(state.session_id), states))
|
||||||
lambda state: func(state.session_id), states
|
|
||||||
))
|
|
||||||
|
|
||||||
cur_sess = await self._dispatch_states(
|
cur_sess = await self._dispatch_states(states, event, data)
|
||||||
states, event, data)
|
|
||||||
|
|
||||||
sessions.extend(cur_sess)
|
sessions.extend(cur_sess)
|
||||||
dispatched += len(cur_sess)
|
dispatched += len(cur_sess)
|
||||||
|
|
||||||
log.info('Dispatched {} {!r} to {} states',
|
log.info("Dispatched {} {!r} to {} states", guild_id, event, dispatched)
|
||||||
guild_id, event, dispatched)
|
|
||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
async def dispatch(self, guild_id: int,
|
async def dispatch(self, guild_id: int, event: str, data: Any):
|
||||||
event: str, data: Any):
|
|
||||||
"""Dispatch an event to all subscribers of the guild."""
|
"""Dispatch an event to all subscribers of the guild."""
|
||||||
return await self.dispatch_filter(
|
return await self.dispatch_filter(guild_id, lambda sess_id: True, event, data)
|
||||||
guild_id,
|
|
||||||
lambda sess_id: True,
|
|
||||||
event, data,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -22,6 +22,7 @@ from .dispatcher import Dispatcher
|
||||||
|
|
||||||
class MemberDispatcher(Dispatcher):
|
class MemberDispatcher(Dispatcher):
|
||||||
"""Member backend for Pub/Sub."""
|
"""Member backend for Pub/Sub."""
|
||||||
|
|
||||||
KEY_TYPE = tuple
|
KEY_TYPE = tuple
|
||||||
|
|
||||||
async def dispatch(self, key, event, data):
|
async def dispatch(self, key, event, data):
|
||||||
|
|
@ -39,7 +40,7 @@ class MemberDispatcher(Dispatcher):
|
||||||
# if no states were found, we should
|
# if no states were found, we should
|
||||||
# unsub the user from the GUILD channel
|
# unsub the user from the GUILD channel
|
||||||
if not states:
|
if not states:
|
||||||
await self.main_dispatcher.unsub('guild', guild_id, user_id)
|
await self.main_dispatcher.unsub("guild", guild_id, user_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
return await self._dispatch_states(states, event, data)
|
return await self._dispatch_states(states, event, data)
|
||||||
|
|
|
||||||
|
|
@ -22,22 +22,18 @@ from .dispatcher import Dispatcher
|
||||||
|
|
||||||
class UserDispatcher(Dispatcher):
|
class UserDispatcher(Dispatcher):
|
||||||
"""User backend for Pub/Sub."""
|
"""User backend for Pub/Sub."""
|
||||||
|
|
||||||
KEY_TYPE = int
|
KEY_TYPE = int
|
||||||
|
|
||||||
async def dispatch_filter(self, user_id: int, func, event, data):
|
async def dispatch_filter(self, user_id: int, func, event, data):
|
||||||
"""Dispatch an event to all shards of a user."""
|
"""Dispatch an event to all shards of a user."""
|
||||||
|
|
||||||
# filter only states where func() gives true
|
# filter only states where func() gives true
|
||||||
states = list(filter(
|
states = list(
|
||||||
lambda state: func(state.session_id),
|
filter(lambda state: func(state.session_id), self.sm.user_states(user_id))
|
||||||
self.sm.user_states(user_id)
|
)
|
||||||
))
|
|
||||||
|
|
||||||
return await self._dispatch_states(states, event, data)
|
return await self._dispatch_states(states, event, data)
|
||||||
|
|
||||||
async def dispatch(self, user_id: int, event, data):
|
async def dispatch(self, user_id: int, event, data):
|
||||||
return await self.dispatch_filter(
|
return await self.dispatch_filter(user_id, lambda sess_id: True, event, data)
|
||||||
user_id,
|
|
||||||
lambda sess_id: True,
|
|
||||||
event, data,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ import time
|
||||||
|
|
||||||
class RatelimitBucket:
|
class RatelimitBucket:
|
||||||
"""Main ratelimit bucket class."""
|
"""Main ratelimit bucket class."""
|
||||||
|
|
||||||
def __init__(self, tokens, second):
|
def __init__(self, tokens, second):
|
||||||
self.requests = tokens
|
self.requests = tokens
|
||||||
self.second = second
|
self.second = second
|
||||||
|
|
@ -88,17 +89,19 @@ class RatelimitBucket:
|
||||||
|
|
||||||
Used to manage multiple ratelimits to users.
|
Used to manage multiple ratelimits to users.
|
||||||
"""
|
"""
|
||||||
return RatelimitBucket(self.requests,
|
return RatelimitBucket(self.requests, self.second)
|
||||||
self.second)
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f'<RatelimitBucket requests={self.requests} '
|
return (
|
||||||
f'second={self.second} window: {self._window} '
|
f"<RatelimitBucket requests={self.requests} "
|
||||||
f'tokens={self._tokens}>')
|
f"second={self.second} window: {self._window} "
|
||||||
|
f"tokens={self._tokens}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Ratelimit:
|
class Ratelimit:
|
||||||
"""Manages buckets."""
|
"""Manages buckets."""
|
||||||
|
|
||||||
def __init__(self, tokens, second, keys=None):
|
def __init__(self, tokens, second, keys=None):
|
||||||
self._cache = {}
|
self._cache = {}
|
||||||
if keys is None:
|
if keys is None:
|
||||||
|
|
@ -107,12 +110,11 @@ class Ratelimit:
|
||||||
self._cooldown = RatelimitBucket(tokens, second)
|
self._cooldown = RatelimitBucket(tokens, second)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f'<Ratelimit cooldown={self._cooldown}>')
|
return f"<Ratelimit cooldown={self._cooldown}>"
|
||||||
|
|
||||||
def _verify_cache(self):
|
def _verify_cache(self):
|
||||||
current = time.time()
|
current = time.time()
|
||||||
dead_keys = [k for k, v in self._cache.items()
|
dead_keys = [k for k, v in self._cache.items() if current > v._last + v.second]
|
||||||
if current > v._last + v.second]
|
|
||||||
|
|
||||||
for k in dead_keys:
|
for k in dead_keys:
|
||||||
del self._cache[k]
|
del self._cache[k]
|
||||||
|
|
|
||||||
|
|
@ -31,10 +31,10 @@ async def _check_bucket(bucket):
|
||||||
if retry_after:
|
if retry_after:
|
||||||
request.retry_after = retry_after
|
request.retry_after = retry_after
|
||||||
|
|
||||||
raise Ratelimited('You are being rate limited.', {
|
raise Ratelimited(
|
||||||
'retry_after': int(retry_after * 1000),
|
"You are being rate limited.",
|
||||||
'global': request.bucket_global,
|
{"retry_after": int(retry_after * 1000), "global": request.bucket_global},
|
||||||
})
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_global(ratelimit):
|
async def _handle_global(ratelimit):
|
||||||
|
|
@ -59,13 +59,13 @@ async def _handle_specific(ratelimit):
|
||||||
keys = ratelimit.keys
|
keys = ratelimit.keys
|
||||||
|
|
||||||
# base key is the user id
|
# base key is the user id
|
||||||
key_components = [f'user_id:{user_id}']
|
key_components = [f"user_id:{user_id}"]
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
val = request.view_args[key]
|
val = request.view_args[key]
|
||||||
key_components.append(f'{key}:{val}')
|
key_components.append(f"{key}:{val}")
|
||||||
|
|
||||||
bucket_key = ':'.join(key_components)
|
bucket_key = ":".join(key_components)
|
||||||
bucket = ratelimit.get_bucket(bucket_key)
|
bucket = ratelimit.get_bucket(bucket_key)
|
||||||
await _check_bucket(bucket)
|
await _check_bucket(bucket)
|
||||||
|
|
||||||
|
|
@ -78,9 +78,7 @@ async def ratelimit_handler():
|
||||||
rule = request.url_rule
|
rule = request.url_rule
|
||||||
|
|
||||||
if rule is None:
|
if rule is None:
|
||||||
return await _handle_global(
|
return await _handle_global(app.ratelimiter.global_bucket)
|
||||||
app.ratelimiter.global_bucket
|
|
||||||
)
|
|
||||||
|
|
||||||
# rule.endpoint is composed of '<blueprint>.<function>'
|
# rule.endpoint is composed of '<blueprint>.<function>'
|
||||||
# and so we can use that to make routes with different
|
# and so we can use that to make routes with different
|
||||||
|
|
@ -97,6 +95,4 @@ async def ratelimit_handler():
|
||||||
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
|
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
|
||||||
await _handle_specific(ratelimit)
|
await _handle_specific(ratelimit)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
await _handle_global(
|
await _handle_global(app.ratelimiter.global_bucket)
|
||||||
app.ratelimiter.global_bucket
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -34,33 +34,30 @@ WS:
|
||||||
|All Sent Messages| | 120/60s | per-session
|
|All Sent Messages| | 120/60s | per-session
|
||||||
"""
|
"""
|
||||||
|
|
||||||
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
|
REACTION_BUCKET = Ratelimit(1, 0.25, ("channel_id"))
|
||||||
|
|
||||||
RATELIMITS = {
|
RATELIMITS = {
|
||||||
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
|
"channel_messages.create_message": Ratelimit(5, 5, ("channel_id")),
|
||||||
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
|
"channel_messages.delete_message": Ratelimit(5, 1, ("channel_id")),
|
||||||
|
|
||||||
# all of those share the same bucket.
|
# all of those share the same bucket.
|
||||||
'channel_reactions.add_reaction': REACTION_BUCKET,
|
"channel_reactions.add_reaction": REACTION_BUCKET,
|
||||||
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
|
"channel_reactions.remove_own_reaction": REACTION_BUCKET,
|
||||||
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
|
"channel_reactions.remove_user_reaction": REACTION_BUCKET,
|
||||||
|
"guild_members.modify_guild_member": Ratelimit(10, 10, ("guild_id")),
|
||||||
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
|
"guild_members.update_nickname": Ratelimit(1, 1, ("guild_id")),
|
||||||
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
|
|
||||||
|
|
||||||
# this only applies to username.
|
# this only applies to username.
|
||||||
# 'users.patch_me': Ratelimit(2, 3600),
|
# 'users.patch_me': Ratelimit(2, 3600),
|
||||||
|
"_ws.connect": Ratelimit(1, 5),
|
||||||
'_ws.connect': Ratelimit(1, 5),
|
"_ws.presence": Ratelimit(5, 60),
|
||||||
'_ws.presence': Ratelimit(5, 60),
|
"_ws.messages": Ratelimit(120, 60),
|
||||||
'_ws.messages': Ratelimit(120, 60),
|
|
||||||
|
|
||||||
# 1000 / 4h for new session issuing
|
# 1000 / 4h for new session issuing
|
||||||
'_ws.session': Ratelimit(1000, 14400)
|
"_ws.session": Ratelimit(1000, 14400),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RatelimitManager:
|
class RatelimitManager:
|
||||||
"""Manager for the bucket managers"""
|
"""Manager for the bucket managers"""
|
||||||
|
|
||||||
def __init__(self, testing_flag=False):
|
def __init__(self, testing_flag=False):
|
||||||
self._ratelimiters = {}
|
self._ratelimiters = {}
|
||||||
self._test = testing_flag
|
self._test = testing_flag
|
||||||
|
|
@ -74,9 +71,7 @@ class RatelimitManager:
|
||||||
|
|
||||||
# NOTE: this is a bad way to do it, but
|
# NOTE: this is a bad way to do it, but
|
||||||
# we only need to change that one for now.
|
# we only need to change that one for now.
|
||||||
rtl = (Ratelimit(10, 1)
|
rtl = Ratelimit(10, 1) if self._test and path == "_ws.connect" else rtl
|
||||||
if self._test and path == '_ws.connect'
|
|
||||||
else rtl)
|
|
||||||
|
|
||||||
self._ratelimiters[path] = rtl
|
self._ratelimiters[path] = rtl
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,26 +28,30 @@ from .errors import BadRequest
|
||||||
from .permissions import Permissions
|
from .permissions import Permissions
|
||||||
from .types import Color
|
from .types import Color
|
||||||
from .enums import (
|
from .enums import (
|
||||||
ActivityType, StatusType, ExplicitFilter, RelationshipType,
|
ActivityType,
|
||||||
MessageNotifications, ChannelType, VerificationLevel
|
StatusType,
|
||||||
|
ExplicitFilter,
|
||||||
|
RelationshipType,
|
||||||
|
MessageNotifications,
|
||||||
|
ChannelType,
|
||||||
|
VerificationLevel,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
|
from litecord.embed.schemas import EMBED_OBJECT, EmbedURL
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_ ]{2,30}$', re.A)
|
USERNAME_REGEX = re.compile(r"^[a-zA-Z0-9_ ]{2,30}$", re.A)
|
||||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
|
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$", re.A)
|
||||||
re.A)
|
DATA_REGEX = re.compile(r"data\:image/(png|jpeg|gif);base64,(.+)", re.A)
|
||||||
DATA_REGEX = re.compile(r'data\:image/(png|jpeg|gif);base64,(.+)', re.A)
|
|
||||||
|
|
||||||
|
|
||||||
# collection of regexes
|
# collection of regexes
|
||||||
USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M)
|
USER_MENTION = re.compile(r"<@!?(\d+)>", re.A | re.M)
|
||||||
CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M)
|
CHAN_MENTION = re.compile(r"<#(\d+)>", re.A | re.M)
|
||||||
ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M)
|
ROLE_MENTION = re.compile(r"<@&(\d+)>", re.A | re.M)
|
||||||
EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
|
EMOJO_MENTION = re.compile(r"<:(\.+):(\d+)>", re.A | re.M)
|
||||||
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
|
ANIMOJI_MENTION = re.compile(r"<a:(\.+):(\d+)>", re.A | re.M)
|
||||||
|
|
||||||
|
|
||||||
def _in_enum(enum, value) -> bool:
|
def _in_enum(enum, value) -> bool:
|
||||||
|
|
@ -61,6 +65,7 @@ def _in_enum(enum, value) -> bool:
|
||||||
|
|
||||||
class LitecordValidator(Validator):
|
class LitecordValidator(Validator):
|
||||||
"""Main validator class for Litecord, containing custom types."""
|
"""Main validator class for Litecord, containing custom types."""
|
||||||
|
|
||||||
def _validate_type_username(self, value: str) -> bool:
|
def _validate_type_username(self, value: str) -> bool:
|
||||||
"""Validate against the username regex."""
|
"""Validate against the username regex."""
|
||||||
return bool(USERNAME_REGEX.match(value))
|
return bool(USERNAME_REGEX.match(value))
|
||||||
|
|
@ -130,8 +135,7 @@ class LitecordValidator(Validator):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# nobody is allowed to use the INCOMING and OUTGOING rel types
|
# nobody is allowed to use the INCOMING and OUTGOING rel types
|
||||||
return val in (RelationshipType.FRIEND.value,
|
return val in (RelationshipType.FRIEND.value, RelationshipType.BLOCK.value)
|
||||||
RelationshipType.BLOCK.value)
|
|
||||||
|
|
||||||
def _validate_type_msg_notifications(self, value: str):
|
def _validate_type_msg_notifications(self, value: str):
|
||||||
try:
|
try:
|
||||||
|
|
@ -152,14 +156,15 @@ class LitecordValidator(Validator):
|
||||||
return self._validate_type_guild_name(value)
|
return self._validate_type_guild_name(value)
|
||||||
|
|
||||||
def _validate_type_theme(self, value: str) -> bool:
|
def _validate_type_theme(self, value: str) -> bool:
|
||||||
return value in ['light', 'dark']
|
return value in ["light", "dark"]
|
||||||
|
|
||||||
def _validate_type_nickname(self, value: str) -> bool:
|
def _validate_type_nickname(self, value: str) -> bool:
|
||||||
return isinstance(value, str) and (len(value) < 32)
|
return isinstance(value, str) and (len(value) < 32)
|
||||||
|
|
||||||
|
|
||||||
def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
|
def validate(
|
||||||
raise_err: bool = True) -> Dict:
|
reqjson: Optional[Union[Dict, List]], schema: Dict, raise_err: bool = True
|
||||||
|
) -> Dict:
|
||||||
"""Validate the given user-given data against a schema, giving the
|
"""Validate the given user-given data against a schema, giving the
|
||||||
"correct" version of the document, with all defaults applied.
|
"correct" version of the document, with all defaults applied.
|
||||||
|
|
||||||
|
|
@ -176,20 +181,20 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
|
||||||
validator = LitecordValidator(schema)
|
validator = LitecordValidator(schema)
|
||||||
|
|
||||||
if reqjson is None:
|
if reqjson is None:
|
||||||
raise BadRequest('No JSON provided')
|
raise BadRequest("No JSON provided")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
valid = validator.validate(reqjson)
|
valid = validator.validate(reqjson)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('Error while validating')
|
log.exception("Error while validating")
|
||||||
raise Exception(f'Error while validating: {reqjson}')
|
raise Exception(f"Error while validating: {reqjson}")
|
||||||
|
|
||||||
if not valid:
|
if not valid:
|
||||||
errs = validator.errors
|
errs = validator.errors
|
||||||
log.warning('Error validating doc {!r}: {!r}', reqjson, errs)
|
log.warning("Error validating doc {!r}: {!r}", reqjson, errs)
|
||||||
|
|
||||||
if raise_err:
|
if raise_err:
|
||||||
raise BadRequest('bad payload', errs)
|
raise BadRequest("bad payload", errs)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -197,554 +202,441 @@ def validate(reqjson: Optional[Union[Dict, List]], schema: Dict,
|
||||||
|
|
||||||
|
|
||||||
REGISTER = {
|
REGISTER = {
|
||||||
'username': {'type': 'username', 'required': True},
|
"username": {"type": "username", "required": True},
|
||||||
'email': {'type': 'email', 'required': False},
|
"email": {"type": "email", "required": False},
|
||||||
'password': {'type': 'password', 'required': False},
|
"password": {"type": "password", "required": False},
|
||||||
|
|
||||||
# invite stands for a guild invite, not an instance invite (that's on
|
# invite stands for a guild invite, not an instance invite (that's on
|
||||||
# the register_with_invite handler).
|
# the register_with_invite handler).
|
||||||
'invite': {'type': 'string', 'required': False, 'nullable': True},
|
"invite": {"type": "string", "required": False, "nullable": True},
|
||||||
|
|
||||||
# following fields only sent by official client, unused by us
|
# following fields only sent by official client, unused by us
|
||||||
'fingerprint': {'type': 'string', 'required': False, 'nullable': True},
|
"fingerprint": {"type": "string", "required": False, "nullable": True},
|
||||||
'captcha_key': {'type': 'string', 'required': False, 'nullable': True},
|
"captcha_key": {"type": "string", "required": False, "nullable": True},
|
||||||
'gift_code_sku_id': {'type': 'string', 'required': False, 'nullable': True},
|
"gift_code_sku_id": {"type": "string", "required": False, "nullable": True},
|
||||||
'consent': {'type': 'boolean', 'required': False},
|
"consent": {"type": "boolean", "required": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
# only used by us, not discord, hence 'invcode' (to separate from discord)
|
# only used by us, not discord, hence 'invcode' (to separate from discord)
|
||||||
REGISTER_WITH_INVITE = {**REGISTER, **{
|
REGISTER_WITH_INVITE = {**REGISTER, **{"invcode": {"type": "string", "required": True}}}
|
||||||
'invcode': {'type': 'string', 'required': True}
|
|
||||||
}}
|
|
||||||
|
|
||||||
|
|
||||||
USER_UPDATE = {
|
USER_UPDATE = {
|
||||||
'username': {
|
"username": {
|
||||||
'type': 'username', 'minlength': 2,
|
"type": "username",
|
||||||
'maxlength': 30, 'required': False},
|
"minlength": 2,
|
||||||
|
"maxlength": 30,
|
||||||
'discriminator': {
|
"required": False,
|
||||||
'type': 'discriminator',
|
|
||||||
'required': False,
|
|
||||||
'nullable': True,
|
|
||||||
},
|
},
|
||||||
|
"discriminator": {"type": "discriminator", "required": False, "nullable": True},
|
||||||
'password': {
|
"password": {"type": "password", "required": False},
|
||||||
'type': 'password', 'required': False,
|
"new_password": {
|
||||||
|
"type": "password",
|
||||||
|
"required": False,
|
||||||
|
"dependencies": "password",
|
||||||
|
"nullable": True,
|
||||||
},
|
},
|
||||||
|
"email": {"type": "email", "required": False, "dependencies": "password"},
|
||||||
'new_password': {
|
"avatar": {
|
||||||
'type': 'password', 'required': False,
|
|
||||||
'dependencies': 'password', 'nullable': True
|
|
||||||
},
|
|
||||||
|
|
||||||
'email': {
|
|
||||||
'type': 'email', 'required': False, 'dependencies': 'password',
|
|
||||||
},
|
|
||||||
|
|
||||||
'avatar': {
|
|
||||||
# can be both b64_icon or string (just the hash)
|
# can be both b64_icon or string (just the hash)
|
||||||
'type': 'string', 'required': False,
|
"type": "string",
|
||||||
'nullable': True
|
"required": False,
|
||||||
|
"nullable": True,
|
||||||
},
|
},
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PARTIAL_ROLE_GUILD_CREATE = {
|
PARTIAL_ROLE_GUILD_CREATE = {
|
||||||
'type': 'dict',
|
"type": "dict",
|
||||||
'schema': {
|
"schema": {
|
||||||
'name': {'type': 'role_name'},
|
"name": {"type": "role_name"},
|
||||||
'color': {'type': 'number', 'default': 0},
|
"color": {"type": "number", "default": 0},
|
||||||
'hoist': {'type': 'boolean', 'default': False},
|
"hoist": {"type": "boolean", "default": False},
|
||||||
|
|
||||||
# NOTE: no position on partial role (on guild create)
|
# NOTE: no position on partial role (on guild create)
|
||||||
|
"permissions": {"coerce": Permissions, "required": False},
|
||||||
'permissions': {'coerce': Permissions, 'required': False},
|
"mentionable": {"type": "boolean", "default": False},
|
||||||
'mentionable': {'type': 'boolean', 'default': False},
|
},
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PARTIAL_CHANNEL_GUILD_CREATE = {
|
PARTIAL_CHANNEL_GUILD_CREATE = {
|
||||||
'type': 'dict',
|
"type": "dict",
|
||||||
'schema': {
|
"schema": {"name": {"type": "channel_name"}, "type": {"type": "channel_type"}},
|
||||||
'name': {'type': 'channel_name'},
|
|
||||||
'type': {'type': 'channel_type'},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GUILD_CREATE = {
|
GUILD_CREATE = {
|
||||||
'name': {'type': 'guild_name'},
|
"name": {"type": "guild_name"},
|
||||||
'region': {'type': 'voice_region', 'nullable': True},
|
"region": {"type": "voice_region", "nullable": True},
|
||||||
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
|
"icon": {"type": "b64_icon", "required": False, "nullable": True},
|
||||||
|
"verification_level": {"type": "verification_level", "default": 0},
|
||||||
'verification_level': {
|
"default_message_notifications": {"type": "msg_notifications", "default": 0},
|
||||||
'type': 'verification_level', 'default': 0},
|
"explicit_content_filter": {"type": "explicit", "default": 0},
|
||||||
'default_message_notifications': {
|
"roles": {"type": "list", "required": False, "schema": PARTIAL_ROLE_GUILD_CREATE},
|
||||||
'type': 'msg_notifications', 'default': 0},
|
"channels": {"type": "list", "default": [], "schema": PARTIAL_CHANNEL_GUILD_CREATE},
|
||||||
'explicit_content_filter': {
|
|
||||||
'type': 'explicit', 'default': 0},
|
|
||||||
|
|
||||||
'roles': {
|
|
||||||
'type': 'list', 'required': False,
|
|
||||||
'schema': PARTIAL_ROLE_GUILD_CREATE},
|
|
||||||
'channels': {
|
|
||||||
'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
GUILD_UPDATE = {
|
GUILD_UPDATE = {
|
||||||
'name': {
|
"name": {"type": "guild_name", "required": False},
|
||||||
'type': 'guild_name',
|
"region": {"type": "voice_region", "required": False, "nullable": True},
|
||||||
'required': False
|
|
||||||
},
|
|
||||||
'region': {'type': 'voice_region', 'required': False, 'nullable': True},
|
|
||||||
|
|
||||||
# all three can have hashes
|
# all three can have hashes
|
||||||
'icon': {'type': 'string', 'required': False, 'nullable': True},
|
"icon": {"type": "string", "required": False, "nullable": True},
|
||||||
'banner': {'type': 'string', 'required': False, 'nullable': True},
|
"banner": {"type": "string", "required": False, "nullable": True},
|
||||||
'splash': {'type': 'string', 'required': False, 'nullable': True},
|
"splash": {"type": "string", "required": False, "nullable": True},
|
||||||
|
"description": {
|
||||||
'description': {
|
"type": "string",
|
||||||
'type': 'string', 'required': False,
|
"required": False,
|
||||||
'minlength': 1, 'maxlength': 120,
|
"minlength": 1,
|
||||||
'nullable': True
|
"maxlength": 120,
|
||||||
|
"nullable": True,
|
||||||
},
|
},
|
||||||
|
"verification_level": {"type": "verification_level", "required": False},
|
||||||
'verification_level': {
|
"default_message_notifications": {"type": "msg_notifications", "required": False},
|
||||||
'type': 'verification_level', 'required': False},
|
"explicit_content_filter": {"type": "explicit", "required": False},
|
||||||
'default_message_notifications': {
|
"afk_channel_id": {"type": "snowflake", "required": False, "nullable": True},
|
||||||
'type': 'msg_notifications', 'required': False},
|
"afk_timeout": {"type": "number", "required": False},
|
||||||
'explicit_content_filter': {'type': 'explicit', 'required': False},
|
"owner_id": {"type": "snowflake", "required": False},
|
||||||
|
"system_channel_id": {"type": "snowflake", "required": False, "nullable": True},
|
||||||
'afk_channel_id': {
|
|
||||||
'type': 'snowflake', 'required': False, 'nullable': True},
|
|
||||||
'afk_timeout': {'type': 'number', 'required': False},
|
|
||||||
|
|
||||||
'owner_id': {'type': 'snowflake', 'required': False},
|
|
||||||
|
|
||||||
'system_channel_id': {
|
|
||||||
'type': 'snowflake', 'required': False, 'nullable': True},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CHAN_OVERWRITE = {
|
CHAN_OVERWRITE = {
|
||||||
'id': {'coerce': int},
|
"id": {"coerce": int},
|
||||||
'type': {'type': 'string', 'allowed': ['role', 'member']},
|
"type": {"type": "string", "allowed": ["role", "member"]},
|
||||||
'allow': {'coerce': Permissions},
|
"allow": {"coerce": Permissions},
|
||||||
'deny': {'coerce': Permissions}
|
"deny": {"coerce": Permissions},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CHAN_CREATE = {
|
CHAN_CREATE = {
|
||||||
'name': {
|
"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": True},
|
||||||
'type': 'string', 'minlength': 2,
|
"type": {"type": "channel_type", "default": ChannelType.GUILD_TEXT.value},
|
||||||
'maxlength': 100, 'required': True
|
"position": {"coerce": int, "required": False},
|
||||||
},
|
"topic": {"type": "string", "minlength": 0, "maxlength": 1024, "required": False},
|
||||||
|
"nsfw": {"type": "boolean", "required": False},
|
||||||
'type': {'type': 'channel_type',
|
"rate_limit_per_user": {"coerce": int, "min": 0, "max": 120, "required": False},
|
||||||
'default': ChannelType.GUILD_TEXT.value},
|
"bitrate": {
|
||||||
|
"coerce": int,
|
||||||
'position': {'coerce': int, 'required': False},
|
"min": 8000,
|
||||||
|
|
||||||
'topic': {
|
|
||||||
'type': 'string', 'minlength': 0,
|
|
||||||
'maxlength': 1024, 'required': False},
|
|
||||||
|
|
||||||
'nsfw': {'type': 'boolean', 'required': False},
|
|
||||||
'rate_limit_per_user': {
|
|
||||||
'coerce': int, 'min': 0,
|
|
||||||
'max': 120, 'required': False},
|
|
||||||
|
|
||||||
'bitrate': {
|
|
||||||
'coerce': int, 'min': 8000,
|
|
||||||
|
|
||||||
# NOTE: 'max' is 96000 for non-vip guilds
|
# NOTE: 'max' is 96000 for non-vip guilds
|
||||||
'max': 128000, 'required': False},
|
"max": 128000,
|
||||||
|
"required": False,
|
||||||
'user_limit': {
|
},
|
||||||
|
"user_limit": {
|
||||||
# user_limit being 0 means infinite.
|
# user_limit being 0 means infinite.
|
||||||
'coerce': int, 'min': 0,
|
"coerce": int,
|
||||||
'max': 99, 'required': False
|
"min": 0,
|
||||||
|
"max": 99,
|
||||||
|
"required": False,
|
||||||
},
|
},
|
||||||
|
"permission_overwrites": {
|
||||||
'permission_overwrites': {
|
"type": "list",
|
||||||
'type': 'list',
|
"schema": {"type": "dict", "schema": CHAN_OVERWRITE},
|
||||||
'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE},
|
"required": False,
|
||||||
'required': False
|
|
||||||
},
|
},
|
||||||
|
"parent_id": {"coerce": int, "required": False, "nullable": True},
|
||||||
'parent_id': {'coerce': int, 'required': False, 'nullable': True}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CHAN_UPDATE = {**CHAN_CREATE, **{
|
CHAN_UPDATE = {
|
||||||
'name': {
|
**CHAN_CREATE,
|
||||||
'type': 'string', 'minlength': 2,
|
**{"name": {"type": "string", "minlength": 2, "maxlength": 100, "required": False}},
|
||||||
'maxlength': 100, 'required': False},
|
}
|
||||||
|
|
||||||
}}
|
|
||||||
|
|
||||||
|
|
||||||
ROLE_CREATE = {
|
ROLE_CREATE = {
|
||||||
'name': {'type': 'string', 'default': 'new role'},
|
"name": {"type": "string", "default": "new role"},
|
||||||
'permissions': {'coerce': Permissions, 'nullable': True},
|
"permissions": {"coerce": Permissions, "nullable": True},
|
||||||
'color': {'coerce': Color, 'default': 0},
|
"color": {"coerce": Color, "default": 0},
|
||||||
'hoist': {'type': 'boolean', 'default': False},
|
"hoist": {"type": "boolean", "default": False},
|
||||||
'mentionable': {'type': 'boolean', 'default': False},
|
"mentionable": {"type": "boolean", "default": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
ROLE_UPDATE = {
|
ROLE_UPDATE = {
|
||||||
'name': {'type': 'string', 'required': False},
|
"name": {"type": "string", "required": False},
|
||||||
'permissions': {'coerce': Permissions, 'required': False},
|
"permissions": {"coerce": Permissions, "required": False},
|
||||||
'color': {'coerce': Color, 'required': False},
|
"color": {"coerce": Color, "required": False},
|
||||||
'hoist': {'type': 'boolean', 'required': False},
|
"hoist": {"type": "boolean", "required": False},
|
||||||
'mentionable': {'type': 'boolean', 'required': False},
|
"mentionable": {"type": "boolean", "required": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ROLE_UPDATE_POSITION = {
|
ROLE_UPDATE_POSITION = {
|
||||||
'roles': {
|
"roles": {
|
||||||
'type': 'list',
|
"type": "list",
|
||||||
'schema': {
|
"schema": {
|
||||||
'type': 'dict',
|
"type": "dict",
|
||||||
'schema': {
|
"schema": {"id": {"coerce": int}, "position": {"coerce": int}},
|
||||||
'id': {'coerce': int},
|
|
||||||
'position': {'coerce': int},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MEMBER_UPDATE = {
|
MEMBER_UPDATE = {
|
||||||
'nick': {
|
"nick": {"type": "nickname", "required": False},
|
||||||
'type': 'nickname', 'required': False},
|
"roles": {"type": "list", "required": False, "schema": {"coerce": int}},
|
||||||
'roles': {'type': 'list', 'required': False,
|
"mute": {"type": "boolean", "required": False},
|
||||||
'schema': {'coerce': int}},
|
"deaf": {"type": "boolean", "required": False},
|
||||||
'mute': {'type': 'boolean', 'required': False},
|
"channel_id": {"type": "snowflake", "required": False},
|
||||||
'deaf': {'type': 'boolean', 'required': False},
|
|
||||||
'channel_id': {'type': 'snowflake', 'required': False},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# NOTE: things such as payload_json are parsed at the handler
|
# NOTE: things such as payload_json are parsed at the handler
|
||||||
# for creating a message.
|
# for creating a message.
|
||||||
MESSAGE_CREATE = {
|
MESSAGE_CREATE = {
|
||||||
'content': {'type': 'string', 'minlength': 0, 'maxlength': 2000},
|
"content": {"type": "string", "minlength": 0, "maxlength": 2000},
|
||||||
'nonce': {'type': 'snowflake', 'required': False},
|
"nonce": {"type": "snowflake", "required": False},
|
||||||
'tts': {'type': 'boolean', 'required': False},
|
"tts": {"type": "boolean", "required": False},
|
||||||
|
"embed": {
|
||||||
'embed': {
|
"type": "dict",
|
||||||
'type': 'dict',
|
"schema": EMBED_OBJECT,
|
||||||
'schema': EMBED_OBJECT,
|
"required": False,
|
||||||
'required': False,
|
"nullable": True,
|
||||||
'nullable': True
|
},
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
GW_ACTIVITY = {
|
GW_ACTIVITY = {
|
||||||
'name': {'type': 'string', 'required': True},
|
"name": {"type": "string", "required": True},
|
||||||
'type': {'type': 'activity_type', 'required': True},
|
"type": {"type": "activity_type", "required": True},
|
||||||
|
"url": {"type": "string", "required": False, "nullable": True},
|
||||||
'url': {'type': 'string', 'required': False, 'nullable': True},
|
"timestamps": {
|
||||||
|
"type": "dict",
|
||||||
'timestamps': {
|
"required": False,
|
||||||
'type': 'dict',
|
"schema": {
|
||||||
'required': False,
|
"start": {"type": "number", "required": False},
|
||||||
'schema': {
|
"end": {"type": "number", "required": False},
|
||||||
'start': {'type': 'number', 'required': False},
|
|
||||||
'end': {'type': 'number', 'required': False},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"application_id": {"type": "snowflake", "required": False, "nullable": False},
|
||||||
'application_id': {'type': 'snowflake', 'required': False,
|
"details": {"type": "string", "required": False, "nullable": True},
|
||||||
'nullable': False},
|
"state": {"type": "string", "required": False, "nullable": True},
|
||||||
'details': {'type': 'string', 'required': False, 'nullable': True},
|
"party": {
|
||||||
'state': {'type': 'string', 'required': False, 'nullable': True},
|
"type": "dict",
|
||||||
|
"required": False,
|
||||||
'party': {
|
"schema": {
|
||||||
'type': 'dict',
|
"id": {"type": "snowflake", "required": False},
|
||||||
'required': False,
|
"size": {"type": "list", "required": False},
|
||||||
'schema': {
|
|
||||||
'id': {'type': 'snowflake', 'required': False},
|
|
||||||
'size': {'type': 'list', 'required': False},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
|
||||||
'assets': {
|
|
||||||
'type': 'dict',
|
|
||||||
'required': False,
|
|
||||||
'schema': {
|
|
||||||
'large_image': {'type': 'snowflake', 'required': False},
|
|
||||||
'large_text': {'type': 'string', 'required': False},
|
|
||||||
'small_image': {'type': 'snowflake', 'required': False},
|
|
||||||
'small_text': {'type': 'string', 'required': False},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
"assets": {
|
||||||
'secrets': {
|
"type": "dict",
|
||||||
'type': 'dict',
|
"required": False,
|
||||||
'required': False,
|
"schema": {
|
||||||
'schema': {
|
"large_image": {"type": "snowflake", "required": False},
|
||||||
'join': {'type': 'string', 'required': False},
|
"large_text": {"type": "string", "required": False},
|
||||||
'spectate': {'type': 'string', 'required': False},
|
"small_image": {"type": "snowflake", "required": False},
|
||||||
'match': {'type': 'string', 'required': False},
|
"small_text": {"type": "string", "required": False},
|
||||||
}
|
|
||||||
},
|
},
|
||||||
|
},
|
||||||
'instance': {'type': 'boolean', 'required': False},
|
"secrets": {
|
||||||
'flags': {'type': 'number', 'required': False},
|
"type": "dict",
|
||||||
|
"required": False,
|
||||||
|
"schema": {
|
||||||
|
"join": {"type": "string", "required": False},
|
||||||
|
"spectate": {"type": "string", "required": False},
|
||||||
|
"match": {"type": "string", "required": False},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"instance": {"type": "boolean", "required": False},
|
||||||
|
"flags": {"type": "number", "required": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
GW_STATUS_UPDATE = {
|
GW_STATUS_UPDATE = {
|
||||||
'status': {'type': 'status_external', 'required': False,
|
"status": {"type": "status_external", "required": False, "default": "online"},
|
||||||
'default': 'online'},
|
"activities": {
|
||||||
'activities': {
|
"type": "list",
|
||||||
'type': 'list', 'required': False,
|
"required": False,
|
||||||
'schema': {'type': 'dict', 'schema': GW_ACTIVITY}
|
"schema": {"type": "dict", "schema": GW_ACTIVITY},
|
||||||
},
|
},
|
||||||
'afk': {'type': 'boolean', 'required': False},
|
"afk": {"type": "boolean", "required": False},
|
||||||
|
"since": {"type": "number", "required": False, "nullable": True},
|
||||||
'since': {'type': 'number', 'required': False, 'nullable': True},
|
"game": {
|
||||||
'game': {
|
"type": "dict",
|
||||||
'type': 'dict',
|
"required": False,
|
||||||
'required': False,
|
"nullable": True,
|
||||||
'nullable': True,
|
"schema": GW_ACTIVITY,
|
||||||
'schema': GW_ACTIVITY,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
INVITE = {
|
INVITE = {
|
||||||
# max_age in seconds
|
# max_age in seconds
|
||||||
# 0 for infinite
|
# 0 for infinite
|
||||||
'max_age': {
|
"max_age": {
|
||||||
'type': 'number',
|
"type": "number",
|
||||||
'min': 0,
|
"min": 0,
|
||||||
'max': 86400,
|
"max": 86400,
|
||||||
|
|
||||||
# a day
|
# a day
|
||||||
'default': 86400
|
"default": 86400,
|
||||||
},
|
},
|
||||||
|
|
||||||
# max invite uses
|
# max invite uses
|
||||||
'max_uses': {
|
"max_uses": {
|
||||||
'type': 'number',
|
"type": "number",
|
||||||
'min': 0,
|
"min": 0,
|
||||||
|
|
||||||
# idk
|
# idk
|
||||||
'max': 1000,
|
"max": 1000,
|
||||||
|
|
||||||
# default infinite
|
# default infinite
|
||||||
'default': 0
|
"default": 0,
|
||||||
},
|
},
|
||||||
|
"temporary": {"type": "boolean", "required": False, "default": False},
|
||||||
'temporary': {'type': 'boolean', 'required': False, 'default': False},
|
"unique": {"type": "boolean", "required": False, "default": True},
|
||||||
'unique': {'type': 'boolean', 'required': False, 'default': True},
|
"validate": {
|
||||||
'validate': {'type': 'string', 'required': False, 'nullable': True} # discord client sends invite code there
|
"type": "string",
|
||||||
|
"required": False,
|
||||||
|
"nullable": True,
|
||||||
|
}, # discord client sends invite code there
|
||||||
}
|
}
|
||||||
|
|
||||||
USER_SETTINGS = {
|
USER_SETTINGS = {
|
||||||
'afk_timeout': {
|
"afk_timeout": {"type": "number", "required": False, "min": 0, "max": 3000},
|
||||||
'type': 'number', 'required': False, 'min': 0, 'max': 3000},
|
"animate_emoji": {"type": "boolean", "required": False},
|
||||||
|
"convert_emoticons": {"type": "boolean", "required": False},
|
||||||
'animate_emoji': {'type': 'boolean', 'required': False},
|
"default_guilds_restricted": {"type": "boolean", "required": False},
|
||||||
'convert_emoticons': {'type': 'boolean', 'required': False},
|
"detect_platform_accounts": {"type": "boolean", "required": False},
|
||||||
'default_guilds_restricted': {'type': 'boolean', 'required': False},
|
"developer_mode": {"type": "boolean", "required": False},
|
||||||
'detect_platform_accounts': {'type': 'boolean', 'required': False},
|
"disable_games_tab": {"type": "boolean", "required": False},
|
||||||
'developer_mode': {'type': 'boolean', 'required': False},
|
"enable_tts_command": {"type": "boolean", "required": False},
|
||||||
'disable_games_tab': {'type': 'boolean', 'required': False},
|
"explicit_content_filter": {"type": "explicit", "required": False},
|
||||||
'enable_tts_command': {'type': 'boolean', 'required': False},
|
"friend_source": {
|
||||||
|
"type": "dict",
|
||||||
'explicit_content_filter': {'type': 'explicit', 'required': False},
|
"required": False,
|
||||||
|
"schema": {
|
||||||
'friend_source': {
|
"all": {"type": "boolean", "required": False},
|
||||||
'type': 'dict',
|
"mutual_guilds": {"type": "boolean", "required": False},
|
||||||
'required': False,
|
"mutual_friends": {"type": "boolean", "required": False},
|
||||||
'schema': {
|
|
||||||
'all': {'type': 'boolean', 'required': False},
|
|
||||||
'mutual_guilds': {'type': 'boolean', 'required': False},
|
|
||||||
'mutual_friends': {'type': 'boolean', 'required': False},
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
'guild_positions': {
|
|
||||||
'type': 'list',
|
|
||||||
'required': False,
|
|
||||||
'schema': {'type': 'snowflake'}
|
|
||||||
},
|
},
|
||||||
'restricted_guilds': {
|
"guild_positions": {
|
||||||
'type': 'list',
|
"type": "list",
|
||||||
'required': False,
|
"required": False,
|
||||||
'schema': {'type': 'snowflake'}
|
"schema": {"type": "snowflake"},
|
||||||
},
|
},
|
||||||
|
"restricted_guilds": {
|
||||||
'gif_auto_play': {'type': 'boolean', 'required': False},
|
"type": "list",
|
||||||
'inline_attachment_media': {'type': 'boolean', 'required': False},
|
"required": False,
|
||||||
'inline_embed_media': {'type': 'boolean', 'required': False},
|
"schema": {"type": "snowflake"},
|
||||||
'message_display_compact': {'type': 'boolean', 'required': False},
|
},
|
||||||
'render_embeds': {'type': 'boolean', 'required': False},
|
"gif_auto_play": {"type": "boolean", "required": False},
|
||||||
'render_reactions': {'type': 'boolean', 'required': False},
|
"inline_attachment_media": {"type": "boolean", "required": False},
|
||||||
'show_current_game': {'type': 'boolean', 'required': False},
|
"inline_embed_media": {"type": "boolean", "required": False},
|
||||||
|
"message_display_compact": {"type": "boolean", "required": False},
|
||||||
'timezone_offset': {'type': 'number', 'required': False},
|
"render_embeds": {"type": "boolean", "required": False},
|
||||||
|
"render_reactions": {"type": "boolean", "required": False},
|
||||||
'status': {'type': 'status_external', 'required': False},
|
"show_current_game": {"type": "boolean", "required": False},
|
||||||
'theme': {'type': 'theme', 'required': False}
|
"timezone_offset": {"type": "number", "required": False},
|
||||||
|
"status": {"type": "status_external", "required": False},
|
||||||
|
"theme": {"type": "theme", "required": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
RELATIONSHIP = {
|
RELATIONSHIP = {
|
||||||
'type': {
|
"type": {
|
||||||
'type': 'rel_type',
|
"type": "rel_type",
|
||||||
'required': False,
|
"required": False,
|
||||||
'default': RelationshipType.FRIEND.value
|
"default": RelationshipType.FRIEND.value,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CREATE_DM = {
|
CREATE_DM = {"recipient_id": {"type": "snowflake", "required": True}}
|
||||||
'recipient_id': {
|
|
||||||
'type': 'snowflake',
|
|
||||||
'required': True
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CREATE_GROUP_DM = {
|
CREATE_GROUP_DM = {
|
||||||
'recipients': {
|
"recipients": {"type": "list", "required": True, "schema": {"type": "snowflake"}}
|
||||||
'type': 'list',
|
|
||||||
'required': True,
|
|
||||||
'schema': {'type': 'snowflake'}
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GROUP_DM_UPDATE = {
|
GROUP_DM_UPDATE = {
|
||||||
'name': {
|
"name": {"type": "guild_name", "required": False},
|
||||||
'type': 'guild_name',
|
"icon": {"type": "b64_icon", "required": False, "nullable": True},
|
||||||
'required': False
|
|
||||||
},
|
|
||||||
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SPECIFIC_FRIEND = {
|
SPECIFIC_FRIEND = {
|
||||||
'username': {'type': 'username'},
|
"username": {"type": "username"},
|
||||||
'discriminator': {'type': 'discriminator'}
|
"discriminator": {"type": "discriminator"},
|
||||||
}
|
}
|
||||||
|
|
||||||
GUILD_SETTINGS_CHAN_OVERRIDE = {
|
GUILD_SETTINGS_CHAN_OVERRIDE = {
|
||||||
'type': 'dict',
|
"type": "dict",
|
||||||
'schema': {
|
"schema": {
|
||||||
'muted': {
|
"muted": {"type": "boolean", "required": False},
|
||||||
'type': 'boolean', 'required': False},
|
"message_notifications": {"type": "msg_notifications", "required": False},
|
||||||
'message_notifications': {
|
},
|
||||||
'type': 'msg_notifications',
|
|
||||||
'required': False,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GUILD_SETTINGS = {
|
GUILD_SETTINGS = {
|
||||||
'channel_overrides': {
|
"channel_overrides": {
|
||||||
'type': 'dict',
|
"type": "dict",
|
||||||
'valueschema': GUILD_SETTINGS_CHAN_OVERRIDE,
|
"valueschema": GUILD_SETTINGS_CHAN_OVERRIDE,
|
||||||
'keyschema': {'type': 'snowflake'},
|
"keyschema": {"type": "snowflake"},
|
||||||
'required': False,
|
"required": False,
|
||||||
},
|
},
|
||||||
'suppress_everyone': {
|
"suppress_everyone": {"type": "boolean", "required": False},
|
||||||
'type': 'boolean', 'required': False},
|
"muted": {"type": "boolean", "required": False},
|
||||||
'muted': {
|
"mobile_push": {"type": "boolean", "required": False},
|
||||||
'type': 'boolean', 'required': False},
|
"message_notifications": {"type": "msg_notifications", "required": False},
|
||||||
'mobile_push': {
|
|
||||||
'type': 'boolean', 'required': False},
|
|
||||||
'message_notifications': {
|
|
||||||
'type': 'msg_notifications',
|
|
||||||
'required': False,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GUILD_PRUNE = {
|
GUILD_PRUNE = {
|
||||||
'days': {'type': 'number', 'coerce': int, 'min': 1, 'max': 30, 'default': 7},
|
"days": {"type": "number", "coerce": int, "min": 1, "max": 30, "default": 7},
|
||||||
'compute_prune_count': {'type': 'string', 'default': 'true'}
|
"compute_prune_count": {"type": "string", "default": "true"},
|
||||||
}
|
}
|
||||||
|
|
||||||
NEW_EMOJI = {
|
NEW_EMOJI = {
|
||||||
'name': {
|
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True},
|
"image": {"type": "b64_icon", "required": True},
|
||||||
'image': {'type': 'b64_icon', 'required': True},
|
"roles": {"type": "list", "schema": {"coerce": int}},
|
||||||
'roles': {'type': 'list', 'schema': {'coerce': int}}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PATCH_EMOJI = {
|
PATCH_EMOJI = {
|
||||||
'name': {
|
"name": {"type": "string", "minlength": 1, "maxlength": 256, "required": True},
|
||||||
'type': 'string', 'minlength': 1, 'maxlength': 256, 'required': True},
|
"roles": {"type": "list", "schema": {"coerce": int}},
|
||||||
'roles': {'type': 'list', 'schema': {'coerce': int}}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
SEARCH_CHANNEL = {
|
SEARCH_CHANNEL = {
|
||||||
'content': {'type': 'string', 'minlength': 1, 'required': True},
|
"content": {"type": "string", "minlength": 1, "required": True},
|
||||||
'include_nsfw': {'coerce': bool, 'default': False},
|
"include_nsfw": {"coerce": bool, "default": False},
|
||||||
'offset': {'coerce': int, 'default': 0}
|
"offset": {"coerce": int, "default": 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
GET_MENTIONS = {
|
GET_MENTIONS = {
|
||||||
'limit': {'coerce': int, 'default': 25},
|
"limit": {"coerce": int, "default": 25},
|
||||||
'roles': {'coerce': bool, 'default': True},
|
"roles": {"coerce": bool, "default": True},
|
||||||
'everyone': {'coerce': bool, 'default': True},
|
"everyone": {"coerce": bool, "default": True},
|
||||||
'guild_id': {'coerce': int, 'required': False}
|
"guild_id": {"coerce": int, "required": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
VANITY_URL_PATCH = {
|
VANITY_URL_PATCH = {
|
||||||
# TODO: put proper values in maybe an invite data type
|
# TODO: put proper values in maybe an invite data type
|
||||||
'code': {'type': 'string', 'minlength': 5, 'maxlength': 30}
|
"code": {"type": "string", "minlength": 5, "maxlength": 30}
|
||||||
}
|
}
|
||||||
|
|
||||||
WEBHOOK_CREATE = {
|
WEBHOOK_CREATE = {
|
||||||
'name': {
|
"name": {"type": "string", "minlength": 2, "maxlength": 32, "required": True},
|
||||||
'type': 'string', 'minlength': 2, 'maxlength': 32,
|
"avatar": {"type": "b64_icon", "required": False, "nullable": False},
|
||||||
'required': True
|
|
||||||
},
|
|
||||||
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
WEBHOOK_UPDATE = {
|
WEBHOOK_UPDATE = {
|
||||||
'name': {
|
"name": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
|
||||||
'type': 'string', 'minlength': 2, 'maxlength': 32,
|
|
||||||
'required': False
|
|
||||||
},
|
|
||||||
|
|
||||||
# TODO: check if its b64_icon or string since the client
|
# TODO: check if its b64_icon or string since the client
|
||||||
# could pass an icon hash instead.
|
# could pass an icon hash instead.
|
||||||
'avatar': {'type': 'b64_icon', 'required': False, 'nullable': False},
|
"avatar": {"type": "b64_icon", "required": False, "nullable": False},
|
||||||
'channel_id': {'coerce': int, 'required': False, 'nullable': False}
|
"channel_id": {"coerce": int, "required": False, "nullable": False},
|
||||||
}
|
}
|
||||||
|
|
||||||
WEBHOOK_MESSAGE_CREATE = {
|
WEBHOOK_MESSAGE_CREATE = {
|
||||||
'content': {
|
"content": {"type": "string", "minlength": 0, "maxlength": 2000, "required": False},
|
||||||
'type': 'string',
|
"tts": {"type": "boolean", "required": False},
|
||||||
'minlength': 0, 'maxlength': 2000, 'required': False
|
"username": {"type": "string", "minlength": 2, "maxlength": 32, "required": False},
|
||||||
|
"avatar_url": {"coerce": EmbedURL, "required": False},
|
||||||
|
"embeds": {
|
||||||
|
"type": "list",
|
||||||
|
"required": False,
|
||||||
|
"schema": {"type": "dict", "schema": EMBED_OBJECT},
|
||||||
},
|
},
|
||||||
'tts': {'type': 'boolean', 'required': False},
|
|
||||||
|
|
||||||
'username': {
|
|
||||||
'type': 'string',
|
|
||||||
'minlength': 2, 'maxlength': 32, 'required': False
|
|
||||||
},
|
|
||||||
|
|
||||||
'avatar_url': {
|
|
||||||
'coerce': EmbedURL, 'required': False
|
|
||||||
},
|
|
||||||
|
|
||||||
'embeds': {
|
|
||||||
'type': 'list',
|
|
||||||
'required': False,
|
|
||||||
'schema': {'type': 'dict', 'schema': EMBED_OBJECT}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BULK_DELETE = {
|
BULK_DELETE = {
|
||||||
'messages': {
|
"messages": {
|
||||||
'type': 'list', 'required': True,
|
"type": "list",
|
||||||
'minlength': 2, 'maxlength': 100,
|
"required": True,
|
||||||
'schema': {'coerce': int}
|
"minlength": 2,
|
||||||
|
"maxlength": 100,
|
||||||
|
"schema": {"coerce": int},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,19 +61,19 @@ def _snowflake(timestamp: int) -> Snowflake:
|
||||||
# bits 0-12 encode _generated_ids (size 12)
|
# bits 0-12 encode _generated_ids (size 12)
|
||||||
|
|
||||||
# modulo'd to prevent overflows
|
# modulo'd to prevent overflows
|
||||||
genid_b = '{0:012b}'.format(_generated_ids % 4096)
|
genid_b = "{0:012b}".format(_generated_ids % 4096)
|
||||||
|
|
||||||
# bits 12-17 encode PROCESS_ID (size 5)
|
# bits 12-17 encode PROCESS_ID (size 5)
|
||||||
procid_b = '{0:05b}'.format(PROCESS_ID)
|
procid_b = "{0:05b}".format(PROCESS_ID)
|
||||||
|
|
||||||
# bits 17-22 encode WORKER_ID (size 5)
|
# bits 17-22 encode WORKER_ID (size 5)
|
||||||
workid_b = '{0:05b}'.format(WORKER_ID)
|
workid_b = "{0:05b}".format(WORKER_ID)
|
||||||
|
|
||||||
# bits 22-64 encode (timestamp - EPOCH) (size 42)
|
# bits 22-64 encode (timestamp - EPOCH) (size 42)
|
||||||
epochized = timestamp - EPOCH
|
epochized = timestamp - EPOCH
|
||||||
epoch_b = '{0:042b}'.format(epochized)
|
epoch_b = "{0:042b}".format(epochized)
|
||||||
|
|
||||||
snowflake_b = f'{epoch_b}{workid_b}{procid_b}{genid_b}'
|
snowflake_b = f"{epoch_b}{workid_b}{procid_b}{genid_b}"
|
||||||
_generated_ids += 1
|
_generated_ids += 1
|
||||||
|
|
||||||
return int(snowflake_b, 2)
|
return int(snowflake_b, 2)
|
||||||
|
|
@ -87,7 +87,7 @@ def snowflake_time(snowflake: Snowflake) -> float:
|
||||||
# the total size for a snowflake is 64 bits,
|
# the total size for a snowflake is 64 bits,
|
||||||
# considering it is a string, position 0 to 42 will give us
|
# considering it is a string, position 0 to 42 will give us
|
||||||
# the `epochized` variable
|
# the `epochized` variable
|
||||||
snowflake_b = '{0:064b}'.format(snowflake)
|
snowflake_b = "{0:064b}".format(snowflake)
|
||||||
epochized_b = snowflake_b[:42]
|
epochized_b = snowflake_b[:42]
|
||||||
epochized = int(epochized_b, 2)
|
epochized = int(epochized_b, 2)
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -24,6 +24,7 @@ from litecord.enums import MessageType
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
||||||
"""Handle a message pin."""
|
"""Handle a message pin."""
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
@ -37,8 +38,10 @@ async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
|
||||||
($1, $2, NULL, $3, NULL, '',
|
($1, $2, NULL, $3, NULL, '',
|
||||||
$4)
|
$4)
|
||||||
""",
|
""",
|
||||||
new_id, channel_id, author_id,
|
new_id,
|
||||||
MessageType.CHANNEL_PINNED_MESSAGE.value
|
channel_id,
|
||||||
|
author_id,
|
||||||
|
MessageType.CHANNEL_PINNED_MESSAGE.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_id
|
return new_id
|
||||||
|
|
@ -56,15 +59,16 @@ async def _handle_recp_add(app, channel_id, author_id, peer_id):
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, NULL, $4, $5)
|
($1, $2, $3, NULL, $4, $5)
|
||||||
""",
|
""",
|
||||||
new_id, channel_id, author_id,
|
new_id,
|
||||||
f'<@{peer_id}>',
|
channel_id,
|
||||||
MessageType.RECIPIENT_ADD.value
|
author_id,
|
||||||
|
f"<@{peer_id}>",
|
||||||
|
MessageType.RECIPIENT_ADD.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
||||||
|
|
@ -76,9 +80,11 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, NULL, $4, $5)
|
($1, $2, $3, NULL, $4, $5)
|
||||||
""",
|
""",
|
||||||
new_id, channel_id, author_id,
|
new_id,
|
||||||
f'<@{peer_id}>',
|
channel_id,
|
||||||
MessageType.RECIPIENT_REMOVE.value
|
author_id,
|
||||||
|
f"<@{peer_id}>",
|
||||||
|
MessageType.RECIPIENT_REMOVE.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_id
|
return new_id
|
||||||
|
|
@ -87,13 +93,16 @@ async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
|
||||||
async def _handle_gdm_name_edit(app, channel_id, author_id):
|
async def _handle_gdm_name_edit(app, channel_id, author_id):
|
||||||
new_id = get_snowflake()
|
new_id = get_snowflake()
|
||||||
|
|
||||||
gdm_name = await app.db.fetchval("""
|
gdm_name = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT name FROM group_dm_channels
|
SELECT name FROM group_dm_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""",
|
||||||
|
channel_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not gdm_name:
|
if not gdm_name:
|
||||||
log.warning('no gdm name found for sys message')
|
log.warning("no gdm name found for sys message")
|
||||||
return
|
return
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
|
|
@ -104,9 +113,11 @@ async def _handle_gdm_name_edit(app, channel_id, author_id):
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, NULL, $4, $5)
|
($1, $2, $3, NULL, $4, $5)
|
||||||
""",
|
""",
|
||||||
new_id, channel_id, author_id,
|
new_id,
|
||||||
|
channel_id,
|
||||||
|
author_id,
|
||||||
gdm_name,
|
gdm_name,
|
||||||
MessageType.CHANNEL_NAME_CHANGE.value
|
MessageType.CHANNEL_NAME_CHANGE.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_id
|
return new_id
|
||||||
|
|
@ -123,16 +134,19 @@ async def _handle_gdm_icon_edit(app, channel_id, author_id):
|
||||||
VALUES
|
VALUES
|
||||||
($1, $2, $3, NULL, $4, $5)
|
($1, $2, $3, NULL, $4, $5)
|
||||||
""",
|
""",
|
||||||
new_id, channel_id, author_id,
|
new_id,
|
||||||
'',
|
channel_id,
|
||||||
MessageType.CHANNEL_ICON_CHANGE.value
|
author_id,
|
||||||
|
"",
|
||||||
|
MessageType.CHANNEL_ICON_CHANGE.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_id
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
async def send_sys_message(app, channel_id: int, m_type: MessageType,
|
async def send_sys_message(
|
||||||
*args, **kwargs) -> int:
|
app, channel_id: int, m_type: MessageType, *args, **kwargs
|
||||||
|
) -> int:
|
||||||
"""Send a system message.
|
"""Send a system message.
|
||||||
|
|
||||||
The handler for a given message type MUST return an integer, that integer
|
The handler for a given message type MUST return an integer, that integer
|
||||||
|
|
@ -156,22 +170,19 @@ async def send_sys_message(app, channel_id: int, m_type: MessageType,
|
||||||
try:
|
try:
|
||||||
handler = {
|
handler = {
|
||||||
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
|
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
|
||||||
|
|
||||||
# gdm specific
|
# gdm specific
|
||||||
MessageType.RECIPIENT_ADD: _handle_recp_add,
|
MessageType.RECIPIENT_ADD: _handle_recp_add,
|
||||||
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
|
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
|
||||||
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
|
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
|
||||||
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit
|
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit,
|
||||||
}[m_type]
|
}[m_type]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise ValueError('Invalid system message type')
|
raise ValueError("Invalid system message type")
|
||||||
|
|
||||||
message_id = await handler(app, channel_id, *args, **kwargs)
|
message_id = await handler(app, channel_id, *args, **kwargs)
|
||||||
|
|
||||||
message = await app.storage.get_message(message_id)
|
message = await app.storage.get_message(message_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch(
|
await app.dispatcher.dispatch("channel", channel_id, "MESSAGE_CREATE", message)
|
||||||
'channel', channel_id, 'MESSAGE_CREATE', message
|
|
||||||
)
|
|
||||||
|
|
||||||
return message_id
|
return message_id
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ HOURS = 60 * MINUTES
|
||||||
|
|
||||||
class Color:
|
class Color:
|
||||||
"""Custom color class"""
|
"""Custom color class"""
|
||||||
|
|
||||||
def __init__(self, val: int):
|
def __init__(self, val: int):
|
||||||
self.blue = val & 255
|
self.blue = val & 255
|
||||||
self.green = (val >> 8) & 255
|
self.green = (val >> 8) & 255
|
||||||
|
|
@ -37,7 +38,7 @@ class Color:
|
||||||
@property
|
@property
|
||||||
def value(self):
|
def value(self):
|
||||||
"""Give the actual RGB integer encoding this color."""
|
"""Give the actual RGB integer encoding this color."""
|
||||||
return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16)
|
return int("%02x%02x%02x" % (self.red, self.green, self.blue), 16)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
|
|
@ -49,4 +50,4 @@ class Color:
|
||||||
|
|
||||||
def timestamp_(dt) -> Optional[str]:
|
def timestamp_(dt) -> Optional[str]:
|
||||||
"""safer version for dt.isoformat()"""
|
"""safer version for dt.isoformat()"""
|
||||||
return f'{dt.isoformat()}+00:00' if dt else None
|
return f"{dt.isoformat()}+00:00" if dt else None
|
||||||
|
|
|
||||||
|
|
@ -27,43 +27,52 @@ log = Logger(__name__)
|
||||||
|
|
||||||
class UserStorage:
|
class UserStorage:
|
||||||
"""Storage functions related to a single user."""
|
"""Storage functions related to a single user."""
|
||||||
|
|
||||||
def __init__(self, storage):
|
def __init__(self, storage):
|
||||||
self.storage = storage
|
self.storage = storage
|
||||||
self.db = storage.db
|
self.db = storage.db
|
||||||
|
|
||||||
async def fetch_notes(self, user_id: int) -> dict:
|
async def fetch_notes(self, user_id: int) -> dict:
|
||||||
"""Fetch a users' notes"""
|
"""Fetch a users' notes"""
|
||||||
note_rows = await self.db.fetch("""
|
note_rows = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT target_id, note
|
SELECT target_id, note
|
||||||
FROM notes
|
FROM notes
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return {str(row['target_id']): row['note']
|
return {str(row["target_id"]): row["note"] for row in note_rows}
|
||||||
for row in note_rows}
|
|
||||||
|
|
||||||
async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
|
async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
|
||||||
"""Get current user settings."""
|
"""Get current user settings."""
|
||||||
row = await self.storage.fetchrow_with_json("""
|
row = await self.storage.fetchrow_with_json(
|
||||||
|
"""
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM user_settings
|
FROM user_settings
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
log.info('Generating user settings for {}', user_id)
|
log.info("Generating user settings for {}", user_id)
|
||||||
|
|
||||||
await self.db.execute("""
|
await self.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO user_settings (id)
|
INSERT INTO user_settings (id)
|
||||||
VALUES ($1)
|
VALUES ($1)
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
# recalling get_user_settings
|
# recalling get_user_settings
|
||||||
# should work after adding
|
# should work after adding
|
||||||
return await self.get_user_settings(user_id)
|
return await self.get_user_settings(user_id)
|
||||||
|
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
drow.pop('id')
|
drow.pop("id")
|
||||||
return drow
|
return drow
|
||||||
|
|
||||||
async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]:
|
async def get_relationships(self, user_id: int) -> List[Dict[str, Any]]:
|
||||||
|
|
@ -76,11 +85,15 @@ class UserStorage:
|
||||||
_outgoing = RelationshipType.OUTGOING.value
|
_outgoing = RelationshipType.OUTGOING.value
|
||||||
|
|
||||||
# check all outgoing friends
|
# check all outgoing friends
|
||||||
friends = await self.db.fetch("""
|
friends = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT user_id, peer_id, rel_type
|
SELECT user_id, peer_id, rel_type
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE user_id = $1 AND rel_type = $2
|
WHERE user_id = $1 AND rel_type = $2
|
||||||
""", user_id, _friend)
|
""",
|
||||||
|
user_id,
|
||||||
|
_friend,
|
||||||
|
)
|
||||||
friends = list(map(dict, friends))
|
friends = list(map(dict, friends))
|
||||||
|
|
||||||
# mutuals is a list of ints
|
# mutuals is a list of ints
|
||||||
|
|
@ -95,66 +108,80 @@ class UserStorage:
|
||||||
SELECT user_id, peer_id
|
SELECT user_id, peer_id
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
WHERE user_id = $1 AND peer_id = $2 AND rel_type = $3
|
||||||
""", row['peer_id'], row['user_id'],
|
""",
|
||||||
_friend)
|
row["peer_id"],
|
||||||
|
row["user_id"],
|
||||||
|
_friend,
|
||||||
|
)
|
||||||
|
|
||||||
if is_friend is not None:
|
if is_friend is not None:
|
||||||
mutuals.append(row['peer_id'])
|
mutuals.append(row["peer_id"])
|
||||||
|
|
||||||
# fetch friend requests directed at us
|
# fetch friend requests directed at us
|
||||||
incoming_friends = await self.db.fetch("""
|
incoming_friends = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT user_id, peer_id
|
SELECT user_id, peer_id
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE peer_id = $1 AND rel_type = $2
|
WHERE peer_id = $1 AND rel_type = $2
|
||||||
""", user_id, _friend)
|
""",
|
||||||
|
user_id,
|
||||||
|
_friend,
|
||||||
|
)
|
||||||
|
|
||||||
# only need their ids
|
# only need their ids
|
||||||
incoming_friends = [r['user_id'] for r in incoming_friends
|
incoming_friends = [
|
||||||
if r['user_id'] not in mutuals]
|
r["user_id"] for r in incoming_friends if r["user_id"] not in mutuals
|
||||||
|
]
|
||||||
|
|
||||||
# only fetch blocks we did,
|
# only fetch blocks we did,
|
||||||
# not fetching the ones people did to us
|
# not fetching the ones people did to us
|
||||||
blocks = await self.db.fetch("""
|
blocks = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT user_id, peer_id, rel_type
|
SELECT user_id, peer_id, rel_type
|
||||||
FROM relationships
|
FROM relationships
|
||||||
WHERE user_id = $1 AND rel_type = $2
|
WHERE user_id = $1 AND rel_type = $2
|
||||||
""", user_id, _block)
|
""",
|
||||||
|
user_id,
|
||||||
|
_block,
|
||||||
|
)
|
||||||
blocks = list(map(dict, blocks))
|
blocks = list(map(dict, blocks))
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for drow in friends:
|
for drow in friends:
|
||||||
drow['type'] = drow['rel_type']
|
drow["type"] = drow["rel_type"]
|
||||||
drow['id'] = str(drow['peer_id'])
|
drow["id"] = str(drow["peer_id"])
|
||||||
drow.pop('rel_type')
|
drow.pop("rel_type")
|
||||||
|
|
||||||
# check if the receiver is a mutual
|
# check if the receiver is a mutual
|
||||||
# if it isnt, its still on a friend request stage
|
# if it isnt, its still on a friend request stage
|
||||||
if drow['peer_id'] not in mutuals:
|
if drow["peer_id"] not in mutuals:
|
||||||
drow['type'] = _outgoing
|
drow["type"] = _outgoing
|
||||||
|
|
||||||
drow['user'] = await self.storage.get_user(drow['peer_id'])
|
drow["user"] = await self.storage.get_user(drow["peer_id"])
|
||||||
|
|
||||||
drow.pop('user_id')
|
drow.pop("user_id")
|
||||||
drow.pop('peer_id')
|
drow.pop("peer_id")
|
||||||
res.append(drow)
|
res.append(drow)
|
||||||
|
|
||||||
for peer_id in incoming_friends:
|
for peer_id in incoming_friends:
|
||||||
res.append({
|
res.append(
|
||||||
'id': str(peer_id),
|
{
|
||||||
'user': await self.storage.get_user(peer_id),
|
"id": str(peer_id),
|
||||||
'type': _incoming,
|
"user": await self.storage.get_user(peer_id),
|
||||||
})
|
"type": _incoming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
for drow in blocks:
|
for drow in blocks:
|
||||||
drow['type'] = drow['rel_type']
|
drow["type"] = drow["rel_type"]
|
||||||
drow.pop('rel_type')
|
drow.pop("rel_type")
|
||||||
|
|
||||||
drow['id'] = str(drow['peer_id'])
|
drow["id"] = str(drow["peer_id"])
|
||||||
drow['user'] = await self.storage.get_user(drow['peer_id'])
|
drow["user"] = await self.storage.get_user(drow["peer_id"])
|
||||||
|
|
||||||
drow.pop('user_id')
|
drow.pop("user_id")
|
||||||
drow.pop('peer_id')
|
drow.pop("peer_id")
|
||||||
res.append(drow)
|
res.append(drow)
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
@ -163,9 +190,11 @@ class UserStorage:
|
||||||
"""Get all friend IDs for a user."""
|
"""Get all friend IDs for a user."""
|
||||||
rels = await self.get_relationships(user_id)
|
rels = await self.get_relationships(user_id)
|
||||||
|
|
||||||
return [int(r['user']['id'])
|
return [
|
||||||
|
int(r["user"]["id"])
|
||||||
for r in rels
|
for r in rels
|
||||||
if r['type'] == RelationshipType.FRIEND.value]
|
if r["type"] == RelationshipType.FRIEND.value
|
||||||
|
]
|
||||||
|
|
||||||
async def get_dms(self, user_id: int) -> List[Dict[str, Any]]:
|
async def get_dms(self, user_id: int) -> List[Dict[str, Any]]:
|
||||||
"""Get all DM channels for a user, including group DMs.
|
"""Get all DM channels for a user, including group DMs.
|
||||||
|
|
@ -173,13 +202,16 @@ class UserStorage:
|
||||||
This will only fetch channels the user has in their state,
|
This will only fetch channels the user has in their state,
|
||||||
which is different than the whole list of DM channels.
|
which is different than the whole list of DM channels.
|
||||||
"""
|
"""
|
||||||
dm_ids = await self.db.fetch("""
|
dm_ids = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT dm_id
|
SELECT dm_id
|
||||||
FROM dm_channel_state
|
FROM dm_channel_state
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
dm_ids = [r['dm_id'] for r in dm_ids]
|
dm_ids = [r["dm_id"] for r in dm_ids]
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
|
|
@ -191,21 +223,24 @@ class UserStorage:
|
||||||
|
|
||||||
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
|
async def get_read_state(self, user_id: int) -> List[Dict[str, Any]]:
|
||||||
"""Get the read state for a user."""
|
"""Get the read state for a user."""
|
||||||
rows = await self.db.fetch("""
|
rows = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT channel_id, last_message_id, mention_count
|
SELECT channel_id, last_message_id, mention_count
|
||||||
FROM user_read_state
|
FROM user_read_state
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
|
|
||||||
drow['id'] = str(drow['channel_id'])
|
drow["id"] = str(drow["channel_id"])
|
||||||
drow.pop('channel_id')
|
drow.pop("channel_id")
|
||||||
|
|
||||||
drow['last_message_id'] = str(drow['last_message_id'])
|
drow["last_message_id"] = str(drow["last_message_id"])
|
||||||
|
|
||||||
res.append(drow)
|
res.append(drow)
|
||||||
|
|
||||||
|
|
@ -214,13 +249,17 @@ class UserStorage:
|
||||||
async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List:
|
async def _get_chan_overrides(self, user_id: int, guild_id: int) -> List:
|
||||||
chan_overrides = []
|
chan_overrides = []
|
||||||
|
|
||||||
overrides = await self.db.fetch("""
|
overrides = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT channel_id::text, muted, message_notifications
|
SELECT channel_id::text, muted, message_notifications
|
||||||
FROM guild_settings_channel_overrides
|
FROM guild_settings_channel_overrides
|
||||||
WHERE
|
WHERE
|
||||||
user_id = $1
|
user_id = $1
|
||||||
AND guild_id = $2
|
AND guild_id = $2
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
for chan_row in overrides:
|
for chan_row in overrides:
|
||||||
dcrow = dict(chan_row)
|
dcrow = dict(chan_row)
|
||||||
|
|
@ -228,30 +267,35 @@ class UserStorage:
|
||||||
|
|
||||||
return chan_overrides
|
return chan_overrides
|
||||||
|
|
||||||
async def get_guild_settings_one(self, user_id: int,
|
async def get_guild_settings_one(self, user_id: int, guild_id: int) -> dict:
|
||||||
guild_id: int) -> dict:
|
|
||||||
"""Get guild settings information for a single guild."""
|
"""Get guild settings information for a single guild."""
|
||||||
row = await self.db.fetchrow("""
|
row = await self.db.fetchrow(
|
||||||
|
"""
|
||||||
SELECT guild_id::text, suppress_everyone, muted,
|
SELECT guild_id::text, suppress_everyone, muted,
|
||||||
message_notifications, mobile_push
|
message_notifications, mobile_push
|
||||||
FROM guild_settings
|
FROM guild_settings
|
||||||
WHERE user_id = $1 AND guild_id = $2
|
WHERE user_id = $1 AND guild_id = $2
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
if not row:
|
if not row:
|
||||||
await self.db.execute("""
|
await self.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO guild_settings (user_id, guild_id)
|
INSERT INTO guild_settings (user_id, guild_id)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", user_id, guild_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
return await self.get_guild_settings_one(user_id, guild_id)
|
return await self.get_guild_settings_one(user_id, guild_id)
|
||||||
|
|
||||||
gid = int(row['guild_id'])
|
gid = int(row["guild_id"])
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
chan_overrides = await self._get_chan_overrides(user_id, gid)
|
chan_overrides = await self._get_chan_overrides(user_id, gid)
|
||||||
return {**drow, **{
|
return {**drow, **{"channel_overrides": chan_overrides}}
|
||||||
'channel_overrides': chan_overrides
|
|
||||||
}}
|
|
||||||
|
|
||||||
async def get_guild_settings(self, user_id: int):
|
async def get_guild_settings(self, user_id: int):
|
||||||
"""Get the specific User Guild Settings,
|
"""Get the specific User Guild Settings,
|
||||||
|
|
@ -259,34 +303,38 @@ class UserStorage:
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
settings = await self.db.fetch("""
|
settings = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT guild_id::text, suppress_everyone, muted,
|
SELECT guild_id::text, suppress_everyone, muted,
|
||||||
message_notifications, mobile_push
|
message_notifications, mobile_push
|
||||||
FROM guild_settings
|
FROM guild_settings
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
for row in settings:
|
for row in settings:
|
||||||
gid = int(row['guild_id'])
|
gid = int(row["guild_id"])
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
|
|
||||||
chan_overrides = await self._get_chan_overrides(user_id, gid)
|
chan_overrides = await self._get_chan_overrides(user_id, gid)
|
||||||
|
|
||||||
res.append({**drow, **{
|
res.append({**drow, **{"channel_overrides": chan_overrides}})
|
||||||
'channel_overrides': chan_overrides
|
|
||||||
}})
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_user_guilds(self, user_id: int) -> List[int]:
|
async def get_user_guilds(self, user_id: int) -> List[int]:
|
||||||
"""Get all guild IDs a user is on."""
|
"""Get all guild IDs a user is on."""
|
||||||
guild_ids = await self.db.fetch("""
|
guild_ids = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
FROM members
|
FROM members
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [row['guild_id'] for row in guild_ids]
|
return [row["guild_id"] for row in guild_ids]
|
||||||
|
|
||||||
async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]:
|
async def get_mutual_guilds(self, user_id: int, peer_id: int) -> List[int]:
|
||||||
"""Get a list of guilds two separate users
|
"""Get a list of guilds two separate users
|
||||||
|
|
@ -301,13 +349,17 @@ class UserStorage:
|
||||||
|
|
||||||
return await self.get_user_guilds(user_id) or [0]
|
return await self.get_user_guilds(user_id) or [0]
|
||||||
|
|
||||||
mutual_guilds = await self.db.fetch("""
|
mutual_guilds = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT guild_id FROM members WHERE user_id = $1
|
SELECT guild_id FROM members WHERE user_id = $1
|
||||||
INTERSECT
|
INTERSECT
|
||||||
SELECT guild_id FROM members WHERE user_id = $2
|
SELECT guild_id FROM members WHERE user_id = $2
|
||||||
""", user_id, peer_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
|
||||||
mutual_guilds = [r['guild_id'] for r in mutual_guilds]
|
mutual_guilds = [r["guild_id"] for r in mutual_guilds]
|
||||||
|
|
||||||
return mutual_guilds
|
return mutual_guilds
|
||||||
|
|
||||||
|
|
@ -316,7 +368,8 @@ class UserStorage:
|
||||||
|
|
||||||
This returns false even if there is a friend request.
|
This returns false even if there is a friend request.
|
||||||
"""
|
"""
|
||||||
return await self.db.fetchval("""
|
return await self.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT
|
SELECT
|
||||||
(
|
(
|
||||||
SELECT EXISTS(
|
SELECT EXISTS(
|
||||||
|
|
@ -337,17 +390,23 @@ class UserStorage:
|
||||||
AND rel_type = 1
|
AND rel_type = 1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
""", user_id, peer_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
peer_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_gdms_internal(self, user_id) -> List[int]:
|
async def get_gdms_internal(self, user_id) -> List[int]:
|
||||||
"""Return a list of Group DM IDs the user is a member of."""
|
"""Return a list of Group DM IDs the user is a member of."""
|
||||||
rows = await self.db.fetch("""
|
rows = await self.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM group_dm_members
|
FROM group_dm_members
|
||||||
WHERE member_id = $1
|
WHERE member_id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
return [r['id'] for r in rows]
|
return [r["id"] for r in rows]
|
||||||
|
|
||||||
async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
|
async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
|
||||||
"""Get list of group DMs a user is in."""
|
"""Get list of group DMs a user is in."""
|
||||||
|
|
@ -356,8 +415,6 @@ class UserStorage:
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for gdm_id in gdm_ids:
|
for gdm_id in gdm_ids:
|
||||||
res.append(
|
res.append(await self.storage.get_channel(gdm_id, user_id=user_id))
|
||||||
await self.storage.get_channel(gdm_id, user_id=user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ async def task_wrapper(name: str, coro):
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except:
|
except:
|
||||||
log.exception('{} task error', name)
|
log.exception("{} task error", name)
|
||||||
|
|
||||||
|
|
||||||
def dict_get(mapping, key, default):
|
def dict_get(mapping, key, default):
|
||||||
|
|
@ -84,54 +84,66 @@ def mmh3(inp_str: str, seed: int = 0):
|
||||||
h1 = seed
|
h1 = seed
|
||||||
|
|
||||||
# mm3 constants
|
# mm3 constants
|
||||||
c1 = 0xcc9e2d51
|
c1 = 0xCC9E2D51
|
||||||
c2 = 0x1b873593
|
c2 = 0x1B873593
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
while i < bytecount:
|
while i < bytecount:
|
||||||
k1 = (
|
k1 = (
|
||||||
(key[i] & 0xff) |
|
(key[i] & 0xFF)
|
||||||
((key[i + 1] & 0xff) << 8) |
|
| ((key[i + 1] & 0xFF) << 8)
|
||||||
((key[i + 2] & 0xff) << 16) |
|
| ((key[i + 2] & 0xFF) << 16)
|
||||||
((key[i + 3] & 0xff) << 24)
|
| ((key[i + 3] & 0xFF) << 24)
|
||||||
)
|
)
|
||||||
|
|
||||||
i += 4
|
i += 4
|
||||||
|
|
||||||
k1 = ((((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16))) & 0xffffffff
|
k1 = (
|
||||||
|
(((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16))
|
||||||
|
) & 0xFFFFFFFF
|
||||||
k1 = (k1 << 15) | (_u(k1) >> 17)
|
k1 = (k1 << 15) | (_u(k1) >> 17)
|
||||||
k1 = ((((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16))) & 0xffffffff;
|
k1 = (
|
||||||
|
(((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16))
|
||||||
|
) & 0xFFFFFFFF
|
||||||
|
|
||||||
h1 ^= k1
|
h1 ^= k1
|
||||||
h1 = (h1 << 13) | (_u(h1) >> 19);
|
h1 = (h1 << 13) | (_u(h1) >> 19)
|
||||||
h1b = ((((h1 & 0xffff) * 5) + ((((_u(h1) >> 16) * 5) & 0xffff) << 16))) & 0xffffffff;
|
h1b = (
|
||||||
h1 = (((h1b & 0xffff) + 0x6b64) + ((((_u(h1b) >> 16) + 0xe654) & 0xffff) << 16))
|
(((h1 & 0xFFFF) * 5) + ((((_u(h1) >> 16) * 5) & 0xFFFF) << 16))
|
||||||
|
) & 0xFFFFFFFF
|
||||||
|
h1 = ((h1b & 0xFFFF) + 0x6B64) + ((((_u(h1b) >> 16) + 0xE654) & 0xFFFF) << 16)
|
||||||
|
|
||||||
k1 = 0
|
k1 = 0
|
||||||
v = None
|
v = None
|
||||||
|
|
||||||
if remainder == 3:
|
if remainder == 3:
|
||||||
v = (key[i + 2] & 0xff) << 16
|
v = (key[i + 2] & 0xFF) << 16
|
||||||
elif remainder == 2:
|
elif remainder == 2:
|
||||||
v = (key[i + 1] & 0xff) << 8
|
v = (key[i + 1] & 0xFF) << 8
|
||||||
elif remainder == 1:
|
elif remainder == 1:
|
||||||
v = (key[i] & 0xff)
|
v = key[i] & 0xFF
|
||||||
|
|
||||||
if v is not None:
|
if v is not None:
|
||||||
k1 ^= v
|
k1 ^= v
|
||||||
|
|
||||||
k1 = (((k1 & 0xffff) * c1) + ((((_u(k1) >> 16) * c1) & 0xffff) << 16)) & 0xffffffff
|
k1 = (((k1 & 0xFFFF) * c1) + ((((_u(k1) >> 16) * c1) & 0xFFFF) << 16)) & 0xFFFFFFFF
|
||||||
k1 = (k1 << 15) | (_u(k1) >> 17)
|
k1 = (k1 << 15) | (_u(k1) >> 17)
|
||||||
k1 = (((k1 & 0xffff) * c2) + ((((_u(k1) >> 16) * c2) & 0xffff) << 16)) & 0xffffffff
|
k1 = (((k1 & 0xFFFF) * c2) + ((((_u(k1) >> 16) * c2) & 0xFFFF) << 16)) & 0xFFFFFFFF
|
||||||
h1 ^= k1
|
h1 ^= k1
|
||||||
|
|
||||||
h1 ^= len(key)
|
h1 ^= len(key)
|
||||||
|
|
||||||
h1 ^= _u(h1) >> 16
|
h1 ^= _u(h1) >> 16
|
||||||
h1 = (((h1 & 0xffff) * 0x85ebca6b) + ((((_u(h1) >> 16) * 0x85ebca6b) & 0xffff) << 16)) & 0xffffffff
|
h1 = (
|
||||||
|
((h1 & 0xFFFF) * 0x85EBCA6B) + ((((_u(h1) >> 16) * 0x85EBCA6B) & 0xFFFF) << 16)
|
||||||
|
) & 0xFFFFFFFF
|
||||||
h1 ^= _u(h1) >> 13
|
h1 ^= _u(h1) >> 13
|
||||||
h1 = ((((h1 & 0xffff) * 0xc2b2ae35) + ((((_u(h1) >> 16) * 0xc2b2ae35) & 0xffff) << 16))) & 0xffffffff
|
h1 = (
|
||||||
|
(
|
||||||
|
((h1 & 0xFFFF) * 0xC2B2AE35)
|
||||||
|
+ ((((_u(h1) >> 16) * 0xC2B2AE35) & 0xFFFF) << 16)
|
||||||
|
)
|
||||||
|
) & 0xFFFFFFFF
|
||||||
h1 ^= _u(h1) >> 16
|
h1 ^= _u(h1) >> 16
|
||||||
|
|
||||||
return _u(h1) >> 0
|
return _u(h1) >> 0
|
||||||
|
|
@ -139,6 +151,7 @@ def mmh3(inp_str: str, seed: int = 0):
|
||||||
|
|
||||||
class LitecordJSONEncoder(JSONEncoder):
|
class LitecordJSONEncoder(JSONEncoder):
|
||||||
"""Custom JSON encoder for Litecord."""
|
"""Custom JSON encoder for Litecord."""
|
||||||
|
|
||||||
def default(self, value: Any):
|
def default(self, value: Any):
|
||||||
"""By default, this will try to get the to_json attribute of a given
|
"""By default, this will try to get the to_json attribute of a given
|
||||||
value being JSON encoded."""
|
value being JSON encoded."""
|
||||||
|
|
@ -151,17 +164,17 @@ class LitecordJSONEncoder(JSONEncoder):
|
||||||
async def pg_set_json(con):
|
async def pg_set_json(con):
|
||||||
"""Set JSON and JSONB codecs for an asyncpg connection."""
|
"""Set JSON and JSONB codecs for an asyncpg connection."""
|
||||||
await con.set_type_codec(
|
await con.set_type_codec(
|
||||||
'json',
|
"json",
|
||||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
decoder=json.loads,
|
decoder=json.loads,
|
||||||
schema='pg_catalog'
|
schema="pg_catalog",
|
||||||
)
|
)
|
||||||
|
|
||||||
await con.set_type_codec(
|
await con.set_type_codec(
|
||||||
'jsonb',
|
"jsonb",
|
||||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
decoder=json.loads,
|
decoder=json.loads,
|
||||||
schema='pg_catalog'
|
schema="pg_catalog",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -179,6 +192,7 @@ def yield_chunks(input_list: Sequence[Any], chunk_size: int):
|
||||||
for idx in range(0, len(input_list), chunk_size):
|
for idx in range(0, len(input_list), chunk_size):
|
||||||
yield input_list[idx : idx + chunk_size]
|
yield input_list[idx : idx + chunk_size]
|
||||||
|
|
||||||
|
|
||||||
def to_update(j: dict, orig: dict, field: str) -> bool:
|
def to_update(j: dict, orig: dict, field: str) -> bool:
|
||||||
"""Compare values to check if j[field] is actually updating
|
"""Compare values to check if j[field] is actually updating
|
||||||
the value in orig[field]. Useful for icon checks."""
|
the value in orig[field]. Useful for icon checks."""
|
||||||
|
|
@ -193,27 +207,23 @@ async def search_result_from_list(rows: List) -> Dict[str, Any]:
|
||||||
- An int (?) on `total_results`
|
- An int (?) on `total_results`
|
||||||
- Two bigint[], each on `before` and `after` respectively.
|
- Two bigint[], each on `before` and `after` respectively.
|
||||||
"""
|
"""
|
||||||
results = 0 if not rows else rows[0]['total_results']
|
results = 0 if not rows else rows[0]["total_results"]
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
before, after = [], []
|
before, after = [], []
|
||||||
|
|
||||||
for before_id in reversed(row['before']):
|
for before_id in reversed(row["before"]):
|
||||||
before.append(await app.storage.get_message(before_id))
|
before.append(await app.storage.get_message(before_id))
|
||||||
|
|
||||||
for after_id in row['after']:
|
for after_id in row["after"]:
|
||||||
after.append(await app.storage.get_message(after_id))
|
after.append(await app.storage.get_message(after_id))
|
||||||
|
|
||||||
msg = await app.storage.get_message(row['current_id'])
|
msg = await app.storage.get_message(row["current_id"])
|
||||||
msg['hit'] = True
|
msg["hit"] = True
|
||||||
res.append(before + [msg] + after)
|
res.append(before + [msg] + after)
|
||||||
|
|
||||||
return {
|
return {"total_results": results, "messages": res, "analytics_id": ""}
|
||||||
'total_results': results,
|
|
||||||
'messages': res,
|
|
||||||
'analytics_id': '',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_int(val: Any) -> Union[int, Any]:
|
def maybe_int(val: Any) -> Union[int, Any]:
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ log = Logger(__name__)
|
||||||
|
|
||||||
class LVSPConnection:
|
class LVSPConnection:
|
||||||
"""Represents a single LVSP connection."""
|
"""Represents a single LVSP connection."""
|
||||||
|
|
||||||
def __init__(self, lvsp, region: str, hostname: str):
|
def __init__(self, lvsp, region: str, hostname: str):
|
||||||
self.lvsp = lvsp
|
self.lvsp = lvsp
|
||||||
self.app = lvsp.app
|
self.app = lvsp.app
|
||||||
|
|
@ -46,7 +47,7 @@ class LVSPConnection:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _log_id(self):
|
def _log_id(self):
|
||||||
return f'region={self.region} hostname={self.hostname}'
|
return f"region={self.region} hostname={self.hostname}"
|
||||||
|
|
||||||
async def send(self, payload):
|
async def send(self, payload):
|
||||||
"""Send a payload down the websocket."""
|
"""Send a payload down the websocket."""
|
||||||
|
|
@ -61,50 +62,42 @@ class LVSPConnection:
|
||||||
|
|
||||||
async def send_op(self, opcode: int, data: dict):
|
async def send_op(self, opcode: int, data: dict):
|
||||||
"""Send a message with an OP code included"""
|
"""Send a message with an OP code included"""
|
||||||
await self.send({
|
await self.send({"op": opcode, "d": data})
|
||||||
'op': opcode,
|
|
||||||
'd': data
|
|
||||||
})
|
|
||||||
|
|
||||||
async def send_info(self, info_type: str, info_data: Dict):
|
async def send_info(self, info_type: str, info_data: Dict):
|
||||||
"""Send an INFO message down the websocket."""
|
"""Send an INFO message down the websocket."""
|
||||||
await self.send({
|
await self.send(
|
||||||
'op': OP.info,
|
{
|
||||||
'd': {
|
"op": OP.info,
|
||||||
'type': InfoTable[info_type.upper()],
|
"d": {"type": InfoTable[info_type.upper()], "data": info_data},
|
||||||
'data': info_data
|
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
async def _heartbeater(self, hb_interval: int):
|
async def _heartbeater(self, hb_interval: int):
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(hb_interval)
|
await asyncio.sleep(hb_interval)
|
||||||
|
|
||||||
# TODO: add self._seq
|
# TODO: add self._seq
|
||||||
await self.send_op(OP.heartbeat, {
|
await self.send_op(OP.heartbeat, {"s": 0})
|
||||||
's': 0
|
|
||||||
})
|
|
||||||
|
|
||||||
# give the server 300 milliseconds to reply.
|
# give the server 300 milliseconds to reply.
|
||||||
await asyncio.sleep(300)
|
await asyncio.sleep(300)
|
||||||
await self.conn.close(4000, 'heartbeat timeout')
|
await self.conn.close(4000, "heartbeat timeout")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _start_hb(self):
|
def _start_hb(self):
|
||||||
self._hb_task = self.app.loop.create_task(
|
self._hb_task = self.app.loop.create_task(self._heartbeater(self._hb_interval))
|
||||||
self._heartbeater(self._hb_interval)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _stop_hb(self):
|
def _stop_hb(self):
|
||||||
self._hb_task.cancel()
|
self._hb_task.cancel()
|
||||||
|
|
||||||
async def _handle_0(self, msg):
|
async def _handle_0(self, msg):
|
||||||
"""Handle HELLO message."""
|
"""Handle HELLO message."""
|
||||||
data = msg['d']
|
data = msg["d"]
|
||||||
|
|
||||||
# nonce = data['nonce']
|
# nonce = data['nonce']
|
||||||
self._hb_interval = data['heartbeat_interval']
|
self._hb_interval = data["heartbeat_interval"]
|
||||||
|
|
||||||
# TODO: send identify
|
# TODO: send identify
|
||||||
|
|
||||||
|
|
@ -112,48 +105,52 @@ class LVSPConnection:
|
||||||
"""Update the health value of a given voice server."""
|
"""Update the health value of a given voice server."""
|
||||||
self.health = new_health
|
self.health = new_health
|
||||||
|
|
||||||
await self.app.db.execute("""
|
await self.app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE voice_servers
|
UPDATE voice_servers
|
||||||
SET health = $1
|
SET health = $1
|
||||||
WHERE hostname = $2
|
WHERE hostname = $2
|
||||||
""", new_health, self.hostname)
|
""",
|
||||||
|
new_health,
|
||||||
|
self.hostname,
|
||||||
|
)
|
||||||
|
|
||||||
async def _handle_3(self, msg):
|
async def _handle_3(self, msg):
|
||||||
"""Handle READY message.
|
"""Handle READY message.
|
||||||
|
|
||||||
We only start heartbeating after READY.
|
We only start heartbeating after READY.
|
||||||
"""
|
"""
|
||||||
await self._update_health(msg['health'])
|
await self._update_health(msg["health"])
|
||||||
self._start_hb()
|
self._start_hb()
|
||||||
|
|
||||||
async def _handle_5(self, msg):
|
async def _handle_5(self, msg):
|
||||||
"""Handle HEARTBEAT_ACK."""
|
"""Handle HEARTBEAT_ACK."""
|
||||||
self._stop_hb()
|
self._stop_hb()
|
||||||
await self._update_health(msg['health'])
|
await self._update_health(msg["health"])
|
||||||
self._start_hb()
|
self._start_hb()
|
||||||
|
|
||||||
async def _handle_6(self, msg):
|
async def _handle_6(self, msg):
|
||||||
"""Handle INFO messages."""
|
"""Handle INFO messages."""
|
||||||
info = msg['d']
|
info = msg["d"]
|
||||||
info_type_str = InfoReverse[info['type']].lower()
|
info_type_str = InfoReverse[info["type"]].lower()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
info_handler = getattr(self, f'_handle_info_{info_type_str}')
|
info_handler = getattr(self, f"_handle_info_{info_type_str}")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return
|
return
|
||||||
|
|
||||||
await info_handler(info['data'])
|
await info_handler(info["data"])
|
||||||
|
|
||||||
async def _handle_info_channel_assign(self, data: dict):
|
async def _handle_info_channel_assign(self, data: dict):
|
||||||
"""called by the server once we got a channel assign."""
|
"""called by the server once we got a channel assign."""
|
||||||
try:
|
try:
|
||||||
channel_id = data['channel_id']
|
channel_id = data["channel_id"]
|
||||||
channel_id = int(channel_id)
|
channel_id = int(channel_id)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
guild_id = data['guild_id']
|
guild_id = data["guild_id"]
|
||||||
guild_id = int(guild_id)
|
guild_id = int(guild_id)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
guild_id = None
|
guild_id = None
|
||||||
|
|
@ -166,19 +163,19 @@ class LVSPConnection:
|
||||||
msg = await self.recv()
|
msg = await self.recv()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
opcode = msg['op']
|
opcode = msg["op"]
|
||||||
handler = getattr(self, f'_handle_{opcode}')
|
handler = getattr(self, f"_handle_{opcode}")
|
||||||
await handler(msg)
|
await handler(msg)
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
# TODO: error codes in LVSP
|
# TODO: error codes in LVSP
|
||||||
raise Exception('invalid op code')
|
raise Exception("invalid op code")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""Try to start a websocket connection."""
|
"""Try to start a websocket connection."""
|
||||||
try:
|
try:
|
||||||
self.conn = await websockets.connect(f'wss://{self.hostname}')
|
self.conn = await websockets.connect(f"wss://{self.hostname}")
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('failed to start lvsp conn to {}', self.hostname)
|
log.exception("failed to start lvsp conn to {}", self.hostname)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Start the websocket."""
|
"""Start the websocket."""
|
||||||
|
|
@ -186,15 +183,15 @@ class LVSPConnection:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.conn:
|
if not self.conn:
|
||||||
log.error('failed to start lvsp connection, stopping')
|
log.error("failed to start lvsp connection, stopping")
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._loop()
|
await self._loop()
|
||||||
except websockets.exceptions.ConnectionClosed as err:
|
except websockets.exceptions.ConnectionClosed as err:
|
||||||
log.warning('conn close, {}, err={}', self._log_id, err)
|
log.warning("conn close, {}, err={}", self._log_id, err)
|
||||||
# except WebsocketClose as err:
|
# except WebsocketClose as err:
|
||||||
# log.warning('ws close, state={} err={}', self.state, err)
|
# log.warning('ws close, state={} err={}', self.state, err)
|
||||||
# await self.conn.close(code=err.code, reason=err.reason)
|
# await self.conn.close(code=err.code, reason=err.reason)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
log.exception('An exception has occoured. {}', self._log_id)
|
log.exception("An exception has occoured. {}", self._log_id)
|
||||||
await self.conn.close(code=4000, reason=repr(err))
|
await self.conn.close(code=4000, reason=repr(err))
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ log = Logger(__name__)
|
||||||
@dataclass
|
@dataclass
|
||||||
class Region:
|
class Region:
|
||||||
"""Voice region data."""
|
"""Voice region data."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
vip: bool
|
vip: bool
|
||||||
|
|
||||||
|
|
@ -40,6 +41,7 @@ class LVSPManager:
|
||||||
|
|
||||||
Spawns :class:`LVSPConnection` as needed, etc.
|
Spawns :class:`LVSPConnection` as needed, etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, app, voice):
|
def __init__(self, app, voice):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.voice = voice
|
self.voice = voice
|
||||||
|
|
@ -61,49 +63,50 @@ class LVSPManager:
|
||||||
async def _spawn(self):
|
async def _spawn(self):
|
||||||
"""Spawn LVSPConnection for each region."""
|
"""Spawn LVSPConnection for each region."""
|
||||||
|
|
||||||
regions = await self.app.db.fetch("""
|
regions = await self.app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT id, vip
|
SELECT id, vip
|
||||||
FROM voice_regions
|
FROM voice_regions
|
||||||
WHERE deprecated = false
|
WHERE deprecated = false
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
regions = [Region(r['id'], r['vip']) for r in regions]
|
regions = [Region(r["id"], r["vip"]) for r in regions]
|
||||||
|
|
||||||
if not regions:
|
if not regions:
|
||||||
log.warning('no regions are setup')
|
log.warning("no regions are setup")
|
||||||
return
|
return
|
||||||
|
|
||||||
for region in regions:
|
for region in regions:
|
||||||
# store it locally for region() function
|
# store it locally for region() function
|
||||||
self.regions[region.id] = region
|
self.regions[region.id] = region
|
||||||
|
|
||||||
self.app.loop.create_task(
|
self.app.loop.create_task(self._spawn_region(region))
|
||||||
self._spawn_region(region)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _spawn_region(self, region: Region):
|
async def _spawn_region(self, region: Region):
|
||||||
"""Spawn a region. Involves fetching all the hostnames
|
"""Spawn a region. Involves fetching all the hostnames
|
||||||
for the regions and spawning a LVSPConnection for each."""
|
for the regions and spawning a LVSPConnection for each."""
|
||||||
servers = await self.app.db.fetch("""
|
servers = await self.app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT hostname
|
SELECT hostname
|
||||||
FROM voice_servers
|
FROM voice_servers
|
||||||
WHERE region_id = $1
|
WHERE region_id = $1
|
||||||
""", region.id)
|
""",
|
||||||
|
region.id,
|
||||||
|
)
|
||||||
|
|
||||||
if not servers:
|
if not servers:
|
||||||
log.warning('region {} does not have servers', region)
|
log.warning("region {} does not have servers", region)
|
||||||
return
|
return
|
||||||
|
|
||||||
servers = [r['hostname'] for r in servers]
|
servers = [r["hostname"] for r in servers]
|
||||||
self.servers[region.id] = servers
|
self.servers[region.id] = servers
|
||||||
|
|
||||||
for hostname in servers:
|
for hostname in servers:
|
||||||
conn = LVSPConnection(self, region.id, hostname)
|
conn = LVSPConnection(self, region.id, hostname)
|
||||||
self.conns[hostname] = conn
|
self.conns[hostname] = conn
|
||||||
|
|
||||||
self.app.loop.create_task(
|
self.app.loop.create_task(conn.run())
|
||||||
conn.run()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def del_conn(self, conn):
|
async def del_conn(self, conn):
|
||||||
"""Delete a connection from the connection pool."""
|
"""Delete a connection from the connection pool."""
|
||||||
|
|
@ -119,11 +122,14 @@ class LVSPManager:
|
||||||
|
|
||||||
async def guild_region(self, guild_id: int) -> Optional[str]:
|
async def guild_region(self, guild_id: int) -> Optional[str]:
|
||||||
"""Return the voice region of a guild."""
|
"""Return the voice region of a guild."""
|
||||||
return await self.app.db.fetchval("""
|
return await self.app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT region
|
SELECT region
|
||||||
FROM guilds
|
FROM guilds
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", guild_id)
|
""",
|
||||||
|
guild_id,
|
||||||
|
)
|
||||||
|
|
||||||
def get_health(self, hostname: str) -> float:
|
def get_health(self, hostname: str) -> float:
|
||||||
"""Get voice server health, given hostname."""
|
"""Get voice server health, given hostname."""
|
||||||
|
|
@ -144,10 +150,7 @@ class LVSPManager:
|
||||||
region = await self.guild_region(guild_id)
|
region = await self.guild_region(guild_id)
|
||||||
|
|
||||||
# sort connected servers by health
|
# sort connected servers by health
|
||||||
sorted_servers = sorted(
|
sorted_servers = sorted(self.servers[region], key=self.get_health)
|
||||||
self.servers[region],
|
|
||||||
key=self.get_health
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hostname = sorted_servers[0]
|
hostname = sorted_servers[0]
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class OPCodes:
|
class OPCodes:
|
||||||
"""LVSP OP codes."""
|
"""LVSP OP codes."""
|
||||||
|
|
||||||
hello = 0
|
hello = 0
|
||||||
identify = 1
|
identify = 1
|
||||||
resume = 2
|
resume = 2
|
||||||
|
|
@ -29,13 +31,13 @@ class OPCodes:
|
||||||
|
|
||||||
|
|
||||||
InfoTable = {
|
InfoTable = {
|
||||||
'CHANNEL_REQ': 0,
|
"CHANNEL_REQ": 0,
|
||||||
'CHANNEL_ASSIGN': 1,
|
"CHANNEL_ASSIGN": 1,
|
||||||
'CHANNEL_UPDATE': 2,
|
"CHANNEL_UPDATE": 2,
|
||||||
'CHANNEL_DESTROY': 3,
|
"CHANNEL_DESTROY": 3,
|
||||||
'VST_CREATE': 4,
|
"VST_CREATE": 4,
|
||||||
'VST_UPDATE': 5,
|
"VST_UPDATE": 5,
|
||||||
'VST_LEAVE': 6,
|
"VST_LEAVE": 6,
|
||||||
}
|
}
|
||||||
|
|
||||||
InfoReverse = {v: k for k, v in InfoTable.items()}
|
InfoReverse = {v: k for k, v in InfoTable.items()}
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ def _construct_state(state_dict: dict) -> VoiceState:
|
||||||
|
|
||||||
class VoiceManager:
|
class VoiceManager:
|
||||||
"""Main voice manager class."""
|
"""Main voice manager class."""
|
||||||
|
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
|
|
@ -56,7 +57,7 @@ class VoiceManager:
|
||||||
"""Return if a user can join a channel."""
|
"""Return if a user can join a channel."""
|
||||||
|
|
||||||
channel = await self.app.storage.get_channel(channel_id)
|
channel = await self.app.storage.get_channel(channel_id)
|
||||||
ctype = ChannelType(channel['type'])
|
ctype = ChannelType(channel["type"])
|
||||||
|
|
||||||
if ctype not in VOICE_CHANNELS:
|
if ctype not in VOICE_CHANNELS:
|
||||||
return
|
return
|
||||||
|
|
@ -65,14 +66,12 @@ class VoiceManager:
|
||||||
|
|
||||||
# get_permissions returns ALL_PERMISSIONS when
|
# get_permissions returns ALL_PERMISSIONS when
|
||||||
# the channel isn't from a guild
|
# the channel isn't from a guild
|
||||||
perms = await get_permissions(
|
perms = await get_permissions(user_id, channel_id, storage=self.app.storage)
|
||||||
user_id, channel_id, storage=self.app.storage
|
|
||||||
)
|
|
||||||
|
|
||||||
# hacky user_limit but should work, as channels not
|
# hacky user_limit but should work, as channels not
|
||||||
# in guilds won't have that field.
|
# in guilds won't have that field.
|
||||||
is_full = states >= channel.get('user_limit', 100)
|
is_full = states >= channel.get("user_limit", 100)
|
||||||
is_bot = (await self.app.storage.get_user(user_id))['bot']
|
is_bot = (await self.app.storage.get_user(user_id))["bot"]
|
||||||
is_manager = perms.bits.manage_channels
|
is_manager = perms.bits.manage_channels
|
||||||
|
|
||||||
# if the channel is full AND:
|
# if the channel is full AND:
|
||||||
|
|
@ -140,8 +139,8 @@ class VoiceManager:
|
||||||
|
|
||||||
for field in prop:
|
for field in prop:
|
||||||
# NOTE: this should not happen, ever.
|
# NOTE: this should not happen, ever.
|
||||||
if field in ('channel_id', 'user_id'):
|
if field in ("channel_id", "user_id"):
|
||||||
raise ValueError('properties are updating channel or user')
|
raise ValueError("properties are updating channel or user")
|
||||||
|
|
||||||
new_state_dict[field] = prop[field]
|
new_state_dict[field] = prop[field]
|
||||||
|
|
||||||
|
|
@ -153,27 +152,28 @@ class VoiceManager:
|
||||||
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
|
async def move_channels(self, old_voice_key: VoiceKey, channel_id: int):
|
||||||
"""Move a user between channels."""
|
"""Move a user between channels."""
|
||||||
await self.del_state(old_voice_key)
|
await self.del_state(old_voice_key)
|
||||||
await self.create_state(old_voice_key, {'channel_id': channel_id})
|
await self.create_state(old_voice_key, {"channel_id": channel_id})
|
||||||
|
|
||||||
async def _lvsp_info_guild(self, guild_id, info_type, info_data):
|
async def _lvsp_info_guild(self, guild_id, info_type, info_data):
|
||||||
hostname = await self.lvsp.get_guild_server(guild_id)
|
hostname = await self.lvsp.get_guild_server(guild_id)
|
||||||
if hostname is None:
|
if hostname is None:
|
||||||
log.error('no voice server for guild id {}', guild_id)
|
log.error("no voice server for guild id {}", guild_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
conn = self.lvsp.get_conn(hostname)
|
conn = self.lvsp.get_conn(hostname)
|
||||||
await conn.send_info(info_type, info_data)
|
await conn.send_info(info_type, info_data)
|
||||||
|
|
||||||
async def _create_ctx_guild(self, guild_id, channel_id):
|
async def _create_ctx_guild(self, guild_id, channel_id):
|
||||||
await self._lvsp_info_guild(guild_id, 'CHANNEL_REQ', {
|
await self._lvsp_info_guild(
|
||||||
'guild_id': str(guild_id),
|
guild_id,
|
||||||
'channel_id': str(channel_id),
|
"CHANNEL_REQ",
|
||||||
})
|
{"guild_id": str(guild_id), "channel_id": str(channel_id)},
|
||||||
|
)
|
||||||
|
|
||||||
async def _start_voice_guild(self, voice_key: VoiceKey, data: dict):
|
async def _start_voice_guild(self, voice_key: VoiceKey, data: dict):
|
||||||
"""Start a voice context in a guild."""
|
"""Start a voice context in a guild."""
|
||||||
user_id, guild_id = voice_key
|
user_id, guild_id = voice_key
|
||||||
channel_id = int(data['channel_id'])
|
channel_id = int(data["channel_id"])
|
||||||
|
|
||||||
existing_states = self.states[voice_key]
|
existing_states = self.states[voice_key]
|
||||||
channel_exists = any(
|
channel_exists = any(
|
||||||
|
|
@ -183,11 +183,15 @@ class VoiceManager:
|
||||||
if not channel_exists:
|
if not channel_exists:
|
||||||
await self._create_ctx_guild(guild_id, channel_id)
|
await self._create_ctx_guild(guild_id, channel_id)
|
||||||
|
|
||||||
await self._lvsp_info_guild(guild_id, 'VST_CREATE', {
|
await self._lvsp_info_guild(
|
||||||
'user_id': str(user_id),
|
guild_id,
|
||||||
'guild_id': str(guild_id),
|
"VST_CREATE",
|
||||||
'channel_id': str(channel_id),
|
{
|
||||||
})
|
"user_id": str(user_id),
|
||||||
|
"guild_id": str(guild_id),
|
||||||
|
"channel_id": str(channel_id),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def create_state(self, voice_key: VoiceKey, data: dict):
|
async def create_state(self, voice_key: VoiceKey, data: dict):
|
||||||
"""Creates (or tries to create) a voice state.
|
"""Creates (or tries to create) a voice state.
|
||||||
|
|
@ -249,10 +253,13 @@ class VoiceManager:
|
||||||
|
|
||||||
async def voice_server_list(self, region: str) -> List[dict]:
|
async def voice_server_list(self, region: str) -> List[dict]:
|
||||||
"""Get a list of voice server objects"""
|
"""Get a list of voice server objects"""
|
||||||
rows = await self.app.db.fetch("""
|
rows = await self.app.db.fetch(
|
||||||
|
"""
|
||||||
SELECT hostname, last_health
|
SELECT hostname, last_health
|
||||||
FROM voice_servers
|
FROM voice_servers
|
||||||
WHERE region_id = $1
|
WHERE region_id = $1
|
||||||
""", region)
|
""",
|
||||||
|
region,
|
||||||
|
)
|
||||||
|
|
||||||
return list(map(dict, rows))
|
return list(map(dict, rows))
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from dataclasses import dataclass, asdict
|
||||||
@dataclass
|
@dataclass
|
||||||
class VoiceState:
|
class VoiceState:
|
||||||
"""Represents a voice state."""
|
"""Represents a voice state."""
|
||||||
|
|
||||||
guild_id: int
|
guild_id: int
|
||||||
channel_id: int
|
channel_id: int
|
||||||
user_id: int
|
user_id: int
|
||||||
|
|
@ -55,7 +56,7 @@ class VoiceState:
|
||||||
|
|
||||||
# a better approach would be actually using
|
# a better approach would be actually using
|
||||||
# the suppressed_by field for backend efficiency.
|
# the suppressed_by field for backend efficiency.
|
||||||
self_dict['suppress'] = user_id == self.suppressed_by
|
self_dict["suppress"] = user_id == self.suppressed_by
|
||||||
self_dict.pop('suppressed_by')
|
self_dict.pop("suppressed_by")
|
||||||
|
|
||||||
return self_dict
|
return self_dict
|
||||||
|
|
|
||||||
|
|
@ -27,5 +27,5 @@ import config
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
sys.exit(main(config))
|
sys.exit(main(config))
|
||||||
|
|
|
||||||
|
|
@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ ALPHABET = string.ascii_lowercase + string.ascii_uppercase + string.digits
|
||||||
|
|
||||||
async def _gen_inv() -> str:
|
async def _gen_inv() -> str:
|
||||||
"""Generate an invite code"""
|
"""Generate an invite code"""
|
||||||
return ''.join(choice(ALPHABET) for _ in range(6))
|
return "".join(choice(ALPHABET) for _ in range(6))
|
||||||
|
|
||||||
|
|
||||||
async def gen_inv(ctx) -> str:
|
async def gen_inv(ctx) -> str:
|
||||||
|
|
@ -34,11 +34,14 @@ async def gen_inv(ctx) -> str:
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
possible_inv = await _gen_inv()
|
possible_inv = await _gen_inv()
|
||||||
|
|
||||||
created_at = await ctx.db.fetchval("""
|
created_at = await ctx.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT created_at
|
SELECT created_at
|
||||||
FROM instance_invites
|
FROM instance_invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", possible_inv)
|
""",
|
||||||
|
possible_inv,
|
||||||
|
)
|
||||||
|
|
||||||
if created_at is None:
|
if created_at is None:
|
||||||
return possible_inv
|
return possible_inv
|
||||||
|
|
@ -51,27 +54,32 @@ async def make_inv(ctx, args):
|
||||||
|
|
||||||
max_uses = args.max_uses
|
max_uses = args.max_uses
|
||||||
|
|
||||||
await ctx.db.execute("""
|
await ctx.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO instance_invites (code, max_uses)
|
INSERT INTO instance_invites (code, max_uses)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", code, max_uses)
|
""",
|
||||||
|
code,
|
||||||
|
max_uses,
|
||||||
|
)
|
||||||
|
|
||||||
print(f'invite created with {max_uses} max uses', code)
|
print(f"invite created with {max_uses} max uses", code)
|
||||||
|
|
||||||
|
|
||||||
async def list_invs(ctx, args):
|
async def list_invs(ctx, args):
|
||||||
rows = await ctx.db.fetch("""
|
rows = await ctx.db.fetch(
|
||||||
|
"""
|
||||||
SELECT code, created_at, uses, max_uses
|
SELECT code, created_at, uses, max_uses
|
||||||
FROM instance_invites
|
FROM instance_invites
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
print(len(rows), 'invites')
|
print(len(rows), "invites")
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
max_uses = row['max_uses']
|
max_uses = row["max_uses"]
|
||||||
delta = datetime.datetime.utcnow() - row['created_at']
|
delta = datetime.datetime.utcnow() - row["created_at"]
|
||||||
usage = ('infinite uses' if max_uses == -1
|
usage = "infinite uses" if max_uses == -1 else f'{row["uses"]} / {max_uses}'
|
||||||
else f'{row["uses"]} / {max_uses}')
|
|
||||||
|
|
||||||
print(f'\t{row["code"]}, {usage}, made {delta} ago')
|
print(f'\t{row["code"]}, {usage}, made {delta} ago')
|
||||||
|
|
||||||
|
|
@ -79,40 +87,37 @@ async def list_invs(ctx, args):
|
||||||
async def delete_inv(ctx, args):
|
async def delete_inv(ctx, args):
|
||||||
inv = args.invite_code
|
inv = args.invite_code
|
||||||
|
|
||||||
res = await ctx.db.execute("""
|
res = await ctx.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM instance_invites
|
DELETE FROM instance_invites
|
||||||
WHERE code = $1
|
WHERE code = $1
|
||||||
""", inv)
|
""",
|
||||||
|
inv,
|
||||||
|
)
|
||||||
|
|
||||||
if res == 'DELETE 0':
|
if res == "DELETE 0":
|
||||||
print('NOT FOUND')
|
print("NOT FOUND")
|
||||||
return
|
return
|
||||||
|
|
||||||
print('OK')
|
print("OK")
|
||||||
|
|
||||||
|
|
||||||
def setup(subparser):
|
def setup(subparser):
|
||||||
makeinv_parser = subparser.add_parser(
|
makeinv_parser = subparser.add_parser("makeinv", help="create an invite")
|
||||||
'makeinv',
|
|
||||||
help='create an invite',
|
|
||||||
)
|
|
||||||
|
|
||||||
makeinv_parser.add_argument(
|
makeinv_parser.add_argument(
|
||||||
'max_uses', nargs='?', type=int, default=-1,
|
"max_uses",
|
||||||
help='Maximum amount of uses before the invite is unavailable',
|
nargs="?",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Maximum amount of uses before the invite is unavailable",
|
||||||
)
|
)
|
||||||
|
|
||||||
makeinv_parser.set_defaults(func=make_inv)
|
makeinv_parser.set_defaults(func=make_inv)
|
||||||
|
|
||||||
listinv_parser = subparser.add_parser(
|
listinv_parser = subparser.add_parser("listinv", help="list all invites")
|
||||||
'listinv',
|
|
||||||
help='list all invites',
|
|
||||||
)
|
|
||||||
listinv_parser.set_defaults(func=list_invs)
|
listinv_parser.set_defaults(func=list_invs)
|
||||||
|
|
||||||
delinv_parser = subparser.add_parser(
|
delinv_parser = subparser.add_parser("delinv", help="delete an invite")
|
||||||
'delinv',
|
delinv_parser.add_argument("invite_code")
|
||||||
help='delete an invite',
|
|
||||||
)
|
|
||||||
delinv_parser.add_argument('invite_code')
|
|
||||||
delinv_parser.set_defaults(func=delete_inv)
|
delinv_parser.set_defaults(func=delete_inv)
|
||||||
|
|
|
||||||
|
|
@ -19,4 +19,4 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from .command import setup as migration
|
from .command import setup as migration
|
||||||
|
|
||||||
__all__ = ['migration']
|
__all__ = ["migration"]
|
||||||
|
|
|
||||||
|
|
@ -32,18 +32,19 @@ from logbook import Logger
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
Migration = namedtuple('Migration', 'id name path')
|
Migration = namedtuple("Migration", "id name path")
|
||||||
|
|
||||||
# line of change, 4 april 2019, at 1am (gmt+0)
|
# line of change, 4 april 2019, at 1am (gmt+0)
|
||||||
BREAK = datetime.datetime(2019, 4, 4, 1)
|
BREAK = datetime.datetime(2019, 4, 4, 1)
|
||||||
|
|
||||||
# if a database has those tables, it ran 0_base.sql.
|
# if a database has those tables, it ran 0_base.sql.
|
||||||
HAS_BASE = ['users', 'guilds', 'e']
|
HAS_BASE = ["users", "guilds", "e"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MigrationContext:
|
class MigrationContext:
|
||||||
"""Hold information about migration."""
|
"""Hold information about migration."""
|
||||||
|
|
||||||
migration_folder: Path
|
migration_folder: Path
|
||||||
scripts: Dict[int, Migration]
|
scripts: Dict[int, Migration]
|
||||||
|
|
||||||
|
|
@ -60,22 +61,21 @@ def make_migration_ctx() -> MigrationContext:
|
||||||
script_folder = os.sep.join(script_path.split(os.sep)[:-1])
|
script_folder = os.sep.join(script_path.split(os.sep)[:-1])
|
||||||
script_folder = Path(script_folder)
|
script_folder = Path(script_folder)
|
||||||
|
|
||||||
migration_folder = script_folder / 'scripts'
|
migration_folder = script_folder / "scripts"
|
||||||
|
|
||||||
mctx = MigrationContext(migration_folder, {})
|
mctx = MigrationContext(migration_folder, {})
|
||||||
|
|
||||||
for mig_path in migration_folder.glob('*.sql'):
|
for mig_path in migration_folder.glob("*.sql"):
|
||||||
mig_path_str = str(mig_path)
|
mig_path_str = str(mig_path)
|
||||||
|
|
||||||
# extract migration script id and name
|
# extract migration script id and name
|
||||||
mig_filename = mig_path_str.split(os.sep)[-1].split('.')[0]
|
mig_filename = mig_path_str.split(os.sep)[-1].split(".")[0]
|
||||||
name_fragments = mig_filename.split('_')
|
name_fragments = mig_filename.split("_")
|
||||||
|
|
||||||
mig_id = int(name_fragments[0])
|
mig_id = int(name_fragments[0])
|
||||||
mig_name = '_'.join(name_fragments[1:])
|
mig_name = "_".join(name_fragments[1:])
|
||||||
|
|
||||||
mctx.scripts[mig_id] = Migration(
|
mctx.scripts[mig_id] = Migration(mig_id, mig_name, mig_path)
|
||||||
mig_id, mig_name, mig_path)
|
|
||||||
|
|
||||||
return mctx
|
return mctx
|
||||||
|
|
||||||
|
|
@ -83,7 +83,8 @@ def make_migration_ctx() -> MigrationContext:
|
||||||
async def _ensure_changelog(app, ctx):
|
async def _ensure_changelog(app, ctx):
|
||||||
# make sure we have the migration table up
|
# make sure we have the migration table up
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE migration_log (
|
CREATE TABLE migration_log (
|
||||||
change_num bigint NOT NULL,
|
change_num bigint NOT NULL,
|
||||||
|
|
||||||
|
|
@ -94,43 +95,56 @@ async def _ensure_changelog(app, ctx):
|
||||||
|
|
||||||
PRIMARY KEY (change_num)
|
PRIMARY KEY (change_num)
|
||||||
);
|
);
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
except asyncpg.DuplicateTableError:
|
except asyncpg.DuplicateTableError:
|
||||||
log.debug('existing migration table')
|
log.debug("existing migration table")
|
||||||
|
|
||||||
# NOTE: this is a migration breakage,
|
# NOTE: this is a migration breakage,
|
||||||
# only applying to databases that had their first migration
|
# only applying to databases that had their first migration
|
||||||
# before 4 april 2019 (more on BREAK)
|
# before 4 april 2019 (more on BREAK)
|
||||||
|
|
||||||
# if migration_log is empty, just assume this is new
|
# if migration_log is empty, just assume this is new
|
||||||
first = await app.db.fetchval("""
|
first = (
|
||||||
|
await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT apply_ts FROM migration_log
|
SELECT apply_ts FROM migration_log
|
||||||
ORDER BY apply_ts ASC
|
ORDER BY apply_ts ASC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""") or BREAK
|
"""
|
||||||
|
)
|
||||||
|
or BREAK
|
||||||
|
)
|
||||||
if first < BREAK:
|
if first < BREAK:
|
||||||
log.info('deleting migration_log due to migration structure change')
|
log.info("deleting migration_log due to migration structure change")
|
||||||
await app.db.execute("DROP TABLE migration_log")
|
await app.db.execute("DROP TABLE migration_log")
|
||||||
await _ensure_changelog(app, ctx)
|
await _ensure_changelog(app, ctx)
|
||||||
|
|
||||||
|
|
||||||
async def _insert_log(app, migration_id: int, description) -> bool:
|
async def _insert_log(app, migration_id: int, description) -> bool:
|
||||||
try:
|
try:
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
INSERT INTO migration_log (change_num, description)
|
INSERT INTO migration_log (change_num, description)
|
||||||
VALUES ($1, $2)
|
VALUES ($1, $2)
|
||||||
""", migration_id, description)
|
""",
|
||||||
|
migration_id,
|
||||||
|
description,
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except asyncpg.UniqueViolationError:
|
except asyncpg.UniqueViolationError:
|
||||||
log.warning('already inserted {}', migration_id)
|
log.warning("already inserted {}", migration_id)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _delete_log(app, migration_id: int):
|
async def _delete_log(app, migration_id: int):
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
DELETE FROM migration_log WHERE change_num = $1
|
DELETE FROM migration_log WHERE change_num = $1
|
||||||
""", migration_id)
|
""",
|
||||||
|
migration_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def apply_migration(app, migration: Migration) -> bool:
|
async def apply_migration(app, migration: Migration) -> bool:
|
||||||
|
|
@ -144,21 +158,20 @@ async def apply_migration(app, migration: Migration) -> bool:
|
||||||
|
|
||||||
Returns a boolean signaling if this failed or not.
|
Returns a boolean signaling if this failed or not.
|
||||||
"""
|
"""
|
||||||
migration_sql = migration.path.read_text(encoding='utf-8')
|
migration_sql = migration.path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
res = await _insert_log(
|
res = await _insert_log(app, migration.id, f"migration: {migration.name}")
|
||||||
app, migration.id, f'migration: {migration.name}')
|
|
||||||
|
|
||||||
if not res:
|
if not res:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await app.db.execute(migration_sql)
|
await app.db.execute(migration_sql)
|
||||||
log.info('applied {} {}', migration.id, migration.name)
|
log.info("applied {} {}", migration.id, migration.name)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except:
|
except:
|
||||||
log.exception('failed to run migration, rollbacking log')
|
log.exception("failed to run migration, rollbacking log")
|
||||||
await _delete_log(app, migration.id)
|
await _delete_log(app, migration.id)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
@ -169,9 +182,11 @@ async def _check_base(app) -> bool:
|
||||||
file."""
|
file."""
|
||||||
try:
|
try:
|
||||||
for table in HAS_BASE:
|
for table in HAS_BASE:
|
||||||
await app.db.execute(f"""
|
await app.db.execute(
|
||||||
|
f"""
|
||||||
SELECT * FROM {table} LIMIT 0
|
SELECT * FROM {table} LIMIT 0
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
except asyncpg.UndefinedTableError:
|
except asyncpg.UndefinedTableError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -197,14 +212,16 @@ async def migrate_cmd(app, _args):
|
||||||
has_base = await _check_base(app)
|
has_base = await _check_base(app)
|
||||||
|
|
||||||
# fetch latest local migration that has been run on this database
|
# fetch latest local migration that has been run on this database
|
||||||
local_change = await app.db.fetchval("""
|
local_change = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT max(change_num)
|
SELECT max(change_num)
|
||||||
FROM migration_log
|
FROM migration_log
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# if base exists, add it to logs, if not, apply (and add to logs)
|
# if base exists, add it to logs, if not, apply (and add to logs)
|
||||||
if has_base:
|
if has_base:
|
||||||
await _insert_log(app, 0, 'migration setup (from existing)')
|
await _insert_log(app, 0, "migration setup (from existing)")
|
||||||
else:
|
else:
|
||||||
await apply_migration(app, ctx.scripts[0])
|
await apply_migration(app, ctx.scripts[0])
|
||||||
|
|
||||||
|
|
@ -215,10 +232,10 @@ async def migrate_cmd(app, _args):
|
||||||
local_change = local_change or 0
|
local_change = local_change or 0
|
||||||
latest_change = ctx.latest
|
latest_change = ctx.latest
|
||||||
|
|
||||||
log.debug('local: {}, latest: {}', local_change, latest_change)
|
log.debug("local: {}, latest: {}", local_change, latest_change)
|
||||||
|
|
||||||
if local_change == latest_change:
|
if local_change == latest_change:
|
||||||
print('no changes to do, exiting')
|
print("no changes to do, exiting")
|
||||||
return
|
return
|
||||||
|
|
||||||
# we do local_change + 1 so we start from the
|
# we do local_change + 1 so we start from the
|
||||||
|
|
@ -227,15 +244,13 @@ async def migrate_cmd(app, _args):
|
||||||
for idx in range(local_change + 1, latest_change + 1):
|
for idx in range(local_change + 1, latest_change + 1):
|
||||||
migration = ctx.scripts.get(idx)
|
migration = ctx.scripts.get(idx)
|
||||||
|
|
||||||
print('applying', migration.id, migration.name)
|
print("applying", migration.id, migration.name)
|
||||||
await apply_migration(app, migration)
|
await apply_migration(app, migration)
|
||||||
|
|
||||||
|
|
||||||
def setup(subparser):
|
def setup(subparser):
|
||||||
migrate_parser = subparser.add_parser(
|
migrate_parser = subparser.add_parser(
|
||||||
'migrate',
|
"migrate", help="Run migration tasks", description=migrate_cmd.__doc__
|
||||||
help='Run migration tasks',
|
|
||||||
description=migrate_cmd.__doc__
|
|
||||||
)
|
)
|
||||||
|
|
||||||
migrate_parser.set_defaults(func=migrate_cmd)
|
migrate_parser.set_defaults(func=migrate_cmd)
|
||||||
|
|
|
||||||
|
|
@ -24,39 +24,51 @@ from litecord.enums import UserFlags
|
||||||
|
|
||||||
async def find_user(username, discrim, ctx) -> int:
|
async def find_user(username, discrim, ctx) -> int:
|
||||||
"""Get a user ID via the username/discrim pair."""
|
"""Get a user ID via the username/discrim pair."""
|
||||||
return await ctx.db.fetchval("""
|
return await ctx.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM users
|
FROM users
|
||||||
WHERE username = $1 AND discriminator = $2
|
WHERE username = $1 AND discriminator = $2
|
||||||
""", username, discrim)
|
""",
|
||||||
|
username,
|
||||||
|
discrim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def set_user_staff(user_id, ctx):
|
async def set_user_staff(user_id, ctx):
|
||||||
"""Give a single user staff status."""
|
"""Give a single user staff status."""
|
||||||
old_flags = await ctx.db.fetchval("""
|
old_flags = await ctx.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT flags
|
SELECT flags
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
new_flags = old_flags | UserFlags.staff
|
new_flags = old_flags | UserFlags.staff
|
||||||
|
|
||||||
await ctx.db.execute("""
|
await ctx.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users
|
UPDATE users
|
||||||
SET flags=$1
|
SET flags=$1
|
||||||
WHERE id = $2
|
WHERE id = $2
|
||||||
""", new_flags, user_id)
|
""",
|
||||||
|
new_flags,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def adduser(ctx, args):
|
async def adduser(ctx, args):
|
||||||
"""Create a single user."""
|
"""Create a single user."""
|
||||||
uid, _ = await create_user(args.username, args.email,
|
uid, _ = await create_user(
|
||||||
args.password, ctx.db, ctx.loop)
|
args.username, args.email, args.password, ctx.db, ctx.loop
|
||||||
|
)
|
||||||
|
|
||||||
user = await ctx.storage.get_user(uid)
|
user = await ctx.storage.get_user(uid)
|
||||||
|
|
||||||
print('created!')
|
print("created!")
|
||||||
print(f'\tuid: {uid}')
|
print(f"\tuid: {uid}")
|
||||||
print(f'\tusername: {user["username"]}')
|
print(f'\tusername: {user["username"]}')
|
||||||
print(f'\tdiscrim: {user["discriminator"]}')
|
print(f'\tdiscrim: {user["discriminator"]}')
|
||||||
|
|
||||||
|
|
@ -72,22 +84,26 @@ async def make_staff(ctx, args):
|
||||||
uid = await find_user(args.username, args.discrim, ctx)
|
uid = await find_user(args.username, args.discrim, ctx)
|
||||||
|
|
||||||
if not uid:
|
if not uid:
|
||||||
return print('user not found')
|
return print("user not found")
|
||||||
|
|
||||||
await set_user_staff(uid, ctx)
|
await set_user_staff(uid, ctx)
|
||||||
print('OK: set staff')
|
print("OK: set staff")
|
||||||
|
|
||||||
|
|
||||||
async def generate_bot_token(ctx, args):
|
async def generate_bot_token(ctx, args):
|
||||||
"""Generate a token for specified bot."""
|
"""Generate a token for specified bot."""
|
||||||
|
|
||||||
password_hash = await ctx.db.fetchval("""
|
password_hash = await ctx.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT password_hash
|
SELECT password_hash
|
||||||
FROM users
|
FROM users
|
||||||
WHERE id = $1 AND bot = 'true'
|
WHERE id = $1 AND bot = 'true'
|
||||||
""", int(args.user_id))
|
""",
|
||||||
|
int(args.user_id),
|
||||||
|
)
|
||||||
|
|
||||||
if not password_hash:
|
if not password_hash:
|
||||||
return print('cannot find a bot with specified id')
|
return print("cannot find a bot with specified id")
|
||||||
|
|
||||||
print(make_token(args.user_id, password_hash))
|
print(make_token(args.user_id, password_hash))
|
||||||
|
|
||||||
|
|
@ -97,7 +113,7 @@ async def del_user(ctx, args):
|
||||||
uid = await find_user(args.username, args.discrim, ctx)
|
uid = await find_user(args.username, args.discrim, ctx)
|
||||||
|
|
||||||
if uid is None:
|
if uid is None:
|
||||||
print('user not found')
|
print("user not found")
|
||||||
return
|
return
|
||||||
|
|
||||||
user = await ctx.storage.get_user(uid)
|
user = await ctx.storage.get_user(uid)
|
||||||
|
|
@ -106,57 +122,48 @@ async def del_user(ctx, args):
|
||||||
print(f'\tuname: {user["username"]}')
|
print(f'\tuname: {user["username"]}')
|
||||||
print(f'\tdiscrim: {user["discriminator"]}')
|
print(f'\tdiscrim: {user["discriminator"]}')
|
||||||
|
|
||||||
print('\n you sure you want to delete user? press Y (uppercase)')
|
print("\n you sure you want to delete user? press Y (uppercase)")
|
||||||
confirm = input()
|
confirm = input()
|
||||||
|
|
||||||
if confirm != 'Y':
|
if confirm != "Y":
|
||||||
print('not confirmed')
|
print("not confirmed")
|
||||||
return
|
return
|
||||||
|
|
||||||
await delete_user(uid, app_=ctx)
|
await delete_user(uid, app_=ctx)
|
||||||
print('ok')
|
print("ok")
|
||||||
|
|
||||||
|
|
||||||
def setup(subparser):
|
def setup(subparser):
|
||||||
setup_test_parser = subparser.add_parser(
|
setup_test_parser = subparser.add_parser("adduser", help="create a user")
|
||||||
'adduser',
|
|
||||||
help='create a user',
|
|
||||||
)
|
|
||||||
|
|
||||||
setup_test_parser.add_argument(
|
setup_test_parser.add_argument("username", help="username of the user")
|
||||||
'username', help='username of the user')
|
setup_test_parser.add_argument("email", help="email of the user")
|
||||||
setup_test_parser.add_argument(
|
setup_test_parser.add_argument("password", help="password of the user")
|
||||||
'email', help='email of the user')
|
|
||||||
setup_test_parser.add_argument(
|
|
||||||
'password', help='password of the user')
|
|
||||||
|
|
||||||
setup_test_parser.set_defaults(func=adduser)
|
setup_test_parser.set_defaults(func=adduser)
|
||||||
|
|
||||||
staff_parser = subparser.add_parser(
|
staff_parser = subparser.add_parser(
|
||||||
'make_staff',
|
"make_staff", help="make a user staff", description=make_staff.__doc__
|
||||||
help='make a user staff',
|
|
||||||
description=make_staff.__doc__
|
|
||||||
)
|
)
|
||||||
|
|
||||||
staff_parser.add_argument('username')
|
staff_parser.add_argument("username")
|
||||||
staff_parser.add_argument(
|
staff_parser.add_argument("discrim", help="the discriminator of the user")
|
||||||
'discrim', help='the discriminator of the user')
|
|
||||||
|
|
||||||
staff_parser.set_defaults(func=make_staff)
|
staff_parser.set_defaults(func=make_staff)
|
||||||
|
|
||||||
del_user_parser = subparser.add_parser(
|
del_user_parser = subparser.add_parser("deluser", help="delete a single user")
|
||||||
'deluser', help='delete a single user')
|
|
||||||
|
|
||||||
del_user_parser.add_argument('username')
|
del_user_parser.add_argument("username")
|
||||||
del_user_parser.add_argument('discrim')
|
del_user_parser.add_argument("discrim")
|
||||||
|
|
||||||
del_user_parser.set_defaults(func=del_user)
|
del_user_parser.set_defaults(func=del_user)
|
||||||
|
|
||||||
token_parser = subparser.add_parser(
|
token_parser = subparser.add_parser(
|
||||||
'generate_token',
|
"generate_token",
|
||||||
help='generate a token for specified bot',
|
help="generate a token for specified bot",
|
||||||
description=generate_bot_token.__doc__)
|
description=generate_bot_token.__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
token_parser.add_argument('user_id')
|
token_parser.add_argument("user_id")
|
||||||
|
|
||||||
token_parser.set_defaults(func=generate_bot_token)
|
token_parser.set_defaults(func=generate_bot_token)
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ log = Logger(__name__)
|
||||||
@dataclass
|
@dataclass
|
||||||
class FakeApp:
|
class FakeApp:
|
||||||
"""Fake app instance."""
|
"""Fake app instance."""
|
||||||
|
|
||||||
config: dict
|
config: dict
|
||||||
db = None
|
db = None
|
||||||
loop: asyncio.BaseEventLoop = None
|
loop: asyncio.BaseEventLoop = None
|
||||||
|
|
@ -50,7 +51,7 @@ class FakeApp:
|
||||||
|
|
||||||
def init_parser():
|
def init_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
subparser = parser.add_subparsers(help='operations')
|
subparser = parser.add_subparsers(help="operations")
|
||||||
|
|
||||||
migration(subparser)
|
migration(subparser)
|
||||||
users.setup(subparser)
|
users.setup(subparser)
|
||||||
|
|
@ -78,12 +79,12 @@ def main(config):
|
||||||
# only init app managers when we aren't migrating
|
# only init app managers when we aren't migrating
|
||||||
# as the managers require it
|
# as the managers require it
|
||||||
# and the migrate command also sets the db up
|
# and the migrate command also sets the db up
|
||||||
if argv[1] != 'migrate':
|
if argv[1] != "migrate":
|
||||||
init_app_managers(app, voice=False)
|
init_app_managers(app, voice=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
loop.run_until_complete(args.func(app, args))
|
loop.run_until_complete(args.func(app, args))
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception('error while running command')
|
log.exception("error while running command")
|
||||||
finally:
|
finally:
|
||||||
loop.run_until_complete(app.db.close())
|
loop.run_until_complete(app.db.close())
|
||||||
|
|
|
||||||
236
run.py
236
run.py
|
|
@ -33,32 +33,51 @@ from aiohttp import ClientSession
|
||||||
import config
|
import config
|
||||||
|
|
||||||
from litecord.blueprints import (
|
from litecord.blueprints import (
|
||||||
gateway, auth, users, guilds, channels, webhooks, science,
|
gateway,
|
||||||
voice, invites, relationships, dms, icons, nodeinfo, static,
|
auth,
|
||||||
attachments, dm_channels
|
users,
|
||||||
|
guilds,
|
||||||
|
channels,
|
||||||
|
webhooks,
|
||||||
|
science,
|
||||||
|
voice,
|
||||||
|
invites,
|
||||||
|
relationships,
|
||||||
|
dms,
|
||||||
|
icons,
|
||||||
|
nodeinfo,
|
||||||
|
static,
|
||||||
|
attachments,
|
||||||
|
dm_channels,
|
||||||
)
|
)
|
||||||
|
|
||||||
# those blueprints are separated from the "main" ones
|
# those blueprints are separated from the "main" ones
|
||||||
# for code readability if people want to dig through
|
# for code readability if people want to dig through
|
||||||
# the codebase.
|
# the codebase.
|
||||||
from litecord.blueprints.guild import (
|
from litecord.blueprints.guild import (
|
||||||
guild_roles, guild_members, guild_channels, guild_mod,
|
guild_roles,
|
||||||
guild_emoji
|
guild_members,
|
||||||
|
guild_channels,
|
||||||
|
guild_mod,
|
||||||
|
guild_emoji,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.blueprints.channel import (
|
from litecord.blueprints.channel import (
|
||||||
channel_messages, channel_reactions, channel_pins
|
channel_messages,
|
||||||
|
channel_reactions,
|
||||||
|
channel_pins,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.blueprints.user import (
|
from litecord.blueprints.user import user_settings, user_billing, fake_store
|
||||||
user_settings, user_billing, fake_store
|
|
||||||
)
|
|
||||||
|
|
||||||
from litecord.blueprints.user.billing_job import payment_job
|
from litecord.blueprints.user.billing_job import payment_job
|
||||||
|
|
||||||
from litecord.blueprints.admin_api import (
|
from litecord.blueprints.admin_api import (
|
||||||
voice as voice_admin, features as features_admin,
|
voice as voice_admin,
|
||||||
guilds as guilds_admin, users as users_admin, instance_invites
|
features as features_admin,
|
||||||
|
guilds as guilds_admin,
|
||||||
|
users as users_admin,
|
||||||
|
instance_invites,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.blueprints.admin_api.voice import guild_region_check
|
from litecord.blueprints.admin_api.voice import guild_region_check
|
||||||
|
|
@ -84,23 +103,23 @@ from litecord.utils import LitecordJSONEncoder
|
||||||
# setup logbook
|
# setup logbook
|
||||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||||
handler.push_application()
|
handler.push_application()
|
||||||
log = Logger('litecord.boot')
|
log = Logger("litecord.boot")
|
||||||
redirect_logging()
|
redirect_logging()
|
||||||
|
|
||||||
|
|
||||||
def make_app():
|
def make_app():
|
||||||
app = Quart(__name__)
|
app = Quart(__name__)
|
||||||
app.config.from_object(f'config.{config.MODE}')
|
app.config.from_object(f"config.{config.MODE}")
|
||||||
is_debug = app.config.get('DEBUG', False)
|
is_debug = app.config.get("DEBUG", False)
|
||||||
app.debug = is_debug
|
app.debug = is_debug
|
||||||
|
|
||||||
if is_debug:
|
if is_debug:
|
||||||
log.info('on debug')
|
log.info("on debug")
|
||||||
handler.level = logbook.DEBUG
|
handler.level = logbook.DEBUG
|
||||||
app.logger.level = logbook.DEBUG
|
app.logger.level = logbook.DEBUG
|
||||||
|
|
||||||
# always keep websockets on INFO
|
# always keep websockets on INFO
|
||||||
logging.getLogger('websockets').setLevel(logbook.INFO)
|
logging.getLogger("websockets").setLevel(logbook.INFO)
|
||||||
|
|
||||||
# use our custom json encoder for custom data types
|
# use our custom json encoder for custom data types
|
||||||
app.json_encoder = LitecordJSONEncoder
|
app.json_encoder = LitecordJSONEncoder
|
||||||
|
|
@ -112,51 +131,44 @@ def set_blueprints(app_):
|
||||||
"""Set the blueprints for a given app instance"""
|
"""Set the blueprints for a given app instance"""
|
||||||
bps = {
|
bps = {
|
||||||
gateway: None,
|
gateway: None,
|
||||||
auth: '/auth',
|
auth: "/auth",
|
||||||
|
users: "/users",
|
||||||
users: '/users',
|
user_settings: "/users",
|
||||||
user_settings: '/users',
|
user_billing: "/users",
|
||||||
user_billing: '/users',
|
relationships: "/users",
|
||||||
relationships: '/users',
|
guilds: "/guilds",
|
||||||
|
guild_roles: "/guilds",
|
||||||
guilds: '/guilds',
|
guild_members: "/guilds",
|
||||||
guild_roles: '/guilds',
|
guild_channels: "/guilds",
|
||||||
guild_members: '/guilds',
|
guild_mod: "/guilds",
|
||||||
guild_channels: '/guilds',
|
guild_emoji: "/guilds",
|
||||||
guild_mod: '/guilds',
|
channels: "/channels",
|
||||||
guild_emoji: '/guilds',
|
channel_messages: "/channels",
|
||||||
|
channel_reactions: "/channels",
|
||||||
channels: '/channels',
|
channel_pins: "/channels",
|
||||||
channel_messages: '/channels',
|
|
||||||
channel_reactions: '/channels',
|
|
||||||
channel_pins: '/channels',
|
|
||||||
|
|
||||||
webhooks: None,
|
webhooks: None,
|
||||||
science: None,
|
science: None,
|
||||||
voice: '/voice',
|
voice: "/voice",
|
||||||
invites: None,
|
invites: None,
|
||||||
dms: '/users',
|
dms: "/users",
|
||||||
dm_channels: '/channels',
|
dm_channels: "/channels",
|
||||||
|
|
||||||
fake_store: None,
|
fake_store: None,
|
||||||
|
|
||||||
icons: -1,
|
icons: -1,
|
||||||
attachments: -1,
|
attachments: -1,
|
||||||
nodeinfo: -1,
|
nodeinfo: -1,
|
||||||
static: -1,
|
static: -1,
|
||||||
|
voice_admin: "/admin/voice",
|
||||||
voice_admin: '/admin/voice',
|
features_admin: "/admin/guilds",
|
||||||
features_admin: '/admin/guilds',
|
guilds_admin: "/admin/guilds",
|
||||||
guilds_admin: '/admin/guilds',
|
users_admin: "/admin/users",
|
||||||
users_admin: '/admin/users',
|
instance_invites: "/admin/instance/invites",
|
||||||
instance_invites: '/admin/instance/invites'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for bp, suffix in bps.items():
|
for bp, suffix in bps.items():
|
||||||
url_prefix = f'/api/v6{suffix or ""}'
|
url_prefix = f'/api/v6{suffix or ""}'
|
||||||
|
|
||||||
if suffix == -1:
|
if suffix == -1:
|
||||||
url_prefix = ''
|
url_prefix = ""
|
||||||
|
|
||||||
app_.register_blueprint(bp, url_prefix=url_prefix)
|
app_.register_blueprint(bp, url_prefix=url_prefix)
|
||||||
|
|
||||||
|
|
@ -175,37 +187,35 @@ async def app_before_request():
|
||||||
@app.after_request
|
@app.after_request
|
||||||
async def app_after_request(resp):
|
async def app_after_request(resp):
|
||||||
"""Handle CORS headers."""
|
"""Handle CORS headers."""
|
||||||
origin = request.headers.get('Origin', '*')
|
origin = request.headers.get("Origin", "*")
|
||||||
resp.headers['Access-Control-Allow-Origin'] = origin
|
resp.headers["Access-Control-Allow-Origin"] = origin
|
||||||
resp.headers['Access-Control-Allow-Headers'] = (
|
resp.headers["Access-Control-Allow-Headers"] = (
|
||||||
'*, X-Super-Properties, '
|
"*, X-Super-Properties, "
|
||||||
'X-Fingerprint, '
|
"X-Fingerprint, "
|
||||||
'X-Context-Properties, '
|
"X-Context-Properties, "
|
||||||
'X-Failed-Requests, '
|
"X-Failed-Requests, "
|
||||||
'X-Debug-Options, '
|
"X-Debug-Options, "
|
||||||
'Content-Type, '
|
"Content-Type, "
|
||||||
'Authorization, '
|
"Authorization, "
|
||||||
'Origin, '
|
"Origin, "
|
||||||
'If-None-Match'
|
"If-None-Match"
|
||||||
)
|
)
|
||||||
resp.headers['Access-Control-Allow-Methods'] = \
|
resp.headers["Access-Control-Allow-Methods"] = resp.headers.get("allow", "*")
|
||||||
resp.headers.get('allow', '*')
|
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def _set_rtl_reset(bucket, resp):
|
def _set_rtl_reset(bucket, resp):
|
||||||
reset = bucket._window + bucket.second
|
reset = bucket._window + bucket.second
|
||||||
precision = request.headers.get('x-ratelimit-precision', 'second')
|
precision = request.headers.get("x-ratelimit-precision", "second")
|
||||||
|
|
||||||
if precision == 'second':
|
if precision == "second":
|
||||||
resp.headers['X-RateLimit-Reset'] = str(round(reset))
|
resp.headers["X-RateLimit-Reset"] = str(round(reset))
|
||||||
elif precision == 'millisecond':
|
elif precision == "millisecond":
|
||||||
resp.headers['X-RateLimit-Reset'] = str(reset)
|
resp.headers["X-RateLimit-Reset"] = str(reset)
|
||||||
else:
|
else:
|
||||||
resp.headers['X-RateLimit-Reset'] = (
|
resp.headers["X-RateLimit-Reset"] = (
|
||||||
'Invalid X-RateLimit-Precision, '
|
"Invalid X-RateLimit-Precision, " "valid options are (second, millisecond)"
|
||||||
'valid options are (second, millisecond)'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -218,15 +228,15 @@ async def app_set_ratelimit_headers(resp):
|
||||||
if bucket is None:
|
if bucket is None:
|
||||||
raise AttributeError()
|
raise AttributeError()
|
||||||
|
|
||||||
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
|
resp.headers["X-RateLimit-Limit"] = str(bucket.requests)
|
||||||
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
|
resp.headers["X-RateLimit-Remaining"] = str(bucket._tokens)
|
||||||
resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower()
|
resp.headers["X-RateLimit-Global"] = str(request.bucket_global).lower()
|
||||||
_set_rtl_reset(bucket, resp)
|
_set_rtl_reset(bucket, resp)
|
||||||
|
|
||||||
# only add Retry-After if we actually hit a ratelimit
|
# only add Retry-After if we actually hit a ratelimit
|
||||||
retry_after = request.retry_after
|
retry_after = request.retry_after
|
||||||
if request.retry_after:
|
if request.retry_after:
|
||||||
resp.headers['Retry-After'] = str(retry_after)
|
resp.headers["Retry-After"] = str(retry_after)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -238,8 +248,8 @@ async def init_app_db(app_):
|
||||||
|
|
||||||
Also spawns the job scheduler.
|
Also spawns the job scheduler.
|
||||||
"""
|
"""
|
||||||
log.info('db connect')
|
log.info("db connect")
|
||||||
app_.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
app_.db = await asyncpg.create_pool(**app.config["POSTGRES"])
|
||||||
|
|
||||||
app_.sched = JobManager()
|
app_.sched = JobManager()
|
||||||
|
|
||||||
|
|
@ -247,7 +257,7 @@ async def init_app_db(app_):
|
||||||
def init_app_managers(app_, *, voice=True):
|
def init_app_managers(app_, *, voice=True):
|
||||||
"""Initialize singleton classes."""
|
"""Initialize singleton classes."""
|
||||||
app_.loop = asyncio.get_event_loop()
|
app_.loop = asyncio.get_event_loop()
|
||||||
app_.ratelimiter = RatelimitManager(app_.config.get('_testing'))
|
app_.ratelimiter = RatelimitManager(app_.config.get("_testing"))
|
||||||
app_.state_manager = StateManager()
|
app_.state_manager = StateManager()
|
||||||
|
|
||||||
app_.storage = Storage(app_)
|
app_.storage = Storage(app_)
|
||||||
|
|
@ -274,15 +284,12 @@ async def api_index(app_):
|
||||||
to_find = {}
|
to_find = {}
|
||||||
found = []
|
found = []
|
||||||
|
|
||||||
with open('discord_endpoints.txt') as fd:
|
with open("discord_endpoints.txt") as fd:
|
||||||
for line in fd.readlines():
|
for line in fd.readlines():
|
||||||
components = line.split(' ')
|
components = line.split(" ")
|
||||||
components = list(filter(
|
components = list(filter(bool, components))
|
||||||
bool,
|
|
||||||
components
|
|
||||||
))
|
|
||||||
name, method, path = components
|
name, method, path = components
|
||||||
path = f'/api/v6{path.strip()}'
|
path = f"/api/v6{path.strip()}"
|
||||||
method = method.strip()
|
method = method.strip()
|
||||||
to_find[(path, method)] = name
|
to_find[(path, method)] = name
|
||||||
|
|
||||||
|
|
@ -290,17 +297,17 @@ async def api_index(app_):
|
||||||
path = rule.rule
|
path = rule.rule
|
||||||
|
|
||||||
# convert the path to the discord_endpoints file's style
|
# convert the path to the discord_endpoints file's style
|
||||||
path = path.replace('_', '.')
|
path = path.replace("_", ".")
|
||||||
path = path.replace('<', '{')
|
path = path.replace("<", "{")
|
||||||
path = path.replace('>', '}')
|
path = path.replace(">", "}")
|
||||||
path = path.replace('int:', '')
|
path = path.replace("int:", "")
|
||||||
|
|
||||||
# change our parameters into user.id
|
# change our parameters into user.id
|
||||||
path = path.replace('member.id', 'user.id')
|
path = path.replace("member.id", "user.id")
|
||||||
path = path.replace('banned.id', 'user.id')
|
path = path.replace("banned.id", "user.id")
|
||||||
path = path.replace('target.id', 'user.id')
|
path = path.replace("target.id", "user.id")
|
||||||
path = path.replace('other.id', 'user.id')
|
path = path.replace("other.id", "user.id")
|
||||||
path = path.replace('peer.id', 'user.id')
|
path = path.replace("peer.id", "user.id")
|
||||||
|
|
||||||
methods = rule.methods
|
methods = rule.methods
|
||||||
|
|
||||||
|
|
@ -317,10 +324,15 @@ async def api_index(app_):
|
||||||
percentage = (len(found) / len(api)) * 100
|
percentage = (len(found) / len(api)) * 100
|
||||||
percentage = round(percentage, 2)
|
percentage = round(percentage, 2)
|
||||||
|
|
||||||
log.debug('API compliance: {} out of {} ({} missing), {}% compliant',
|
log.debug(
|
||||||
len(found), len(api), len(missing), percentage)
|
"API compliance: {} out of {} ({} missing), {}% compliant",
|
||||||
|
len(found),
|
||||||
|
len(api),
|
||||||
|
len(missing),
|
||||||
|
percentage,
|
||||||
|
)
|
||||||
|
|
||||||
log.debug('missing: {}', missing)
|
log.debug("missing: {}", missing)
|
||||||
|
|
||||||
|
|
||||||
async def post_app_start(app_):
|
async def post_app_start(app_):
|
||||||
|
|
@ -332,7 +344,7 @@ async def post_app_start(app_):
|
||||||
|
|
||||||
def start_websocket(host, port, ws_handler) -> asyncio.Future:
|
def start_websocket(host, port, ws_handler) -> asyncio.Future:
|
||||||
"""Start a websocket. Returns the websocket future"""
|
"""Start a websocket. Returns the websocket future"""
|
||||||
log.info(f'starting websocket at {host} {port}')
|
log.info(f"starting websocket at {host} {port}")
|
||||||
|
|
||||||
async def _wrapper(ws, url):
|
async def _wrapper(ws, url):
|
||||||
# We wrap the main websocket_handler
|
# We wrap the main websocket_handler
|
||||||
|
|
@ -348,7 +360,7 @@ async def app_before_serving():
|
||||||
|
|
||||||
Also sets up the websocket handlers.
|
Also sets up the websocket handlers.
|
||||||
"""
|
"""
|
||||||
log.info('opening db')
|
log.info("opening db")
|
||||||
await init_app_db(app)
|
await init_app_db(app)
|
||||||
|
|
||||||
app.session = ClientSession()
|
app.session = ClientSession()
|
||||||
|
|
@ -359,8 +371,7 @@ async def app_before_serving():
|
||||||
# start gateway websocket
|
# start gateway websocket
|
||||||
# voice websocket is handled by the voice server
|
# voice websocket is handled by the voice server
|
||||||
ws_fut = start_websocket(
|
ws_fut = start_websocket(
|
||||||
app.config['WS_HOST'], app.config['WS_PORT'],
|
app.config["WS_HOST"], app.config["WS_PORT"], websocket_handler
|
||||||
websocket_handler
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await ws_fut
|
await ws_fut
|
||||||
|
|
@ -379,7 +390,7 @@ async def app_after_serving():
|
||||||
|
|
||||||
app.sched.close()
|
app.sched.close()
|
||||||
|
|
||||||
log.info('closing db')
|
log.info("closing db")
|
||||||
await app.db.close()
|
await app.db.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -391,24 +402,23 @@ async def handle_litecord_err(err):
|
||||||
ejson = {}
|
ejson = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ejson['code'] = err.error_code
|
ejson["code"] = err.error_code
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log.warning('error: {} {!r}', err.status_code, err.message)
|
log.warning("error: {} {!r}", err.status_code, err.message)
|
||||||
|
|
||||||
return jsonify({
|
return (
|
||||||
'error': True,
|
jsonify(
|
||||||
'status': err.status_code,
|
{"error": True, "status": err.status_code, "message": err.message, **ejson}
|
||||||
'message': err.message,
|
),
|
||||||
**ejson
|
err.status_code,
|
||||||
}), err.status_code
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.errorhandler(500)
|
@app.errorhandler(500)
|
||||||
async def handle_500(err):
|
async def handle_500(err):
|
||||||
return jsonify({
|
return (
|
||||||
'error': True,
|
jsonify({"error": True, "message": repr(err), "internal_server_error": True}),
|
||||||
'message': repr(err),
|
500,
|
||||||
'internal_server_error': True,
|
)
|
||||||
}), 500
|
|
||||||
|
|
|
||||||
12
setup.py
12
setup.py
|
|
@ -20,10 +20,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='litecord',
|
name="litecord",
|
||||||
version='0.0.1',
|
version="0.0.1",
|
||||||
description='Implementation of the Discord API',
|
description="Implementation of the Discord API",
|
||||||
url='https://litecord.top',
|
url="https://litecord.top",
|
||||||
author='Luna Mendes',
|
author="Luna Mendes",
|
||||||
python_requires='>=3.7'
|
python_requires=">=3.7",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,15 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
|
|
||||||
def email() -> str:
|
def email() -> str:
|
||||||
return f'{secrets.token_hex(5)}@{secrets.token_hex(5)}.com'
|
return f"{secrets.token_hex(5)}@{secrets.token_hex(5)}.com"
|
||||||
|
|
||||||
|
|
||||||
class TestClient:
|
class TestClient:
|
||||||
"""Test client that wraps pytest-sanic's TestClient and a test
|
"""Test client that wraps pytest-sanic's TestClient and a test
|
||||||
user and adds authorization headers to test requests."""
|
user and adds authorization headers to test requests."""
|
||||||
|
|
||||||
def __init__(self, test_cli, test_user):
|
def __init__(self, test_cli, test_user):
|
||||||
self.cli = test_cli
|
self.cli = test_cli
|
||||||
self.app = test_cli.app
|
self.app = test_cli.app
|
||||||
|
|
@ -37,31 +39,31 @@ class TestClient:
|
||||||
def _inject_auth(self, kwargs: dict) -> list:
|
def _inject_auth(self, kwargs: dict) -> list:
|
||||||
"""Inject the test user's API key into the test request before
|
"""Inject the test user's API key into the test request before
|
||||||
passing the request on to the underlying TestClient."""
|
passing the request on to the underlying TestClient."""
|
||||||
headers = kwargs.get('headers', {})
|
headers = kwargs.get("headers", {})
|
||||||
headers['authorization'] = self.user['token']
|
headers["authorization"] = self.user["token"]
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
async def get(self, *args, **kwargs):
|
async def get(self, *args, **kwargs):
|
||||||
"""Send a GET request."""
|
"""Send a GET request."""
|
||||||
kwargs['headers'] = self._inject_auth(kwargs)
|
kwargs["headers"] = self._inject_auth(kwargs)
|
||||||
return await self.cli.get(*args, **kwargs)
|
return await self.cli.get(*args, **kwargs)
|
||||||
|
|
||||||
async def post(self, *args, **kwargs):
|
async def post(self, *args, **kwargs):
|
||||||
"""Send a POST request."""
|
"""Send a POST request."""
|
||||||
kwargs['headers'] = self._inject_auth(kwargs)
|
kwargs["headers"] = self._inject_auth(kwargs)
|
||||||
return await self.cli.post(*args, **kwargs)
|
return await self.cli.post(*args, **kwargs)
|
||||||
|
|
||||||
async def put(self, *args, **kwargs):
|
async def put(self, *args, **kwargs):
|
||||||
"""Send a POST request."""
|
"""Send a POST request."""
|
||||||
kwargs['headers'] = self._inject_auth(kwargs)
|
kwargs["headers"] = self._inject_auth(kwargs)
|
||||||
return await self.cli.put(*args, **kwargs)
|
return await self.cli.put(*args, **kwargs)
|
||||||
|
|
||||||
async def patch(self, *args, **kwargs):
|
async def patch(self, *args, **kwargs):
|
||||||
"""Send a PATCH request."""
|
"""Send a PATCH request."""
|
||||||
kwargs['headers'] = self._inject_auth(kwargs)
|
kwargs["headers"] = self._inject_auth(kwargs)
|
||||||
return await self.cli.patch(*args, **kwargs)
|
return await self.cli.patch(*args, **kwargs)
|
||||||
|
|
||||||
async def delete(self, *args, **kwargs):
|
async def delete(self, *args, **kwargs):
|
||||||
"""Send a DELETE request."""
|
"""Send a DELETE request."""
|
||||||
kwargs['headers'] = self._inject_auth(kwargs)
|
kwargs["headers"] = self._inject_auth(kwargs)
|
||||||
return await self.cli.delete(*args, **kwargs)
|
return await self.cli.delete(*args, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -36,22 +36,22 @@ from litecord.blueprints.auth import make_token
|
||||||
from litecord.blueprints.users import delete_user
|
from litecord.blueprints.users import delete_user
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name='app')
|
@pytest.fixture(name="app")
|
||||||
def _test_app(unused_tcp_port, event_loop):
|
def _test_app(unused_tcp_port, event_loop):
|
||||||
set_blueprints(main_app)
|
set_blueprints(main_app)
|
||||||
main_app.config['_testing'] = True
|
main_app.config["_testing"] = True
|
||||||
|
|
||||||
# reassign an unused tcp port for websockets
|
# reassign an unused tcp port for websockets
|
||||||
# since the config might give a used one.
|
# since the config might give a used one.
|
||||||
ws_port = unused_tcp_port
|
ws_port = unused_tcp_port
|
||||||
|
|
||||||
main_app.config['IS_SSL'] = False
|
main_app.config["IS_SSL"] = False
|
||||||
main_app.config['WS_PORT'] = ws_port
|
main_app.config["WS_PORT"] = ws_port
|
||||||
main_app.config['WEBSOCKET_URL'] = f'localhost:{ws_port}'
|
main_app.config["WEBSOCKET_URL"] = f"localhost:{ws_port}"
|
||||||
|
|
||||||
# testing user creations requires hardcoding this to true
|
# testing user creations requires hardcoding this to true
|
||||||
# on testing
|
# on testing
|
||||||
main_app.config['REGISTRATIONS'] = True
|
main_app.config["REGISTRATIONS"] = True
|
||||||
|
|
||||||
# make sure we're calling the before_serving hooks
|
# make sure we're calling the before_serving hooks
|
||||||
event_loop.run_until_complete(main_app.startup())
|
event_loop.run_until_complete(main_app.startup())
|
||||||
|
|
@ -63,11 +63,12 @@ def _test_app(unused_tcp_port, event_loop):
|
||||||
event_loop.run_until_complete(main_app.shutdown())
|
event_loop.run_until_complete(main_app.shutdown())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name='test_cli')
|
@pytest.fixture(name="test_cli")
|
||||||
def _test_cli(app):
|
def _test_cli(app):
|
||||||
"""Give a test client."""
|
"""Give a test client."""
|
||||||
return app.test_client()
|
return app.test_client()
|
||||||
|
|
||||||
|
|
||||||
# code shamelessly stolen from my elixire mr
|
# code shamelessly stolen from my elixire mr
|
||||||
# https://gitlab.com/elixire/elixire/merge_requests/52
|
# https://gitlab.com/elixire/elixire/merge_requests/52
|
||||||
async def _user_fixture_setup(app):
|
async def _user_fixture_setup(app):
|
||||||
|
|
@ -76,21 +77,26 @@ async def _user_fixture_setup(app):
|
||||||
user_email = email()
|
user_email = email()
|
||||||
|
|
||||||
user_id, pwd_hash = await create_user(
|
user_id, pwd_hash = await create_user(
|
||||||
username, user_email, password, app.db, app.loop)
|
username, user_email, password, app.db, app.loop
|
||||||
|
)
|
||||||
|
|
||||||
# generate a token for api access
|
# generate a token for api access
|
||||||
user_token = make_token(user_id, pwd_hash)
|
user_token = make_token(user_id, pwd_hash)
|
||||||
|
|
||||||
return {'id': user_id, 'token': user_token,
|
return {
|
||||||
'email': user_email, 'username': username,
|
"id": user_id,
|
||||||
'password': password}
|
"token": user_token,
|
||||||
|
"email": user_email,
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def _user_fixture_teardown(app, udata: dict):
|
async def _user_fixture_teardown(app, udata: dict):
|
||||||
await delete_user(udata['id'], app_=app)
|
await delete_user(udata["id"], app_=app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name='test_user')
|
@pytest.fixture(name="test_user")
|
||||||
async def test_user_fixture(app):
|
async def test_user_fixture(app):
|
||||||
"""Yield a randomly generated test user."""
|
"""Yield a randomly generated test user."""
|
||||||
udata = await _user_fixture_setup(app)
|
udata = await _user_fixture_setup(app)
|
||||||
|
|
@ -113,18 +119,25 @@ async def test_cli_staff(test_cli):
|
||||||
# same test_cli_user, which isn't acceptable.
|
# same test_cli_user, which isn't acceptable.
|
||||||
app = test_cli.app
|
app = test_cli.app
|
||||||
test_user = await _user_fixture_setup(app)
|
test_user = await _user_fixture_setup(app)
|
||||||
user_id = test_user['id']
|
user_id = test_user["id"]
|
||||||
|
|
||||||
# copied from manage.cmd.users.set_user_staff.
|
# copied from manage.cmd.users.set_user_staff.
|
||||||
old_flags = await app.db.fetchval("""
|
old_flags = await app.db.fetchval(
|
||||||
|
"""
|
||||||
SELECT flags FROM users WHERE id = $1
|
SELECT flags FROM users WHERE id = $1
|
||||||
""", user_id)
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
new_flags = old_flags | UserFlags.staff
|
new_flags = old_flags | UserFlags.staff
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute(
|
||||||
|
"""
|
||||||
UPDATE users SET flags = $1 WHERE id = $2
|
UPDATE users SET flags = $1 WHERE id = $2
|
||||||
""", new_flags, user_id)
|
""",
|
||||||
|
new_flags,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
yield TestClient(test_cli, test_user)
|
yield TestClient(test_cli, test_user)
|
||||||
await _user_fixture_teardown(test_cli.app, test_user)
|
await _user_fixture_teardown(test_cli.app, test_user)
|
||||||
|
|
|
||||||
|
|
@ -24,24 +24,24 @@ import pytest
|
||||||
from litecord.blueprints.guilds import delete_guild
|
from litecord.blueprints.guilds import delete_guild
|
||||||
from litecord.errors import GuildNotFound
|
from litecord.errors import GuildNotFound
|
||||||
|
|
||||||
|
|
||||||
async def _create_guild(test_cli_staff):
|
async def _create_guild(test_cli_staff):
|
||||||
genned_name = secrets.token_hex(6)
|
genned_name = secrets.token_hex(6)
|
||||||
|
|
||||||
resp = await test_cli_staff.post('/api/v6/guilds', json={
|
resp = await test_cli_staff.post(
|
||||||
'name': genned_name,
|
"/api/v6/guilds", json={"name": genned_name, "region": None}
|
||||||
'region': None
|
)
|
||||||
})
|
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
rjson = await resp.json
|
rjson = await resp.json
|
||||||
assert isinstance(rjson, dict)
|
assert isinstance(rjson, dict)
|
||||||
assert rjson['name'] == genned_name
|
assert rjson["name"] == genned_name
|
||||||
|
|
||||||
return rjson
|
return rjson
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
||||||
resp = await test_cli_staff.get(f'/api/v6/admin/guilds/{guild_id}')
|
resp = await test_cli_staff.get(f"/api/v6/admin/guilds/{guild_id}")
|
||||||
|
|
||||||
if ret_early:
|
if ret_early:
|
||||||
return resp
|
return resp
|
||||||
|
|
@ -49,7 +49,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
rjson = await resp.json
|
rjson = await resp.json
|
||||||
assert isinstance(rjson, dict)
|
assert isinstance(rjson, dict)
|
||||||
assert rjson['id'] == guild_id
|
assert rjson["id"] == guild_id
|
||||||
|
|
||||||
return rjson
|
return rjson
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ async def _fetch_guild(test_cli_staff, guild_id, *, ret_early=False):
|
||||||
async def test_guild_fetch(test_cli_staff):
|
async def test_guild_fetch(test_cli_staff):
|
||||||
"""Test the creation and fetching of a guild via the Admin API."""
|
"""Test the creation and fetching of a guild via the Admin API."""
|
||||||
rjson = await _create_guild(test_cli_staff)
|
rjson = await _create_guild(test_cli_staff)
|
||||||
guild_id = rjson['id']
|
guild_id = rjson["id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await _fetch_guild(test_cli_staff, guild_id)
|
await _fetch_guild(test_cli_staff, guild_id)
|
||||||
|
|
@ -70,8 +70,8 @@ async def test_guild_fetch(test_cli_staff):
|
||||||
async def test_guild_update(test_cli_staff):
|
async def test_guild_update(test_cli_staff):
|
||||||
"""Test the update of a guild via the Admin API."""
|
"""Test the update of a guild via the Admin API."""
|
||||||
rjson = await _create_guild(test_cli_staff)
|
rjson = await _create_guild(test_cli_staff)
|
||||||
guild_id = rjson['id']
|
guild_id = rjson["id"]
|
||||||
assert not rjson['unavailable']
|
assert not rjson["unavailable"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# I believe setting up an entire gateway client registered to the guild
|
# I believe setting up an entire gateway client registered to the guild
|
||||||
|
|
@ -79,19 +79,17 @@ async def test_guild_update(test_cli_staff):
|
||||||
# testing them. Yes, I know its a bad idea, but if someone has an easier
|
# testing them. Yes, I know its a bad idea, but if someone has an easier
|
||||||
# way to write that, do send an MR.
|
# way to write that, do send an MR.
|
||||||
resp = await test_cli_staff.patch(
|
resp = await test_cli_staff.patch(
|
||||||
f'/api/v6/admin/guilds/{guild_id}',
|
f"/api/v6/admin/guilds/{guild_id}", json={"unavailable": True}
|
||||||
json={
|
)
|
||||||
'unavailable': True
|
|
||||||
})
|
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
rjson = await resp.json
|
rjson = await resp.json
|
||||||
assert isinstance(rjson, dict)
|
assert isinstance(rjson, dict)
|
||||||
assert rjson['id'] == guild_id
|
assert rjson["id"] == guild_id
|
||||||
assert rjson['unavailable']
|
assert rjson["unavailable"]
|
||||||
|
|
||||||
rjson = await _fetch_guild(test_cli_staff, guild_id)
|
rjson = await _fetch_guild(test_cli_staff, guild_id)
|
||||||
assert rjson['unavailable']
|
assert rjson["unavailable"]
|
||||||
finally:
|
finally:
|
||||||
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
||||||
|
|
||||||
|
|
@ -100,20 +98,19 @@ async def test_guild_update(test_cli_staff):
|
||||||
async def test_guild_delete(test_cli_staff):
|
async def test_guild_delete(test_cli_staff):
|
||||||
"""Test the update of a guild via the Admin API."""
|
"""Test the update of a guild via the Admin API."""
|
||||||
rjson = await _create_guild(test_cli_staff)
|
rjson = await _create_guild(test_cli_staff)
|
||||||
guild_id = rjson['id']
|
guild_id = rjson["id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await test_cli_staff.delete(f'/api/v6/admin/guilds/{guild_id}')
|
resp = await test_cli_staff.delete(f"/api/v6/admin/guilds/{guild_id}")
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
resp = await _fetch_guild(
|
resp = await _fetch_guild(test_cli_staff, guild_id, ret_early=True)
|
||||||
test_cli_staff, guild_id, ret_early=True)
|
|
||||||
|
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
rjson = await resp.json
|
rjson = await resp.json
|
||||||
assert isinstance(rjson, dict)
|
assert isinstance(rjson, dict)
|
||||||
assert rjson['error']
|
assert rjson["error"]
|
||||||
assert rjson['code'] == GuildNotFound.error_code
|
assert rjson["code"] == GuildNotFound.error_code
|
||||||
finally:
|
finally:
|
||||||
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
await delete_guild(int(guild_id), app_=test_cli_staff.app)
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
async def _get_invs(test_cli):
|
async def _get_invs(test_cli):
|
||||||
resp = await test_cli.get('/api/v6/admin/instance/invites')
|
resp = await test_cli.get("/api/v6/admin/instance/invites")
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
rjson = await resp.json
|
rjson = await resp.json
|
||||||
|
|
@ -39,7 +39,7 @@ async def test_get_invites(test_cli_staff):
|
||||||
async def test_inv_delete_invalid(test_cli_staff):
|
async def test_inv_delete_invalid(test_cli_staff):
|
||||||
"""Test errors happen when trying to delete a
|
"""Test errors happen when trying to delete a
|
||||||
non-existing instance invite."""
|
non-existing instance invite."""
|
||||||
resp = await test_cli_staff.delete('/api/v6/admin/instance/invites/aaaaaa')
|
resp = await test_cli_staff.delete("/api/v6/admin/instance/invites/aaaaaa")
|
||||||
|
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
@ -48,21 +48,20 @@ async def test_inv_delete_invalid(test_cli_staff):
|
||||||
async def test_create_invite(test_cli_staff):
|
async def test_create_invite(test_cli_staff):
|
||||||
"""Test the creation of an instance invite, then listing it,
|
"""Test the creation of an instance invite, then listing it,
|
||||||
then deleting it."""
|
then deleting it."""
|
||||||
resp = await test_cli_staff.put('/api/v6/admin/instance/invites', json={
|
resp = await test_cli_staff.put(
|
||||||
'max_uses': 1
|
"/api/v6/admin/instance/invites", json={"max_uses": 1}
|
||||||
})
|
)
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
rjson = await resp.json
|
rjson = await resp.json
|
||||||
assert isinstance(rjson, dict)
|
assert isinstance(rjson, dict)
|
||||||
code = rjson['code']
|
code = rjson["code"]
|
||||||
|
|
||||||
# assert that the invite is in the list
|
# assert that the invite is in the list
|
||||||
invites = await _get_invs(test_cli_staff)
|
invites = await _get_invs(test_cli_staff)
|
||||||
assert any(inv['code'] == code for inv in invites)
|
assert any(inv["code"] == code for inv in invites)
|
||||||
|
|
||||||
# delete it, and assert it worked
|
# delete it, and assert it worked
|
||||||
resp = await test_cli_staff.delete(
|
resp = await test_cli_staff.delete(f"/api/v6/admin/instance/invites/{code}")
|
||||||
f'/api/v6/admin/instance/invites/{code}')
|
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue