diff --git a/app/__main__.py b/app/__main__.py index 2057c04..645f6ee 100644 --- a/app/__main__.py +++ b/app/__main__.py @@ -12,17 +12,18 @@ try: except BaseException: pass -from app.core._utils.create_maria_tables import create_maria_tables +from app.core._utils.create_maria_tables import create_db_tables from app.core.storage import engine if startup_target == '__main__': - create_maria_tables(engine) + # Ensure DB schema exists (async engine) + asyncio.get_event_loop().run_until_complete(create_db_tables(engine)) else: time.sleep(7) from app.api import app from app.bot import dp as uploader_bot_dp from app.client_bot import dp as client_bot_dp -from app.core._config import SANIC_PORT, MYSQL_URI, PROJECT_HOST +from app.core._config import SANIC_PORT, PROJECT_HOST, DATABASE_URL from app.core.logger import make_log if int(os.getenv("SANIC_MAINTENANCE", '0')) == 1: @@ -52,7 +53,11 @@ async def execute_queue(app): make_log(None, f"Application normally started. HTTP port: {SANIC_PORT}") make_log(None, f"Telegram bot: https://t.me/{telegram_bot_username}") make_log(None, f"Client Telegram bot: https://t.me/{client_telegram_bot_username}") - make_log(None, f"MariaDB host: {MYSQL_URI.split('@')[1].split('/')[0].replace('/', '')}") + try: + _db_host = DATABASE_URL.split('@')[1].split('/')[0].replace('/', '') + except Exception: + _db_host = 'postgres://' + make_log(None, f"PostgreSQL host: {_db_host}") make_log(None, f"API host: {PROJECT_HOST}") while True: try: diff --git a/app/api/middleware.py b/app/api/middleware.py index 548b9f5..7cdd950 100644 --- a/app/api/middleware.py +++ b/app/api/middleware.py @@ -8,7 +8,8 @@ from app.core.models.keys import KnownKey from app.core.models._telegram.wrapped_bot import Wrapped_CBotChat from app.core.models.user_activity import UserActivity from app.core.models.user import User -from app.core.storage import Session +from sqlalchemy import select +from app.core.storage import new_session from datetime import datetime, timedelta @@ -30,7 +31,8 @@ async def try_authorization(request): make_log("auth", "Invalid token length", level="warning") return - known_key = request.ctx.db_session.query(KnownKey).filter(KnownKey.seed == token).first() + result = await request.ctx.db_session.execute(select(KnownKey).where(KnownKey.seed == token)) + known_key = result.scalars().first() if not known_key: make_log("auth", "Unknown key", level="warning") return @@ -58,7 +60,8 @@ async def try_authorization(request): make_log("auth", f"User ID mismatch: {known_key.meta.get('I_user_id', -1)} != {user_id}", level="warning") return - user = request.ctx.db_session.query(User).filter(User.id == known_key.meta['I_user_id']).first() + result = await request.ctx.db_session.execute(select(User).where(User.id == known_key.meta['I_user_id'])) + user = result.scalars().first() if not user: make_log("auth", "No user from key", level="warning") return @@ -130,14 +133,14 @@ async def save_activity(request): created=datetime.now() ) request.ctx.db_session.add(new_user_activity) - request.ctx.db_session.commit() + await request.ctx.db_session.commit() async def attach_user_to_request(request): if request.method == 'OPTIONS': return attach_headers(sanic_response.text("OK")) - request.ctx.db_session = Session() + request.ctx.db_session = new_session() request.ctx.verified_hash = None request.ctx.user = None request.ctx.user_key = None @@ -153,8 +156,8 @@ async def close_request_handler(request, response): response = sanic_response.text("OK") try: - request.ctx.db_session.close() - except BaseException as e: + await request.ctx.db_session.close() + except BaseException: pass response = attach_headers(response) diff --git a/app/api/routes/_blockchain.py b/app/api/routes/_blockchain.py index a6a47aa..dbff5dc 100644 --- a/app/api/routes/_blockchain.py +++ b/app/api/routes/_blockchain.py @@ -3,7 +3,7 @@ from datetime import datetime import traceback from sanic import response -from sqlalchemy import and_ +from sqlalchemy import and_, select, func from tonsdk.boc import begin_cell, begin_dict from tonsdk.utils import Address @@ -61,9 +61,9 @@ async def s_api_v1_blockchain_send_new_content_message(request): assert not err, f"Invalid content CID" # Поиск исходного файла загруженного - decrypted_content = request.ctx.db_session.query(StoredContent).filter( - StoredContent.hash == decrypted_content_cid.content_hash_b58 - ).first() + decrypted_content = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.hash == decrypted_content_cid.content_hash_b58) + )).scalars().first() assert decrypted_content, "No content locally found" assert decrypted_content.type == "local/content_bin", "Invalid content type" @@ -74,9 +74,9 @@ async def s_api_v1_blockchain_send_new_content_message(request): if request.json['image']: image_content_cid, err = resolve_content(request.json['image']) assert not err, f"Invalid image CID" - image_content = request.ctx.db_session.query(StoredContent).filter( - StoredContent.hash == image_content_cid.content_hash_b58 - ).first() + image_content = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.hash == image_content_cid.content_hash_b58) + )).scalars().first() assert image_content, "No image locally found" else: image_content_cid = None @@ -105,18 +105,22 @@ async def s_api_v1_blockchain_send_new_content_message(request): ) i += 1 - promo_free_upload_available = ( - 3 - (request.ctx.db_session.query(PromoAction).filter( - PromoAction.user_internal_id == request.ctx.user.id, - PromoAction.action_type == 'freeUpload', - ).count()) - ) - if request.ctx.db_session.query(BlockchainTask).filter( - and_( - BlockchainTask.user_id == request.ctx.user.id, - BlockchainTask.status != 'done', + _cnt = (await request.ctx.db_session.execute( + select(func.count()).select_from(PromoAction).where( + and_( + PromoAction.user_internal_id == request.ctx.user.id, + PromoAction.action_type == 'freeUpload' + ) ) - ).first(): + )).scalar() + promo_free_upload_available = 3 - int(_cnt or 0) + + has_pending_task = (await request.ctx.db_session.execute( + select(BlockchainTask).where( + and_(BlockchainTask.user_id == request.ctx.user.id, BlockchainTask.status != 'done') + ) + )).scalars().first() + if has_pending_task: make_log("Blockchain", f"User {request.ctx.user.id} already has a pending task", level='warning') promo_free_upload_available = 0 @@ -139,7 +143,7 @@ async def s_api_v1_blockchain_send_new_content_message(request): begin_cell() .store_uint(0x5491d08c, 32) .store_uint(int.from_bytes(encrypted_content_cid.content_hash, "big", signed=False), 256) - .store_address(Address(request.ctx.user.wallet_address(request.ctx.db_session))) + .store_address(Address(await request.ctx.user.wallet_address_async(request.ctx.db_session))) .store_ref( begin_cell() .store_ref( @@ -177,7 +181,7 @@ async def s_api_v1_blockchain_send_new_content_message(request): user_id = request.ctx.user.id ) request.ctx.db_session.add(blockchain_task) - request.ctx.db_session.commit() + await request.ctx.db_session.commit() await request.ctx.user_uploader_wrapper.send_message( request.ctx.user.translated('p_uploadContentTxPromo').format( @@ -258,18 +262,37 @@ async def s_api_v1_blockchain_send_purchase_content_message(request): assert field_key in request.json, f"No {field_key} provided" assert field_value(request.json[field_key]), f"Invalid {field_key} provided" - if not request.ctx.user.wallet_address(request.ctx.db_session): + if not (await request.ctx.user.wallet_address_async(request.ctx.db_session)): return response.json({"error": "No wallet address provided"}, status=400) - license_exist = request.ctx.db_session.query(UserContent).filter_by( - onchain_address=request.json['content_address'], - ).first() + from sqlalchemy import select + license_exist = (await request.ctx.db_session.execute(select(UserContent).where( + UserContent.onchain_address == request.json['content_address'] + ))).scalars().first() if license_exist: - r_content = StoredContent.from_cid(request.ctx.db_session, license_exist.content.cid.serialize_v2()) + from app.core.content.content_id import ContentId + _cid = ContentId.deserialize(license_exist.content.cid.serialize_v2()) + r_content = (await request.ctx.db_session.execute(select(StoredContent).where(StoredContent.hash == _cid.content_hash_b58))).scalars().first() else: - r_content = StoredContent.from_cid(request.ctx.db_session, request.json['content_address']) + from app.core.content.content_id import ContentId + _cid = ContentId.deserialize(request.json['content_address']) + r_content = (await request.ctx.db_session.execute(select(StoredContent).where(StoredContent.hash == _cid.content_hash_b58))).scalars().first() - content = r_content.open_content(request.ctx.db_session) + async def open_content_async(session, sc: StoredContent): + if not sc.encrypted: + decrypted = sc + encrypted = (await session.execute(select(StoredContent).where(StoredContent.decrypted_content_id == sc.id))).scalars().first() + else: + encrypted = sc + decrypted = (await session.execute(select(StoredContent).where(StoredContent.id == sc.decrypted_content_id))).scalars().first() + assert decrypted and encrypted, "Can't open content" + ctype = decrypted.json_format().get('content_type', 'application/x-binary') + try: + content_type = ctype.split('/')[0] + except Exception: + content_type = 'application' + return {'encrypted_content': encrypted, 'decrypted_content': decrypted, 'content_type': content_type} + content = await open_content_async(request.ctx.db_session, r_content) licenses_cost = content['encrypted_content'].json_format()['license'] assert request.json['license_type'] in licenses_cost diff --git a/app/api/routes/_system.py b/app/api/routes/_system.py index f9d30f7..a32f448 100644 --- a/app/api/routes/_system.py +++ b/app/api/routes/_system.py @@ -6,6 +6,7 @@ from base58 import b58encode, b58decode from sanic import response from app.core.models.node_storage import StoredContent +from sqlalchemy import select from app.core._blockchain.ton.platform import platform from app.core._crypto.signer import Signer from app.core._secrets import hot_pubkey, service_wallet, hot_seed @@ -19,10 +20,10 @@ def get_git_info(): async def s_api_v1_node(request): # /api/v1/node - last_known_index = request.ctx.db_session.query(StoredContent).filter( - StoredContent.onchain_index != None - ).order_by(StoredContent.onchain_index.desc()).first() - last_known_index = last_known_index.onchain_index if last_known_index else 0 + last_known_index_obj = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.onchain_index != None).order_by(StoredContent.onchain_index.desc()) + )).scalars().first() + last_known_index = last_known_index_obj.onchain_index if last_known_index_obj else 0 last_known_index = max(last_known_index, 0) return response.json({ 'id': b58encode(hot_pubkey).decode(), @@ -39,10 +40,10 @@ async def s_api_v1_node(request): # /api/v1/node }) async def s_api_v1_node_friendly(request): - last_known_index = request.ctx.db_session.query(StoredContent).filter( - StoredContent.onchain_index != None - ).order_by(StoredContent.onchain_index.desc()).first() - last_known_index = last_known_index.onchain_index if last_known_index else 0 + last_known_index_obj = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.onchain_index != None).order_by(StoredContent.onchain_index.desc()) + )).scalars().first() + last_known_index = last_known_index_obj.onchain_index if last_known_index_obj else 0 last_known_index = max(last_known_index, 0) response_plain_text = f""" Node address: {service_wallet.address.to_string(1, 1, 1)} diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index 6e4d7ce..d8cd4be 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -37,7 +37,9 @@ async def s_api_v1_auth_twa(request): make_log("auth", "Invalid TWA data", level="warning") return response.json({"error": "Invalid TWA data"}, status=401) - known_user = request.ctx.db_session.query(User).filter(User.telegram_id == twa_data.user.id).first() + known_user = (await request.ctx.db_session.execute( + select(User).where(User.telegram_id == twa_data.user.id) + )).scalars().first() if not known_user: new_user = User( telegram_id=twa_data.user.id, @@ -52,9 +54,11 @@ async def s_api_v1_auth_twa(request): created=datetime.now() ) request.ctx.db_session.add(new_user) - request.ctx.db_session.commit() + await request.ctx.db_session.commit() - known_user = request.ctx.db_session.query(User).filter(User.telegram_id == twa_data.user.id).first() + known_user = (await request.ctx.db_session.execute( + select(User).where(User.telegram_id == twa_data.user.id) + )).scalars().first() assert known_user, "User not created" new_user_key = await known_user.create_api_token_v1(request.ctx.db_session, "USER_API_V1") @@ -65,12 +69,12 @@ async def s_api_v1_auth_twa(request): wallet_info.account = Account.from_dict(auth_data['ton_proof']['account']) wallet_info.ton_proof = TonProof.from_dict({'proof': auth_data['ton_proof']['ton_proof']}) connection_payload = auth_data['ton_proof']['ton_proof']['payload'] - known_payload = (request.ctx.db_session.execute(select(KnownKey).where(KnownKey.seed == connection_payload))).scalars().first() + known_payload = (await request.ctx.db_session.execute(select(KnownKey).where(KnownKey.seed == connection_payload))).scalars().first() assert known_payload, "Unknown payload" assert known_payload.meta['I_user_id'] == known_user.id, "Invalid user_id" assert wallet_info.check_proof(connection_payload), "Invalid proof" - for known_connection in (request.ctx.db_session.execute(select(WalletConnection).where( + for known_connection in (await request.ctx.db_session.execute(select(WalletConnection).where( and_( WalletConnection.user_id == known_user.id, WalletConnection.network == 'ton' @@ -78,7 +82,7 @@ async def s_api_v1_auth_twa(request): ))).scalars().all(): known_connection.invalidated = True - for other_connection in (request.ctx.db_session.execute(select(WalletConnection).where( + for other_connection in (await request.ctx.db_session.execute(select(WalletConnection).where( WalletConnection.wallet_address == Address(wallet_info.account.address).to_string(1, 1, 1) ))).scalars().all(): other_connection.invalidated = True @@ -99,12 +103,12 @@ async def s_api_v1_auth_twa(request): without_pk=False ) request.ctx.db_session.add(new_connection) - request.ctx.db_session.commit() + await request.ctx.db_session.commit() except BaseException as e: make_log("auth", f"Invalid ton_proof: {e}", level="warning") return response.json({"error": "Invalid ton_proof"}, status=400) - ton_connection = (request.ctx.db_session.execute(select(WalletConnection).where( + ton_connection = (await request.ctx.db_session.execute(select(WalletConnection).where( and_( WalletConnection.user_id == known_user.id, WalletConnection.network == 'ton', @@ -112,7 +116,7 @@ async def s_api_v1_auth_twa(request): ) ).order_by(WalletConnection.created.desc()))).scalars().first() known_user.last_use = datetime.now() - request.ctx.db_session.commit() + await request.ctx.db_session.commit() return response.json({ 'user': known_user.json_format(), @@ -124,7 +128,7 @@ async def s_api_v1_auth_me(request): if not request.ctx.user: return response.json({"error": "Unauthorized"}, status=401) - ton_connection = (request.ctx.db_session.execute( + ton_connection = (await request.ctx.db_session.execute( select(WalletConnection).where( and_( WalletConnection.user_id == request.ctx.user.id, @@ -132,7 +136,7 @@ async def s_api_v1_auth_me(request): WalletConnection.invalidated == False ) ).order_by(WalletConnection.created.desc()) - )).scalars().first() + ))).scalars().first() return response.json({ 'user': request.ctx.user.json_format(), @@ -159,10 +163,12 @@ async def s_api_v1_auth_select_wallet(request): user = request.ctx.user # Check if a WalletConnection already exists for this user with the given canonical wallet address - existing_connection = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.wallet_address == canonical_address - ).first() + existing_connection = (await db_session.execute(select(WalletConnection).where( + and_( + WalletConnection.user_id == user.id, + WalletConnection.wallet_address == canonical_address + ) + ))).scalars().first() if not existing_connection: return response.json({"error": "Wallet connection not found"}, status=404) @@ -185,6 +191,6 @@ async def s_api_v1_auth_select_wallet(request): without_pk=False ) db_session.add(new_connection) - db_session.commit() + await db_session.commit() return response.empty(status=200) diff --git a/app/api/routes/content.py b/app/api/routes/content.py index d392f28..24b056d 100644 --- a/app/api/routes/content.py +++ b/app/api/routes/content.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta from sanic import response +from sqlalchemy import select, and_, func from aiogram import Bot, types from sqlalchemy import and_ from app.core.logger import make_log @@ -22,13 +23,20 @@ async def s_api_v1_content_list(request): store = request.args.get('store', 'local') assert store in ('local', 'onchain'), "Invalid store" - content_list = request.ctx.db_session.query(StoredContent).filter( - StoredContent.type.like(store + '%'), - StoredContent.disabled == False - ).order_by(StoredContent.created.desc()).offset(offset).limit(limit) - make_log("Content", f"Listed {content_list.count()} contents", level='info') + stmt = ( + select(StoredContent) + .where( + StoredContent.type.like(store + '%'), + StoredContent.disabled == False + ) + .order_by(StoredContent.created.desc()) + .offset(offset) + .limit(limit) + ) + rows = (await request.ctx.db_session.execute(stmt)).scalars().all() + make_log("Content", f"Listed {len(rows)} contents", level='info') result = {} - for content in content_list.all(): + for content in rows: content_json = content.json_format() result[content_json["cid"]] = content_json @@ -38,23 +46,41 @@ async def s_api_v1_content_list(request): async def s_api_v1_content_view(request, content_address: str): # content_address can be CID or TON address - license_exist = request.ctx.db_session.query(UserContent).filter_by( - onchain_address=content_address, - ).first() + license_exist = (await request.ctx.db_session.execute( + select(UserContent).where(UserContent.onchain_address == content_address) + )).scalars().first() if license_exist: content_address = license_exist.content.cid.serialize_v2() - r_content = StoredContent.from_cid(request.ctx.db_session, content_address) - content = r_content.open_content(request.ctx.db_session) + from app.core.content.content_id import ContentId + cid = ContentId.deserialize(content_address) + r_content = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.hash == cid.content_hash_b58) + )).scalars().first() + async def open_content_async(session, sc: StoredContent): + if not sc.encrypted: + decrypted = sc + encrypted = (await session.execute(select(StoredContent).where(StoredContent.decrypted_content_id == sc.id))).scalars().first() + else: + encrypted = sc + decrypted = (await session.execute(select(StoredContent).where(StoredContent.id == sc.decrypted_content_id))).scalars().first() + assert decrypted and encrypted, "Can't open content" + ctype = decrypted.json_format().get('content_type', 'application/x-binary') + try: + content_type = ctype.split('/')[0] + except Exception: + content_type = 'application' + return {'encrypted_content': encrypted, 'decrypted_content': decrypted, 'content_type': content_type} + content = await open_content_async(request.ctx.db_session, r_content) opts = { 'content_type': content['content_type'], # возможно с ошибками, нужно переделать на ffprobe 'content_address': content['encrypted_content'].meta.get('item_address', '') } if content['encrypted_content'].key_id: - known_key = request.ctx.db_session.query(KnownKey).filter( - KnownKey.id == content['encrypted_content'].key_id - ).first() + known_key = (await request.ctx.db_session.execute( + select(KnownKey).where(KnownKey.id == content['encrypted_content'].key_id) + )).scalars().first() if known_key: opts['key_hash'] = known_key.seed_hash # нахер не нужно на данный момент @@ -64,22 +90,23 @@ async def s_api_v1_content_view(request, content_address: str): have_access = False if request.ctx.user: - user_wallet_address = request.ctx.user.wallet_address(request.ctx.db_session) + user_wallet_address = await request.ctx.user.wallet_address_async(request.ctx.db_session) have_access = ( (content['encrypted_content'].owner_address == user_wallet_address) - or bool(request.ctx.db_session.query(UserContent).filter_by(owner_address=user_wallet_address, status='active', - content_id=content['encrypted_content'].id).first()) \ - or bool(request.ctx.db_session.query(StarsInvoice).filter( + or bool((await request.ctx.db_session.execute(select(UserContent).where( + and_(UserContent.owner_address == user_wallet_address, UserContent.status == 'active', UserContent.content_id == content['encrypted_content'].id) + ))).scalars().first()) \ + or bool((await request.ctx.db_session.execute(select(StarsInvoice).where( and_( StarsInvoice.user_id == request.ctx.user.id, StarsInvoice.content_hash == content['encrypted_content'].hash, StarsInvoice.paid == True ) - ).first()) + ))).scalars().first()) ) if not have_access: - current_star_rate = ServiceConfig(request.ctx.db_session).get('live_tonPerStar', [0, 0])[0] + current_star_rate = (await ServiceConfig(request.ctx.db_session).get('live_tonPerStar', [0, 0]))[0] if current_star_rate < 0: current_star_rate = 0.00000001 @@ -88,14 +115,14 @@ async def s_api_v1_content_view(request, content_address: str): stars_cost = 2 invoice_id = f"access_{uuid.uuid4().hex}" - exist_invoice = request.ctx.db_session.query(StarsInvoice).filter( + exist_invoice = (await request.ctx.db_session.execute(select(StarsInvoice).where( and_( StarsInvoice.user_id == request.ctx.user.id, StarsInvoice.created > datetime.now() - timedelta(minutes=25), StarsInvoice.amount == stars_cost, StarsInvoice.content_hash == content['encrypted_content'].hash, ) - ).first() + ))).scalars().first() if exist_invoice: invoice_url = exist_invoice.invoice_url else: @@ -119,7 +146,7 @@ async def s_api_v1_content_view(request, content_address: str): invoice_url=invoice_url ) ) - request.ctx.db_session.commit() + await request.ctx.db_session.commit() except BaseException as e: make_log("Content", f"Can't create invoice link: {e}", level='warning') @@ -142,15 +169,20 @@ async def s_api_v1_content_view(request, content_address: str): if have_access: user_content_option = 'low' # TODO: подключать high если человек внезапно меломан - converted_content = request.ctx.db_session.query(StoredContent).filter( + converted_content = (await request.ctx.db_session.execute(select(StoredContent).where( StoredContent.hash == converted_content[user_content_option] - ).first() + ))).scalars().first() if converted_content: display_options['content_url'] = converted_content.web_url opts['content_ext'] = converted_content.filename.split('.')[-1] content_meta = content['encrypted_content'].json_format() - content_metadata = StoredContent.from_cid(request.ctx.db_session, content_meta.get('metadata_cid') or None) + from app.core.content.content_id import ContentId + _mcid = content_meta.get('metadata_cid') or None + content_metadata = None + if _mcid: + _cid = ContentId.deserialize(_mcid) + content_metadata = (await request.ctx.db_session.execute(select(StoredContent).where(StoredContent.hash == _cid.content_hash_b58))).scalars().first() with open(content_metadata.filepath, 'r') as f: content_metadata_json = json.loads(f.read()) @@ -187,14 +219,17 @@ async def s_api_v1_content_friendly_list(request): """ - for content in request.ctx.db_session.query(StoredContent).filter( + contents = (await request.ctx.db_session.execute(select(StoredContent).where( StoredContent.type == 'onchain/content' - ).all(): + ))).scalars().all() + for content in contents: if not content.meta.get('metadata_cid'): make_log("Content", f"Content {content.cid.serialize_v2()} has no metadata", level='warning') continue - metadata_content = StoredContent.from_cid(request.ctx.db_session, content.meta.get('metadata_cid')) + from app.core.content.content_id import ContentId + _cid = ContentId.deserialize(content.meta.get('metadata_cid')) + metadata_content = (await request.ctx.db_session.execute(select(StoredContent).where(StoredContent.hash == _cid.content_hash_b58))).scalars().first() with open(metadata_content.filepath, 'r') as f: metadata = json.loads(f.read()) @@ -228,10 +263,12 @@ async def s_api_v1_5_content_list(request): return response.json({'error': 'Invalid limit'}, status=400) # Query onchain contents which are not disabled - contents = request.ctx.db_session.query(StoredContent).filter( - StoredContent.type == 'onchain/content', - StoredContent.disabled == False - ).order_by(StoredContent.created.desc()).offset(offset).limit(limit).all() + contents = (await request.ctx.db_session.execute( + select(StoredContent) + .where(StoredContent.type == 'onchain/content', StoredContent.disabled == False) + .order_by(StoredContent.created.desc()) + .offset(offset).limit(limit) + )).scalars().all() result = [] for content in contents: @@ -240,7 +277,9 @@ async def s_api_v1_5_content_list(request): if not metadata_cid: continue # Skip if no metadata_cid is found - metadata_content = StoredContent.from_cid(request.ctx.db_session, metadata_cid) + from app.core.content.content_id import ContentId + _cid = ContentId.deserialize(metadata_cid) + metadata_content = (await request.ctx.db_session.execute(select(StoredContent).where(StoredContent.hash == _cid.content_hash_b58))).scalars().first() try: with open(metadata_content.filepath, 'r') as f: metadata = json.load(f) @@ -256,9 +295,9 @@ async def s_api_v1_5_content_list(request): preview_link = None converted_content = content.meta.get('converted_content') if converted_content: - converted_content = request.ctx.db_session.query(StoredContent).filter( + converted_content = (await request.ctx.db_session.execute(select(StoredContent).where( StoredContent.hash == converted_content['low_preview'] - ).first() + ))).scalars().first() preview_link = converted_content.web_url if converted_content.filename.split('.')[-1] in ('mp4', 'mov'): media_type = 'video' diff --git a/app/api/routes/node_storage.py b/app/api/routes/node_storage.py index f4a4b84..4ef2877 100644 --- a/app/api/routes/node_storage.py +++ b/app/api/routes/node_storage.py @@ -11,6 +11,7 @@ from sanic import response import json from app.core._config import UPLOADS_DIR +from sqlalchemy import select from app.core._utils.resolve_content import resolve_content from app.core.logger import make_log from app.core.models.node_storage import StoredContent @@ -52,7 +53,9 @@ async def s_api_v1_storage_post(request): try: file_hash_bin = hashlib.sha256(file_content).digest() file_hash = b58encode(file_hash_bin).decode() - stored_content = request.ctx.db_session.query(StoredContent).filter(StoredContent.hash == file_hash).first() + stored_content = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.hash == file_hash) + )).scalars().first() if stored_content: stored_cid = stored_content.cid.serialize_v1() stored_cid_v2 = stored_content.cid.serialize_v2() @@ -80,7 +83,7 @@ async def s_api_v1_storage_post(request): key_id=None, ) request.ctx.db_session.add(new_content) - request.ctx.db_session.commit() + await request.ctx.db_session.commit() file_path = os.path.join(UPLOADS_DIR, file_hash) async with aiofiles.open(file_path, "wb") as file: @@ -112,7 +115,9 @@ async def s_api_v1_storage_get(request, file_hash=None): return response.json({"error": errmsg}, status=400) content_sha256 = b58encode(cid.content_hash).decode() - content = request.ctx.db_session.query(StoredContent).filter(StoredContent.hash == content_sha256).first() + content = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.hash == content_sha256) + )).scalars().first() if not content: return response.json({"error": "File not found"}, status=404) @@ -139,7 +144,16 @@ async def s_api_v1_storage_get(request, file_hash=None): tempfile_path += "_mpeg" + (f"_{seconds_limit}" if seconds_limit else "") if not os.path.exists(tempfile_path): try: - cover_content = StoredContent.from_cid(content.meta.get('cover_cid')) + # Resolve cover content by CID (async) + from app.core.content.content_id import ContentId + try: + _cid = ContentId.deserialize(content.meta.get('cover_cid')) + _cover_hash = _cid.content_hash_b58 + cover_content = (await request.ctx.db_session.execute( + select(StoredContent).where(StoredContent.hash == _cover_hash) + )).scalars().first() + except Exception: + cover_content = None cover_tempfile_path = os.path.join(UPLOADS_DIR, f"tmp_{cover_content.hash}_jpeg") if not os.path.exists(cover_tempfile_path): cover_image = Image.open(cover_content.filepath) diff --git a/app/api/routes/progressive_storage.py b/app/api/routes/progressive_storage.py index 4cf7333..6650f8b 100644 --- a/app/api/routes/progressive_storage.py +++ b/app/api/routes/progressive_storage.py @@ -11,6 +11,7 @@ from base58 import b58encode from sanic import response from app.core.logger import make_log +from sqlalchemy import select from app.core.models.node_storage import StoredContent from app.core._config import UPLOADS_DIR from app.core._utils.resolve_content import resolve_content @@ -130,7 +131,7 @@ async def s_api_v1_5_storage_post(request): return response.json({"error": "Failed to finalize file storage"}, status=500) db_session = request.ctx.db_session - existing = db_session.query(StoredContent).filter_by(hash=computed_hash_b58).first() + existing = (await db_session.execute(select(StoredContent).where(StoredContent.hash == computed_hash_b58))).scalars().first() if existing: make_log("uploader_v1.5", f"File with hash {computed_hash_b58} already exists in DB", level="INFO") serialized_v2 = existing.cid.serialize_v2() @@ -156,7 +157,7 @@ async def s_api_v1_5_storage_post(request): created=datetime.utcnow() ) db_session.add(new_content) - db_session.commit() + await db_session.commit() make_log("uploader_v1.5", f"New file stored and indexed for user {user_id} with hash {computed_hash_b58}", level="INFO") except Exception as e: make_log("uploader_v1.5", f"Database error: {e}", level="ERROR") @@ -191,7 +192,7 @@ async def s_api_v1_5_storage_get(request, file_hash): return response.json({"error": "File not found"}, status=404) db_session = request.ctx.db_session - stored = db_session.query(StoredContent).filter_by(hash=file_hash).first() + stored = (await db_session.execute(select(StoredContent).where(StoredContent.hash == file_hash))).scalars().first() if stored and stored.filename: filename_for_mime = stored.filename else: diff --git a/app/api/routes/tonconnect.py b/app/api/routes/tonconnect.py index 8cc9fb1..03a85ff 100644 --- a/app/api/routes/tonconnect.py +++ b/app/api/routes/tonconnect.py @@ -4,6 +4,7 @@ from aiogram.utils.web_app import safe_parse_webapp_init_data from sanic import response from app.core._blockchain.ton.connect import TonConnect, unpack_wallet_info, WalletConnection +from sqlalchemy import select, and_ from app.core._config import TELEGRAM_API_KEY from app.core.models.user import User from app.core.logger import make_log @@ -23,8 +24,19 @@ async def s_api_v1_tonconnect_new(request): db_session = request.ctx.db_session user = request.ctx.user memory = request.ctx.memory - ton_connect, ton_connection = TonConnect.by_user(db_session, user) - await ton_connect.restore_connection() + # Try restore last connection from DB + ton_connection = (await db_session.execute(select(WalletConnection).where( + and_( + WalletConnection.user_id == user.id, + WalletConnection.invalidated == False, + WalletConnection.network == 'ton' + ) + ).order_by(WalletConnection.created.desc()))).scalars().first() + if ton_connection: + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) + await ton_connect.restore_connection() + else: + ton_connect = TonConnect() make_log("TonConnect_API", f"SDK connected?: {ton_connect.connected}", level='info') if ton_connect.connected: return response.json({"error": "Already connected"}, status=400) @@ -47,13 +59,11 @@ async def s_api_v1_tonconnect_logout(request): user = request.ctx.user memory = request.ctx.memory - wallet_connections = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.invalidated == False - ).all() + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False) + )) + wallet_connections = result.scalars().all() for wallet_connection in wallet_connections: wallet_connection.invalidated = True - - db_session.commit() + await db_session.commit() return response.json({"success": True}) - diff --git a/app/bot/middleware.py b/app/bot/middleware.py index f926c4a..1029023 100644 --- a/app/bot/middleware.py +++ b/app/bot/middleware.py @@ -1,6 +1,7 @@ from app.core.logger import make_log, logger from app.core.models._telegram import Wrapped_CBotChat from app.core.models.user import User +from sqlalchemy import select from app.core.storage import db_session from aiogram import BaseMiddleware, types from app.core.models.messages import KnownTelegramMessage @@ -21,9 +22,9 @@ class UserDataMiddleware(BaseMiddleware): # TODO: maybe make users cache - with db_session(auto_commit=False) as session: + async with db_session(auto_commit=False) as session: try: - user = session.query(User).filter_by(telegram_id=user_id).first() + user = (await session.execute(select(User).where(User.telegram_id == user_id))).scalars().first() except BaseException as e: logger.error(f"Error when middleware getting user: {e}") user = None @@ -42,7 +43,7 @@ class UserDataMiddleware(BaseMiddleware): created=datetime.now() ) session.add(user) - session.commit() + await session.commit() else: if user.username != update_body.from_user.username: user.username = update_body.from_user.username @@ -60,7 +61,7 @@ class UserDataMiddleware(BaseMiddleware): } user.last_use = datetime.now() - session.commit() + await session.commit() data['user'] = user data['db_session'] = session @@ -72,11 +73,11 @@ class UserDataMiddleware(BaseMiddleware): if update_body.text.startswith('/start'): message_type = 'start_command' - if session.query(KnownTelegramMessage).filter_by( - chat_id=update_body.chat.id, - message_id=update_body.message_id, - from_user=True - ).first(): + if (await session.execute(select(KnownTelegramMessage).where( + (KnownTelegramMessage.chat_id == update_body.chat.id) & + (KnownTelegramMessage.message_id == update_body.message_id) & + (KnownTelegramMessage.from_user == True) + ))).scalars().first(): make_log("UserDataMiddleware", f"Message {update_body.message_id} already processed", level='debug') return @@ -91,7 +92,7 @@ class UserDataMiddleware(BaseMiddleware): meta={} ) session.add(new_message) - session.commit() + await session.commit() result = await handler(event, data) return result diff --git a/app/bot/routers/content.py b/app/bot/routers/content.py index c5bf935..6bacb75 100644 --- a/app/bot/routers/content.py +++ b/app/bot/routers/content.py @@ -6,6 +6,7 @@ from app.core._keyboards import get_inline_keyboard from app.core._utils.tg_process_template import tg_process_template from app.core.logger import make_log from app.core.models.node_storage import StoredContent +from sqlalchemy import select, and_ import json router = Router() @@ -20,12 +21,13 @@ def chunks(lst, n): async def t_callback_owned_content(query: types.CallbackQuery, memory=None, user=None, db_session=None, chat_wrap=None, **extra): message_text = user.translated("ownedContent_menu") content_list = [] - for content in db_session.query(StoredContent).filter_by( - owner_address=user.wallet_address(db_session), - type='onchain/content' - ).all(): + user_addr = await user.wallet_address_async(db_session) + result = await db_session.execute(select(StoredContent).where( + and_(StoredContent.owner_address == user_addr, StoredContent.type == 'onchain/content') + )) + for content in result.scalars().all(): try: - metadata_content = StoredContent.from_cid(db_session, content.json_format()['metadata_cid']) + metadata_content = await StoredContent.from_cid_async(db_session, content.json_format()['metadata_cid']) with open(metadata_content.filepath, 'r') as f: metadata_content_json = json.loads(f.read()) except BaseException as e: @@ -59,10 +61,9 @@ async def t_callback_owned_content(query: types.CallbackQuery, memory=None, user async def t_callback_node_content(query: types.CallbackQuery, memory=None, user=None, db_session=None, chat_wrap=None, **extra): content_oid = int(query.data.split('_')[1]) + row = (await db_session.execute(select(StoredContent).where(StoredContent.id == content_oid))).scalars().first() return await chat_wrap.send_content( - db_session, db_session.query(StoredContent).filter_by( - id=content_oid - ).first(), + db_session, row, extra_buttons=[ [{ 'text': user.translated('back_button'), diff --git a/app/bot/routers/home.py b/app/bot/routers/home.py index f325621..686d34c 100644 --- a/app/bot/routers/home.py +++ b/app/bot/routers/home.py @@ -3,6 +3,7 @@ from aiogram.filters import Command from tonsdk.utils import Address from app.core._blockchain.ton.connect import TonConnect +from sqlalchemy import select, and_ from app.core._keyboards import get_inline_keyboard from app.core._utils.tg_process_template import tg_process_template from app.core.models.wallet_connection import WalletConnection @@ -32,8 +33,14 @@ async def send_home_menu(chat_wrap, user, wallet_connection, **kwargs): async def send_connect_wallets_list(db_session, chat_wrap, user, **kwargs): - ton_connect, ton_connection = TonConnect.by_user(db_session, user, callback_fn=()) - await ton_connect.restore_connection() + # Try to restore existing connection via DB + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False, WalletConnection.network == 'ton') + ).order_by(WalletConnection.created.desc())) + ton_connection = result.scalars().first() + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) if ton_connection else TonConnect() + if ton_connection: + await ton_connect.restore_connection() wallets = ton_connect._sdk_client.get_wallets() message_text = user.translated("connectWalletsList_menu") return await tg_process_template( @@ -66,10 +73,9 @@ async def t_home_menu(__msg, **extra): else: message_id = None - wallet_connection = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.invalidated == False - ).first() + wallet_connection = (await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False) + ))).scalars().first() # if not wallet_connection: # return await send_connect_wallets_list(db_session, chat_wrap, user, message_id=message_id) diff --git a/app/bot/routers/tonconnect.py b/app/bot/routers/tonconnect.py index cd5b293..a7052f1 100644 --- a/app/bot/routers/tonconnect.py +++ b/app/bot/routers/tonconnect.py @@ -7,6 +7,7 @@ from aiogram.filters import Command from app.bot.routers.home import send_connect_wallets_list, send_home_menu from app.core._blockchain.ton.connect import TonConnect, unpack_wallet_info +from sqlalchemy import select, and_ from app.core._keyboards import get_inline_keyboard from app.core._utils.tg_process_template import tg_process_template from app.core.logger import make_log @@ -33,15 +34,21 @@ async def t_tonconnect_dev_menu(message: types.Message, memory=None, user=None, keyboard = [] - ton_connect, ton_connection = TonConnect.by_user(db_session, user, callback_fn=()) + # Restore recent connection + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False, WalletConnection.network == 'ton') + ).order_by(WalletConnection.created.desc())) + ton_connection = result.scalars().first() + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) if ton_connection else TonConnect() make_log("TonConnect_DevMenu", f"Available wallets: {ton_connect._sdk_client.get_wallets()}", level='debug') - await ton_connect.restore_connection() + if ton_connection: + await ton_connect.restore_connection() make_log("TonConnect_DevMenu", f"SDK connected?: {ton_connect.connected}", level='info') if not ton_connect.connected: if ton_connection: make_log("TonConnect_DevMenu", f"Invalidating old connection", level='debug') ton_connection.invalidated = True - db_session.commit() + await db_session.commit() message_text = f"""Wallet is not connected @@ -71,8 +78,13 @@ Use /dev_tonconnect {wallet_app_name} for connect to wallet.""" async def t_callback_init_tonconnect(query: types.CallbackQuery, memory=None, user=None, db_session=None, chat_wrap=None, **extra): wallet_app_name = query.data.split("_")[1] - ton_connect, ton_connection = TonConnect.by_user(db_session, user) - await ton_connect.restore_connection() + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False, WalletConnection.network == 'ton') + ).order_by(WalletConnection.created.desc())) + ton_connection = result.scalars().first() + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) if ton_connection else TonConnect() + if ton_connection: + await ton_connect.restore_connection() connection_link = await ton_connect.new_connection(wallet_app_name) ton_connect.connected memory.add_task(pause_ton_connection, ton_connect, delay_s=60 * 3) @@ -98,10 +110,9 @@ async def t_callback_init_tonconnect(query: types.CallbackQuery, memory=None, us start_ts = datetime.now() while datetime.now() - start_ts < timedelta(seconds=180): - new_connection = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.invalidated == False - ).first() + new_connection = (await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False) + ))).scalars().first() if new_connection: await tg_process_template( chat_wrap, user.translated('p_successConnectWallet') @@ -115,14 +126,13 @@ async def t_callback_init_tonconnect(query: types.CallbackQuery, memory=None, us async def t_callback_disconnect_wallet(query: types.CallbackQuery, memory=None, user=None, db_session=None, chat_wrap=None, **extra): - wallet_connections = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.invalidated == False - ).all() + wallet_connections = (await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False) + ))).scalars().all() for wallet_connection in wallet_connections: wallet_connection.invalidated = True - db_session.commit() + await db_session.commit() return await send_home_menu(chat_wrap, user, None, message_id=query.message.message_id) diff --git a/app/client_bot/routers/content.py b/app/client_bot/routers/content.py index dd8fe7d..fe0b8d7 100644 --- a/app/client_bot/routers/content.py +++ b/app/client_bot/routers/content.py @@ -6,6 +6,7 @@ from aiogram import types, Router, F from app.core._keyboards import get_inline_keyboard from app.core.models.node_storage import StoredContent +from sqlalchemy import select, and_ import json from app.core.logger import make_log from app.core.models.content.user_content import UserAction, UserContent @@ -30,7 +31,7 @@ CACHE_CHAT_ID = -1002390124789 async def t_callback_purchase_node_content(query: types.CallbackQuery, memory=None, user=None, db_session=None, chat_wrap=None, **extra): content_oid = int(query.data.split('_')[1]) is_cancel_request = query.data.split('_')[2] == 'cancel' if len(query.data.split('_')) > 2 else False - content = db_session.query(StoredContent).filter_by(id=content_oid).first() + content = (await db_session.execute(select(StoredContent).where(StoredContent.id == content_oid))).scalars().first() if not content: return await query.answer(user.translated('error_contentNotFound'), show_alert=True) @@ -43,11 +44,16 @@ async def t_callback_purchase_node_content(query: types.CallbackQuery, memory=No make_log("Purchase", f"User {user.id} initiated purchase for content ID {content_oid}. License price: {license_price_num}.", level='info') - ton_connect, ton_connection = TonConnect.by_user(db_session, user, callback_fn=()) - await ton_connect.restore_connection() + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False, WalletConnection.network == 'ton') + ).order_by(WalletConnection.created.desc())) + ton_connection = result.scalars().first() + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) if ton_connection else TonConnect() + if ton_connection: + await ton_connect.restore_connection() assert ton_connect.connected, "No connected wallet" - user_wallet_address = user.wallet_address(db_session) + user_wallet_address = await user.wallet_address_async(db_session) memory._app.add_task(ton_connect._sdk_client.send_transaction({ 'valid_until': int(datetime.now().timestamp() + 300), @@ -76,18 +82,15 @@ async def t_callback_purchase_node_content(query: types.CallbackQuery, memory=No else: # Logging cancellation attempt with detailed information make_log("Purchase", f"User {user.id} cancelled purchase for content ID {content_oid}.", level='info') - action = db_session.query(UserAction).filter_by( - type='purchase', - content_id=content_oid, - user_id=user.id, - status='requested' - ).first() + action = (await db_session.execute(select(UserAction).where( + and_(UserAction.type == 'purchase', UserAction.content_id == content_oid, UserAction.user_id == user.id, UserAction.status == 'requested') + ))).scalars().first() if not action: return await query.answer() action.status = 'canceled' - db_session.commit() + await db_session.commit() await chat_wrap.send_content(db_session, content, message_id=query.message.message_id) @@ -104,9 +107,7 @@ async def t_inline_query_node_content(query: types.InlineQuery, memory=None, use args = None if source_args_ext.startswith('Q'): license_onchain_address = source_args_ext[1:] - licensed_content = db_session.query(UserContent).filter_by( - onchain_address=license_onchain_address, - ).first().content + licensed_content = (await db_session.execute(select(UserContent).where(UserContent.onchain_address == license_onchain_address))).scalars().first().content make_log("InlineSearch", f"Query '{query.query}' is a license query for content ID {licensed_content.id}.", level='info') args = licensed_content.cid.serialize_v2() else: @@ -118,15 +119,15 @@ async def t_inline_query_node_content(query: types.InlineQuery, memory=None, use content_list = [] search_query = {'hash': cid.content_hash_b58} make_log("InlineSearch", f"Searching with query '{search_query}'.", level='info') - content = db_session.query(StoredContent).filter_by(**search_query).first() - content_prod = content.open_content(db_session) + content = (await db_session.execute(select(StoredContent).where(StoredContent.hash == cid.content_hash_b58))).scalars().first() + content_prod = await content.open_content_async(db_session) # Get both encrypted and decrypted content objects encrypted_content = content_prod['encrypted_content'] decrypted_content = content_prod['decrypted_content'] decrypted_content_meta = decrypted_content.json_format() try: - metadata_content = StoredContent.from_cid(db_session, content.json_format()['metadata_cid']) + metadata_content = await StoredContent.from_cid_async(db_session, content.json_format()['metadata_cid']) with open(metadata_content.filepath, 'r') as f: metadata_content_json = json.loads(f.read()) except BaseException as e: @@ -144,7 +145,7 @@ async def t_inline_query_node_content(query: types.InlineQuery, memory=None, use result_kwargs = {} try: - cover_content = StoredContent.from_cid(db_session, decrypted_content_meta.get('cover_cid') or None) + cover_content = await StoredContent.from_cid_async(db_session, decrypted_content_meta.get('cover_cid') or None) except BaseException as e: cover_content = None @@ -152,9 +153,7 @@ async def t_inline_query_node_content(query: types.InlineQuery, memory=None, use result_kwargs['thumb_url'] = cover_content.web_url content_type_declared = decrypted_content_meta.get('content_type', 'application/x-binary').split('/')[0] - preview_content = db_session.query(StoredContent).filter_by( - hash=content.meta.get('converted_content', {}).get('low_preview') - ).first() + preview_content = (await db_session.execute(select(StoredContent).where(StoredContent.hash == content.meta.get('converted_content', {}).get('low_preview')))).scalars().first() content_type_declared = { 'mp3': 'audio', 'flac': 'audio', @@ -196,7 +195,7 @@ async def t_inline_query_node_content(query: types.InlineQuery, memory=None, use **decrypted_content.meta, 'telegram_file_cache_preview': preview_file_id } - db_session.commit() + await db_session.commit() except Exception as e: # Logging error during preview upload with detailed content type and query information make_log("InlineSearch", f"Error uploading preview for content type '{content_type_declared}' during inline query '{query.query}': {e}", level='error') diff --git a/app/client_bot/routers/home.py b/app/client_bot/routers/home.py index 5e14ef4..5125f2f 100644 --- a/app/client_bot/routers/home.py +++ b/app/client_bot/routers/home.py @@ -3,6 +3,7 @@ from aiogram.filters import Command from tonsdk.utils import Address from app.core._blockchain.ton.connect import TonConnect +from sqlalchemy import select, and_ from app.core._keyboards import get_inline_keyboard from app.core._utils.tg_process_template import tg_process_template from app.core.logger import make_log @@ -32,8 +33,13 @@ async def send_home_menu(chat_wrap, user, wallet_connection, **kwargs): async def send_connect_wallets_list(db_session, chat_wrap, user, **kwargs): - ton_connect, ton_connection = TonConnect.by_user(db_session, user, callback_fn=()) - await ton_connect.restore_connection() + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False, WalletConnection.network == 'ton') + ).order_by(WalletConnection.created.desc())) + ton_connection = result.scalars().first() + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) if ton_connection else TonConnect() + if ton_connection: + await ton_connect.restore_connection() wallets = ton_connect._sdk_client.get_wallets() message_text = user.translated("connectWalletsList_menu") return await tg_process_template( @@ -66,10 +72,9 @@ async def t_home_menu(__msg, **extra): else: message_id = None - wallet_connection = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.invalidated == False - ).first() + wallet_connection = (await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False) + ))).scalars().first() # if not wallet_connection: # return await send_connect_wallets_list(db_session, chat_wrap, user, message_id=message_id) diff --git a/app/client_bot/routers/tonconnect.py b/app/client_bot/routers/tonconnect.py index 6de8f29..30e6e21 100644 --- a/app/client_bot/routers/tonconnect.py +++ b/app/client_bot/routers/tonconnect.py @@ -7,6 +7,7 @@ from aiogram.filters import Command from app.client_bot.routers.home import send_connect_wallets_list, send_home_menu from app.core._blockchain.ton.connect import TonConnect, unpack_wallet_info +from sqlalchemy import select, and_ from app.core._keyboards import get_inline_keyboard from app.core._utils.tg_process_template import tg_process_template from app.core.logger import make_log @@ -34,15 +35,20 @@ async def t_tonconnect_dev_menu(message: types.Message, memory=None, user=None, keyboard = [] - ton_connect, ton_connection = TonConnect.by_user(db_session, user, callback_fn=()) + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False, WalletConnection.network == 'ton') + ).order_by(WalletConnection.created.desc())) + ton_connection = result.scalars().first() + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) if ton_connection else TonConnect() make_log("TonConnect_DevMenu", f"Available wallets: {ton_connect._sdk_client.get_wallets()}", level='debug') - await ton_connect.restore_connection() + if ton_connection: + await ton_connect.restore_connection() make_log("TonConnect_DevMenu", f"SDK connected?: {ton_connect.connected}", level='info') if not ton_connect.connected: if ton_connection: make_log("TonConnect_DevMenu", f"Invalidating old connection", level='debug') ton_connection.invalidated = True - db_session.commit() + await db_session.commit() message_text = f"""Wallet is not connected @@ -73,8 +79,13 @@ Use /dev_tonconnect {wallet_app_name} for connect to wallet.""" async def t_callback_init_tonconnect(query: types.CallbackQuery, memory=None, user=None, db_session=None, chat_wrap=None, **extra): wallet_app_name = query.data.split("_")[1] - ton_connect, ton_connection = TonConnect.by_user(db_session, user) - await ton_connect.restore_connection() + result = await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False, WalletConnection.network == 'ton') + ).order_by(WalletConnection.created.desc())) + ton_connection = result.scalars().first() + ton_connect = TonConnect.by_key(ton_connection.keys["connection_key"]) if ton_connection else TonConnect() + if ton_connection: + await ton_connect.restore_connection() connection_link = await ton_connect.new_connection(wallet_app_name) ton_connect.connected memory.add_task(pause_ton_connection, ton_connect, delay_s=60 * 3) @@ -100,10 +111,9 @@ async def t_callback_init_tonconnect(query: types.CallbackQuery, memory=None, us start_ts = datetime.now() while datetime.now() - start_ts < timedelta(seconds=180): - new_connection = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.invalidated == False - ).first() + new_connection = (await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False) + ))).scalars().first() if new_connection: await tg_process_template( chat_wrap, user.translated('p_successConnectWallet') @@ -118,14 +128,13 @@ async def t_callback_init_tonconnect(query: types.CallbackQuery, memory=None, us async def t_callback_disconnect_wallet(query: types.CallbackQuery, memory=None, user=None, db_session=None, chat_wrap=None, **extra): - wallet_connections = db_session.query(WalletConnection).filter( - WalletConnection.user_id == user.id, - WalletConnection.invalidated == False - ).all() + wallet_connections = (await db_session.execute(select(WalletConnection).where( + and_(WalletConnection.user_id == user.id, WalletConnection.invalidated == False) + ))).scalars().all() for wallet_connection in wallet_connections: wallet_connection.invalidated = True - db_session.commit() + await db_session.commit() return await send_home_menu(chat_wrap, user, None, message_id=query.message.message_id) diff --git a/app/core/_config.py b/app/core/_config.py index 1f7372b..67c2d65 100644 --- a/app/core/_config.py +++ b/app/core/_config.py @@ -19,9 +19,8 @@ import httpx TELEGRAM_BOT_USERNAME = httpx.get(f"https://api.telegram.org/bot{TELEGRAM_API_KEY}/getMe").json()['result']['username'] CLIENT_TELEGRAM_BOT_USERNAME = httpx.get(f"https://api.telegram.org/bot{CLIENT_TELEGRAM_API_KEY}/getMe").json()['result']['username'] - -MYSQL_URI = os.environ['MYSQL_URI'] -MYSQL_DATABASE = os.environ['MYSQL_DATABASE'] +# Unified database URL (PostgreSQL) +DATABASE_URL = os.environ['DATABASE_URL'] LOG_LEVEL = os.getenv('LOG_LEVEL', 'DEBUG') LOG_DIR = os.getenv('LOG_DIR', 'logs') diff --git a/app/core/_crypto/content.py b/app/core/_crypto/content.py index 6a5b446..fdb7c3f 100644 --- a/app/core/_crypto/content.py +++ b/app/core/_crypto/content.py @@ -36,9 +36,10 @@ async def create_new_encryption_key(db_session, user_id: int = None) -> KnownKey meta={"I_user_id": user_id} if user_id else None, created=datetime.now() ) + from sqlalchemy import select db_session.add(new_key) - db_session.commit() - new_key = db_session.query(KnownKey).filter(KnownKey.seed_hash == new_seed_hash).first() + await db_session.commit() + new_key = (await db_session.execute(select(KnownKey).where(KnownKey.seed_hash == new_seed_hash))).scalars().first() assert new_key, "Key not created" return new_key @@ -46,9 +47,10 @@ async def create_new_encryption_key(db_session, user_id: int = None) -> KnownKey async def create_encrypted_content( db_session, decrypted_content: StoredContent, ) -> StoredContent: - encrypted_content = db_session.query(StoredContent).filter( + from sqlalchemy import select + encrypted_content = (await db_session.execute(select(StoredContent).where( StoredContent.id == decrypted_content.decrypted_content_id - ).first() + ))).scalars().first() if encrypted_content: make_log("create_encrypted_content", f"(d={decrypted_content.cid.serialize_v2()}) => (e={encrypted_content.cid.serialize_v2()}): already exist (found by decrypted content)", level="debug") return encrypted_content @@ -57,10 +59,8 @@ async def create_encrypted_content( if decrypted_content.key is None: key = await create_new_encryption_key(db_session, user_id=decrypted_content.user_id) decrypted_content.key_id = key.id - db_session.commit() - decrypted_content = db_session.query(StoredContent).filter( - StoredContent.id == decrypted_content.id - ).first() + await db_session.commit() + decrypted_content = (await db_session.execute(select(StoredContent).where(StoredContent.id == decrypted_content.id))).scalars().first() assert decrypted_content.key_id, "Key not assigned" decrypted_path = os.path.join(UPLOADS_DIR, decrypted_content.hash) @@ -72,9 +72,7 @@ async def create_encrypted_content( encrypted_bin = cipher.encrypt(decrypted_bin) encrypted_hash_bin = sha256(encrypted_bin).digest() encrypted_hash = b58encode(encrypted_hash_bin).decode() - encrypted_content = db_session.query(StoredContent).filter( - StoredContent.hash == encrypted_hash - ).first() + encrypted_content = (await db_session.execute(select(StoredContent).where(StoredContent.hash == encrypted_hash))).scalars().first() if encrypted_content: make_log("create_encrypted_content", f"(d={decrypted_content.cid.serialize_v2()}) => (e={encrypted_content.cid.serialize_v2()}): already exist (found by encrypted_hash)", level="debug") return encrypted_content @@ -99,19 +97,16 @@ async def create_encrypted_content( created=datetime.now(), ) db_session.add(encrypted_content) - db_session.commit() + await db_session.commit() encrypted_path = os.path.join(UPLOADS_DIR, encrypted_hash) async with aiofiles.open(encrypted_path, mode='wb') as file: await file.write(encrypted_bin) - encrypted_content = db_session.query(StoredContent).filter( - StoredContent.hash == encrypted_hash - ).first() + encrypted_content = (await db_session.execute(select(StoredContent).where(StoredContent.hash == encrypted_hash))).scalars().first() assert encrypted_content, "Content not created" make_log("create_encrypted_content", f"(d={decrypted_content.cid.serialize_v2()}) => (e={encrypted_content.cid.serialize_v2()}): created new content/bin", level="debug") return encrypted_content - diff --git a/app/core/_secrets.py b/app/core/_secrets.py index 1a4c403..889626a 100644 --- a/app/core/_secrets.py +++ b/app/core/_secrets.py @@ -1,4 +1,6 @@ +import asyncio from os import getenv, urandom +import os from nacl.bindings import crypto_sign_seed_keypair from tonsdk.utils import Address @@ -7,38 +9,44 @@ from app.core._blockchain.ton.wallet_v3cr3 import WalletV3CR3 from app.core.models._config import ServiceConfig from app.core.storage import db_session from app.core.logger import make_log -import os -def load_hot_pair(): - with db_session() as session: +async def load_hot_pair_async(): + async with db_session() as session: service_config = ServiceConfig(session) - hot_seed = service_config.get('private_key') + hot_seed = await service_config.get('private_key') if hot_seed is None: make_log("HotWallet", "No seed found, generating new one", level='info') - hot_seed = os.getenv("TON_INIT_HOT_SEED") - if not hot_seed: - hot_seed = urandom(32) + hot_seed_env = os.getenv("TON_INIT_HOT_SEED") + if not hot_seed_env: + hot_seed_bytes = urandom(32) make_log("HotWallet", f"Generated random seed") else: - hot_seed = bytes.fromhex(hot_seed) + hot_seed_bytes = bytes.fromhex(hot_seed_env) make_log("HotWallet", f"Loaded seed from env") - service_config.set('private_key', hot_seed.hex()) - return load_hot_pair() + await service_config.set('private_key', hot_seed_bytes.hex()) + hot_seed = hot_seed_bytes.hex() - hot_seed = bytes.fromhex(hot_seed) - public_key, private_key = crypto_sign_seed_keypair(hot_seed) - return hot_seed, public_key, private_key + hot_seed_bytes = bytes.fromhex(hot_seed) + public_key, private_key = crypto_sign_seed_keypair(hot_seed_bytes) + return hot_seed_bytes, public_key, private_key _extra_ton_wallet_options = {} if getenv('TON_CUSTOM_WALLET_ADDRESS'): _extra_ton_wallet_options['address'] = Address(getenv('TON_CUSTOM_WALLET_ADDRESS')) -hot_seed, hot_pubkey, hot_privkey = load_hot_pair() -service_wallet = WalletV3CR3( - private_key=hot_privkey, - public_key=hot_pubkey, - **_extra_ton_wallet_options -) + +def _init_wallet(): + # Safe to call at import time; Sanic event loop not running yet + hot_seed, hot_pubkey, hot_privkey = asyncio.run(load_hot_pair_async()) + wallet = WalletV3CR3( + private_key=hot_privkey, + public_key=hot_pubkey, + **_extra_ton_wallet_options + ) + return hot_seed, hot_pubkey, hot_privkey, wallet + + +hot_seed, hot_pubkey, hot_privkey, service_wallet = _init_wallet() diff --git a/app/core/_utils/create_maria_tables.py b/app/core/_utils/create_maria_tables.py index 9ba16e5..33b8538 100644 --- a/app/core/_utils/create_maria_tables.py +++ b/app/core/_utils/create_maria_tables.py @@ -1,10 +1,12 @@ +from sqlalchemy.ext.asyncio import AsyncEngine from app.core.models import BlockchainTask from app.core.models.base import AlchemyBase -def create_maria_tables(engine): - """Create all tables in the database.""" +async def create_db_tables(engine: AsyncEngine): + """Create all tables in the database (PostgreSQL, async).""" + # ensure model import side-effects initialize mappers BlockchainTask() - AlchemyBase.metadata.create_all(engine) + async with engine.begin() as conn: + await conn.run_sync(AlchemyBase.metadata.create_all) - diff --git a/app/core/auth_v1.py b/app/core/auth_v1.py index a0e08ff..18a6ddf 100644 --- a/app/core/auth_v1.py +++ b/app/core/auth_v1.py @@ -56,9 +56,10 @@ class AuthenticationMixin: }, created=datetime.fromtimestamp(init_ts) ) + from sqlalchemy import select db_session.add(new_key) - db_session.commit() - new_key = db_session.query(KnownKey).filter(KnownKey.seed_hash == new_key.seed_hash).first() + await db_session.commit() + new_key = (await db_session.execute(select(KnownKey).where(KnownKey.seed_hash == new_key.seed_hash))).scalars().first() assert new_key, "Key not created" make_log("auth", f"[new-K] User {user_id} created new {token_type} key {new_key.id}") return { diff --git a/app/core/background/convert_service.py b/app/core/background/convert_service.py index 0d243ad..692bb12 100644 --- a/app/core/background/convert_service.py +++ b/app/core/background/convert_service.py @@ -6,7 +6,7 @@ import json import shutil import magic # python-magic for MIME detection from base58 import b58decode, b58encode -from sqlalchemy import and_, or_ +from sqlalchemy import and_, or_, select from app.core.models.node_storage import StoredContent from app.core.models._telegram import Wrapped_CBotChat from app.core._utils.send_status import send_status @@ -19,9 +19,9 @@ from app.core.content.content_id import ContentId async def convert_loop(memory): - with db_session() as session: + async with db_session() as session: # Query for unprocessed encrypted content - unprocessed_encrypted_content = session.query(StoredContent).filter( + unprocessed_encrypted_content = (await session.execute(select(StoredContent).where( and_( StoredContent.type == "onchain/content", or_( @@ -29,15 +29,15 @@ async def convert_loop(memory): StoredContent.ipfs_cid == None, ) ) - ).first() + ))).scalars().first() if not unprocessed_encrypted_content: make_log("ConvertProcess", "No content to convert", level="debug") return # Достаем расшифрованный файл - decrypted_content = session.query(StoredContent).filter( + decrypted_content = (await session.execute(select(StoredContent).where( StoredContent.id == unprocessed_encrypted_content.decrypted_content_id - ).first() + ))).scalars().first() if not decrypted_content: make_log("ConvertProcess", "Decrypted content not found", level="error") return @@ -78,7 +78,7 @@ async def convert_loop(memory): option_name: decrypted_content.hash for option_name in ['high', 'low', 'low_preview'] } } - session.commit() + await session.commit() return # ==== Конвертация для видео или аудио: оригинальная логика ==== @@ -171,9 +171,7 @@ async def convert_loop(memory): file_hash = b58encode(bytes.fromhex(file_hash)).decode() # Save new StoredContent if not exists - if not session.query(StoredContent).filter( - StoredContent.hash == file_hash - ).first(): + if not (await session.execute(select(StoredContent).where(StoredContent.hash == file_hash))).scalars().first(): new_content = StoredContent( type="local/content_bin", hash=file_hash, @@ -183,7 +181,7 @@ async def convert_loop(memory): created=datetime.now(), ) session.add(new_content) - session.commit() + await session.commit() save_path = os.path.join(UPLOADS_DIR, file_hash) try: @@ -233,13 +231,13 @@ async def convert_loop(memory): **unprocessed_encrypted_content.meta, 'converted_content': converted_content } - session.commit() + await session.commit() # Notify user if needed if not unprocessed_encrypted_content.meta.get('upload_notify_msg_id'): - wallet_owner_connection = session.query(WalletConnection).filter( + wallet_owner_connection = (await session.execute(select(WalletConnection).where( WalletConnection.wallet_address == unprocessed_encrypted_content.owner_address - ).order_by(WalletConnection.id.desc()).first() + ).order_by(WalletConnection.id.desc()))).scalars().first() if wallet_owner_connection: wallet_owner_user = wallet_owner_connection.user bot = Wrapped_CBotChat( @@ -249,7 +247,7 @@ async def convert_loop(memory): db_session=session ) unprocessed_encrypted_content.meta['upload_notify_msg_id'] = await bot.send_content(session, unprocessed_encrypted_content) - session.commit() + await session.commit() async def main_fn(memory): diff --git a/app/core/background/indexer_service.py b/app/core/background/indexer_service.py index fedd4c0..496d7ac 100644 --- a/app/core/background/indexer_service.py +++ b/app/core/background/indexer_service.py @@ -17,6 +17,7 @@ from app.core._utils.resolve_content import resolve_content from app.core.models.wallet_connection import WalletConnection from app.core._keyboards import get_inline_keyboard from app.core.models._telegram import Wrapped_CBotChat +from sqlalchemy import select, and_, desc from app.core.storage import db_session import os import traceback @@ -33,7 +34,7 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: platform_found = True make_log("Indexer", "Service running", level="debug") - with db_session() as session: + async with db_session() as session: try: result = await toncenter.run_get_method('EQD8TJ8xEWB1SpnRE4d89YO3jl0W0EiBnNS4IBaHaUmdfizE', 'get_pool_data') assert result['exit_code'] == 0, f"Error in get-method: {result}" @@ -41,21 +42,21 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: assert result['stack'][1][0] == 'num', f"get second element is not num" usdt_per_ton = (int(result['stack'][0][1], 16) * 1e3) / int(result['stack'][1][1], 16) ton_per_star = 0.014 / usdt_per_ton - ServiceConfig(session).set('live_tonPerStar', [ton_per_star, datetime.utcnow().timestamp()]) + await ServiceConfig(session).set('live_tonPerStar', [ton_per_star, datetime.utcnow().timestamp()]) make_log("TON_Daemon", f"TON per STAR price: {ton_per_star}", level="DEBUG") except BaseException as e: make_log("TON_Daemon", f"Error while saving TON per STAR price: {e}" + '\n' + traceback.format_exc(), level="ERROR") - new_licenses = session.query(UserContent).filter( + new_licenses = (await session.execute(select(UserContent).where( and_( ~UserContent.meta.contains({'notification_sent': True}), UserContent.type == 'nft/listen' ) - ).all() + ))).scalars().all() for new_license in new_licenses: - licensed_content = session.query(StoredContent).filter( + licensed_content = (await session.execute(select(StoredContent).where( StoredContent.id == new_license.content_id - ).first() + ))).scalars().first() if not licensed_content: make_log("Indexer", f"Licensed content not found: {new_license.content_id}", level="error") @@ -70,10 +71,12 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: session, licensed_content ) - wallet_owner_connection = session.query(WalletConnection).filter_by( - wallet_address=licensed_content.owner_address, - invalidated=False - ).order_by(desc(WalletConnection.id)).first() + wallet_owner_connection = (await session.execute( + select(WalletConnection).where( + WalletConnection.wallet_address == licensed_content.owner_address, + WalletConnection.invalidated == False + ).order_by(desc(WalletConnection.id)) + )).scalars().first() wallet_owner_user = wallet_owner_connection.user if wallet_owner_user.telegram_id: wallet_owner_bot = Wrapped_CBotChat(memory._telegram_bot, chat_id=wallet_owner_user.telegram_id, user=wallet_owner_user, db_session=session) @@ -89,21 +92,19 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: make_log("IndexerSendNewLicense", f"Error: {e}" + '\n' + traceback.format_exc(), level="error") new_license.meta = {**new_license.meta, 'notification_sent': True} - session.commit() + await session.commit() - content_without_cid = session.query(StoredContent).filter( - StoredContent.content_id == None - ) + content_without_cid = (await session.execute(select(StoredContent).where(StoredContent.content_id == None))).scalars().all() for target_content in content_without_cid: target_cid = target_content.cid.serialize_v2() make_log("Indexer", f"Content without CID: {target_content.hash}, setting CID: {target_cid}", level="debug") target_content.content_id = target_cid - session.commit() + await session.commit() - last_known_index_ = session.query(StoredContent).filter( - StoredContent.onchain_index != None - ).order_by(StoredContent.onchain_index.desc()).first() + last_known_index_ = (await session.execute( + select(StoredContent).where(StoredContent.onchain_index != None).order_by(StoredContent.onchain_index.desc()) + )).scalars().first() last_known_index = last_known_index_.onchain_index if last_known_index_ else 0 last_known_index = max(last_known_index, 0) make_log("Indexer", f"Last known index: {last_known_index}", level="debug") @@ -196,14 +197,13 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: user_wallet_connection = None if item_owner_address: - user_wallet_connection = session.query(WalletConnection).filter( + user_wallet_connection = (await session.execute(select(WalletConnection).where( WalletConnection.wallet_address == item_owner_address.to_string(1, 1, 1) - ).first() + ))).scalars().first() - encrypted_stored_content = session.query(StoredContent).filter( - StoredContent.hash == item_content_hash_str, - # StoredContent.type.like("local%") - ).first() + encrypted_stored_content = (await session.execute(select(StoredContent).where( + StoredContent.hash == item_content_hash_str + ))).scalars().first() if encrypted_stored_content: is_duplicate = encrypted_stored_content.type.startswith("onchain") \ and encrypted_stored_content.onchain_index != item_index @@ -234,14 +234,15 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: ) try: - for hint_message in session.query(KnownTelegramMessage).filter( + result = await session.execute(select(KnownTelegramMessage).where( and_( KnownTelegramMessage.chat_id == user.telegram_id, KnownTelegramMessage.type == 'hint', cast(KnownTelegramMessage.meta['encrypted_content_hash'], String) == encrypted_stored_content.hash, KnownTelegramMessage.deleted == False ) - ).all(): + )) + for hint_message in result.scalars().all(): await user_uploader_wrapper.delete_message(hint_message.message_id) except BaseException as e: make_log("Indexer", f"Error while deleting hint messages: {e}" + '\n' + traceback.format_exc(), level="error") @@ -260,7 +261,7 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: **item_metadata_packed } - session.commit() + await session.commit() return platform_found, seqno else: item_metadata_packed['copied_from'] = encrypted_stored_content.id @@ -282,7 +283,7 @@ async def indexer_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: updated=datetime.now() ) session.add(onchain_stored_content) - session.commit() + await session.commit() make_log("Indexer", f"Item indexed: {item_content_hash_str}", level="info") last_known_index += 1 diff --git a/app/core/background/license_service.py b/app/core/background/license_service.py index 17098ab..a6d1a07 100644 --- a/app/core/background/license_service.py +++ b/app/core/background/license_service.py @@ -3,7 +3,7 @@ from base64 import b64decode from datetime import datetime, timedelta from base58 import b58encode -from sqlalchemy import and_, or_ +from sqlalchemy import and_, or_, select, desc from tonsdk.boc import Cell from tonsdk.utils import Address @@ -27,7 +27,7 @@ import traceback async def license_index_loop(memory, platform_found: bool, seqno: int) -> [bool, int]: make_log("LicenseIndex", "Service running", level="debug") - with db_session() as session: + async with db_session() as session: async def check_telegram_stars_transactions(): # Проверка звездных telegram транзакций, обновление paid offset = {'desc': 'Статичное число заранее известного количества транзакций, которое даже не знает наш бот', 'value': 1}['value'] + \ @@ -45,19 +45,19 @@ async def license_index_loop(memory, platform_found: bool, seqno: int) -> [bool, continue try: - existing_invoice = session.query(StarsInvoice).filter( + existing_invoice = (await session.execute(select(StarsInvoice).where( StarsInvoice.external_id == star_payment.source.invoice_payload - ).first() + ))).scalars().first() if not existing_invoice: continue if star_payment.amount == existing_invoice.amount: if not existing_invoice.paid: existing_invoice.paid = True - session.commit() + await session.commit() - licensed_content = session.query(StoredContent).filter(StoredContent.hash == existing_invoice.content_hash).first() - user = session.query(User).filter(User.id == existing_invoice.user_id).first() + licensed_content = (await session.execute(select(StoredContent).where(StoredContent.hash == existing_invoice.content_hash))).scalars().first() + user = (await session.execute(select(User).where(User.id == existing_invoice.user_id))).scalars().first() await (Wrapped_CBotChat(memory._client_telegram_bot, chat_id=user.telegram_id, user=user, db_session=session)).send_content( session, licensed_content @@ -73,9 +73,10 @@ async def license_index_loop(memory, platform_found: bool, seqno: int) -> [bool, make_log("StarsProcessing", f"Error: {e}" + '\n' + traceback.format_exc(), level="error") # Проверка кошельков пользователей на появление новых NFT, добавление их в базу как неопознанные - for user in session.query(User).filter( + users = (await session.execute(select(User).where( User.last_use > datetime.now() - timedelta(hours=4) - ).order_by(User.updated.asc()).all(): + ).order_by(User.updated.asc()))).scalars().all() + for user in users: user_wallet_address = user.wallet_address(session) if not user_wallet_address: make_log("LicenseIndex", f"User {user.id} has no wallet address", level="info") @@ -91,17 +92,17 @@ async def license_index_loop(memory, platform_found: bool, seqno: int) -> [bool, try: await user.scan_owned_user_content(session) user.meta = {**user.meta, 'last_updated_licenses': datetime.now().isoformat()} - session.commit() + await session.commit() except BaseException as e: make_log("LicenseIndex", f"Error: {e}" + '\n' + traceback.format_exc(), level="error") # Проверка NFT на актуальность данных, в том числе уже проверенные - process_content = session.query(UserContent).filter( + process_content = (await session.execute(select(UserContent).where( and_( UserContent.type.startswith('nft/'), UserContent.updated < (datetime.now() - timedelta(minutes=60)), ) - ).order_by(UserContent.updated.asc()).first() + ).order_by(UserContent.updated.asc()))).scalars().first() if process_content: make_log("LicenseIndex", f"Syncing content with blockchain: {process_content.id}", level="info") try: @@ -110,7 +111,7 @@ async def license_index_loop(memory, platform_found: bool, seqno: int) -> [bool, make_log("LicenseIndex", f"Error: {e}" + '\n' + traceback.format_exc(), level="error") finally: process_content.updated = datetime.now() - session.commit() + await session.commit() return platform_found, seqno diff --git a/app/core/background/ton_service.py b/app/core/background/ton_service.py index e9dba98..de505f9 100644 --- a/app/core/background/ton_service.py +++ b/app/core/background/ton_service.py @@ -125,7 +125,7 @@ async def main_fn(memory): sw_seqno_value = await get_sw_seqno() make_log("TON", f"Service running ({sw_seqno_value})", level="debug") - with db_session() as session: + async with db_session() as session: # Проверка отправленных сообщений await send_status("ton_daemon", f"working: processing in-txs (seqno={sw_seqno_value})") async def process_incoming_transaction(transaction: dict): @@ -142,14 +142,17 @@ async def main_fn(memory): in_msg_created_at = in_msg_slice.read_uint(64) in_msg_epoch = int(in_msg_created_at // (60 * 60)) in_msg_seqno = HighloadQueryId.from_query_id(in_msg_query_id).to_seqno() + from sqlalchemy import select in_msg_blockchain_task = ( - session.query(BlockchainTask).filter( - and_( - BlockchainTask.seqno == in_msg_seqno, - BlockchainTask.epoch == in_msg_epoch, + await session.execute( + select(BlockchainTask).where( + and_( + BlockchainTask.seqno == in_msg_seqno, + BlockchainTask.epoch == in_msg_epoch, + ) ) ) - ).first() + ).scalars().first() if not in_msg_blockchain_task: return @@ -157,7 +160,7 @@ async def main_fn(memory): in_msg_blockchain_task.status = 'done' in_msg_blockchain_task.transaction_hash = transaction_hash in_msg_blockchain_task.transaction_lt = transaction_lt - session.commit() + await session.commit() for blockchain_message in [transaction['in_msg']]: try: @@ -177,11 +180,11 @@ async def main_fn(memory): await send_status("ton_daemon", f"working: processing out-txs (seqno={sw_seqno_value})") # Отправка подписанных сообщений - for blockchain_task in ( - session.query(BlockchainTask).filter( - BlockchainTask.status == 'processing', - ).order_by(BlockchainTask.updated.asc()).all() - ): + from sqlalchemy import select + _processing = (await session.execute(select(BlockchainTask).where( + BlockchainTask.status == 'processing' + ).order_by(BlockchainTask.updated.asc()))).scalars().all() + for blockchain_task in _processing: make_log("TON_Daemon", f"Processing task (processing) {blockchain_task.id}") query_boc = bytes.fromhex(blockchain_task.meta['signed_message']) errors_list = [] @@ -210,23 +213,22 @@ async def main_fn(memory): # or sum([int("terminating vm with exit code 36" in e) for e in errors_list]) > 0: make_log("TON_Daemon", f"Task {blockchain_task.id} done", level="DEBUG") blockchain_task.status = 'done' - session.commit() + await session.commit() continue await asyncio.sleep(0.5) await send_status("ton_daemon", f"working: creating new messages (seqno={sw_seqno_value})") # Создание новых подписей - for blockchain_task in ( - session.query(BlockchainTask).filter(BlockchainTask.status == 'wait').all() - ): + _waiting = (await session.execute(select(BlockchainTask).where(BlockchainTask.status == 'wait'))).scalars().all() + for blockchain_task in _waiting: try: # Check processing tasks in current epoch < 3_000_000 - if ( - session.query(BlockchainTask).filter( - BlockchainTask.epoch == blockchain_task.epoch, - ).count() > 3_000_000 - ): + from sqlalchemy import func + _cnt = (await session.execute(select(func.count()).select_from(BlockchainTask).where( + BlockchainTask.epoch == blockchain_task.epoch + ))).scalar() or 0 + if _cnt > 3_000_000: make_log("TON", f"Too many processing tasks in epoch {blockchain_task.epoch}", level="error") await send_status("ton_daemon", f"working: too many tasks in epoch {blockchain_task.epoch}") await asyncio.sleep(5) @@ -235,10 +237,11 @@ async def main_fn(memory): sign_created = int(datetime.utcnow().timestamp()) - 60 try: current_epoch = int(datetime.utcnow().timestamp() // (60 * 60)) + from sqlalchemy import func max_epoch_seqno = ( - session.query(func.max(BlockchainTask.seqno)).filter( + (await session.execute(select(func.max(BlockchainTask.seqno)).where( BlockchainTask.epoch == current_epoch - ).scalar() or 0 + ))).scalar() or 0 ) current_epoch_shift = 3_000_000 if current_epoch % 2 == 0 else 0 current_seqno = max_epoch_seqno + 1 + (current_epoch_shift if max_epoch_seqno == 0 else 0) @@ -266,7 +269,7 @@ async def main_fn(memory): 'sign_created': sign_created, 'signed_message': query_boc.hex(), } - session.commit() + await session.commit() make_log("TON", f"Created signed message for task {blockchain_task.id}" + '\n' + traceback.format_exc(), level="info") except BaseException as e: make_log("TON", f"Error processing task {blockchain_task.id}: {e}" + '\n' + traceback.format_exc(), level="error") @@ -287,4 +290,3 @@ async def main_fn(memory): - diff --git a/app/core/content/utils.py b/app/core/content/utils.py index a607d3d..0f597fa 100644 --- a/app/core/content/utils.py +++ b/app/core/content/utils.py @@ -26,7 +26,9 @@ async def create_new_content( content_hash_bin = sha256(content_bin).digest() content_hash_b58 = b58encode(content_hash_bin).decode() - new_content = db_session.query(StoredContent).filter(StoredContent.hash == content_hash_b58).first() + from sqlalchemy import select + result = await db_session.execute(select(StoredContent).where(StoredContent.hash == content_hash_b58)) + new_content = result.scalars().first() if new_content: return new_content, False @@ -38,8 +40,9 @@ async def create_new_content( ) db_session.add(new_content) - db_session.commit() - new_content = db_session.query(StoredContent).filter(StoredContent.hash == content_hash_b58).first() + await db_session.commit() + result = await db_session.execute(select(StoredContent).where(StoredContent.hash == content_hash_b58)) + new_content = result.scalars().first() assert new_content, "Content not created (through utils)" content_filepath = os.path.join(UPLOADS_DIR, content_hash_b58) async with aiofiles.open(content_filepath, 'wb') as file: diff --git a/app/core/models/_config.py b/app/core/models/_config.py index 10b2442..4891d8e 100644 --- a/app/core/models/_config.py +++ b/app/core/models/_config.py @@ -1,6 +1,6 @@ from app.core.models.base import AlchemyBase -from sqlalchemy import Column, BigInteger, Integer, String, ForeignKey, DateTime, JSON, Boolean +from sqlalchemy import Column, Integer, String, JSON, select class ServiceConfigValue(AlchemyBase): @@ -19,20 +19,18 @@ class ServiceConfig: def __init__(self, session): self.session = session - def get(self, key, default=None): - result = self.session.query(ServiceConfigValue).filter(ServiceConfigValue.key == key).first() + async def get(self, key, default=None): + result = (await self.session.execute(select(ServiceConfigValue).where(ServiceConfigValue.key == key))).scalars().first() return (result.value if result else None) or default - def set(self, key, value): - config_value = self.session.query(ServiceConfigValue).filter( - ServiceConfigValue.key == key - ).first() - if not config_value: - config_value = ServiceConfigValue(key=key) - self.session.add(config_value) - self.session.commit() - return self.set(key, value) + async def set(self, key, value): + result = (await self.session.execute(select(ServiceConfigValue).where(ServiceConfigValue.key == key))).scalars().first() + if not result: + result = ServiceConfigValue(key=key) + self.session.add(result) + await self.session.commit() + return await self.set(key, value) - config_value.packed_value = {'value': value} - self.session.commit() + result.packed_value = {'value': value} + await self.session.commit() return diff --git a/app/core/models/_telegram/templates/player.py b/app/core/models/_telegram/templates/player.py index 1e3d378..8f7d1bc 100644 --- a/app/core/models/_telegram/templates/player.py +++ b/app/core/models/_telegram/templates/player.py @@ -1,4 +1,4 @@ -from sqlalchemy import and_ +from sqlalchemy import and_, select from app.core.models.node_storage import StoredContent from app.core.models.content.user_content import UserContent, UserAction from app.core.logger import make_log @@ -27,17 +27,17 @@ class PlayerTemplates: if not content.encrypted: local_content = content else: - local_content = db_session.query(StoredContent).filter_by( - id=content.decrypted_content_id - ).first() + local_content = (await db_session.execute(select(StoredContent).where(StoredContent.id == content.decrypted_content_id))).scalars().first() # TODO: add check decrypted_content by .format_json()['content_cid'] if local_content: cd_log += f"Decrypted: {local_content.hash}. " else: cd_log += "Can't decrypt content. " - user_wallet_address = self.user.wallet_address(self.db_session) - user_existing_license = self.db_session.query(UserContent).filter_by(user_id=self.user.id, content_id=content.id).first() + user_wallet_address = await self.user.wallet_address_async(self.db_session) + user_existing_license = (await self.db_session.execute(select(UserContent).where( + and_(UserContent.user_id == self.user.id, UserContent.content_id == content.id) + ))).scalars().first() if local_content: content_meta = content.json_format() @@ -48,12 +48,12 @@ class PlayerTemplates: except: content_type, content_encoding = 'application', 'x-binary' - content_metadata = StoredContent.from_cid(db_session, content_meta.get('metadata_cid') or None) + content_metadata = await StoredContent.from_cid_async(db_session, content_meta.get('metadata_cid') or None) with open(content_metadata.filepath, 'r') as f: content_metadata_json = json.loads(f.read()) try: - cover_content = StoredContent.from_cid(self.db_session, content_meta.get('cover_cid') or None) + cover_content = await StoredContent.from_cid_async(self.db_session, content_meta.get('cover_cid') or None) cd_log += f"Cover content: {cover_content.cid.serialize_v2()}. " except BaseException as e: cd_log += f"Can't get cover content: {e}. " @@ -88,12 +88,15 @@ class PlayerTemplates:
🔴 «открыть в MY»
""" make_log("TG-Player", f"Send content {content_type} ({content_encoding}) to chat {self._chat_id}. {cd_log}") - for kmsg in self.db_session.query(KnownTelegramMessage).filter_by( - content_id=content.id, - chat_id=self._chat_id, - type=f'content/{content_type}', - deleted=False - ).all(): + kmsgs = (await self.db_session.execute(select(KnownTelegramMessage).where( + and_( + KnownTelegramMessage.content_id == content.id, + KnownTelegramMessage.chat_id == self._chat_id, + KnownTelegramMessage.type == f'content/{content_type}', + KnownTelegramMessage.deleted == False + ) + ))).scalars().all() + for kmsg in kmsgs: await self.delete_message(kmsg.message_id) r = await tg_process_template( diff --git a/app/core/models/_telegram/wrapped_bot.py b/app/core/models/_telegram/wrapped_bot.py index b4c1b5e..60aa8b6 100644 --- a/app/core/models/_telegram/wrapped_bot.py +++ b/app/core/models/_telegram/wrapped_bot.py @@ -1,7 +1,7 @@ from aiogram import Bot, types from datetime import datetime, timedelta -from sqlalchemy import and_ +from sqlalchemy import and_, select from app.core.logger import make_log from app.core.models.messages import KnownTelegramMessage @@ -46,14 +46,15 @@ class Wrapped_CBotChat(T, PlayerTemplates): if self.db_session: if message_type == 'common': ci = 0 - for oc_msg in self.db_session.query(KnownTelegramMessage).filter( + result = await self.db_session.execute(select(KnownTelegramMessage).where( and_( KnownTelegramMessage.type == 'common', KnownTelegramMessage.bot_id == self.bot_id, KnownTelegramMessage.chat_id == self._chat_id, KnownTelegramMessage.deleted == False ) - ).all(): + )) + for oc_msg in result.scalars().all(): make_log(self, f"Delete old message {oc_msg.message_id} {oc_msg.type} {oc_msg.bot_id} {oc_msg.chat_id}") await self.delete_message(oc_msg.message_id) ci += 1 @@ -75,7 +76,7 @@ class Wrapped_CBotChat(T, PlayerTemplates): content_id=content_id ) ) - self.db_session.commit() + await self.db_session.commit() else: make_log(self, f"Unknown result type: {type(result)}", level='warning') @@ -127,14 +128,16 @@ class Wrapped_CBotChat(T, PlayerTemplates): message_id )): if self.db_session: - known_message = self.db_session.query(KnownTelegramMessage).filter( - KnownTelegramMessage.bot_id == self.bot_id, - KnownTelegramMessage.chat_id == self._chat_id, - KnownTelegramMessage.message_id == message_id - ).first() + known_message = (await self.db_session.execute(select(KnownTelegramMessage).where( + and_( + KnownTelegramMessage.bot_id == self.bot_id, + KnownTelegramMessage.chat_id == self._chat_id, + KnownTelegramMessage.message_id == message_id + ) + ))).scalars().first() if known_message: known_message.deleted = True - self.db_session.commit() + await self.db_session.commit() except Exception as e: make_log(self, f"Error deleting message {self._chat_id}/{message_id}. Error: {e}", level='warning') return None diff --git a/app/core/models/asset.py b/app/core/models/asset.py index 3cdbeea..fc2ca5c 100644 --- a/app/core/models/asset.py +++ b/app/core/models/asset.py @@ -29,22 +29,21 @@ class Asset(AlchemyBase): AlchemyBase.metadata.create_all(engine) @classmethod - def find(cls, session, **kwargs): + async def find_async(cls, session, **kwargs): + from sqlalchemy import select, func if 'symbol' in kwargs: kwargs['symbol'] = kwargs['symbol'].upper() - result = session.query(cls).filter_by(**kwargs) - results_count = result.count() - if results_count == 0: - any_count = session.query(cls).count() - if any_count == 0: - init_asset = cls(**DEFAULT_ASSET_INITOBJ) - session.add(init_asset) - session.commit() - return cls.find(session, **kwargs) + result = await session.execute(select(cls).filter_by(**kwargs)) + row = result.scalars().first() + if row: + return row - raise Exception(f"Asset not found: {kwargs}") - elif results_count == 1: - return result.first() - else: - raise Exception(f"Multiple assets found: {results_count}") + any_count = (await session.execute(select(func.count()).select_from(cls))).scalar() or 0 + if any_count == 0: + init_asset = cls(**DEFAULT_ASSET_INITOBJ) + session.add(init_asset) + await session.commit() + return await cls.find_async(session, **kwargs) + + raise Exception(f"Asset not found: {kwargs}") diff --git a/app/core/models/content/indexation_mixins.py b/app/core/models/content/indexation_mixins.py index 4edf042..61ee5d5 100644 --- a/app/core/models/content/indexation_mixins.py +++ b/app/core/models/content/indexation_mixins.py @@ -1,7 +1,7 @@ import traceback import base58 -from sqlalchemy import and_ +from sqlalchemy import and_, select from app.core.logger import make_log from app.core.models import StoredContent @@ -57,13 +57,9 @@ class UserContentIndexationMixin: values_slice = cc_indexator_data['values'].begin_parse() content_hash_b58 = base58.b58encode(bytes.fromhex(hex(values_slice.read_uint(256))[2:])).decode() make_log("UserContent", f"License ({self.onchain_address}) content hash: {content_hash_b58}", level="info") - stored_content = db_session.query(StoredContent).filter( - and_( - StoredContent.type == 'onchain/content', - StoredContent.hash == content_hash_b58, - - ) - ).first() + stored_content = (await db_session.execute(select(StoredContent).where( + and_(StoredContent.type == 'onchain/content', StoredContent.hash == content_hash_b58) + ))).scalars().first() trusted_cop_address_result = await toncenter.run_get_method(stored_content.meta['item_address'], 'get_nft_address_by_index', [['num', cc_indexator_data['index']]]) assert trusted_cop_address_result.get('exit_code', -1) == 0, "Trusted cop address error" trusted_cop_address = Cell.one_from_boc(b64decode(trusted_cop_address_result['stack'][0][1]['bytes'])).begin_parse().read_msg_addr().to_string(1, 1, 1) @@ -72,7 +68,7 @@ class UserContentIndexationMixin: self.owner_address = cc_indexator_data['owner_address'] self.type = 'nft/listen' self.content_id = stored_content.id - db_session.commit() + await db_session.commit() except BaseException as e: errored = True make_log("UserContent", f"Error: {e}" + '\n' + traceback.format_exc(), level="error") @@ -80,7 +76,6 @@ class UserContentIndexationMixin: if errored is True: self.type = 'nft/unknown' self.content_id = None - db_session.commit() - + await db_session.commit() diff --git a/app/core/models/node_storage.py b/app/core/models/node_storage.py index aafe1c6..7573afe 100644 --- a/app/core/models/node_storage.py +++ b/app/core/models/node_storage.py @@ -96,6 +96,30 @@ class StoredContent(AlchemyBase, AudioContentMixin): make_log("NodeStorage.open_content", f"Can't open content: {self.id} {e}", level='warning') raise e + async def open_content_async(self, db_session, content_type=None): + from sqlalchemy import select + try: + decrypted_content = self if not self.encrypted else None + encrypted_content = self if self.encrypted else None + if not decrypted_content: + decrypted_content = (await db_session.execute(select(StoredContent).where(StoredContent.id == self.decrypted_content_id))).scalars().first() + else: + encrypted_content = (await db_session.execute(select(StoredContent).where(StoredContent.decrypted_content_id == self.id))).scalars().first() + + assert decrypted_content, "Can't get decrypted content" + assert encrypted_content, "Can't get encrypted content" + _ct = content_type or decrypted_content.json_format()['content_type'] + content_type = _ct.split('/')[0] if _ct else 'application' + + return { + 'encrypted_content': encrypted_content, + 'decrypted_content': decrypted_content, + 'content_type': content_type or 'application/x-binary' + } + except BaseException as e: + make_log("NodeStorage.open_content_async", f"Can't open content: {self.id} {e}", level='warning') + raise e + def json_format(self): extra_fields = {} if self.type.startswith('local'): @@ -155,3 +179,15 @@ class StoredContent(AlchemyBase, AudioContentMixin): assert content, "Content not found" return content + @classmethod + async def from_cid_async(cls, db_session, content_id): + from sqlalchemy import select + if isinstance(content_id, str): + cid = ContentId.deserialize(content_id) + else: + cid = content_id + + result = await db_session.execute(select(StoredContent).where(StoredContent.hash == cid.content_hash_b58)) + content = result.scalars().first() + assert content, "Content not found" + return content diff --git a/app/core/models/user/wallet_mixin.py b/app/core/models/user/wallet_mixin.py index 5b80d43..d7e18ba 100644 --- a/app/core/models/user/wallet_mixin.py +++ b/app/core/models/user/wallet_mixin.py @@ -10,18 +10,21 @@ from httpx import AsyncClient class WalletMixin: - def wallet_connection(self, db_session): - return db_session.query(WalletConnection).filter( - WalletConnection.user_id == self.id, - WalletConnection.invalidated == False - ).order_by(WalletConnection.created.desc()).first() + async def wallet_connection_async(self, db_session): + from sqlalchemy import select, and_, desc + result = await db_session.execute( + select(WalletConnection) + .where(and_(WalletConnection.user_id == self.id, WalletConnection.invalidated == False)) + .order_by(WalletConnection.created.desc()) + ) + return result.scalars().first() - def wallet_address(self, db_session): - wallet_connection = self.wallet_connection(db_session) - return wallet_connection.wallet_address if wallet_connection else None + async def wallet_address_async(self, db_session): + wc = await self.wallet_connection_async(db_session) + return wc.wallet_address if wc else None async def scan_owned_user_content(self, db_session): - user_wallet_address = self.wallet_address(db_session) + user_wallet_address = await self.wallet_address_async(db_session) async def get_nft_items_list(): try: @@ -40,9 +43,8 @@ class WalletMixin: item_address = Address(nft_item['address']).to_string(1, 1, 1) owner_address = Address(nft_item['owner']['address']).to_string(1, 1, 1) - user_content = db_session.query(UserContent).filter( - UserContent.onchain_address == item_address - ).first() + from sqlalchemy import select + user_content = (await db_session.execute(select(UserContent).where(UserContent.onchain_address == item_address))).scalars().first() if user_content: continue @@ -57,18 +59,18 @@ class WalletMixin: created=datetime.now(), meta={}, user_id=self.id, - wallet_connection_id=self.wallet_connection(db_session).id, + wallet_connection_id=(await self.wallet_connection_async(db_session)).id, status="active" ) db_session.add(user_content) - db_session.commit() + await db_session.commit() make_log(self, f"New onchain NFT found: {item_address}", level='info') async def ____scan_owned_user_content(self, db_session): page_id = -1 page_size = 100 have_next_page = True - user_wallet_address = self.wallet_address(db_session) + user_wallet_address = await self.wallet_address_async(db_session) while have_next_page: page_id += 1 nfts_list = await toncenter.get_nft_items(limit=100, offset=page_id * page_size, owner_address=user_wallet_address) @@ -81,9 +83,8 @@ class WalletMixin: item_address = Address(nft_item['address']).to_string(1, 1, 1) owner_address = Address(nft_item['owner_address']).to_string(1, 1, 1) - user_content = db_session.query(UserContent).filter( - UserContent.onchain_address == item_address - ).first() + from sqlalchemy import select + user_content = (await db_session.execute(select(UserContent).where(UserContent.onchain_address == item_address))).scalars().first() if user_content: continue @@ -105,11 +106,11 @@ class WalletMixin: 'metadata_uri': nft_content, }, user_id=self.id, - wallet_connection_id=self.wallet_connection(db_session).id, + wallet_connection_id=(await self.wallet_connection_async(db_session)).id, status="active" ) db_session.add(user_content) - db_session.commit() + await db_session.commit() make_log(self, f"New onchain NFT found: {item_address}", level='info') except BaseException as e: @@ -122,6 +123,6 @@ class WalletMixin: except BaseException as e: make_log(self, f"Error while scanning user content: {e}", level='error') - return self.db_session.query(UserContent).filter( - UserContent.user_id == self.id - ).offset(offset).limit(limit).all() + from sqlalchemy import select + result = await db_session.execute(select(UserContent).where(UserContent.user_id == self.id).offset(offset).limit(limit)) + return result.scalars().all() diff --git a/app/core/storage.py b/app/core/storage.py index 9b11725..918ec11 100644 --- a/app/core/storage.py +++ b/app/core/storage.py @@ -1,45 +1,57 @@ import time -from contextlib import contextmanager +from contextlib import asynccontextmanager -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.sql import text +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession -from app.core._config import MYSQL_URI, MYSQL_DATABASE +from app.core._config import DATABASE_URL from app.core.logger import make_log -from sqlalchemy.pool import NullPool - -engine = create_engine(MYSQL_URI, poolclass=NullPool) #, echo=True) -Session = sessionmaker(bind=engine) -database_initialized = False -while not database_initialized: +def _to_async_dsn(url: str) -> str: + # Convert psycopg2 DSN to asyncpg DSN + # postgresql+psycopg2://user:pass@host:5432/db -> postgresql+asyncpg://user:pass@host:5432/db + return url.replace("+psycopg2", "+asyncpg") + + +# Async engine for PostgreSQL +engine = create_async_engine( + _to_async_dsn(DATABASE_URL), + pool_size=10, + max_overflow=20, + pool_timeout=30, + pool_recycle=1800, + pool_pre_ping=True, +) + +AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + + +async def wait_db_ready(): + ready = False + while not ready: + try: + async with engine.connect() as conn: + await conn.execute(text("SELECT 1")) + ready = True + except Exception as e: + make_log("SQL", 'PostgreSQL is not ready yet: ' + str(e), level='debug') + time.sleep(1) + + +@asynccontextmanager +async def db_session(auto_commit: bool = False): + session: AsyncSession = AsyncSessionLocal() try: - with Session() as session: - databases_list = session.execute(text("SHOW DATABASES;")) - databases_list = [row[0] for row in databases_list] - make_log("SQL", 'Database list: ' + str(databases_list), level='debug') - assert MYSQL_DATABASE in databases_list, 'Database not found' - database_initialized = True - except Exception as e: - make_log("SQL", 'MariaDB is not ready yet: ' + str(e), level='debug') - time.sleep(1) - -engine = create_engine(f"{MYSQL_URI}/{MYSQL_DATABASE}", poolclass=NullPool) -Session = sessionmaker(bind=engine) - - -@contextmanager -def db_session(auto_commit=False): - _session = Session() - try: - yield _session - if auto_commit is True: - _session.commit() + yield session + if auto_commit: + await session.commit() except BaseException as e: - _session.rollback() + await session.rollback() raise e finally: - _session.close() + await session.close() + +def new_session() -> AsyncSession: + return AsyncSessionLocal() diff --git a/app/core/transactions.py b/app/core/transactions.py index 0a08a59..3400882 100644 --- a/app/core/transactions.py +++ b/app/core/transactions.py @@ -1,19 +1,19 @@ from datetime import datetime +from sqlalchemy import select, and_, func from app.core.logger import make_log from app.core.models import Memory, User, UserBalance, Asset, InternalTransaction from app.core.storage import db_session -def get_user_balance(session, user: User, asset: Asset) -> UserBalance: +async def get_user_balance(session, user: User, asset: Asset) -> UserBalance: assert user, "No user" assert asset, "No asset" - result = session.query(UserBalance).filter( - UserBalance.user_id == user.id, - UserBalance.asset_id == asset.id - ) - results_count = result.count() - if results_count == 0: + result = await session.execute(select(UserBalance).where( + and_(UserBalance.user_id == user.id, UserBalance.asset_id == asset.id) + )) + row = result.scalars().first() + if not row: user_balance = UserBalance( user_id=user.id, asset_id=asset.id, @@ -21,12 +21,9 @@ def get_user_balance(session, user: User, asset: Asset) -> UserBalance: created=datetime.now(), ) session.add(user_balance) - session.commit() - return get_user_balance(session, user, asset) - elif results_count == 1: - return result.first() - else: - raise Exception(f"Multiple user balances found: {results_count}") + await session.commit() + return await get_user_balance(session, user, asset) + return row async def make_internal_transaction( @@ -46,13 +43,13 @@ async def make_internal_transaction( raise Exception(f"Invalid amount: {amount}") abs_amount = abs(amount) - with db_session(auto_commit=False) as session: + async with db_session(auto_commit=False) as session: async with memory.transaction(): - user = session.query(User).filter_by(id=user_id).first() + user = (await session.execute(select(User).where(User.id == user_id))).scalars().first() assert user, "No user" - asset = session.query(Asset).filter_by(id=asset_id).first() + asset = (await session.execute(select(Asset).where(Asset.id == asset_id))).scalars().first() assert asset, "No asset" - user_balance = get_user_balance(session, user, asset) + user_balance = await get_user_balance(session, user, asset) assert user_balance, "No user balance" if is_spent is True: if abs_amount > user_balance.balance: @@ -71,6 +68,6 @@ async def make_internal_transaction( created=datetime.now(), ) session.add(internal_transaction) - session.commit() + await session.commit() make_log(user, f"Made internal transaction: {'-' if is_spent else ''}{abs_amount} {asset.symbol}, type: {type}") diff --git a/requirements.txt b/requirements.txt index 08aa8aa..402a037 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,8 @@ sanic==21.9.1 websockets==10.0 sqlalchemy==2.0.23 python-dotenv==1.0.0 -pymysql==1.1.0 +psycopg2-binary==2.9.9 +asyncpg==0.29.0 aiogram==3.13.0 pytonconnect==0.3.0 base58==2.1.1 @@ -16,5 +17,3 @@ pydub==0.25.1 pillow==10.2.0 ffmpeg-python==0.2.0 python-magic==0.4.27 - -