mirror of https://gitlab.com/litecord/litecord.git
user.billing: add create_payment
- user: add billing_job for recurring payments (monthly or weekly) - user.billing: make main functions accept external db object - user.billing: fix get_payment's fields - litecord: add job module with JobManager
This commit is contained in:
parent
976f8d0ed8
commit
0ec615f3bd
|
|
@ -178,10 +178,10 @@
|
|||
},
|
||||
"hypercorn": {
|
||||
"hashes": [
|
||||
"sha256:d563272b41269e9b2a73b1058f2471a511b545f64090b3fc2d6006cbaf8109fe",
|
||||
"sha256:f6e8c5f02e9c97d6981c56098be379cb0f9c2a48f64c1d5dd2a8f228b61bc2b8"
|
||||
"sha256:3931144309c40341a46a2d054ac550bbd012a1f1a803774b5d6a3add90f52259",
|
||||
"sha256:4df03fbc101efb4faf0b0883863ff7e620f94310e309311ceafaadb38ee1fa36"
|
||||
],
|
||||
"version": "==0.4.1"
|
||||
"version": "==0.4.2"
|
||||
},
|
||||
"hyperframe": {
|
||||
"hashes": [
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pprint
|
||||
import json
|
||||
import datetime
|
||||
from asyncio import sleep
|
||||
from enum import Enum
|
||||
|
||||
from quart import Blueprint, jsonify, request, current_app as app
|
||||
|
|
@ -46,6 +47,16 @@ class PaymentStatus:
|
|||
FAILED = 2
|
||||
|
||||
|
||||
# how much should a payment be, depending
|
||||
# of the subscription
|
||||
AMOUNTS = {
|
||||
'premium_month_tier_1': 499,
|
||||
'premium_month_tier_2': 999,
|
||||
'premium_year_tier_1': 4999,
|
||||
'premium_year_tier_2': 9999,
|
||||
}
|
||||
|
||||
|
||||
CREATE_SUBSCRIPTION = {
|
||||
'payment_gateway_plan_id': {'type': 'string'},
|
||||
'payment_source_id': {'coerce': int}
|
||||
|
|
@ -80,8 +91,11 @@ async def get_payment_source_ids(user_id: int) -> list:
|
|||
return [r['id'] for r in rows]
|
||||
|
||||
|
||||
async def get_payment_ids(user_id: int) -> list:
|
||||
rows = await app.db.fetch("""
|
||||
async def get_payment_ids(user_id: int, db=None) -> list:
|
||||
if not db:
|
||||
db = app.db
|
||||
|
||||
rows = await db.fetch("""
|
||||
SELECT id
|
||||
FROM user_payments
|
||||
WHERE user_id = $1
|
||||
|
|
@ -100,11 +114,15 @@ async def get_subscription_ids(user_id: int) -> list:
|
|||
return [r['id'] for r in rows]
|
||||
|
||||
|
||||
async def get_payment_source(user_id: int, source_id: int) -> dict:
|
||||
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
||||
"""Get a payment source's information."""
|
||||
|
||||
if not db:
|
||||
db = app.db
|
||||
|
||||
source = {}
|
||||
|
||||
source_type = await app.db.fetchval("""
|
||||
source_type = await db.fetchval("""
|
||||
SELECT source_type
|
||||
FROM user_payment_sources
|
||||
WHERE id = $1 AND user_id = $2
|
||||
|
|
@ -120,7 +138,7 @@ async def get_payment_source(user_id: int, source_id: int) -> dict:
|
|||
|
||||
fields = ','.join(specific_fields)
|
||||
|
||||
extras_row = await app.db.fetchrow(f"""
|
||||
extras_row = await db.fetchrow(f"""
|
||||
SELECT {fields}, billing_address, default_, id::text
|
||||
FROM user_payment_sources
|
||||
WHERE id = $1
|
||||
|
|
@ -143,9 +161,14 @@ async def get_payment_source(user_id: int, source_id: int) -> dict:
|
|||
return {**source, **derow}
|
||||
|
||||
|
||||
async def get_subscription(subscription_id: int):
|
||||
row = await app.db.fetchrow("""
|
||||
async def get_subscription(subscription_id: int, db=None):
|
||||
"""Get a subscription's information."""
|
||||
if not db:
|
||||
db = app.db
|
||||
|
||||
row = await db.fetchrow("""
|
||||
SELECT id::text, source_id::text AS payment_source_id,
|
||||
user_id,
|
||||
payment_gateway, payment_gateway_plan_id,
|
||||
period_start AS current_period_start,
|
||||
period_end AS current_period_end,
|
||||
|
|
@ -167,9 +190,13 @@ async def get_subscription(subscription_id: int):
|
|||
return drow
|
||||
|
||||
|
||||
async def get_payment(payment_id: int):
|
||||
row = await app.db.fetchrow("""
|
||||
SELECT id::text, source_id, subscription_id,
|
||||
async def get_payment(payment_id: int, db=None):
|
||||
"""Get a single payment's information."""
|
||||
if not db:
|
||||
db = app.db
|
||||
|
||||
row = await db.fetchrow("""
|
||||
SELECT id::text, source_id, subscription_id, user_id,
|
||||
amount, amount_refunded, currency,
|
||||
description, status, tax, tax_inclusive
|
||||
FROM user_payments
|
||||
|
|
@ -177,10 +204,47 @@ async def get_payment(payment_id: int):
|
|||
""", payment_id)
|
||||
|
||||
drow = dict(row)
|
||||
|
||||
drow.pop('source_id')
|
||||
drow.pop('subscription_id')
|
||||
drow.pop('user_id')
|
||||
|
||||
drow['created_at'] = snowflake_datetime(int(drow['id']))
|
||||
|
||||
drow['payment_source'] = await get_payment_source(
|
||||
row['user_id'], row['source_id'], db)
|
||||
|
||||
drow['subscription'] = await get_subscription(
|
||||
row['subscription_id'], db)
|
||||
|
||||
return drow
|
||||
|
||||
|
||||
async def create_payment(subscription_id, app):
|
||||
"""Create a payment."""
|
||||
sub = await get_subscription(subscription_id, app.db)
|
||||
|
||||
new_id = get_snowflake()
|
||||
|
||||
amount = AMOUNTS[sub['payment_gateway_plan_id']]
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO user_payments (
|
||||
id, source_id, subscription_id, user_id,
|
||||
amount, amount_refunded, currency,
|
||||
description, status, tax, tax_inclusive
|
||||
)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
|
||||
""", new_id, int(sub['payment_source_id']),
|
||||
subscription_id, int(sub['user_id']),
|
||||
amount, 'usd', 'FUCK NITRO',
|
||||
PaymentStatus.SUCCESS)
|
||||
|
||||
return new_id
|
||||
|
||||
|
||||
@bp.route('/@me/billing/payment-sources', methods=['GET'])
|
||||
async def _get_billing_sources():
|
||||
user_id = await token_check()
|
||||
|
|
@ -248,7 +312,7 @@ async def _create_subscription():
|
|||
|
||||
source = await get_payment_source(user_id, j['payment_source_id'])
|
||||
if not source:
|
||||
raise BadInput('invalid source id')
|
||||
raise BadRequest('invalid source id')
|
||||
|
||||
plan_id = j['payment_gateway_plan_id']
|
||||
|
||||
|
|
@ -273,6 +337,8 @@ async def _create_subscription():
|
|||
SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
|
||||
plan_id, 1)
|
||||
|
||||
await create_payment(new_id, app)
|
||||
|
||||
return jsonify(
|
||||
await get_subscription(new_id)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
this file only serves the periodic payment job code.
|
||||
"""
|
||||
import datetime
|
||||
from asyncio import sleep
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.blueprints.user.billing import (
|
||||
get_subscription, get_payment_ids, get_payment, PaymentStatus,
|
||||
create_payment
|
||||
)
|
||||
|
||||
from litecord.snowflake import snowflake_datetime, get_snowflake
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
# how many days until a payment needs
|
||||
# to be issued
|
||||
THRESHOLDS = {
|
||||
'premium_month_tier_1': 30,
|
||||
'premium_month_tier_2': 30,
|
||||
'premium_year_tier_1': 365,
|
||||
'premium_year_tier_2': 365,
|
||||
}
|
||||
|
||||
|
||||
async def _resched(app):
|
||||
log.debug('waiting 2 minutes for job.')
|
||||
await sleep(120)
|
||||
await app.sched.spawn(payment_job(app))
|
||||
|
||||
|
||||
async def _process_user_payments(app, user_id: int):
|
||||
payments = await get_payment_ids(user_id, app.db)
|
||||
|
||||
if not payments:
|
||||
log.debug('no payments for uid {}, skipping', user_id)
|
||||
return
|
||||
|
||||
log.debug('{} payments for uid {}', len(payments), user_id)
|
||||
|
||||
latest_payment = max(payments)
|
||||
|
||||
payment_data = await get_payment(latest_payment, app.db)
|
||||
|
||||
# calculate the difference between this payment
|
||||
# and now.
|
||||
now = datetime.datetime.now()
|
||||
payment_tstamp = snowflake_datetime(int(payment_data['id']))
|
||||
|
||||
delta = now - payment_tstamp
|
||||
|
||||
sub_id = int(payment_data['subscription']['id'])
|
||||
subscription = await get_subscription(
|
||||
sub_id, app.db)
|
||||
|
||||
threshold = THRESHOLDS[subscription['payment_gateway_plan_id']]
|
||||
|
||||
log.debug('delta {} delta days {} threshold {}',
|
||||
delta, delta.days, threshold)
|
||||
|
||||
if delta.days > threshold:
|
||||
# insert new payment, for free !!!!!!
|
||||
log.info('creating payment for sid={}',
|
||||
sub_id)
|
||||
await create_payment(sub_id, app)
|
||||
else:
|
||||
log.debug('not there yet for sid={}', sub_id)
|
||||
|
||||
|
||||
async def payment_job(app):
|
||||
"""Main payment job function.
|
||||
|
||||
This function will check through users' payments
|
||||
and add a new one once a month / year.
|
||||
"""
|
||||
log.info('payment job start!')
|
||||
|
||||
user_ids = await app.db.fetch("""
|
||||
SELECT DISTINCT user_id
|
||||
FROM user_payments
|
||||
""")
|
||||
|
||||
log.debug('working {} users', len(user_ids))
|
||||
print(user_ids)
|
||||
|
||||
# go through each user's payments
|
||||
for row in user_ids:
|
||||
user_id = row['user_id']
|
||||
try:
|
||||
await _process_user_payments(app, user_id)
|
||||
except Exception:
|
||||
log.exception('error while processing user payments')
|
||||
|
||||
await _resched(app)
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
import asyncio
|
||||
|
||||
|
||||
class JobManager:
|
||||
"""Manage background jobs"""
|
||||
def __init__(self, loop=None):
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
self.jobs = []
|
||||
|
||||
def spawn(self, coro):
|
||||
task = self.loop.create_task(coro)
|
||||
self.jobs.append(task)
|
||||
|
||||
def close(self):
|
||||
for job in self.jobs:
|
||||
job.cancel()
|
||||
21
run.py
21
run.py
|
|
@ -32,6 +32,10 @@ from litecord.blueprints.user import (
|
|||
user_settings, user_billing
|
||||
)
|
||||
|
||||
from litecord.blueprints.user.billing_job import (
|
||||
payment_job
|
||||
)
|
||||
|
||||
from litecord.ratelimits.handler import ratelimit_handler
|
||||
from litecord.ratelimits.main import RatelimitManager
|
||||
|
||||
|
|
@ -42,6 +46,7 @@ from litecord.storage import Storage
|
|||
from litecord.dispatcher import EventDispatcher
|
||||
from litecord.presence import PresenceManager
|
||||
from litecord.images import IconManager
|
||||
from litecord.jobs import JobManager
|
||||
|
||||
# setup logbook
|
||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||
|
|
@ -164,10 +169,15 @@ async def app_set_ratelimit_headers(resp):
|
|||
|
||||
|
||||
async def init_app_db(app):
|
||||
"""Connect to databases"""
|
||||
"""Connect to databases.
|
||||
|
||||
Also spawns the job scheduler.
|
||||
"""
|
||||
log.info('db connect')
|
||||
app.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
||||
|
||||
app.sched = JobManager()
|
||||
|
||||
|
||||
def init_app_managers(app):
|
||||
"""Initialize singleton classes."""
|
||||
|
|
@ -180,9 +190,15 @@ def init_app_managers(app):
|
|||
app.dispatcher = EventDispatcher(app)
|
||||
app.presence = PresenceManager(app.storage,
|
||||
app.state_manager, app.dispatcher)
|
||||
|
||||
app.storage.presence = app.presence
|
||||
|
||||
|
||||
async def post_app_start(app):
|
||||
# we'll need to start a billing job
|
||||
app.sched.spawn(payment_job(app))
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def app_before_serving():
|
||||
log.info('opening db')
|
||||
|
|
@ -209,6 +225,7 @@ async def app_before_serving():
|
|||
|
||||
ws_future = websockets.serve(_wrapper, host, port)
|
||||
|
||||
await post_app_start(app)
|
||||
await ws_future
|
||||
|
||||
|
||||
|
|
@ -223,6 +240,8 @@ async def app_after_serving():
|
|||
|
||||
app.state_manager.close()
|
||||
|
||||
app.sched.close()
|
||||
|
||||
log.info('closing db')
|
||||
await app.db.close()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue