#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This module contains utility functions for various stuff.
"""

# Copyright (C) 2007-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2007-06-15
# Last changed: 2009-03-19

from itertools import izip
import swiginac
#import strings

import numpy.testing
_last_m = numpy.testing.memusage()
def printmem(msg):
    global _last_m
    m = numpy.testing.memusage()
    diff = m - _last_m
    print msg, diff, "b, ", diff/1024, "Kb, ", diff/1024**2, "Mb"
    _last_m = m

def dot_product(seq1, seq2):
    return sum(a*b for (a,b) in izip(seq1,seq2))

def make_name_valid(name):
    """Filter away .[] from indexed array names."""
    name = name.replace(".","_")
    name = name.replace("[","_")
    name = name.replace("]","_")
    return name

def index_string(i):
    if isinstance(i, int):
        return str(i)
    return "_".join(str(j) for j in i)

def fe_is_discontinuous(fe):
    sfc_error("FIXME")

    if not isinstance(fe, str):
        fe = strings.finite_element_classname(fe)

    if "_" in fe:
        fe = fe.split('_')[1]

    return fe in ["DiscontinuousLagrange",
                  "VectorDiscontinuousLagrange",
                  "TensorDiscontinuousLagrange",
                  "ArnoldFalkWintherWeakSymU",
                  "ArnoldFalkWintherWeakSymP"]

def fe_is_signed(fe):
    sfc_error("FIXME")

    if not isinstance(fe, str):
        fe = strings.finite_element_classname(fe)

    if "_" in fe:
        fe = fe.split('_')[1]

    return fe in ["ArnoldFalkWintherWeakSymSigma",
                  "Nedelec2Hdiv",
                  "Nedelec2HdivPtv",
                  "Nedelec",
                  "RaviartThomas",
                  "Robust",
                  "RobustPtv"]

def check_range(i, a, b, msg="Invalid range."):
    """Check that i is in [a,b), raise exception otherwise."""
    if (i < a or i >= b) and (a != b):
        raise ValueError(msg)

def unique(sequence):
    s = set()
    for i in sequence:
        if not i in s:
            s.add(i)
            yield i

def indices_subset(indices, keep):
    newindices = []
    for ii in indices:
        jj = tuple((j if keep[i] else None) for (i,j) in enumerate(ii))
        newindices.append(jj)
    return tuple(unique(newindices))

def shape(dims):
    #if len(dims) == 0:
    #    return (1,)
    return tuple(dims)

def permute(shape):
    """Returns a permutation of all multiindices within the range of a rank 0, 1, or 2 tensor shape."""
    if len(shape) == 0:
        return [(0,)]
    if len(shape) == 1:
        return [(k,) for k in range(shape[0])]
    if len(shape) == 2:
        return [(k1,k2) for k1 in range(shape[0]) for k2 in range(shape[1])]
    raise ValueError("Shapes with rank 3 or higher not implemented in permute(shape)")

def list_items(l):
    return zip( range(len(l)), l )

def as_list_with_len(x, wantlen):
    """If x is not a list type, it is repeated in a list wantlen times.
       Otherwise checks it x the correct length. Always returns a list
       with length wantlen or raises an exception."""
    if isinstance(x, tuple):
        x = list(x)
    if not isinstance(x, list):
        x = [x,]
    if len(x) == 1:
        x = wantlen*x
    if len(x) != (wantlen):
        raise ValueError("Got list with size " + str(len(x)) + ", need " + str(wantlen) + ".")
    return x

def matrix_to_list(m):
    return [m[i] for i in range(len(m))]

def list_to_matrix(m, n, l):
    if len(l) != m*n:
        raise ValueError("Invalid sizes: %d * %d != %d" % (m, n, l))
    return swiginac.matrix(m, n, l)

def list_to_vector(l):
    return swiginac.matrix(len(l), 1, l)

def is_function(f):
    return hasattr(f, 'func_name')

def is_functor(f):
    if hasattr(f, 'func_name'):
        return False
    return hasattr(f, '__call__')

def get_func_code(f):
    """Get the func_code object from a function or functor object."""
    if not callable(f):
        raise RuntimeError("Object is not callable!")
    fc = None
    if is_function(f):
        fc = f.func_code
    if is_functor(f):
        fc = f.__call__.im_func.func_code
    return fc

def get_callable_name(f):
    return get_func_code(f).co_name

def get_callable_num_args(f):
    return get_func_code(f).co_argcount

