It is a Google interview question for software engineer interns from careercup. Here is the original thread,
e.g:
3
/ \
9 4
/ \ / \
1 8 2
/ \ / \ / \
4 5 8 2
answer: <3,9,8,8>, sum = 3+9+8+8=28
2. Data Structure and Algorithm
My first attempt is to use the tree structure. Tree structure seems to be a very direct mapping to pyramid except one difference. The nodes (except the first node and the last node at the same level) have two parents from level 2 ( root - level 0) . There are two obvious ways to bridge this difference. One is to duplicate the nodes and then each has only one parent. The other is to keep as it is - two parents point to the same child.
In my implementation I picked the second option - two parents points the the same child. But this choice has a problem that tree free function does not work any more because of double memory free. As two parents point to the same child, the child is freed from the right branch of its left parent and will be freed the second time in the left branch of its right parent. Therefore in order to avoid the double free the pyramid tree structure has to be freed level by level - BFS. It is implemented in PyramidTree:Destory().
Finding the path with maximal sum becomes relatively easy because of the tree structure. It is a typical dynamic programming. Its sub-problem is that from the top down the sum of the current path plus its left and right branch, F(sum, node).
- F(0, root)
- Max(F(root->value, root->left), F(root->value, root->right)
- ......
See the implementation of PyramidTree::FindMaxSum() and PyramidTree::FindMaxSumAndPath().
3. C++ Implementation
// ********************************************************************************
// Implementation
// ********************************************************************************
#include "TreeNode.h"
#include <exception>
#include <queue>
#include <string>
#include <vector>
namespace Pyramid
{
template <typename T>
size_t ValidInput(const std::vector<T> &input) {
if (input.empty()) {
return 0;
}
// 1, 2, 3, 4, 5
// 1 + 2 + 3 + 4 + 5 = (1+5)*5/2= 15
// level = sqrt(30) = 5
// 5 * 6 = 30
const size_t inputLen = input.size();
const size_t level = static_cast<size_t>(sqrt(inputLen * 2));
return level*(level + 1) == 2 * inputLen ? level : 0;
}
template<typename T>
struct PyramidPath
{
PyramidPath()
: sum(0.)
{}
PyramidPath(const T& val)
: sum(val)
{
path.push_back(val);
}
PyramidPath(const PyramidPath &rhs, T val)
: sum(val + rhs.sum)
{
path.reserve(1 + rhs.path.size());
path.push_back(val);
path.insert(path.end(), rhs.path.begin(), rhs.path.end());
}
std::vector<T> path;
double sum;
};
class PyramixException : std::exception
{
public:
PyramixException(const std::string& msg)
: m_ErrMsg(msg)
{}
const char* what() const{
return m_ErrMsg.c_str();
}
private:
std::string m_ErrMsg;
};
template<typename T>
class PyramidTree
{
public:
PyramidTree()
: m_root(NULL), m_Depth(0)
{}
PyramidTree(const std::vector<T> &input)
: m_root(NULL), m_Depth(0)
{
ConstructPyramid(input);
}
~PyramidTree() {
Destory();
}
void ConstructPyramid(const std::vector<T> &input)
{
const size_t level = ValidInput(input);
if (level == 0) {
throw PyramixException("Construction failure - invalid input");
}
Destory();
// construct the tree
m_Depth = level;
std::queue<TreeNode<T>*> leafNodes;
auto iter = input.begin();
while (iter != input.end()) {
if (leafNodes.empty()) {
m_root = new TreeNode<T>(*iter);
leafNodes.push(m_root);
++iter;
}
else {
size_t leafNodeSize = leafNodes.size();
TreeNode<T>* newNode = new TreeNode<T>(*iter);
leafNodes.push(newNode);
while (leafNodeSize) {
TreeNode<T>* curNode = leafNodes.front();
curNode->left = newNode;
++iter;
newNode = new TreeNode<T>(*iter);
curNode->right = newNode;
leafNodes.pop();
leafNodes.push(newNode);
--leafNodeSize;
}
++iter;
}
}
}
double FindMaxSum() const {
if (!m_root) {
throw PyramixException("Pyramid not constructed yet");
}
return FindMaxSumInternal(m_root);
}
PyramidPath<T> FindMaxSumAndPath() const {
if (!m_root) {
throw PyramixException("Pyramid not constructed yet");
}
return FindMaxSumAndPathInternal(m_root);
}
size_t GetDepth() const {
if (!m_root) {
throw PyramixException("Pyramid not constructed yet");
}
return m_Depth;
}
void Destory() {
if (m_root) {
// delete the tree
std::queue<TreeNode<T>*> nodesAtSameLevel;
nodesAtSameLevel.push(m_root);
TreeNode<T>* curNode = NULL;
while (!nodesAtSameLevel.empty()) {
size_t numOfNodesAtSameLevel = nodesAtSameLevel.size();
curNode = nodesAtSameLevel.front();
if (curNode->left) {
nodesAtSameLevel.push(curNode->left);
}
while (numOfNodesAtSameLevel) {
curNode = nodesAtSameLevel.front();
if (curNode->right) {
nodesAtSameLevel.push(curNode->right);
}
delete curNode;
nodesAtSameLevel.pop();
--numOfNodesAtSameLevel;
}
}
m_root = NULL;
}
m_Depth = 0;
}
private:
double FindMaxSumInternal(TreeNode<T> *curNode) const {
if (!curNode) {
return 0.;
}
// tree sub-problem
const double leftBranchSum = FindMaxSumInternal(curNode->left) + *curNode->data;
const double rightBranchSum = FindMaxSumInternal(curNode->right) + *curNode->data;
return leftBranchSum > rightBranchSum ? leftBranchSum : rightBranchSum;
}
PyramidPath<T> FindMaxSumAndPathInternal(TreeNode<T> *curNode) const {
if (!curNode) {
return PyramidPath<T>();
}
// tree sub-problem
const PyramidPath<T> leftBranchPath = FindMaxSumAndPathInternal(curNode->left);
const PyramidPath<T> rightBranchPath = FindMaxSumAndPathInternal(curNode->right);;
if (leftBranchPath.sum >= rightBranchPath.sum) {
return PyramidPath<T>(leftBranchPath, *curNode->data);
}
return PyramidPath<T>(rightBranchPath, *curNode->data);
}
TreeNode<T> * m_root;
size_t m_Depth;
};
// ********************************************************************************
// Test
// ********************************************************************************
using namespace Pyramid;
void TestPyramidTree()
{
{
const std::vector<int> input;
assert(Pyramid::ValidInput(input) == 0);
const std::vector<int> input1 = { 1 };
assert(Pyramid::ValidInput(input1) == 1);
const std::vector<int> input2 = { 1, 2, 3 };
assert(Pyramid::ValidInput(input2) == 2);
const std::vector<int> input3 = { 1, 2, 3, 4, 5, 6 };
assert(Pyramid::ValidInput(input3) == 3);
const std::vector<int> input4 = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
assert(Pyramid::ValidInput(input4) == 4);
const std::vector<int> input5 = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
assert(Pyramid::ValidInput(input5) == 5);
}
{
const std::vector<int> input1 = { 1, 0 };
assert(Pyramid::ValidInput(input1) == 0);
const std::vector<int> input2 = { 1, 2, 3, 0 };
assert(Pyramid::ValidInput(input2) == 0);
const std::vector<int> input3 = { 1, 2, 3, 4, 5, 6, 0 };
assert(Pyramid::ValidInput(input3) == 0);
const std::vector<int> input4 = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0 };
assert(Pyramid::ValidInput(input4) == 0);
const std::vector<int> input5 = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0 };
assert(Pyramid::ValidInput(input5) == 0);
}
{
PyramidTree<int> pyramid;
try {
pyramid.GetDepth();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.FindMaxSum();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.FindMaxSumAndPath();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.ConstructPyramid({1, 2, 3, 4});
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Construction failure - invalid input");
}
try {
pyramid.FindMaxSum();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
try {
pyramid.FindMaxSumAndPath();
assert(false);
}
catch (const PyramixException &e) {
assert(std::string(e.what()) == "Pyramid not constructed yet");
}
pyramid.ConstructPyramid({ 1, 2, 3 });
assert(pyramid.GetDepth() == 2);
assert(pyramid.FindMaxSum() == 4);
auto path = pyramid.FindMaxSumAndPath();
assert(path.sum == 4);
assert(path.path == std::vector<int>({1, 3}));
pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6 });
assert(pyramid.GetDepth() == 3);
assert(pyramid.FindMaxSum() == 10);
path = pyramid.FindMaxSumAndPath();
assert(path.sum == 10);
assert(path.path == std::vector<int>({ 1, 3, 6 }));
pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
assert(pyramid.GetDepth() == 4);
assert(pyramid.FindMaxSum() == 20);
path = pyramid.FindMaxSumAndPath();
assert(path.sum == 20);
assert(path.path == std::vector<int>({ 1, 3, 6, 10 }));
pyramid.ConstructPyramid({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 });
assert(pyramid.GetDepth() == 5);
assert(pyramid.FindMaxSum() == 35);
path = pyramid.FindMaxSumAndPath();
assert(path.sum == 35);
assert(path.path == std::vector<int>({ 1, 3, 6, 10, 15 }));
}
}
Wouldn't Dijkstra's shortest path algorithm (modified to calculate longest path) on a graph solve this problem?
ReplyDelete