/* Exploring Languages with Interprters and Functional Programming
   Chapter 9: Recursion Styles and Efficiency -- Exponentiation
   Copyright (C) 2018, H. Conrad Cunningham

1234567890123456789012345678901234567890123456789012345678901234567890

2016-02-10: Developed from 2015 Elixir version (for CSci 555)
2016-07-18: Modified to use BigInt instead of Double
            Also to remove b argument on exptIter in expt3
2018-07-04: Updated comments for 2018 textbook

I adapted this exponentiation module from an Elixir version, which
was, in turn, adapted from a Lua version, which was, in turn, adapted
from the Scheme code in section 1.2.4 of Abelson and Sussman's
Structure and Interpretation of Computer Programs (SICP).

I have only tested this minimally.

Scala and functional programming highlights:
- Use match statements for integers
- Use nested function definitions
- Call sys.error for error exits
- Use accumulating parameters
- Use tail recursion
- Note complexity measures

*/

object Expt {

  /* Function "expt1" computes b^n (b raised to power n) using
     backward linear recursion.

     Time complexity:  O(n) recursive calls
     Space complexity: O(n) active recursive calls
  */
  
  def expt1(b: BigInt, n: Int): BigInt = n match {
    case 0          => 1
    case m if m > 0 => b * expt1(b,m-1)
    case _          =>
      sys.error("Cannot raise to a negative power " + n)
  }
  
  /* Function "expt2" computes b^n using tail recursion.

      Time complexity:  O(n) recursive calls
      Space complexity: O(1) with tail call optimization
  */

  def expt2(b: BigInt, n: Int): BigInt = {

    // private tail recursive auxiliary, called only if n >= 0
    def exptIter(b1: BigInt, n1: Int, p: BigInt): BigInt = n1 match {
      case 0 => p
      case m => exptIter(b1,m-1,b1*p)
    }

    if (n >= 0)
      exptIter(b,n,1)
    else
      sys.error("Cannot raise to negative power " + n )
  }

  /* Function "expt3" computes b^n using a logarithmic algorithm and
     backward recursion.  It takes advantage of squaring.

     b^n = (b^(n/2)) * (b^(n/2)) if n evenxs
     b^n = b * b^(n-1) if n odd

     Time complexity:  O(log n) recursive calls
     Space complexity: O(log n) active recursive calls needing
                                stack frames
   */
   
  def expt3(b: BigInt, n: Int): BigInt = {

    // private tail recursive auxiliary, called only if n >= 0
    def exptIter(n1: Int): BigInt = n1 match {
      case 0                 => 1
      case m if (m % 2 == 0) => // i.e. even
        val exp = exptIter(m/2)
        exp * exp               // backward recursion
      case m                 => // i.e. odd
        b * exptIter(m-1)       // backward recursion
    }
    
    if (n >= 0)
      exptIter(n)
    else
      sys.error("Cannot raise to negative power " + n )
  }

  def main(args: Array[String]) {
    println("expt1(2,10) = " + expt1(2,10))
    println("expt2(2,10) = " + expt2(2,10))
    println("expt3(2,10) = " + expt3(2,10))
  }

}

