#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.11"
# dependencies = ["httpx>=0.27", "typer>=0.12", "keyring>=25"]
# ///
"""sprite-cp — scp-style file transfers via the Sprites API.

Usage:  sprite-cp [OPTIONS] SOURCE... DEST

  SOURCE and DEST use scp path syntax: a path is remote when it contains
  ':' before the first '/', e.g. mysprite:/home/user/file.txt.  All other
  paths are local; prefix with './' to force-local a path with a colon.

Examples:
  sprite-cp ./hello.txt       mysprite:/tmp/          # upload
  sprite-cp mysprite:/tmp/log ./log                   # download
  sprite-cp -r ./src          mysprite:/home/user/src # recursive upload
  sprite-cp -r mysprite:/data ./data                  # recursive download
"""

from __future__ import annotations

import asyncio
import json
import os
import re
import stat
import sys
from pathlib import Path
from typing import Annotated, AsyncIterator, Optional
from urllib.parse import quote

import httpx
import typer

__version__ = "0.2.0"
DEFAULT_BASE_URL = "https://api.sprites.dev"
_CHUNK = 1 << 20  # 1 MiB streaming chunk
_TOKEN_ENVS = ("SPRITE_TOKEN", "SPRITE_API_TOKEN", "FLY_API_TOKEN", "FLY_ACCESS_TOKEN")
_DEFAULT_CONCURRENCY = 8

app = typer.Typer(add_completion=False, pretty_exceptions_enable=False)


# ── Auth ──────────────────────────────────────────────────────────────────────

def _fly_config_token() -> tuple[str | None, str]:
    d = os.environ.get("FLY_CONFIG_DIR") or os.path.expanduser("~/.fly")
    path = os.path.join(d, "config.yml")
    try:
        s = os.stat(path)
        if s.st_mode & (stat.S_IRGRP | stat.S_IROTH):
            print(f"Warning: {path} is group/world-readable (chmod 600 recommended)", file=sys.stderr)
        with open(path) as f:
            for line in f:
                m = re.match(r"^\s*access_token:\s*['\"]?([^'\"#\s]+)", line)
                if m:
                    return m.group(1), path
    except OSError:
        pass
    return None, path


def _sprites_config_token() -> tuple[str | None, str]:
    """Resolve a token from the on-disk state of the official `sprite` CLI.

    Layout (macOS):
      ~/.sprites/sprites.json
        -> current_user, current_selection{url,org}, users[], path_history{}
      ~/.sprites/users/<id>-<hash>.json
        -> urls[<url>].orgs[<org>].keyring_key
      macOS Keychain generic-password:
        service = "sprites-cli:<userID>",  account = keyring_key
    """
    sprites_home = Path(os.environ.get("SPRITES_HOME") or os.path.expanduser("~/.sprites"))
    state_path = sprites_home / "sprites.json"
    try:
        state = json.loads(state_path.read_text())
    except (OSError, json.JSONDecodeError):
        return None, str(state_path)

    user_id = state.get("current_user")
    users = {u.get("id"): u for u in state.get("users", []) if u.get("id")}
    user = users.get(user_id) if user_id else None
    if not user:
        return None, str(state_path)

    # Pick the org the way `sprite` does: longest matching path_history prefix
    # of cwd, else current_selection.
    cwd = os.getcwd()
    history = state.get("path_history", {}) or {}
    best = max(
        (p for p in history if cwd == p or cwd.startswith(p.rstrip("/") + "/")),
        key=len,
        default=None,
    )
    sel = history[best] if best else (state.get("current_selection") or {})
    api_url = sel.get("url") or "https://api.sprites.dev"
    org = sel.get("org")
    if not org:
        return None, str(state_path)

    # Per-user config holds the exact keyring account string.
    user_cfg_path = user.get("config_path") or ""
    try:
        user_cfg = json.loads(Path(user_cfg_path).read_text())
        keyring_key = user_cfg["urls"][api_url]["orgs"][org]["keyring_key"]
    except (OSError, KeyError, json.JSONDecodeError):
        keyring_key = f"sprites:org:{api_url}:{org}"

    service = f"sprites-cli:{user_id}"
    try:
        import keyring
        from keyring.errors import KeyringError
        tok = keyring.get_password(service, keyring_key)
    except (ImportError, KeyringError):
        tok = None
    # zalando/go-keyring wraps stored values as "go-keyring-base64:<b64>" to
    # dodge a macOS Keychain bug with non-ASCII bytes. Python's keyring lib
    # doesn't unwrap; do it here.
    PREFIX = "go-keyring-base64:"
    if tok and tok.startswith(PREFIX):
        import base64
        try:
            tok = base64.b64decode(tok[len(PREFIX):], validate=True).decode("utf-8")
        except (ValueError, UnicodeDecodeError):
            tok = None

    if tok:
        return tok, f"keyring ({service} / {org}@{api_url})"
    return None, f"keyring miss ({service} / {keyring_key})"


def resolve_token(cli_token: str | None) -> tuple[str, str]:
    """Return (token, source_description). Exits if no token is found."""
    if cli_token:
        return cli_token, "--token"
    for var in _TOKEN_ENVS:
        if val := os.environ.get(var):
            return val, f"${var}"
    for fn in (_sprites_config_token, _fly_config_token):
        tok, src = fn()
        if tok:
            return tok, src
    typer.echo("No auth token found. Run `flyctl auth login` or set SPRITE_TOKEN.", err=True)
    raise typer.Exit(1)


# ── Path parsing ──────────────────────────────────────────────────────────────

def parse_path(s: str) -> tuple:
    """Return ("local", str) or ("remote", sprite_name, path).

    A path is remote iff ':' appears before the first '/'.
    Prefix with './' to force-local a path that contains a colon.
    """
    colon = s.find(":")
    slash = s.find("/")
    if colon > 0 and (slash < 0 or colon < slash):
        sprite, _, path = s.partition(":")
        # Accept optional org@sprite prefix (scp muscle memory); org comes from token.
        if "@" in sprite:
            sprite = sprite.rsplit("@", 1)[1]
        return "remote", sprite, path or "/"
    return "local", s


# ── Sprites HTTP client (async) ───────────────────────────────────────────────

def _error_msg(body: str | bytes) -> str:
    text = body.decode("utf-8", "replace") if isinstance(body, bytes) else body
    try:
        return json.loads(text).get("error", text[:200])
    except (ValueError, AttributeError):
        return text[:200]


class APIError(Exception):
    pass


class SpriteClient:
    def __init__(self, token: str, base_url: str = DEFAULT_BASE_URL) -> None:
        self._base = base_url.rstrip("/")
        # Token lives only in this header — never in URLs, query params, or logs.
        self._http = httpx.AsyncClient(
            headers={"Authorization": f"Bearer {token}"},
            timeout=httpx.Timeout(connect=10.0, read=None, write=None, pool=10.0),
        )

    async def __aenter__(self) -> "SpriteClient":
        return self

    async def __aexit__(self, *exc_info) -> None:
        await self._http.aclose()

    def _url(self, sprite: str, endpoint: str) -> str:
        return f"{self._base}/v1/sprites/{quote(sprite, safe='')}/fs/{endpoint}"

    def _raise(self, r: httpx.Response, op: str, path: str) -> None:
        if r.status_code in (200, 201):
            return
        raise APIError(f"{op} {path!r}: HTTP {r.status_code}: {_error_msg(r.text)}")

    async def _list_raw(self, sprite: str, path: str) -> dict:
        r = await self._http.get(self._url(sprite, "list"), params={"path": path})
        self._raise(r, "list", path)
        return r.json()

    async def list_dir(self, sprite: str, path: str) -> list[dict]:
        return (await self._list_raw(sprite, path)).get("entries", [])

    async def stat(self, sprite: str, path: str) -> dict | None:
        """Return the entry for a file path, or None if path is a directory.

        The server returns the file's own entry when given a file path, and the
        directory's children when given a directory path.  We detect the
        difference by comparing the response's canonical path to entries[0]:
        they match for a file, they differ for a directory (entries are children).
        """
        data = await self._list_raw(sprite, path)
        entries = data.get("entries", [])
        canonical = data.get("path", "").rstrip("/")
        if len(entries) == 1 and entries[0]["path"].rstrip("/") == canonical:
            return entries[0]
        return None

    async def walk(self, sprite: str, root: str) -> AsyncIterator[dict]:
        """Depth-first walk of a remote directory, yielding file entries."""
        for entry in await self.list_dir(sprite, root):
            if entry["isDir"]:
                async for child in self.walk(sprite, entry["path"]):
                    yield child
            else:
                yield entry

    async def upload(self, sprite: str, src: Path, dst: str) -> None:
        """Stream src bytes to dst; server auto-creates parent directories."""
        st = src.stat()
        mode = oct(st.st_mode & 0o777)[2:]

        async def _chunks() -> AsyncIterator[bytes]:
            fh = await asyncio.to_thread(src.open, "rb")
            try:
                while data := await asyncio.to_thread(fh.read, _CHUNK):
                    yield data
            finally:
                await asyncio.to_thread(fh.close)

        r = await self._http.put(
            self._url(sprite, "write"),
            params={"path": dst, "mode": mode, "mkdirParents": "true"},
            content=_chunks(),
            headers={"Content-Type": "application/octet-stream",
                     "Content-Length": str(st.st_size)},
        )
        self._raise(r, "upload", dst)

    async def download(self, sprite: str, src: str, dst: Path) -> None:
        """Stream src bytes to dst, creating parent directories as needed."""
        dst.parent.mkdir(parents=True, exist_ok=True)
        async with self._http.stream("GET", self._url(sprite, "read"), params={"path": src}) as r:
            if r.status_code not in (200, 206):
                body = await r.aread()
                raise APIError(f"download {src!r}: HTTP {r.status_code}: {_error_msg(body)}")
            fh = await asyncio.to_thread(dst.open, "wb")
            try:
                async for chunk in r.aiter_bytes(chunk_size=_CHUNK):
                    await asyncio.to_thread(fh.write, chunk)
            finally:
                await asyncio.to_thread(fh.close)


# ── Transfer helpers ──────────────────────────────────────────────────────────

def safe_join(root: Path, untrusted: str) -> Path:
    """Resolve a server-supplied relative path under root; abort on traversal."""
    result = (root / untrusted).resolve()
    try:
        result.relative_to(root.resolve())
    except ValueError:
        typer.echo(f"sprite-cp: path traversal rejected: {untrusted!r}", err=True)
        raise typer.Exit(1)
    return result


def _log(msg: str, quiet: bool) -> None:
    if not quiet:
        print(msg)


async def upload_file(
    client: SpriteClient, src: Path, sprite: str, dst: str, quiet: bool, sem: asyncio.Semaphore,
) -> None:
    async with sem:
        await client.upload(sprite, src, dst)
        _log(f"{src} -> {sprite}:{dst}", quiet)


async def upload_tree(
    client: SpriteClient, src: Path, sprite: str, dst: str, quiet: bool, sem: asyncio.Semaphore,
) -> None:
    async with asyncio.TaskGroup() as tg:
        for dirpath, dirs, files in os.walk(src, followlinks=False):
            dirs.sort()
            for fname in sorted(files):
                local = Path(dirpath) / fname
                rel = str(local.relative_to(src))
                remote = dst.rstrip("/") + "/" + rel
                tg.create_task(upload_file(client, local, sprite, remote, quiet, sem))


async def download_file(
    client: SpriteClient, sprite: str, src: str, dst: Path, quiet: bool, sem: asyncio.Semaphore,
) -> None:
    async with sem:
        await client.download(sprite, src, dst)
        _log(f"{sprite}:{src} -> {dst}", quiet)


async def download_tree(
    client: SpriteClient, sprite: str, src: str, dst: Path, quiet: bool, sem: asyncio.Semaphore,
) -> None:
    dst.mkdir(parents=True, exist_ok=True)
    async with asyncio.TaskGroup() as tg:
        async for entry in client.walk(sprite, src):
            rel = entry["path"].removeprefix(src.rstrip("/")).lstrip("/")
            local = safe_join(dst, rel)
            tg.create_task(download_file(client, sprite, entry["path"], local, quiet, sem))


# ── CLI ───────────────────────────────────────────────────────────────────────

def _version_cb(val: bool) -> None:
    if val:
        typer.echo(f"sprite-cp {__version__}")
        raise typer.Exit()


def _first_api_error(eg: BaseException) -> APIError | None:
    """Walk a (possibly nested) ExceptionGroup and return the first APIError."""
    if isinstance(eg, APIError):
        return eg
    if isinstance(eg, BaseExceptionGroup):
        for exc in eg.exceptions:
            if found := _first_api_error(exc):
                return found
    return None


@app.command(no_args_is_help=True)
def main(
    paths: Annotated[list[str], typer.Argument(help="SOURCE [SOURCE...] DEST  (SPRITE:PATH for remote)")],
    recursive:        Annotated[bool,         typer.Option("-r", "--recursive",         help="Copy directories recursively")] = False,
    quiet:            Annotated[bool,         typer.Option("-q", "--quiet",             help="Suppress per-file output")] = False,
    concurrency:      Annotated[int,          typer.Option("-j", "--concurrency",       help=f"Max parallel transfers (default {_DEFAULT_CONCURRENCY})", min=1, max=64)] = _DEFAULT_CONCURRENCY,
    token:            Annotated[Optional[str], typer.Option("--token",                  help="Auth token (overrides env / config)")] = None,
    api_url:          Annotated[Optional[str], typer.Option("--api-url", envvar="SPRITE_URL", hidden=True)] = None,
    show_token_source:Annotated[bool,         typer.Option("--show-token-source",       hidden=True)] = False,
    version:          Annotated[Optional[bool], typer.Option("--version", callback=_version_cb, is_eager=True, is_flag=True)] = None,
) -> None:
    """Copy files to or from a sprite, scp-style.

    \b
    Remote paths use SPRITE:PATH syntax, e.g. mysprite:/home/user/file.
    Prefix with ./ to force-local a path that contains a colon.
    Symlinks are not followed during recursive upload (unlike scp).
    """
    if len(paths) < 2:
        typer.echo("sprite-cp: need at least SOURCE and DEST", err=True)
        raise typer.Exit(1)

    *srcs, dest = paths
    pd = parse_path(dest)
    ps_list = [parse_path(s) for s in srcs]

    if len({p[0] for p in ps_list}) > 1:
        typer.echo("sprite-cp: cannot mix local and remote sources", err=True)
        raise typer.Exit(1)

    uploading = ps_list[0][0] == "local"
    if uploading and pd[0] != "remote":
        typer.echo("sprite-cp: local-to-local not supported (use cp)", err=True)
        raise typer.Exit(1)
    if not uploading and pd[0] != "local":
        typer.echo("sprite-cp: remote-to-remote not supported (use 'sprite exec cp')", err=True)
        raise typer.Exit(1)

    tok, tok_src = resolve_token(token)

    if show_token_source:
        typer.echo(tok_src)
        raise typer.Exit()

    asyncio.run(_amain(
        uploading=uploading, ps_list=ps_list, pd=pd, srcs=srcs,
        recursive=recursive, quiet=quiet, concurrency=concurrency,
        token=tok, api_url=api_url or DEFAULT_BASE_URL,
    ))


async def _amain(
    *, uploading: bool, ps_list: list[tuple], pd: tuple, srcs: list[str],
    recursive: bool, quiet: bool, concurrency: int, token: str, api_url: str,
) -> None:
    sem = asyncio.Semaphore(concurrency)
    try:
        async with SpriteClient(token, api_url) as client:
            async with asyncio.TaskGroup() as tg:
                if uploading:
                    _, d_sprite, d_path = pd
                    for ps in ps_list:
                        src_path = Path(ps[1])
                        if src_path.is_dir():
                            if not recursive:
                                typer.echo(f"sprite-cp: {src_path}: is a directory (use -r)", err=True)
                                raise typer.Exit(1)
                            # dest/srcname if dest ends with /; dest as rename otherwise
                            dst = d_path.rstrip("/") + "/" + src_path.name if d_path.endswith("/") else d_path
                            tg.create_task(upload_tree(client, src_path, d_sprite, dst, quiet, sem))
                        else:
                            # append basename when dest is explicitly a dir path or has multiple sources
                            dst = d_path.rstrip("/") + "/" + src_path.name if (d_path.endswith("/") or len(srcs) > 1) else d_path
                            tg.create_task(upload_file(client, src_path, d_sprite, dst, quiet, sem))
                else:
                    d_path = Path(pd[1])
                    for ps in ps_list:
                        _, s_sprite, s_path = ps
                        # stat is sequential (cheap, decides which task to spawn)
                        entry = await client.stat(s_sprite, s_path)
                        is_file = entry is not None
                        if is_file:
                            local = d_path / Path(s_path).name if (d_path.is_dir() or pd[1].endswith("/")) else d_path
                            tg.create_task(download_file(client, s_sprite, s_path, local, quiet, sem))
                        else:
                            if not recursive:
                                typer.echo(f"sprite-cp: {s_sprite}:{s_path}: is a directory (use -r)", err=True)
                                raise typer.Exit(1)
                            # into dest/srcname if dest exists or ends with /; rename otherwise
                            if d_path.exists() or pd[1].endswith("/"):
                                dst_dir = d_path / Path(s_path.rstrip("/")).name
                            else:
                                dst_dir = d_path
                            tg.create_task(download_tree(client, s_sprite, s_path, dst_dir, quiet, sem))

    except* typer.Exit:
        raise typer.Exit(1)
    except* APIError as eg:
        first = _first_api_error(eg) or APIError("unknown error")
        typer.echo(f"sprite-cp: {first}", err=True)
        raise typer.Exit(1)


if __name__ == "__main__":
    app()
