Simple Memoization in Haskell

Being a beginner in Haskell, it took me quite a while to figure out a nice solution to this problem. (Of course, since I’m a beginner, I can’t guarantee that better solutions don’t exist!)

What’s the problem? Well, Haskell is, of course, a purely functional language, meaning that:

  • It has no real state to speak of,
  • Everything that’s calculated, is calculated by recursion.

See any contradictoins there? The classic solution to recursive problems is memoization. But memoization requires us to save state — something that’s unnatural and difficult in Haskell! So I’ll walk through a simple but general way to solve this problem.

Let’s use the classic example of the Fibonacci numbers, because I’m too lazy to come up with anything else and complexity would only get in the way. Note, however, that our solution works on recursive formulas of any number of variables.

fibs 0 = 1
fibs 1 = 1
fibs n = (fibs (n-1)) + (fibs (n-2))

Looks great, but if you try “fibs 40” you’ll never get an answer. We’re going to fix this by making a Data.Array. This array will store our Fibonacci numbers. The only downside is that we have to know ahead of time how large we want our Array to be.

OK, first we need to know how to make an Array. For more info on Arrays, you should check out the relevant article on A Gentle Introduction to Haskell. But for now, we’re going to use the mkArray function presented in that article:

import Data.Array
mkArray :: (Ix a) => (a -> b) -> (a,a) -> Array a b
mkArray f bnds = array bnds [(i, f i) | i <- range bnds]

-- make a one-dimensional array indexed from one to ten,
-- element at i is equal to i
arr1 = mkArray (\x -> x) (1,10)

-- make a 2d array from (0,0) to (10,10),
-- entry (i,j) is equal to i+j
arr2 = mkArray (\e -> (fst e) + (snd e)) ((0,0),(10,10))

Good. Now, how do we apply this to fibs in a general way?  Well, first we’ll transform our function into a “let” construct:

fibs x =
       let f 0 = 1
           f 1 = 1
           f n = (fibs (n-1)) + (fibs (n-2))
        in f x

Now we’re going to take “fibs” out and make it its own function, using the “let” construct to build our array.

fibarr = mkArray func (0,10000)
        let ...

fibs x = fibarr!(x)

Here we see how to access our array. More generally, if we want to access an n-dimensional array at a1,…,an, we have to access it using a tuple:  myarray!(a1,a2,…,an).

Ok, the last step is to connect the “func” used in building the array to our let clause down below. We’ll do this simply by defining func:

fibarr = mkArray func (0,10000)
    where func elem =
            let f 0 = 1
                f 1 = 1
                f n = (fibs (n-1)) + (fibs (n-2))
             in f elem

fibs x = fibarr!(x)

Note here that elem is a tuple in the more general case. With more than three variables, you may need to write your own functions to access the elements of a tuple beyond the first and second, or download the appropriate library. But here’s how the construct looks in the 2-dimensional case:

-- li and ui are lower and upper bounds
myarr = mkArray func ((l1,l2),(u1,u2))
    where func elem =
             let f x1 x2 = ...
                 ...
                 ...
              in f (fst elem) (snd elem)

origfunc x1 x2 = myarr!(x1,x2)

And here’s our complete code for fibs:

import Data.Array
mkArray :: (Ix a) => (a -> b) -> (a,a) -> Array a b
mkArray f bnds = array bnds [(i, f i) | i <- range bnds]

fibarr :: Array (Int) Integer
fibarr = mkArray func (0,10000)
    where func elem =
            let f 0 = 1
                f 1 = 1
                f n = (fibs (n-1)) + (fibs (n-2))
             in f elem

fibs :: Int -> Integer
fibs x = fibarr!(x)

main = do print (fibs 10000)
Advertisements