From 0ec615f3bdd55f63e02630e0a45e1a6e52b45e32 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 14 Nov 2018 19:49:36 -0300 Subject: [PATCH] user.billing: add create_payment - user: add billing_job for recurring payments (monthly or weekly) - user.billing: make main functions accept external db object - user.billing: fix get_payment's fields - litecord: add job module with JobManager --- Pipfile.lock | 6 +- litecord/blueprints/user/billing.py | 88 ++++++++++++++++++++--- litecord/blueprints/user/billing_job.py | 95 +++++++++++++++++++++++++ litecord/jobs.py | 16 +++++ run.py | 21 +++++- 5 files changed, 211 insertions(+), 15 deletions(-) create mode 100644 litecord/blueprints/user/billing_job.py create mode 100644 litecord/jobs.py diff --git a/Pipfile.lock b/Pipfile.lock index 17f4f66..3924626 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -178,10 +178,10 @@ }, "hypercorn": { "hashes": [ - "sha256:d563272b41269e9b2a73b1058f2471a511b545f64090b3fc2d6006cbaf8109fe", - "sha256:f6e8c5f02e9c97d6981c56098be379cb0f9c2a48f64c1d5dd2a8f228b61bc2b8" + "sha256:3931144309c40341a46a2d054ac550bbd012a1f1a803774b5d6a3add90f52259", + "sha256:4df03fbc101efb4faf0b0883863ff7e620f94310e309311ceafaadb38ee1fa36" ], - "version": "==0.4.1" + "version": "==0.4.2" }, "hyperframe": { "hashes": [ diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py index 1e6398d..56ff7cc 100644 --- a/litecord/blueprints/user/billing.py +++ b/litecord/blueprints/user/billing.py @@ -1,6 +1,7 @@ import pprint import json import datetime +from asyncio import sleep from enum import Enum from quart import Blueprint, jsonify, request, current_app as app @@ -46,6 +47,16 @@ class PaymentStatus: FAILED = 2 +# how much should a payment be, depending +# of the subscription +AMOUNTS = { + 'premium_month_tier_1': 499, + 'premium_month_tier_2': 999, + 'premium_year_tier_1': 4999, + 'premium_year_tier_2': 9999, +} + + CREATE_SUBSCRIPTION = { 'payment_gateway_plan_id': {'type': 'string'}, 'payment_source_id': {'coerce': int} @@ -80,8 +91,11 @@ async def get_payment_source_ids(user_id: int) -> list: return [r['id'] for r in rows] -async def get_payment_ids(user_id: int) -> list: - rows = await app.db.fetch(""" +async def get_payment_ids(user_id: int, db=None) -> list: + if not db: + db = app.db + + rows = await db.fetch(""" SELECT id FROM user_payments WHERE user_id = $1 @@ -100,11 +114,15 @@ async def get_subscription_ids(user_id: int) -> list: return [r['id'] for r in rows] -async def get_payment_source(user_id: int, source_id: int) -> dict: +async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: """Get a payment source's information.""" + + if not db: + db = app.db + source = {} - source_type = await app.db.fetchval(""" + source_type = await db.fetchval(""" SELECT source_type FROM user_payment_sources WHERE id = $1 AND user_id = $2 @@ -120,7 +138,7 @@ async def get_payment_source(user_id: int, source_id: int) -> dict: fields = ','.join(specific_fields) - extras_row = await app.db.fetchrow(f""" + extras_row = await db.fetchrow(f""" SELECT {fields}, billing_address, default_, id::text FROM user_payment_sources WHERE id = $1 @@ -143,9 +161,14 @@ async def get_payment_source(user_id: int, source_id: int) -> dict: return {**source, **derow} -async def get_subscription(subscription_id: int): - row = await app.db.fetchrow(""" +async def get_subscription(subscription_id: int, db=None): + """Get a subscription's information.""" + if not db: + db = app.db + + row = await db.fetchrow(""" SELECT id::text, source_id::text AS payment_source_id, + user_id, payment_gateway, payment_gateway_plan_id, period_start AS current_period_start, period_end AS current_period_end, @@ -167,9 +190,13 @@ async def get_subscription(subscription_id: int): return drow -async def get_payment(payment_id: int): - row = await app.db.fetchrow(""" - SELECT id::text, source_id, subscription_id, +async def get_payment(payment_id: int, db=None): + """Get a single payment's information.""" + if not db: + db = app.db + + row = await db.fetchrow(""" + SELECT id::text, source_id, subscription_id, user_id, amount, amount_refunded, currency, description, status, tax, tax_inclusive FROM user_payments @@ -177,10 +204,47 @@ async def get_payment(payment_id: int): """, payment_id) drow = dict(row) + + drow.pop('source_id') + drow.pop('subscription_id') + drow.pop('user_id') + drow['created_at'] = snowflake_datetime(int(drow['id'])) + + drow['payment_source'] = await get_payment_source( + row['user_id'], row['source_id'], db) + + drow['subscription'] = await get_subscription( + row['subscription_id'], db) + return drow +async def create_payment(subscription_id, app): + """Create a payment.""" + sub = await get_subscription(subscription_id, app.db) + + new_id = get_snowflake() + + amount = AMOUNTS[sub['payment_gateway_plan_id']] + + await app.db.execute( + """ + INSERT INTO user_payments ( + id, source_id, subscription_id, user_id, + amount, amount_refunded, currency, + description, status, tax, tax_inclusive + ) + VALUES + ($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false) + """, new_id, int(sub['payment_source_id']), + subscription_id, int(sub['user_id']), + amount, 'usd', 'FUCK NITRO', + PaymentStatus.SUCCESS) + + return new_id + + @bp.route('/@me/billing/payment-sources', methods=['GET']) async def _get_billing_sources(): user_id = await token_check() @@ -248,7 +312,7 @@ async def _create_subscription(): source = await get_payment_source(user_id, j['payment_source_id']) if not source: - raise BadInput('invalid source id') + raise BadRequest('invalid source id') plan_id = j['payment_gateway_plan_id'] @@ -273,6 +337,8 @@ async def _create_subscription(): SubscriptionType.PURCHASE, PaymentGateway.STRIPE, plan_id, 1) + await create_payment(new_id, app) + return jsonify( await get_subscription(new_id) ) diff --git a/litecord/blueprints/user/billing_job.py b/litecord/blueprints/user/billing_job.py new file mode 100644 index 0000000..6dce986 --- /dev/null +++ b/litecord/blueprints/user/billing_job.py @@ -0,0 +1,95 @@ +""" +this file only serves the periodic payment job code. +""" +import datetime +from asyncio import sleep +from logbook import Logger + +from litecord.blueprints.user.billing import ( + get_subscription, get_payment_ids, get_payment, PaymentStatus, + create_payment +) + +from litecord.snowflake import snowflake_datetime, get_snowflake + +log = Logger(__name__) + +# how many days until a payment needs +# to be issued +THRESHOLDS = { + 'premium_month_tier_1': 30, + 'premium_month_tier_2': 30, + 'premium_year_tier_1': 365, + 'premium_year_tier_2': 365, +} + + +async def _resched(app): + log.debug('waiting 2 minutes for job.') + await sleep(120) + await app.sched.spawn(payment_job(app)) + + +async def _process_user_payments(app, user_id: int): + payments = await get_payment_ids(user_id, app.db) + + if not payments: + log.debug('no payments for uid {}, skipping', user_id) + return + + log.debug('{} payments for uid {}', len(payments), user_id) + + latest_payment = max(payments) + + payment_data = await get_payment(latest_payment, app.db) + + # calculate the difference between this payment + # and now. + now = datetime.datetime.now() + payment_tstamp = snowflake_datetime(int(payment_data['id'])) + + delta = now - payment_tstamp + + sub_id = int(payment_data['subscription']['id']) + subscription = await get_subscription( + sub_id, app.db) + + threshold = THRESHOLDS[subscription['payment_gateway_plan_id']] + + log.debug('delta {} delta days {} threshold {}', + delta, delta.days, threshold) + + if delta.days > threshold: + # insert new payment, for free !!!!!! + log.info('creating payment for sid={}', + sub_id) + await create_payment(sub_id, app) + else: + log.debug('not there yet for sid={}', sub_id) + + +async def payment_job(app): + """Main payment job function. + + This function will check through users' payments + and add a new one once a month / year. + """ + log.info('payment job start!') + + user_ids = await app.db.fetch(""" + SELECT DISTINCT user_id + FROM user_payments + """) + + log.debug('working {} users', len(user_ids)) + print(user_ids) + + # go through each user's payments + for row in user_ids: + user_id = row['user_id'] + try: + await _process_user_payments(app, user_id) + except Exception: + log.exception('error while processing user payments') + + await _resched(app) diff --git a/litecord/jobs.py b/litecord/jobs.py new file mode 100644 index 0000000..f2161e4 --- /dev/null +++ b/litecord/jobs.py @@ -0,0 +1,16 @@ +import asyncio + + +class JobManager: + """Manage background jobs""" + def __init__(self, loop=None): + self.loop = loop or asyncio.get_event_loop() + self.jobs = [] + + def spawn(self, coro): + task = self.loop.create_task(coro) + self.jobs.append(task) + + def close(self): + for job in self.jobs: + job.cancel() diff --git a/run.py b/run.py index 8e0d632..e0a35d3 100644 --- a/run.py +++ b/run.py @@ -32,6 +32,10 @@ from litecord.blueprints.user import ( user_settings, user_billing ) +from litecord.blueprints.user.billing_job import ( + payment_job +) + from litecord.ratelimits.handler import ratelimit_handler from litecord.ratelimits.main import RatelimitManager @@ -42,6 +46,7 @@ from litecord.storage import Storage from litecord.dispatcher import EventDispatcher from litecord.presence import PresenceManager from litecord.images import IconManager +from litecord.jobs import JobManager # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) @@ -164,10 +169,15 @@ async def app_set_ratelimit_headers(resp): async def init_app_db(app): - """Connect to databases""" + """Connect to databases. + + Also spawns the job scheduler. + """ log.info('db connect') app.db = await asyncpg.create_pool(**app.config['POSTGRES']) + app.sched = JobManager() + def init_app_managers(app): """Initialize singleton classes.""" @@ -180,9 +190,15 @@ def init_app_managers(app): app.dispatcher = EventDispatcher(app) app.presence = PresenceManager(app.storage, app.state_manager, app.dispatcher) + app.storage.presence = app.presence +async def post_app_start(app): + # we'll need to start a billing job + app.sched.spawn(payment_job(app)) + + @app.before_serving async def app_before_serving(): log.info('opening db') @@ -209,6 +225,7 @@ async def app_before_serving(): ws_future = websockets.serve(_wrapper, host, port) + await post_app_start(app) await ws_future @@ -223,6 +240,8 @@ async def app_after_serving(): app.state_manager.close() + app.sched.close() + log.info('closing db') await app.db.close()