235. Lowest Common Ancestor of a Binary Search Tree

Given a binary search tree (BST), find the lowest common ancestor (LCA) of two given nodes in the BST.

According to thedefinition of LCA on Wikipedia: “The lowest common ancestor is defined between two nodes v and w as the lowest node in T that has both v and w as descendants (where we allowa node to be a descendant of itself).”

        _______6______
       /              \
    ___2__          ___8__
   /      \        /      \
   0      _4       7       9
         /  \
         3   5

For example, the lowest common ancestor (LCA) of nodes2and8is6. Another example is LCA of nodes2and4is2, since a node can be a descendant of itself according to the LCA definition.

Thoughts:

Because of BST, we can decide which branch to to based on values. Two ways to approach this problem:

  1. Iterative, O(1) space : Iteratively traversing down the side on which two nodes reside until the "split" is found.

  2. Recuesive

Iterative:

Code 1

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
        while((root->val - p->val) * (root-> val - q->val) > 0) 
            root = (root-> val) > (p->val)? (root->left): (root-> right);
            // =0 means either the current node is a. root is one of {p,q} b. root is the lowest parent of p and q.
        return root;
    }
};

Code 1 (Java)

public TreeNode lowestCommonAncestor(TreeNode root, TreeNode p, TreeNode q) {
    while ((root.val - p.val) * (root.val - q.val) > 0)
        root = p.val < root.val ? root.left : root.right;
    return root;
}

Code 1 (Python)

def lowestCommonAncestor(self, root, p, q):
    while (root.val - p.val) * (root.val - q.val) > 0:
        root = (root.left, root.right)[p.val > root.val]
    return root
def lowestCommonAncestor(self, root, p, q):
    while root:
        if p.val < root.val > q.val:
            root = root.left
        elif p.val > root.val < q.val:
            root = root.right
        else:
            return root

Recursive

Code 2

class Solution {
public:
    TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
        return (root-> val - p-> val)*(root ->val - q->val) > 0 ?
            lowestCommonAncestor(p->val < root->val ? root->left: root->right, p , q): root;
    }
};

Code 2 (Java)

class Solution {
    public TreeNode lowestCommonAncestor(TreeNode root, TreeNode p, TreeNode q) {
        return (root.val - p.val) * (root.val - q.val) > 0 ? 
            lowestCommonAncestor(p.val - root.val < 0 ? root.left: root.right, p , q) : root;
    }
}

Code 2 (Python)

class Solution(object):
    def lowestCommonAncestor(self, root, p, q):
        return self.lowestCommonAncestor((root.left, root.right)[p.val > root.val], p, q) \
                if (root.val - p.val) * (root.val - q.val) > 0  else root

Special thanks to StefanPochmann as he nailed this problem again over here.

Last updated

Was this helpful?