Alex headshot

AlBlue’s Blog

Macs, Modularity and More

Introduction to Scala – functions

Howto 2007 Scala

In the previous post, I covered the basics of the Scala interpreter. This time, we'll cover how to do more exciting things than just the basics of arithmetic expressions.

Any programming language allows you to abstract away common behaviour into routines of some description. Functional programming languages represent everything as functions; imperative languages (like C) have a mixture of both functions and procedures with side-effects; object-oriented languages represent everything as classes and methods.

Scala uses a hybrid approach. It's possible to write pure functions (those that don't change the state); procedures (those that do) as well as classes and methods. That is, it's possible to use whichever style of programming suits both you and the task at hand. We'll start off with basic functions and cover other styles in future posts.

Functions are defined with def, and for simple (one-line) functions, the return is not necessary. Functions are called by name, with any arguments passed in parenthesis:

scala> def sum(a:Int,b:Int) = a+b
sum: (Int,Int)Int
scala> sum(1,2)
res0: Int = 3

We need to specify the type of the arguments, using the same notation as before. In this example, the arguments are positional and passed in by value. It becomes easy to write code for calculating the factorial (n*n-2*n-3..*3*2*1) of a number:

scala> def factorial(n:Int):Int = if (n==0) 1 else n * factorial(n-1)
factorial: (Int)Int

In this case, factorial is recursive; it calls itself to calculate the value. We need to say the return type of factorial is Int, and we do that in the same way we do for others in Scala; using : to denote the type.

Why didn't we need to do this for sum? Well, it turns out that the compiler can infer the type of a+b, since it knows what a and b are. It would have been just as correct if you specified the return type, but for brevity we didn't in this example. On the other hand, the factorial code calls itself; so the return type of itself is itself (which doesn't tell the compiler much), so we need to give it an extra nudge. (Some type inference engines can figure out that since 1 is the terminal case, and that 1 is an Int, then the return type of factorial must be an Int as well; however, Scala's doesn't.)

If you want spread the function over multiple lines, you can wrap the code in { and }. The value of a set of expressions is the value of the last expression, or if you need to return out of a function early, use return:

scala> def factorial(n:Int):Int =
     | {
     |   if (n==0)
     |     return 1
     |   else
     |     return n * factorial(n-1)
     | }
factorial2: (BigInt,BigInt)BigInt

You might have noticed that there's little syntax in terms of punctuation in the Scala source. The ; delimiter allows multiple statements to be put on the same line; but it's rarely used in practice, simply because it's not needed. The indentation is conventional as well; spaces have no effect on syntax.

There's a couple of problems with this naïve recursive implementation; firstly, for big numbers, the Int overflows; and secondly, for really big numbers, we'll get a StackOverflowError:

scala> factorial(31)
res1: Int = 738197504
scala> factorial(32)
res2: Int = -2147483648
scala> factorial(33)
res3: Int = -2147483648
scala> factorial(34)
res4: Int = 0
scala> factorial(1000000)
java.lang.StackOverflowError
        at .factorial(<console>:4)

There's a couple of things to do. First, we can replace Int with BigInt, which is allows arbitrary precision arithmetic on integer values:

scala> def factorial(n:BigInt):BigInt = if (n==0) 1 else n * factorial(n-1)
factorial: (BigInt)BigInt
scala> factorial(31)
res5: BigInt = 8222838654177922817725562880000000
scala> factorial(32)
res6: BigInt = 263130836933693530167218012160000000
scala> factorial(1000000)
java.lang.StackOverflowError
        at .factorial(<console>:4)

We can solve the stack overflow by turning the function into a tail-recursive function. This is one in which the last value is a call to itself. In our factorial example, it's not, because if we expand it it looks like this:

factorial(5)
5*(factorial(4))
5*(4*(factorial(3)))
5*(4*(3*(factorial(2))))
5*(4*(3*(2*(factorial(1))))))

At each call, we end up with more and more *s in the expression, until finally Java gives up and we get the StackOverflowError. The way we fix this problem is to create a second argument to the function (sometimes known as an 'accumulator'); but you can think of it as the 'result':

scala> def factorial2(n:BigInt,result:BigInt):BigInt = if (n==0) result else factorial2(n-1,n*result)
scala> factorial2(5,1)
res7: BigInt = 120
scala> factorial2(10000,1)
res8: BigInt = 2846.....

This is important for a couple of reasons. Firstly, what's happening under the covers is something called tail-call-optimisation. That is, if the last call is a call to itself, then instead of creating a new stack, we'll simply replace the current evaluation with a different call, and this optimisation allows us to take constant space. Using our factorial(5) example from earlier:

factorial2(5,1)
factorial2(4,5)
factorial2(3,20)
factorial2(2,60)
factorial2(1,120)

At each point the size of the stack space is constant, so we can perform bigger calculations than we would have been able to do otherwise. Secondly, the Scala runtime uses this optimisation in its calls. (Incidentally, there's a proposal to allow tail-call optimisation to be done at the Java bytecode/JVM level in the future, and a whole host of support for other languages including the invokedymanic bytecode.)

Mind you, the solution isn't perfect; you have to supply a starting value as the initial value of result. If you supply a wrong value, the calculation of factorial could end up being wrong. Ideally, we'd like to abstract that away from the caller of the function, whilst still allowing the function to be defined as tail recursive. In order to achieve this, we can create a function that bootstraps the recursive function:

scala> def factorial(n:BigInt):BigInt = factorial2(n,1)
factorial: (BigInt)BigInt
scala> factorial(10000)
res9: BigInt = 2846.....

We've covered some basics of functions in this post, but there's more to come in future posts. However, next up we'll look at the way we can deal with some simple object types. Stay tuned.