"""Wallet-based authentication for JAE-AI agent. Endpoints: POST /api/auth/nonce – issue sign-in nonce for a Solana pubkey POST /api/auth/verify – verify Ed25519 signature, issue JWT cookie GET /api/auth/whoami – introspect cookie + holdings → tier POST /api/auth/logout – clear cookie """ import json import os import secrets import time import datetime from pathlib import Path from flask import Blueprint, request, jsonify, make_response import jwt import requests as req import base58 from nacl.signing import VerifyKey from nacl.exceptions import BadSignatureError from agent_tiers import compute_tier, pick_model auth_bp = Blueprint('auth_bp', __name__, url_prefix='/api/auth') DATA_DIR = Path(__file__).parent / 'data' APIKEYS_FILE = DATA_DIR / 'apikeys.json' # Solana RPC endpoints (mainnet). Helius used if key present, else public RPC. PUBLIC_RPC = 'https://api.mainnet-beta.solana.com' # In-memory nonce store: address -> (nonce, expires_ts) _NONCES: dict[str, tuple[str, float]] = {} _NONCE_TTL = 300 # 5 minutes # Balance cache: address -> (lamports, ts) _BALANCE_CACHE: dict[str, tuple[int, float]] = {} _BALANCE_TTL = 60 # 1 minute COOKIE_NAME = 'jae_session' JWT_ALGO = 'HS256' JWT_EXPIRY_HOURS = 24 def _load_apikeys() -> dict: try: with open(APIKEYS_FILE) as f: return json.load(f) except Exception: return {} def _save_apikeys(keys: dict): try: with open(APIKEYS_FILE, 'w') as f: json.dump(keys, f, indent=2) except Exception as e: print(f'[auth] Failed to save apikeys: {e}') def get_jwt_secret() -> str: """Load JAE_JWT_SECRET from apikeys.json; generate + persist if missing.""" keys = _load_apikeys() secret = keys.get('jae_jwt_secret') or '' if not secret: secret = secrets.token_hex(32) keys['jae_jwt_secret'] = secret _save_apikeys(keys) return secret def get_rpc_url() -> str: keys = _load_apikeys() # Check custom slots for Helius for slot in ('custom1', 'custom2', 'custom3'): c = keys.get(slot) or {} name = (c.get('name') or '').lower() if 'helius' in name and c.get('key'): return f"https://mainnet.helius-rpc.com/?api-key={c['key']}" # Check a direct helius field h = keys.get('helius') or {} if h.get('api_key'): return f"https://mainnet.helius-rpc.com/?api-key={h['api_key']}" return PUBLIC_RPC def fetch_sol_balance(address: str) -> float: """Return SOL balance (float). Cached 60s. Returns 0.0 on error.""" now = time.time() cached = _BALANCE_CACHE.get(address) if cached and now - cached[1] < _BALANCE_TTL: return cached[0] / 1_000_000_000 try: r = req.post( get_rpc_url(), json={'jsonrpc': '2.0', 'id': 1, 'method': 'getBalance', 'params': [address]}, timeout=6, ) r.raise_for_status() lamports = int(r.json()['result']['value']) _BALANCE_CACHE[address] = (lamports, now) return lamports / 1_000_000_000 except Exception as e: print(f'[auth] getBalance failed for {address[:8]}…: {e}') return 0.0 def _cleanup_nonces(): now = time.time() stale = [k for k, (_, exp) in _NONCES.items() if exp < now] for k in stale: _NONCES.pop(k, None) def _valid_solana_address(addr: str) -> bool: if not isinstance(addr, str) or not (32 <= len(addr) <= 44): return False try: raw = base58.b58decode(addr) return len(raw) == 32 except Exception: return False def read_session() -> dict | None: """Decode JWT cookie. Returns payload or None.""" token = request.cookies.get(COOKIE_NAME) if not token: return None try: return jwt.decode(token, get_jwt_secret(), algorithms=[JWT_ALGO]) except Exception: return None # ── Endpoints ──────────────────────────────────────────────────────────── @auth_bp.route('/nonce', methods=['POST']) def issue_nonce(): data = request.get_json(silent=True) or {} address = (data.get('address') or '').strip() if not _valid_solana_address(address): return jsonify({'error': 'Invalid Solana address'}), 400 _cleanup_nonces() nonce = secrets.token_hex(16) expires_ts = time.time() + _NONCE_TTL expires_iso = datetime.datetime.utcfromtimestamp(expires_ts).isoformat() + 'Z' _NONCES[address] = (nonce, expires_ts) message = f'Sign in to jaeswift.xyz\nnonce: {nonce}\nexpires: {expires_iso}' return jsonify({'message': message, 'nonce': nonce, 'expires': expires_iso}) @auth_bp.route('/verify', methods=['POST']) def verify_signature(): data = request.get_json(silent=True) or {} address = (data.get('address') or '').strip() signature_b58 = (data.get('signature') or '').strip() nonce = (data.get('nonce') or '').strip() if not _valid_solana_address(address): return jsonify({'error': 'Invalid address'}), 400 if not signature_b58 or not nonce: return jsonify({'error': 'Missing signature or nonce'}), 400 stored = _NONCES.get(address) if not stored: return jsonify({'error': 'No active nonce for this address'}), 401 stored_nonce, expires_ts = stored if stored_nonce != nonce: return jsonify({'error': 'Nonce mismatch'}), 401 if time.time() > expires_ts: _NONCES.pop(address, None) return jsonify({'error': 'Nonce expired'}), 401 # Rebuild exact signed message expires_iso = datetime.datetime.utcfromtimestamp(expires_ts).isoformat() + 'Z' message = f'Sign in to jaeswift.xyz\nnonce: {nonce}\nexpires: {expires_iso}' try: pubkey_bytes = base58.b58decode(address) vk = VerifyKey(pubkey_bytes) # Try base58 then base64 decode for signature sig_bytes = None try: sig_bytes = base58.b58decode(signature_b58) except Exception: import base64 sig_bytes = base64.b64decode(signature_b58) vk.verify(message.encode('utf-8'), sig_bytes) except BadSignatureError: return jsonify({'error': 'Signature verification failed'}), 401 except Exception as e: return jsonify({'error': f'Signature decode error: {e}'}), 400 # Consume nonce _NONCES.pop(address, None) # Compute tier based on live SOL balance balance = fetch_sol_balance(address) tier = compute_tier(address, balance) payload = { 'address': address, 'tier': tier, 'iat': int(time.time()), 'exp': int(time.time()) + JWT_EXPIRY_HOURS * 3600, } token = jwt.encode(payload, get_jwt_secret(), algorithm=JWT_ALGO) resp = make_response(jsonify({ 'authenticated': True, 'address': address, 'tier': tier, 'balance_sol': round(balance, 4), 'model': pick_model(tier), 'expires_in': JWT_EXPIRY_HOURS * 3600, })) resp.set_cookie( COOKIE_NAME, token, max_age=JWT_EXPIRY_HOURS * 3600, httponly=True, secure=True, samesite='Lax', path='/', ) return resp @auth_bp.route('/whoami', methods=['GET']) def whoami(): sess = read_session() if not sess: return jsonify({ 'authenticated': False, 'address': None, 'tier': 'anonymous', 'model': pick_model('anonymous'), 'balance_sol': 0, 'holdings': {}, }) address = sess.get('address') # Re-fetch balance (cached) to refresh tier on the fly balance = fetch_sol_balance(address) if address else 0.0 tier = compute_tier(address, balance) return jsonify({ 'authenticated': True, 'address': address, 'tier': tier, 'model': pick_model(tier), 'balance_sol': round(balance, 4), 'holdings': {'sol': round(balance, 4)}, 'expires_at': sess.get('exp'), }) @auth_bp.route('/logout', methods=['POST']) def logout(): resp = make_response(jsonify({'ok': True})) resp.set_cookie(COOKIE_NAME, '', expires=0, path='/', secure=True, httponly=True, samesite='Lax') return resp