/*  Notes on Scala for Java Programmers
    Scala Expression Tree using a Traditional Object-Oriented Approach
    Adapted by H. Conrad Cunningham

1234567890123456789012345678901234567890123456789012345678901234567890

2008-09-10: (V1) Based partially on code from "A Scala Tutorial
            for Java Programmers" for a programming assignment
2016-02-04: (V2) Adapted for another assignment.
2018-02-03: (V2a) Updated comments.
2019-02-01: Updated comments
2022-02-25: Updated to be compatible with Scala 3 (procedures)

This program was constructed to coordinate with ExprCase.scala.

I created this traditional OO version from the case-class-based
functional version adapted from Section 6 of "A Scala Tutorials for
Java Programmers" for a programming assignment in the initial
Scala-based offering of Multiparadigm Programming (later numbered CSci
556) in Fall 2008. I later adapted it for a programming assignment in
the Scala-based Spring 2016 offering of CSci 555 Functional
Programming. 

*/

/* Expression Tree class hierarchy

   Note: We probably should override the definition of "equals" as
   well as "toString" in the subclasses of Tree.
*/

abstract class Tree {
  type Environment = String => Double

  // Evaluate this expression in environment env
  def eval(env: Environment): Double

  // Determine derivative of expression with respect to variable v
  def derive(v: String): Tree

  // Simplify expression by replacing constant subexpressions by a
  // constant
//def simplify: Tree
}


/* Addition operator subclass Sum */

class Sum(l: Tree, r: Tree) extends Tree {

  def eval(env: Environment) = l.eval(env) + r.eval(env)

  def derive(v: String)      = new Sum(l.derive(v), r.derive(v))

  // Override definition of toString from Any
  override def toString = "Sum(" + l + "," + r + ")"
}

/* Variable (name) subclass Var */

class Var(n: String) extends Tree {

  def eval(env: Environment) = env(n)

  def derive(v: String)      = 
    if (v == n) new Const(1) else new Const(0.0)

  // Accessor for name attribute
  def getn                   = n

  override def toString = "Var(" + n + ")"
}


/* Constant (value) subclass Const */

class Const(v: Double) extends Tree {

  def eval(env: Environment) = v 

  def derive(v: String)      = new Const(0.0)

  // Accessor for value  attribute
  def getv                   = v

  override def toString = "Const(" + v + ")"
}


/* Main method for testing the expression Tree hierarchy  */

object ExprObj {
  
  type Environment = String => Double

  def main(args: Array[String]): Unit = {

    val env: Environment = { case "x" => 5 case "y" => 7 }

    println("Begin testing expression tree program -- Subclass version")

    val c0:  Tree = new Const(0.0)
    val c1:  Tree = new Const(1.0)
    val c3:  Tree = new Const(3.0)
    val c6:  Tree = new Const(6.0)
    val c7:  Tree = new Const(7.0)
    val cm3: Tree = new Const(-3.0)

    println("Expression: " + c0)
    println("Evaluation with x=5, y=7: "  + c0.eval(env))
    println("Derivative relative to x:\n" + c0.derive("x"))
    println("Derivative relative to y:\n" + c0.derive("y"))
//  println("Simplification\n" + c0.simplify)
    println(" ")

    println("Expression: "                + cm3)
    println("Evaluation with x=5, y=7: "  + cm3.eval(env))
    println("Derivative relative to x:\n" + cm3.derive("x"))
    println("Derivative relative to y:\n" + cm3.derive("y"))
//  println("Simplification\n" + cm3.simplify)
    println(" ")

    val x: Tree   = new Var("x")
    val y: Tree   = new Var("y")
    val z: Tree   = new Var("z")  /* no value in env */

    println("Expression: "                + x)
    println("Evaluation with x=5, y=7: "  + x.eval(env))
    println("Derivative relative to x:\n" + x.derive("x"))
    println("Derivative relative to y:\n" + x.derive("y"))
//  println("Simplification\n"            + x.simplify)
    println(" ")

    println("Expression: "                + z)
//  Undefined variable.  There is no provision currently to handle this.
//  println("Evaluation with x=5, y=7: "  + z.eval(env))
    println("Derivative relative to x:\n" + z.derive("x"))
    println("Derivative relative to y:\n" + z.derive("y"))
//  println("Simplification\n"            + z.simplify)
    println(" ")

    val s0L: Tree = new Sum(c0,c3)
    val s0R: Tree = new Sum(c3,c0)
    val s1:  Tree = new Sum(c7,cm3)
    val s2:  Tree = new Sum(c1,y)
    val s3:  Tree = new Sum(x,c3)
    val s4:  Tree = new Sum(x,y)
    val s5:  Tree = new Sum(s1,s0L)
    val s6:  Tree = new Sum(new Sum(s1,s2),new Sum(s1,s4))

    println("Expression: "                + s0L)
    println("Evaluation with x=5, y=7: "  + s0L.eval(env))
    println("Derivative relative to x:\n" + s0L.derive("x"))
    println("Derivative relative to y:\n" + s0L.derive("y"))
//  println("Simplification\n"            + s0L.simplify)
    println(" ")

    println("Expression: "                + s0R)
    println("Evaluation with x=5, y=7: "  + s0R.eval(env))
    println("Derivative relative to x:\n" + s0R.derive("x"))
    println("Derivative relative to y:\n" + s0R.derive("y"))
//  println("Simplification\n"            + s0R.simplify)
    println(" ")

    println("Expression: "                + s1)
    println("Evaluation with x=5, y=7: "  + s1.eval(env))
    println("Derivative relative to x:\n" + s1.derive("x"))
    println("Derivative relative to y:\n" + s1.derive("y"))
//  println("Simplification\n" + s1.simplify)
    println(" ")

    println("Expression: "                + s2)
    println("Evaluation with x=5, y=7: "  + s2.eval(env))
    println("Derivative relative to x:\n" + s2.derive("x"))
    println("Derivative relative to y:\n" + s2.derive("y"))
//  println("Simplification\n"            + s2.simplify)
    println(" ")

    println("Expression: "                + s3)
    println("Evaluation with x=5, y=7: "  + s3.eval(env))
    println("Derivative relative to x:\n" + s3.derive("x"))
    println("Derivative relative to y:\n" + s3.derive("y"))
//  println("Simplification\n"            + s3.simplify)
    println(" ")

    println("Expression: "                + s4)
    println("Evaluation with x=5, y=7: "  + s4.eval(env))
    println("Derivative relative to x:\n" + s4.derive("x"))
    println("Derivative relative to y:\n" + s4.derive("y"))
//  println("Simplification\n"            + s4.simplify)
    println(" ")

    println("Expression: "                + s5)
    println("Evaluation with x=5, y=7: "  + s5.eval(env))
    println("Derivative relative to x:\n" + s5.derive("x"))
    println("Derivative relative to y:\n" + s5.derive("y"))
//  println("Simplification\n"            + s5.simplify)
    println(" ")

    println("Expression: "                + s6)
    println("Evaluation with x=5, y=7: "  + s6.eval(env))
    println("Derivative relative to x:\n" + s6.derive("x"))
    println("Derivative relative to y:\n" + s6.derive("y"))
//  println("Simplification\n"            + s6.simplify)
    println(" ")

    val exp: Tree = new Sum(new Sum(x,x),new Sum(c7,y))

    println("Expression: "                + exp)
    println("Evaluation with x=5, y=7: "  + exp.eval(env))
    println("Derivative relative to x:\n" + exp.derive("x"))
    println("Derivative relative to y:\n" + exp.derive("y"))
    println(" ")
  }
}
