mirror of https://gitlab.com/litecord/litecord.git
user.billing: add create_payment
- user: add billing_job for recurring payments (monthly or weekly) - user.billing: make main functions accept external db object - user.billing: fix get_payment's fields - litecord: add job module with JobManager
This commit is contained in:
parent
976f8d0ed8
commit
0ec615f3bd
|
|
@ -178,10 +178,10 @@
|
||||||
},
|
},
|
||||||
"hypercorn": {
|
"hypercorn": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:d563272b41269e9b2a73b1058f2471a511b545f64090b3fc2d6006cbaf8109fe",
|
"sha256:3931144309c40341a46a2d054ac550bbd012a1f1a803774b5d6a3add90f52259",
|
||||||
"sha256:f6e8c5f02e9c97d6981c56098be379cb0f9c2a48f64c1d5dd2a8f228b61bc2b8"
|
"sha256:4df03fbc101efb4faf0b0883863ff7e620f94310e309311ceafaadb38ee1fa36"
|
||||||
],
|
],
|
||||||
"version": "==0.4.1"
|
"version": "==0.4.2"
|
||||||
},
|
},
|
||||||
"hyperframe": {
|
"hyperframe": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import pprint
|
import pprint
|
||||||
import json
|
import json
|
||||||
import datetime
|
import datetime
|
||||||
|
from asyncio import sleep
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from quart import Blueprint, jsonify, request, current_app as app
|
from quart import Blueprint, jsonify, request, current_app as app
|
||||||
|
|
@ -46,6 +47,16 @@ class PaymentStatus:
|
||||||
FAILED = 2
|
FAILED = 2
|
||||||
|
|
||||||
|
|
||||||
|
# how much should a payment be, depending
|
||||||
|
# of the subscription
|
||||||
|
AMOUNTS = {
|
||||||
|
'premium_month_tier_1': 499,
|
||||||
|
'premium_month_tier_2': 999,
|
||||||
|
'premium_year_tier_1': 4999,
|
||||||
|
'premium_year_tier_2': 9999,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
CREATE_SUBSCRIPTION = {
|
CREATE_SUBSCRIPTION = {
|
||||||
'payment_gateway_plan_id': {'type': 'string'},
|
'payment_gateway_plan_id': {'type': 'string'},
|
||||||
'payment_source_id': {'coerce': int}
|
'payment_source_id': {'coerce': int}
|
||||||
|
|
@ -80,8 +91,11 @@ async def get_payment_source_ids(user_id: int) -> list:
|
||||||
return [r['id'] for r in rows]
|
return [r['id'] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
async def get_payment_ids(user_id: int) -> list:
|
async def get_payment_ids(user_id: int, db=None) -> list:
|
||||||
rows = await app.db.fetch("""
|
if not db:
|
||||||
|
db = app.db
|
||||||
|
|
||||||
|
rows = await db.fetch("""
|
||||||
SELECT id
|
SELECT id
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
WHERE user_id = $1
|
WHERE user_id = $1
|
||||||
|
|
@ -100,11 +114,15 @@ async def get_subscription_ids(user_id: int) -> list:
|
||||||
return [r['id'] for r in rows]
|
return [r['id'] for r in rows]
|
||||||
|
|
||||||
|
|
||||||
async def get_payment_source(user_id: int, source_id: int) -> dict:
|
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
|
||||||
"""Get a payment source's information."""
|
"""Get a payment source's information."""
|
||||||
|
|
||||||
|
if not db:
|
||||||
|
db = app.db
|
||||||
|
|
||||||
source = {}
|
source = {}
|
||||||
|
|
||||||
source_type = await app.db.fetchval("""
|
source_type = await db.fetchval("""
|
||||||
SELECT source_type
|
SELECT source_type
|
||||||
FROM user_payment_sources
|
FROM user_payment_sources
|
||||||
WHERE id = $1 AND user_id = $2
|
WHERE id = $1 AND user_id = $2
|
||||||
|
|
@ -120,7 +138,7 @@ async def get_payment_source(user_id: int, source_id: int) -> dict:
|
||||||
|
|
||||||
fields = ','.join(specific_fields)
|
fields = ','.join(specific_fields)
|
||||||
|
|
||||||
extras_row = await app.db.fetchrow(f"""
|
extras_row = await db.fetchrow(f"""
|
||||||
SELECT {fields}, billing_address, default_, id::text
|
SELECT {fields}, billing_address, default_, id::text
|
||||||
FROM user_payment_sources
|
FROM user_payment_sources
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
@ -143,9 +161,14 @@ async def get_payment_source(user_id: int, source_id: int) -> dict:
|
||||||
return {**source, **derow}
|
return {**source, **derow}
|
||||||
|
|
||||||
|
|
||||||
async def get_subscription(subscription_id: int):
|
async def get_subscription(subscription_id: int, db=None):
|
||||||
row = await app.db.fetchrow("""
|
"""Get a subscription's information."""
|
||||||
|
if not db:
|
||||||
|
db = app.db
|
||||||
|
|
||||||
|
row = await db.fetchrow("""
|
||||||
SELECT id::text, source_id::text AS payment_source_id,
|
SELECT id::text, source_id::text AS payment_source_id,
|
||||||
|
user_id,
|
||||||
payment_gateway, payment_gateway_plan_id,
|
payment_gateway, payment_gateway_plan_id,
|
||||||
period_start AS current_period_start,
|
period_start AS current_period_start,
|
||||||
period_end AS current_period_end,
|
period_end AS current_period_end,
|
||||||
|
|
@ -167,9 +190,13 @@ async def get_subscription(subscription_id: int):
|
||||||
return drow
|
return drow
|
||||||
|
|
||||||
|
|
||||||
async def get_payment(payment_id: int):
|
async def get_payment(payment_id: int, db=None):
|
||||||
row = await app.db.fetchrow("""
|
"""Get a single payment's information."""
|
||||||
SELECT id::text, source_id, subscription_id,
|
if not db:
|
||||||
|
db = app.db
|
||||||
|
|
||||||
|
row = await db.fetchrow("""
|
||||||
|
SELECT id::text, source_id, subscription_id, user_id,
|
||||||
amount, amount_refunded, currency,
|
amount, amount_refunded, currency,
|
||||||
description, status, tax, tax_inclusive
|
description, status, tax, tax_inclusive
|
||||||
FROM user_payments
|
FROM user_payments
|
||||||
|
|
@ -177,10 +204,47 @@ async def get_payment(payment_id: int):
|
||||||
""", payment_id)
|
""", payment_id)
|
||||||
|
|
||||||
drow = dict(row)
|
drow = dict(row)
|
||||||
|
|
||||||
|
drow.pop('source_id')
|
||||||
|
drow.pop('subscription_id')
|
||||||
|
drow.pop('user_id')
|
||||||
|
|
||||||
drow['created_at'] = snowflake_datetime(int(drow['id']))
|
drow['created_at'] = snowflake_datetime(int(drow['id']))
|
||||||
|
|
||||||
|
drow['payment_source'] = await get_payment_source(
|
||||||
|
row['user_id'], row['source_id'], db)
|
||||||
|
|
||||||
|
drow['subscription'] = await get_subscription(
|
||||||
|
row['subscription_id'], db)
|
||||||
|
|
||||||
return drow
|
return drow
|
||||||
|
|
||||||
|
|
||||||
|
async def create_payment(subscription_id, app):
|
||||||
|
"""Create a payment."""
|
||||||
|
sub = await get_subscription(subscription_id, app.db)
|
||||||
|
|
||||||
|
new_id = get_snowflake()
|
||||||
|
|
||||||
|
amount = AMOUNTS[sub['payment_gateway_plan_id']]
|
||||||
|
|
||||||
|
await app.db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO user_payments (
|
||||||
|
id, source_id, subscription_id, user_id,
|
||||||
|
amount, amount_refunded, currency,
|
||||||
|
description, status, tax, tax_inclusive
|
||||||
|
)
|
||||||
|
VALUES
|
||||||
|
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
|
||||||
|
""", new_id, int(sub['payment_source_id']),
|
||||||
|
subscription_id, int(sub['user_id']),
|
||||||
|
amount, 'usd', 'FUCK NITRO',
|
||||||
|
PaymentStatus.SUCCESS)
|
||||||
|
|
||||||
|
return new_id
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/@me/billing/payment-sources', methods=['GET'])
|
@bp.route('/@me/billing/payment-sources', methods=['GET'])
|
||||||
async def _get_billing_sources():
|
async def _get_billing_sources():
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
@ -248,7 +312,7 @@ async def _create_subscription():
|
||||||
|
|
||||||
source = await get_payment_source(user_id, j['payment_source_id'])
|
source = await get_payment_source(user_id, j['payment_source_id'])
|
||||||
if not source:
|
if not source:
|
||||||
raise BadInput('invalid source id')
|
raise BadRequest('invalid source id')
|
||||||
|
|
||||||
plan_id = j['payment_gateway_plan_id']
|
plan_id = j['payment_gateway_plan_id']
|
||||||
|
|
||||||
|
|
@ -273,6 +337,8 @@ async def _create_subscription():
|
||||||
SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
|
SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
|
||||||
plan_id, 1)
|
plan_id, 1)
|
||||||
|
|
||||||
|
await create_payment(new_id, app)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
await get_subscription(new_id)
|
await get_subscription(new_id)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
this file only serves the periodic payment job code.
|
||||||
|
"""
|
||||||
|
import datetime
|
||||||
|
from asyncio import sleep
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
from litecord.blueprints.user.billing import (
|
||||||
|
get_subscription, get_payment_ids, get_payment, PaymentStatus,
|
||||||
|
create_payment
|
||||||
|
)
|
||||||
|
|
||||||
|
from litecord.snowflake import snowflake_datetime, get_snowflake
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
# how many days until a payment needs
|
||||||
|
# to be issued
|
||||||
|
THRESHOLDS = {
|
||||||
|
'premium_month_tier_1': 30,
|
||||||
|
'premium_month_tier_2': 30,
|
||||||
|
'premium_year_tier_1': 365,
|
||||||
|
'premium_year_tier_2': 365,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _resched(app):
|
||||||
|
log.debug('waiting 2 minutes for job.')
|
||||||
|
await sleep(120)
|
||||||
|
await app.sched.spawn(payment_job(app))
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_user_payments(app, user_id: int):
|
||||||
|
payments = await get_payment_ids(user_id, app.db)
|
||||||
|
|
||||||
|
if not payments:
|
||||||
|
log.debug('no payments for uid {}, skipping', user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
log.debug('{} payments for uid {}', len(payments), user_id)
|
||||||
|
|
||||||
|
latest_payment = max(payments)
|
||||||
|
|
||||||
|
payment_data = await get_payment(latest_payment, app.db)
|
||||||
|
|
||||||
|
# calculate the difference between this payment
|
||||||
|
# and now.
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
payment_tstamp = snowflake_datetime(int(payment_data['id']))
|
||||||
|
|
||||||
|
delta = now - payment_tstamp
|
||||||
|
|
||||||
|
sub_id = int(payment_data['subscription']['id'])
|
||||||
|
subscription = await get_subscription(
|
||||||
|
sub_id, app.db)
|
||||||
|
|
||||||
|
threshold = THRESHOLDS[subscription['payment_gateway_plan_id']]
|
||||||
|
|
||||||
|
log.debug('delta {} delta days {} threshold {}',
|
||||||
|
delta, delta.days, threshold)
|
||||||
|
|
||||||
|
if delta.days > threshold:
|
||||||
|
# insert new payment, for free !!!!!!
|
||||||
|
log.info('creating payment for sid={}',
|
||||||
|
sub_id)
|
||||||
|
await create_payment(sub_id, app)
|
||||||
|
else:
|
||||||
|
log.debug('not there yet for sid={}', sub_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def payment_job(app):
|
||||||
|
"""Main payment job function.
|
||||||
|
|
||||||
|
This function will check through users' payments
|
||||||
|
and add a new one once a month / year.
|
||||||
|
"""
|
||||||
|
log.info('payment job start!')
|
||||||
|
|
||||||
|
user_ids = await app.db.fetch("""
|
||||||
|
SELECT DISTINCT user_id
|
||||||
|
FROM user_payments
|
||||||
|
""")
|
||||||
|
|
||||||
|
log.debug('working {} users', len(user_ids))
|
||||||
|
print(user_ids)
|
||||||
|
|
||||||
|
# go through each user's payments
|
||||||
|
for row in user_ids:
|
||||||
|
user_id = row['user_id']
|
||||||
|
try:
|
||||||
|
await _process_user_payments(app, user_id)
|
||||||
|
except Exception:
|
||||||
|
log.exception('error while processing user payments')
|
||||||
|
|
||||||
|
await _resched(app)
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class JobManager:
|
||||||
|
"""Manage background jobs"""
|
||||||
|
def __init__(self, loop=None):
|
||||||
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
self.jobs = []
|
||||||
|
|
||||||
|
def spawn(self, coro):
|
||||||
|
task = self.loop.create_task(coro)
|
||||||
|
self.jobs.append(task)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
for job in self.jobs:
|
||||||
|
job.cancel()
|
||||||
21
run.py
21
run.py
|
|
@ -32,6 +32,10 @@ from litecord.blueprints.user import (
|
||||||
user_settings, user_billing
|
user_settings, user_billing
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from litecord.blueprints.user.billing_job import (
|
||||||
|
payment_job
|
||||||
|
)
|
||||||
|
|
||||||
from litecord.ratelimits.handler import ratelimit_handler
|
from litecord.ratelimits.handler import ratelimit_handler
|
||||||
from litecord.ratelimits.main import RatelimitManager
|
from litecord.ratelimits.main import RatelimitManager
|
||||||
|
|
||||||
|
|
@ -42,6 +46,7 @@ from litecord.storage import Storage
|
||||||
from litecord.dispatcher import EventDispatcher
|
from litecord.dispatcher import EventDispatcher
|
||||||
from litecord.presence import PresenceManager
|
from litecord.presence import PresenceManager
|
||||||
from litecord.images import IconManager
|
from litecord.images import IconManager
|
||||||
|
from litecord.jobs import JobManager
|
||||||
|
|
||||||
# setup logbook
|
# setup logbook
|
||||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||||
|
|
@ -164,10 +169,15 @@ async def app_set_ratelimit_headers(resp):
|
||||||
|
|
||||||
|
|
||||||
async def init_app_db(app):
|
async def init_app_db(app):
|
||||||
"""Connect to databases"""
|
"""Connect to databases.
|
||||||
|
|
||||||
|
Also spawns the job scheduler.
|
||||||
|
"""
|
||||||
log.info('db connect')
|
log.info('db connect')
|
||||||
app.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
app.db = await asyncpg.create_pool(**app.config['POSTGRES'])
|
||||||
|
|
||||||
|
app.sched = JobManager()
|
||||||
|
|
||||||
|
|
||||||
def init_app_managers(app):
|
def init_app_managers(app):
|
||||||
"""Initialize singleton classes."""
|
"""Initialize singleton classes."""
|
||||||
|
|
@ -180,9 +190,15 @@ def init_app_managers(app):
|
||||||
app.dispatcher = EventDispatcher(app)
|
app.dispatcher = EventDispatcher(app)
|
||||||
app.presence = PresenceManager(app.storage,
|
app.presence = PresenceManager(app.storage,
|
||||||
app.state_manager, app.dispatcher)
|
app.state_manager, app.dispatcher)
|
||||||
|
|
||||||
app.storage.presence = app.presence
|
app.storage.presence = app.presence
|
||||||
|
|
||||||
|
|
||||||
|
async def post_app_start(app):
|
||||||
|
# we'll need to start a billing job
|
||||||
|
app.sched.spawn(payment_job(app))
|
||||||
|
|
||||||
|
|
||||||
@app.before_serving
|
@app.before_serving
|
||||||
async def app_before_serving():
|
async def app_before_serving():
|
||||||
log.info('opening db')
|
log.info('opening db')
|
||||||
|
|
@ -209,6 +225,7 @@ async def app_before_serving():
|
||||||
|
|
||||||
ws_future = websockets.serve(_wrapper, host, port)
|
ws_future = websockets.serve(_wrapper, host, port)
|
||||||
|
|
||||||
|
await post_app_start(app)
|
||||||
await ws_future
|
await ws_future
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -223,6 +240,8 @@ async def app_after_serving():
|
||||||
|
|
||||||
app.state_manager.close()
|
app.state_manager.close()
|
||||||
|
|
||||||
|
app.sched.close()
|
||||||
|
|
||||||
log.info('closing db')
|
log.info('closing db')
|
||||||
await app.db.close()
|
await app.db.close()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue