recurrence relation

Algorithm #12: Matrix Exponentiation

Please read the previous post on Binary Exponentiation before you start with this one.

Lets first understand what a recurrence relation is. You probably know about the Fibonacci Series. It is a sequence of numbers in which the first number is 0, the second number is 1 and all subsequent numbers are determined using the formula:
f(n) = f(n-1) + f(n-2)

An equation such as the one above, in which one term of a sequence is defined using the previous terms, is called a Recurrence relation. Therefore, relations like
f(n)=f(n-3) + f(n-2) + f(n-1) [ Tribonacci Series ]
or
f(n)=3*f(n-1) + 7*f(n-2) [ an arbitrary example ]
etc.
are recurrence relations.

If we are given the problem to find the nth Fibonacci number modulo a prime number M, the naive solution would be like this:

long long findFibonacci(long long n,long long M)
{
  if(n==1)
    return 0;
  if(n==2)
    return 1;
  long long i,prevterm=0,prevterm2=1,thisterm;
  for(i=3;i<=n;i++)
  {
    thisterm=(prevterm+prevterm2)%M;
    prevterm=prevterm2;
    prevterm2=thisterm;

  }
  return thisterm;
}

In fact, we can write code to find nth element of any recurrence relation in a similar manner.
The problem with the previous code is that it has O(n) i.e. linear complexity.

Matrix exponentiation is a faster method that can be used to find the nth element of a series defined by a recurrence relation.
We’ll take Fibonacci series as an example.

In matrix exponentiation, we first convert the addition in a recurrence relation to multiplication. The advantage of doing this will become clear as you read on.

So the question is: How can we convert the addition in a recurrence relation to multiplication. The answer is matrices!

The general recurrence relation for a series in which a term depends on the previous 2 terms is:
f(n) = a*f(n-1) + b*f(n-2)
( For Fibonacci, a=1 and b=1 )
The matrix form of this equation is:

| f(n)   | =  | p  q | X | f(n-1) |
| f(n-1) |    | r  s |   | f(n-2) |

For convenience let
| p  q | = Z
| r  s |

Therefore, we get
f(n) = p * f(n-1) + q * f(n-2)
and
f(n-1) = r * f(n-1) + s * f(n-2)

For each recurrence relation, the values of p,q,r and s will be different.
On solving the above equations for the Fibonacci sequence we get, p=1, q=1, r=1 and s=0.

So, the Z matrix for Fibonacci sequence is

| 1  1 |
| 1  0 |

and the matrix form for f(n) = f(n-1) + f(n-2) is:

| f(n)   | = | 1  1 | X | f(n-1) |
| f(n-1) |   | 1  0 |   | f(n-2) |

Now lets get to the method for finding the nth element.
Initially we have the matrix,

| f(2) |
| f(1) |

Using the matrix form of Fibonacci series given above, if we have to find the next Fibonacci number, i.e. f(3), we will multiply Z matrix by the above matrix:

| 1  1 |  X | f(2) | = | f(3) |
| 1  0 |    | f(1) |   | f(2) |
If we again multiply Z with | f(3) | , we'll get | f(4) |
	      	            | f(2) |             | f(3) |

So, we have the following equation for the nth Fibonacci number.

| f(n)   | = Z^(n-2) X | f(2) |
| f(n-1) |             | f(1) |

So, we have successfully changed the addition in the recurrence equation to multiplication.
But now what??
As I mentioned in my previous post, we have an algorithm called Binary Exponentiation that can perform the operation base^power in O(log n) time.
Because now our job is to find Z^(n-2), we can do this by using Binary Exponentiation in O(log n) time.

Z^(n-2) will then be multiplied by | f(2) | and we'll get | f(n)   |  
		 		   | f(1) |		  | f(n-1) |

Ofcourse there is the small matter that multiplying matrices contributes to more overhead. But that overhead is tiny as compared to the speed up that we are obtaining by reducing O(n) to O(log n)

Here is the Matrix Exponentiation code for finding the nth Fibonacci number.
Compare it with the iterative version of Binary Exponentiation. You’ll observe that the only change is that we are now performing matrix multiplication instead of simple integer multiplication.
That’s why this algorithm is called Matrix Exponentiation.

void matmult(long long  a[][2],long long  b[][2],long long c[][2],long long  M)//multiply matrix a and b. put result in c
{
	int i,j,k;
	for(i=0;i<2;i++)
	{
		for(j=0;j<2;j++)
		{
			c[i][j]=0;
			for(k=0;k<2;k++)
			{
				c[i][j]+=(a[i][k]*b[k][j]);
				c[i][j]=c[i][j]%M;
			}
		}
	}

}
void matpow(long long Z[][2],int n,long long ans[][2],long long M)
//find ( Z^n )% M and return result in ans
{

	long long temp[2][2];
	//assign ans= the identity matrix
	ans[0][0]=1;
	ans[1][0]=0;
	ans[0][1]=0;
	ans[1][1]=1;
	int i,j;
	while(n>0)
	{
		if(n&1)
		{
			matmult(ans,Z,temp,M);
			for(i=0;i<2;i++)
				for(j=0;j<2;j++)
					ans[i][j]=temp[i][j];
		}
		matmult(Z,Z,temp,M);
		for(i=0;i<2;i++)
			for(j=0;j<2;j++)
				Z[i][j]=temp[i][j];


		n=n/2;
	}
	return;
	
}
long long findFibonacci(long long n,long long M)
{
	
	long long fib;
	if(n>2)
	{
		long long int Z[2][2]={{1,1},{1,0}},result[2][2];//modify matrix a[][] for other recurrence relations
		matpow(Z,n-2,result,M);
		fib=result[0][0]*1 + result[0][1]*0;	//final multiplication of Z^(n-2) with the initial terms of the series
	}
	else
		fib=n-1;
		
	return fib;
}

The challenging part of this algorithm is to find the Z matrix for a recurrence relation.
For the recurrence relation: f(n) = f(n-1) + 2*f(n-2) + 3*f(n-3), we have

| f(n)   |    | p  q  r |   | f(n-1) |
| f(n-1) | =  | s  t  u | X | f(n-2) |
| f(n-2) |    | v  w  x |   | f(n-3) |

Write out the equations for f(n), f(n-1) and f(n-2) from the above matrix equation and you’ll find that the Z matrix is:

| 1  2  3 |
| 1  0  0 |
| 0  1  0 |

Related Problems:
CLIMBING STAIRS (Codechef)

Advertisements