/* Natural Number Arithmetic using Peano-Inspired Structures
   Using Function Module Style with Case Classes
   H. Conrad Cunningham, Professor
   Computer and Information Science
   University of Mississippi

1234567890123456789012345678901234567890123456789012345678901234567890

14 Feb 2012: Developed from regular OO version
10 Feb 2016: Clean code and format, change obselete Scala usage

This version uses case classes and implements the comparson methodson
the case class/object instances. But it implements most functionality
in a module of functions. See the extensive description in the
traditional object-oriented version.

Case classes generated the needed implementations of equals and
toString and simplified construction.

*/

trait Ord[T] {
  def < (that: T): Boolean
  def <=(that: T): Boolean = (this < that) || (this == that)
  def > (that: T): Boolean = !(this <= that)
  def >=(that: T): Boolean = !(this < that)
}

abstract class Nat extends Ord[Nat]

case object Zero extends Nat {
  def <(that:Nat): Boolean = that match {
    case Succ(pred:Nat) => true
    case Zero           => false
  }
}
    
case class Succ(p:Nat) extends Nat {
  def <(that:Nat): Boolean = that match {
    case Succ(q) => p < q
    case Zero    => false
  }
}

case object Err extends Nat {
  def <(n: Nat): Boolean = 
    sys.error("No ordering between Err and " + n)
}

object Nat {

  def pred(n: Nat): Nat = n match {
    case Succ(p) => p
    case _       => Err
  }

  // uses patter match on a pair of Nats
  def add(m:Nat,n:Nat): Nat = (m,n) match {
    case (Succ(_),Succ(q)) => add(Succ(m),q)
    case (_,Zero)          => m
    case (Zero,_)          => n
    case (_,Err)           => Err        
    case (Err,_)           => Err
    }

  def sub(m:Nat,n:Nat): Nat = (m,n) match {
    case (Succ(p),Succ(q)) => sub(p,q)
    case (_,Zero)          => m
    case (Zero,_)          => Err
    case (_,Err)           => Err
    case (Err,_)           => Err
  }

  def toNat(n:Int): Nat = 
    if (n > 0) Succ(toNat(n-1)) 
    else if (n == 0) Zero 
    else Err

  def toInt(n:Nat): Int = n match {
    case Succ(p) => 1 + toInt(p)
    case Zero    => 0
    case Err     => -1
  }
}

import Nat._

object TestCaseExtNats {
  // Main method for testing
  def main(args: Array[String]) {

  // Constants to use in tests
  val three = Succ(Succ(Succ(Zero)))
  val six   = Succ(Succ(Succ(three)))
  val big   = 100
  val bad   = -1

  // Test conversion from Int to Nat and testing toString
  println("toNat(0)             ==> " + toNat(0))
  println("toNat(5)             ==> " + toNat(5))
  println("toNat(big)           ==> " + toNat(big))
  println("toNat(bad)           ==> " + toNat(bad))

  // Test Zero methods

  println("Zero + Zero          ==> " + add(Zero,Zero))
  println("Zero + three         ==> " + add(Zero,three))
  println("Zero + Err           ==> " + add(Zero,Err))

  println("Zero - Zero          ==> " + sub(Zero,Zero))
  println("Zero - three         ==> " + sub(Zero,three))
  println("Zero - Err           ==> " + sub(Zero,Err))

  println("Zero == Zero         ==> " + (Zero == Zero))
  println("Zero == three        ==> " + (Zero == three))
  println("Zero == Err          ==> " + (Zero == Err))
  println("Zero == bad          ==> " + (Zero == bad))

  println("Zero < Zero          ==> " + (Zero < Zero))
  println("Zero < three         ==> " + (Zero < three))

  print(  "Zero < Err           ==> ")
  try   { println((Zero < Err)) }
  catch { case ex: Throwable => println("Error:  " + ex.getMessage)}

  println("Zero <= Zero         ==> " + (Zero <= Zero))

  print(  "Zero <= Err          ==> ")
  try   { println((Zero <= Err)) }
  catch { case ex: Throwable => println("Error:  " + ex.getMessage) }

  println("Zero > Zero          ==> " + (Zero > Zero))
  println("Zero > three         ==> " + (Zero > three))

  print(  "Zero > Err           ==> ")
  try   { println((Zero > Err)) }
  catch { case ex: Throwable => println("Error:  " + ex.getMessage) }

  println("pred(Zero)           ==> " + pred(Zero))
  println("toInt(Zero)          ==> " + toInt(Zero))

  // Test Succ methods

  println("three + Zero         ==> " + add(three,Zero))
  println("three + three        ==> " + add(three,three))
  println("three + Err          ==> " + add(three,Err))

  println("three - Zero         ==> " + sub(three,Zero))
  println("three - three        ==> " + sub(three,three))
  println("three - six          ==> " + sub(three,six))
  println("three - Err          ==> " + sub(three,Err))

  println("three == Zero        ==> " + (three == Zero))
  println("three == three       ==> " + (three == three)) 
  println("three == six         ==> " + (three == six))

  println("three < Zero         ==> " + (three < Zero))
  println("three < three        ==> " + (three < three))
  println("three < six          ==> " + (three < six))
  println("six   < three        ==> " + (six < three))

  print(  "three < Err          ==> ")
  try   { println((three < Err)) }
  catch { case ex: Throwable => println("Error:  " + ex.getMessage) }

  println("three <= Zero        ==> " + (three <= Zero))
  println("three <= three       ==> " + (three <= three))
  println("three <= six         ==> " + (three <= six))
  println("six   <= three       ==> " + (six <= three))

  print(  "three <= Err         ==> ")
  try   { println((three <= Err)) }
  catch { case ex: Throwable => println("Error:  " + ex.getMessage) }

  println("three > Zero         ==> " + (three > Zero))
  println("three > three        ==> " + (three > three))
  println("three > six          ==> " + (three > six))
  println("six   > three        ==> " + (six > three))

  print(  "three > Err          ==> ")
  try   { println((three > Err)) }
  catch { case ex: Throwable => println("Error:  " + ex.getMessage) }

  println("pred(three)          ==> " + pred(three))
  println("toInt(three)         ==> " + toInt(three))

  // Test Err methods

  println("Err + Zero           ==> " + add(Err,Zero))
  println("Err + three          ==> " + add(Err,three))
  println("Err + Err          ==> " + add(Err,Err))

  println("Err - Zero           ==> " + sub(Err,Zero))
  println("Err - three          ==> " + sub(Err,three))
  println("Err - Err            ==> " + sub(Err,Err))

  println("Err == Zero          ==> " + (Err == Zero))
  println("Err == three         ==> " + (Err == three)) 
  println("Err == Err           ==> " + (Err == Err)) 
  println("Err == bad           ==> " + (Err == bad)) 

  print(  "Err < three          ==> ")
  try   { println((Err < three)) }
  catch { case ex: Throwable => println("Error:  " + ex.getMessage) }

  println("pred(Err)            ==> " + pred(Err))
  println("toInt(Err)           ==> " + toInt(Err))

  // Test precondition on Succ construction
  print(  "Succ(Err)          ==> ")
  try   { println(Succ(Err)) }
  catch { case ex: Throwable => 
            println("Error on Succ constructor precondition check:  "
	    + ex.getMessage)
	 }
  }
}
