TODO: GetRange

This commit is contained in:
huangsimin 2019-03-19 19:15:54 +08:00
parent 50e4bd754f
commit 06d55a2f9e
6 changed files with 375 additions and 313 deletions

View File

@ -120,6 +120,19 @@ func (avl *Tree) Remove(key interface{}) *Node {
return nil return nil
} }
// Values 返回先序遍历的值
func (avl *Tree) Values() []interface{} {
mszie := 0
if avl.root != nil {
mszie = avl.size
}
result := make([]interface{}, 0, mszie)
avl.Traversal(func(v interface{}) {
result = append(result, v)
}, DLR)
return result
}
func (avl *Tree) Get(key interface{}) (interface{}, bool) { func (avl *Tree) Get(key interface{}) (interface{}, bool) {
n, ok := avl.GetNode(key) n, ok := avl.GetNode(key)
if ok { if ok {
@ -143,57 +156,35 @@ func (avl *Tree) GetAround(key interface{}) (result [3]interface{}) {
} }
func (avl *Tree) GetAroundNode(value interface{}) (result [3]*Node) { func (avl *Tree) GetAroundNode(value interface{}) (result [3]*Node) {
n := avl.root if cur, ok := avl.GetNode(value); ok {
for { var iter *Iterator
if n == nil { iter = NewIterator(cur)
return iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up)
if v, ok := iter.tstack.Pop(); ok {
result[0] = v.(*Node)
// iter.curPushPrevStack(iter.cur)
} else {
result[0] = iter.up
} }
lastc := 0 iter = NewIterator(cur)
switch c := avl.comparator(value, n.value); c { iter.curPushNextStack(iter.up)
case -1: iter.up = iter.getNextUp(iter.up)
if c != -lastc {
result[0] = n
}
lastc = c
n = n.children[0]
case 1:
if c != -lastc {
result[2] = n
}
lastc = c
n = n.children[1]
case 0:
switch lastc { if v, ok := iter.tstack.Pop(); ok {
case -1: result[2] = v.(*Node)
if n.children[1] != nil { // iter.curPushNextStack(iter.cur)
result[0] = n.children[1] } else {
} result[2] = iter.up
case 1:
if n.children[0] != nil {
result[2] = n.children[0]
}
case 0:
if n.children[1] != nil {
result[0] = n.children[1]
}
if n.children[0] != nil {
result[2] = n.children[0]
}
result[1] = n
return
}
default:
panic("Get comparator only is allowed in -1, 0, 1")
} }
result[1] = cur
} }
return
} }
func (avl *Tree) GetNode(value interface{}) (*Node, bool) { func (avl *Tree) GetNode(value interface{}) (*Node, bool) {
@ -251,47 +242,64 @@ func (avl *Tree) debugString() string {
return str return str
} }
func (avl *Tree) TraversalBreadth() (result []interface{}) { type TraversalMethod int
result = make([]interface{}, 0, avl.size)
var traverasl func(cur *Node)
traverasl = func(cur *Node) {
if cur == nil {
return
}
result = append(result, cur.value)
traverasl(cur.children[0])
traverasl(cur.children[1])
}
traverasl(avl.root)
return
}
func (avl *Tree) TraversalDepth(leftright int) (result []interface{}) { const (
result = make([]interface{}, 0, avl.size) _ TraversalMethod = iota
if leftright < 0 { //DLR 前序遍历
DLR
//LDR 中序遍历
LDR
//LRD 后序遍历
LRD
)
// Traversal 遍历的方法
func (avl *Tree) Traversal(every func(v interface{}), traversalMethod ...interface{}) {
if avl.root == nil {
return
}
method := LDR
if len(traversalMethod) != 0 {
method = traversalMethod[0].(TraversalMethod)
}
switch method {
case DLR:
var traverasl func(cur *Node) var traverasl func(cur *Node)
traverasl = func(cur *Node) { traverasl = func(cur *Node) {
if cur == nil { if cur == nil {
return return
} }
traverasl(cur.children[0]) traverasl(cur.children[0])
result = append(result, cur.value) every(cur.value)
traverasl(cur.children[1]) traverasl(cur.children[1])
} }
traverasl(avl.root) traverasl(avl.root)
} else { case LRD:
var traverasl func(cur *Node) var traverasl func(cur *Node)
traverasl = func(cur *Node) { traverasl = func(cur *Node) {
if cur == nil { if cur == nil {
return return
} }
traverasl(cur.children[1]) traverasl(cur.children[1])
result = append(result, cur.value) every(cur.value)
traverasl(cur.children[0]) traverasl(cur.children[0])
} }
traverasl(avl.root) traverasl(avl.root)
case LDR:
var traverasl func(cur *Node)
traverasl = func(cur *Node) {
if cur == nil {
return
}
every(cur.value)
traverasl(cur.children[0])
traverasl(cur.children[1])
}
traverasl(avl.root)
} }
return return
} }
@ -564,7 +572,6 @@ func abs(n int) int {
} }
func (avl *Tree) fixRemoveHeight(cur *Node) { func (avl *Tree) fixRemoveHeight(cur *Node) {
for { for {
lefth, rigthh, lrmax := getMaxAndChildrenHeight(cur) lefth, rigthh, lrmax := getMaxAndChildrenHeight(cur)

View File

@ -16,7 +16,7 @@ import (
"github.com/emirpasic/gods/utils" "github.com/emirpasic/gods/utils"
) )
const CompartorSize = 1000000 const CompartorSize = 100
const NumberMax = 600 const NumberMax = 600
func TestSave(t *testing.T) { func TestSave(t *testing.T) {
@ -124,7 +124,7 @@ func TestGetAround(t *testing.T) {
} }
if spew.Sprint(avl.GetAround(40)) != "[40 40 30]" { if spew.Sprint(avl.GetAround(40)) != "[40 40 30]" {
t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(50))) t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(40)))
} }
if spew.Sprint(avl.GetAround(50)) != "[<nil> 50 40]" { if spew.Sprint(avl.GetAround(50)) != "[<nil> 50 40]" {
@ -241,7 +241,7 @@ ALL:
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
avl.Remove(l[i]) avl.Remove(l[i])
gods.Remove(l[i]) gods.Remove(l[i])
s1 := spew.Sprint(avl.TraversalDepth(-1)) s1 := spew.Sprint(avl.Values())
s2 := spew.Sprint(gods.Values()) s2 := spew.Sprint(gods.Values())
if s1 != s2 { if s1 != s2 {
t.Error("avl remove error", "avlsize = ", avl.Size()) t.Error("avl remove error", "avlsize = ", avl.Size())
@ -251,6 +251,7 @@ ALL:
} }
} }
} }
} }
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
@ -279,7 +280,7 @@ ALL:
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
avl.Remove(l[i]) avl.Remove(l[i])
gods.Remove(l[i]) gods.Remove(l[i])
if spew.Sprint(gods.Values()) != spew.Sprint(avl.TraversalDepth(-1)) && avl.size != 0 { if spew.Sprint(gods.Values()) != spew.Sprint(avl.Values()) && avl.size != 0 {
// if gods.String() != avl.String() && gods.Size() != 0 && avl.size != 0 { // if gods.String() != avl.String() && gods.Size() != 0 && avl.size != 0 {
t.Error(src1) t.Error(src1)
t.Error(src2) t.Error(src2)
@ -316,27 +317,6 @@ func BenchmarkIterator(b *testing.B) {
} }
func BenchmarkGodsIterator(b *testing.B) {
tree := avltree.NewWithIntComparator()
l := loadTestData()
b.N = len(l)
for _, v := range l {
tree.Put(v, v)
}
b.ResetTimer()
b.StartTimer()
iter := tree.Iterator()
for iter.Next() {
}
for iter.Prev() {
}
for iter.Next() {
}
}
func BenchmarkRemove(b *testing.B) { func BenchmarkRemove(b *testing.B) {
tree := New(utils.IntComparator) tree := New(utils.IntComparator)

View File

@ -20,8 +20,10 @@ func initIterator(avltree *Tree) *Iterator {
return iter return iter
} }
func NewIterator(tree *Tree) *Iterator { func NewIterator(n *Node) *Iterator {
return initIterator(tree) iter := &Iterator{tstack: lastack.New()}
iter.up = n
return iter
} }
func (iter *Iterator) Value() interface{} { func (iter *Iterator) Value() interface{} {

View File

@ -36,8 +36,10 @@ func New(comparator utils.Comparator) *Tree {
func (avl *Tree) String() string { func (avl *Tree) String() string {
str := "AVLTree\n" str := "AVLTree\n"
if avl.root == nil {
return str + "nil"
}
output(avl.root, "", true, &str) output(avl.root, "", true, &str)
return str return str
} }
@ -55,7 +57,7 @@ func (avl *Tree) Size() int {
func (avl *Tree) Remove(key interface{}) *Node { func (avl *Tree) Remove(key interface{}) *Node {
if n, ok := avl.GetNode(key); ok { if n, ok := avl.GetNode(key); ok {
if avl.root == n { if avl.root.size == 1 {
avl.root = nil avl.root = nil
return n return n
} }
@ -112,6 +114,19 @@ func (avl *Tree) Remove(key interface{}) *Node {
return nil return nil
} }
// Values 返回先序遍历的值
func (avl *Tree) Values() []interface{} {
mszie := 0
if avl.root != nil {
mszie = avl.root.size
}
result := make([]interface{}, 0, mszie)
avl.Traversal(func(v interface{}) {
result = append(result, v)
}, LDR)
return result
}
func (avl *Tree) Get(key interface{}) (interface{}, bool) { func (avl *Tree) Get(key interface{}) (interface{}, bool) {
n, ok := avl.GetNode(key) n, ok := avl.GetNode(key)
if ok { if ok {
@ -121,6 +136,35 @@ func (avl *Tree) Get(key interface{}) (interface{}, bool) {
} }
func (avl *Tree) GetRange(min, max interface{}) []interface{} { func (avl *Tree) GetRange(min, max interface{}) []interface{} {
var minN *Node
for minN = avl.root; minN != nil; {
switch c := avl.comparator(min, minN.value); c {
case -1:
minN = minN.children[0]
case 1:
minN = minN.children[1]
case 0:
break
default:
panic("Get comparator only is allowed in -1, 0, 1")
}
}
var maxN *Node
for maxN = avl.root; maxN != nil; {
switch c := avl.comparator(min, maxN.value); c {
case -1:
maxN = maxN.children[0]
case 1:
maxN = maxN.children[1]
case 0:
break
default:
panic("Get comparator only is allowed in -1, 0, 1")
}
}
return nil return nil
} }
@ -135,57 +179,36 @@ func (avl *Tree) GetAround(key interface{}) (result [3]interface{}) {
} }
func (avl *Tree) GetAroundNode(value interface{}) (result [3]*Node) { func (avl *Tree) GetAroundNode(value interface{}) (result [3]*Node) {
n := avl.root
for { if cur, ok := avl.GetNode(value); ok {
if n == nil { var iter *Iterator
return
iter = NewIterator(cur)
iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up)
if v, ok := iter.tstack.Pop(); ok {
result[0] = v.(*Node)
// iter.curPushPrevStack(iter.cur)
} else {
result[0] = iter.up
} }
lastc := 0 iter = NewIterator(cur)
switch c := avl.comparator(value, n.value); c { iter.curPushNextStack(iter.up)
case -1: iter.up = iter.getNextUp(iter.up)
if c != -lastc {
result[0] = n
}
lastc = c
n = n.children[0]
case 1:
if c != -lastc {
result[2] = n
}
lastc = c
n = n.children[1]
case 0:
switch lastc { if v, ok := iter.tstack.Pop(); ok {
case -1: result[2] = v.(*Node)
if n.children[1] != nil { // iter.curPushNextStack(iter.cur)
result[0] = n.children[1] } else {
} result[2] = iter.up
case 1:
if n.children[0] != nil {
result[2] = n.children[0]
}
case 0:
if n.children[1] != nil {
result[0] = n.children[1]
}
if n.children[0] != nil {
result[2] = n.children[0]
}
result[1] = n
return
}
default:
panic("Get comparator only is allowed in -1, 0, 1")
} }
result[1] = cur
} }
return
} }
func (avl *Tree) GetNode(value interface{}) (*Node, bool) { func (avl *Tree) GetNode(value interface{}) (*Node, bool) {
@ -246,61 +269,107 @@ func (avl *Tree) Put(value interface{}) {
} }
} }
func (avl *Tree) debugString() string { type TraversalMethod int
str := "AVLTree\n"
outputfordebug(avl.root, "", true, &str)
return str
}
func (avl *Tree) TraversalBreadth() (result []interface{}) { const (
// L = left R = right D = Value(dest)
_ TraversalMethod = iota
//DLR 先值 然后左递归 右递归 下面同理
DLR
//LDR 先从左边有序访问到右边 从小到大
LDR
// LRD 同理
LRD
// DRL 同理
DRL
// RDL 先从右边有序访问到左边 从大到小
RDL
// RLD 同理
RLD
)
// Traversal 遍历的方法
func (avl *Tree) Traversal(every func(v interface{}), traversalMethod ...interface{}) {
if avl.root == nil { if avl.root == nil {
return return
} }
result = make([]interface{}, 0, avl.root.size) method := DLR
var traverasl func(cur *Node) if len(traversalMethod) != 0 {
traverasl = func(cur *Node) { method = traversalMethod[0].(TraversalMethod)
if cur == nil { }
return
switch method {
case DLR:
var traverasl func(cur *Node)
traverasl = func(cur *Node) {
if cur == nil {
return
}
every(cur.value)
traverasl(cur.children[0])
traverasl(cur.children[1])
} }
result = append(result, cur.value) traverasl(avl.root)
traverasl(cur.children[0]) case LDR:
traverasl(cur.children[1])
}
traverasl(avl.root)
return
}
func (avl *Tree) TraversalDepth(leftright int) (result []interface{}) {
if avl.root == nil {
return
}
result = make([]interface{}, 0, avl.root.size)
if leftright < 0 {
var traverasl func(cur *Node) var traverasl func(cur *Node)
traverasl = func(cur *Node) { traverasl = func(cur *Node) {
if cur == nil { if cur == nil {
return return
} }
traverasl(cur.children[0]) traverasl(cur.children[0])
result = append(result, cur.value) every(cur.value)
traverasl(cur.children[1]) traverasl(cur.children[1])
} }
traverasl(avl.root) traverasl(avl.root)
} else { case LRD:
var traverasl func(cur *Node)
traverasl = func(cur *Node) {
if cur == nil {
return
}
traverasl(cur.children[0])
traverasl(cur.children[1])
every(cur.value)
}
traverasl(avl.root)
case DRL:
var traverasl func(cur *Node)
traverasl = func(cur *Node) {
if cur == nil {
return
}
every(cur.value)
traverasl(cur.children[0])
traverasl(cur.children[1])
}
traverasl(avl.root)
case RDL:
var traverasl func(cur *Node) var traverasl func(cur *Node)
traverasl = func(cur *Node) { traverasl = func(cur *Node) {
if cur == nil { if cur == nil {
return return
} }
traverasl(cur.children[1]) traverasl(cur.children[1])
result = append(result, cur.value) every(cur.value)
traverasl(cur.children[0]) traverasl(cur.children[0])
} }
traverasl(avl.root) traverasl(avl.root)
case RLD:
var traverasl func(cur *Node)
traverasl = func(cur *Node) {
if cur == nil {
return
}
traverasl(cur.children[1])
traverasl(cur.children[0])
every(cur.value)
}
traverasl(avl.root)
} }
return return
} }
@ -503,11 +572,6 @@ func (avl *Tree) fixRemoveHeight(cur *Node) {
} }
} }
// func abs(n int) int {
// y := n >> 31
// return (n ^ y) - y
// }
func (avl *Tree) fixPutHeight(cur *Node, lefts, rigths int) { func (avl *Tree) fixPutHeight(cur *Node, lefts, rigths int) {
if lefts > rigths { if lefts > rigths {
l := cur.children[0] l := cur.children[0]
@ -598,3 +662,12 @@ func outputfordebug(node *Node, prefix string, isTail bool, str *string) {
outputfordebug(node.children[0], newPrefix, true, str) outputfordebug(node.children[0], newPrefix, true, str)
} }
} }
func (avl *Tree) debugString() string {
str := "AVLTree\n"
if avl.root == nil {
return str + "nil"
}
outputfordebug(avl.root, "", true, &str)
return str
}

View File

@ -62,109 +62,70 @@ func loadTestData() []int {
return l return l
} }
// func TestIterator(t *testing.T) { func TestIterator(t *testing.T) {
// avl := New(utils.IntComparator) avl := New(utils.IntComparator)
// for _, v := range []int{1, 2, 7, 4, 5, 6, 7, 14, 15, 20, 30, 21, 3} { for _, v := range []int{7, 14, 14, 14, 16, 17, 20, 30, 21, 40, 40, 50, 3, 40, 40, 40, 15} {
// // t.Error(v) avl.Put(v)
// avl.Put(v) }
t.Error(avl.Values())
t.Error(avl.debugString())
}
// } func TestGetAround(t *testing.T) {
// // ` AVLTree avl := New(utils.IntComparator)
// // │ ┌── 30 for _, v := range []int{7, 14, 14, 14, 16, 17, 20, 30, 21, 40, 50, 3, 40, 40, 40, 15} {
// // │ │ └── 21 avl.Put(v)
// // │ ┌── 20 }
// // │ │ └── 15
// // └── 14
// // │ ┌── 7
// // │ ┌── 7
// // │ │ └── 6
// // └── 5
// // │ ┌── 4
// // │ │ └── 3
// // └── 2
// // └── 1`
// iter := avl.Iterator() // root start point var Result string
// l := []int{14, 15, 20, 21, 30} Result = spew.Sprint(avl.GetAround(3))
if Result != "[7 3 <nil>]" {
t.Error("avl.GetAround(3)) is error", Result)
}
// for i := 0; iter.Prev(); i++ { Result = spew.Sprint(avl.GetAround(40))
// if iter.Value().(int) != l[i] { if Result != "[40 40 30]" {
// t.Error("iter prev error", iter.Value(), l[i]) t.Error("avl.GetAround(40)) is error", Result)
// } }
// }
// iter.Prev() Result = spew.Sprint(avl.GetAround(50))
// if iter.Value().(int) != 30 { if Result != "[<nil> 50 40]" {
// t.Error("prev == false", iter.Value(), iter.Prev(), iter.Value()) t.Error("avl.GetAround(50)) is error", Result)
// } }
}
// l = []int{21, 20, 15, 14, 7, 7, 6, 5, 4, 3, 2, 1}
// for i := 0; iter.Next(); i++ { // cur is 30 next is 21
// if iter.Value().(int) != l[i] {
// t.Error(iter.Value())
// }
// }
// if iter.Next() != false {
// t.Error("Next is error, cur is tail, val = 1 Next return false")
// }
// if iter.Value().(int) != 1 { // cur is 1
// t.Error("next == false", iter.Value(), iter.Next(), iter.Value())
// }
// if iter.Prev() != true && iter.Value().(int) != 2 {
// t.Error("next to prev is error")
// }
// }
// func TestGetAround(t *testing.T) {
// avl := New(utils.IntComparator)
// for _, v := range []int{7, 14, 15, 20, 30, 21, 40, 40, 50, 3, 40, 40, 40} {
// avl.Put(v)
// }
// if spew.Sprint(avl.GetAround(30)) != "[40 30 21]" {
// t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(30)))
// }
// if spew.Sprint(avl.GetAround(40)) != "[40 40 30]" {
// t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(50)))
// }
// if spew.Sprint(avl.GetAround(50)) != "[<nil> 50 40]" {
// t.Error("avl.GetAround(40)) is error", spew.Sprint(avl.GetAround(50)))
// }
// }
// // for test error case // // for test error case
// func TestPutComparatorRandom(t *testing.T) { func TestPutComparatorRandom(t *testing.T) {
// for n := 0; n < 300000; n++ { for n := 0; n < 300000; n++ {
// avl := New(utils.IntComparator) avl := New(utils.IntComparator)
// godsavl := avltree.NewWithIntComparator() godsavl := avltree.NewWithIntComparator()
// content := "" content := ""
// m := make(map[int]int) m := make(map[int]int)
// for i := 0; len(m) < 10; i++ { for i := 0; len(m) < 10; i++ {
// v := randomdata.Number(0, 65535) v := randomdata.Number(0, 65535)
// if _, ok := m[v]; !ok { if _, ok := m[v]; !ok {
// m[v] = v m[v] = v
// content += spew.Sprint(v) + "," content += spew.Sprint(v) + ","
// avl.Put(v) avl.Put(v)
// godsavl.Put(v, v) godsavl.Put(v, v)
// } }
// } }
// if avl.String() != godsavl.String() { s1 := spew.Sprint(avl.Values())
// t.Error(godsavl.String()) s2 := spew.Sprint(godsavl.Values())
// t.Error(avl.debugString())
// t.Error(content, n) if s1 != s2 {
// break t.Error(godsavl.String())
// } t.Error(avl.debugString())
// } t.Error(content, n)
// } break
}
}
}
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
avl := New(utils.IntComparator) avl := New(utils.IntComparator)
@ -186,7 +147,6 @@ func TestGet(t *testing.T) {
} }
func TestRemoveAll(t *testing.T) { func TestRemoveAll(t *testing.T) {
ALL: ALL:
for c := 0; c < 5000; c++ { for c := 0; c < 5000; c++ {
avl := New(utils.IntComparator) avl := New(utils.IntComparator)
@ -194,7 +154,7 @@ ALL:
var l []int var l []int
m := make(map[int]int) m := make(map[int]int)
for i := 0; len(l) < 100; i++ { for i := 0; len(l) < 10; i++ {
v := randomdata.Number(0, 100000) v := randomdata.Number(0, 100000)
if _, ok := m[v]; !ok { if _, ok := m[v]; !ok {
m[v] = v m[v] = v
@ -204,13 +164,16 @@ ALL:
} }
} }
for i := 0; i < 100; i++ { for i := 0; i < 10; i++ {
avl.Remove(l[i]) avl.Remove(l[i])
gods.Remove(l[i]) gods.Remove(l[i])
s1 := spew.Sprint(avl.TraversalDepth(-1))
s1 := spew.Sprint(avl.Values())
s2 := spew.Sprint(gods.Values()) s2 := spew.Sprint(gods.Values())
if s1 != s2 { if s1 != s2 {
t.Error("avl remove error", "avlsize = ", avl.Size()) t.Error("avl remove error", "avlsize = ", avl.Size())
t.Error(avl.root, i, l[i])
t.Error(s1) t.Error(s1)
t.Error(s2) t.Error(s2)
break ALL break ALL
@ -245,7 +208,7 @@ ALL:
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
avl.Remove(l[i]) avl.Remove(l[i])
gods.Remove(l[i]) gods.Remove(l[i])
if avl.root != nil && spew.Sprint(gods.Values()) != spew.Sprint(avl.TraversalDepth(-1)) { if avl.root != nil && spew.Sprint(gods.Values()) != spew.Sprint(avl.Values()) {
// if gods.String() != avl.String() && gods.Size() != 0 && avl.size != 0 { // if gods.String() != avl.String() && gods.Size() != 0 && avl.size != 0 {
t.Error(src1) t.Error(src1)
t.Error(src2) t.Error(src2)
@ -260,48 +223,26 @@ ALL:
} }
} }
// func BenchmarkIterator(b *testing.B) { func BenchmarkIterator(b *testing.B) {
// tree := New(utils.IntComparator) tree := New(utils.IntComparator)
// l := loadTestData() l := loadTestData()
// b.N = len(l) b.N = len(l)
// for _, v := range l { for _, v := range l {
// tree.Put(v) tree.Put(v)
// } }
// b.ResetTimer() b.ResetTimer()
// b.StartTimer() b.StartTimer()
// iter := tree.Iterator() iter := tree.Iterator()
// for iter.Next() { for iter.Next() {
// } }
// for iter.Prev() { for iter.Prev() {
// } }
// for iter.Next() { for iter.Next() {
// } }
}
// }
// func BenchmarkGodsIterator(b *testing.B) {
// tree := avltree.NewWithIntComparator()
// l := loadTestData()
// b.N = len(l)
// for _, v := range l {
// tree.Put(v, v)
// }
// b.ResetTimer()
// b.StartTimer()
// iter := tree.Iterator()
// for iter.Next() {
// }
// for iter.Prev() {
// }
// for iter.Next() {
// }
// }
func BenchmarkRemove(b *testing.B) { func BenchmarkRemove(b *testing.B) {
tree := New(utils.IntComparator) tree := New(utils.IntComparator)

View File

@ -5,8 +5,6 @@ import (
) )
type Iterator struct { type Iterator struct {
op *Tree
dir int dir int
up *Node up *Node
cur *Node cur *Node
@ -15,13 +13,15 @@ type Iterator struct {
} }
func initIterator(avltree *Tree) *Iterator { func initIterator(avltree *Tree) *Iterator {
iter := &Iterator{op: avltree, tstack: lastack.New()} iter := &Iterator{tstack: lastack.New()}
iter.up = avltree.root iter.up = avltree.root
return iter return iter
} }
func NewIterator(tree *Tree) *Iterator { func NewIterator(n *Node) *Iterator {
return initIterator(tree) iter := &Iterator{tstack: lastack.New()}
iter.up = n
return iter
} }
func (iter *Iterator) Value() interface{} { func (iter *Iterator) Value() interface{} {
@ -46,6 +46,36 @@ func (iter *Iterator) Right() bool {
return false return false
} }
func GetPrev(cur *Node, idx int) *Node {
iter := NewIterator(cur)
iter.curPushPrevStack(iter.up)
iter.up = iter.getPrevUp(iter.up)
for i := 0; i < idx; i++ {
if iter.tstack.Size() == 0 {
if iter.up == nil {
return nil
}
iter.tstack.Push(iter.up)
iter.up = iter.getPrevUp(iter.up)
}
if v, ok := iter.tstack.Pop(); ok {
iter.cur = v.(*Node)
if i == idx-1 {
return iter.cur
}
iter.curPushPrevStack(iter.cur)
} else {
return nil
}
}
return cur
}
func (iter *Iterator) Prev() (result bool) { func (iter *Iterator) Prev() (result bool) {
if iter.dir > -1 { if iter.dir > -1 {
@ -73,6 +103,35 @@ func (iter *Iterator) Prev() (result bool) {
return false return false
} }
func GetNext(cur *Node, idx int) *Node {
iter := NewIterator(cur)
iter.curPushNextStack(iter.up)
iter.up = iter.getNextUp(iter.up)
for i := 0; i < idx; i++ {
if iter.tstack.Size() == 0 {
if iter.up == nil {
return nil
}
iter.tstack.Push(iter.up)
iter.up = iter.getNextUp(iter.up)
}
if v, ok := iter.tstack.Pop(); ok {
iter.cur = v.(*Node)
if i == idx-1 {
return iter.cur
}
iter.curPushNextStack(iter.cur)
} else {
return nil
}
}
return cur
}
func (iter *Iterator) Next() (result bool) { func (iter *Iterator) Next() (result bool) {