From f6f50a1cff83298806aeb6dce9d29bac4a373c5c Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 25 Oct 2019 11:23:40 -0300 Subject: [PATCH] remove app param from billing functions --- litecord/blueprints/user/billing.py | 134 ++++++++++-------------- litecord/blueprints/user/billing_job.py | 36 ++++--- run.py | 2 +- 3 files changed, 78 insertions(+), 94 deletions(-) diff --git a/litecord/blueprints/user/billing.py b/litecord/blueprints/user/billing.py index a1a4148..2ac1972 100644 --- a/litecord/blueprints/user/billing.py +++ b/litecord/blueprints/user/billing.py @@ -122,16 +122,13 @@ async def get_payment_source_ids(user_id: int) -> list: return [r["id"] for r in rows] -async def get_payment_ids(user_id: int, db=None) -> list: - if not db: - db = app.db - - rows = await db.fetch( +async def get_payment_ids(user_id: int) -> list: + rows = await app.db.fetch( """ - SELECT id - FROM user_payments - WHERE user_id = $1 - """, + SELECT id + FROM user_payments + WHERE user_id = $1 + """, user_id, ) @@ -151,18 +148,14 @@ async def get_subscription_ids(user_id: int) -> list: return [r["id"] for r in rows] -async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: +async def get_payment_source(user_id: int, source_id: int) -> dict: """Get a payment source's information.""" - - if not db: - db = app.db - - source_type = await db.fetchval( + source_type = await app.db.fetchval( """ - SELECT source_type - FROM user_payment_sources - WHERE id = $1 AND user_id = $2 - """, + SELECT source_type + FROM user_payment_sources + WHERE id = $1 AND user_id = $2 + """, source_id, user_id, ) @@ -176,7 +169,7 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: fields = ",".join(specific_fields) - extras_row = await db.fetchrow( + extras_row = await app.db.fetchrow( f""" SELECT {fields}, billing_address, default_, id::text FROM user_payment_sources @@ -199,22 +192,19 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: return {**source, **derow} -async def get_subscription(subscription_id: int, db=None): +async def get_subscription(subscription_id: int): """Get a subscription's information.""" - if not db: - db = app.db - - row = await db.fetchrow( + row = await app.db.fetchrow( """ - SELECT id::text, source_id::text AS payment_source_id, - user_id, - payment_gateway, payment_gateway_plan_id, - period_start AS current_period_start, - period_end AS current_period_end, - canceled_at, s_type, status - FROM user_subscriptions - WHERE id = $1 - """, + SELECT id::text, source_id::text AS payment_source_id, + user_id, + payment_gateway, payment_gateway_plan_id, + period_start AS current_period_start, + period_end AS current_period_end, + canceled_at, s_type, status + FROM user_subscriptions + WHERE id = $1 + """, subscription_id, ) @@ -231,19 +221,16 @@ async def get_subscription(subscription_id: int, db=None): return drow -async def get_payment(payment_id: int, db=None): +async def get_payment(payment_id: int): """Get a single payment's information.""" - if not db: - db = app.db - - row = await db.fetchrow( + row = await app.db.fetchrow( """ - SELECT id::text, source_id, subscription_id, user_id, - amount, amount_refunded, currency, - description, status, tax, tax_inclusive - FROM user_payments - WHERE id = $1 - """, + SELECT id::text, source_id, subscription_id, user_id, + amount, amount_refunded, currency, + description, status, tax, tax_inclusive + FROM user_payments + WHERE id = $1 + """, payment_id, ) @@ -255,27 +242,22 @@ async def get_payment(payment_id: int, db=None): drow["created_at"] = snowflake_datetime(int(drow["id"])) - drow["payment_source"] = await get_payment_source( - row["user_id"], row["source_id"], db - ) + drow["payment_source"] = await get_payment_source(row["user_id"], row["source_id"]) - drow["subscription"] = await get_subscription(row["subscription_id"], db) + drow["subscription"] = await get_subscription(row["subscription_id"]) return drow -async def create_payment(subscription_id, db=None): +async def create_payment(subscription_id): """Create a payment.""" - if not db: - db = app.db - - sub = await get_subscription(subscription_id, db) + sub = await get_subscription(subscription_id) new_id = get_snowflake() amount = AMOUNTS[sub["payment_gateway_plan_id"]] - await db.execute( + await app.db.execute( """ INSERT INTO user_payments ( id, source_id, subscription_id, user_id, @@ -298,9 +280,9 @@ async def create_payment(subscription_id, db=None): return new_id -async def process_subscription(app, subscription_id: int): +async def process_subscription(subscription_id: int): """Process a single subscription.""" - sub = await get_subscription(subscription_id, app.db) + sub = await get_subscription(subscription_id) user_id = int(sub["user_id"]) @@ -313,10 +295,10 @@ async def process_subscription(app, subscription_id: int): # payments), then we should update premium status first_payment_id = await app.db.fetchval( """ - SELECT MIN(id) - FROM user_payments - WHERE subscription_id = $1 - """, + SELECT MIN(id) + FROM user_payments + WHERE subscription_id = $1 + """, subscription_id, ) @@ -324,10 +306,10 @@ async def process_subscription(app, subscription_id: int): premium_since = await app.db.fetchval( """ - SELECT premium_since - FROM users - WHERE id = $1 - """, + SELECT premium_since + FROM users + WHERE id = $1 + """, user_id, ) @@ -343,10 +325,10 @@ async def process_subscription(app, subscription_id: int): old_flags = await app.db.fetchval( """ - SELECT flags - FROM users - WHERE id = $1 - """, + SELECT flags + FROM users + WHERE id = $1 + """, user_id, ) @@ -355,17 +337,17 @@ async def process_subscription(app, subscription_id: int): await app.db.execute( """ - UPDATE users - SET premium_since = $1, flags = $2 - WHERE id = $3 - """, + UPDATE users + SET premium_since = $1, flags = $2 + WHERE id = $3 + """, first_payment_ts, new_flags, user_id, ) # dispatch updated user to all possible clients - await mass_user_update(user_id, app) + await mass_user_update(user_id) @bp.route("/@me/billing/payment-sources", methods=["GET"]) @@ -474,11 +456,11 @@ async def _create_subscription(): 1, ) - await create_payment(new_id, app.db) + await create_payment(new_id) # make sure we update the user's premium status # and dispatch respective user updates to other people. - await process_subscription(app, new_id) + await process_subscription(new_id) return jsonify(await get_subscription(new_id)) diff --git a/litecord/blueprints/user/billing_job.py b/litecord/blueprints/user/billing_job.py index 4148415..ee50c33 100644 --- a/litecord/blueprints/user/billing_job.py +++ b/litecord/blueprints/user/billing_job.py @@ -21,6 +21,8 @@ along with this program. If not, see . this file only serves the periodic payment job code. """ import datetime + +from quart import current_app as app from asyncio import sleep, CancelledError from logbook import Logger @@ -47,14 +49,14 @@ THRESHOLDS = { } -async def _resched(app): +async def _resched(): log.debug("waiting 30 minutes for job.") await sleep(30 * MINUTES) - app.sched.spawn(payment_job(app)) + app.sched.spawn(payment_job()) -async def _process_user_payments(app, user_id: int): - payments = await get_payment_ids(user_id, app.db) +async def _process_user_payments(user_id: int): + payments = await get_payment_ids(user_id) if not payments: log.debug("no payments for uid {}, skipping", user_id) @@ -64,7 +66,7 @@ async def _process_user_payments(app, user_id: int): latest_payment = max(payments) - payment_data = await get_payment(latest_payment, app.db) + payment_data = await get_payment(latest_payment) # calculate the difference between this payment # and now. @@ -74,7 +76,7 @@ async def _process_user_payments(app, user_id: int): delta = now - payment_tstamp sub_id = int(payment_data["subscription"]["id"]) - subscription = await get_subscription(sub_id, app.db) + subscription = await get_subscription(sub_id) # if the max payment is X days old, we create another. # X is 30 for monthly subscriptions of nitro, @@ -89,12 +91,12 @@ async def _process_user_payments(app, user_id: int): # create_payment does not call any Stripe # or BrainTree APIs at all, since we'll just # give it as free. - await create_payment(sub_id, app.db) + await create_payment(sub_id) else: log.debug("sid={}, missing {} days", sub_id, threshold - delta.days) -async def payment_job(app): +async def payment_job(): """Main payment job function. This function will check through users' payments @@ -104,9 +106,9 @@ async def payment_job(app): user_ids = await app.db.fetch( """ - SELECT DISTINCT user_id - FROM user_payments - """ + SELECT DISTINCT user_id + FROM user_payments + """ ) log.debug("working {} users", len(user_ids)) @@ -115,24 +117,24 @@ async def payment_job(app): for row in user_ids: user_id = row["user_id"] try: - await _process_user_payments(app, user_id) + await _process_user_payments(user_id) except Exception: log.exception("error while processing user payments") subscribers = await app.db.fetch( """ - SELECT id - FROM user_subscriptions - """ + SELECT id + FROM user_subscriptions + """ ) for row in subscribers: try: - await process_subscription(app, row["id"]) + await process_subscription(row["id"]) except Exception: log.exception("error while processing subscription") log.debug("rescheduling..") try: - await _resched(app) + await _resched() except CancelledError: log.info("cancelled while waiting for resched") diff --git a/run.py b/run.py index eaaa9b2..55aeb37 100644 --- a/run.py +++ b/run.py @@ -337,7 +337,7 @@ async def api_index(app_): async def post_app_start(app_): # we'll need to start a billing job - app_.sched.spawn(payment_job(app_)) + app_.sched.spawn(payment_job()) app_.sched.spawn(api_index(app_)) app_.sched.spawn(guild_region_check())