Wednesday 11 February 2015

Dynamic Programming - Minimize the Number of Square Numbers Sum to N (Google Interview Question)

1. Problem Description
This is a Google Interview Questions for Software Developer/Engineer on careercup. Here is the original thread.

"

You are given a positive integer number N. Return the minimum number K, such that N can be represented as K integer squares.
Examples:
9 --> 1 (9 = 3^2)
8 --> 2 (8 = 4^2 + 4^2)
15 --> 4 (15 = 3^2 + 2^2 + 1^2 + 1^2)
First reach a solution without any assumptions, then you can improve it using this mathematical lemma: For any positive integer N, there is at least one representation of N as 4 or less squares.
"
2. Analysis
Obviously this is a minimization problem and its objective is to find the minimum K that a given number can be summed up by K square numbers. Let's consider the best case and the worse case of K.

Given a number N, the best case of K will be 1 as N is a square number, which means that there exists a integer x that meets the constraint, y= x*x, where y = N.

Then what is the worst case of K. The easiest one popping up in my mind is that K is equal to N, where all the square numbers are "1". N = 1*1 + 1*! + ... + 1*1. This is of course the theoretical worst case. Can we think of anything better than this worst case, as this a minimization problem. The better we can narrow down the worst case, the fast the search can go.
Here is one of case can be regards as the worst case.
    - K = 0
    - x = sqrt(N), N' = N - x*x and K increment to 1
    - If N' < 4, K = K + N'
        * If N' is equal to 3, 2, 1 or 0, then their K will be itself.
        * 3 = 1+1+1; 2 = 1+1; 1 = 1; 0 - N is a square number
        * return K (we call this K as K')
    - Else N = N' and repeat last 2 steps

Now let's think how bad this K' can be. Every time we divide N into two parts until N is less or equal to 3. Does this process ring a bell to you? Certainly to me this process reminds me of the process of binary search, keeping dividing into two groups until less than 2. But actually we are doing better than binary search, O(log2(N)), because each time we are picking up the smaller half.
Assume that N is not a square number and there must exist an integer i that meets i*i < N < (i+1)*(i+1), Each time we are picking up (N - i*i) that is less than (i+1)*(i+1) - i*i, because
    - N - i*i < (i+1)*(i+1) - i*i = 2*i + 1 ==> N - i*i < 2*i + 1
    - In the above we always pick N - i*i as the half we are going to repeat the steps
    - Let's compare the two halfs, one is i*i and another is N-(i*i) < 2*i + 1
    - Let's compare with the first half i*i even with the bigger value 2*i + 1 (> N - i*i).
    - So when will i*i be less or equal to 2*i + 1 ==> i*i <= 2*i + 1
    - i*i - 2*i + 1 <= 1 + 1 ==> (i-1)^2 <= 2
    - In the positive range only when i < 1 + root(2) < 3 ==> i < 3, we have i*i < 2*i+1
    - So as long as i >=3 or N > 9, then i*i is more than 2*i + 1.
    - Therefore as long as N > 9, i*i is always larger than N - i*i
    - It means that the smaller half is always picked for the next round.
    - Now we can conclude than K' is better  than ~ log2(N)

Here we have successfully narrowed down the worst case that can be, from N to K'. What is this telling us? Or how will this piece of information help the computation complexity. It actually means that we successfully narrow down the search depth from [1, N] to [1, K']. Any search of K that is higher than (K' - 1), we can ignore it. Besides we can regard this as a recursive process. As long as a better K is found (<K'), we could update K' with the best finding so far and stop any search which is higher than (K' - 1) in the process.

Then what does this K' mean in the search space? It means that the search depth (search breadth and search depth I assume everyone is familiar with). For instance let's illustrate the worst case N,
    - K = 0
    - j = 1, N' = N - j*j = N - 1 and K increment to 1, search depth = 1
    - j = 1 again, N' = N - 1 - j*j = N - 2 and K increment to 2 and search depth = 2
    ......
    - j = 1 again, until N' = N - K + 1' and K increment to K' - 1 and search depth = K' - 1
    - Now can stop this search because
        * If N' is a square number, then update K' = K' - 1
        * If N' is not a square number, not matter what happens K will be no better than K'

Of course here we considered the worst case. The other search path is similar. K is incremented as each step depth increment.Stop search when it reaches K' - 1 and update K' if the search depth is less than it and it is at the point of leaf node. At the same time the search breadth can be optimized as well. When consider j with in [1, i], do we really need to consider j = 1? Not really. The search in the first half of j within [1, i/2) is just repeating the search path of j within [i/2, i].

3. Complexity
Given a number N, it takes constant time if it is a square number. This is the best case that takes O(1) computation and whose search depth is equal to 1. Let's consider a non-square number. There exits a number i that meets i*i < N < (i+1)*(i+1), where
i is equal to floor(sqrt(N)).

Base on the above analysis
    - depth = 0; j within [i/2, i], the worst case of
        N' = N - (i/2)*(i/2)
            < (i+1)*(i+i) - i*i/4  = 3/4*(i*i) + 2*i + 1
                                             = 3/4(i*i + 2*i + 1) + i/2 + 1/4
                                             = 3/4(i+1)*(i+1) + i/2 + 1/4
        Therefor the worse case of N', meets 3/4*(i*i) < N' < 3/4*(i+1)*(i+1) + i/2 +3/4
        The worse case N' is a scalar times N.
    - depth = 1, the worst case N' is scalar times N. The search breath is still ~ N^(1/2)
    - The search depth is limited to K', therefore the computation complexity is O(N^(1/2*(K'-1)))
       should be better than O(N^(1/2*(log2(N)))) theoretically.
In the case of Lemma, K' = 4, then the computation complexity is O(N^(1.5)) and the space complexity is O(1).


4. C++ Implementation
// ********************************************************************************
// IMPLEMENTATION
// ********************************************************************************
// header file
#pragma once

class MinNumOfSquareOp
{
public:
    MinNumOfSquareOp();
    ~MinNumOfSquareOp();

    size_t operator()(const size_t x) const;

private:
    size_t FindTheSearchDepth_DP(const size_t x) const;
    void FindSolutionNotWorseThanK_DP(const size_t x, size_t depth, size_t& k) const;
};

// cpp file
#include "MinNumOfSquareOp.h"

#include <cmath>

MinNumOfSquareOp::MinNumOfSquareOp()
{
}


MinNumOfSquareOp::~MinNumOfSquareOp()
{
}

size_t MinNumOfSquareOp::FindTheSearchDepth_DP(size_t x) const
{
    if (x <= 3) {
        return x;
    }

    const size_t root = sqrt(x);
    return 1 + FindTheSearchDepth_DP(x - root*root);
}

void MinNumOfSquareOp::FindSolutionNotWorseThanK_DP(const size_t x,
                                                    size_t depth,
                                                    size_t& k) const
{
    if (depth >= k) {
        return;
    }

    if (x <= 3) {
        if ((depth + x) < k) {
            k = depth + x;
        }
        return;
    }

    const size_t root = sqrt(x);
    const size_t halfOfRoot = root >> 1;
    for (size_t val = root; val >= halfOfRoot; --val) {
        FindSolutionNotWorseThanK_DP(x - val*val, depth + 1, k);
        if (k == 2) {
            return;
        }
    }
}

size_t MinNumOfSquareOp::operator()(size_t x) const
{
    //size_t k = FindTheSearchDepth_DP(x);
    size_t k = 4;
    if (k <= 2) {
        return k;
    }

    FindSolutionNotWorseThanK_DP(x, 0, k);
    return k;
}

// ********************************************************************************
// TEST
// ********************************************************************************
#include "MinNumOfSquareOp.h"

#include <cassert>
void TestCases()
{
    {
        MinNumOfSquareOp mns;
        assert(mns(0) == 0);
        assert(mns(1) == 1);
        assert(mns(2) == 2);
        assert(mns(3) == 3);
        assert(mns(4) == 1);
        assert(mns(5) == 2);
        assert(mns(6) == 3);
        assert(mns(7) == 4);
        assert(mns(8) == 2);
        assert(mns(9) == 1);
        assert(mns(10) == 2);
        assert(mns(11) == 3);
        assert(mns(12) == 3);
        assert(mns(13) == 2);
        assert(mns(14) == 3);
        assert(mns(15) == 4);
        assert(mns(16) == 1);
        assert(mns(17) == 2);
        assert(mns(18) == 2);
        assert(mns(19) == 3);
        assert(mns(20) == 2);
        assert(mns(21) == 3);
        assert(mns(22) == 3);
        assert(mns(23) == 4);
        assert(mns(24) == 3);
        assert(mns(25) == 1);
        assert(mns(103) == 4);
    }
}
// ********************************************************************************

No comments:

Post a Comment