Monday, 14 March 2016

Find the Path With Maximal Sum in Pyramid (II)

1. Problem Description
See the problem description on BTS - Find the Path With Maximal Sum in Pyramid.

2. Data Structure and Algorithm
Instead of having a tree structure a native array structure should also do the job, because the each level has deterministic number of nodes
    - Level 0 (root) - 1 node
    - Level 1           - 2 nodes
    - Level 2           - 3 nodes
    - ......

And each node's child can be determined as long as the pyramid is given. Use Node(level, index) to represent a node in pyramid. where
   - level - represents which level this node is in, starting from 0
   - index - represents which location this node is in at this level, staring from 0
   - Root node - Node(0, 0)
   - Level 1 - Node(1, 0) and Node(1, 1)
   - Level 2 - Node(2, 0), Node(2, 1) and Node(2, 2)
   - Level 3 - Node(3, 0), Node(3, 1), Node(3, 2) and Node(3, 3)

In this mapping the children of a node is also deterministic.
    - Node(0, 0)'s children: Node(1, 0) and Node(1, 1)
    - Node(1, 0)'s children: Node(2, 0) and Node(2, 1)
    - Node(1, 1)'s children: Node(2, 1) and Node(2, 2)
    - Node(2, 0)'s children: Node(3, 0) and Node(3, 1)
    - Node(2, 1)'s children: Node(3, 1) and Node(3, 2)
    - Node(2, 3)'s children: Node(3, 2) and Node(3, 3)
    - ......

In summary for each level  and node
    - Level i has i+1 nodes
    - Non-leaf Node(i, j)'s children: Node(i+1, j) and Node(i+1, j+1)
    - Leaf Node(i, j) has no children

Finding the path with maximal sum is easy too because it can be resolved by the same idea of using tree data structure.

3. C++ Implementation
// ********************************************************************************
// Implementation
// ********************************************************************************
template <typename T>
class PyramidArray
{
private:
    // Node representation in array
    struct PyramidArrNode {
        PyramidArrNode(size_t l, size_t n)
            : level(l), node(n)
        {}
       
        size_t GetIndex() const {
            return (level*(level + 1) >> 1) + node;
        }

        // root : level 0
        // level 1 - two nodes (0, 1)
        // level 2 - three nodes(0, 1, 2)
        // ......
        size_t level;
        size_t node;
    };

public:
    PyramidArray()
        : m_Depth(0)
    {}

    PyramidArray(const std::vector<T> &input)
        : PyramidArray()
    {
        ConstructPyramid(input);
    }

    void ConstructPyramid(const std::vector<T> &input) {
        const size_t level = ValidInput(input);
        if (level == 0) {
            throw PyramixException("Construction failure - invalid input");
        }
        m_data = input;
        m_Depth = level;
    }

    double FindMaxSum() const {
        if (m_data.empty()) {
            throw PyramixException("Pyramid not constructed yet");
        }

        return FindMaxSumInternal(PyramidArrNode(0, 0));
    }

    PyramidPath<T> FindMaxSumAndPath() const {
        if (m_data.empty()) {
            throw PyramixException("Pyramid not constructed yet");
        }

        return FindMaxSumAndPathInternal(PyramidArrNode(0, 0));
    }

    size_t GetDepth() const {
        if (m_data.empty()) {
            throw PyramixException("Pyramid not constructed yet");
        }
        return m_Depth;
    }

private:
    // sub-problem
    double FindMaxSumInternal(const PyramidArrNode &node) const {
        // Node(0, 0) -> 0
        // Node(1, 0) -> 1
        // Node(1, 1) -> 2
        // Node(2, 0) -> ((1+2)/2)*2+0 = 3;
        // Node(2, 1) ->
        if (!HaveChildren(node)) {
            return m_data[node.GetIndex()];
        }

        const double leftBranchSum = FindMaxSumInternal(PyramidArrNode(node.level + 1, node.node));
        const double rightBranchSum = FindMaxSumInternal(PyramidArrNode(node.level + 1, node.node + 1));
        return leftBranchSum > rightBranchSum ? leftBranchSum + m_data[node.GetIndex()] :
                                                rightBranchSum + m_data[node.GetIndex()];
    }

    // sub-problem
    PyramidPath<T> FindMaxSumAndPathInternal(const PyramidArrNode &node) const {
        if (!HaveChildren(node)) {
            return PyramidPath<T>(m_data[node.GetIndex()]);
        }
        const PyramidPath<T> leftBranchPath = FindMaxSumAndPathInternal(PyramidArrNode(node.level + 1, node.node));
        const PyramidPath<T> rightBranchPath = FindMaxSumAndPathInternal(PyramidArrNode(node.level + 1, node.node + 1));
        if (leftBranchPath.sum >= rightBranchPath.sum) {
            return PyramidPath<T>(leftBranchPath, m_data[node.GetIndex()]);
        }
        return PyramidPath<T>(rightBranchPath, m_data[node.GetIndex()]);
    }

    bool HaveChildren(const PyramidArrNode &node) const {
        const size_t size = (node.level + 1) * (node.level + 2) >> 1;
        return m_data.size() > size;
    }

    std::vector<T> m_data;
    size_t m_Depth;
};

// ********************************************************************************
// Test
// ********************************************************************************
void TestPyramidArray()
{
    {
        PyramidArray<int> pyramid;
        try {
            pyramid.GetDepth();
            assert(false);
        }
        catch (const PyramixException &e) {
            assert(std::string(e.what()) == "Pyramid not constructed yet");
        }
        try {
            pyramid.FindMaxSum();
            assert(false);
        }
        catch (const PyramixException &e) {
            assert(std::string(e.what()) == "Pyramid not constructed yet");
        }

        try {
            pyramid.FindMaxSumAndPath();
            assert(false);
        }
        catch (const PyramixException &e) {
            assert(std::string(e.what()) == "Pyramid not constructed yet");
        }

        try {
            pyramid.ConstructPyramid({ 1, 2, 3, 4 });
            assert(false);
        }
        catch (const PyramixException &e) {
            assert(std::string(e.what()) == "Construction failure - invalid input");
        }

        try {
            pyramid.FindMaxSum();
            assert(false);
        }
        catch (const PyramixException &e) {
            assert(std::string(e.what()) == "Pyramid not constructed yet");
        }

        try {
            pyramid.FindMaxSumAndPath();
            assert(false);
        }
        catch (const PyramixException &e) {
            assert(std::string(e.what()) == "Pyramid not constructed yet");
        }

        pyramid.ConstructPyramid({ 1, 2, 3 });
        assert(pyramid.GetDepth() == 2);
        assert(pyramid.FindMaxSum() == 4);
        auto path = pyramid.FindMaxSumAndPath();
        assert(path.sum == 4);
        assert(path.path == std::vector<int>({ 1, 3 }));

        pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6 });
        assert(pyramid.GetDepth() == 3);
        assert(pyramid.FindMaxSum() == 10);
        path = pyramid.FindMaxSumAndPath();
        assert(path.sum == 10);
        assert(path.path == std::vector<int>({ 1, 3, 6 }));

        pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
        assert(pyramid.GetDepth() == 4);
        assert(pyramid.FindMaxSum() == 20);
        path = pyramid.FindMaxSumAndPath();
        assert(path.sum == 20);
        assert(path.path == std::vector<int>({ 1, 3, 6, 10 }));

        pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 });
        assert(pyramid.GetDepth() == 5);
        assert(pyramid.FindMaxSum() == 35);
        path = pyramid.FindMaxSumAndPath();
        assert(path.sum == 35);
        assert(path.path == std::vector<int>({ 1, 3, 6, 10, 15 }));
    }
}


No comments:

Post a Comment