package compose.widgets

import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.ColumnScope
import androidx.compose.material.MaterialTheme
import androidx.compose.material.contentColorFor
import androidx.compose.runtime.*
import androidx.compose.runtime.snapshots.SnapshotStateMap
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.input.key.Key
import androidx.compose.ui.input.key.onKeyEvent
import androidx.compose.ui.input.key.onPreviewKeyEvent
import compose.input.onKeyShortcut
import compose.utils.*


/**
 * A widget that displays a list of items where some items can be expanded to show more items "inside".
 *
 * The tree and all the relevant structures are type-parameterized with [I], the type of data associated with inner
 * nodes and [L], the type of data associated with leaf nodes.
 *
 * Note that [childrenOf] must return the same items (compared with [Any.equals] when invoked for the same input, over
 * time. This makes it currently unsuitable for displaying dynamic trees.
 */
@Composable
fun <I, L> TreeList(


    /**
     * The state object.
     */
    state: TreeListState<I, L>,


    /**
     * The list of toplevel nodes.
     */
    topLevelNodes: List<TreeListNode<I, L>>,


    /**
     * Returns the children of the given inner node.
     *
     * This function must produce the same (compared with [Any.equals]) items for the same input over time.
     */
    childrenOf: (I) -> List<TreeListNode<I, L>>,


    /**
     * The composable for an inner tree node.
     *
     * It is passed the corresponding [TreeListNode.Inner], the level of the node in the tree, and whether it's
     * currently expanded.
     */
    innerContent: @Composable (node: TreeListNode.Inner<I>, level: Int, isExpanded: Boolean) -> Unit,


    /**
     * The composable for a leaf tree node.
     *
     * It is passed the corresponding [TreeListNode.Leaf] and the level of the node in the tree.
     */
    leafContent: @Composable (node: TreeListNode.Leaf<L>, level: Int) -> Unit,


    /**
     * The background color of selected rows.
     */
    selectedBackgroundColor: Color = MaterialTheme.colors.secondary,


    /**
     * The content color of selected rows.
     */
    selectedContentColor: Color = contentColorFor(selectedBackgroundColor),


    /**
     * A modifier to apply to the widget.
     */
    modifier: Modifier = Modifier


) {
    val (rootNodeTraverseKey, nodeTraverseKey) = remember { generateTraverseKeys() }
    remember(state) {
        state.setRootNodeTraverseKey(rootNodeTraverseKey)
    }

    Column(
        modifier = modifier
            .then(TreeListRootModifierElement(rootNodeTraverseKey))  // Tag the root element
    ) {
        TreeListItems(
            level = 0,
            state = state,
            nodes = topLevelNodes,
            childrenOf = childrenOf,
            innerContent = innerContent,
            leafContent = leafContent,
            selectedBackgroundColor = selectedBackgroundColor,
            selectedContentColor = selectedContentColor,
            nodeTraverseKey = nodeTraverseKey,
        )
    }
}


/**
 * Emits the elements corresponding to [nodes], recursively.
 */
@Composable
private fun <I, L> ColumnScope.TreeListItems(
    level: Int,
    state: TreeListState<I, L>,
    nodes: List<TreeListNode<I, L>>,
    childrenOf: (I) -> List<TreeListNode<I, L>>,
    innerContent: @Composable (TreeListNode.Inner<I>, Int, Boolean) -> Unit,
    leafContent: @Composable (TreeListNode.Leaf<L>, Int) -> Unit,
    selectedBackgroundColor: Color = MaterialTheme.colors.secondary,
    selectedContentColor: Color = contentColorFor(selectedBackgroundColor),
    nodeTraverseKey: String,
) {
    for (node in nodes) {
        key(node) {
            TreeNode(
                node = node,
                level = level,
                state = state,
                childrenOf = childrenOf,
                innerContent = innerContent,
                leafContent = leafContent,
                selectedBackgroundColor = selectedBackgroundColor,
                selectedContentColor = selectedContentColor,
                nodeTraverseKey = nodeTraverseKey,
            )
        }
    }
}


/**
 * Emits the elements corresponding to [node] and all of its children, recursively.
 */
@Composable
private fun <I, L> ColumnScope.TreeNode(
    node: TreeListNode<I, L>,
    level: Int,
    state: TreeListState<I, L>,
    childrenOf: (I) -> List<TreeListNode<I, L>>,
    innerContent: @Composable (TreeListNode.Inner<I>, Int, Boolean) -> Unit,
    leafContent: @Composable (TreeListNode.Leaf<L>, Int) -> Unit,
    selectedBackgroundColor: Color,
    selectedContentColor: Color,
    nodeTraverseKey: String,
) {
    val isSelected by remember(node, state) {
        derivedStateOf {
            state.selectedNode == node
        }
    }
    val isExpanded by remember(node, state) {
        derivedStateOf {
            (node is TreeListNode.Inner) && state.isExpanded(node)
        }
    }

    val nodeModifierElement = remember(node) {
        TreeListNodeModifierElement(nodeTraverseKey, node)
    }
    SelectableBox(
        isSelected = isSelected,
        selectedBackgroundColor = selectedBackgroundColor,
        selectedContentColor = selectedContentColor,
        modifier = Modifier
            .then(nodeModifierElement)
    ) {
        when (node) {
            is TreeListNode.Leaf -> {
                leafContent(node, level)
            }
            is TreeListNode.Inner -> {
                innerContent(node, level, isExpanded)
            }
        }
    }

    DisposableEffect(state, node, nodeModifierElement) {
        state.associateNode(node, nodeModifierElement)

        onDispose {
            state.disassociateNode(node, nodeModifierElement)
        }
    }

    if (isExpanded) {
        val children = remember(childrenOf, node) {
            childrenOf((node as TreeListNode.Inner).value)
        }
        TreeListItems(
            level = level + 1,
            state = state,
            nodes = children,
            childrenOf = childrenOf,
            innerContent = innerContent,
            leafContent = leafContent,
            nodeTraverseKey = nodeTraverseKey
        )
    }
}


/**
 * Represents the nodes in the tree, as given to us by the producer of the tree.
 */
@Immutable
sealed interface TreeListNode<out I, out L> {


    /**
     * An inner node.
     */
    data class Inner<out I>(val value: I): TreeListNode<I, Nothing>


    /**
     * A leaf node.
     */
    data class Leaf<out L>(val value: L): TreeListNode<Nothing, L>


}


/**
 * Returns the [value] of the given tree list node where the type of data associated with leaves and inner nodes is the
 * same.
 */
val <T> TreeListNode<T,T>.value: T
    get() = when (this) {
        is TreeListNode.Inner -> value
        is TreeListNode.Leaf -> value
    }


/**
 * The holder for the state of the tree.
 */
@Stable
class TreeListState<I, L>(


    /**
     * Returns whether `this` is an ancestor of the given inner node.
     */
    private val isAncestorOf: TreeListNode.Inner<I>.(TreeListNode<I, L>) -> Boolean


) {


    /**
     * The set of expanded nodes.
     */
    private val expandedNodes: MutableMap<TreeListNode.Inner<I>, Unit> = SnapshotStateMap()


    /**
     * Maps each [TreeListNode] to the [TagModifierElement] that tags it in the modifier tree.
     */
    private val tagModifierElementByNode = mutableMapOf<TreeListNode<I, L>, TreeListNodeModifierElement<I, L>>()


    /**
     * The traverse key for the tree list's root node.
     */
    private lateinit var rootNodeTraverseKey: String


    /**
     * The currently selected node.
     */
    var selectedNode: TreeListNode<I, L>? by mutableStateOf(null)


    /**
     * Notifies the state of the traverse key of the tree list's root node.
     */
    internal fun setRootNodeTraverseKey(traverseKey: String) {
        rootNodeTraverseKey = traverseKey
    }


    /**
     * Expands the given node.
     *
     * Returns whether actually expanded (`false` if it was already expanded).
     */
    fun expand(node: TreeListNode.Inner<I>): Boolean {
        return expandedNodes.put(node, Unit) == null
    }


    /**
     * Collapses the given node.
     *
     * Returns whether actually collapsed (`false` if it was already collapsed).
     */
    fun collapse(node: TreeListNode.Inner<I>): Boolean {
        if (expandedNodes.remove(node) != null) {
            val descendents = expandedNodes.keys.filterTo(mutableSetOf()) { node.isAncestorOf(it) }
            expandedNodes.keys.removeAll(descendents)
            return true
        }
        else
            return false
    }


    /**
     * Returns whether the given node is currently expanded.
     */
    fun isExpanded(node: TreeListNode.Inner<I>) = node in expandedNodes


    /**
     * Toggles the expanded state of the given node.
     */
    fun toggleExpanded(node: TreeListNode.Inner<I>) = if (isExpanded(node)) collapse(node) else expand(node)


    /**
     * Selects the node returned by [findNode], given the currently selected node, if any.
     */
    private inline fun selectNodeIfFound(
        findNode: (TagModifierNode<TreeListNode<I, L>>) -> TagModifierNode<TreeListNode<I, L>>?
    ) {
        selectedModifierNode()?.let(findNode)?.let {
            selectedNode = it.data
        }
    }


    /**
     * Selects the node immediately preceding the currently selected node, if any.
     */
    fun selectPrevious() {
        selectNodeIfFound { it.findPreviousSibling(rootNodeTraverseKey) }
    }


    /**
     * Selects the node immediately following the currently selected node, if any.
     */
    fun selectNext() {
        selectNodeIfFound { it.findNextSibling(rootNodeTraverseKey) }
    }


    /**
     * Selects the first node in the tree list.
     *
     * Note that this only works if there's a selected item already.
     */
    fun selectFirst() {
        selectNodeIfFound { it.findFirstSibling(rootNodeTraverseKey) }
    }


    /**
     * Selects the last node in the tree list.
     *
     * Note that this only works if there's a selected item already.
     */
    fun selectLast() {
        selectNodeIfFound { it.findLastSibling(rootNodeTraverseKey) }
    }


    /**
     * Selects the node that is [itemsInPage] items before the currently selected one.
     */
    fun selectPreviousPage(itemsInPage: Int) {
        selectNodeIfFound { it.findPrecedingSibling(rootNodeTraverseKey, count = itemsInPage) }
    }


    /**
     * Selects the node that is [itemsInPage] items after the currently selected one.
     */
    fun selectNextPage(itemsInPage: Int) {
        selectNodeIfFound { it.findFollowingSibling(rootNodeTraverseKey, count = itemsInPage) }
    }


    /**
     * Selects the direct parent of the currently selected item, or the first item if the currently selected one is
     * already a top-level node.
     */
    fun selectParentOrFirst() {
        selectNodeIfFound { selectedNode ->
            var firstNode: TagModifierNode<TreeListNode<I, L>>? = null
            var ancestor: TagModifierNode<TreeListNode<I, L>>? = null
            val selectedTreeNode = selectedNode.data
            selectedNode.traverseSiblings(rootNodeTraverseKey) {
                if (firstNode == null)
                    firstNode = it
                if (it == selectedNode) {
                    false
                }
                else {
                    val treeListNode = it.data
                    if ((treeListNode is TreeListNode.Inner<I>) && treeListNode.isAncestorOf(selectedTreeNode)) {
                        ancestor = it
                    }
                    true
                }
            }
            ancestor ?: firstNode!!
        }
    }


    /**
     * Associates the tree given node with the [TagModifierElement] that tags it in the modifier tree.
     */
    internal fun associateNode(node: TreeListNode<I, L>, tagModifierElement: TreeListNodeModifierElement<I, L>) {
        tagModifierElementByNode[node] = tagModifierElement
    }


    /**
     * Disassociates the given tree node with the [TagModifierElement] that tags it in the modifier tree.
     */
    internal fun disassociateNode(node: TreeListNode<I, L>, tagModifierElement: TreeListNodeModifierElement<I, L>) {
        tagModifierElementByNode.remove(node, tagModifierElement)
        if (selectedNode == node)
            selectedNode = null
    }


    /**
     * Returns the [TagModifierNode] corresponding to the currently selected tree list node.
     */
    private fun selectedModifierNode(): TagModifierNode<TreeListNode<I, L>>? {
        val selectedNode = this.selectedNode ?: return null
        return tagModifierElementByNode[selectedNode]?.modifierNode
            ?: error("Missing TreeListNodeModifierElement for node $selectedNode")
    }


}


/**
 * The prefix of the traverse key for the tree list's root element.
 */
private const val ROOT_NODE_TRAVERSE_KEY_PREFIX = "compose.widgets.TREELIST_PARENT_NODE_KEY"


/**
 * The prefix of the traverse key for the tree list's child elements.
 */
private const val CHILD_NODE_TRAVERSE_KEY_PREFIX = "compose.widgets.TREELIST_NODE_KEY"


/**
 * Generates traverse keys for the root and child nodes.
 */
private fun generateTraverseKeys() = Pair(
    generateTraverseKey(ROOT_NODE_TRAVERSE_KEY_PREFIX),
    generateTraverseKey(CHILD_NODE_TRAVERSE_KEY_PREFIX)
)


/**
 * The [TagModifierElement] with which we tag nodes so that we can find e.g. the previous/next node.
 */
private typealias TreeListNodeModifierElement<I, L> = TagModifierElement<TreeListNode<I, L>>


/**
 * The [TagModifierElement] with which we tag the tree list's root element.
 */
private fun TreeListRootModifierElement(traverseKey: String) =  TagModifierElement(traverseKey, Unit)


/**
 * A modifier for moving the selection in the tree in response to the standard set of key shortcuts.
 */
fun <I, L> Modifier.moveSelectionWithKeys(


    /**
     * The [TreeListState] to modify when keys are pressed.
     */
    state: TreeListState<I, L>,


    /**
     * Returns the number of items in a page. If `null`, page-up and page-down keys won't do anything.
     */
    itemsInPage: (() -> Int)? = null,


    /**
     * Whether to use [Modifier.onPreviewKeyEvent] when listening to key events.
     * Otherwise, [Modifier.onKeyEvent] will be used.
     */
    onPreview: Boolean = false


): Modifier = this
    .onKeyShortcut(Key.DirectionUp, onPreview = onPreview) {
        state.selectPrevious()
    }
    .onKeyShortcut(Key.DirectionDown, onPreview = onPreview) {
        state.selectNext()
    }
    .onKeyShortcut(Key.MoveHome, onPreview = onPreview) {
        state.selectFirst()
    }
    .onKeyShortcut(Key.MoveEnd, onPreview = onPreview) {
        state.selectLast()
    }
    .onKeyShortcut(Key.DirectionLeft, onPreview = onPreview) {
        state.selectedNode?.let {
            // Don't consume it unless we actually performed an action, so that expandCollapseWithKeys can handle it
            consumeEvent = false
            if ((it is TreeListNode.Leaf) || !state.isExpanded(it as TreeListNode.Inner)) {
                state.selectParentOrFirst()
                consumeEvent = true
            }
        }
    }
    .onKeyShortcut(Key.DirectionRight, onPreview = onPreview) {
        state.selectedNode?.let {
            // Don't consume it unless we actually performed an action, so that expandCollapseWithKeys can handle it
            consumeEvent = false
            if ((it is TreeListNode.Leaf) || state.isExpanded(it as TreeListNode.Inner)) {
                state.selectNext()
                consumeEvent = true
            }
        }
    }
    .then(
        if (itemsInPage == null)
            Modifier
        else {
            Modifier
                .onKeyShortcut(Key.PageUp, onPreview = onPreview) {
                    state.selectPreviousPage(itemsInPage())
                }
                .onKeyShortcut(Key.PageDown, onPreview = onPreview) {
                    state.selectNextPage(itemsInPage())
                }
        }
    )


/**
 * A modifier for expanding and collapsing the currently selected node with the standard key shortcuts for that.
 */
fun <I, L> Modifier.expandCollapseWithKeys(


    /**
     * The [TreeListState] to modify when keys are pressed.
     */
    state: TreeListState<I, L>,


    /**
     * Whether to use [Modifier.onPreviewKeyEvent] when listening to key events.
     * Otherwise, [Modifier.onKeyEvent] will be used.
     */
    onPreview: Boolean = false


): Modifier = this
    .onKeyShortcut(Key.DirectionLeft, onPreview = onPreview) {
        // Don't consume it unless we actually performed an action, so that moveSelectionWithKeys can handle it
        consumeEvent = false
        (state.selectedNode as? TreeListNode.Inner)?.let {
            consumeEvent = state.collapse(it)
        }
    }
    .onKeyShortcut(Key.DirectionRight, onPreview = onPreview) {
        // Don't consume it unless we actually performed an action, so that moveSelectionWithKeys can handle it
        consumeEvent = false
        (state.selectedNode as? TreeListNode.Inner)?.let {
            consumeEvent = state.expand(it)
        }
    }
