#!/usr/bin/env python # coding: utf-8 #

230. Kth Smallest Element in a BST

#
# # #

Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

# #

 

# #

Example 1:

# #
Input: root = [3,1,4,null,2], k = 1
#    3
#   / \
#  1   4
#   \
#    2
# Output: 1
# #

 

#

Example 2:

# #
Input: root = [5,3,6,2,4,null,null,1], k = 3
#        5
#       / \
#      3   6
#     / \
#    2   4
#   /
#  1
# Output: 3
# 
# #

 

#

Constraints:

# # # # # #

 

# Source #
# In[1]: # Definition for a binary tree node. class TreeNode: def __init__(self, val=0, left=None, right=None): self.val = val self.left = left self.right = right #

Code

# In[2]: def kth_smallest(root, k): """Recursive approach Time complexity: O(H+k), where H is a tree height. The stack contains at least H+k elements, since before starting to pop out one has to go down to a leaf. O(log⁡N+k) for the balanced tree O(N+k) for completely unbalanced tree Space complexity: O(H) to keep the stack, where H is a tree height. O(N) in the worst case of the skewed tree O(log⁡N) in the average case of the balanced tree """ def inorder(node): if node is None: return nonlocal k, traversal if len(traversal) >= k: return inorder(node.left) traversal.append(node.val) inorder(node.right) traversal = [] inorder(root) return traversal[k-1] # !!! traversal[-1] don't work # # (always possible to add up to 2 more node after k) #
#

Follow up #1:

#

Solve it both recursively and iteratively

#

Code

# In[ ]: def kth_smallest(root, k): """iterative approach""" stack = [] traversal = [] node = root while node or stack: while node: # go to the leftmost child if len(traversal) == k: break stack.append(node) node = node.left # if no more left child, get the 1st right node and check for leftmost child again if len(traversal) == k: break node = stack.pop() traversal.append(node.val) node = node.right return traversal[-1] # In[ ]: def kth_smallest(root, k): """alternative iterative solution""" if not root: return None stack = [] traversal = [] node = root while len(traversal) < k: if node: stack.append(node) node = node.left elif stack and not node: node = stack.pop() traversal.append(node.val) node = node.right else: break return traversal[-1] # In[ ]: from itertools import islice def kth_smallest(root, k): """alternative iterative solution using a generator""" def inorder(node): if not node: return yield from inorder(node.left) yield node.val yield from inorder(node.right) return next(islice(inorder(root), k-1, k)) # return list(islice(inorder(root), k-1, k))[0] # alternatively #
#

Follow up #2:

#

What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

#

Code

# In[ ]: # In a BST Insert and delete have a time complexity of O(H), where H = height (= log N for balanced tree) # Without any optimisation insert/delete + search of kth element has O(2H+k) # We could combine the BST with a double linked list. Then it would have: # O(H) time for the insert and delete. # O(k) for the search of kth smallest. # The overall time complexity for insert/delete + search of kth smallest is O(H+k) instead of O(2H+k) # Time complexity: O(H+k) # O(log⁡N+k)in average case # O(N+k) in worst case # Space complexity: O(N) to keep the linked list.