"""
AdaptiveVariationalProblem
"""

__author__ = "Marie E. Rognes (meg@simula.no)"
__copyright__ = "Copyright (C) 2009 - 2010 Marie E. Rognes"
__license__  = "GNU GPL version 3 or any later version"

# First added:  2009-09-19
# Last changed: 2010-06-22
#
# Modified by Anders Logg, 2009-2010

from time import time

from dolfin import rhs, lhs, derivative, action, assemble, refine
from dolfin import VariationalProblem, TrialFunction
from dolfin import plot, info, info_green, info_red, error, interactive, warning

from dolfin.cpp import Parameters, BoundaryMesh, File

from ufl.algorithms import preprocess

from dolfin.adaptivity.adaptivedata import AdaptiveData
from dolfin.adaptivity.marking import mark_cells
from dolfin.adaptivity.errorestimators import *
from dolfin.adaptivity.normestimators import *
from dolfin.adaptivity.updates import *
from dolfin.adaptivity.utils import *
from dolfin.adaptivity.formmanipulation import extract_mesh

__all__ = ["AdaptiveVariationalProblem"]

estimator_classes = {"error_representation": ErrorRepresentationEstimator,
                     "dual_weighted_residual": DualWeightedResidualEstimator,
                     "both_residuals": BothResidualsEstimator,
                     "cauchy_schwarz": CauchySchwarzEstimator,
                     "becker_rannacher": BeckerRannacherEstimator,
                     "energy_norm": EnergyNormEstimator,
                     "bank_weiser": BankWeiserEstimator}

class AdaptiveVariationalProblem:
    """
    """

    def __init__(self, F, bcs=None, goal_functional=None,
                 u=None,
                 reference=None,
                 goal_exterior_domain=None,
                 snapping_boundary=None):
        """
        """

        # Store variational problem
        self.F = F
        self.u = u
        self.bcs = bcs
        self.goal_exterior_domain = goal_exterior_domain
        self.snapping_boundary = snapping_boundary

        # Store goal and (optional) reference value
        self.goal_functional = goal_functional
        self.reference = reference

        # Create parameter set
        self.parameters = self.default_parameters(goal_functional)

        # Create container for mesh hierarchy
        self.meshes = []

        # Create container for various data produces during the adaptivity
        self.goal_exterior_facets = None
        self.data = []

    def default_parameters(self, goal):
        """
        Initialize default parameters
        """
        parameters = Parameters("adaptive_variational_problem")
        parameters.add("max_iterations", 20)
        parameters.add("plot_indicators", True)
        parameters.add("save_indicators", "indicators")
        parameters.add("plot_results", True)
        parameters.add("linear_solver", "direct")
        parameters.add("stopping_criterion", "tolerance")

        # Error estimation parameters
        error_estimators = Parameters("error_estimation")
        if not goal:
            error_estimators.add("estimator", "energy_norm")
            error_estimators.add("indicator", "energy_norm")
        else:
            error_estimators.add("estimator", "error_representation")
            error_estimators.add("indicator", "dual_weighted_residual")
            error_estimators.add("dual_strategy", "extrapolation")
        parameters.add(error_estimators)

        # Mesh marking parameters
        refinements = Parameters("marking")
        refinements.add("strategy", "dorfler")
        refinements.add("fraction", 0.5)
        refinements.add("tolerance", 0.0)
        parameters.add(refinements)

        return parameters

    def solve(self, tolerance):
        """

        Solve until estimated error is less than given tolerance.
        Return resulting solution

        SOLVE -- ESTIMATE -- INDICATE -- MARK -- REFINE
        """

        # Display current parameters
        self.summary()

        # Select error estimator and indicator classes
        error_params = self.parameters["error_estimation"]
        Estimator = estimator_classes[error_params["estimator"]]
        Indicator = estimator_classes[error_params["indicator"]]

        # Select marking strategy and inform of tolerance
        self.parameters["marking"]["tolerance"] = tolerance
        mark = lambda m, i: mark_cells(m, i, self.parameters["marking"])

        # Extract mesh and store:
        mesh = extract_mesh(self.F)
        self.meshes.append(mesh)

        file = File("%s.pvd" % self.parameters["save_indicators"])

        # Adaptive loop
        max_iterations = self.parameters["max_iterations"]
        for i in range(max_iterations):

            print_iteration(i)

            # Create storage for adaptivity meta data
            data = AdaptiveData()

            # Convert exterior_facets to mesh function
            self.goal_exterior_facets = domain_to_mf(self.goal_exterior_domain, mesh)

            # (1) Solve primal problem
            print_stage("(%d.1) Solving primal problem" % i)
            t0 = time()
            u_h = self.solve_primal(mesh)
            data.t_primal = time() - t0

            # (2) Estimate error
            print_stage("(%d.2) Estimating error" % i)
            t0 = time()
            if not self.goal_functional:
                warning("Norm-based error estimation is experimental. Proceed at own risk.")
                estimator = Estimator(self.F, error_params)
            else:
                estimator = Estimator(self.F, self.bcs, self.goal_functional,
                                      self.u, error_params,
                                      self.goal_exterior_facets)
            error_estimate = estimator.estimate_error(u_h)
            data.t_error_estimation = time() - t0

            # Save various data
            data.functional_value = self.evaluate_goal(u_h)
            data.tolerance = tolerance
            data.error_estimate = error_estimate
            data.num_cells = mesh.num_cells()
            data.num_dofs = u_h.function_space().dim()
            data.reference = self.reference
            data.refinement_level = i
            data.summary()
            self.data += [data]

            # (2b) Check stopping criterion
            if self.done(tolerance, error_estimate, u_h):
                self.final_summary()
                info_green("Solution computed to within given tolerance, stopping")
                return u_h
            elif i + 1 >= max_iterations:
                self.final_summary()
                info_red("Reached maximum number of iterations (%d): error estimate = %g > %g" % \
                         (max_iterations, error_estimate, tolerance))
                return u_h

            # (3) Assemble error indicators
            print_stage("(%d.3) Computing error indicators" % i)
            t0 = time()
            indicator = Indicator(estimator)
            error_indicators = indicator.assemble_error_indicators(u_h)
            t_error_indicators = time() - t0

            # Store indicators
            data.error_indicators = error_indicators
            data.sum_indicators = indicator.estimate_error(u_h)
            mf = convert_to_mesh_function(error_indicators)
            file << mf
            if self.parameters["plot_indicators"]:
                plot(mf, title="Error indicators (refinement level %d)" % i)

            print_value_red("Sum error indicators: ", data.sum_indicators)

            # (4) Mark
            print_stage("(%d.4) Marking cells for refinement" % i)
            t0 = time()
            cell_markers = mark(mesh, error_indicators)
            t_marking = time() - t0

            # (5) Refine mesh
            print_stage("(%d.5) Refining mesh" % i)
            t0 = time()
            mesh = refine(mesh, cell_markers)
            t_refinement = time() - t0

            # Store mesh
            self.meshes.append(mesh)
            meshfile = File("%s_mesh%d.xml" % (self.parameters["save_indicators"], i+1))
            meshfile << mesh

            # Smooth boundary
            if self.snapping_boundary:
                mesh.snap_boundary(self.snapping_boundary, False)

            # Update forms etc after mesh refinement
            self._update(mesh, u_h.function_space())

            # Update adaptivity data
            data.t_error_indicators = t_error_indicators
            data.t_marking = t_marking
            data.t_refinement = t_refinement

        return u_h

    def done(self, tolerance, error_estimate, u_h):
        is_done = False
        if self.parameters["stopping_criterion"] == "tolerance":
            is_done = abs(error_estimate) < tolerance
        elif self.parameters["stopping_criterion"] == "dimension":
            is_done = u_h.function_space().dim() > tolerance
        else:
            error("Unknown stopping criterion: %s" % self.parameters["stopping_criterion"])

        return is_done

    def _update(self, mesh, V_h):
        """ Update all variables common for the loop of the
        variational problem to the mesh 'mesh'.
        """

        # Update F
        (self.F, w) = update_form(self.F, mesh, extract=self.u)

        # Update goal-functional
        if self.goal_functional:
            info("--- Updating goal functional")
            self.goal_functional = update_form(self.goal_functional, mesh,
                                               exchange=(self.u, w))[0]

        # Update bcs
        if self.bcs:
            info("--- Updating boundary conditions")
            self.bcs = update_bcs(self.bcs, mesh, V_h)

        # Update u
        self.u = w

    def solve_primal(self, mesh):

        # FIXME: Waiting for better VariationalProblem here
        if self.u is None:
            pde = VariationalProblem(lhs(self.F), rhs(self.F), self.bcs)
            pde.parameters["linear_solver"] = self.parameters["linear_solver"]
            u_h = pde.solve()

        else:
            du = TrialFunction(self.u.function_space())
            a = derivative(self.F, self.u, du)
            pde = VariationalProblem(a, self.F, self.bcs, nonlinear=True)
            pde.parameters["linear_solver"] = self.parameters["linear_solver"]
            pde.solve(self.u)
            u_h = self.u

        return u_h

    def evaluate_goal(self, u_h):
        """
        Evaluate goal functional and compare error with actual error
        if possible
        """

        if self.goal_functional is None: return

        # FIXME: Temporary until we make the functional a functional
        # (not a linear form)
        try:
            goal_functional = action(self.goal_functional, u_h)
        except:
            goal_functional = self.goal_functional
            info("Can't compute action so I assume it's a functional already..")

        mesh = u_h.function_space().mesh()

        # Compute value of goal functional
        return assemble(goal_functional, mesh=mesh,
                        exterior_facet_domains=self.goal_exterior_facets)


    def solution(self):
        """Return stored solution u."""
        return self.u

    def final_summary(self):

        info("-"*80 + "\nSummary of iterations:\n" + "-"*80)
        self.summary()

        for datum in self.data:
            datum.summary()

        if self.reference is not None:
            filename = "%s.py" % self.parameters["save_indicators"]
            save_data(self.data, filename)

    def summary(self):
        "Pretty-print summary"

        # Print some parameter values
        info("")
        print_value("Error estimator   ", self.parameters["error_estimation"]["estimator"])
        print_value("Error indicator   ", self.parameters["error_estimation"]["indicator"])
        if self.goal_functional:
            print_value("Dual strategy     ", self.parameters["error_estimation"]["dual_strategy"])
        print_value("Mesh refinement   ", self.parameters["marking"]["strategy"])
        info("")
