--[[ Arithmetic Expression Tree Program Skeleton Recursive Function Version with List-style Nodes H. Conrad Cunningham, Professor Computer and Information Science University of Mississippi Developed for CSci 658, Software Language Engineering, Fall 2013 1234567890123456789012345678901234567890123456789012345678901234567890 2013-08-29: Modified program from author's Scala functional version 2013-09-02: Completed prototype 2013-09-04: Made similar to Recursive Function Record version added valToString as alternative to treeConcat output 2013-09-07: Corrected typos and comments --]] -- Function "treeConcat" takes a recursive list of lists in argument -- "t" and returns the corresponding parenthesized traversal string. -- A list is stored in the low positive integer indices of an -- list-style Lua table. (This was borrowed from the author's KILT -- Utilities Module.) local function treeConcat(t) if type(t) ~= "table" then return tostring(t) end local res = {} for i = 1, #t do res[i] = treeConcat(t[i]) end return "(" .. table.concat(res," ") .. ")" end --[[ ARITHMETIC EXPRESSION TREES This program represents an arithmetic expression tree by a list-style table whose first element is an operator tag and whose subsequent elements are the operands for that operation. Each operand may be an arithmetic expression tree. --]] -- Constants for tree node type tags local SUM_TYPE, SUM_STR = "Sum", "Sum" local VAR_TYPE, VAR_STR = "Var", "Var" local CONST_TYPE, CONST_STR = "Const", "Const" -- Constants for frequent constant expression trees (singleton refs) local CONST_ZERO = {CONST_TYPE,0} local CONST_ONE = {CONST_TYPE,1} -- Function "eval" evaluates expression tree "t" in environment -- "env". It checks the operator (first element of "t") to determine -- what actions to take. local function eval(t,env) if (#t < 2) then error("eval called with an invalid expression tree: " .. treeConcat(t), 2) end if not env then error("eval called with a nil environment.") end if t[1] == SUM_TYPE then -- {SUM_TYPE,left,right} if #t == 3 then return eval(t[2],env) + eval(t[3],env) else error("eval called for type " .. SUM_STR .. " node that does not have exactly two operands: " .. treeConcat(t), 2) end elseif t[1] == VAR_TYPE then -- {VAR_TYPE,name} return env[t[2]] elseif t[1] == CONST_TYPE then -- {CONST_TYPE,num} return t[2] else error("eval called with unknown tree node type: " .. treeConcat(t), 2) end end -- Function "derive" takes an arithmetic expression tree "t" and a -- variable "v" and returns the derivative, another arithmetic -- expression tree. local function derive(t,v) if (#t < 2) then error("derive called with an invalid expression tree: " .. treeConcat(t), 2) end if type(v) ~= "string" then error("derive called with nonstring variable: " .. treeConcat(v), 2) end if t[1] == SUM_TYPE then -- {SUM_TYPE,left,right} if #t == 3 then return {SUM_TYPE, derive(t[2],v), derive(t[3],v)} else error("derive called for type " .. SUM_STR .. " node that does not have exactly two operands: " .. treeConcat(t), 2) end elseif t[1] == VAR_TYPE then -- {VAR_TYPE,name} if v == t[2] then return CONST_ONE else return CONST_ZERO end elseif t[1] == CONST_TYPE then -- {CONST_TYPE,num} return CONST_ZERO else error("derive called with unknown tree node type: " .. treeConcat(t), 2) end end -- Function "valToString" takes an arithmetic expression tree "t" and -- returns a string representation of the expression tree. local function valToString(t) if type(t) == "table" then if t[1] == SUM_TYPE then return SUM_STR .. "(" .. valToString(t[2]) .. "," .. valToString(t[3]) .. ")" elseif t[1] == VAR_TYPE then return VAR_STR .. "(" .. t[2] .. ")" elseif t[1] == CONST_TYPE then return CONST_STR .. "(" .. tostring(t[2]) .. ")" else error("valToString called with unknown tree type: " .. tostring(t[1]), 2) end else error("valToString called with invalid expression: " .. tostring(t), 2) end end -- MAIN PROGRAM local exp = {SUM_TYPE, {SUM_TYPE, {VAR_TYPE,"x"}, {VAR_TYPE,"x"}}, {SUM_TYPE,{CONST_TYPE,7}, {VAR_TYPE,"y"}} } local env = {x = 5, y = 7 } -- print("Expression: " .. treeConcat(exp)) print("Expression: " .. valToString(exp)) print("Evaluation with x=5, y=7: " .. eval(exp,env)) print("Derivative relative to x:\n " .. valToString(derive(exp, "x"))) print("Derivative relative to y:\n " .. valToString(derive(exp, "y")))