Month: June 2014

Algorithm #12: Matrix Exponentiation

Please read the previous post on Binary Exponentiation before you start with this one.

Lets first understand what a recurrence relation is. You probably know about the Fibonacci Series. It is a sequence of numbers in which the first number is 0, the second number is 1 and all subsequent numbers are determined using the formula:
f(n) = f(n-1) + f(n-2)

An equation such as the one above, in which one term of a sequence is defined using the previous terms, is called a Recurrence relation. Therefore, relations like
f(n)=f(n-3) + f(n-2) + f(n-1) [ Tribonacci Series ]
or
f(n)=3*f(n-1) + 7*f(n-2) [ an arbitrary example ]
etc.
are recurrence relations.

If we are given the problem to find the nth Fibonacci number modulo a prime number M, the naive solution would be like this:

long long findFibonacci(long long n,long long M)
{
  if(n==1)
    return 0;
  if(n==2)
    return 1;
  long long i,prevterm=0,prevterm2=1,thisterm;
  for(i=3;i<=n;i++)
  {
    thisterm=(prevterm+prevterm2)%M;
    prevterm=prevterm2;
    prevterm2=thisterm;

  }
  return thisterm;
}

In fact, we can write code to find nth element of any recurrence relation in a similar manner.
The problem with the previous code is that it has O(n) i.e. linear complexity.

Matrix exponentiation is a faster method that can be used to find the nth element of a series defined by a recurrence relation.
We’ll take Fibonacci series as an example.

In matrix exponentiation, we first convert the addition in a recurrence relation to multiplication. The advantage of doing this will become clear as you read on.

So the question is: How can we convert the addition in a recurrence relation to multiplication. The answer is matrices!

The general recurrence relation for a series in which a term depends on the previous 2 terms is:
f(n) = a*f(n-1) + b*f(n-2)
( For Fibonacci, a=1 and b=1 )
The matrix form of this equation is:

| f(n)   | =  | p  q | X | f(n-1) |
| f(n-1) |    | r  s |   | f(n-2) |

For convenience let
| p  q | = Z
| r  s |

Therefore, we get
f(n) = p * f(n-1) + q * f(n-2)
and
f(n-1) = r * f(n-1) + s * f(n-2)

For each recurrence relation, the values of p,q,r and s will be different.
On solving the above equations for the Fibonacci sequence we get, p=1, q=1, r=1 and s=0.

So, the Z matrix for Fibonacci sequence is

| 1  1 |
| 1  0 |

and the matrix form for f(n) = f(n-1) + f(n-2) is:

| f(n)   | = | 1  1 | X | f(n-1) |
| f(n-1) |   | 1  0 |   | f(n-2) |

Now lets get to the method for finding the nth element.
Initially we have the matrix,

| f(2) |
| f(1) |

Using the matrix form of Fibonacci series given above, if we have to find the next Fibonacci number, i.e. f(3), we will multiply Z matrix by the above matrix:

| 1  1 |  X | f(2) | = | f(3) |
| 1  0 |    | f(1) |   | f(2) |
If we again multiply Z with | f(3) | , we'll get | f(4) |
	      	            | f(2) |             | f(3) |

So, we have the following equation for the nth Fibonacci number.

| f(n)   | = Z^(n-2) X | f(2) |
| f(n-1) |             | f(1) |

So, we have successfully changed the addition in the recurrence equation to multiplication.
But now what??
As I mentioned in my previous post, we have an algorithm called Binary Exponentiation that can perform the operation base^power in O(log n) time.
Because now our job is to find Z^(n-2), we can do this by using Binary Exponentiation in O(log n) time.

Z^(n-2) will then be multiplied by | f(2) | and we'll get | f(n)   |  
		 		   | f(1) |		  | f(n-1) |

Ofcourse there is the small matter that multiplying matrices contributes to more overhead. But that overhead is tiny as compared to the speed up that we are obtaining by reducing O(n) to O(log n)

Here is the Matrix Exponentiation code for finding the nth Fibonacci number.
Compare it with the iterative version of Binary Exponentiation. You’ll observe that the only change is that we are now performing matrix multiplication instead of simple integer multiplication.
That’s why this algorithm is called Matrix Exponentiation.

void matmult(long long  a[][2],long long  b[][2],long long c[][2],long long  M)//multiply matrix a and b. put result in c
{
	int i,j,k;
	for(i=0;i<2;i++)
	{
		for(j=0;j<2;j++)
		{
			c[i][j]=0;
			for(k=0;k<2;k++)
			{
				c[i][j]+=(a[i][k]*b[k][j]);
				c[i][j]=c[i][j]%M;
			}
		}
	}

}
void matpow(long long Z[][2],int n,long long ans[][2],long long M)
//find ( Z^n )% M and return result in ans
{

	long long temp[2][2];
	//assign ans= the identity matrix
	ans[0][0]=1;
	ans[1][0]=0;
	ans[0][1]=0;
	ans[1][1]=1;
	int i,j;
	while(n>0)
	{
		if(n&1)
		{
			matmult(ans,Z,temp,M);
			for(i=0;i<2;i++)
				for(j=0;j<2;j++)
					ans[i][j]=temp[i][j];
		}
		matmult(Z,Z,temp,M);
		for(i=0;i<2;i++)
			for(j=0;j<2;j++)
				Z[i][j]=temp[i][j];


		n=n/2;
	}
	return;
	
}
long long findFibonacci(long long n,long long M)
{
	
	long long fib;
	if(n>2)
	{
		long long int Z[2][2]={{1,1},{1,0}},result[2][2];//modify matrix a[][] for other recurrence relations
		matpow(Z,n-2,result,M);
		fib=result[0][0]*1 + result[0][1]*0;	//final multiplication of Z^(n-2) with the initial terms of the series
	}
	else
		fib=n-1;
		
	return fib;
}

The challenging part of this algorithm is to find the Z matrix for a recurrence relation.
For the recurrence relation: f(n) = f(n-1) + 2*f(n-2) + 3*f(n-3), we have

| f(n)   |    | p  q  r |   | f(n-1) |
| f(n-1) | =  | s  t  u | X | f(n-2) |
| f(n-2) |    | v  w  x |   | f(n-3) |

Write out the equations for f(n), f(n-1) and f(n-2) from the above matrix equation and you’ll find that the Z matrix is:

| 1  2  3 |
| 1  0  0 |
| 0  1  0 |

Related Problems:
CLIMBING STAIRS (Codechef)

Advertisements

Algorithm #11: Binary Exponentiation

Some programming problems require us to find the value of base^power modulo some positive prime number M (power>=0).
If you had to find the value of ( base^power ) % M, how would you do it??
The easiest method that comes to mind is iterating a loop.

long long findPower(long long base,long long power,long long M)
{
	long long ans=1;
	int i;
	for(i=1;i<=power;i++)
		ans=(ans*base)%M;
	return ans;
}

The above method is simple but not at all efficient. For values of power around 10^8 or more, this method will take a lot of time to run. If you increase the value of power to 10^16, you’ll pass out from college before the computation ends! Increase it to around 10^25, the sun will become a Red Giant and swallow Earth before the value is computed!!

But why would we need to find such large values :P. Yeah that’s a valid question; but it shouldn’t stop us from learning a better algorithm!

There is a much better algorithm than the linear one described above, which is the topic of this post.
I’ve already introduced Binary Exponentiation as a part of an earlier post. But, I need to formalize it for the next post.

Binary Exponentiation is based on the idea that,
to find base^power, all we need to do is find base^(power/2) and square it. And this method can be repeated in finding base^(power/2) also.

Suppose that we need to find 5^8.
5^8=5^4 * 5^4
5^4=5^2 * 5^2
5^2=5 * 5

The number of steps required to find 5^8 has been reduced from 8 to just 3.
As another example, consider 8^51
8^51 = 8^25 * 8^25 * 8
8^25 = 8^12 * 8^12 * 8
8^12 = 8^6 * 8^6
8^6 = 8^3 * 8^3
8^3 = 8 * 8 * 8

If we used the linear algorithm, we would have required 27 steps. But, using this awesome trick we required roughly 5 steps.
In a general sense, we require steps in the order of O(logbase2 n)

This algorithm can be implemented recursively as well as iteratively.

RECURSIVE IMPLEMENTATION:

long long fastPower(long long base,long long power,long long M)
{
    if(power==0)
       return 1;
    if(power==1)
    return base;
    long long halfn=fastPower(base,power/2,M);
    if(power%2==0)
        return ( halfn * halfn ) % M;
    else
        return ( ( ( halfn * halfn ) % M ) * base ) % M;
}

ITERATIVE IMPLEMENTATION:

long long int fastPower(long long base,long long power,long long M)
{
        long long result=1;
        while (power>0) 
        {
                if (power%2==1)         
                        result = (result*base)%M;
                base = (base*base)%M;
                power/=2;
        }
        return result;
}

The iterative implementation can be a little tricky to understand. Take a pen and paper and trace the values of variables through the iterations. That’s an effective way to understand and convince yourself.
The iterative version runs a little faster that the recursive version.

Binary exponentiation will find the value of base^(10^25) in about 85 steps only.. Now that’s really cool…

                        The maximum number of multiplications that are required to compute base ^ power is 2 X FLOOR(logbase2 (power)). Think why 😛