Problem
315. Count of Smaller Numbers After Self
Solution
设数组长度为n
归并排序, O(nlogn) time, O(n) space
首先按照解决问题的一般套路,观察问题可不可以reduce(参考算法笔记-概述),发现reduce后不易解(实际上是可以的,不过这里记录解题时的思路,一开始想的时候并没有想出来如何用分治求解);观察解空间和解的形式也无用。
而对于算法而言,有一个很重要的特性,算法一般更容易处理结构化的数据。对于数组而言,结构化通常意味着有序。如果考虑排序,那么对于结果中的第i个元素,其值为:
* 排序过程中,原数组中第i个元素右边有多少比其小的元素出现在左边
这个性质意味着这个问题就可以用排序算法求解,由于只希望右边比自己小的元素出现在左边,那么意味着排序过程是稳定的,而稳定的排序算法中效率最高的是归并排序(归并排序基于分治,即该问题可以通过reduce来求解),O(nlogn) time.
事实上我们并不关心最后排序的结果,只关心排序过程中的逆序情况,所以可以直接对元素下标排序。
class Solution
{
public:
vector<int> countSmaller(vector<int> &nums)
{
int n = nums.size();
vector<int> indexs(n, 0);
iota(indexs.begin(), indexs.end(), 0);
vector<int> results(n, 0);
mergeSort(nums, 0, n - 1, indexs, results);
return results;
}
private:
void mergeSort(vector<int> &nums, int start, int end, vector<int> &indexs, vector<int> &results)
{
if (end - start >= 1)
{
int mid = (start + end) / 2;
mergeSort(nums, start, mid, indexs, results);
mergeSort(nums, mid + 1, end, indexs, results);
int i = start, j = mid + 1, pos = start;
vector<int> tmp;
int cnt = 0;
while (i <= mid && j <= end)
{
if (nums[indexs[j]] < nums[indexs[i]])
{
tmp.emplace_back(indexs[j++]);
cnt++;
}
else
{
results[indexs[i]] += cnt;
tmp.emplace_back(indexs[i++]);
}
}
while (i <= mid)
{
results[indexs[i]] += cnt;
tmp.emplace_back(indexs[i++]);
}
while (j <= mid)
tmp.emplace_back(indexs[j++]);
for (auto item : tmp)
indexs[pos++] = item;
}
}
};
二叉搜索树 O(nlogn) time, O(n) sapce
题目所求的是数组元素右边比其小的元素个数,相当于要求部分有序-能快速的求出右边比自己小的元素个数,而二叉搜索树就有类似的性质:根节点大于左子树的所有结点
如果数组从后向前建立搜索树,将节点插入后,统计其左子树节点个数即可。为了优化算法,可以记录每个节点左子树节点个数,插入时更新。这样就得到了一个O(nlogn) time的解法
class Solution
{
public:
vector<int> countSmaller(vector<int> &nums)
{
int n = nums.size();
vector<int> results(n, 0);
if (n <= 1)
return results;
Node *root = new Node(nums[n - 1]);
for (int i = n - 2; i >= 0; --i)
results[i] = insert(root, nums[i]);
return results;
}
private:
struct Node
{
int val;
int count;
int leftSize;
Node *left;
Node *right;
Node(int _val)
{
val = _val;
count = 1;
leftSize = 0;
left = nullptr;
right = nullptr;
}
};
int insert(Node *node, int val)
{
if (node->val > val)
{
node->leftSize++;
if (node->left == nullptr)
{
node->left = new Node(val);
return 0;
}
else
return insert(node->left, val);
}
else if (node->val < val)
{
if (node->right == nullptr)
{
node->right = new Node(val);
return node->count + node->leftSize;
}
else
return node->count + node->leftSize + insert(node->right, val);
}
else
{
node->count++;
return node->leftSize;
}
}
};