In the first half of this article we discussed convolution, using FFT for convolution, and the difficulties associated with using floating-point numbers when exact integer output is required. This second half of the article is about the Schonhage-Strassen algorithm for fast integer multiplication. It multiplies bit integers in time via a modular arithmetic based FFT, known as the number theoretic transform (NTT). Schonhage-Strassen is no longer asymptotically the fastest, but it held the title for over thirty years, and its ideas form the basis for newer algorithms that slightly reduce the runtime.
Key to Schonhage-Strassen is NTT, which is a modified DFT. The DFT basically multiplies the original signal by an th root of unity raised to various powers. A standard DFT uses the root , but as shown in the appendix of this article, any th root of unity works, where a root of unity is defined as such that for and . The NTT uses roots of unity that arise from modular arithmetic. For example, 2 is a 5th root of unity modulo 31, because , but . The NTT returns the results of the convolution modulo but otherwise isn't that different from a standard DFT:
- FFT algorithms such as Cooley-Turkey also work on NTT, so we can compute an NTT in time
- As shown in the appendix, pointwise multiplication in the NTT domain is equivalent to convolving the original signal
Together these two things allow us to use NTT to do convolution. There are two big questions left to answer:
- What modulo what do we want to use? Unlike with complex numbers, there are many th roots of unity, depending on the we choose. For example, is a tenth root of unity modulo for all integers . Are certain better than others?
- How should we split the integer? Given an -bit integer, should it be represented as a length signal of bits, a length signal of pairs of bits, etc.
These two questions are somewhat related because we want to choose an large enough such that each element of the convolution is within --otherwise, it'll overflow modulo and be wrong. The max value of each element in the convolution depends on how the integer is split, so these things must be decided together.
Suppose we split our -bit integer into blocks, each of length (so ). We now want to calculate how large needs to be to avoid overflow. By the definition of a cyclic convolution,
Since each element of and is an -bit integer, its value is at most . So:
We therefore need .
Don't recurse too much
TLDR: FFT-based multiplication is recursive because no matter how we split our -bit integer into a signal, our convolution result's elements are of unbounded size. We can't avoid the recursive calls from multiplying signals in the frequency domain. But with and for any , we can replace what would've been additional recursive multiplies with bit-shifts, adds, and subtracts. This reduction in recursion is key to good asymptotic performance.
Reproduced below is the pseudocode from part 1 of this article for computing , where and are two signals of length :
- Use FFT to compute and . Time:
- By the convolution theorem, so we pointwise multiply our results from step 1. Time:
- Use FFT to apply the inverse DFT to our result from the previous step. Time:
You might be wondering why Schonhage-Strassen is when our pseudocode above seems to suggest we only need time. The problem is that the above costs are in terms of number of arithmetic operations needed, and arithmetic isn't because we are dealing with arbitrarily large integers. We need bits to store each element of the FFT result, and as shown above, so . Since , no matter how what we set and , will grow unboundedly as grows. This means step 2 of our pseudocode involves multiplying integers of unbounded size--time for recursion! (Our recursion ends up being deep, and each recursion level takes time, giving our final runtime.)
There's more, though--FFT requires us to multiply elements of our input signal by raised to some power. For more detail, see the pseudocode for FFT-based NTT below:
# assuming b is a power of 2; Cooley-Turkey algorithm fft(signal, b, omega, M): if b == 1: return [signal] signal_even = [signal, signal, ..., signal[b - 2]] signal_odd = [signal, signal, ..., signal[b - 1]] fft_even = fft(signal_even, b / 2, omega^2, M) fft_odd = fft(signal_odd, b / 2, omega^2, M) res = array(b) for k in [0, b/2): # assume we've precomputed all omega^k so it's free to use here prod = omega^k * fft_odd[k] res[k] = fft_even[k] + prod (mod M) res[k + b/2] = fft_even[k] - prod (mod M) return res
Most operations are fine; adding
prod, copying an element of
signal_even, etc. all take time. The cost of all these operations over the entire FFT would be which is no big deal. But calculating
prod involves the multiplication of two -bit integers, which means more recursion. If we let be the cost of multiplying two -bit integers, we get the recurrence
which solves to ; not very good.
The FFT inherently involves multiplication, but the actual recursive calls to our multiplication routine can be avoided with a trick. Here's the analogy for humans: it's easy to compute 1019068913 * 100000, but hard to compute 1019068913 * 193126 even though 100000 and 193126 are the same length. Similarly, it's easy for us to compute numbers modulo powers of 10, but hard modulo other integers. For computers, the same applies, but with powers of two since they work in binary. For example, multiplication by a power of two becomes a left shift. In general, integers with sparse representations in binary ( for some ) support efficient multiplication and remainder. Crucial to Schonhage Strassen is that is an th root of unity modulo , and a th root of unity modulo . We won't go into the details, but computing integers modulo can be done in linear time (as can multiplying by ).
TLDR: We let and and solve our recurrence with as a parameter. With we get ; at we get ; at we get .
With our choices of and , we're left with the recurrence
and the task of finding the optimal and . Let's assume for now that and . We now find an approximation for . Earlier we found that we need , so let's set .
as long as the second term will dominate, so for now we'll say . This simplifies our recurrence to:
There no systematic way of solving this kind of recurrence (at least not that I know of), so we just have to expand it and compute the sum. By how big-O notation is defined, there is some constant such that . We start expanding and get:
Things quickly get ugly, but notice that the recursive calls are always of the form . We can derive that
To simplify our notation, let be shorthand for . Suppose we expand our recursion times and get . Once we expand an additional level we'll have the following tuple:
We'll also add a term to our runtime. Our total runtime should be the sum of this term over all depths of the recursion:
where denotes the total depth; we'll compute later. We start by calling so our initial tuple is . Note that for all , and . Combined with the base case, this allows us to say and for all . Our sum therefore simplifies to
It's also trivial to derive that . Finally, is upper bounded by ; it starts at zero and for any , we have . Now we can further simplify the summation to:
These expressions for and are also useful for finding the depth of our recursion. We compute by solving for when our input into , which is , becomes :
Plugging this back into our sum, simplfies to . Meanwhile 's value depends on the size of . If , the sum simplifies to . If , we get , and if we get . Overall we get:
is the best we can do
Unfortunately, our choices for and make the case impossible. The underlying problem is that is a very sparse root of unity. For a modulus of bits, we're only getting an th root of unity. In contrast, a prime of length bits is guaranteed to have a th root of unity (proof here).
The sparsity of our root of unity intuitively makes sense because among all numbers in we'd expect "nice" numbers, which are easy to multiply by, to be very rare. We need these nice numbers to avoid recursive multiplies in our FFT routine that would otherwise ruin our recursion. These numbers' rarity gets us in trouble though, because for we need more than an th root of unity from an bit modulus. For example, for , we need an th root of unity but only have bits; this is equivalent to trying to get an th root of unity from bits. It seems like this is impossible while constraining ourselves to "nice" and which allow the FFT to run in time, so the best we can achieve is with and .
We now have the basic idea of Schonhage-Strassen down. Suppose we want to multiply two bit integers. We:
- split each input integer into pieces of bits each
- Perform an NTT on each, using and (the original paper uses , but works as well)
- Recurse to compute the pointwise product of the two NTTs
- Perform an inverse NTT (with same and )
However, there are a lot of details that need to be ironed out. First, we need to pad our inputs such that . This is needed because the FFT algorithm we use requires the signal to be a power-of-two length. This at most doubles the length of our integers, which doesn't change the asymptotic complexity. Next is how exactly to take the square root of , since isn't necessarily a perfect square. We can assume for some , and if is even we can take the square root exactly. If is odd, two options come to mind:
- ; we have more chunks, and each chunk is smaller. This leads to the recurrence .
- ; fewer chunks that are each larger. This leads to the recurrence .
Note we actually do need to solve these recurrences, because they are slightly different from the one we solved above, which assumes we can perfectly divide the integer into square root-sized chunks. Fortunately, both choices for give us , so either can be used.
About that approximation...
TLDR: We've been using but in fact must be slightly bigger. We can't tweak our NTT to make this work, so we instead compute the convolution modulo using a totally different technique. This result modulo can be computed fast enough to not affect our recurrence, and can be combined with our NTT result (which computed modulo ) to get a result modulo which is sufficient.
The trickiest detail involves our use of the approximation . Recall we in fact need to fit all possible convolution results, so we're a little under. Unfortunately, there aren't any in our very constrained selection of and such that ; the closest we can do is , in which case , which greatly overshoots our needs. Setting this high makes the recursion too "fat":
results in an runtime. Schonhage-Strassen gets around this by keeping , but then also calculating the result of the convolution modulo . These two results can then be combined to get the convolution modulo . It looks like this:
Up until now we've only been talking about things going on in the blue box. The weird thing is that although the auxilary work looks exactly like the primary work, the auxilary work is calculated in a completely different manner. We can't use NTT for our auxilary work because is a power of two, and there's no easy way to find roots of unity modulo a power of two. We also need to use a power-of-two modulus for our auxilary work in order for it to be combineable with our main work, which was done modulo . In general, a result modulo and a result modulo can only be used to deduce the result modulo iff and are relatively prime. It's easy to show and are relatively prime, because one is a power of two and the other is odd. But no such guarantee can be made for , or any number not a power of two.
Although we can't use NTT, a very fast and powerful tool for convolution, our auxilary work is so small that we can use simpler, less efficient convolution algorithms and still scrape under the time budget. As long as our auxilary work takes time, it won't affect our recurrence because the NTTs in our primary task already took time. It's fairly easy to fit in because our modulus is so small. One way to look at things is we only need bits to store each result modulo . Since there are chunks, our auxilary result is only bits total. The main result needed bits, which is far larger.
Naive convolution is too slow
Unforunately, we don't quite have the time to calculate our convolution modulo naively. This would take arithmetic operations. Each operation is done on bit integers, so overall we'd spend time. If we recursively call our own multiplication routine, and assume still (which isn't necessarily correct, because the additional recursive calls could slow ), this runtime would evaluate to
which is too slow.
Karatsuba to the rescue
We can see from above that naive convolution is just barely too slow to work; even arithmetic operation convolution would be enough. We can't use NTT, which gives operations, but there are other sub-quadratic convolution algorithms we can resort to. One option is Karatsuba convolution, which takes arithmetic operations. It's more commonly known as a multiplication algorithm, but since multiplication and convolution are equivalent, it can also be used for our purposes. Say we want to multiply two numbers and . Let's split the digits of each number into halves: and . For example, if then . The naive way of computing is:
A B x C D -------- AD BD AC BC -------- Result: AC (AD + BC) BD
Which takes 4 multiplications, each half the size. This gives the recurrence which solves to as expected. Karatsuba is a way to compute the result with 3 multiplications: , , and . We can then compute to get the middle term of our result. This leads to the recurrence which gives as desired.
Using Karatsuba for convolution is the exact same; given two signals and , we split each into halves; . Then, we compute , , and ; these are analagous to , respectively, from above. Then equals the sum of the following:
- , padded with zeros at the back
- , padded with zeros before and after
- , padded with zeros at the front
This allows us to convolve our integers in operations, each of which is on bit integers. We can just use naive quadratic-time multiplication for these operations, which gives us time overall. One final detail is that Karatsuba gives us an acyclic convolution, but the cyclic convolution can trivially be computed from the acyclic convolution so this is no problem.
Combining auxilary and main work
If we know a certain number is satisfies and , then
This can be derived from the Chinese Remainder Theorem. Crucially, does not require any multiplication; we can do with bit shifts, then subtract to get the result. Bit shifts, addition, and subtraction all take linear time, so each element of our convolution can be found in time. We have elements, so this takes time total.
One final thing...
After the combine step, we now have the exact value of each element of the cyclic convolution. However, we still don't have the exact result of the integer multiplication, because multiplication is an acyclic convolution. These convolutions are related but different. Let denote the cyclic convolution of the length- signals and , and the acyclic convolution of the same two signals. is of length while is of length . The cyclic convolution is equivalent to adding together the lower and upper halves of the acyclic convolution:
If we interpret our convolution results as integers, we can express as where and are the two halves of . Since , we get . This shows that our multiplication result is unrecoverable from unless we can guarantee that the result is below . We can ensure this is true by padding and with zeros each. Our signal is now of length , and the result must be below because .
This solves our problem but also doubles our problem size. Doing this once initially is fine, but if we do the zero-padding on every recursive call to our multiply routine, it would blow up our recursion and result in an runtime. The workaround to this is to re-define as the cost of multiplying two bit integers and returning the result modulo (as opposed to the full result). We zero-pad our inputs in the initial call, but not in the recursive routine. This turns out to work fine because our recursive calls to the multiply routine were used to compute ; we wanted things modulo anyways, not the full result.
And finally we're done! Here's a block diagram of the algorithm, with recursive calls denoted by red dashed arrows:
These links might be useful to learn more about Schonhage-Strassen and other ideas in the realm of FFT-based multiplication:
- Modern Computer Algebra - more basic overview
- Modern Computer Arithmetic - more focused on polynomial multiplication which is slightly different
- KTH Royal Institute of Technology, algorithm notes - excellent discussion of Schonhage-Strassen's subtleties, but fairly dense
Appendix: NTT-Based Convolution Works
We show that , where and are discrete signals of length , represents pointwise multiplication, and represents convolution. Following the standard definition of DFT/NTT:
NTT and DFT share the same formula and the property that is an th root of unity. In DFT, ; in NTT, we let be an integer such that and for all . Note this implies if and only if . Now we put our NTT definition into the inverse DFT/NTT formula:
Notice the product of the and sums looks a lot like convolving--exactly what we want. Let's expand this product, then group the terms by their exponent on :
If we move all terms to the inner sum, and then re-arrange the sums, we get:
Now we'll prove that the innermost sum acts as a filter on ; unless . Let . Earlier we defined such that , which we'll use now:
Where the last step uses a well known algebraic identity. This implies one of two things:
- , as desired
By our definition of , the former case can only occur if . Otherwise, the sum is zero mod and can be dropped from the formula. So, our inverse NTT expression now simplifies to:
Finally, since , then ranges from to , so either or . Assuming that and "wrap around" (as in ) this expression leads to the convolution formula as desired.