#!/usr/bin/env python

import unittest
import os, sys, glob, shutil, commands

import ufl

from ufl import FiniteElement
from ufl import VectorElement
from ufl import TensorElement
from ufl import MixedElement

from ufl import BasisFunction
from ufl import TestFunction
from ufl import TrialFunction

from ufl import Function
from ufl import Constant

from ufl import dx, ds

import SyFi
import newsfc as sfc

from dolfin import Mesh, MeshEditor, assemble

#import instant
#instant.set_logging_level("error")
#instant.set_logging_level("warning")
#instant.set_logging_level("info")
#instant.set_logging_level("debug")


def num_integrals(form):
    return (form.num_cell_integrals(), form.num_exterior_facet_integrals(), form.num_interior_facet_integrals())

cell2dim = { "interval": 1, "triangle": 2, "tetrahedron": 3, "quadrilateral": 2, "hexahedron": 3 }

cell2volume = { "interval": 1.0, "triangle": 0.5, "tetrahedron": 1.0/6.0, "quadrilateral": 1.0, "hexahedron": 1.0 }
        

def UnitCell(celltype):
    tdim = cell2dim[celltype]
    gdim = tdim
    mesh = Mesh()
    editor = MeshEditor()
    editor.open(mesh, celltype, tdim, gdim)
    if celltype == "interval":
        vertices = [(0.0,),
                    (1.0,)]
    if celltype == "triangle":
        vertices = [(0.0, 0.0),
                    (1.0, 0.0),
                    (0.0, 1.0)]
    if celltype == "tetrahedron":
        vertices = [(0.0, 0.0, 0.0),
                    (1.0, 0.0, 0.0),
                    (0.0, 1.0, 0.0),
                    (0.0, 0.0, 1.0)]
    if celltype == "quadrilateral":
        vertices = [(0.0, 0.0),
                    (1.0, 0.0),
                    (1.0, 1.0),
                    (0.0, 1.0)]
    if celltype == "hexahedron":
        vertices = [(0.0, 0.0, 0.0),
                    (1.0, 0.0, 0.0),
                    (1.0, 1.0, 0.0),
                    (0.0, 1.0, 0.0),
                    (0.0, 0.0, 1.0),
                    (1.0, 0.0, 1.0),
                    (1.0, 1.0, 1.0),
                    (0.0, 1.0, 1.0)]
    editor.initVertices(len(vertices))
    editor.initCells(1)
    for i, p in enumerate(vertices):
        editor.addVertex(i, *p)
    editor.addCell(0, *range(len(vertices)))
    editor.close()
    return mesh


def assemble_on_cell(form, celltype, coeffs):
    "Assemble UFC form on a unit cell mesh and return the result as a float or numpy array."
    mesh = UnitCell(celltype)
    A = assemble(form, mesh, coeffs)
    if isinstance(A, float):
        return A
    return A.array()

_test_temp_dir = "temp_dir"
_done_test_temp_dir = "done_temp_dir"
class SFCJitTest(unittest.TestCase):
    def setUp(self):
        #print "Running sfc jit test in testdir"
        #print "Imported SyFi from location", SyFi.__file__
        #print "Imported sfc  from location", sfc.__file__
        # Generate code in a clean directory:
        shutil.rmtree(_test_temp_dir, ignore_errors=True)
        os.mkdir(_test_temp_dir)
        os.chdir(_test_temp_dir)
    
    def tearDown(self):
        dirs = glob.glob("*")
        os.chdir("..")
        for d in dirs:
            os.rename(os.path.join(_test_temp_dir, d), os.path.join(_done_test_temp_dir, d))
    
    def testSetup(self):
        pass
    
    def _testJitVolume(self, polygon):
        "Test that the integral of 1.0 over a unit cell equals the length/area/volume of the unit cell."
        c = Constant(polygon)
        a = c*dx
        form = sfc.jit(a)
        self.assertTrue(form.rank() == 0)
        self.assertTrue(form.num_coefficients() == 1)
        self.assertTrue(num_integrals(form) == (1,0,0))
        A = assemble_on_cell(form, polygon, coeffs=[1.0])
        self.assertAlmostEqual(A, cell2volume[polygon])
    
    def testJitVolumeInterval(self):
        self._testJitVolume("interval")

    def testJitVolumeTriangle(self):
        self._testJitVolume("triangle")
    
    def testJitVolumeTetrahedron(self):
        self._testJitVolume("tetrahedron")
    
    def _testJitVolumeQuadrilateral(self): # Not supported by dolfin yet
        self._testJitVolume("quadrilateral")
    
    def _testJitVolumeHexahedron(self): # Not supported by dolfin yet
        self._testJitVolume("hexahedron")

    def _testJitConstant(self, polygon, degree):
        """Test that the integral of a constant coefficient over a unit 
        cell mesh equals the constant times the volume of the unit cell."""
        element = FiniteElement("CG", polygon, degree)
        f = Function(element)
        a = f*dx
        form = sfc.jit(a)
        self.assertTrue(form.rank() == 0)
        self.assertTrue(form.num_coefficients() == 1)
        self.assertTrue(num_integrals(form) == (1,0,0))
        const = 1.23
        A = assemble_on_cell(form, polygon, coeffs=[const])
        self.assertAlmostEqual(A, const*cell2volume[polygon])
    
    def testJitConstantInterval(self):
        polygon = "interval"
        self._testJitConstant(polygon, 1)
        self._testJitConstant(polygon, 2)

    def testJitConstantTriangle(self):
        polygon = "triangle"
        self._testJitConstant(polygon, 1)
        self._testJitConstant(polygon, 2)
    
    def testJitConstantTetrahedron(self):
        polygon = "tetrahedron"
        self._testJitConstant(polygon, 1)
        self._testJitConstant(polygon, 2)
    
    def _testJitConstantQuadrilateral(self): # Not supported by dolfin yet
        polygon = "quadrilateral"
        self._testJitConstant(polygon, 1)
        self._testJitConstant(polygon, 2)

    def _testJitConstantHexahedron(self): # Not supported by dolfin yet
        polygon = "hexahedron"
        self._testJitConstant(polygon, 1)
        self._testJitConstant(polygon, 2)
    
    def testJitSource(self):
        "Test the source vector."
        element = FiniteElement("CG", "triangle", 1)
        v = TestFunction(element)
        f = Function(element)
        a = f*v*dx
        form = sfc.jit(a)
        self.assertTrue(form.rank() == 1)
        self.assertTrue(form.num_coefficients() == 1)
        self.assertTrue(num_integrals(form) == (1,0,0))
        A = assemble_on_cell(form, "triangle", coeffs=[3.14])
        # TODO: Assert correct result
    
    def testJitMass(self):
        "Test the mass matrix."
        element = FiniteElement("CG", "triangle", 1)
        v = TestFunction(element)
        u = TrialFunction(element)
        f = Function(element)
        a = f*u*v*dx
        form = sfc.jit(a)
        self.assertTrue(form.rank() == 2)
        self.assertTrue(form.num_coefficients() == 1)
        self.assertTrue(num_integrals(form) == (1,0,0))
        A = assemble_on_cell(form, "triangle", coeffs=[5.43])
        # TODO: Assert correct result

    def testJitSplitTerms(self):
        "Test a form split over two foo*dx terms, using the mass matrix."
        element = FiniteElement("CG", "triangle", 1)
        v = TestFunction(element)
        u = TrialFunction(element)
        f = Function(element)
        a = u*v*dx + f*u*v*dx
        form = sfc.jit(a)
        self.assertTrue(form.rank() == 2)
        self.assertTrue(form.num_coefficients() == 1)
        self.assertTrue(num_integrals(form) == (1,0,0))
        A = assemble_on_cell(form, "triangle", coeffs=[4.43])
        # TODO: Assert correct result


def test(verbosity=0):
    shutil.rmtree(_done_test_temp_dir, ignore_errors=True)
    os.mkdir(_done_test_temp_dir)
    
    classes = [SFCJitTest]
    suites = [unittest.makeSuite(c) for c in classes]
    testsuites = unittest.TestSuite(suites)
    unittest.TextTestRunner(verbosity=verbosity).run(testsuites)

if __name__ == "__main__":
    test()

