完善了avl 并且修改了 升序为和核心 TODO: 需要修改vbt iterator 和 其他

This commit is contained in:
eson 2019-03-24 03:11:42 +08:00
parent 6a6596cee5
commit d0996f7aaf
5 changed files with 326 additions and 269 deletions

View File

@ -34,31 +34,31 @@ func New(compare compare.Compare) *Tree {
return &Tree{compare: compare} return &Tree{compare: compare}
} }
func (avl *Tree) String() string { func (tree *Tree) String() string {
if avl.size == 0 { if tree.size == 0 {
return "" return ""
} }
str := "AVLTree\n" str := "AVLTree\n"
output(avl.root, "", true, &str) output(tree.root, "", true, &str)
return str return str
} }
func (avl *Tree) Iterator() *Iterator { func (tree *Tree) Iterator() *Iterator {
return initIterator(avl) return initIterator(tree)
} }
func (avl *Tree) Size() int { func (tree *Tree) Size() int {
return avl.size return tree.size
} }
func (avl *Tree) Remove(key interface{}) *Node { func (tree *Tree) Remove(key interface{}) *Node {
if n, ok := avl.GetNode(key); ok { if n, ok := tree.GetNode(key); ok {
avl.size-- tree.size--
if avl.size == 0 { if tree.size == 0 {
avl.root = nil tree.root = nil
return n return n
} }
@ -68,7 +68,7 @@ func (avl *Tree) Remove(key interface{}) *Node {
if left == -1 && right == -1 { if left == -1 && right == -1 {
p := n.parent p := n.parent
p.children[getRelationship(n)] = nil p.children[getRelationship(n)] = nil
avl.fixRemoveHeight(p) tree.fixRemoveHeight(p)
return n return n
} }
@ -105,9 +105,9 @@ func (avl *Tree) Remove(key interface{}) *Node {
// 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度 // 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度
if cparent == n { if cparent == n {
avl.fixRemoveHeight(n) tree.fixRemoveHeight(n)
} else { } else {
avl.fixRemoveHeight(cparent) tree.fixRemoveHeight(cparent)
} }
return cur return cur
@ -153,7 +153,7 @@ func (tree *Tree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(min) iter := NewIterator(min)
for iter.Prev() { for iter.Next() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == max { if iter.cur == max {
break break
@ -178,7 +178,7 @@ func (tree *Tree) GetRange(k1, k2 interface{}) (result []interface{}) {
result = make([]interface{}, 0, 16) result = make([]interface{}, 0, 16)
iter := NewIterator(max) iter := NewIterator(max)
for iter.Next() { for iter.Prev() {
result = append(result, iter.Value()) result = append(result, iter.Value())
if iter.cur == min { if iter.cur == min {
break break
@ -194,8 +194,8 @@ func (tree *Tree) GetRange(k1, k2 interface{}) (result []interface{}) {
return return
} }
func (avl *Tree) Get(key interface{}) (interface{}, bool) { func (tree *Tree) Get(key interface{}) (interface{}, bool) {
n, ok := avl.GetNode(key) n, ok := tree.GetNode(key)
if ok { if ok {
return n.value, true return n.value, true
} }
@ -227,6 +227,15 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
n = n.children[1] n = n.children[1]
lastc = c lastc = c
case 0: case 0:
iter := NewIterator(n)
iter.Prev()
for iter.Prev() {
if tree.compare(iter.cur.value, n.value) == 0 {
n = iter.cur
} else {
break
}
}
result[1] = n result[1] = n
n = nil n = nil
default: default:
@ -236,146 +245,83 @@ func (tree *Tree) getArountNode(key interface{}) (result [3]*Node) {
switch lastc { switch lastc {
case 1: case 1:
const il = 0
const ir = 1
if result[1] == nil { if result[1] != nil {
result[0] = GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1)
} else {
result[0] = last result[0] = last
result[2] = GetNext(last, 1)
parent := last
for ; parent != nil && parent.parent != nil; parent = parent.parent {
child := getRelationship(parent)
if child == (-lastc+2)/2 { // child 与 compare 后左右的关系
result[2] = parent.parent
break
}
}
} else {
l := result[1].children[il]
r := result[1].children[ir]
if l == nil {
result[0] = result[1].parent
} else {
for l.children[ir] != nil {
l = l.children[ir]
}
result[0] = l
}
if r == nil {
parent := result[1].parent
for ; parent != nil && parent.parent != nil; parent = parent.parent {
child := getRelationship(parent)
if child == (-lastc+2)/2 { // child 与 compare 后左右的关系
result[2] = parent.parent
break
}
}
} else {
for r.children[il] != nil {
r = r.children[il]
}
result[2] = r
}
} }
case -1: case -1:
const il = 1 if result[1] != nil {
const ir = 0 result[0] = GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1)
if result[1] == nil { } else {
result[2] = last result[2] = last
result[0] = GetPrev(last, 1)
parent := last
for ; parent != nil && parent.parent != nil; parent = parent.parent {
child := getRelationship(parent)
if child == (-lastc+2)/2 { // child 与 compare 后左右的关系
result[0] = parent.parent
break
}
} }
} else {
l := result[1].children[il]
r := result[1].children[ir]
if l == nil {
result[2] = result[1].parent
} else {
for l.children[ir] != nil {
l = l.children[ir]
}
result[2] = l
}
if r == nil {
parent := result[1].parent
for ; parent != nil && parent.parent != nil; parent = parent.parent {
child := getRelationship(parent)
if child == (-lastc+2)/2 { // child 与 compare 后左右的关系
result[0] = parent.parent
break
}
}
} else {
for r.children[il] != nil {
r = r.children[il]
}
result[0] = r
}
}
case 0: case 0:
const il = 0
const ir = 1
if result[1] == nil { if result[1] == nil {
return return
} }
result[0] = GetPrev(result[1], 1)
result[2] = GetNext(result[1], 1)
l := result[1].children[il] // const il = 0
r := result[1].children[ir] // const ir = 1
if l == nil { // if result[1] == nil {
result[0] = nil // return
} else { // }
for l.children[ir] != nil {
l = l.children[ir]
}
result[0] = l
}
if r == nil { // l := result[1].children[il]
result[2] = nil // r := result[1].children[ir]
} else {
for r.children[il] != nil { // if l == nil {
r = r.children[il] // result[0] = nil
} // } else {
result[2] = r // for l.children[ir] != nil {
} // l = l.children[ir]
// }
// result[0] = l
// }
// if r == nil {
// result[2] = nil
// } else {
// for r.children[il] != nil {
// r = r.children[il]
// }
// result[2] = r
// }
} }
return return
} }
func (avl *Tree) GetNode(value interface{}) (*Node, bool) { func (tree *Tree) GetNode(value interface{}) (*Node, bool) {
for n := avl.root; n != nil; { for n := tree.root; n != nil; {
switch c := avl.compare(value, n.value); c { switch c := tree.compare(value, n.value); c {
case -1: case -1:
n = n.children[0] n = n.children[0]
case 1: case 1:
n = n.children[1] n = n.children[1]
case 0: case 0:
iter := NewIterator(n)
iter.Prev()
for iter.Prev() {
if tree.compare(iter.cur.value, n.value) == 0 {
n = iter.cur
} else {
break
}
}
return n, true return n, true
default: default:
panic("Get compare only is allowed in -1, 0, 1") panic("Get compare only is allowed in -1, 0, 1")
@ -384,15 +330,15 @@ func (avl *Tree) GetNode(value interface{}) (*Node, bool) {
return nil, false return nil, false
} }
func (avl *Tree) Put(value interface{}) { func (tree *Tree) Put(value interface{}) {
avl.size++ tree.size++
node := &Node{value: value} node := &Node{value: value}
if avl.size == 1 { if tree.size == 1 {
avl.root = node tree.root = node
return return
} }
cur := avl.root cur := tree.root
parent := cur.parent parent := cur.parent
child := -1 child := -1
@ -402,13 +348,13 @@ func (avl *Tree) Put(value interface{}) {
parent.children[child] = node parent.children[child] = node
node.parent = parent node.parent = parent
if node.parent.height == 0 { if node.parent.height == 0 {
avl.fixPutHeight(node.parent) tree.fixPutHeight(node.parent)
} }
return return
} }
parent = cur parent = cur
c := avl.compare(value, cur.value) c := tree.compare(value, cur.value)
child = (c + 2) / 2 child = (c + 2) / 2
cur = cur.children[child] cur = cur.children[child]
} }
@ -559,7 +505,7 @@ func (tree *Tree) Traversal(every func(v interface{}) bool, traversalMethod ...i
} }
} }
func (avl *Tree) lrrotate(cur *Node) { func (tree *Tree) lrrotate(cur *Node) {
const l = 1 const l = 1
const r = 0 const r = 0
@ -603,7 +549,7 @@ func (avl *Tree) lrrotate(cur *Node) {
cur.height = getMaxChildrenHeight(cur) + 1 cur.height = getMaxChildrenHeight(cur) + 1
} }
func (avl *Tree) rlrotate(cur *Node) { func (tree *Tree) rlrotate(cur *Node) {
const l = 0 const l = 0
const r = 1 const r = 1
@ -645,7 +591,7 @@ func (avl *Tree) rlrotate(cur *Node) {
cur.height = getMaxChildrenHeight(cur) + 1 cur.height = getMaxChildrenHeight(cur) + 1
} }
func (avl *Tree) rrotateex(cur *Node) { func (tree *Tree) rrotateex(cur *Node) {
const l = 0 const l = 0
const r = 1 const r = 1
@ -683,7 +629,7 @@ func (avl *Tree) rrotateex(cur *Node) {
cur.height = getMaxChildrenHeight(cur) + 1 cur.height = getMaxChildrenHeight(cur) + 1
} }
func (avl *Tree) rrotate(cur *Node) { func (tree *Tree) rrotate(cur *Node) {
const l = 0 const l = 0
const r = 1 const r = 1
@ -720,7 +666,7 @@ func (avl *Tree) rrotate(cur *Node) {
cur.height = getMaxChildrenHeight(cur) + 1 cur.height = getMaxChildrenHeight(cur) + 1
} }
func (avl *Tree) lrotateex(cur *Node) { func (tree *Tree) lrotateex(cur *Node) {
const l = 1 const l = 1
const r = 0 const r = 0
@ -759,7 +705,7 @@ func (avl *Tree) lrotateex(cur *Node) {
cur.height = getMaxChildrenHeight(cur) + 1 cur.height = getMaxChildrenHeight(cur) + 1
} }
func (avl *Tree) lrotate(cur *Node) { func (tree *Tree) lrotate(cur *Node) {
const l = 1 const l = 1
const r = 0 const r = 0
@ -822,7 +768,7 @@ func getHeight(cur *Node) int {
return cur.height return cur.height
} }
func (avl *Tree) fixRemoveHeight(cur *Node) { func (tree *Tree) fixRemoveHeight(cur *Node) {
for { for {
lefth, rigthh, lrmax := getMaxAndChildrenHeight(cur) lefth, rigthh, lrmax := getMaxAndChildrenHeight(cur)
@ -836,16 +782,16 @@ func (avl *Tree) fixRemoveHeight(cur *Node) {
if diff < -1 { if diff < -1 {
r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式 r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式
if getHeight(r.children[0]) > getHeight(r.children[1]) { if getHeight(r.children[0]) > getHeight(r.children[1]) {
avl.lrrotate(cur) tree.lrrotate(cur)
} else { } else {
avl.lrotate(cur) tree.lrotate(cur)
} }
} else if diff > 1 { } else if diff > 1 {
l := cur.children[0] l := cur.children[0]
if getHeight(l.children[1]) > getHeight(l.children[0]) { if getHeight(l.children[1]) > getHeight(l.children[0]) {
avl.rlrotate(cur) tree.rlrotate(cur)
} else { } else {
avl.rrotate(cur) tree.rrotate(cur)
} }
} else { } else {
if cur.height == curheight { if cur.height == curheight {
@ -862,7 +808,7 @@ func (avl *Tree) fixRemoveHeight(cur *Node) {
} }
func (avl *Tree) fixPutHeight(cur *Node) { func (tree *Tree) fixPutHeight(cur *Node) {
for { for {
@ -874,16 +820,16 @@ func (avl *Tree) fixPutHeight(cur *Node) {
if diff < -1 { if diff < -1 {
r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式 r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式
if getHeight(r.children[0]) > getHeight(r.children[1]) { if getHeight(r.children[0]) > getHeight(r.children[1]) {
avl.lrrotate(cur) tree.lrrotate(cur)
} else { } else {
avl.lrotate(cur) tree.lrotate(cur)
} }
} else if diff > 1 { } else if diff > 1 {
l := cur.children[0] l := cur.children[0]
if getHeight(l.children[1]) > getHeight(l.children[0]) { if getHeight(l.children[1]) > getHeight(l.children[0]) {
avl.rlrotate(cur) tree.rlrotate(cur)
} else { } else {
avl.rrotate(cur) tree.rrotate(cur)
} }
} else { } else {
@ -973,11 +919,11 @@ func outputfordebug(node *Node, prefix string, isTail bool, str *string) {
} }
} }
func (avl *Tree) debugString() string { func (tree *Tree) debugString() string {
if avl.size == 0 { if tree.size == 0 {
return "" return ""
} }
str := "AVLTree\n" str := "AVLTree\n"
outputfordebug(avl.root, "", true, &str) outputfordebug(tree.root, "", true, &str)
return str return str
} }

View File

@ -8,6 +8,7 @@ import (
"os" "os"
"testing" "testing"
"474420502.top/eson/structure/compare"
"github.com/Pallinder/go-randomdata" "github.com/Pallinder/go-randomdata"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/emirpasic/gods/trees/avltree" "github.com/emirpasic/gods/trees/avltree"
@ -80,57 +81,166 @@ func TestIterator(t *testing.T) {
// └── 1` // └── 1`
iter := avl.Iterator() // root start point iter := avl.Iterator() // root start point
t.Error(iter.cur, spew.Sdump(iter.tstack))
l := []int{14, 15, 20, 21, 30} l := []int{14, 15, 20, 21, 30}
for i := 0; iter.Prev(); i++ { for i := 0; iter.Next(); i++ {
if iter.Value().(int) != l[i] { if iter.Value().(int) != l[i] {
t.Error("iter prev error", iter.Value(), l[i]) t.Error("iter Next error", iter.Value(), l[i])
} }
} }
iter.Prev() iter.Next()
if iter.Value().(int) != 30 { if iter.Value().(int) != 30 {
t.Error("prev == false", iter.Value(), iter.Prev(), iter.Value()) t.Error("Next == false", iter.Value(), iter.Next(), iter.Value())
} }
l = []int{21, 20, 15, 14, 7, 7, 6, 5, 4, 3, 2, 1} l = []int{21, 20, 15, 14, 7, 7, 6, 5, 4, 3, 2, 1}
t.Error(iter.cur, spew.Sdump(iter.tstack)) for i := 0; iter.Prev(); i++ { // cur is 30 next is 21
for i := 0; iter.Next(); i++ { // cur is 30 next is 21
if iter.Value().(int) != l[i] { if iter.Value().(int) != l[i] {
t.Error(iter.Value()) t.Error(iter.Value())
} }
} }
if iter.Next() != false { if iter.Prev() != false {
t.Error("Next is error, cur is tail, val = 1 Next return false") t.Error("Prev is error, cur is tail, val = 1 Prev return false")
} }
if iter.Value().(int) != 1 { // cur is 1 if iter.Value().(int) != 1 { // cur is 1
t.Error("next == false", iter.Value(), iter.Next(), iter.Value()) t.Error("next == false", iter.Value(), iter.Prev(), iter.Value())
} }
if iter.Prev() != true && iter.Value().(int) != 2 { if iter.Next() != true && iter.Value().(int) != 2 {
t.Error("next to prev is error") t.Error("next to prev is error")
} }
} }
func TestGetRange(t *testing.T) {
tree := New(compare.Int)
for _, v := range []int{5, 6, 8, 10, 13, 17, 1, 2, 40, 30} {
tree.Put(v)
}
// t.Error(tree.debugString())
// t.Error(tree.getArountNode(20))
// t.Error(tree.Values())
result := tree.GetRange(0, 20)
if spew.Sprint(result) != "[1 2 5 6 8 10 13 17]" {
t.Error(result)
}
result = tree.GetRange(-5, -1)
if spew.Sprint(result) != "[]" {
t.Error(result)
}
result = tree.GetRange(7, 20)
if spew.Sprint(result) != "[8 10 13 17]" {
t.Error(result)
}
result = tree.GetRange(30, 40)
if spew.Sprint(result) != "[30 40]" {
t.Error(result)
}
result = tree.GetRange(30, 60)
if spew.Sprint(result) != "[30 40]" {
t.Error(result)
}
result = tree.GetRange(40, 40)
if spew.Sprint(result) != "[40]" {
t.Error(result)
}
result = tree.GetRange(50, 60)
if spew.Sprint(result) != "[]" {
t.Error(result)
}
result = tree.GetRange(50, 1)
if spew.Sprint(result) != "[40 30 17 13 10 8 6 5 2 1]" {
t.Error(result)
}
result = tree.GetRange(30, 20)
if spew.Sprint(result) != "[30]" {
t.Error(result)
}
}
func TestGetAround(t *testing.T) { func TestGetAround(t *testing.T) {
avl := New(utils.IntComparator) tree := New(compare.Int)
for _, v := range []int{7, 14, 15, 20, 30, 21, 40, 40, 50, 3, 40, 40, 40} { for _, v := range []int{7, 14, 14, 14, 16, 17, 20, 30, 21, 40, 50, 3, 40, 40, 40, 15} {
avl.Put(v) tree.Put(v)
} }
if spew.Sprint(avl.GetAround(30)) != "[40 30 21]" { var Result string
t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(30)))
Result = spew.Sprint(tree.GetAround(14))
if Result != "[7 14 14]" {
t.Error(tree.Values())
t.Error("14 is root, tree.GetAround(14)) is error", Result)
t.Error(tree.debugString())
} }
if spew.Sprint(avl.GetAround(40)) != "[40 40 30]" { Result = spew.Sprint(tree.GetAround(17))
t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(40))) if Result != "[16 17 20]" {
t.Error(tree.Values())
t.Error("tree.GetAround(17)) is error", Result)
t.Error(tree.debugString())
} }
if spew.Sprint(avl.GetAround(50)) != "[<nil> 50 40]" { Result = spew.Sprint(tree.GetAround(3))
t.Error("avl.GetAround(50)) is error", spew.Sprint(avl.GetAround(50))) if Result != "[<nil> 3 7]" {
t.Error(tree.Values())
t.Error("tree.GetAround(3)) is error", Result)
t.Error(tree.debugString())
} }
Result = spew.Sprint(tree.GetAround(40))
if Result != "[30 40 40]" {
t.Error(tree.Values())
t.Error("tree.GetAround(40)) is error", Result)
t.Error(tree.debugString())
}
Result = spew.Sprint(tree.GetAround(50))
if Result != "[40 50 <nil>]" {
t.Error(tree.Values())
t.Error("tree.GetAround(50)) is error", Result)
t.Error(tree.debugString())
}
Result = spew.Sprint(tree.GetAround(18))
if Result != "[17 <nil> 20]" {
t.Error(tree.Values())
t.Error("18 is not in list, tree.GetAround(18)) is error", Result)
t.Error(tree.debugString())
}
Result = spew.Sprint(tree.GetAround(5))
if Result != "[3 <nil> 7]" {
t.Error(tree.Values())
t.Error("5 is not in list, tree.GetAround(5)) is error", Result)
t.Error(tree.debugString())
}
Result = spew.Sprint(tree.GetAround(2))
if Result != "[<nil> <nil> 3]" {
t.Error(tree.Values())
t.Error("2 is not in list, tree.GetAround(2)) is error", Result)
t.Error(tree.debugString())
}
Result = spew.Sprint(tree.GetAround(100))
if Result != "[50 <nil> <nil>]" {
t.Error(tree.Values())
t.Error("50 is not in list, tree.GetAround(50)) is error", Result)
t.Error(tree.debugString())
}
} }
// for test error case // for test error case
@ -300,7 +410,6 @@ func BenchmarkIterator(b *testing.B) {
tree := New(utils.IntComparator) tree := New(utils.IntComparator)
l := loadTestData() l := loadTestData()
b.N = len(l)
for _, v := range l { for _, v := range l {
tree.Put(v) tree.Put(v)
@ -308,12 +417,19 @@ func BenchmarkIterator(b *testing.B) {
b.ResetTimer() b.ResetTimer()
b.StartTimer() b.StartTimer()
b.N = 0
iter := tree.Iterator() iter := tree.Iterator()
for iter.Next() { for iter.Next() {
b.N++
} }
for iter.Prev() { for iter.Prev() {
b.N++
} }
for iter.Next() { for iter.Next() {
b.N++
}
for iter.Prev() {
b.N++
} }
} }

View File

@ -1,10 +1,6 @@
package avl package avl
import ( import (
"log"
"github.com/davecgh/go-spew/spew"
"474420502.top/eson/structure/lastack" "474420502.top/eson/structure/lastack"
) )
@ -50,63 +46,6 @@ func (iter *Iterator) Right() bool {
return false return false
} }
func GetPrev(cur *Node, idx int) *Node {
iter := NewIterator(cur)
iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up)
for i := 0; i < idx; i++ {
if iter.tstack.Size() == 0 {
if iter.up == nil {
return nil
}
iter.tstack.Push(iter.up)
iter.up = iter.getPrevUp(iter.up)
}
if v, ok := iter.tstack.Pop(); ok {
iter.cur = v.(*Node)
if i == idx-1 {
return iter.cur
}
iter.curPushPrevStack(iter.cur)
} else {
return nil
}
}
return cur
}
func (iter *Iterator) Prev() (result bool) {
if iter.dir > -1 {
if iter.dir == 1 && iter.cur != nil {
iter.tstack.Clear()
iter.curPushPrevStack(iter.cur)
iter.up = iter.getPrevUp(iter.cur)
}
iter.dir = -1
}
if iter.tstack.Size() == 0 {
if iter.up == nil {
return false
}
iter.tstack.Push(iter.up)
iter.up = iter.getPrevUp(iter.up)
}
if v, ok := iter.tstack.Pop(); ok {
iter.cur = v.(*Node)
iter.curPushPrevStack(iter.cur)
return true
}
return false
}
func GetNext(cur *Node, idx int) *Node { func GetNext(cur *Node, idx int) *Node {
iter := NewIterator(cur) iter := NewIterator(cur)
@ -139,16 +78,15 @@ func GetNext(cur *Node, idx int) *Node {
func (iter *Iterator) Next() (result bool) { func (iter *Iterator) Next() (result bool) {
if iter.dir < 1 { // 非 1(next 方向定义 -1 为 prev) if iter.dir > -1 {
if iter.dir == -1 && iter.cur != nil { // 如果上次为prev方向, 则清空辅助计算的栈 if iter.dir == 1 && iter.cur != nil {
iter.tstack.Clear() iter.tstack.Clear()
iter.curPushNextStack(iter.cur) // 把当前cur计算的逆向回朔 iter.curPushNextStack(iter.cur)
iter.up = iter.getNextUp(iter.cur) // cur 寻找下个要计算up iter.up = iter.getNextUp(iter.cur)
} }
iter.dir = 1 iter.dir = -1
} }
// 如果栈空了, 把up的递归计算入栈, 重新计算 下次的up值
if iter.tstack.Size() == 0 { if iter.tstack.Size() == 0 {
if iter.up == nil { if iter.up == nil {
return false return false
@ -163,6 +101,64 @@ func (iter *Iterator) Next() (result bool) {
return true return true
} }
return false
}
func GetPrev(cur *Node, idx int) *Node {
iter := NewIterator(cur)
iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up)
for i := 0; i < idx; i++ {
if iter.tstack.Size() == 0 {
if iter.up == nil {
return nil
}
iter.tstack.Push(iter.up)
iter.up = iter.getPrevUp(iter.up)
}
if v, ok := iter.tstack.Pop(); ok {
iter.cur = v.(*Node)
if i == idx-1 {
return iter.cur
}
iter.curPushPrevStack(iter.cur)
} else {
return nil
}
}
return cur
}
func (iter *Iterator) Prev() (result bool) {
if iter.dir < 1 { // 非 1(next 方向定义 -1 为 prev)
if iter.dir == -1 && iter.cur != nil { // 如果上次为prev方向, 则清空辅助计算的栈
iter.tstack.Clear()
iter.curPushPrevStack(iter.cur) // 把当前cur计算的逆向回朔
iter.up = iter.getPrevUp(iter.cur) // cur 寻找下个要计算up
}
iter.dir = 1
}
// 如果栈空了, 把up的递归计算入栈, 重新计算 下次的up值
if iter.tstack.Size() == 0 {
if iter.up == nil {
return false
}
iter.tstack.Push(iter.up)
iter.up = iter.getPrevUp(iter.up)
}
if v, ok := iter.tstack.Pop(); ok {
iter.cur = v.(*Node)
iter.curPushPrevStack(iter.cur)
return true
}
// 如果再次计算的栈为空, 则只能返回false // 如果再次计算的栈为空, 则只能返回false
return false return false
} }
@ -174,7 +170,7 @@ func getRelationship(cur *Node) int {
return 0 return 0
} }
func (iter *Iterator) getNextUp(cur *Node) *Node { func (iter *Iterator) getPrevUp(cur *Node) *Node {
for cur.parent != nil { for cur.parent != nil {
if getRelationship(cur) == 1 { // next 在 降序 小值. 如果child在右边, parent 比 child 小, parent才有效, 符合降序 if getRelationship(cur) == 1 { // next 在 降序 小值. 如果child在右边, parent 比 child 小, parent才有效, 符合降序
return cur.parent return cur.parent
@ -184,20 +180,19 @@ func (iter *Iterator) getNextUp(cur *Node) *Node {
return nil return nil
} }
func (iter *Iterator) curPushNextStack(cur *Node) { func (iter *Iterator) curPushPrevStack(cur *Node) {
next := cur.children[0] // 当前的左然后向右找, 找到最大, 就是最接近cur 并且小于cur的值 Prev := cur.children[0] // 当前的左然后向右找, 找到最大, 就是最接近cur 并且小于cur的值
if next != nil { if Prev != nil {
log.Println(spew.Sdump(iter.tstack)) iter.tstack.Push(Prev)
iter.tstack.Push(next) for Prev.children[1] != nil {
for next.children[1] != nil { Prev = Prev.children[1]
next = next.children[1] iter.tstack.Push(Prev) // 入栈 用于回溯
iter.tstack.Push(next) // 入栈 用于回溯
} }
} }
} }
func (iter *Iterator) getPrevUp(cur *Node) *Node { func (iter *Iterator) getNextUp(cur *Node) *Node {
for cur.parent != nil { for cur.parent != nil {
if getRelationship(cur) == 0 { // Prev 在 降序 大值. 如果child在左边, parent 比 child 大, parent才有效 , 符合降序 if getRelationship(cur) == 0 { // Prev 在 降序 大值. 如果child在左边, parent 比 child 大, parent才有效 , 符合降序
return cur.parent return cur.parent
@ -207,14 +202,14 @@ func (iter *Iterator) getPrevUp(cur *Node) *Node {
return nil return nil
} }
func (iter *Iterator) curPushPrevStack(cur *Node) { func (iter *Iterator) curPushNextStack(cur *Node) {
prev := cur.children[1] next := cur.children[1]
if prev != nil { if next != nil {
iter.tstack.Push(prev) iter.tstack.Push(next)
for prev.children[0] != nil { for next.children[0] != nil {
prev = prev.children[0] next = next.children[0]
iter.tstack.Push(prev) iter.tstack.Push(next)
} }
} }
} }

View File

@ -56,7 +56,7 @@ func BenchmarkGet(b *testing.B) {
func BenchmarkPush(b *testing.B) { func BenchmarkPush(b *testing.B) {
s := New() s := New()
b.N = 200000 b.N = 20000000
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
v := randomdata.Number(0, 65535) v := randomdata.Number(0, 65535)
s.Push(v) s.Push(v)
@ -65,7 +65,7 @@ func BenchmarkPush(b *testing.B) {
func BenchmarkGodsPush(b *testing.B) { func BenchmarkGodsPush(b *testing.B) {
s := arraystack.New() s := arraystack.New()
b.N = 200000 b.N = 2000000
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
v := randomdata.Number(0, 65535) v := randomdata.Number(0, 65535)
s.Push(v) s.Push(v)
@ -74,7 +74,7 @@ func BenchmarkGodsPush(b *testing.B) {
func BenchmarkPop(b *testing.B) { func BenchmarkPop(b *testing.B) {
s := New() s := New()
b.N = 200000 b.N = 2000000
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
v := randomdata.Number(0, 65535) v := randomdata.Number(0, 65535)
@ -91,7 +91,7 @@ func BenchmarkPop(b *testing.B) {
func BenchmarkGodsPop(b *testing.B) { func BenchmarkGodsPop(b *testing.B) {
s := arraystack.New() s := arraystack.New()
b.N = 200 b.N = 2000000
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
v := randomdata.Number(0, 65535) v := randomdata.Number(0, 65535)

View File

@ -28,7 +28,7 @@ func BenchmarkGodsPush(b *testing.B) {
func BenchmarkPop(b *testing.B) { func BenchmarkPop(b *testing.B) {
s := New() s := New()
b.N = 200 b.N = 200000
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
v := randomdata.Number(0, 65535) v := randomdata.Number(0, 65535)
@ -45,7 +45,7 @@ func BenchmarkPop(b *testing.B) {
func BenchmarkGodsPop(b *testing.B) { func BenchmarkGodsPop(b *testing.B) {
s := arraystack.New() s := arraystack.New()
b.N = 200 b.N = 200000
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
v := randomdata.Number(0, 65535) v := randomdata.Number(0, 65535)