421. Maximum XOR of Two Numbers In An Array

/**

 * 421. Maximum XOR of Two Numbers in an Array

 * Given an integer array nums, return the maximum result of nums[i] XOR nums[j], where 0 ≤ i ≤ j < n.

 * Follow up: Could you do this in O(n) runtime?

 * Example 1:

 * Input: nums = [3,10,5,25,2,8]

 * Output: 28

 * Explanation: The maximum result is 5 XOR 25 = 28.

 */

See the solution on github here: https://github.com/zcoderz/leetcode/blob/main/src/main/java/trie/MaxXorOfTwoArrays.java

A very naïve way would be to run XOR of each element in the array with every other element and return the maximum xor value. This would run in O(N^2).

However, the problem can be solved in O(N) time via use of trie. This is a very unique usage of trie that I came across and hence I have included this question here.

Observation:

  • XOR is a binary operation that runs on bits of the numbers
  • Numbers are constructed off of bits where they can share bit prefixes.
  • If we think of the problem conceptually as numbers sharing bit prefixes then the problem can utilize tries and you can walk the trie to find matches in a single iteration.

Solution:

  • Convert each number to a binary char array. Pad the number with a bit mask so that all numbers have the same length.
    • int len = Integer.toBinaryString(maxNum).length();
    • int bitMask = 1 << len;
    • String binary = Integer.toBinaryString(num | bitmask).substring(1);
  • For each number and for each binary character in the number
    • Traverse a trie
      • Update trie with new bit if it didn’t already exist
    • If node for opposite binary character is found then use the opposite node to traverse the opposite path
    • If node for opposite bit isn’t found then use the current bit’s node to traverse the opposite path
    • Per iteration across binary characters left shift the xor count by 1 to signify the new bit index

Here is the code:

(if the above description isn’t clear, see the inline comments in code below to understand the logic)

int buildTrie(int num, TrieNode node, int bitMask) {
    //this is so that each number has the same number of bits
    String binary = Integer.<em>toBinaryString</em>(num | bitMask).substring(1);
    TrieNode xorNode = node;
    int maxXor = 0;
    //this is a great idea. traverse the path of the opposite bits
    //update an integer that reflects the path of xor bits calculated so far
    for (Character ch : binary.toCharArray()) {
        maxXor = maxXor << 1; //left shift at each iteration so the bit is accounted at its correct final index
        Integer chInt = ch == '0' ? 0 : 1;
        //create the new node if it doesn't already exist
        TrieNode newNode = node.map.computeIfAbsent(chInt, k -> new TrieNode());
        Integer opposite = chInt == 0 ? 1 : 0;
        TrieNode nextXorNode;
        //retrieve the opposite node
        nextXorNode = xorNode.map.get(opposite);
        if (nextXorNode == null) {
            //if next opposite was missing traverse the path of current bit
            nextXorNode = xorNode.map.get(chInt);
        } else {
            //add 1 to maxOr to reflect the increase
            maxXor += 1;
        }
        //update pointers for next iteration
        xorNode = nextXorNode;
        node = newNode;
    }
    return maxXor;
}