remove app param from billing functions

This commit is contained in:
Luna 2019-10-25 11:23:40 -03:00
parent ce04ac5c5f
commit f6f50a1cff
3 changed files with 78 additions and 94 deletions

View File

@ -122,11 +122,8 @@ async def get_payment_source_ids(user_id: int) -> list:
return [r["id"] for r in rows] return [r["id"] for r in rows]
async def get_payment_ids(user_id: int, db=None) -> list: async def get_payment_ids(user_id: int) -> list:
if not db: rows = await app.db.fetch(
db = app.db
rows = await db.fetch(
""" """
SELECT id SELECT id
FROM user_payments FROM user_payments
@ -151,13 +148,9 @@ async def get_subscription_ids(user_id: int) -> list:
return [r["id"] for r in rows] return [r["id"] for r in rows]
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict: async def get_payment_source(user_id: int, source_id: int) -> dict:
"""Get a payment source's information.""" """Get a payment source's information."""
source_type = await app.db.fetchval(
if not db:
db = app.db
source_type = await db.fetchval(
""" """
SELECT source_type SELECT source_type
FROM user_payment_sources FROM user_payment_sources
@ -176,7 +169,7 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
fields = ",".join(specific_fields) fields = ",".join(specific_fields)
extras_row = await db.fetchrow( extras_row = await app.db.fetchrow(
f""" f"""
SELECT {fields}, billing_address, default_, id::text SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources FROM user_payment_sources
@ -199,12 +192,9 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
return {**source, **derow} return {**source, **derow}
async def get_subscription(subscription_id: int, db=None): async def get_subscription(subscription_id: int):
"""Get a subscription's information.""" """Get a subscription's information."""
if not db: row = await app.db.fetchrow(
db = app.db
row = await db.fetchrow(
""" """
SELECT id::text, source_id::text AS payment_source_id, SELECT id::text, source_id::text AS payment_source_id,
user_id, user_id,
@ -231,12 +221,9 @@ async def get_subscription(subscription_id: int, db=None):
return drow return drow
async def get_payment(payment_id: int, db=None): async def get_payment(payment_id: int):
"""Get a single payment's information.""" """Get a single payment's information."""
if not db: row = await app.db.fetchrow(
db = app.db
row = await db.fetchrow(
""" """
SELECT id::text, source_id, subscription_id, user_id, SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency, amount, amount_refunded, currency,
@ -255,27 +242,22 @@ async def get_payment(payment_id: int, db=None):
drow["created_at"] = snowflake_datetime(int(drow["id"])) drow["created_at"] = snowflake_datetime(int(drow["id"]))
drow["payment_source"] = await get_payment_source( drow["payment_source"] = await get_payment_source(row["user_id"], row["source_id"])
row["user_id"], row["source_id"], db
)
drow["subscription"] = await get_subscription(row["subscription_id"], db) drow["subscription"] = await get_subscription(row["subscription_id"])
return drow return drow
async def create_payment(subscription_id, db=None): async def create_payment(subscription_id):
"""Create a payment.""" """Create a payment."""
if not db: sub = await get_subscription(subscription_id)
db = app.db
sub = await get_subscription(subscription_id, db)
new_id = get_snowflake() new_id = get_snowflake()
amount = AMOUNTS[sub["payment_gateway_plan_id"]] amount = AMOUNTS[sub["payment_gateway_plan_id"]]
await db.execute( await app.db.execute(
""" """
INSERT INTO user_payments ( INSERT INTO user_payments (
id, source_id, subscription_id, user_id, id, source_id, subscription_id, user_id,
@ -298,9 +280,9 @@ async def create_payment(subscription_id, db=None):
return new_id return new_id
async def process_subscription(app, subscription_id: int): async def process_subscription(subscription_id: int):
"""Process a single subscription.""" """Process a single subscription."""
sub = await get_subscription(subscription_id, app.db) sub = await get_subscription(subscription_id)
user_id = int(sub["user_id"]) user_id = int(sub["user_id"])
@ -365,7 +347,7 @@ async def process_subscription(app, subscription_id: int):
) )
# dispatch updated user to all possible clients # dispatch updated user to all possible clients
await mass_user_update(user_id, app) await mass_user_update(user_id)
@bp.route("/@me/billing/payment-sources", methods=["GET"]) @bp.route("/@me/billing/payment-sources", methods=["GET"])
@ -474,11 +456,11 @@ async def _create_subscription():
1, 1,
) )
await create_payment(new_id, app.db) await create_payment(new_id)
# make sure we update the user's premium status # make sure we update the user's premium status
# and dispatch respective user updates to other people. # and dispatch respective user updates to other people.
await process_subscription(app, new_id) await process_subscription(new_id)
return jsonify(await get_subscription(new_id)) return jsonify(await get_subscription(new_id))

View File

@ -21,6 +21,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
this file only serves the periodic payment job code. this file only serves the periodic payment job code.
""" """
import datetime import datetime
from quart import current_app as app
from asyncio import sleep, CancelledError from asyncio import sleep, CancelledError
from logbook import Logger from logbook import Logger
@ -47,14 +49,14 @@ THRESHOLDS = {
} }
async def _resched(app): async def _resched():
log.debug("waiting 30 minutes for job.") log.debug("waiting 30 minutes for job.")
await sleep(30 * MINUTES) await sleep(30 * MINUTES)
app.sched.spawn(payment_job(app)) app.sched.spawn(payment_job())
async def _process_user_payments(app, user_id: int): async def _process_user_payments(user_id: int):
payments = await get_payment_ids(user_id, app.db) payments = await get_payment_ids(user_id)
if not payments: if not payments:
log.debug("no payments for uid {}, skipping", user_id) log.debug("no payments for uid {}, skipping", user_id)
@ -64,7 +66,7 @@ async def _process_user_payments(app, user_id: int):
latest_payment = max(payments) latest_payment = max(payments)
payment_data = await get_payment(latest_payment, app.db) payment_data = await get_payment(latest_payment)
# calculate the difference between this payment # calculate the difference between this payment
# and now. # and now.
@ -74,7 +76,7 @@ async def _process_user_payments(app, user_id: int):
delta = now - payment_tstamp delta = now - payment_tstamp
sub_id = int(payment_data["subscription"]["id"]) sub_id = int(payment_data["subscription"]["id"])
subscription = await get_subscription(sub_id, app.db) subscription = await get_subscription(sub_id)
# if the max payment is X days old, we create another. # if the max payment is X days old, we create another.
# X is 30 for monthly subscriptions of nitro, # X is 30 for monthly subscriptions of nitro,
@ -89,12 +91,12 @@ async def _process_user_payments(app, user_id: int):
# create_payment does not call any Stripe # create_payment does not call any Stripe
# or BrainTree APIs at all, since we'll just # or BrainTree APIs at all, since we'll just
# give it as free. # give it as free.
await create_payment(sub_id, app.db) await create_payment(sub_id)
else: else:
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days) log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
async def payment_job(app): async def payment_job():
"""Main payment job function. """Main payment job function.
This function will check through users' payments This function will check through users' payments
@ -115,7 +117,7 @@ async def payment_job(app):
for row in user_ids: for row in user_ids:
user_id = row["user_id"] user_id = row["user_id"]
try: try:
await _process_user_payments(app, user_id) await _process_user_payments(user_id)
except Exception: except Exception:
log.exception("error while processing user payments") log.exception("error while processing user payments")
@ -128,11 +130,11 @@ async def payment_job(app):
for row in subscribers: for row in subscribers:
try: try:
await process_subscription(app, row["id"]) await process_subscription(row["id"])
except Exception: except Exception:
log.exception("error while processing subscription") log.exception("error while processing subscription")
log.debug("rescheduling..") log.debug("rescheduling..")
try: try:
await _resched(app) await _resched()
except CancelledError: except CancelledError:
log.info("cancelled while waiting for resched") log.info("cancelled while waiting for resched")

2
run.py
View File

@ -337,7 +337,7 @@ async def api_index(app_):
async def post_app_start(app_): async def post_app_start(app_):
# we'll need to start a billing job # we'll need to start a billing job
app_.sched.spawn(payment_job(app_)) app_.sched.spawn(payment_job())
app_.sched.spawn(api_index(app_)) app_.sched.spawn(api_index(app_))
app_.sched.spawn(guild_region_check()) app_.sched.spawn(guild_region_check())