# -*- coding=utf-8 -*-

from __future__ import absolute_import, unicode_literals

import itertools

import resolvelib

import plette
import requirementslib
import vistir

from ..internals.hashes import get_hashes
from ..internals.reporters import StdOutReporter
from ..internals.traces import trace_graph
from ..internals.utils import identify_requirment
from .caches import HashCache
from .metadata import set_metadata
from .providers import BasicProvider, EagerUpgradeProvider, PinReuseProvider


def _get_requirements(model, section_name):
    """Produce a mapping of identifier: requirement from the section.
    """
    if not model:
        return {}
    return {identify_requirment(r): r for r in (
        requirementslib.Requirement.from_pipfile(name, package._data)
        for name, package in model.get(section_name, {}).items()
    )}


def _get_requires_python(pipfile):
    try:
        requires = pipfile.requires
    except AttributeError:
        return ""
    try:
        return requires.python_full_version
    except AttributeError:
        pass
    try:
        return requires.python_version
    except AttributeError:
        return ""


def _collect_derived_entries(state, traces, identifiers):
    """Produce a mapping containing all candidates derived from `identifiers`.

    `identifiers` should provide a collection of requirement identifications
    from a section (i.e. `packages` or `dev-packages`). This function uses
    `trace` to filter out candidates in the state that are present because of
    an entry in that collection.
    """
    identifiers = set(identifiers)
    if not identifiers:
        return {}

    entries = {}
    extras = {}
    for identifier, requirement in state.mapping.items():
        routes = {trace[1] for trace in traces[identifier] if len(trace) > 1}
        if identifier not in identifiers and not (identifiers & routes):
            continue
        name = requirement.normalized_name
        if requirement.extras:
            # Aggregate extras from multiple routes so we can produce their
            # union in the lock file. (sarugaku/passa#24)
            try:
                extras[name].extend(requirement.extras)
            except KeyError:
                extras[name] = list(requirement.extras)
        entries[name] = next(iter(requirement.as_pipfile().values()))
    for name, ext in extras.items():
        entries[name]["extras"] = ext

    return entries


class AbstractLocker(object):
    """Helper class to produce a new lock file for a project.

    This is not intended for instantiation. You should use one of its concrete
    subclasses instead. The class contains logic to:

    * Prepare a project for locking
    * Perform the actually resolver invocation
    * Convert resolver output into lock file format
    * Update the project to have the new lock file
    """
    def __init__(self, project):
        self.project = project
        self.default_requirements = _get_requirements(
            project.pipfile, "packages",
        )
        self.develop_requirements = _get_requirements(
            project.pipfile, "dev-packages",
        )

        # This comprehension dance ensures we merge packages from both
        # sections, and definitions in the default section win.
        self.requirements = {k: r for k, r in itertools.chain(
            self.develop_requirements.items(),
            self.default_requirements.items(),
        )}.values()

        self.sources = [s._data.copy() for s in project.pipfile.sources]
        self.allow_prereleases = bool(
            project.pipfile.get("pipenv", {}).get("allow_prereleases", False),
        )
        self.requires_python = _get_requires_python(project.pipfile)

    def __repr__(self):
        return "<{0} @ {1!r}>".format(type(self).__name__, self.project.root)

    def get_provider(self):
        raise NotImplementedError

    def get_reporter(self):
        # TODO: Build SpinnerReporter, and use this only in verbose mode.
        return StdOutReporter(self.requirements)

    def lock(self):
        """Lock specified (abstract) requirements into (concrete) candidates.

        The locking procedure consists of four stages:

        * Resolve versions and dependency graph (powered by ResolveLib).
        * Walk the graph to determine "why" each candidate came to be, i.e.
          what top-level requirements result in a given candidate.
        * Populate hashes for resolved candidates.
        * Populate markers based on dependency specifications of each
          candidate, and the dependency graph.
        """
        provider = self.get_provider()
        reporter = self.get_reporter()
        resolver = resolvelib.Resolver(provider, reporter)

        with vistir.cd(self.project.root):
            state = resolver.resolve(self.requirements)

        traces = trace_graph(state.graph)

        hash_cache = HashCache()
        for r in state.mapping.values():
            if not r.hashes:
                r.hashes = get_hashes(hash_cache, r)

        set_metadata(
            state.mapping, traces,
            provider.fetched_dependencies,
            provider.collected_requires_pythons,
        )

        lockfile = plette.Lockfile.with_meta_from(self.project.pipfile)
        lockfile["default"] = _collect_derived_entries(
            state, traces, self.default_requirements,
        )
        lockfile["develop"] = _collect_derived_entries(
            state, traces, self.develop_requirements,
        )
        self.project.lockfile = lockfile


class BasicLocker(AbstractLocker):
    """Basic concrete locker.

    This takes a project, generates a lock file from its Pipfile, and sets
    the lock file property to the project.
    """
    def get_provider(self):
        return BasicProvider(
            self.requirements, self.sources,
            self.requires_python, self.allow_prereleases,
        )


class PinReuseLocker(AbstractLocker):
    """A specialized locker to handle re-locking based on existing pins.

    See :class:`.providers.PinReuseProvider` for more information.
    """
    def __init__(self, project):
        super(PinReuseLocker, self).__init__(project)
        pins = _get_requirements(project.lockfile, "develop")
        pins.update(_get_requirements(project.lockfile, "default"))
        for pin in pins.values():
            pin.markers = None
        self.preferred_pins = pins

    def get_provider(self):
        return PinReuseProvider(
            self.preferred_pins, self.requirements, self.sources,
            self.requires_python, self.allow_prereleases,
        )


class EagerUpgradeLocker(PinReuseLocker):
    """A specialized locker to handle the "eager" upgrade strategy.

    See :class:`.providers.EagerUpgradeProvider` for more
    information.
    """
    def __init__(self, tracked_names, *args, **kwargs):
        super(EagerUpgradeLocker, self).__init__(*args, **kwargs)
        self.tracked_names = tracked_names

    def get_provider(self):
        return EagerUpgradeProvider(
            self.tracked_names, self.preferred_pins,
            self.requirements, self.sources,
            self.requires_python, self.allow_prereleases,
        )