#!/usr/bin/env python
"""Contains the ParameterDict class, useful for defining
recursive dictionaries of parameters and using attribute
syntax for later access.

Some useful features:
- Recursive copy function of parameter subsets
- Recursive update function including parameter subsets
- Recursive indented pretty-print
- Valid parameters are declared as keyword arguments to the constructor,
  and assigning to indeclared variables is not allowed.

See help(ParameterDict) for an interactive example.
"""

__author__ = "Martin Sandve Alnes <martinal@simula.no>, Johan Hake <hake@simula.no>"
__date__ = "2008-06-22 -- 2008-06-23"
__copyright__ = "(C) 2008 Martin Sandve Alnes and Simula Resarch Laboratory"
__license__  = "GNU GPL Version 2, or (at your option) any later version"


class ParameterDict(dict):
    """A dictionary with attribute-style access,
    that maps attribute access to the real dictionary.
    
    Interactive example:
    >>> m = ParameterDict(Re = 1.0, f = "sin(x)")
    >>> print m
    Re = 1.0
    f = 'sin(x)'
    >>> s = ParameterDict(max_iterations = 10, tolerance = 1e-8)
    >>> print s
    max_iterations = 10
    tolerance = 1e-08
    >>> p = ParameterDict(model = m, solver = s)
    >>> print p
    model = {
        Re = 1.0
        f = 'sin(x)'
    }
    solver = {
        max_iterations = 10
        tolerance = 1e-08
    }
    >>> q = p.copy()
    >>> q.model.Re = 2.3e6
    >>> q.solver.max_iterations = 100
    >>> print q
    model = {
        Re = 2300000.0
        f = 'sin(x)'
    }
    solver = {
        max_iterations = 100
        tolerance = 1e-08
    }
    >>> print p
    model = {
        Re = 1.0
        f = 'sin(x)'
    }
    solver = {
        max_iterations = 10
        tolerance = 1e-08
    }
    >>> p.update(q)
    >>> print p
    model = {
        Re = 2300000.0
        f = 'sin(x)'
    }
    solver = {
        max_iterations = 100
        tolerance = 1e-08
    }
    >>> s.nothere = 123
    Traceback (most recent call last):
      File "doctest.py", line 1212, in __run
        compileflags, 1) in test.globs
      File "<doctest __main__.ParameterDict[13]>", line 1, in <module>
        s.nothere = 123
      File "ParameterDict.py", line 107, in __setattr__
        raise AttributeError("%s is not an item in this parameter dict." % key)
    AttributeError: nothere is not an item in this parameter dict.
    """
    def __init__(self, **params):
        dict.__init__(self, **params)
        for k,v in params.iteritems():
            self.__setattr__(k, v)
    
    def __getstate__(self):
        return self.__dict__.items()
    
    def __setstate__(self, items):
        for key, val in items:
            self.__dict__[key] = val
    
    def __str__(self):
        return self.format()
    
    def __repr__(self):
        return "%s(%s)" % (self.__class__.__name__, ", ".join("%s = %s" % (k,repr(dict.__getitem__(self, k))) for k in sorted(dict.iterkeys(self))))
    
    def __delitem__(self, key):
        return dict.__delitem__(self, key)
    
    def __setattr__(self, key, value):
        assert isinstance(key, str)
        if dict.__contains__(self, key):
            dict.__setitem__(self, key, value)
        else: # TODO: Keep or drop this?
            raise AttributeError("%s is not an item in this parameter dict." % key)
        return dict.__setattr__(self, key, value)
    
    def __getattr__(self, key):
        if not dict.__contains__(self, key):
            raise AttributeError("%s is not an item in this parameter dict." % key)
        return dict.__getitem__(self, key)
    
    def format(self, indent=None):
        "Make a recursive indented pretty-print string of self and parameter subsets."
        value_formatter = repr
        if indent is None:
            indent = 0
        s = ""
        for k in sorted(dict.iterkeys(self)):
            v = getattr(self, k)
            if isinstance(v, self.__class__):
                s += "    "*indent + "%s = {\n" % k
                s += v.format(indent+1)
                s += "\n" + "    "*indent + "}\n"
            else:
                s += "    "*indent + "%s = %s\n" % (k, value_formatter(v))
        return s[:-1]
    
    def copy(self):
        """Make a copy of self, including recursive copying of parameter subsets.
        Parameter values themselves are not copied."""
        # TODO: Make this an external function to handle recursive dicts...
        items = {}
        for k in dict.iterkeys(self):
            v = getattr(self, k)
            if isinstance(v, dict): #self.__class__):
                items[k] = v.copy()
            else:
                items[k] = v
        ch = ParameterDict(**items)
        return ch
    
    def update(self, other):
        "A recursive update that handles parameter subsets correctly unlike dict.update."
        # TODO: Make this an external function to handle recursive dicts...
        for k in dict.iterkeys(other):
            sv = getattr(self, k)
            ov = getattr(other, k)
            if isinstance(sv, dict): #self.__class__):
                # Update my own subdict with others subdict
                sv.update(ov)
            else:
                # Set my own value to others value
                setattr(self, k, ov)


# Test code
if __name__ == "__main__":
    
    def default_a():
        p = ParameterDict(abla=123, abli="sin")
        return p

    def default_b():
        p = ParameterDict(bblal=987, bling="akjh")
        return p

    def default_params():
        p = ParameterDict(something = 3,
                          other = .1239,
                          a = default_a(),
                          b = default_b()
                         )
        return p
    
    # Get a defined set of parameters
    p = default_params()
    
    # Test parameter setting
    p.something = 9
    p.other = "8134"
    
    # Test parameter setting exceptions
    try:
        p.blatti = 7
        raise RuntimeError("Failed to throw exception on erroneous parameter assignment.")
    except:
        pass

    # Test iteration:
    for k in p.keys():
        print k 
    for k in p.iterkeys():
        print k 
    for v in p.values():
        print v 
    for v in p.itervalues():
        print v 
    for k,v in p.items():
        print k,v
    for k,v in p.iteritems():
        print k,v
    
    # Test random access:
    ap1 = p.a
    ap2 = p["a"]
    assert ap1 is ap2
    
    # Test printing of parameter set
    print 
    print "str(p):"
    print str(p)
    print 
    print "repr(p):"
    print repr(p)
    print 
    
    # Test copy
    q = p.copy()
    q.something = "q specific!"
    q.a.abla = "q.a specific!"
    print 
    print "Should be different:"
    print repr(p)
    print repr(q)
    
    # Test update
    p.update(q)
    print 
    print "Should be equal:"
    print repr(q)
    print repr(p)

    # Test indented formatting:
    print
    print q.format()

    print p

# Run doctest
def _test():
    import doctest
    doctest.testmod()
if __name__ == "__main__":
    _test()
