#!/usr/bin/python3
# Copyright (C) 2018 Jelmer Vernooij <jelmer@jelmer.uk>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

import asyncio
import contextlib
import json
import logging
import os
import subprocess
import tempfile
from typing import List, Dict, Tuple, Optional

from urllib.request import urlopen
from urllib.error import HTTPError, URLError

import breezy.bzr
import breezy.git
from breezy.errors import NotBranchError
from breezy.revision import RevisionID
from breezy.trace import note
from breezy.transform import MalformedTransform
from breezy.tree import Tree

from breezy.transport import NoSuchFile
from breezy.workingtree import WorkingTree

from debian.changelog import Changelog, Version

from breezy.plugins.debian.info import versions_dict
from breezy.plugins.debian.upstream import PackageVersionNotPresent
from breezy.plugins.debian.import_dsc import (
    DistributionBranch,
    DistributionBranchSet,
    VersionAlreadyImported,
)
from breezy.plugins.debian.util import MissingChangelogError
from breezy.plugins.debian.apt_repo import (
    LocalApt,
    RemoteApt,
    NoAptSources,
    AptSourceError,
)

BRANCH_NAME = "missing-commits"


def connect_udd_mirror():
    import psycopg2

    return psycopg2.connect(
        database="udd",
        user="udd-mirror",
        password="udd-mirror",
        host="udd-mirror.debian.net",
    )


def select_vcswatch_packages():
    conn = connect_udd_mirror()
    cursor = conn.cursor()
    args = []
    query = """\
    SELECT sources.source, vcswatch.url
    FROM vcswatch JOIN sources ON sources.source = vcswatch.source
    WHERE
     vcswatch.status IN ('OLD', 'UNREL') AND
     sources.release = 'sid'
"""
    cursor.execute(query, tuple(args))
    packages = []
    for package, vcs_url in cursor.fetchall():
        packages.append(package)
    return packages


class SnapshotDownloadError(Exception):

    def __init__(self, url, inner):
        self.url = url
        self.inner = inner


def download_snapshot(package: str, version: Version, output_dir: str) -> str:
    note("Downloading %s %s", package, version)
    srcfiles_url = (
        "https://snapshot.debian.org/mr/package/%s/%s/"
        "srcfiles?fileinfo=1" % (package, version)
    )
    files = {}
    try:
        for hsh, entries in json.load(urlopen(srcfiles_url))["fileinfo"].items():
            for entry in entries:
                files[entry["name"]] = hsh
        except HTTPError as e:
            raise SnapshotDownloadError(srcfiles_url, e) from e
        except URLError as e:
            raise SnapshotDownloadError(srcfiles_url, e) from e
    for filename, hsh in files.items():
        local_path = os.path.join(output_dir, os.path.basename(filename))
        with open(local_path, "wb") as f:
            url = "https://snapshot.debian.org/file/%s" % hsh
            note('.. Downloading %s', url)
            try:
                with urlopen(url) as g:
                    f.write(g.read())
            except HTTPError as e:
                raise SnapshotDownloadError(url, e) from e
            except URLError as e:
                raise SnapshotDownloadError(url, e) from e
    file_version = Version(version)
    file_version.epoch = None
    dsc_filename = "%s_%s.dsc" % (package, file_version)
    return os.path.join(output_dir, dsc_filename)


class NoMissingVersions(Exception):
    def __init__(self, vcs_version, archive_version):
        self.vcs_version = vcs_version
        self.archive_version = archive_version
        super(NoMissingVersions, self).__init__(
            "No missing versions after all. Archive has %s, VCS has %s"
            % (archive_version, vcs_version)
        )


class TreeVersionNotInArchiveChangelog(Exception):
    def __init__(self, tree_version):
        self.tree_version = tree_version
        super(TreeVersionNotInArchiveChangelog, self).__init__(
            "tree version %s does not appear in archive changelog" %
            tree_version
        )


class TreeUpstreamVersionMissing(Exception):
    def __init__(self, upstream_version):
        self.upstream_version = upstream_version
        super(TreeUpstreamVersionMissing, self).__init__(
            "unable to find upstream version %r" % upstream_version
        )


def import_uncommitted(
        tree: Tree, subpath: str, apt,
        source_version: Optional[str] = None) -> List[Tuple[str, Version, RevisionID]]:
    cl_path = os.path.join(subpath, "debian/changelog")
    try:
        with tree.get_file(cl_path) as f:
            tree_cl = Changelog(f)
            package_name = tree_cl.package
    except NoSuchFile as e:
        raise MissingChangelogError([cl_path]) from e

    with contextlib.ExitStack() as es:
        es.enter_context(apt)
        archive_source = es.enter_context(tempfile.TemporaryDirectory())
        apt.retrieve_source(
            package_name, archive_source, source_version=source_version)
        [dsc] = [e.name for e in os.scandir(archive_source)
                 if e.name.endswith('.dsc')]
        note("Unpacking source %s", dsc)
        subprocess.check_output(['dpkg-source', '-x', dsc], cwd=archive_source)
        [subdir] = [e.path for e in os.scandir(archive_source) if e.is_dir()]
        with open(os.path.join(subdir, "debian", "changelog"), "r") as f:
            archive_cl = Changelog(f)
        missing_versions: List[Version] = []
        for block in archive_cl:
            if block.version == tree_cl.version:
                break
            missing_versions.append(block.version)
        else:
            raise TreeVersionNotInArchiveChangelog(tree_cl.version)
        if len(missing_versions) == 0:
            raise NoMissingVersions(tree_cl.version, archive_cl.version)
        note("Missing versions: %s", ", ".join(map(str, missing_versions)))
        ret = []
        dbs = DistributionBranchSet()
        db = DistributionBranch(tree.branch, tree.branch, tree=tree)
        dbs.add_branch(db)
        if tree_cl.version.debian_revision:
            note("Extracting upstream version %s.",
                tree_cl.version.upstream_version)
            upstream_dir = es.enter_context(tempfile.TemporaryDirectory())
            try:
                upstream_tips = db.pristine_upstream_source.version_as_revisions(
                    tree_cl.package, tree_cl.version.upstream_version)
            except PackageVersionNotPresent as e:
                # TODO(jelmer): Should we import it instead?
                raise TreeUpstreamVersionMissing(tree_cl.version.upstream_version) from e
            db.extract_upstream_tree(upstream_tips, upstream_dir)
        applied_patches = tree.has_filename(".pc/applied-patches")
        version_path: Dict[Version, str] = {}
        for version in missing_versions:
            output_dir = es.enter_context(tempfile.TemporaryDirectory())
            version_path[version] = download_snapshot(
                package_name, version, output_dir)
        for version in reversed(missing_versions):
            note("Importing %s", version)
            dsc_path = version_path[version]
            try:
                tag_name = db.import_package(dsc_path, apply_patches=applied_patches)
            except VersionAlreadyImported as e:
                # Present in the repository, just not on the branch
                note("%s was already imported (tag: %s), just not on the "
                     "branch. Updating tree.", e.version, e.tag_name) 
                tag_name = e.tag_name
                db.tree.update(revision=db.branch.tags.lookup_tag(e.tag_name))
            revision = db.branch.tags.lookup_tag(tag_name)
            ret.append((tag_name, version, revision))
    return ret


def report_fatal(code, description):
    if os.environ.get('SVP_API') == '1':
        with open(os.environ['SVP_RESULT'], 'w') as f:
            json.dump({
                'versions': versions_dict(),
                'result_code': code,
                'description': description}, f)
    logging.fatal('%s', description)


async def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--apt-repository', type=str,
        help='APT repository to use. Defaults to locally configured.',
        default=(
            os.environ.get('APT_REPOSITORY')
            or os.environ.get('REPOSITORIES')))
    parser.add_argument(
        '--apt-repository-key', type=str,
        help=('APT repository key to use for validation, '
              'if --apt-repository is set.'),
        default=os.environ.get('APT_REPOSITORY_KEY'))
    parser.add_argument(
        '--version', type=str,
        help='Source version to import')
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format='%(message)s')

    if args.apt_repository:
        apt = RemoteApt.from_string(
            args.apt_repository, args.apt_repository_key)
    else:
        apt = LocalApt()
    try:
        local_tree, subpath = WorkingTree.open_containing('.')
    except NotBranchError:
        report_fatal(
            "not-branch-error",
            "Not running in a version-controlled directory")
        return 1
    try:
        ret = import_uncommitted(
            local_tree, subpath, apt, source_version=args.version)
    except AptSourceError as e:
        if isinstance(e.reason, list):
            reason = e.reason[-1]
        else:
            reason = e.reason
        report_fatal("apt-source-error", reason)
        return 1
    except MissingChangelogError as e:
        report_fatal(
            "missing-changelog",
            "Missing changelog: %s" % e.location[0])
        return 1
    except NoAptSources:
        report_fatal(
            "no-apt-sources",
            "No sources configured in /etc/apt/sources.list")
        return 1
    except TreeUpstreamVersionMissing as e:
        report_fatal("tree-upstream-version-missing", str(e))
        return 1
    except TreeVersionNotInArchiveChangelog as e:
        report_fatal("tree-version-not-in-archive-changelog", str(e))
        return 1
    except NoMissingVersions as e:
        report_fatal("nothing-to-do", str(e))
        return 0
    except SnapshotDownloadError as e:
        report_fatal(
            'snapshot-download-failed',
            'Downloading %s failed: %s' % (e.url, e.inner))
        return 1
    except MalformedTransform as e:
        report_fatal('malformed-transform', str(e))
        return 1

    if os.environ.get('SVP_API') == '1':
        with open(os.environ['SVP_RESULT'], 'w') as f:
            json.dump({
                'description': 'Import archive changes missing from the VCS.',
                'versions': versions_dict(),
                'commit-message': "Import missing uploads: %s." % (
                    ", ".join([str(v) for t, v, rs in ret])),
                'context': {
                    'tags':
                        [(tag_name, str(version))
                         for (tag_name, version, rs) in ret],
                },
            }, f)

    note('Imported uploads: %s.', [str(v[1]) for v in ret])


if __name__ == "__main__":
    import sys

    sys.exit(asyncio.run(main()))
