remove app param from billing functions

This commit is contained in:
Luna 2019-10-25 11:23:40 -03:00
parent ce04ac5c5f
commit f6f50a1cff
3 changed files with 78 additions and 94 deletions

View File

@ -122,16 +122,13 @@ async def get_payment_source_ids(user_id: int) -> list:
return [r["id"] for r in rows]
async def get_payment_ids(user_id: int, db=None) -> list:
if not db:
db = app.db
rows = await db.fetch(
async def get_payment_ids(user_id: int) -> list:
rows = await app.db.fetch(
"""
SELECT id
FROM user_payments
WHERE user_id = $1
""",
SELECT id
FROM user_payments
WHERE user_id = $1
""",
user_id,
)
@ -151,18 +148,14 @@ async def get_subscription_ids(user_id: int) -> list:
return [r["id"] for r in rows]
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
async def get_payment_source(user_id: int, source_id: int) -> dict:
"""Get a payment source's information."""
if not db:
db = app.db
source_type = await db.fetchval(
source_type = await app.db.fetchval(
"""
SELECT source_type
FROM user_payment_sources
WHERE id = $1 AND user_id = $2
""",
SELECT source_type
FROM user_payment_sources
WHERE id = $1 AND user_id = $2
""",
source_id,
user_id,
)
@ -176,7 +169,7 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
fields = ",".join(specific_fields)
extras_row = await db.fetchrow(
extras_row = await app.db.fetchrow(
f"""
SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources
@ -199,22 +192,19 @@ async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
return {**source, **derow}
async def get_subscription(subscription_id: int, db=None):
async def get_subscription(subscription_id: int):
"""Get a subscription's information."""
if not db:
db = app.db
row = await db.fetchrow(
row = await app.db.fetchrow(
"""
SELECT id::text, source_id::text AS payment_source_id,
user_id,
payment_gateway, payment_gateway_plan_id,
period_start AS current_period_start,
period_end AS current_period_end,
canceled_at, s_type, status
FROM user_subscriptions
WHERE id = $1
""",
SELECT id::text, source_id::text AS payment_source_id,
user_id,
payment_gateway, payment_gateway_plan_id,
period_start AS current_period_start,
period_end AS current_period_end,
canceled_at, s_type, status
FROM user_subscriptions
WHERE id = $1
""",
subscription_id,
)
@ -231,19 +221,16 @@ async def get_subscription(subscription_id: int, db=None):
return drow
async def get_payment(payment_id: int, db=None):
async def get_payment(payment_id: int):
"""Get a single payment's information."""
if not db:
db = app.db
row = await db.fetchrow(
row = await app.db.fetchrow(
"""
SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency,
description, status, tax, tax_inclusive
FROM user_payments
WHERE id = $1
""",
SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency,
description, status, tax, tax_inclusive
FROM user_payments
WHERE id = $1
""",
payment_id,
)
@ -255,27 +242,22 @@ async def get_payment(payment_id: int, db=None):
drow["created_at"] = snowflake_datetime(int(drow["id"]))
drow["payment_source"] = await get_payment_source(
row["user_id"], row["source_id"], db
)
drow["payment_source"] = await get_payment_source(row["user_id"], row["source_id"])
drow["subscription"] = await get_subscription(row["subscription_id"], db)
drow["subscription"] = await get_subscription(row["subscription_id"])
return drow
async def create_payment(subscription_id, db=None):
async def create_payment(subscription_id):
"""Create a payment."""
if not db:
db = app.db
sub = await get_subscription(subscription_id, db)
sub = await get_subscription(subscription_id)
new_id = get_snowflake()
amount = AMOUNTS[sub["payment_gateway_plan_id"]]
await db.execute(
await app.db.execute(
"""
INSERT INTO user_payments (
id, source_id, subscription_id, user_id,
@ -298,9 +280,9 @@ async def create_payment(subscription_id, db=None):
return new_id
async def process_subscription(app, subscription_id: int):
async def process_subscription(subscription_id: int):
"""Process a single subscription."""
sub = await get_subscription(subscription_id, app.db)
sub = await get_subscription(subscription_id)
user_id = int(sub["user_id"])
@ -313,10 +295,10 @@ async def process_subscription(app, subscription_id: int):
# payments), then we should update premium status
first_payment_id = await app.db.fetchval(
"""
SELECT MIN(id)
FROM user_payments
WHERE subscription_id = $1
""",
SELECT MIN(id)
FROM user_payments
WHERE subscription_id = $1
""",
subscription_id,
)
@ -324,10 +306,10 @@ async def process_subscription(app, subscription_id: int):
premium_since = await app.db.fetchval(
"""
SELECT premium_since
FROM users
WHERE id = $1
""",
SELECT premium_since
FROM users
WHERE id = $1
""",
user_id,
)
@ -343,10 +325,10 @@ async def process_subscription(app, subscription_id: int):
old_flags = await app.db.fetchval(
"""
SELECT flags
FROM users
WHERE id = $1
""",
SELECT flags
FROM users
WHERE id = $1
""",
user_id,
)
@ -355,17 +337,17 @@ async def process_subscription(app, subscription_id: int):
await app.db.execute(
"""
UPDATE users
SET premium_since = $1, flags = $2
WHERE id = $3
""",
UPDATE users
SET premium_since = $1, flags = $2
WHERE id = $3
""",
first_payment_ts,
new_flags,
user_id,
)
# dispatch updated user to all possible clients
await mass_user_update(user_id, app)
await mass_user_update(user_id)
@bp.route("/@me/billing/payment-sources", methods=["GET"])
@ -474,11 +456,11 @@ async def _create_subscription():
1,
)
await create_payment(new_id, app.db)
await create_payment(new_id)
# make sure we update the user's premium status
# and dispatch respective user updates to other people.
await process_subscription(app, new_id)
await process_subscription(new_id)
return jsonify(await get_subscription(new_id))

View File

@ -21,6 +21,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
this file only serves the periodic payment job code.
"""
import datetime
from quart import current_app as app
from asyncio import sleep, CancelledError
from logbook import Logger
@ -47,14 +49,14 @@ THRESHOLDS = {
}
async def _resched(app):
async def _resched():
log.debug("waiting 30 minutes for job.")
await sleep(30 * MINUTES)
app.sched.spawn(payment_job(app))
app.sched.spawn(payment_job())
async def _process_user_payments(app, user_id: int):
payments = await get_payment_ids(user_id, app.db)
async def _process_user_payments(user_id: int):
payments = await get_payment_ids(user_id)
if not payments:
log.debug("no payments for uid {}, skipping", user_id)
@ -64,7 +66,7 @@ async def _process_user_payments(app, user_id: int):
latest_payment = max(payments)
payment_data = await get_payment(latest_payment, app.db)
payment_data = await get_payment(latest_payment)
# calculate the difference between this payment
# and now.
@ -74,7 +76,7 @@ async def _process_user_payments(app, user_id: int):
delta = now - payment_tstamp
sub_id = int(payment_data["subscription"]["id"])
subscription = await get_subscription(sub_id, app.db)
subscription = await get_subscription(sub_id)
# if the max payment is X days old, we create another.
# X is 30 for monthly subscriptions of nitro,
@ -89,12 +91,12 @@ async def _process_user_payments(app, user_id: int):
# create_payment does not call any Stripe
# or BrainTree APIs at all, since we'll just
# give it as free.
await create_payment(sub_id, app.db)
await create_payment(sub_id)
else:
log.debug("sid={}, missing {} days", sub_id, threshold - delta.days)
async def payment_job(app):
async def payment_job():
"""Main payment job function.
This function will check through users' payments
@ -104,9 +106,9 @@ async def payment_job(app):
user_ids = await app.db.fetch(
"""
SELECT DISTINCT user_id
FROM user_payments
"""
SELECT DISTINCT user_id
FROM user_payments
"""
)
log.debug("working {} users", len(user_ids))
@ -115,24 +117,24 @@ async def payment_job(app):
for row in user_ids:
user_id = row["user_id"]
try:
await _process_user_payments(app, user_id)
await _process_user_payments(user_id)
except Exception:
log.exception("error while processing user payments")
subscribers = await app.db.fetch(
"""
SELECT id
FROM user_subscriptions
"""
SELECT id
FROM user_subscriptions
"""
)
for row in subscribers:
try:
await process_subscription(app, row["id"])
await process_subscription(row["id"])
except Exception:
log.exception("error while processing subscription")
log.debug("rescheduling..")
try:
await _resched(app)
await _resched()
except CancelledError:
log.info("cancelled while waiting for resched")

2
run.py
View File

@ -337,7 +337,7 @@ async def api_index(app_):
async def post_app_start(app_):
# we'll need to start a billing job
app_.sched.spawn(payment_job(app_))
app_.sched.spawn(payment_job())
app_.sched.spawn(api_index(app_))
app_.sched.spawn(guild_region_check())