3478. 选出和最大的 K 个元素
题目描述
给你两个整数数组,nums1
和 nums2
,长度均为 n
,以及一个正整数 k
。
对从 0
到 n - 1
每个下标 i
,执行下述操作:
- 找出所有满足
nums1[j]
小于nums1[i]
的下标j
。 - 从这些下标对应的
nums2[j]
中选出 至多k
个,并 最大化 这些值的总和作为结果。
返回一个长度为 n
的数组 answer
,其中 answer[i]
表示对应下标 i
的结果。
示例 1:
输入:nums1 = [4,2,1,5,3], nums2 = [10,20,30,40,50], k = 2
输出:[80,30,0,80,50]
解释:
- 对于
i = 0
:满足nums1[j] < nums1[0]
的下标为[1, 2, 4]
,选出其中值最大的两个,结果为50 + 30 = 80
。 - 对于
i = 1
:满足nums1[j] < nums1[1]
的下标为[2]
,只能选择这个值,结果为30
。 - 对于
i = 2
:不存在满足nums1[j] < nums1[2]
的下标,结果为0
。 - 对于
i = 3
:满足nums1[j] < nums1[3]
的下标为[0, 1, 2, 4]
,选出其中值最大的两个,结果为50 + 30 = 80
。 - 对于
i = 4
:满足nums1[j] < nums1[4]
的下标为[1, 2]
,选出其中值最大的两个,结果为30 + 20 = 50
。
示例 2:
输入:nums1 = [2,2,2,2], nums2 = [3,1,2,3], k = 1
输出:[0,0,0,0]
解释:由于 nums1
中的所有元素相等,不存在满足条件 nums1[j] < nums1[i]
,所有位置的结果都是 0 。
提示:
n == nums1.length == nums2.length
1 <= n <= 105
1 <= nums1[i], nums2[i] <= 106
1 <= k <= n
解法
方法一:排序 + 优先队列(小根堆)
我们可以将数组 $\textit{nums1}$ 转换成一个数组 $\textit{arr}$,其中每个元素是一个二元组 $(x, i)$,表示 $\textit{nums1}[i]$ 的值为 $x$。然后对数组 $\textit{arr}$ 按照 $x$ 进行升序排序。
我们使用一个小根堆 $\textit{pq}$ 来维护数组 $\textit{nums2}$ 中的元素,初始时 $\textit{pq}$ 为空。用一个变量 $\textit{s}$ 来记录 $\textit{pq}$ 中的元素之和。另外,我们用一个指针 $j$ 来维护当前需要添加到 $\textit{pq}$ 中的元素在数组 $\textit{arr}$ 中的位置。
我们遍历数组 $\textit{arr}$,对于第 $h$ 个元素 $(x, i)$,我们将所有满足 $j < h$ 并且 $\textit{arr}[j][0] < x$ 的元素 $\textit{nums2}[\textit{arr}[j][1]]$ 添加到 $\textit{pq}$ 中,并将这些元素的和加到 $\textit{s}$ 中。如果 $\textit{pq}$ 的大小超过了 $k$,我们将 $\textit{pq}$ 中的最小元素弹出,并将其从 $\textit{s}$ 中减去。然后,我们更新 $\textit{ans}[i]$ 的值为 $\textit{s}$。
遍历结束后,返回答案数组 $\textit{ans}$。
时间复杂度 $O(n \times \log n)$,空间复杂度 $O(n)$。其中 $n$ 为数组长度。
Python3
class Solution:
def findMaxSum(self, nums1: List[int], nums2: List[int], k: int) -> List[int]:
arr = [(x, i) for i, x in enumerate(nums1)]
arr.sort()
pq = []
s = j = 0
n = len(arr)
ans = [0] * n
for h, (x, i) in enumerate(arr):
while j < h and arr[j][0] < x:
y = nums2[arr[j][1]]
heappush(pq, y)
s += y
if len(pq) > k:
s -= heappop(pq)
j += 1
ans[i] = s
return ans
Java
class Solution {
public long[] findMaxSum(int[] nums1, int[] nums2, int k) {
int n = nums1.length;
int[][] arr = new int[n][0];
for (int i = 0; i < n; ++i) {
arr[i] = new int[] {nums1[i], i};
}
Arrays.sort(arr, (a, b) -> a[0] - b[0]);
PriorityQueue<Integer> pq = new PriorityQueue<>();
long s = 0;
long[] ans = new long[n];
int j = 0;
for (int h = 0; h < n; ++h) {
int x = arr[h][0], i = arr[h][1];
while (j < h && arr[j][0] < x) {
int y = nums2[arr[j][1]];
pq.offer(y);
s += y;
if (pq.size() > k) {
s -= pq.poll();
}
++j;
}
ans[i] = s;
}
return ans;
}
}
C++
class Solution {
public:
vector<long long> findMaxSum(vector<int>& nums1, vector<int>& nums2, int k) {
int n = nums1.size();
vector<pair<int, int>> arr(n);
for (int i = 0; i < n; ++i) {
arr[i] = {nums1[i], i};
}
ranges::sort(arr);
priority_queue<int, vector<int>, greater<int>> pq;
long long s = 0;
int j = 0;
vector<long long> ans(n);
for (int h = 0; h < n; ++h) {
auto [x, i] = arr[h];
while (j < h && arr[j].first < x) {
int y = nums2[arr[j].second];
pq.push(y);
s += y;
if (pq.size() > k) {
s -= pq.top();
pq.pop();
}
++j;
}
ans[i] = s;
}
return ans;
}
};
Go
func findMaxSum(nums1 []int, nums2 []int, k int) []int64 {
n := len(nums1)
arr := make([][2]int, n)
for i, x := range nums1 {
arr[i] = [2]int{x, i}
}
ans := make([]int64, n)
sort.Slice(arr, func(i, j int) bool { return arr[i][0] < arr[j][0] })
pq := hp{}
var s int64
j := 0
for h, e := range arr {
x, i := e[0], e[1]
for j < h && arr[j][0] < x {
y := nums2[arr[j][1]]
heap.Push(&pq, y)
s += int64(y)
if pq.Len() > k {
s -= int64(heap.Pop(&pq).(int))
}
j++
}
ans[i] = s
}
return ans
}
type hp struct{ sort.IntSlice }
func (h hp) Less(i, j int) bool { return h.IntSlice[i] < h.IntSlice[j] }
func (h *hp) Push(v any) { h.IntSlice = append(h.IntSlice, v.(int)) }
func (h *hp) Pop() any {
a := h.IntSlice
v := a[len(a)-1]
h.IntSlice = a[:len(a)-1]
return v
}
TypeScript
function findMaxSum(nums1: number[], nums2: number[], k: number): number[] {
const n = nums1.length;
const arr = nums1.map((x, i) => [x, i]).sort((a, b) => a[0] - b[0]);
const pq = new MinPriorityQueue();
let [s, j] = [0, 0];
const ans: number[] = Array(k).fill(0);
for (let h = 0; h < n; ++h) {
const [x, i] = arr[h];
while (j < h && arr[j][0] < x) {
const y = nums2[arr[j++][1]];
pq.enqueue(y);
s += y;
if (pq.size() > k) {
s -= pq.dequeue();
}
}
ans[i] = s;
}
return ans;
}