"""
Litecord
Copyright (C) 2018 Luna Mendes
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see .
"""
import json
import datetime
from enum import Enum
from quart import Blueprint, jsonify, request, current_app as app
from logbook import Logger
from litecord.auth import token_check
from litecord.schemas import validate
from litecord.snowflake import snowflake_datetime, get_snowflake
from litecord.errors import BadRequest
from litecord.types import timestamp_, HOURS
from litecord.enums import UserFlags, PremiumType
from litecord.blueprints.users import mass_user_update
log = Logger(__name__)
bp = Blueprint('users_billing', __name__)
class PaymentSource(Enum):
CREDIT = 1
PAYPAL = 2
class SubscriptionStatus:
ACTIVE = 1
CANCELLED = 3
class SubscriptionType:
# unknown
PURCHASE = 1
UPGRADE = 2
class SubscriptionPlan:
CLASSIC = 1
NITRO = 2
class PaymentGateway:
STRIPE = 1
BRAINTREE = 2
class PaymentStatus:
SUCCESS = 1
FAILED = 2
PLAN_ID_TO_TYPE = {
'premium_month_tier_1': PremiumType.TIER_1,
'premium_month_tier_2': PremiumType.TIER_2,
'premium_year_tier_1': PremiumType.TIER_1,
'premium_year_tier_2': PremiumType.TIER_2,
}
# how much should a payment be, depending
# of the subscription
AMOUNTS = {
'premium_month_tier_1': 499,
'premium_month_tier_2': 999,
'premium_year_tier_1': 4999,
'premium_year_tier_2': 9999,
}
CREATE_SUBSCRIPTION = {
'payment_gateway_plan_id': {'type': 'string'},
'payment_source_id': {'coerce': int}
}
PAYMENT_SOURCE = {
'billing_address': {
'type': 'dict',
'schema': {
'country': {'type': 'string', 'required': True},
'city': {'type': 'string', 'required': True},
'name': {'type': 'string', 'required': True},
'line_1': {'type': 'string', 'required': False},
'line_2': {'type': 'string', 'required': False},
'postal_code': {'type': 'string', 'required': True},
'state': {'type': 'string', 'required': True},
}
},
'payment_gateway': {'type': 'number', 'required': True},
'token': {'type': 'string', 'required': True},
}
async def get_payment_source_ids(user_id: int) -> list:
rows = await app.db.fetch("""
SELECT id
FROM user_payment_sources
WHERE user_id = $1
""", user_id)
return [r['id'] for r in rows]
async def get_payment_ids(user_id: int, db=None) -> list:
if not db:
db = app.db
rows = await db.fetch("""
SELECT id
FROM user_payments
WHERE user_id = $1
""", user_id)
return [r['id'] for r in rows]
async def get_subscription_ids(user_id: int) -> list:
rows = await app.db.fetch("""
SELECT id
FROM user_subscriptions
WHERE user_id = $1
""", user_id)
return [r['id'] for r in rows]
async def get_payment_source(user_id: int, source_id: int, db=None) -> dict:
"""Get a payment source's information."""
if not db:
db = app.db
source = {}
source_type = await db.fetchval("""
SELECT source_type
FROM user_payment_sources
WHERE id = $1 AND user_id = $2
""", source_id, user_id)
source_type = PaymentSource(source_type)
specific_fields = {
PaymentSource.PAYPAL: ['paypal_email'],
PaymentSource.CREDIT: ['expires_month', 'expires_year',
'brand', 'cc_full']
}[source_type]
fields = ','.join(specific_fields)
extras_row = await db.fetchrow(f"""
SELECT {fields}, billing_address, default_, id::text
FROM user_payment_sources
WHERE id = $1
""", source_id)
derow = dict(extras_row)
if source_type == PaymentSource.CREDIT:
derow['last_4'] = derow['cc_full'][-4:]
derow.pop('cc_full')
derow['default'] = derow['default_']
derow.pop('default_')
source = {
'id': str(source_id),
'type': source_type.value,
}
return {**source, **derow}
async def get_subscription(subscription_id: int, db=None):
"""Get a subscription's information."""
if not db:
db = app.db
row = await db.fetchrow("""
SELECT id::text, source_id::text AS payment_source_id,
user_id,
payment_gateway, payment_gateway_plan_id,
period_start AS current_period_start,
period_end AS current_period_end,
canceled_at, s_type, status
FROM user_subscriptions
WHERE id = $1
""", subscription_id)
drow = dict(row)
drow['type'] = drow['s_type']
drow.pop('s_type')
to_tstamp = ['current_period_start', 'current_period_end', 'canceled_at']
for field in to_tstamp:
drow[field] = timestamp_(drow[field])
return drow
async def get_payment(payment_id: int, db=None):
"""Get a single payment's information."""
if not db:
db = app.db
row = await db.fetchrow("""
SELECT id::text, source_id, subscription_id, user_id,
amount, amount_refunded, currency,
description, status, tax, tax_inclusive
FROM user_payments
WHERE id = $1
""", payment_id)
drow = dict(row)
drow.pop('source_id')
drow.pop('subscription_id')
drow.pop('user_id')
drow['created_at'] = snowflake_datetime(int(drow['id']))
drow['payment_source'] = await get_payment_source(
row['user_id'], row['source_id'], db)
drow['subscription'] = await get_subscription(
row['subscription_id'], db)
return drow
async def create_payment(subscription_id, db=None):
"""Create a payment."""
if not db:
db = app.db
sub = await get_subscription(subscription_id, db)
new_id = get_snowflake()
amount = AMOUNTS[sub['payment_gateway_plan_id']]
await db.execute(
"""
INSERT INTO user_payments (
id, source_id, subscription_id, user_id,
amount, amount_refunded, currency,
description, status, tax, tax_inclusive
)
VALUES
($1, $2, $3, $4, $5, 0, $6, $7, $8, 0, false)
""", new_id, int(sub['payment_source_id']),
subscription_id, int(sub['user_id']),
amount, 'usd', 'FUCK NITRO',
PaymentStatus.SUCCESS)
return new_id
async def process_subscription(app, subscription_id: int):
"""Process a single subscription."""
sub = await get_subscription(subscription_id, app.db)
user_id = int(sub['user_id'])
if sub['status'] != SubscriptionStatus.ACTIVE:
log.debug('ignoring sub {}, not active',
subscription_id)
return
# if the subscription is still active
# (should get cancelled status on failed
# payments), then we should update premium status
first_payment_id = await app.db.fetchval("""
SELECT MIN(id)
FROM user_payments
WHERE subscription_id = $1
""", subscription_id)
first_payment_ts = snowflake_datetime(first_payment_id)
premium_since = await app.db.fetchval("""
SELECT premium_since
FROM users
WHERE id = $1
""", user_id)
premium_since = premium_since or datetime.datetime.fromtimestamp(0)
delta = abs(first_payment_ts - premium_since)
# if the time difference between the first payment
# and the premium_since column is more than 24h
# we update it.
if delta.total_seconds() < 24 * HOURS:
return
old_flags = await app.db.fetchval("""
SELECT flags
FROM users
WHERE id = $1
""", user_id)
new_flags = old_flags | UserFlags.premium_early
log.debug('updating flags {}, {} => {}',
user_id, old_flags, new_flags)
await app.db.execute("""
UPDATE users
SET premium_since = $1, flags = $2
WHERE id = $3
""", first_payment_ts, new_flags, user_id)
# dispatch updated user to all possible clients
await mass_user_update(user_id, app)
@bp.route('/@me/billing/payment-sources', methods=['GET'])
async def _get_billing_sources():
user_id = await token_check()
source_ids = await get_payment_source_ids(user_id)
res = []
for source_id in source_ids:
source = await get_payment_source(user_id, source_id)
res.append(source)
return jsonify(res)
@bp.route('/@me/billing/subscriptions', methods=['GET'])
async def _get_billing_subscriptions():
user_id = await token_check()
sub_ids = await get_subscription_ids(user_id)
res = []
for sub_id in sub_ids:
res.append(await get_subscription(sub_id))
return jsonify(res)
@bp.route('/@me/billing/payments', methods=['GET'])
async def _get_billing_payments():
user_id = await token_check()
payment_ids = await get_payment_ids(user_id)
res = []
for payment_id in payment_ids:
res.append(await get_payment(payment_id))
return jsonify(res)
@bp.route('/@me/billing/payment-sources', methods=['POST'])
async def _create_payment_source():
user_id = await token_check()
j = validate(await request.get_json(), PAYMENT_SOURCE)
new_source_id = get_snowflake()
await app.db.execute(
"""
INSERT INTO user_payment_sources (id, user_id, source_type,
default_, expires_month, expires_year, brand, cc_full,
billing_address)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
""", new_source_id, user_id, PaymentSource.CREDIT.value,
True, 12, 6969, 'Visa', '4242424242424242',
json.dumps(j['billing_address']))
return jsonify(
await get_payment_source(user_id, new_source_id)
)
@bp.route('/@me/billing/subscriptions', methods=['POST'])
async def _create_subscription():
user_id = await token_check()
j = validate(await request.get_json(), CREATE_SUBSCRIPTION)
source = await get_payment_source(user_id, j['payment_source_id'])
if not source:
raise BadRequest('invalid source id')
plan_id = j['payment_gateway_plan_id']
# tier 1 is lightro / classic
# tier 2 is nitro
period_end = {
'premium_month_tier_1': '1 month',
'premium_month_tier_2': '1 month',
'premium_year_tier_1': '1 year',
'premium_year_tier_2': '1 year',
}[plan_id]
new_id = get_snowflake()
await app.db.execute(
f"""
INSERT INTO user_subscriptions (id, source_id, user_id,
s_type, payment_gateway, payment_gateway_plan_id,
status, period_end)
VALUES ($1, $2, $3, $4, $5, $6, $7,
now()::timestamp + interval '{period_end}')
""", new_id, j['payment_source_id'], user_id,
SubscriptionType.PURCHASE, PaymentGateway.STRIPE,
plan_id, 1)
await create_payment(new_id, app.db)
# make sure we update the user's premium status
# and dispatch respective user updates to other people.
await process_subscription(app, new_id)
return jsonify(
await get_subscription(new_id)
)
@bp.route('/@me/billing/subscriptions/',
methods=['DELETE'])
async def _delete_subscription(subscription_id):
# user_id = await token_check()
# return '', 204
pass
@bp.route('/@me/billing/subscriptions/',
methods=['PATCH'])
async def _patch_subscription(subscription_id):
"""change a subscription's payment source"""
# user_id = await token_check()
# j = validate(await request.get_json(), PATCH_SUBSCRIPTION)
# returns subscription object
pass