Search in a sorted 2D matrix

Problem Statement: Given an m*n 2D matrix and an integer, write a program to find if the given integer exists in the matrix.

Given matrix has the following properties:

  • Integers in each row are sorted from left to right.
  • The first integer of each row is greater than the last integer of the previous row

Example 1:

Input: matrix = 
[[1,3,5,7],
 [10,11,16,20],
 [23,30,34,60]], 

target = 3

Output: true

Explanation: As the given integer(target) exists in the given 2D matrix, the function has returned true.

Example 2:

Input: matrix = 
[[1,3,5,7],
 [10,11,16,20],
 [23,30,34,60]], 

target = 13

Output: false

Explanation: As the given integer(target) does not exist in the given 2D matrix, the function has returned false.

Solution

Disclaimer: Don’t jump directly to the solution, try it out yourself first.

Solution 1: Naive approach

Approach: We can traverse through every element that is present in the matrix and return true if we found any element in the matrix is equal to the target integer. If the traversal is finished we can directly return false as we did not find any element in the matrix to be equal to the target integer.

Time complexity: O(m*n)

Space complexity: O(1)

Solution 2: [Efficient] – Binary search

Intuition: As it is clearly mentioned that the given matrix will be row-wise and column-wise sorted, we can see that the elements in the matrix will be in a monotonically increasing order. So we can apply binary search to search the matrix. Consider the 2D matrix as a 1D matrix having indices from 0 to (m*n)-1 and apply binary search. Below the available image is the visual representation of the indices of 3*4 matrix.

Approach: 

i) Initially have a low index as the first index of the considered 1D matrix(i.e: 0) and high index as the last index of the considered 1D matrix(i.e: (m*n)-1).

int low = 0;
int  high = (m*n)-1;

ii) Now apply binary search. Run a while loop with the condition low<=high. Get the middle index as (low+high)/2.We can get the element at middle index using matrix[middle/m][middle%m].

while(low<=high)
    int middle = (low+high)/2;

iii) If the element present at the middle index is greater than the target, then it is obvious that the target element will not exist beyond the middle index. So shrink the search space by updating the high index to middle-1

if(matrix[middle/m][middle%m]<target)
    high = middle-1;

iv) If the middle index element is lesser than the target, shrink the search space by updating the low index to middle+1.

if(matrix[middle/m][middle%m]>target)
    low = middle+1;

v) If the middle index element is equal to the target integer, return true.

if(matrix[middle/m][middle%m]==target)
    return true;

vi) Once the loop terminates we can directly return false as we did not find the target element.

Code:

C++ Code

class Solution {
public:
    bool searchMatrix(vector<vector<int>>& matrix, int target) {
        int lo = 0;
        if(!matrix.size()) return false;
        int hi = (matrix.size() * matrix[0].size()) - 1;
        
        while(lo <= hi) {
            int mid = (lo + (hi - lo) / 2);
            if(matrix[mid/matrix[0].size()][mid % matrix[0].size()] == target) {
                return true;
            }
            if(matrix[mid/matrix[0].size()][mid % matrix[0].size()] < target) {
                lo = mid + 1;
            }
            else {
                hi = mid - 1;
            }
        }
        return false;
    }
};

Java Code

class Solution {
    public boolean searchMatrix(int[][] matrix, int target) {
        int lo = 0;
        if(matrix.length == 0) return false;
        int n = matrix.length; 
        int m = matrix[0].length; 
        int hi = (n * m) - 1;
        
        while(lo <= hi) {
            int mid = (lo + (hi - lo) / 2);
            if(matrix[mid/m][mid % m] == target) {
                return true;
            }
            if(matrix[mid/m][mid % m] < target) {
                lo = mid + 1;
            }
            else {
                hi = mid - 1;
            }
        }
        return false;
    }
}

Python Code

from typing import List

class Solution:
    def searchMatrix(self, matrix: List[List[int]], target: int) -> bool:
        lo = 0
        if not matrix:
            return False
        hi = (len(matrix) * len(matrix[0])) - 1


        while lo <= hi:
            mid = (lo + (hi - lo) // 2)
            if matrix[mid // len(matrix[0])][mid % len(matrix[0])] == target:
                return True
            if matrix[mid // len(matrix[0])][mid % len(matrix[0])] < target:
                lo = mid + 1
            else:
                hi = mid - 1
        return False

Time complexity: O(log(m*n))

Space complexity: O(1)

Special thanks to Rishi Visvas and Sudip Ghosh for contributing to this article on takeUforward. If you also wish to share your knowledge with the takeUforward fam, please check out this article.