package avl

import (
	"github.com/davecgh/go-spew/spew"

	"github.com/emirpasic/gods/utils"
)

type Node struct {
	children [2]*Node
	parent   *Node
	height   int
	child    int
	value    interface{}
}

// func (n *Node) String() string {
// 	if n == nil {
// 		return "nil"
// 	}
// 	return spew.Sprint(n.value)
// }

func (n *Node) String() string {
	if n == nil {
		return "nil"
	}

	p := "nil"
	if n.parent != nil {
		p = spew.Sprint(n.parent.value)
	}
	return spew.Sprint(n.value) + "(" + p + "-" + spew.Sprint(n.child) + "|" + spew.Sprint(n.height) + ")"
}

type AVL struct {
	root       *Node
	size       int
	comparator utils.Comparator
}

func New(comparator utils.Comparator) *AVL {
	return &AVL{comparator: comparator}
}

func (avl *AVL) String() string {
	if avl.size == 0 {
		return ""
	}
	str := "AVLTree" + "\n"
	output(avl.root, "", true, &str)

	return str
}

func (avl *AVL) Iterator() *Iterator {
	return initIterator(avl)
}

func (avl *AVL) Remove(v interface{}) *Node {

	if n, ok := avl.GetNode(v); ok {

		avl.size--
		if avl.size == 0 {
			avl.root = nil
			return n
		}

		left := getHeight(n.children[0])
		right := getHeight(n.children[1])

		if left == -1 && right == -1 {
			p := n.parent
			p.children[n.child] = nil
			avl.fixRemoveHeight(p)
			return n
		}

		var cur *Node
		if left > right {
			cur = n.children[0]
			for cur.children[1] != nil {
				cur = cur.children[1]
			}

			cleft := cur.children[0]
			cur.parent.children[cur.child] = cleft
			if cleft != nil {
				cleft.child = cur.child
				cleft.parent = cur.parent
			}

		} else {
			cur = n.children[1]
			for cur.children[0] != nil {
				cur = cur.children[0]
			}

			cright := cur.children[1]
			cur.parent.children[cur.child] = cright
			if cright != nil {
				cright.child = cur.child
				cright.parent = cur.parent
			}
		}

		cparent := cur.parent
		avl.replace(n, cur)
		// 考虑到刚好替换的节点是 被替换节点的孩子节点的时候, 从自身修复高度
		if cparent == n {
			avl.fixRemoveHeight(cur)
		} else {
			avl.fixRemoveHeight(cparent)
		}

		return n
	}

	return nil
}

func (avl *AVL) Get(v interface{}) (interface{}, bool) {
	n, ok := avl.GetNode(v)
	if ok {
		return n.value, true
	}
	return n, false
}

func (avl *AVL) GetAround(v interface{}) (result [3]interface{}) {
	an := avl.GetAroundNode(v)
	for i, n := range an {
		if n.value != nil {
			result[i] = n.value
		}
	}
	return
}

func (avl *AVL) GetAroundNode(v interface{}) (result [3]*Node) {
	n := avl.root

	for {

		if n == nil {
			return
		}

		lastc := 0
		switch c := avl.comparator(v, n.value); c {
		case -1:
			if c != -lastc {
				result[0] = n
			}
			lastc = c
			n = n.children[0]
		case 1:
			if c != -lastc {
				result[2] = n
			}
			lastc = c
			n = n.children[1]
		case 0:

			switch lastc {
			case -1:
				if n.children[1] != nil {
					result[0] = n.children[1]
				}
			case 1:
				if n.children[0] != nil {
					result[2] = n.children[0]
				}
			case 0:

				if n.children[1] != nil {
					result[0] = n.children[1]
				}
				if n.children[0] != nil {
					result[2] = n.children[0]
				}

				result[1] = n
				return
			}

		default:
			panic("Get comparator only is allowed in -1, 0, 1")
		}

	}
}
func (avl *AVL) GetNode(v interface{}) (*Node, bool) {

	n := avl.root
	for n != nil {
		switch c := avl.comparator(v, n.value); c {
		case -1:
			n = n.children[0]
		case 1:
			n = n.children[1]
		case 0:
			return n, true
		default:
			panic("Get comparator only is allowed in -1, 0, 1")
		}
	}

	return nil, false
}

func (avl *AVL) Put(v interface{}) {
	avl.size++
	node := &Node{value: v}
	if avl.size == 1 {
		avl.root = node
		return
	}

	cur := avl.root
	parent := cur.parent
	child := -1

	for {

		if cur == nil {
			parent.children[child] = node
			node.parent = parent
			node.child = child

			if node.parent.height == 0 {
				avl.fixPutHeight(node.parent)
			}
			return
		}

		parent = cur
		c := avl.comparator(node.value, cur.value)
		if c > -1 { // right
			child = 1
			cur = cur.children[child]
		} else {
			child = 0
			cur = cur.children[child]
		}
	}

}

func (avl *AVL) replace(old *Node, newN *Node) {

	if old.parent == nil {
		setChild(newN, 0, old.children[0])
		setChild(newN, 1, old.children[1])

		newN.parent = nil
		newN.child = -1
		newN.height = old.height

		avl.root = newN
	} else {

		setChild(newN, 0, old.children[0])
		setChild(newN, 1, old.children[1])

		newN.parent = old.parent
		newN.child = old.child
		newN.height = old.height
		old.parent.children[old.child] = newN
	}
}

func setChild(p *Node, child int, node *Node) {
	p.children[child] = node
	if node != nil {
		node.child = child
		node.parent = p
	}
}

func setChildNotNil(p *Node, child int, node *Node) {
	p.children[child] = node
	node.child = child
	node.parent = p
}

func (avl *AVL) debugString() string {
	if avl.size == 0 {
		return ""
	}
	str := "AVL" + "\n"
	outputfordebug(avl.root, "", true, &str)
	return str
}

func (avl *AVL) TraversalBreadth() (result []interface{}) {
	var traverasl func(cur *Node)
	traverasl = func(cur *Node) {
		if cur == nil {
			return
		}
		result = append(result, cur.value)
		traverasl(cur.children[0])
		traverasl(cur.children[1])
	}
	traverasl(avl.root)
	return
}

func (avl *AVL) TraversalDepth(leftright int) (result []interface{}) {

	if leftright < 0 {
		var traverasl func(cur *Node)
		traverasl = func(cur *Node) {
			if cur == nil {
				return
			}
			traverasl(cur.children[0])
			result = append(result, cur.value)
			traverasl(cur.children[1])
		}
		traverasl(avl.root)
	} else {
		var traverasl func(cur *Node)
		traverasl = func(cur *Node) {
			if cur == nil {
				return
			}
			traverasl(cur.children[1])
			result = append(result, cur.value)
			traverasl(cur.children[0])
		}
		traverasl(avl.root)
	}

	return
}

func (avl *AVL) lrrotate(cur *Node) *Node {

	r := cur.children[1]
	rl := r.children[0]
	if cur.parent == nil {
		avl.root = rl
		rl.parent = nil
	} else {
		setChildNotNil(cur.parent, cur.child, rl)
	}

	rll := rl.children[0]
	rlr := rl.children[1]

	setChild(cur, 1, rll)
	setChild(r, 0, rlr)

	setChildNotNil(rl, 0, cur)
	setChildNotNil(rl, 1, r)

	cur.height = getMaxChildrenHeight(cur) + 1
	r.height = getMaxChildrenHeight(r) + 1
	rl.height = getMaxChildrenHeight(rl) + 1

	return rl
}

func (avl *AVL) rlrotate(cur *Node) *Node {

	l := cur.children[0]
	lr := l.children[1]
	if cur.parent == nil {
		avl.root = lr
		lr.parent = nil
	} else {
		setChildNotNil(cur.parent, cur.child, lr)
	}

	lrr := lr.children[1]
	lrl := lr.children[0]

	setChild(cur, 0, lrr)
	setChild(l, 1, lrl)
	setChildNotNil(lr, 1, cur)
	setChildNotNil(lr, 0, l)

	cur.height = getMaxChildrenHeight(cur) + 1
	l.height = getMaxChildrenHeight(l) + 1
	lr.height = getMaxChildrenHeight(lr) + 1

	return lr
}

func (avl *AVL) rrotate(cur *Node) *Node {

	l := cur.children[0]

	setChild(cur, 0, l.children[1])

	l.parent = cur.parent
	if cur.parent == nil {
		avl.root = l
	} else {
		cur.parent.children[cur.child] = l
	}
	l.child = cur.child

	setChildNotNil(l, 1, cur)
	// l.children[1] = cur
	// cur.child = 1
	// cur.parent = l

	cur.height = getMaxChildrenHeight(cur) + 1
	l.height = getMaxChildrenHeight(l) + 1

	return l // 返回前 替换为cur节点的节点, 有利余修复高度
}

func (avl *AVL) lrotate(cur *Node) *Node {

	r := cur.children[1]

	// 右左节点 链接 当前的右节点
	setChild(cur, 1, r.children[0])

	// 设置 需要旋转的节点到当前节点的 链条
	r.parent = cur.parent
	if cur.parent == nil {
		avl.root = r
	} else {
		cur.parent.children[cur.child] = r
	}
	r.child = cur.child

	// 当前节点旋转到 左边的 链条
	setChildNotNil(r, 0, cur)
	// r.children[0] = cur
	// cur.child = 0
	// cur.parent = r

	// 修复改动过的节点高度 先从低开始到高
	cur.height = getMaxChildrenHeight(cur) + 1
	r.height = getMaxChildrenHeight(r) + 1

	return r
}

func getMaxAndChildrenHeight(cur *Node) (h1, h2, maxh int) {
	h1 = getHeight(cur.children[0])
	h2 = getHeight(cur.children[1])
	if h1 > h2 {
		maxh = h1
	} else {
		maxh = h2
	}

	return
}

func getMaxChildrenHeight(cur *Node) int {
	h1 := getHeight(cur.children[0])
	h2 := getHeight(cur.children[1])
	if h1 > h2 {
		return h1
	}
	return h2
}

func getHeight(cur *Node) int {
	if cur == nil {
		return -1
	}
	return cur.height
}

func (avl *AVL) fixRemoveHeight(cur *Node) {

	for {

		lefth, rigthh, lrmax := getMaxAndChildrenHeight(cur)

		// 判断当前节点是否有变化, 如果没变化的时候, 不需要往上修复
		isBreak := false
		if cur.height == lrmax+1 {
			isBreak = true
		} else {
			cur.height = lrmax + 1
		}

		// 计算高度的差值 绝对值大于2的时候需要旋转
		diff := lefth - rigthh
		if diff < -1 {
			r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式
			if getHeight(r.children[0]) > getHeight(r.children[1]) {
				cur = avl.lrrotate(cur)
			} else {
				cur = avl.lrotate(cur)
			}
		} else if diff > 1 {
			l := cur.children[0]
			if getHeight(l.children[1]) > getHeight(l.children[0]) {
				cur = avl.rlrotate(cur)
			} else {
				cur = avl.rrotate(cur)
			}
		} else {

			if isBreak {
				return
			}

		}

		if cur.parent == nil {
			return
		}

		cur = cur.parent
	}

}

func (avl *AVL) fixPutHeight(cur *Node) {

	for {

		lefth := getHeight(cur.children[0])
		rigthh := getHeight(cur.children[1])

		// 计算高度的差值 绝对值大于2的时候需要旋转
		diff := lefth - rigthh
		if diff < -1 {
			r := cur.children[1] // 根据左旋转的右边节点的子节点 左右高度选择旋转的方式
			if getHeight(r.children[0]) > getHeight(r.children[1]) {
				cur = avl.lrrotate(cur)
			} else {
				cur = avl.lrotate(cur)
			}

		} else if diff > 1 {
			l := cur.children[0]
			if getHeight(l.children[1]) > getHeight(l.children[0]) {
				cur = avl.rlrotate(cur)
			} else {
				cur = avl.rrotate(cur)
			}
		} else {
			// 选择一个child的最大高度 + 1为 高度
			if lefth > rigthh {
				cur.height = lefth + 1
			} else {
				cur.height = rigthh + 1
			}
		}

		if cur.parent == nil || cur.height < cur.parent.height {
			return
		}
		cur = cur.parent
	}
}

func output(node *Node, prefix string, isTail bool, str *string) {

	if node.children[1] != nil {
		newPrefix := prefix
		if isTail {
			newPrefix += "│   "
		} else {
			newPrefix += "    "
		}
		output(node.children[1], newPrefix, false, str)
	}
	*str += prefix
	if isTail {
		*str += "└── "
	} else {
		*str += "┌── "
	}

	*str += spew.Sprint(node.value) + "\n"

	if node.children[0] != nil {
		newPrefix := prefix
		if isTail {
			newPrefix += "    "
		} else {
			newPrefix += "│   "
		}
		output(node.children[0], newPrefix, true, str)
	}

}

func outputfordebug(node *Node, prefix string, isTail bool, str *string) {

	if node.children[1] != nil {
		newPrefix := prefix
		if isTail {
			newPrefix += "│   "
		} else {
			newPrefix += "    "
		}
		outputfordebug(node.children[1], newPrefix, false, str)
	}
	*str += prefix
	if isTail {
		*str += "└── "
	} else {
		*str += "┌── "
	}

	suffix := "("
	parentv := ""
	if node.parent == nil {
		parentv = "nil"
	} else {
		parentv = spew.Sprint(node.parent.value)
	}
	suffix += parentv + "-" + spew.Sprint(node.child) + "|" + spew.Sprint(node.height) + ")"
	*str += spew.Sprint(node.value) + suffix + "\n"

	if node.children[0] != nil {
		newPrefix := prefix
		if isTail {
			newPrefix += "    "
		} else {
			newPrefix += "│   "
		}
		outputfordebug(node.children[0], newPrefix, true, str)
	}
}