FFT-Based Integer Multiplication, Part 2

How the Schonhage-Strassen algorithm uses the number theoretic transform (NTT) to multiply N-bit integers in O(N log N log log N) time

November 20, 2018

In the first half of this article we discussed convolution, using FFT for convolution, and the difficulties associated with using floating-point numbers when exact integer output is required. This second half of the article is about the Schonhage-Strassen algorithm for fast integer multiplication. It multiplies N bit integers in O(N\log N\log\log N) time via a modular arithmetic based FFT, known as the number theoretic transform (NTT). Schonhage-Strassen is no longer asymptotically the fastest, but it held the title for over thirty years, and its ideas form the basis for newer algorithms that slightly reduce the runtime.

Key to Schonhage-Strassen is NTT, which is a modified DFT. The DFT basically multiplies the original signal by an N th root of unity raised to various powers. A standard DFT uses the root e^{2\pi j/N} , but as shown in the appendix of this article, any N th root of unity works, where a root of unity is defined as \omega such that \omega^k\neq1 for 1\le k < N and \omega^N=1 . The NTT uses roots of unity that arise from modular arithmetic. For example, 2 is a 5th root of unity modulo 31, because 2^1, 2^2, 2^3, 2^4\not\equiv 1 \pmod{31} , but 2^5\equiv 1 \pmod{31} . The NTT returns the results of the convolution modulo M but otherwise isn’t that different from a standard DFT:

  • FFT algorithms such as Cooley-Turkey also work on NTT, so we can compute an NTT in O(N\log N) time
  • As shown in the appendix, pointwise multiplication in the NTT domain is equivalent to convolving the original signal

Together these two things allow us to use NTT to do O(N\log N) convolution. There are two big questions left to answer:

  • What \omega modulo what M do we want to use? Unlike with complex numbers, there are many N th roots of unity, depending on the M we choose. For example, x is a tenth root of unity modulo x^{10}-1 for all integers x>1 . Are certain x better than others?
  • How should we split the integer? Given an N -bit integer, should it be represented as a length N signal of bits, a length N/2 signal of pairs of bits, etc.

These two questions are somewhat related because we want to choose an M large enough such that each element of the convolution is within [0,\ldots, M) –otherwise, it’ll overflow modulo M and be wrong. The max value of each element in the convolution depends on how the integer is split, so these things must be decided together.

Suppose we split our N -bit integer into B blocks, each of length L (so BL=N ). We now want to calculate how large M needs to be to avoid overflow. By the definition of a cyclic convolution,

(f*g)[n]=\sum_{m=0}^{B-1}f[m]g[n-m]

Since each element of f and g is an L -bit integer, its value is at most 2^L-1 . So:

(f*g)[n]\le\sum_{m=0}^{B-1}(2^L-1)^2=B(2^L-1)^2

We therefore need M>B(2^L-1)^2 .

Don’t recurse too much

TLDR: FFT-based multiplication is recursive because no matter how we split our N -bit integer into a signal, our convolution result’s elements are of unbounded size. We can’t avoid the recursive calls from multiplying signals in the frequency domain. But with \omega=2 and M=2^k\pm1 for any k , we can replace what would’ve been additional recursive multiplies with bit-shifts, adds, and subtracts. This reduction in recursion is key to good asymptotic performance.

Reproduced below is the pseudocode from part 1 of this article for computing f*g , where f and g are two signals of length B :

  1. Use FFT to compute F:=DFT(f) and G:=DFT(g) . Time: O(B\log B)
  2. By the convolution theorem, DFT(f * g)[k]=F[k]\cdot G[k] so we pointwise multiply our results from step 1. Time: O(B)
  3. Use FFT to apply the inverse DFT to our result from the previous step. Time: O(B\log B)

You might be wondering why Schonhage-Strassen is O(N\log N\log\log N) when our pseudocode above seems to suggest we only need O(N\log N) time. The problem is that the above costs are in terms of number of arithmetic operations needed, and arithmetic isn’t O(1) because we are dealing with arbitrarily large integers. We need \log M bits to store each element of the FFT result, and as shown above, M>B(2^{2L}-1)^2 so \log M>\log B + L . Since N=BL , no matter how what we set B and L , \log M will grow unboundedly as N grows. This means step 2 of our pseudocode involves multiplying integers of unbounded size–time for recursion! (Our recursion ends up being \log\log N deep, and each recursion level takes N\log N time, giving our final runtime.)

There’s more, though–FFT requires us to multiply elements of our input signal by \omega raised to some power. For more detail, see the pseudocode for FFT-based NTT below:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# assuming b is a power of 2; Cooley-Turkey algorithm
fft(signal, b, omega, M):
    if b == 1:
        return [signal[0]]
    signal_even = [signal[0], signal[2], ..., signal[b - 2]]
    signal_odd  = [signal[1], signal[3], ..., signal[b - 1]]
    fft_even = fft(signal_even, b / 2, omega^2, M)
    fft_odd  = fft(signal_odd,  b / 2, omega^2, M)
    res = array(b)
    for k in [0, b/2):
        # assume we've precomputed all omega^k so it's free to use here
        prod = omega^k * fft_odd[k]
        res[k] = fft_even[k] + prod (mod M)
        res[k + b/2] = fft_even[k] - prod (mod M)
    return res

Most operations are fine; adding fft_even with prod, copying an element of signal to signal_even, etc. all take O(L) time. The cost of all these operations over the entire FFT would be O(B\log B)O(L)=O(BL\log B)=O(N\log B) which is no big deal. But calculating prod involves the multiplication of two L -bit integers, which means more recursion. If we let \newcommand{\M}{\mathcal{M}} \M(N) be the cost of multiplying two N -bit integers, we get the recurrence

\M(N)=\underbrace{B\M(\log M)}_\text{step 2} + \underbrace{O(B\log B)\M(L)}_\text{steps 1 and 3, multiplies}+\underbrace{O(N\log B)}_{\substack{\text{steps 1 and 3,} \\ \text{shift/add/subtract}}}

which solves to \M(N)=N^{1+\epsilon} ; not very good.

The FFT inherently involves multiplication, but the actual recursive calls to our multiplication routine can be avoided with a trick. Here’s the analogy for humans: it’s easy to compute 1019068913 * 100000, but hard to compute 1019068913 * 193126 even though 100000 and 193126 are the same length. Similarly, it’s easy for us to compute numbers modulo powers of 10, but hard modulo other integers. For computers, the same applies, but with powers of two since they work in binary. For example, multiplication by a power of two becomes a left shift. In general, integers with sparse representations in binary ( 2^a\pm2^b for some a,b ) support efficient multiplication and remainder. Crucial to Schonhage Strassen is that \omega=2 is an N th root of unity modulo 2^N-1 , and a 2N th root of unity modulo 2^N+1 . We won’t go into the details, but computing integers modulo 2^N\pm1 can be done in linear time (as can multiplying by 2^k ).

Optimizing B and L

TLDR: We let B=N^p and L=N^{1-p} and solve our recurrence with p as a parameter. With p<1/2 we get \M(N)=O(N\log^2 N) ; at p=1/2 we get \M(N)=O(N\log N\log\log N) ; at p>1/2 we get \M(N)=O(N\log N) .

With our choices of \omega and M , we’re left with the recurrence

\M(N)=B\M(\log M)+O(N\log N)

and the task of finding the optimal B and L . Let’s assume for now that B=N^p and L=N^{1-p} . We now find an approximation for \log M . Earlier we found that we need M>B(2^L-1)^2 , so let’s set M=B\cdot2^{2L} .

\begin{aligned} \log M &= \log(B\cdot2^{2L})=\log B + 2L\\ &= \log(N^p)+2N^{1-p}=p\log N + 2N^{1-p} \end{aligned}

as long as p<1 the second term will dominate, so for now we’ll say \log M\approx 2N^{1-p} . This simplifies our recurrence to:

\M(N)=N^p\M(2N^{1-p}) + O(N\log N)

There no systematic way of solving this kind of recurrence (at least not that I know of), so we just have to expand it and compute the sum. By how big-O notation is defined, there is some constant c such that \M(N)\le N^p\M(2N^{1-p}) + cN\log N . We start expanding and get:

\begin{aligned} \M(N)&=cN\log N + N^p((2N^{1-p})^p \M(2(2N^{1-p})^{1-p}) + c(2N^{1-p}\log(2N^{1-p})))\\ &= cN\log N+2cN\log(2N^{1-p})+2^p N^{2p-p^2} \M(2^{2-p} N^{(1-p)^2}) \end{aligned}

Things quickly get ugly, but notice that the recursive calls are always of the form 2^w N^x \M(2^y N^z) . We can derive that

\begin{aligned} 2^w N^x \M(2^y N^z) &= 2^w N^x [(2^y N^z)^p \M(2 (2^y N^z)^{1-p}) + c2^y N^z \log(2^y N^z)]\\ &= 2^{w+yp} N^{x+zp} \M(2^{(1-p)y+1} N^{(1-p)z}) + c2^{w+y}N^{x+z}(y\log2+z\log N) \end{aligned}

To simplify our notation, let (w, x, y, z) be shorthand for 2^w N^x \M(2^y N^z) . Suppose we expand our recursion k times and get (w_k, x_k, y_k, z_k) . Once we expand an additional level we’ll have the following tuple:

  • w_{k+1}=w_k+y_k p
  • x_{k+1}=x_k+z_k p
  • y_{k+1}=(1-p)y_k + 1
  • z_{k+1}=(1-p)z_k

We’ll also add a c2^{w_k+y_k}N^{x_k+z_k}(y_k\log2+z_k\log N) term to our runtime. Our total runtime should be the sum of this term over all depths of the recursion:

\M(N)=\sum_{k=0}^d 2^{w_k+y_k} N^{x_k+z_k} (y_k\log2 + z_k\log N)

where d denotes the total depth; we’ll compute d later. We start by calling \M(N)=2^0 N^0 \M(2^0 N^1) so our initial tuple is (0, 0, 0, 1) . Note that for all k , w_{k+1}+y_{k+1}=w_k+y_k+1 and x_{k+1}+z_{k+1}=x_k+z_k . Combined with the base case, this allows us to say w_k+y_k=k and x_k+z_k=1 for all k . Our sum therefore simplifies to

\M(N)=N\sum_{k=0}^d 2^k(y_k\log2 + z_k\log N)

It’s also trivial to derive that z_k=(1-p)^k . Finally, y_k is upper bounded by 1/p ; it starts at zero and for any y_k<1/p , we have y_{k+1}=(1-p)y_k+1 < (1-p) (1/p) + 1 = 1/p . Now we can further simplify the summation to:

\begin{aligned} \M(N)&\le N\sum_{k=0}^d 2^k\left(\frac{1}{p}\log2 + (1-p)^k\log N\right)\\ &=\frac{N}{p}\log2 \sum_{k=0}^d 2^k + N\log N\sum_{k=0}^d 2^k(1-p)^k\\ &\le \underbrace{\frac{N}{p}2^{d+1}\log2}_{t_1} + N\log N\underbrace{\sum_{k=0}^d 2^k(1-p)^k}_{t_2} \end{aligned}

These expressions for y_k and z_k are also useful for finding the depth d of our recursion. We compute d by solving for when our input into \M , which is 2^yN^z , becomes O(1) :

\begin{aligned} O(1)&=2^{y_d}N^{z_d}\le 2^{1/p}N^{(1-p)^d}\\ O(1)&=N^{(1-p)^d}=2^{(1-p)^d\log N}\\ O(1)&=(1-p)^d \log N\\ d &= -\log_{1-p} O(\log N) = O(\log\log N)\\ \end{aligned}

Plugging this back into our sum, t_1 simplfies to O(N\log N) . Meanwhile t_2 's value depends on the size of r:=2(1-p) . If r>1 , the sum simplifies to r^d=(\log N)^{\log r} . If r=1 , we get d , and if r<1 we get O(1) . Overall we get:

  • \M(N)=O(N\log N) if r<1\rightarrow p>1/2
  • \M(N)=O(N\log N\log\log N) if r=1\rightarrow p=1/2
  • \M(N)=O(N\log^{1+\log(2-2p)} N) if r>1\rightarrow p<1/2

O(N\log N\log\log N) is the best we can do

Unfortunately, our choices for \omega and M make the p>1/2 case impossible. The underlying problem is that \omega=2, M=2^N\pm1 is a very sparse root of unity. For a modulus of N bits, we’re only getting an O(N) th root of unity. In contrast, a prime P of length N bits is guaranteed to have a P-1=O(2^N) th root of unity (proof here).

The sparsity of our root of unity intuitively makes sense because among all numbers in [0,\ldots, M) we’d expect “nice” numbers, which are easy to multiply by, to be very rare. We need these nice numbers to avoid recursive multiplies in our FFT routine that would otherwise ruin our recursion. These numbers’ rarity gets us in trouble though, because for p>1/2 we need more than an O(N) th root of unity from an N bit modulus. For example, for p=3/4 , we need an N^{3/4} th root of unity but only have O(N^{1/4}) bits; this is equivalent to trying to get an O(N^3) th root of unity from N bits. It seems like this is impossible while constraining ourselves to “nice” \omega and M which allow the FFT to run in O(N\log N) time, so the best we can achieve is p=1/2 with \omega=2 and M=2^N\pm 1 .

Final Details

We now have the basic idea of Schonhage-Strassen down. Suppose we want to multiply two N bit integers. We:

  • split each input integer into B=\sqrt{N} pieces of L=\sqrt{N} bits each
  • Perform an NTT on each, using \omega=2 and M=2^{2\sqrt{N}}\pm 1 (the original paper uses M=2^{2\sqrt{N}}+1 , but 2^{2\sqrt{N}}-1 works as well)
  • Recurse to compute the pointwise product of the two NTTs
  • Perform an inverse NTT (with same \omega and M )

However, there are a lot of details that need to be ironed out. First, we need to pad our inputs such that N=2^k . This is needed because the FFT algorithm we use requires the signal to be a power-of-two length. This at most doubles the length of our integers, which doesn’t change the asymptotic complexity. Next is how exactly to take the square root of N , since N isn’t necessarily a perfect square. We can assume N=2^k for some k , and if k is even we can take the square root exactly. If k is odd, two options come to mind:

  • B=\sqrt{2N}, L=\sqrt{N/2} ; we have more chunks, and each chunk is smaller. This leads to the recurrence \M(N)=O(N\log N)+\sqrt{2N}\M(2\sqrt{N/2}) .
  • B=\sqrt{N/2}, L=\sqrt{2N} ; fewer chunks that are each larger. This leads to the recurrence \M(N)=O(N\log N)+\sqrt{N/2}\M(2\sqrt{2N}) .

Note we actually do need to solve these recurrences, because they are slightly different from the one we solved above, which assumes we can perfectly divide the integer into square root-sized chunks. Fortunately, both choices for B, L give us O(N\log N\log\log N) , so either can be used.

About that \log M approximation…

TLDR: We’ve been using \log M\approx 2L but \log M in fact must be slightly bigger. We can’t tweak our NTT to make this work, so we instead compute the convolution modulo B using a totally different technique. This result modulo B can be computed fast enough to not affect our recurrence, and can be combined with our NTT result (which computed modulo 2^{2L}-1 ) to get a result modulo B(2^{2L}-1) which is sufficient.

The trickiest detail involves our use of the approximation \log M\approx 2N^{1-p}=2L . Recall we in fact need \log M=\log B + 2L to fit all possible convolution results, so we’re a little under. Unfortunately, there aren’t any M in our very constrained selection of \omega and M such that \log M=\log B + 2L ; the closest we can do is \omega=2^3, M=2^{3L}\pm1 , in which case \log M=3L , which greatly overshoots our needs. Setting M this high makes the recursion too “fat”:

\M(N)=O(N\log N)+\sqrt{N}\M(3\sqrt{N})

results in an O(N\log^2 N) runtime. Schonhage-Strassen gets around this by keeping M=2^{2L}-1 , but then also calculating the result of the convolution modulo B . These two results can then be combined to get the convolution modulo B(2^{2L}-1) . It looks like this:

SSA Simple Block Diagram

Up until now we’ve only been talking about things going on in the blue box. The weird thing is that although the auxilary work looks exactly like the primary work, the auxilary work is calculated in a completely different manner. We can’t use NTT for our auxilary work because B is a power of two, and there’s no easy way to find roots of unity modulo a power of two. We also need to use a power-of-two modulus for our auxilary work in order for it to be combineable with our main work, which was done modulo 2^{2L}-1 . In general, a result modulo x and a result modulo y can only be used to deduce the result modulo xy iff x and y are relatively prime. It’s easy to show B and 2^{2L}-1 are relatively prime, because one is a power of two and the other is odd. But no such guarantee can be made for B\pm1 , or any number not a power of two.

Although we can’t use NTT, a very fast and powerful tool for convolution, our auxilary work is so small that we can use simpler, less efficient convolution algorithms and still scrape under the time budget. As long as our auxilary work takes O(N\log N) time, it won’t affect our recurrence because the NTTs in our primary task already took O(N\log N) time. It’s fairly easy to fit in O(N\log N) because our modulus B is so small. One way to look at things is we only need \log B bits to store each result modulo B . Since there are B chunks, our auxilary result is only B\log B bits total. The main result needed 2BL=O(B^2) bits, which is far larger.

Naive convolution is too slow

Unforunately, we don’t quite have the time to calculate our convolution modulo B naively. This would take B^2 arithmetic operations. Each operation is done on \log B bit integers, so overall we’d spend B^2\M(\log B) time. If we recursively call our own multiplication routine, and assume \M(N)=O(N\log N\log\log N) still (which isn’t necessarily correct, because the additional recursive calls could slow \M ), this runtime would evaluate to

\begin{aligned} &O(B^2\log B\log\log B\log\log\log B)=\\ &O(N\log N\log\log N\log\log\log N) \end{aligned}

which is too slow.

Karatsuba to the rescue

We can see from above that naive convolution is just barely too slow to work; even B^{2-\epsilon} arithmetic operation convolution would be enough. We can’t use NTT, which gives B\log B operations, but there are other sub-quadratic convolution algorithms we can resort to. One option is Karatsuba convolution, which takes O(B^{\log 3\approx1.59}) arithmetic operations. It’s more commonly known as a multiplication algorithm, but since multiplication and convolution are equivalent, it can also be used for our purposes. Say we want to multiply two numbers x and y . Let’s split the digits of each number into halves: x=AB and y=CD . For example, if x=123456 then A=123, B=456 . The naive way of computing xy is:

    A  B
x   C  D
--------
   AD BD
AC BC
--------
Result: AC (AD + BC) BD

Which takes 4 multiplications, each half the size. This gives the recurrence T(N)=4T(N/2)+O(N) which solves to T(N)=O(N^2) as expected. Karatsuba is a way to compute the result with 3 multiplications: p_1=AC , p_2=(A+B)(C+D) , and p_3=BD . We can then compute (AD+BC)=p_2-p_1-p_3 to get the middle term of our result. This leads to the recurrence T(N)=3T(N/2)+O(N) which gives T(N)=O(N^{\log3}) as desired.

Using Karatsuba for convolution is the exact same; given two signals x and y , we split each into halves; x=x_{hi}:x_{lo}, y=y_{hi}:y_{lo} . Then, we compute s_1:=x_{lo}*y_{lo} , s_2:=(x_{lo}+x_{hi})*(y_{lo}+y_{hi}) , and s_3:=x_{hi}*y_{hi} ; these are analagous to p_1,p_2,p_3 , respectively, from above. Then x*y equals the sum of the following:

  • s_3 , padded with N zeros at the back
  • s_2-s_1-s_3 , padded with N/2 zeros before and after
  • s_1 , padded with N zeros at the front

This allows us to convolve our integers in B^{\log3} operations, each of which is on \log B bit integers. We can just use naive quadratic-time multiplication for these operations, which gives us O(B^{\log3}\log^2 B)=O(N) time overall. One final detail is that Karatsuba gives us an acyclic convolution, but the cyclic convolution can trivially be computed from the acyclic convolution so this is no problem.

Combining auxilary and main work

If we know a certain number is x satisfies x\equiv a \pmod{B} and x\equiv b \pmod{2^{2L}-1} , then

x\equiv (x+y)(2^{2L}-1)+y\pmod{B(2^{2L}-1)}

This can be derived from the Chinese Remainder Theorem. Crucially, (x+y)(2^{2L}-1) does not require any multiplication; we can do (x+y)2^{2L} with bit shifts, then subtract x+y to get the result. Bit shifts, addition, and subtraction all take linear time, so each element of our convolution can be found in O(L) time. We have B elements, so this takes O(N) time total.

One final thing…

After the combine step, we now have the exact value of each element of the cyclic convolution. However, we still don’t have the exact result of the integer multiplication, because multiplication is an acyclic convolution. These convolutions are related but different. Let z_{cyc} denote the cyclic convolution of the length- N signals x and y , and z_{acyc} the acyclic convolution of the same two signals. z_{cyc} is of length N while z_{acyc} is of length 2N . The cyclic convolution is equivalent to adding together the lower and upper halves of the acyclic convolution:

z_{cyc}[i] = z_{acyc}[i] + z_{acyc}[i + N]

If we interpret our convolution results as integers, we can express z_{acyc} as 2^N z_{hi}+z_{lo} where z_{lo} and z_{hi} are the two halves of z_{acyc} . Since 2^N\equiv 1 \pmod{2^N-1} , we get z_{acyc}\equiv z_{hi}+z_{lo}\equiv z_{cyc}\pmod{2^N-1} . This shows that our multiplication result is unrecoverable from z_{cyc} unless we can guarantee that the result is below 2^N-1 . We can ensure this is true by padding x and y with N zeros each. Our signal is now of length 2N , and the result must be below 2^{2N}-1 because xy\le(2^N-1)^2<2^{2N}-1 .

This solves our problem but also doubles our problem size. Doing this once initially is fine, but if we do the zero-padding on every recursive call to our multiply routine, it would blow up our recursion and result in an O(N\log^2 N) runtime. The workaround to this is to re-define \M(N) as the cost of multiplying two N bit integers and returning the result modulo 2^N-1 (as opposed to the full result). We zero-pad our inputs in the initial call, but not in the recursive routine. This turns out to work fine because our recursive calls to the multiply routine were used to compute (X*Y) \pmod{2^{2L}-1} ; we wanted things modulo 2^{2L}-1 anyways, not the full result.

And finally we’re done! Here’s a block diagram of the algorithm, with recursive calls denoted by red dashed arrows:

SSA Full Block Diagram

Resources

These links might be useful to learn more about Schonhage-Strassen and other ideas in the realm of FFT-based multiplication:

Appendix: NTT-Based Convolution Works

We show that NTT^{-1}(NTT(x)\cdot NTT(y))\equiv (x*y) \pmod M , where x and y are discrete signals of length N , \cdot represents pointwise multiplication, and * represents convolution. Following the standard definition of DFT/NTT:

NTT(x)[k] = \sum_{n=0}^{N-1} x[n]\omega^{-nk}

NTT and DFT share the same formula and the property that \omega is an N th root of unity. In DFT, \omega=e^{2\pi j/N} ; in NTT, we let \omega be an integer such that \omega^N\equiv 1 \pmod M and \omega^k\not\equiv 1 \pmod M for all k\in[1,2,\ldots,N) . Note this implies w^k\equiv 1 \pmod M if and only if k\equiv 0 \pmod N . Now we put our NTT definition into the inverse DFT/NTT formula:

\begin{aligned} NTT^{-1}(X\cdot Y)[n] =& \frac{1}{N}\sum_{k=0}^{N-1} \omega^{nk} X[k]Y[k] \\ =& \frac{1}{N}\sum_{k=0}^{N-1} \omega^{nk} \left(\sum_{m_1=0}^{N-1} x[m_1]\omega^{-m_1 k}\right) \left(\sum_{m_2=0}^{N-1} y[m_2]\omega^{-m_2 k}\right) \end{aligned}

Notice the product of the m_1 and m_2 sums looks a lot like convolving–exactly what we want. Let’s expand this product, then group the terms by their exponent on \omega :

NTT^{-1}(X\cdot Y)[n] = \frac{1}{N}\sum_{k=0}^{N-1} \omega^{nk} \sum_{s=0}^{2N-2} \omega^{-sk}\sum_{\substack{m_1+m_2=s, \\ 0\le m_1,m_2 < N}} x[m_1]y[m_2]

If we move all terms to the inner sum, and then re-arrange the sums, we get:

\begin{aligned} NTT^{-1}(X\cdot Y)[n] =& \frac{1}{N}\sum_{k=0}^{N-1} \sum_{s=0}^{2N-2} \sum_{\substack{m_1+m_2=s, \\ 0\le m_1,m_2 < N}} \omega^{nk-sk} x[m_1]y[m_2] \\ =& \frac{1}{N}\sum_{s=0}^{2N-2} \sum_{\substack{m_1+m_2=s, \\ 0\le m_1,m_2 < N}} x[m_1]y[m_2] \sum_{k=0}^{N-1} \omega^{nk-sk} \end{aligned}

Now we’ll prove that the innermost sum acts as a filter on s ; \sum_{k=0}^{N-1}\omega^{k(n-s)}\equiv0 \pmod M unless n\equiv s\pmod N . Let \sigma = \omega^{n-s} . Earlier we defined \omega such that \omega^N\equiv1 \pmod M , which we’ll use now:

\sigma^N = \omega^{N(n-s)} \equiv 1^{n-s} \equiv 1 \pmod M

\sigma^N - 1 \equiv 0 \pmod M

(\sigma - 1)\sum_{k=0}^{N-1}\sigma^k \equiv 0 \pmod M

Where the last step uses a well known algebraic identity. This implies one of two things:

  1. \sigma=\omega^{n-s}\equiv 1 \pmod M
  2. \sum_{k=0}^{N-1} \sigma^k = \sum_{k=0}^{N-1} \omega^{k(n-m)} \equiv 0 \pmod M , as desired

By our definition of \omega , the former case can only occur if n-s\equiv 0 \pmod N . Otherwise, the sum is zero mod M and can be dropped from the formula. So, our inverse NTT expression now simplifies to:

\begin{aligned} NTT^{-1}(X\cdot Y)[n] \equiv & \frac{1}{N} \sum_{\substack{m_1+m_2\equiv n \pmod N, \\ 0\le m_1,m_2 < N}} x[m_1]y[m_2] \sum_{k=0}^{N-1} \omega^0\pmod M \\ \equiv & \sum_{\substack{m_1+m_2\equiv n \pmod N, \\ 0\le m_1,m_2 < N}} x[m_1]y[m_2] \pmod M \end{aligned}

Finally, since 0\le m_1,m_2 < N , then m_1+m_2 ranges from 0 to 2N-2 , so either m_1+m_2=n or m_1+m_2=n+N . Assuming that x and y “wrap around” (as in y[-5]=y[N-5] ) this expression leads to the convolution formula as desired.

\begin{aligned} NTT^{-1}(X\cdot Y)[n] \equiv & \sum_{m_1=0}^n x[m_1]y[n-m_1] + \sum_{m_1=n+1}^{N} x[m_1]y[n+N-m_1]\pmod M \\ \equiv & \sum_{m_1=0}^N x[m_1]y[n-m_1] \pmod M \end{aligned}