diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py index 7817ec9..592d3e4 100644 --- a/litecord/blueprints/user/billing.py +++ b/litecord/blueprints/user/billing.py @@ -5,9 +5,9 @@ from quart import Blueprint, jsonify, request, current_app as app from litecord.auth import token_check from litecord.schemas import validate -from litecord.storage import timestamp_ from litecord.snowflake import snowflake_datetime, get_snowflake from litecord.errors import BadRequest +from litecord.types import timestamp_ bp = Blueprint('users_billing', __name__) @@ -43,6 +43,20 @@ class PaymentStatus: FAILED = 2 +class PremiumType: + TIER_1 = 1 + TIER_2 = 2 + NONE = None + + +PLAN_ID_TO_TYPE = { + 'premium_month_tier_1': PremiumType.TIER_1, + 'premium_month_tier_2': PremiumType.TIER_2, + 'premium_year_tier_1': PremiumType.TIER_1, + 'premium_year_tier_2': PremiumType.TIER_2, +} + + # how much should a payment be, depending # of the subscription AMOUNTS = { diff --git a/litecord/blueprints/user/billing_job.py b/litecord/blueprints/user/billing_job.py index a89de2e..3f91072 100644 --- a/litecord/blueprints/user/billing_job.py +++ b/litecord/blueprints/user/billing_job.py @@ -6,11 +6,13 @@ from asyncio import sleep, CancelledError from logbook import Logger from litecord.blueprints.user.billing import ( - get_subscription, get_payment_ids, get_payment, create_payment + get_subscription, get_payment_ids, get_payment, create_payment, + SubscriptionStatus ) from litecord.snowflake import snowflake_datetime -from litecord.types import MINUTES +from litecord.types import MINUTES, HOURS +from litecord.enums import UserFlags log = Logger(__name__) @@ -75,6 +77,65 @@ async def _process_user_payments(app, user_id: int): sub_id, threshold - delta.days) +async def _process_subscription(app, subscription_id: int): + sub = await get_subscription(subscription_id, app.db) + + user_id = int(sub['user_id']) + + if sub['status'] != SubscriptionStatus.ACTIVE: + log.debug('ignoring sub {}, not active', + subscription_id) + return + + # if the subscription is still active + # (should get cancelled status on failed + # payments), then we should update premium status + first_payment_id = await app.db.fetchval(""" + SELECT MIN(id) + FROM user_payments + WHERE subscription_id = $1 + """, subscription_id) + + first_payment_ts = snowflake_datetime(first_payment_id) + + premium_since = await app.db.fetchval(""" + SELECT premium_since + FROM users + WHERE id = $1 + """, user_id) + + premium_since = premium_since or datetime.datetime.fromtimestamp(0) + + delta = abs(first_payment_ts - premium_since) + + # if the time difference between the first payment + # and the premium_since column is more than 24h + # we update it. + if delta.total_seconds() < 24 * HOURS: + return + + old_flags = await app.db.fetchval(""" + SELECT flags + FROM users + WHERE id = $1 + """, user_id) + + new_flags = old_flags | UserFlags.premium_early + log.debug('updating flags {}, {} => {}', + user_id, old_flags, new_flags) + + await app.db.execute(""" + UPDATE users + SET premium_since = $1, flags = $2 + WHERE id = $3 + """, first_payment_ts, new_flags, user_id) + + user_object = await app.storage.get_user(user_id, secure=True) + + # dispatch updated user to all clients + await app.dispatcher.dispatch_user(user_id, 'USER_UPDATE', user_object) + + async def payment_job(app): """Main payment job function. @@ -98,6 +159,16 @@ async def payment_job(app): except Exception: log.exception('error while processing user payments') + subscribers = await app.db.fetch(""" + SELECT id + FROM user_subscriptions + """) + + for row in subscribers: + try: + await _process_subscription(app, row['id']) + except Exception: + log.exception('error while processing subscription') log.debug('rescheduling..') try: await _resched(app) diff --git a/litecord/storage.py b/litecord/storage.py index 7adef28..04dfe87 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -9,6 +9,10 @@ from litecord.blueprints.channel.reactions import ( EmojiType, emoji_sql, partial_emoji ) +from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE + +from litecord.types import timestamp_ + log = Logger(__name__) @@ -29,10 +33,6 @@ def str_(val): return maybe(str, val) -def timestamp_(dt): - return f'{dt.isoformat()}+00:00' if dt else None - - async def _set_json(con): """Set JSON and JSONB codecs for an asyncpg connection.""" @@ -110,6 +110,15 @@ class Storage: duser['mobile'] = False duser['phone'] = None + plan_id = await self.db.fetchval(""" + SELECT payment_gateway_plan_id + FROM user_subscriptions + WHERE status = 1 + AND user_id = $1 + """, user_id) + + duser['premium_type'] = PLAN_ID_TO_TYPE.get(plan_id) + return duser async def search_user(self, username: str, discriminator: str) -> int: diff --git a/litecord/types.py b/litecord/types.py index 8b3a5bf..43c21ce 100644 --- a/litecord/types.py +++ b/litecord/types.py @@ -3,7 +3,7 @@ KILOBYTES = 1024 # time units MINUTES = 60 -HOUR = 60 * MINUTES +HOURS = 60 * MINUTES class Color: @@ -20,3 +20,7 @@ class Color: def __int__(self): return self.value + + +def timestamp_(dt): + return f'{dt.isoformat()}+00:00' if dt else None diff --git a/schema.sql b/schema.sql index ef8278c..cfdf955 100644 --- a/schema.sql +++ b/schema.sql @@ -207,6 +207,8 @@ CREATE TABLE IF NOT EXISTS user_subscriptions ( -- gateway = 1: stripe -- gateway = 2: braintree payment_gateway int DEFAULT 0, + + -- "premium__tier_" payment_gateway_plan_id text, -- status = 1: active