背景
了解到二叉索引树这个数据结构,是在leetcode
的 307 题,题目是要求实现一个数据结构,可以返回数组任意区间的和以及更新数组的某个值。
307、Range Sum Query - MutableGiven an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive.The update(i, val) function modifies nums by updating the element at index i to val.Example:Given nums = [1, 3, 5]sumRange(0, 2) -> 9update(1, 2)sumRange(0, 2) -> 8Constraints:- The array is only modifiable by the update function.- You may assume the number of calls to update and sumRange function is distributed evenly.- 0 <= i <= j <= nums.length - 1
常规解法
先介绍下常规的解法,树状数组有用到他们之中的一些思想或者过程。
解法一
最暴力的解法,sumRange
直接for
循环算,update
直接更新数组中的值。
/*** @param {number[]} nums*/var NumArray = function (nums) {this.nums = [...nums];};/*** @param {number} i* @param {number} val* @return {void}*/NumArray.prototype.update = function (i, val) {this.nums[i] = val;};/*** @param {number} i* @param {number} j* @return {number}*/NumArray.prototype.sumRange = function (i, j) {let sum = 0;for (let k = i; k <= j; k++) {sum += this.nums[k];}return sum;};/*** Your NumArray object will be instantiated and called as such:* var obj = new NumArray(nums)* obj.update(i,val)* var param_2 = obj.sumRange(i,j)*/
时间复杂度:update
是O(1)
,sumRange
是O(n)
。
解法二
303 题 做过sumRange
的优化,我们用一个数组保存累计的和,numsAccumulate[i]
存储0
到i - 1
累计的和。
如果我们想求i
累积到j
的和,只需要用numsAccumulate[j + 1]
减去numsAccumulate[i]
。
结合下边的图应该很好理解,我们要求的是橙色部分,相当于B
的部分减去A
的部分。
所以我们可以提前把一些前缀和存起来,然后查询区间和的时候在可以通过差实现。
/*** @param {number[]} nums*/var NumArray = function (nums) {this.nums = [...nums];this.numsAccumulate = [0];let sum = 0;for (let i = 0; i < nums.length; i++) {sum += nums[i];this.numsAccumulate.push(sum);}};/*** @param {number} i* @param {number} val* @return {void}*/NumArray.prototype.update = function (i, val) {let sub = val - this.nums[i];this.nums[i] = val;for (let k = i + 1; k < this.numsAccumulate.length; k++) {this.numsAccumulate[k] += sub;}};/*** @param {number} i* @param {number} j* @return {number}*/NumArray.prototype.sumRange = function (i, j) {return this.numsAccumulate[j + 1] - this.numsAccumulate[i];};/*** Your NumArray object will be instantiated and called as such:* var obj = new NumArray(nums)* obj.update(i,val)* var param_2 = obj.sumRange(i,j)*/
时间复杂度:update
是O(n)
,sumRange
是O(1)
。
虽然sumRange
的时间复杂度优化了,但是update
又变成了O(n)
。因为更新一个值的时候,这个值后边的累计和都需要更新。
解法三
解法一和解法二时间复杂度两个方法始终一个是O(1)
,一个是O(n)
。这里再分享 官方题解 提供的一个解法,可以优化查询区间的时间复杂度。
我们可以将原数据分成若干个组,然后提前计算这些组的和,举个例子。
组号: 0 1 2 3数组: [2 4 5 6] [9 9 3 8] [1 2 3 4] [4 2 3 4]下标: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15和:17 2910 13
如果我们要计算sumRange(1,13)
,之前我们需要循环累加下标1
到13
的数字的和。
现在我们只需要循环累加1
到3
的和,加上循环累加12
到13
的和,再累加中间组提前算好的和,也就是第1
组和第2
组的和29
和10
,就是最终的结果了。
至于更新的话,我们也不需要像解法二那样更新那么多。我们只需要更新当前元素所在的组即可。
下一个问题,每组的大小定多少呢?
如果定的小了,那么组数就会特别多。
如果定的大了,那么组内元素就会特别多。
组数和组内元素个数都会影响到sumRange
的时间复杂度。
这里,我们在组数和组内元素个数之间取个平衡,假设数组大小是n
,那么组内元素个数取
,这样的话组数也是 ,这样就可以保证我们查询的时间复杂度是 了。因为最坏的情况,无非是查询范围跨越整个数组,中间我们需要累加 个组,第0
组最多累加 次,最后一组也最多累加 次,整体上就是 了。
结合代码理解一下。
/*** @param {number[]} nums*/var NumArray = function (nums) {this.nums = [...nums];this.groupSize = Math.floor(Math.sqrt(this.nums.length));this.group = [];let sum = 0;let i = 0;for (i = 0; i < nums.length; i++) {sum += nums[i];if ((i + 1) % this.groupSize === 0) {this.group.push(sum);sum = 0;}}//有可能数组大小不能整除组的大小, 最后会遗漏下几个元素if (i % this.groupSize !== 0) {this.group.push(sum);}};/*** @param {number} i* @param {number} val* @return {void}*/NumArray.prototype.update = function (i, val) {let sub = val - this.nums[i];let groudId = Math.floor(i / this.groupSize);this.group[groudId] += sub;this.nums[i] = val;};/*** @param {number} i* @param {number} j* @return {number}*/NumArray.prototype.sumRange = function (i, j) {let groupI = Math.floor(i / this.groupSize);let groupJ = Math.floor(j / this.groupSize);let sum = 0;//在同一组内, 直接累加if (groupI === groupJ) {for (let k = i; k <= j; k++) {sum += this.nums[k];}} else {//左边组的元素累加for (let k = i; k < (groupI + 1) * this.groupSize; k++) {sum += this.nums[k];}//累加中间所有的组for (let g = groupI + 1; g < groupJ; g++) {sum += this.group[g];}//右边组的元素累加for (let k = groupJ * this.groupSize; k <= j; k++) {sum += this.nums[k];}}return sum;};/*** Your NumArray object will be instantiated and called as such:* var obj = new NumArray(nums)* obj.update(i,val)* var param_2 = obj.sumRange(i,j)*/
时间复杂度:update
是O(1)
,sumRange
是
。
树状数组
有了上边的背景,我们再回到树状数组。
这个解法写法很简单,但理解的话可能稍微难一些,很多文章都直接讲该怎么用,没有介绍最初的动机,于是去看了提出这个解法的 原始论文,看看能不能理解。
这个解法叫Fenwick tree
或者binary indexed tree
,翻译过来的话叫做树状数组或者二叉索引树,但我觉得binary
翻译成二进制更好,叫做二进制索引树更贴切些,二叉树容易引起误解。
回想一下解法三,我们预先求出了若干个区间和,然后查询的区间可以根据之前预先求出来的区间来求出。这里的话同样的思想,先预先求一些区间和,然后把要求的区间分解成若干个之前求好的区间和即可。相比于解法三,这里的分解会更加巧妙一些。
我们知道计算机中的数都是由二进制来表示的,任何一个数都可以分解成2
的幂次的和,进制转换不熟的话可以参考 再谈进制转换。
举个例子
, 等等。
接下来就是神奇的地方了,每一个数都可以拆成这样的x = a + b + c + ...
的形式。
我们把等式左侧的数x
看做是区间[1, x]
,等式右边看做从x
开始每个区间的长度,也就变成了下边的样子。
[1, x] = [x, x - a + 1] + [x - a, x - a - b + 1] + [x - a - b, x - a - b - c + 1] + ...
。
看起来有些复杂,举个具体的例子就简单多了。
以
为例,可以转换为下边的等式。
[1, 11] = [11, 11] + [10, 9] + [8, 1]
。
[11, 11]
、[10, 9]
、[8, 1]
长度分别是1
、2
、8
。
我们成功把一个大区间,分成了若干个小区间,这就是树状数组最核心的地方了,只要理解了上边讲的,下边就很简单了。
首先,因为数组的下标是从0
开始的,上边的区间范围是从1
开始的,所以我们在原数组开头补一个0
,这样区间就是从1
开始了。
因此我们可以通过分解快速的求出[1, x]
任意前缀区间的和,知道了前缀区间的和,就回到了解法二,通过做差可以算出任意区间的和了。
最后,我们需要解决子区间该怎么求?
[1, 11] = [11, 11] + [10, 9] + [8, 1]
我们用V
表示子区间,用F
表示某个区间。
F[1,11] = V[11] + V[10] + V[8]
其中,V[11] = F[11,11], V[10] = F[10,9], V[8]=F[8...1]
,为什么是这样?
回到二进制,F[0001,1011] = V[1011] + V[1010] + V[1000]
1010 = 1011 - 0001
,0001
就是十进制的1
,所以V[1011]
存1
个数,所以V[11] = F[11,11]
。
1000 = 1010 - 0010
,0010
就是十进制的2
,所以V[1010]
存2
个数,所以V[10] = F[10,9]
。
0000 = 1000 - 1000
,1000
就是十进制的8
,所以V[1000]
存8
个数,所以V[8] = F[8...1]
。
V[1011]
存1
个数,V[1010]
存2
个数,看的是二进制最右边的一个1
到末尾的大小。1010
就是10
,1000
就是1000
。
怎么得到一个数最右边的1
到末尾的大小,是二进制操作的一个技巧,会用到一些补码的知识,可以参考 趣谈计算机补码。
将原数取反,然后再加1
得到的新数和原数按位相与就得到了最右边的1
到末尾的数。
举个例子,对于101000
,先取反得到010111
,再加1
变成011000
,再和原数相与,101000 & 011000
,刚好就得到了1000
。其中,取反再加一,根据补码的知识,可以通过取相反数得到。
所以对于i
的话,i & -i
就得到了最右边的1
到末尾的数,也就是V[i]
这个区间存多少个数。
如果len = i & -i
,那么V[i] = F[i,i-1,i-2, ... i-len+1]
。
参考下边的代码,BIT
就是我们上边要求的V
数组。
/*** @param {number[]} nums*/var NumArray = function (nums) {this.nums = [0, ...nums]; //补一个 0this.BIT = new Array(this.nums.length);for (let i = 1; i < this.BIT.length; i++) {let index = i - ( i & -i ) + 1;this.BIT[i] = 0;//累加 index 到 i 的和while (true) {this.BIT[i] += this.nums[index];index += 1;if (index > i) {break;}}}};
有了BIT
这个数组,一切就都好说了。如果我们想求F[1, 11]
也就是前11
个数的和。
F[1,11] = BIT[11] + BIT[10] + BIT[8]
,看下二进制BIT[0001,1011] = BIT[1011] + BIT[1010] + BIT[1000]
。
1011 -> 1010 -> 1000
,对于BIT
每次的下标就是依次把当前数最右边的1
变成0
。
这里有两种做法,一种是我们求出当前数最右边的1
到末尾的数,然后用原数减一下。
举个例子,1010
最右边的1
到末尾的数是10
,然后用1010 - 10
就得到1000
了。
另外一种做法,就是n & (n - 1)
,比如1010 & (1010 - 1)
,刚好就是1000
了。
知道了这个,我们可以实现一个函数,用来求区间[1, n]
的和。
NumArray.prototype.range = function (index) {let sum = 0;while (index > 0) {sum += this.BIT[index];index -= index & -index;//index = index & (index - 1); //这样也可以}return sum;};
有了range
函数,题目中的sumRange
也就很好实现了。
NumArray.prototype.sumRange = function (i, j) {//range 求的区间范围下标是从 1 开始的,所以这里的 j 需要加 1return this.range(j + 1) - this.range(i);};
接下来是更新函数怎么写。
更新函数的话,最关键的就是找出,当我们更新数组第i
个值,会影响到我们的哪些子区间,也就是代码中的BIT
数组需要更新哪些。
我们来回忆下之前做了什么事情。
这是论文中的一张图,含义就是我们之前分析的,BIT[8]
存的是F[1...8]
,对应图中的就是从第8
个位置到第1
个位置的矩形。BIT[6]
存的是F[6,5]
, 对应图中的就是从第6
个位置一直到第5
个位置的矩形。
然后我们水平从某个数画一条线,比如从3
那里画一条线。
穿过了3
对应的矩形,4
对应的矩形,8
对应的矩形。因此如果改变第3
个数,BIT[3]
,BIT[4]
以及BIT[8]
就需要更新。通过这种方式我们把每个数会影响到哪个区间画出来,找一下规律。
当改变了第5
个元素的时候,会依次影响到BIT[5]
,BIT[6]
,BIT[8]
,BIT[16]
。
00101 -> 00110 -> 01000 -> 10000
。
00101 + 1 = 00110
。
00110 + 10 = 01000
01000 + 1000 = 10000
可以看到每次都是加上当前数最右边的1
到末尾的数,即next = current + (current & -current)
。
所以更新的代码也就出来了。
/*** @param {number} i* @param {number} val* @return {void}*/NumArray.prototype.update = function (i, val) {i += 1;//对应的下标要进行加 1const sub = val - this.nums[i];this.nums[i] = val;while (i < this.nums.length) {this.BIT[i] += sub;i += i & -i;}};
综上,这道题就解决了,我们把代码合在一起。
/*** @param {number[]} nums*/var NumArray = function (nums) {this.nums = [0, ...nums];this.BIT = new Array(this.nums.length);for (let i = 1; i < this.BIT.length; i++) {let index = i - ( i & -i ) + 1;this.BIT[i] = 0;while (true) {this.BIT[i] += this.nums[index];index += 1;if (index > i) {break;}}}};/*** @param {number} i* @param {number} val* @return {void}*/NumArray.prototype.update = function (i, val) {i += 1;const sub = val - this.nums[i];this.nums[i] = val;while (i < this.nums.length) {this.BIT[i] += sub;i += i & -i;}};/*** @param {number} i* @param {number} j* @return {number}*/NumArray.prototype.sumRange = function (i, j) {return this.range(j + 1) - this.range(i);};NumArray.prototype.range = function (index) {let sum = 0;while (index > 0) {sum += this.BIT[index];// index -= index & -index;index = index & (index - 1); //这样也可以}return sum;};/*** Your NumArray object will be instantiated and called as such:* var obj = new NumArray(nums)* obj.update(i,val)* var param_2 = obj.sumRange(i,j)*/
时间复杂度的话,初始化、更新、查询其实都和二进制的位数有关,以查询为例。每次将二进制的最后一位变成0
,最坏的情况就是初始值是全1
,即1111
这种,执行次数就是4
次,也就是二进制的位数。
如果是n
,那么位数大约就是log(n)
,可以结合 再谈进制转换 理解。把一个数展开为2
的幂次和,位数其实就是最高位的幂次加1
。比如
,最高幂次是3
,所以11
的二进制(1011)
位数就是4
。如果要求的数是n
,最高的次幂是x
, ,近似一下 ,x = log(n)
,位数就是log(n) + 1
。
所以update
和sumRange
的时间复杂度就是O(log(n))
。
对于初始化函数,因为要执行n
次,所以就是O(nlog(n))
。当然我们也可以利用解法二,把前缀和都求出来,然后更新数组BIT
的每个值,这样就是O(n)
了。但不是很有必要,因为如果查询和更新的次数很多,远大于n
次,那么初始化这里的时间复杂度也就无关紧要了。
总
讲了很多,其实树状数组最根本的就是开头所提到的二进制幂次的分解,
,然后把右边的分解出来的数看做子区间的长度。