Print Nodes at Distance K in a Binary Tree

Problem Statement: Given a binary tree, a node from the tree, and an integer ‘K’, find all nodes that are at a distance ‘K’ from the given node and return the list of these nodes.

Examples

Example 1:

Input Format: 3 5 1 6 2 0 8 -1 -1 7 4 -1 -1 -1 -1

Output: 1, 7, 4
Explanation: Starting from the given node 5, the nodes at distance 2 are 7, 4, and 1.

Example 2:

Input Format: 2  3  21  6  1  56  4  -1  -1  9  55  -1  -1  44  7

Output: 3, 44, 7
Explanation: Starting from the given node 56, the nodes at distance 3 are 44, 7 and 3.

Practice:

Disclaimer: Don’t jump directly to the solution, try it out yourself first.

Solution:

Approach: 

It can be visualized that all nodes present at distance ‘K’ from a node are radially and symmetrically outwards from it with a step size of ‘K’.

To efficiently travel away from the target node one step at a time, we need to be able to access all adjacent nodes (left, right and parent) of each node. While the left and right child nodes are directly accessible through pointers, accessing the parent node requires maintaining an additional hashmap of the node as key and value as its parent.

The approach involves three primary steps: first, creating parent-child mappings through BFS, then locating and storing the target node, and finally, employing DFS from the target node to identify nodes at distance ‘K’. 

Algorithm:

Step 1: Create Parental Node Map with BFS

  1. Initialize a queue and a parent hashmap to store the parent of each node.
  2. Insert the root node into the queue.
  3. While the queue is not empty, pop the front node of the queue, set its parents as the popped node, and insert the children back into the queue.

Step 2: Locate the Target Node

If a reference to the target node is provided, use the given reference node. If only the value of the target node is given, perform any traversal method (inorder, postorder, or preorder) to find the node with the given value. Store the reference of the found node in the ‘target’ node pointer.

Read more about this in detail here: Inorder, Preorder, Postorder Traversal of Binary Tree

Step 3:  Perform a DFS traversal from the target node to find nodes at distance ‘K’

  1. Initialize a queue and a visited hashmap.
  2. Insert the target node into the queue and initialize the distance from target ‘dis’ as 0.
  3. While the queue is not empty, pop the front node.
  4. Push its non-visited adjacent nodes (parent, left child, right child) back into the queue. Mark the adjacent nodes as visited and increment the distance from the target ‘dis’ by one.
  5. If the distance from the target node is equal to K then add the node to the list of nodes at distance K (result array).

Step 4: Return the list of nodes found at distance ‘K’ from the target node.

Code:

C++ Code

#include <iostream>
#include <unordered_map>
#include <vector>
#include <queue>

using namespace std;

// TreeNode structure
struct TreeNode {
    int val;
    TreeNode *left;
    TreeNode *right;
    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
};

// Class to find nodes at a
// distance K from a target node
class Solution {
    
    // Helper function to mark parents
    // of nodes in the tree
    void markParents(TreeNode *root, unordered_map<TreeNode*,
        TreeNode*> &parent_track, TreeNode* target) {
        
        // Level Order Traversal by taking a queue
        queue<TreeNode*> queue;
        queue.push(root);
        
        // Iterate over all nodes
        while (!queue.empty()) {
            TreeNode* current = queue.front();
            queue.pop();
            
            // Assign parents to left child and
            // right child if they exist
            if (current->left) {
                parent_track[current->left] = current;
                queue.push(current->left);
            }
            
            if (current->right) {
                parent_track[current->right] = current;
                queue.push(current->right);
            }
        }
    }

public:
    // Function to find nodes at a
    // distance K from the target node
    vector<int> distanceK(TreeNode* root, TreeNode* target, int k) {
        
        // Map to mark the parents of all nodes
        unordered_map<TreeNode*, TreeNode*> parent_track;
        // Mark parents of all nodes
        markParents(root, parent_track, target); 
        // Keeps track of visited nodes
        unordered_map<TreeNode*, bool> visited; 
        // Queue to perform level-order traversal
        queue<TreeNode*> queue; 
         // Start traversal from the target node
        queue.push(target);
        // Tracks the current level
        // while traversing the tree
        int curr_level = 0; 
        
        // Continue traversal until the queue is empty
        while (!queue.empty()) { 
            // Get the number of nodes at the current level
            int size = queue.size(); 
            if (curr_level++ == k) { 
                // Break if the current level
                // matches the required distance (k)
                break;
            }
            
            // Traverse the current level of the tree
            for (int i = 0; i < size; i++) {
                // Get the front node in the queue
                TreeNode* current = queue.front(); 
                // Remove the front node from the queue
                queue.pop(); 
                
                // Add unvisited left child to the queue
                if (current->left && !visited[current->left]) {
                    queue.push(current->left);
                    // Mark left child as visited
                    visited[current->left] = true; 
                }
                
                // Add unvisited right child to the queue
                if (current->right && !visited[current->right]) {
                    queue.push(current->right);
                     // Mark right child as visited
                    visited[current->right] = true;
                }
                
                // Add unvisited parent node to the queue
                if (parent_track[current] && !visited[parent_track[current]]) {
                    queue.push(parent_track[current]);
                    // Mark parent node as visited
                    visited[parent_track[current]] = true; 
                }
            }
        }
        
        // Stores nodes at distance k from the target
        vector<int> result; 
        while (!queue.empty()) { 
            // Extract nodes at distance k from the queue
            TreeNode* current = queue.front();
            queue.pop();
            // Store node values in the result vector
            result.push_back(current->val); 
        }

        
        // Return nodes at distance
        // K from the target
        return result; 
    }
};

// Main function
int main() {
    // Create a sample tree for testing
    TreeNode* root = new TreeNode(3);
    root->left = new TreeNode(5);
    root->right = new TreeNode(1);
    root->left->left = new TreeNode(6);
    root->left->right = new TreeNode(2);
    root->right->left = new TreeNode(0);
    root->right->right = new TreeNode(8);
    root->left->right->left = new TreeNode(7);
    root->left->right->right = new TreeNode(4);

    Solution sol;
    TreeNode* target = root->left;
    int k = 2;
    // Find nodes at distance 2 from the target node
    vector<int> result = sol.distanceK(root, target, k); 

    // Print the elements at distance k from the target node
    cout << "Nodes at distance " << k << " from target node " << target->val << ": ";
    for (int val : result) {
        cout << val << " ";
    }
    cout << endl;

    return 0;
}

Output: Nodes at distance 2 from target node 5: 5 7 4 1

Time Complexity: O(2N + log N ) The time complexity arises from traversing the tree to create the parent hashmap, which involves visiting every node once hence O(N), exploring all nodes at a distance of ‘K’ which will be O(N) in the worst case, and the logarithmic lookup time for the hashmap is O( log N) in the worst scenario as well hence O(N + N + log N) which simplified to O(N).

Space Complexity: O(N) The space complexity stems from the data structures used, O(N) for the parent hashmap, O(N) for the queue of DFS, and O(N) for the visited hashmap hence overall our space complexity is O(3N) ~ O(N).

Java Code

import java.util.*;

// TreeNode class
class TreeNode {
    int val;
    TreeNode left;
    TreeNode right;

    TreeNode(int x) {
        val = x;
        left = null;
        right = null;
    }
}

// Class to find nodes at a distance K from a target node
class Solution {

    // Helper function to mark parents of nodes in the tree
    void markParents(TreeNode root,
            HashMap<TreeNode, TreeNode>parentTrack, TreeNode target) {
        // Level Order Traversal by taking a queue
        Queue<TreeNode> queue = new LinkedList<>();
        queue.add(root);

        // Iterate over all nodes
        while (!queue.isEmpty()) {
            TreeNode current = queue.poll();

            // Assign parents to left child and right child if they exist
            if (current.left != null) {
                parentTrack.put(current.left, current);
                queue.add(current.left);
            }

            if (current.right != null) {
                parentTrack.put(current.right, current);
                queue.add(current.right);
            }
        }
    }

    // Function to find nodes at a
    // distance K from the target node
    List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
        // Map to mark the parents of all nodes
        HashMap<TreeNode, TreeNode> parentTrack = new HashMap<>();
        // Mark parents of all nodes
        markParents(root, parentTrack, target);
        // Keeps track of visited nodes
        HashMap<TreeNode, Boolean> visited = new HashMap<>();
        // Queue to perform level-order traversal
        Queue<TreeNode> queue = new LinkedList<>();
        // Start traversal from the target node
        queue.add(target);
        // Tracks the current level
        // while traversing the tree
        int currLevel = 0;

        // Continue traversal until the queue is empty
        while (!queue.isEmpty()) {
            // Get the number of nodes
            // at the current level
            int size = queue.size();
            if (currLevel++ == k) {
                // Break if the current level
                // matches the required distance (k)
                break;
            }

            // Traverse the current level of the tree
            for (int i = 0; i < size; i++) {
                // Get the front node in the queue
                TreeNode current = queue.poll();

                // Add unvisited left child to the queue
                if (current.left != null &&
                                !visited.containsKey(current.left)) {
                    queue.add(current.left);
                    // Mark left child as visited
                    visited.put(current.left, true);
                }

                // Add unvisited right child to the queue
                if (current.right != null && !visited.containsKey(current.right)) {
                    queue.add(current.right);
                    // Mark right child as visited
                    visited.put(current.right, true);
                }

                // Add unvisited parent node to the queue
                if (parentTrack.containsKey(current) &&
                        !visited.containsKey(parentTrack.get(current))) {
                    queue.add(parentTrack.get(current));
                    // Mark parent node as visited
                    visited.put(parentTrack.get(current), true);
                }
            }
        }

        // Stores nodes at distance k from the target
        List<Integer> result = new ArrayList<>();
        while (!queue.isEmpty()) {
            // Extract nodes at distance k from the queue
            TreeNode current = queue.poll();
            // Store node values in the result list
            result.add(current.val);
        }

        // Return nodes at distance K from the target
        return result;
    }

    // Main method
    public static void main(String[] args) {
        // Create a sample tree for testing
        TreeNode root = new TreeNode(3);
        root.left = new TreeNode(5);
        root.right = new TreeNode(1);
        root.left.left = new TreeNode(6);
        root.left.right = new TreeNode(2);
        root.right.left = new TreeNode(0);
        root.right.right = new TreeNode(8);
        root.left.right.left = new TreeNode(7);
        root.left.right.right = new TreeNode(4);

        Solution sol = new Solution();
        TreeNode target = root.left;
        int k = 2;
        // Find nodes at distance 2 from the target node
        List<Integer> result = sol.distanceK(root, target, k);

        // Print the elements at distance k from the target node
        System.out.print("Nodes at distance " + k + " from target node " + target.val + ": ");
        for (int val : result) {
            System.out.print(val + " ");
        }
        System.out.println();
    }
}

Output: Nodes at distance 2 from target node 5: 5 7 4 1

Time Complexity: O(2N + log N ) The time complexity arises from traversing the tree to create the parent hashmap, which involves visiting every node once hence O(N), exploring all nodes at a distance of ‘K’ which will be O(N) in the worst case, and the logarithmic lookup time for the hashmap is O( log N) in the worst scenario as well hence O(N + N + log N) which simplified to O(N).

Space Complexity: O(N) The space complexity stems from the data structures used, O(N) for the parent hashmap, O(N) for the queue of DFS, and O(N) for the visited hashmap hence overall our space complexity is O(3N) ~ O(N).

Python Code

from collections import deque

# TreeNode class definition
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

# Solution class to find nodes
# at a distance K from a target node
class Solution:

    # Helper function to mark
    # parents of nodes in the tree
    def markParents(self, root, parent_track, target):

        # Level Order Traversal by using a deque
        queue = deque()
        queue.append(root)

        # Iterate over all nodes
        while queue:
            current = queue.popleft()

            # Assign parents to left child
            # and right child if they exist
            if current.left:
                parent_track[current.left] = current
                queue.append(current.left)
            
            if current.right:
                parent_track[current.right] = current
                queue.append(current.right)

    # Function to find nodes at a
    # distance K from the target node
    def distanceK(self, root, target, k):

        # Map to mark the parents of all nodes
        parent_track = {}
        # Mark parents of all nodes
        self.markParents(root, parent_track, target)
        # Keeps track of visited nodes
        visited = {}
        # Queue to perform level-order traversal
        queue = deque()
        # Start traversal from the target node
        queue.append(target)
        # Tracks the current level
        # while traversing the tree
        curr_level = 0
        
        # Continue traversal until
        # the queue is empty
        while queue:
            # Get the number of nodes
            # at the current level
            size = len(queue)
            if curr_level == k:
                # Break if the current level
                # matches the required distance (k)
                break

            # Traverse the current level of the tree
            for _ in range(size):
                # Get the front node in the queue
                current = queue.popleft()

                # Add unvisited left child to the queue
                if current.left and current.left not in visited:
                    queue.append(current.left)
                    # Mark left child as visited
                    visited[current.left] = True

                # Add unvisited right child to the queue
                if current.right and current.right not in visited:
                    queue.append(current.right)
                    # Mark right child as visited
                    visited[current.right] = True

                # Add unvisited parent node to the queue
                if current in parent_track and parent_track[current] not in visited:
                    queue.append(parent_track[current])
                    # Mark parent node as visited
                    visited[parent_track[current]] = True
            
            curr_level += 1

        # Stores nodes at distance
        # k from the target
        result = []
        for node in queue:
            # Store node values 
            # in the result list
            result.append(node.val)

        # Return nodes at distance
        # K from the target
        return result

# Main function
if __name__ == "__main__":
    # Create a sample tree for testing
    root = TreeNode(3)
    root.left = TreeNode(5)
    root.right = TreeNode(1)
    root.left.left = TreeNode(6)
    root.left.right = TreeNode(2)
    root.right.left = TreeNode(0)
    root.right.right = TreeNode(8)
    root.left.right.left = TreeNode(7)
    root.left.right.right = TreeNode(4)

    sol = Solution()
    target = root.left
    k = 2
    # Find nodes at distance 2
    # from the target node
    result = sol.distanceK(root, target, k)

    # Print the elements at distance
    # k from the target node
    print(f"Nodes at distance {k} from target node {target.val}: ", end="")
    for val in result:
        print(val, end=" ")
    print()

Output: Nodes at distance 2 from target node 5: 5 7 4 1

Time Complexity: O(2N + log N ) The time complexity arises from traversing the tree to create the parent hashmap, which involves visiting every node once hence O(N), exploring all nodes at a distance of ‘K’ which will be O(N) in the worst case, and the logarithmic lookup time for the hashmap is O( log N) in the worst scenario as well hence O(N + N + log N) which simplified to O(N).

Space Complexity: O(N) The space complexity stems from the data structures used, O(N) for the parent hashmap, O(N) for the queue of DFS, and O(N) for the visited hashmap hence overall our space complexity is O(3N) ~ O(N).

JavaScript Code

// TreeNode structure
class TreeNode {
    constructor(val) {
        this.val = val;
        this.left = null;
        this.right = null;
    }
}

// Class to find nodes at a
// distance K from a target node
class Solution {
    
    // Helper function to mark
    // parents of nodes in the tree
    markParents(root, parent_track, target) {
        
        // Level Order Traversal
        // by taking a queue
        let queue = [];
        queue.push(root);
        
        // Iterate over all nodes
        while (queue.length > 0) {
            let current = queue.shift();
            
            // Assign parents to left child
            // and right child if they exist
            if (current.left) {
                parent_track.set(current.left, current);
                queue.push(current.left);
            }
            
            if (current.right) {
                parent_track.set(current.right, current);
                queue.push(current.right);
            }
        }
    }

    // Function to find nodes at a
    // distance K from the target node
    distanceK(root, target, k) {
        
        // Map to mark the parents of all nodes
        let parent_track = new Map();
        // Mark parents of all nodes
        this.markParents(root, parent_track, target); 
        // Keeps track of visited nodes
        let visited = new Map(); 
        // Queue to perform level-order traversal
        let queue = []; 
        // Start traversal from the target node
        queue.push(target);
        // Tracks the current level
        // while traversing the tree
        let curr_level = 0; 
        
        // Continue traversal until
        // the queue is empty
        while (queue.length > 0) { 
            // Get the number of nodes
            // at the current level
            let size = queue.length; 
            if (curr_level++ === k) { 
                // Break if the current level
                // matches the required distance (k)
                break;
            }
            
            // Traverse the current level of the tree
            for (let i = 0; i < size; i++) {
                // Get the front node in the queue
                let current = queue.shift(); 
                
                // Add unvisited left child to the queue
                if (current.left && !visited.get(current.left)) {
                    queue.push(current.left);
                    // Mark left child as visited
                    visited.set(current.left, true); 
                }
                
                // Add unvisited right child to the queue
                if (current.right && !visited.get(current.right)) {
                    queue.push(current.right);
                    // Mark right child as visited
                    visited.set(current.right, true);
                }
                
                // Add unvisited parent node to the queue
                if (parent_track.get(current) &&
                        !visited.get(parent_track.get(current))) {
                    queue.push(parent_track.get(current));
                    // Mark parent node as visited
                    visited.set(parent_track.get(current), true); 
                }
            }
        }
        
        // Stores nodes at distance k from the target
        let result = []; 
        while (queue.length > 0) { 
            // Extract nodes at distance k from the queue
            let current = queue.shift();
            // Store node values in the result vector
            result.push(current.val); 
        }

        // Return nodes at distance
        // K from the target
        return result; 
    }
}

// Main function
function main() {
    // Create a sample tree for testing
    let root = new TreeNode(3);
    root.left = new TreeNode(5);
    root.right = new TreeNode(1);
    root.left.left = new TreeNode(6);
    root.left.right = new TreeNode(2);
    root.right.left = new TreeNode(0);
    root.right.right = new TreeNode(8);
    root.left.right.left = new TreeNode(7);
    root.left.right.right = new TreeNode(4);

    let sol = new Solution();
    let target = root.left;
    let k = 2;
    // Find nodes at distance 2 from the target node
    let result = sol.distanceK(root, target, k); 

    // Print the elements at distance
    // k from the target node
    console.log(`Nodes at distance ${k} from target node ${target.val}: ${result.join(' ')}`);
}

main();

Output: Nodes at distance 2 from target node 5: 5 7 4 1

Time Complexity: O(2N + log N ) The time complexity arises from traversing the tree to create the parent hashmap, which involves visiting every node once hence O(N), exploring all nodes at a distance of ‘K’ which will be O(N) in the worst case, and the logarithmic lookup time for the hashmap is O( log N) in the worst scenario as well hence O(N + N + log N) which simplified to O(N).

Space Complexity: O(N) The space complexity stems from the data structures used, O(N) for the parent hashmap, O(N) for the queue of DFS, and O(N) for the visited hashmap hence overall our space complexity is O(3N) ~ O(N).

In case you are learning DSA, you should definitely check out our free A2Z DSA Course with videos and blogs.

Special thanks to Gauri Tomar for contributing to this article on takeUforward. If you also wish to share your knowledge with the takeUforward fam, please check out this article