- 2.1 Egyptian Multiplication
- 2.2 Improving the Algorithm
- 2.3 Thoughts on the Chapter
2.2 Improving the Algorithm
Our multiply1 function works well as far as the number of additions is concerned, but it also does ⌊log n⌋ recursive calls. Since function calls are expensive, we want to transform the program to avoid this expense.
One principle we’re going to take advantage of is this: It is often easier to do more work rather than less. Specifically, we’re going to compute
- r + na
where r is a running result that accumulates the partial products na. In other words, we’re going to perform multiply-accumulate rather than just multiply. This principle turns out to be true not only in programming but also in hardware design and in mathematics, where it’s often easier to prove a general result than a specific one.
Here’s our multiply-accumulate function:
int mult_acc0(int r, int n, int a) { if (n == 1) return r + a; if (odd(n)) { return mult_acc0(r + a, half(n), a + a); } else { return mult_acc0(r, half(n), a + a); } }
It obeys the invariant: r + na = r0 + n0a0, where r0, n0 and a0 are the initial values of those variables.
We can improve this further by simplifying the recursion. Notice that the two recursive calls differ only in their first argument. Instead of having two recursive calls for the odd and even cases, we’ll just modify the value of r before we recurse, like this:
int mult_acc1(int r, int n, int a) { if (n == 1) return r + a; if (odd(n)) r = r + a; return mult_acc1(r, half(n), a + a); }
Now our function is tail-recursive—that is, the recursion occurs only in the return value. We’ll take advantage of this fact shortly.
We make two observations:
- n is rarely 1.
- If n is even, there’s no point checking to see if it’s 1.
So we can reduce the number of times we have to compare with 1 by a factor of 2, simply by checking for oddness first:
int mult_acc2(int r, int n, int a) { if (odd(n)) { r = r + a; if (n == 1) return r; } return mult_acc2(r, half(n), a + a); }
Some programmers think that compiler optimizations will do these kinds of transformations for us, but that’s rarely true; they do not transform one algorithm into another.
What we have so far is pretty good, but we’re eventually going to want to eliminate the recursion to avoid the function call overhead. This is easier if the function is strictly tail-recursive.
Definition 2.1. A strictly tail-recursive procedure is one in which all the tail-recursive calls are done with the formal parameters of the procedure being the corresponding arguments.
Again, we can achieve this simply by assigning the desired values to the variables well be passing before we do the recursion:
int mult_acc3(int r, int n, int a) { if (odd(n)) { r = r + a; if (n == 1) return r; } n = half(n); a = a + a; return mult_acc3(r, n, a); }
Now it is easy to convert this to an iterative program by replacing the tail recursion with a while(true) construct:
int mult_acc4(int r, int n, int a) { while (true) { if (odd(n)) { r = r + a; if (n == 1) return r; } n = half (n); a = a + a; } }
With our newly optimized multiply-accumulate function, we can write a new version of multiply. Our new version will invoke our multiply-accumulate helper function:
int multiply2(int n, int a) { if (n == 1) return a; return mult_acc4(a, n -1, a); }
Notice that we skip one iteration of mult_acc4 by calling it with result already set to a.
This is pretty good, except when n is a power of 2. The first thing we do is subtract 1, which means that mult_acc4 will be called with a number whose binary representation is all 1s, the worst case for our algorithm. So we’ll avoid this by doing some of the work in advance when n is even, halving it (and doubling a) until n becomes odd:
int multiply3(int n, int a) { while (!odd(n)) { a = a + a; n = half(n); } if (n == 1) return a; return mult_acc4(a, n - 1, a); }
But now we notice that were making mult_acc4 do one unnecessary test for n = 1, because were calling it with an even number. So well do one halving and doubling on the arguments before we call it, giving us our final version:
int multiply4(int n, int a) { while (!odd(n)) { a = a + a; n = half(n); } if (n == 1) return a; // even(n — 1)=⇒ n—1 ≠ 1 return mult_acc4(a, half(n -1), a + a); }