# CSci 556, Multiparadigm Programming, Fall 2018
# Expression Tree Calculator; Functional Evaluation Table Version
# H. Conrad Cunningham

# 345678901234567890123456789012345678901234567890123456789012345678901234567890

# 2018-09-16: Develop from ExprFuncMod.py code
# 2018-11-02: Reformatted with black 

from abc import ABC, abstractmethod


# Make validity checks module-level functions


def is_valid_name(name):  # any nonempty string for now
    return isinstance(name, str) and len(name) > 0
    # None must not be valid name!


def is_valid_value(value):  # any number for now
    return isinstance(value, (int, float, complex))
    # None must not be valid value!


# Tree hierarchy holds expression tree data for type checks;
# Could use dataclasses


class Tree(ABC):
    pass


class Sum(Tree):
    def __init__(self, l, r):
        self.left = l if isinstance(l, Tree) else None
        self.right = r if isinstance(r, Tree) else None

    def __repr__(self):
        return f"Sum({repr(self.left)},{repr(self.right)})"


class Var(Tree):
    def __init__(self, n):
        self.name = n if is_valid_name(n) else None

    def __repr__(self):
        return f"Var({self.name})"


class Const(Tree):
    def __init__(self, v):
        self.value = v if is_valid_value(v) else None

    def __repr__(self):
        return f"Const({self.value})"


# Build evaluation table that maps class name string to function object


def eval_Sum(tree, env={}):
    lv = eval(tree.left, env) if tree.left else None
    rv = eval(tree.right, env) if tree.right else None
    return lv + rv if lv is not None and rv is not None else None


def eval_Var(tree, env={}):
    return env.get(tree.name)


def eval_Const(tree, env={}):
    return tree.value


eval_tab = {
    Sum.__name__: eval_Sum,
    Var.__name__: eval_Var,
    Const.__name__: eval_Const,
}


def eval(tree, env={}):
    if isinstance(tree, Tree):
        ef = eval_tab.get(tree.__class__.__name__)
        return ef(tree, env) if ef else None
    else:
        return None


# Build derivative table that maps class name string to function object


def derive_Sum(tree, n):
    dl = derive(tree.left, n) if tree.left else None
    dr = derive(tree.right, n) if tree.right else None
    return Sum(dl, dr) if dl and dr else None


def derive_Var(tree, n):
    if is_valid_name(n) and n == tree.name:
        return Const(1)
    elif tree.name is not None:
        return Const(0)
    else:
        return None


def derive_Const(tree, n):
    return Const(0) if tree.value is not None else None


derive_tab = {
    Sum.__name__: derive_Sum,
    Var.__name__: derive_Var,
    Const.__name__: derive_Const,
}


def derive(tree, n):
    if isinstance(tree, Tree):
        df = derive_tab.get(tree.__class__.__name__)
        return df(tree, env) if df else None
    else:
        return None


# Build simplification table that maps class name string to function object


def simplify_Sum(tree):
    sl = simplify(tree.left) if tree.left else None
    sr = simplify(tree.right) if tree.right else None
    lc, rc = isinstance(sl, Const), isinstance(sr, Const)
    if lc and rc:
        return Const(sl.value + sr.value)
    elif lc and sl.value == 0:  # Additive identity
        return sr
    elif rc and sr.value == 0:  # Additive identity
        return sl
    return Sum(sl, sr) if sl and sr else None


def simplify_Var(tree):
    return tree if tree.name is not None else None


def simplify_Const(tree):
    return tree if tree.value is not None else None


simplify_tab = {
    Sum.__name__: simplify_Sum,
    Var.__name__: simplify_Var,
    Const.__name__: simplify_Const,
}


def simplify(tree):
    if isinstance(tree, Tree):
        sf = simplify_tab.get(tree.__class__.__name__)
        return sf(tree) if sf else None
    else:
        return None


# Smoke testing code (should be recoded to compare against desired output)

if __name__ == "__main__":

    def smoke_test02(expr, envir):
        print(f"Expression: {expr}")
        print(f"Evaluation with x=5, y=7:\n  {eval(expr,envir)}")
        print(f"Derivative relative to x:\n  {derive(expr,'x')}")
        print(f"Derivative relative to y:\n  {derive(expr,'y')}")
        print(f"Derivative relative to None:\n  {derive(expr,None)}")
        print(f"Simplification:\n  {simplify(expr)}")
        print("")

    print("\nBegin smoke testing expression tree program\n")

    env = {"x": 5, "y": 7}

    x = Var("x")
    y = Var("y")
    z = Var("z")  # no value in env

    c0 = Const(0.0)
    c1 = Const(1.0)
    c3 = Const(3.0)
    c6 = Const(6.0)
    c7 = Const(7.0)
    cm3 = Const(-3.0)

    smoke_test02(c0, env)
    smoke_test02(cm3, env)

    x = Var("x")
    y = Var("y")
    z = Var("z")  # no value in env

    smoke_test02(x, env)
    smoke_test02(y, env)
    smoke_test02(z, env)

    s0L = Sum(c0, c3)
    s0R = Sum(c3, c0)
    s1 = Sum(c7, cm3)
    s2 = Sum(c1, y)
    s3 = Sum(x, c3)
    s4 = Sum(x, y)
    s5 = Sum(s1, s0L)
    s6 = Sum(Sum(s1, s2), Sum(s1, s4))

    smoke_test02(s0L, env)
    smoke_test02(s0R, env)
    smoke_test02(s1, env)
    smoke_test02(s2, env)
    smoke_test02(s3, env)
    smoke_test02(s4, env)
    smoke_test02(s5, env)
    smoke_test02(s6, env)

    exp = Sum(Sum(x, x), Sum(c7, y))
    exp2 = Sum(Sum(Const(0), Const(0)), Sum(Const(0), Const(1)))

    smoke_test02(exp, env)
    smoke_test02(exp2, env)

    n1 = Const(None)
    n2 = Var(None)
    n3 = Sum(None, None)

    smoke_test02(n1, env)
    smoke_test02(n2, env)
    smoke_test02(n3, env)
