Powered by:NEFU AB-IN
Link
文章目录
- 3209. 子数组按位与值为 K 的数目
- 题意
- 思路
- 代码
3209. 子数组按位与值为 K 的数目
题意
给你一个整数数组 nums 和一个整数 k ,请你返回 nums 中有多少个
子数组
满足:子数组中所有元素按位 AND 的结果为 k 。
思路
-
st+二分
https://leetcode.cn/problems/number-of-subarrays-with-and-value-of-k/solutions/2833382/stbiao-er-fen-by-time-v5-4qtm/
细节在于:- 由于一直and操作是非递减的,所以取个负号,这样就是非递增了,能配合bisect_left函数
- bisect_left 函数,可以直接这么用
l = bisect_left(range(i, n), -k, key=lambda r: -st.query(i, r))
,直接在 range(i, n) 上进行二分查找,通过 key 参数动态计算按位与结果,会快很多(不过这是 3.10 引进的特性),当然我也会自己实现- range(i, n) 生成从 i 到 n-1 的索引序列。
- -k 是要插入的元素,它是 k 的相反数。
- key=lambda r: -st.query(i, r) 是自定义的比较函数,用于对 range(i, n) 中的每个元素 r 进行比较。它将 st.query(i, r) 的结果取反,实际上是在对 -st.query(i, r) 进行比较。
相当于
and_results = [-st.query(i, r) for r in range(i, n)] l = bisect_left(and_results, -k) r = bisect_right(and_results, -k) ans += r - l
-
滚动数组 + 哈希 + dp
- 我们使用一个二维数组 dp 来记录以每个位置结尾的所有可能的按位 AND 结果及其出现次数。定义
dp[i][j]
为以 nums[i] 结尾的子数组中,按位 AND 结果为 j 的子数组数量。 - 初始化,dp[0][nums[0]] = 1:表示第一个元素单独形成一个子数组,且按位 AND 结果为 nums[0]。
- 对于每个元素 nums[i],我们需要遍历之前所有的状态来更新当前状态: d p [ i ] [ n u m s [ i ] & k e y ] + = d p [ i − 1 ] [ k e y ] dp[i][nums[i]\&key]+=dp[i−1][key] dp[i][nums[i]&key]+=dp[i−1][key]
- 由于每一层的状态仅依赖于前一层的状态,因此我们可以使用滚动数组来优化空间复杂度。
- 我们使用一个二维数组 dp 来记录以每个位置结尾的所有可能的按位 AND 结果及其出现次数。定义
代码
'''
Author: NEFU AB-IN
Date: 2024-07-06 21:52:11
FilePath: \LeetCode\CP134_2\d\d.py
LastEditTime: 2024-07-08 16:31:44
'''
# 3.8.19 import
from collections import Counter, defaultdict, deque
from datetime import datetime, timedelta
from functools import lru_cache
from heapq import heapify, heappop, heappush, nlargest, nsmallest
from itertools import combinations, compress, permutations, starmap, tee
from math import ceil, fabs, floor, gcd, log, sqrt
from string import ascii_lowercase, ascii_uppercase
from sys import exit, setrecursionlimit, stdin
from typing import Any, Dict, List, Tuple, TypeVar, Union# Constants
TYPE = TypeVar('TYPE')
N = int(2e5 + 10) # If using AR, modify accordingly
M = int(20) # If using AR, modify accordingly
INF = int(2e9)
OFFSET = int(100)# Set recursion limit
setrecursionlimit(INF)class Arr:array = staticmethod(lambda x=0, size=N: [x] * size)array2d = staticmethod(lambda x=0, rows=N, cols=M: [Arr.array(x, cols) for _ in range(rows)])graph = staticmethod(lambda size=N: [[] for _ in range(size)])@staticmethoddef to_1_indexed(data: Union[List, str, List[List]]):"""Adds a zero prefix to the data and returns the modified data and its length."""if isinstance(data, list):if all(isinstance(item, list) for item in data): # Check if it's a 2D arraynew_data = [[0] * (len(data[0]) + 1)] + [[0] + row for row in data]return new_data, len(new_data) - 1, len(new_data[0]) - 1else:new_data = [0] + datareturn new_data, len(new_data) - 1elif isinstance(data, str):new_data = '0' + datareturn new_data, len(new_data) - 1else:raise TypeError("Input must be a list, a 2D list, or a string")class Str:letter_to_num = staticmethod(lambda x: ord(x.upper()) - 65) # A -> 0num_to_letter = staticmethod(lambda x: ascii_uppercase[x]) # 0 -> Aremoveprefix = staticmethod(lambda s, prefix: s[len(prefix):] if s.startswith(prefix) else s)removesuffix = staticmethod(lambda s, suffix: s[:-len(suffix)] if s.endswith(suffix) else s)class Math:max = staticmethod(lambda a, b: a if a > b else b)min = staticmethod(lambda a, b: a if a < b else b)class IO:input = staticmethod(lambda: stdin.readline().rstrip("\r\n"))read = staticmethod(lambda: map(int, IO.input().split()))read_list = staticmethod(lambda: list(IO.read()))class Std:@staticmethoddef find(container: Union[List[TYPE], str], value: TYPE):"""Returns the index of value in container or -1 if value is not found."""if isinstance(container, list):try:return container.index(value)except ValueError:return -1elif isinstance(container, str):return container.find(value)@staticmethoddef pairwise(iterable):"""Return successive overlapping pairs taken from the input iterable."""a, b = tee(iterable)next(b, None)return zip(a, b)@staticmethoddef bisect_left(a, x, key=lambda y: y):"""The insertion point is the first position where the element is not less than x."""left, right = 0, len(a)while left < right:mid = (left + right) // 2if key(a[mid]) < x:left = mid + 1else:right = midreturn left@staticmethoddef bisect_right(a, x, key=lambda y: y):"""The insertion point is the first position where the element is greater than x."""left, right = 0, len(a)while left < right:mid = (left + right) // 2if key(a[mid]) <= x:left = mid + 1else:right = midreturn leftclass SparseTable:def __init__(self, data: list, func=lambda x, y: x | y):"""Initialize the Sparse Table with the given data and function."""self.func = funcself.st = [list(data)]i, n = 1, len(self.st[0])while 2 * i <= n:pre = self.st[-1]self.st.append([func(pre[j], pre[j + i]) for j in range(n - 2 * i + 1)])i <<= 1def query(self, begin: int, end: int):"""Query the combined result over the interval [begin, end]."""lg = (end - begin + 1).bit_length() - 1return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1])# ————————————————————— Division line ——————————————————————
class Solution:def countSubarrays(self, nums: List[int], k: int) -> int:st = Std.SparseTable(nums, func=lambda x, y: x & y)ans = 0n = len(nums)for i in range(n):l = Std.bisect_left(range(i, n), -k, key=lambda r: -st.query(i, r))r = Std.bisect_right(range(i, n), -k, key=lambda r: -st.query(i, r))ans += r - lreturn ansclass Solution:def countSubarrays(self, nums: List[int], k: int) -> int:nums, n = Arr.to_1_indexed(nums)dp = Counter()res = 0for i in range(1, n + 1):cur_dp = Counter()cur_dp[nums[i]] += 1for num, val in dp.items():# 转移方程cur_dp[nums[i] & num] += valres += cur_dp[k]dp = cur_dpreturn res