# CSci 556, Multiparadigm Programming, Fall 2018
# Expression Tree Calculator; Using Data Classes, Annotations, ABC
# H. Conrad Cunningham

# This version requires Python 3.7+. It is designed for use with
# data classes, deferred evaluation of annotations, and the "mypy"
# type checker.

# 345678901234567890123456789012345678901234567890123456789012345678901234567890

# 2018-09-15: Develop from ExprABC
# 2018-09-16: Put repeated testing code in procedure; update comments
# 2018-11-02: Reformatted with black

from __future__ import annotations  # 3.7 deferred annotations
from typing import cast, Any, Mapping, Optional, Union
from dataclasses import dataclass  # 3.7
from abc import ABC, abstractmethod

# Type aliases (module-level variables)
Name = str
Number = Union[int, float, complex]
NumberR = (
    int,
    float,
    complex,
)  # Union type Number not usable @runtime
Env = Mapping[Name, Number]


class Tree(ABC):

    # add function signature type hints (annotations)
    @abstractmethod
    def eval(self, env: Env = {}) -> Optional[Number]:
        pass

    @abstractmethod
    def derive(self, v: Optional[Name]) -> Optional[Tree]:
        pass

    @abstractmethod
    def simplify(self) -> Optional[Tree]:
        pass

    def is_valid_name(self, name: Any) -> bool:  # any nonempty str now
        return isinstance(name, Name) and len(name) > 0
        # None must not be valid name!

    def is_valid_value(self, value: Any) -> bool:  # any number now
        return isinstance(value, NumberR)
        # None must not be valid value!


# Dataclass decorator generates __init__, __repr__, __eq__; uses type
# annotations for instance variables


@dataclass
class Sum(Tree):
    left: Optional[Tree]
    right: Optional[Tree]

    def eval(self, env: Env = {}) -> Optional[Number]:
        lv = self.left.eval(env) if self.left else None
        rv = self.right.eval(env) if self.right else None
        return lv + rv if lv is not None and rv is not None else None

    def derive(self, v: Optional[Name]) -> Optional[Tree]:
        dl = self.left.derive(v) if self.left else None
        dr = self.right.derive(v) if self.right else None
        return Sum(dl, dr) if dl and dr else None

    def simplify(self) -> Optional[Tree]:
        sl = self.left.simplify() if self.left else None
        sr = self.right.simplify() if self.right else None
        lc, rc = isinstance(sl, Const), isinstance(sr, Const)
        if lc and rc:  # Note several casts for use of mypy
            # No change to result, just allow type checking
            # ints and reals are complex for addition
            return Const(
                cast(complex, cast(Const, sl).value)
                + cast(complex, cast(Const, sr).value)
            )
        elif lc and cast(Const, sl).value == 0:
            return sr
        elif rc and cast(Const, sr).value == 0:
            return sl
        return Sum(sl, sr) if sl and sr else None


@dataclass
class Var(Tree):
    name: Optional[Name]

    def eval(self, env: Env = {}) -> Optional[Number]:
        return env.get(cast(str, self.name))
        # if self.name is not None else None Argument to
        # "get" above might be None, but "get" works by
        # returning None; added "cast" to
        # "get" argument to pass mypy type check

    def derive(self, v: Optional[Name]) -> Optional[Tree]:
        if self.is_valid_name(v) and v == self.name:
            return Const(1)
        elif self.name:
            return Const(0)
        else:
            return None

    def simplify(self) -> Optional[Tree]:
        return self if self.name is not None else None


@dataclass
class Const(Tree):
    value: Optional[Number]

    def eval(self, env: Env = {}) -> Optional[Number]:
        return self.value

    def derive(self, v: Optional[Name]):
        return Const(0) if self.value is not None else None

    def simplify(self) -> Optional[Tree]:
        return self if self.value is not None else None


# Smoke testing code (should be recoded to compare against
# desired output)

if __name__ == "__main__":

    def smoke_test01(expr, envir):
        print(f"Expression: {expr}")
        print(f"Evaluation with x=5, y=7:\n  {expr.eval(envir)}")
        print(f"Derivative relative to x:\n  {expr.derive('x')}")
        print(f"Derivative relative to y:\n  {expr.derive('y')}")
        print(f"Derivative relative to None:\n  {expr.derive(None)}")
        print(f"Simplification:\n  {expr.simplify()}")
        print("")

    print("\nBegin smoke testing expression tree program\n")

    env = {"x": 5, "y": 7}

    x = Var("x")
    y = Var("y")
    z = Var("z")  # no value in env

    c0 = Const(0.0)
    c1 = Const(1.0)
    c3 = Const(3.0)
    c6 = Const(6.0)
    c7 = Const(7.0)
    cm3 = Const(-3.0)

    smoke_test01(c0, env)
    smoke_test01(cm3, env)

    x = Var("x")
    y = Var("y")
    z = Var("z")  # no value in env

    smoke_test01(x, env)
    smoke_test01(y, env)
    smoke_test01(z, env)

    s0L = Sum(c0, c3)
    s0R = Sum(c3, c0)
    s1 = Sum(c7, cm3)
    s2 = Sum(c1, y)
    s3 = Sum(x, c3)
    s4 = Sum(x, y)
    s5 = Sum(s1, s0L)
    s6 = Sum(Sum(s1, s2), Sum(s1, s4))

    smoke_test01(s0L, env)
    smoke_test01(s0R, env)
    smoke_test01(s1, env)
    smoke_test01(s2, env)
    smoke_test01(s3, env)
    smoke_test01(s4, env)
    smoke_test01(s5, env)
    smoke_test01(s6, env)

    exp = Sum(Sum(x, x), Sum(c7, y))
    exp2 = Sum(Sum(Const(0), Const(0)), Sum(Const(0), Const(1)))

    smoke_test01(exp, env)
    smoke_test01(exp2, env)

    n1 = Const(None)
    n2 = Var(None)
    n3 = Sum(None, None)

    smoke_test01(n1, env)
    smoke_test01(n2, env)
    smoke_test01(n3, env)
