diff --git a/avlindex/avlindex.go b/avlindex/avlindex.go index d768323..9759641 100644 --- a/avlindex/avlindex.go +++ b/avlindex/avlindex.go @@ -1,6 +1,8 @@ package avlindex import ( + "log" + "github.com/davecgh/go-spew/spew" "github.com/emirpasic/gods/utils" @@ -54,13 +56,13 @@ func (avl *Tree) Size() int { return avl.root.size } -func (avl *Tree) Index(idx int) (interface{}, bool) { +func (avl *Tree) indexNode(idx int) *Node { cur := avl.root if idx >= 0 { for cur != nil { ls := getSize(cur.children[0]) if idx == ls { - return cur.value, true + return cur } else if idx < ls { cur = cur.children[0] } else { @@ -73,7 +75,7 @@ func (avl *Tree) Index(idx int) (interface{}, bool) { for cur != nil { rs := getSize(cur.children[1]) if idx == rs { - return cur.value, true + return cur } else if idx < rs { cur = cur.children[1] } else { @@ -82,67 +84,92 @@ func (avl *Tree) Index(idx int) (interface{}, bool) { } } } + return nil +} + +func (avl *Tree) Index(idx int) (interface{}, bool) { + n := avl.indexNode(idx) + if n != nil { + return n.value, true + } return nil, false } -func (avl *Tree) Remove(key interface{}) *Node { +func (avl *Tree) RemoveIndex(idx int) bool { + n := avl.indexNode(idx) + if n != nil { + avl.removeNode(n) + return true + } + return false +} - if n, ok := avl.GetNode(key); ok { - if avl.root.size == 1 { - avl.root = nil - return n - } - - ls, rs := getChildrenSize(n) - if ls == 0 && rs == 0 { - p := n.parent - p.children[getRelationship(n)] = nil - avl.fixRemoveHeight(p) - return n - } - - var cur *Node - if ls > rs { - cur = n.children[0] - for cur.children[1] != nil { - cur = cur.children[1] - } - - cleft := cur.children[0] - cur.parent.children[getRelationship(cur)] = cleft - if cleft != nil { - cleft.parent = cur.parent - } - - } else { - cur = n.children[1] - for cur.children[0] != nil { - cur = cur.children[0] - } - - cright := cur.children[1] - cur.parent.children[getRelationship(cur)] = cright - - if cright != nil { - cright.parent = cur.parent - } - } - - cparent := cur.parent - // 修改为interface 交换 - n.value, cur.value = cur.value, n.value - - // 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度 - if cparent == n { - avl.fixRemoveHeight(n) - } else { - avl.fixRemoveHeight(cparent) - } - - return cur +func (avl *Tree) removeNode(n *Node) { + if avl.root.size == 1 { + avl.root = nil + // return n + return } - return nil + ls, rs := getChildrenSize(n) + if ls == 0 && rs == 0 { + p := n.parent + p.children[getRelationship(n)] = nil + avl.fixRemoveHeight(p) + // return n + return + } + + var cur *Node + if ls > rs { + cur = n.children[0] + for cur.children[1] != nil { + cur = cur.children[1] + } + + cleft := cur.children[0] + cur.parent.children[getRelationship(cur)] = cleft + if cleft != nil { + cleft.parent = cur.parent + } + + } else { + cur = n.children[1] + for cur.children[0] != nil { + cur = cur.children[0] + } + + cright := cur.children[1] + cur.parent.children[getRelationship(cur)] = cright + + if cright != nil { + cright.parent = cur.parent + } + } + + cparent := cur.parent + // 修改为interface 交换 + n.value, cur.value = cur.value, n.value + + // 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度 + if cparent == n { + avl.fixRemoveHeight(n) + } else { + avl.fixRemoveHeight(cparent) + } + + // return cur + return +} + +func (avl *Tree) Remove(key interface{}) bool { + + if n, ok := avl.GetNode(key); ok { + avl.removeNode(n) + return true + } + // return nil + return false } // Values 返回先序遍历的值 @@ -152,12 +179,71 @@ func (avl *Tree) Values() []interface{} { mszie = avl.root.size } result := make([]interface{}, 0, mszie) - avl.Traversal(func(v interface{}) { + avl.Traversal(func(v interface{}) bool { result = append(result, v) + return true }, LDR) return result } +func (avl *Tree) GetRange(idx1, idx2 int) (result []interface{}, ok bool) { // 0 -1 + + if idx1^idx2 < 0 { + if idx1 < 0 { + idx1 = avl.root.size + idx1 - 1 + } else { + idx2 = avl.root.size + idx2 - 1 + } + } + + if idx1 > idx2 { + ok = true + if idx1 >= avl.root.size { + idx1 = avl.root.size - 1 + ok = false + } + + n := avl.indexNode(idx1) + iter := NewIterator(n) + result = make([]interface{}, 0, idx1-idx2) + for i := idx2; i <= idx1; i++ { + if iter.Next() { + result = append(result, iter.Value()) + } else { + ok = false + return + } + } + + return + + } else { + ok = true + if idx2 >= avl.root.size { + idx2 = avl.root.size - 1 + ok = false + } + + if n := avl.indexNode(idx1); n != nil { + iter := NewIterator(n) + result = make([]interface{}, 0, idx2-idx1) + for i := idx1; i <= idx2; i++ { + if iter.Prev() { + result = append(result, iter.Value()) + } else { + ok = false + return + } + } + + return + } + + } + + return nil, false +} + func (avl *Tree) Get(key interface{}) (interface{}, bool) { n, ok := avl.GetNode(key) if ok { @@ -290,85 +376,127 @@ const ( ) // Traversal 遍历的方法 -func (avl *Tree) Traversal(every func(v interface{}), traversalMethod ...interface{}) { +func (avl *Tree) Traversal(every func(v interface{}) bool, traversalMethod ...interface{}) { if avl.root == nil { return } - method := DLR + method := LDR if len(traversalMethod) != 0 { method = traversalMethod[0].(TraversalMethod) } switch method { case DLR: - var traverasl func(cur *Node) - traverasl = func(cur *Node) { + var traverasl func(cur *Node) bool + traverasl = func(cur *Node) bool { if cur == nil { - return + return true } - every(cur.value) - traverasl(cur.children[0]) - traverasl(cur.children[1]) + if !every(cur.value) { + return false + } + if !traverasl(cur.children[0]) { + return false + } + if !traverasl(cur.children[1]) { + return false + } + return true } traverasl(avl.root) case LDR: - var traverasl func(cur *Node) - traverasl = func(cur *Node) { + var traverasl func(cur *Node) bool + traverasl = func(cur *Node) bool { if cur == nil { - return + return true } - traverasl(cur.children[0]) - every(cur.value) - traverasl(cur.children[1]) + if !traverasl(cur.children[0]) { + log.Println(cur) + return false + } + if !every(cur.value) { + return false + } + if !traverasl(cur.children[1]) { + return false + } + return true } traverasl(avl.root) case LRD: - var traverasl func(cur *Node) - traverasl = func(cur *Node) { + var traverasl func(cur *Node) bool + traverasl = func(cur *Node) bool { if cur == nil { - return + return true } - traverasl(cur.children[0]) - traverasl(cur.children[1]) - every(cur.value) + if !traverasl(cur.children[0]) { + return false + } + if !traverasl(cur.children[1]) { + return false + } + if !every(cur.value) { + return false + } + return true } traverasl(avl.root) case DRL: - var traverasl func(cur *Node) - traverasl = func(cur *Node) { + var traverasl func(cur *Node) bool + traverasl = func(cur *Node) bool { if cur == nil { - return + return true } - every(cur.value) - traverasl(cur.children[0]) - traverasl(cur.children[1]) + if !every(cur.value) { + return false + } + if !traverasl(cur.children[0]) { + return false + } + if !traverasl(cur.children[1]) { + return false + } + return true } traverasl(avl.root) case RDL: - var traverasl func(cur *Node) - traverasl = func(cur *Node) { + var traverasl func(cur *Node) bool + traverasl = func(cur *Node) bool { if cur == nil { - return + return true } - traverasl(cur.children[1]) - every(cur.value) - traverasl(cur.children[0]) + if !traverasl(cur.children[1]) { + return false + } + if !every(cur.value) { + return false + } + if !traverasl(cur.children[0]) { + return false + } + return true } traverasl(avl.root) case RLD: - var traverasl func(cur *Node) - traverasl = func(cur *Node) { + var traverasl func(cur *Node) bool + traverasl = func(cur *Node) bool { if cur == nil { - return + return true } - traverasl(cur.children[1]) - traverasl(cur.children[0]) - every(cur.value) + if !traverasl(cur.children[1]) { + return false + } + if !traverasl(cur.children[0]) { + return false + } + if !every(cur.value) { + return false + } + return true } traverasl(avl.root) } - return } func (avl *Tree) lrrotate(cur *Node) { diff --git a/avlindex/avlindex_test.go b/avlindex/avlindex_test.go index 5cf61a4..8fa7df6 100644 --- a/avlindex/avlindex_test.go +++ b/avlindex/avlindex_test.go @@ -15,10 +15,10 @@ import ( "github.com/emirpasic/gods/utils" ) -const CompartorSize = 1000000 +const CompartorSize = 100000 const NumberMax = 50000000 -func Save(t *testing.T) { +func TestSave(t *testing.T) { f, err := os.OpenFile("../l.log", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) if err != nil { @@ -71,6 +71,67 @@ func TestIterator(t *testing.T) { // t.Error(avl.debugString()) } +func TestGetRange(t *testing.T) { + tree := New(utils.IntComparator) + l := []int{7, 14, 14, 14, 16, 17, 20, 30, 21, 40, 50, 3, 40, 40, 40, 15} + for _, v := range l { + tree.Put(v) + } + // [3 7 14 14 14 15 16 17 20 21 30 40 40 40 40 50] + // t.Error(tree.Values(), tree.Size()) + + var result string + result = spew.Sprint(tree.GetRange(0, 5)) + if result != "[3 7 14 14 14 15] true" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(2, 5)) + if result != "[14 14 14 15] true" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(10, 100)) + if result != "[30 40 40 40 40 50] false" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(15, 0)) // size = 16, index max = 15 + if result != "[50 40 40 40 40 30 21 20 17 16 15 14 14 14 7 3] true" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(16, 0)) // size = 16, index max = 15 + if result != "[50 40 40 40 40 30 21 20 17 16 15 14 14 14 7 3] false" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(5, 1)) // size = 16, index max = 15 + if result != "[15 14 14 14 7] true" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(-1, -5)) // size = 16, index max = 15 + if result != "[50 40 40 40 40] true" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(-1, -16)) // size = 16, index max = 0 - 15 (-1,-16) + if result != "[50 40 40 40 40 30 21 20 17 16 15 14 14 14 7 3] true" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(-1, -17)) // size = 16, index max = 0 - 15 (-1,-16) + if result != "[50 40 40 40 40 30 21 20 17 16 15 14 14 14 7 3] false" { + t.Error(result) + } + + result = spew.Sprint(tree.GetRange(-5, -1)) // size = 16, index max = 0 - 15 (-1,-16) + if result != "[40 40 40 40 50] true" { + t.Error(result) + } +} + func TestGetAround(t *testing.T) { avl := New(utils.IntComparator) for _, v := range []int{7, 14, 14, 14, 16, 17, 20, 30, 21, 40, 50, 3, 40, 40, 40, 15} { @@ -143,7 +204,28 @@ func TestGet(t *testing.T) { if v, ok := avl.Get(10000); ok { t.Error("the val(1000) is not in tree, but is found", v) } +} +func TestTravalsal(t *testing.T) { + tree := New(utils.IntComparator) + + l := loadTestData() + N := len(l) + for i := 0; i < N; i++ { + tree.Put(l[i]) + } + + i := 0 + var result []interface{} + tree.Traversal(func(v interface{}) bool { + result = append(result, v) + i++ + if i >= 10 { + return false + } + return true + }) + t.Error(result) } func TestRemoveAll(t *testing.T) { @@ -236,12 +318,17 @@ func BenchmarkIterator(b *testing.B) { b.ResetTimer() b.StartTimer() iter := tree.Iterator() + b.N = 0 for iter.Next() { + b.N++ } for iter.Prev() { + b.N++ } for iter.Next() { + b.N++ } + b.Log(b.N, len(l)) } func BenchmarkRemove(b *testing.B) { @@ -280,63 +367,90 @@ func BenchmarkGodsRemove(b *testing.B) { } } -// func BenchmarkGodsRBRemove(b *testing.B) { -// tree := redblacktree.NewWithIntComparator() - -// l := loadTestData() - -// b.N = len(l) -// for _, v := range l { -// tree.Put(v, v) -// } - -// b.ResetTimer() -// b.StartTimer() - -// for i := 0; i < len(l); i++ { -// tree.Remove(l[i]) -// } -// } - -func BenchmarkGet(b *testing.B) { - - avl := New(utils.IntComparator) +func BenchmarkGodsRBRemove(b *testing.B) { + tree := redblacktree.NewWithIntComparator() l := loadTestData() + b.N = len(l) + for _, v := range l { + tree.Put(v, v) + } b.ResetTimer() b.StartTimer() - for i := 0; i < b.N; i++ { - avl.Get(l[i]) + + for i := 0; i < len(l); i++ { + tree.Remove(l[i]) } } -// func BenchmarkGodsRBGet(b *testing.B) { -// tree := redblacktree.NewWithIntComparator() +func BenchmarkGet(b *testing.B) { -// l := loadTestData() -// b.N = len(l) + tree := New(utils.IntComparator) -// b.ResetTimer() -// b.StartTimer() -// for i := 0; i < b.N; i++ { -// tree.Get(l[i]) -// } -// } + l := loadTestData() + b.N = len(l) + for i := 0; i < b.N; i++ { + tree.Put(l[i]) + } -// func BenchmarkGodsAvlGet(b *testing.B) { -// tree := avltree.NewWithIntComparator() + b.ResetTimer() + b.StartTimer() -// l := loadTestData() -// b.N = len(l) + execCount := 50 + b.N = len(l) * execCount -// b.ResetTimer() -// b.StartTimer() -// for i := 0; i < b.N; i++ { -// tree.Get(l[i]) -// } -// } + for i := 0; i < execCount; i++ { + for _, v := range l { + tree.Get(v) + } + } +} + +func BenchmarkGodsRBGet(b *testing.B) { + tree := redblacktree.NewWithIntComparator() + + l := loadTestData() + b.N = len(l) + for i := 0; i < b.N; i++ { + tree.Put(l[i], i) + } + + b.ResetTimer() + b.StartTimer() + + execCount := 50 + b.N = len(l) * execCount + + for i := 0; i < execCount; i++ { + for _, v := range l { + tree.Get(v) + } + } +} + +func BenchmarkGodsAvlGet(b *testing.B) { + tree := avltree.NewWithIntComparator() + + l := loadTestData() + b.N = len(l) + for i := 0; i < b.N; i++ { + tree.Put(l[i], i) + } + + b.ResetTimer() + b.StartTimer() + + execCount := 50 + b.N = len(l) * execCount + + for i := 0; i < execCount; i++ { + for _, v := range l { + tree.Get(v) + } + } +} func BenchmarkPut(b *testing.B) { l := loadTestData() @@ -344,7 +458,7 @@ func BenchmarkPut(b *testing.B) { b.ResetTimer() b.StartTimer() - execCount := 50 + execCount := 500 b.N = len(l) * execCount for i := 0; i < execCount; i++ { avl := New(utils.IntComparator) @@ -357,23 +471,88 @@ func BenchmarkPut(b *testing.B) { func TestPutStable(t *testing.T) { // l := []int{14, 18, 20, 21, 22, 23, 19} - // var l []int - // for i := 0; len(l) < 10; i++ { - // l = append(l, randomdata.Number(0, 65)) - // } + var l []int + for i := 0; len(l) < 10; i++ { + l = append(l, randomdata.Number(0, 65)) + } - // avl := New(utils.IntComparator) - // for _, v := range l { - // avl.Put(v) - // t.Error(avl.debugString(), v) - // } - // t.Error(avl.Values()) - // for _, v := range []int{10, 0, 9, 5, -11, -10, -1, -5} { - // t.Error(avl.Index(v)) - // } + avl := New(utils.IntComparator) + for _, v := range l { + avl.Put(v) + t.Error(avl.debugString(), v) + } + t.Error(avl.Values()) + for _, v := range []int{10, 0, 9, 5, -11, -10, -1, -5} { + t.Error(avl.Index(v)) + } + + avl.RemoveIndex(4) + t.Error(avl.Index(4)) + t.Error(avl.Values()) + t.Error(avl.debugString()) // t.Error(len(l), avl.debugString(), "\n", "-----------") // 3 6(4) } + +func BenchmarkIndex(b *testing.B) { + tree := New(utils.IntComparator) + + l := loadTestData() + b.N = len(l) + for i := 0; i < b.N; i++ { + tree.Put(l[i]) + } + + b.ResetTimer() + b.StartTimer() + + b.N = 1000000 + + var result [50]interface{} + for n := 0; n < b.N; n++ { + i := 0 + tree.Traversal(func(v interface{}) bool { + result[i] = v + i++ + if i < 50 { + return true + } + log.Print(i) + return false + }) + } +} + +func BenchmarkTraversal(b *testing.B) { + tree := New(utils.IntComparator) + + l := loadTestData() + b.N = len(l) + for i := 0; i < b.N; i++ { + tree.Put(l[i]) + } + + b.ResetTimer() + b.StartTimer() + + execCount := 50 + b.N = len(l) * execCount + + for n := 0; n < execCount; n++ { + i := 0 + var result []interface{} + tree.Traversal(func(v interface{}) bool { + result = append(result, v) + i++ + if i >= 50 { + return false + } + return true + }) + + } +} + func BenchmarkGodsRBPut(b *testing.B) { tree := redblacktree.NewWithIntComparator() diff --git a/for_test.go b/for_test.go new file mode 100644 index 0000000..a8c001c --- /dev/null +++ b/for_test.go @@ -0,0 +1,59 @@ +package structure + +import ( + "bytes" + "encoding/gob" + "io/ioutil" + "log" + "os" + "testing" + + randomdata "github.com/Pallinder/go-randomdata" +) + +const CompartorSize = 100 +const NumberMax = 50000000 + +func TestSave(t *testing.T) { + + f, err := os.OpenFile("../l.log", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + log.Println(err) + } + + //fmt.Println(userBytes) + + var l []int + + // for i := 0; len(l) < 1000; i++ { + // v := randomdata.Number(0, 65535) + // l = append(l, v) + // } + + //m := make(map[int]int) + for i := 0; len(l) < CompartorSize; i++ { + v := randomdata.Number(0, NumberMax) + // if _, ok := m[v]; !ok { + // m[v] = v + l = append(l, v) + // } + } + + var result bytes.Buffer + encoder := gob.NewEncoder(&result) + encoder.Encode(l) + lbytes := result.Bytes() + f.Write(lbytes) + +} + +func loadTestData() []int { + data, err := ioutil.ReadFile("../l.log") + if err != nil { + log.Println(err) + } + var l []int + decoder := gob.NewDecoder(bytes.NewReader(data)) + decoder.Decode(&l) + return l +} diff --git a/priority_list/priority_list_test.go b/priority_list/priority_list_test.go index c1b89b8..afa1ea5 100644 --- a/priority_list/priority_list_test.go +++ b/priority_list/priority_list_test.go @@ -1,12 +1,64 @@ package plist import ( + "bytes" + "encoding/gob" + "io/ioutil" + "log" + "os" "testing" "github.com/Pallinder/go-randomdata" "github.com/emirpasic/gods/utils" ) +const CompartorSize = 100 +const NumberMax = 50000000 + +func Save(t *testing.T) { + + f, err := os.OpenFile("../l.log", os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0666) + if err != nil { + log.Println(err) + } + + //fmt.Println(userBytes) + + var l []int + + // for i := 0; len(l) < 1000; i++ { + // v := randomdata.Number(0, 65535) + // l = append(l, v) + // } + + //m := make(map[int]int) + for i := 0; len(l) < CompartorSize; i++ { + v := randomdata.Number(0, NumberMax) + // if _, ok := m[v]; !ok { + // m[v] = v + l = append(l, v) + // } + } + + var result bytes.Buffer + encoder := gob.NewEncoder(&result) + encoder.Encode(l) + lbytes := result.Bytes() + f.Write(lbytes) + +} + +func loadTestData() []int { + data, err := ioutil.ReadFile("../l.log") + if err != nil { + log.Println(err) + } + var l []int + decoder := gob.NewDecoder(bytes.NewReader(data)) + decoder.Decode(&l) + return l +} + func TestInsert(t *testing.T) { pl := New(utils.IntComparator) for i := 0; i < 10; i++ { @@ -148,10 +200,10 @@ func TestRemove(t *testing.T) { } -func BenchmarkInsert(b *testing.B) { +func BenchmarkGet(b *testing.B) { pl := New(utils.IntComparator) - b.N = 3000 + b.N = 100 for i := 0; i < b.N; i++ { v := randomdata.Number(0, 65535) @@ -160,6 +212,7 @@ func BenchmarkInsert(b *testing.B) { b.ResetTimer() b.StartTimer() + for i := 0; i < b.N; i++ { if i%2 == 0 { pl.Get(i) @@ -167,3 +220,19 @@ func BenchmarkInsert(b *testing.B) { } } +func BenchmarkInsert(b *testing.B) { + + l := loadTestData() + + b.ResetTimer() + b.StartTimer() + + execCount := 500 + b.N = len(l) * execCount + for i := 0; i < execCount; i++ { + pl := New(utils.IntComparator) + for _, v := range l { + pl.Push(v) + } + } +}