See the problem description on BTS - Find the Path With Maximal Sum in Pyramid.
2. Data Structure and Algorithm
Instead of having a tree structure a native array structure should also do the job, because the each level has deterministic number of nodes
- Level 0 (root) - 1 node
- Level 1 - 2 nodes
- Level 2 - 3 nodes
- ......
And each node's child can be determined as long as the pyramid is given. Use Node(level, index) to represent a node in pyramid. where
- level - represents which level this node is in, starting from 0
- index - represents which location this node is in at this level, staring from 0
- Root node - Node(0, 0)
- Level 1 - Node(1, 0) and Node(1, 1)
- Level 2 - Node(2, 0), Node(2, 1) and Node(2, 2)
- Level 3 - Node(3, 0), Node(3, 1), Node(3, 2) and Node(3, 3)
In this mapping the children of a node is also deterministic.
- Node(0, 0)'s children: Node(1, 0) and Node(1, 1)
- Node(1, 0)'s children: Node(2, 0) and Node(2, 1)
- Node(1, 1)'s children: Node(2, 1) and Node(2, 2)
- Node(2, 0)'s children: Node(3, 0) and Node(3, 1)
- Node(2, 1)'s children: Node(3, 1) and Node(3, 2)
- Node(2, 3)'s children: Node(3, 2) and Node(3, 3)
- ......
In summary for each level and node
- Level i has i+1 nodes
- Non-leaf Node(i, j)'s children: Node(i+1, j) and Node(i+1, j+1)
- Leaf Node(i, j) has no children
Finding the path with maximal sum is easy too because it can be resolved by the same idea of using tree data structure.
3. C++ Implementation
// ********************************************************************************
// Implementation
// ********************************************************************************
template <typename T>
class PyramidArray
{
private:
// Node representation in array
struct PyramidArrNode {
PyramidArrNode(size_t l, size_t n)
: level(l), node(n)
{}
size_t GetIndex() const {
return (level*(level + 1) >> 1) + node;
}
// root : level 0
// level 1 - two nodes (0, 1)
// level 2 - three nodes(0, 1, 2)
// ......
size_t level;
size_t node;
};
public:
PyramidArray()
: m_Depth(0)
{}
PyramidArray(const std::vector<T> &input)
: PyramidArray()
{
ConstructPyramid(input);
}
void ConstructPyramid(const std::vector<T> &input) {
const size_t level = ValidInput(input);
if (level == 0) {
throw PyramixException("Construction failure - invalid input");
}
m_data = input;
m_Depth = level;
}
double FindMaxSum() const {
if (m_data.empty()) {
throw PyramixException("Pyramid not constructed yet");
}
return FindMaxSumInternal(PyramidArrNode(0, 0));
}
PyramidPath<T> FindMaxSumAndPath() const {
if (m_data.empty()) {
throw PyramixException("Pyramid not constructed yet");
}
return FindMaxSumAndPathInternal(PyramidArrNode(0, 0));
}
size_t GetDepth() const {
if (m_data.empty()) {
throw PyramixException("Pyramid not constructed yet");
}
return m_Depth;
}
private:
// sub-problem
double FindMaxSumInternal(const PyramidArrNode &node) const {
// Node(0, 0) -> 0
// Node(1, 0) -> 1
// Node(1, 1) -> 2
// Node(2, 0) -> ((1+2)/2)*2+0 = 3;
// Node(2, 1) ->
if (!HaveChildren(node)) {
return m_data[node.GetIndex()];
}
const double leftBranchSum = FindMaxSumInternal(PyramidArrNode(node.level + 1, node.node));
const double rightBranchSum = FindMaxSumInternal(PyramidArrNode(node.level + 1, node.node + 1));
return leftBranchSum > rightBranchSum ? leftBranchSum + m_data[node.GetIndex()] :
rightBranchSum + m_data[node.GetIndex()];
}
// sub-problem
PyramidPath<T> FindMaxSumAndPathInternal(const PyramidArrNode &node) const {
if (!HaveChildren(node)) {
return PyramidPath<T>(m_data[node.GetIndex()]);
}
const PyramidPath<T> leftBranchPath = FindMaxSumAndPathInternal(PyramidArrNode(node.level + 1, node.node));
const PyramidPath<T> rightBranchPath = FindMaxSumAndPathInternal(PyramidArrNode(node.level + 1, node.node + 1));
if (leftBranchPath.sum >= rightBranchPath.sum) {
return PyramidPath<T>(leftBranchPath, m_data[node.GetIndex()]);
}
return PyramidPath<T>(rightBranchPath, m_data[node.GetIndex()]);
}
bool HaveChildren(const PyramidArrNode &node) const {
const size_t size = (node.level + 1) * (node.level + 2) >> 1;
return m_data.size() > size;
}
std::vector<T> m_data;
size_t m_Depth;
};
// ********************************************************************************
// Test
// ********************************************************************************
void TestPyramidArray()
{
{
PyramidArray<int> pyramid;
try {
pyramid.GetDepth();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.FindMaxSum();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.FindMaxSumAndPath();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.ConstructPyramid({ 1, 2, 3, 4 });
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Construction failure - invalid input");
}
try {
pyramid.FindMaxSum();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.FindMaxSumAndPath();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
pyramid.ConstructPyramid({ 1, 2, 3 });
assert(pyramid.GetDepth() == 2);
assert(pyramid.FindMaxSum() == 4);
auto path = pyramid.FindMaxSumAndPath();
assert(path.sum == 4);
assert(path.path == std::vector<int>({ 1, 3 }));
pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6 });
assert(pyramid.GetDepth() == 3);
assert(pyramid.FindMaxSum() == 10);
path = pyramid.FindMaxSumAndPath();
assert(path.sum == 10);
assert(path.path == std::vector<int>({ 1, 3, 6 }));
pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
assert(pyramid.GetDepth() == 4);
assert(pyramid.FindMaxSum() == 20);
path = pyramid.FindMaxSumAndPath();
assert(path.sum == 20);
assert(path.path == std::vector<int>({ 1, 3, 6, 10 }));
pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 });
assert(pyramid.GetDepth() == 5);
assert(pyramid.FindMaxSum() == 35);
path = pyramid.FindMaxSumAndPath();
assert(path.sum == 35);
assert(path.path == std::vector<int>({ 1, 3, 6, 10, 15 }));
}
}
No comments:
Post a Comment